Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly set within_autodiff (#442) #490

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ function set_reactant_abi(
if f === Reactant.call_with_reactant
arginfo2 = ArgInfo(fargs isa Nothing ? nothing : fargs[2:end], argtypes[2:end])
return abstract_call(interp, arginfo2::ArgInfo, si, sv, max_methods)
elseif !(interp.within_autodiff_rewrite) && f === overload_autodiff
interp′ = Enzyme.Compiler.Interpreter.EnzymeInterpreter(
interp; within_autodiff_rewrite=true
)
return Base.@invoke abstract_call_known(
interp′::Enzyme.Compiler.Interpreter.EnzymeInterpreter,
f,
arginfo,
si,
sv,
max_methods,
)
end

return Base.@invoke abstract_call_known(
Expand All @@ -59,7 +71,9 @@ end
@static if Enzyme.GPUCompiler.HAS_INTEGRATED_CACHE
struct ReactantCacheToken end

function ReactantInterpreter(; world::UInt=Base.get_world_counter())
function ReactantInterpreter(;
world::UInt=Base.get_world_counter(), within_autodiff=false
)
return Enzyme.Compiler.Interpreter.EnzymeInterpreter(
ReactantCacheToken(),
REACTANT_METHOD_TABLE,
Expand All @@ -68,14 +82,17 @@ end
false, #=reverse_rules=#
false, #=inactive_rules=#
false, #=broadcast_rewrite=#
within_autodiff, #=within_autodiff_rewrite=#
set_reactant_abi,
)
end
else
const REACTANT_CACHE = Enzyme.GPUCompiler.CodeCache()

function ReactantInterpreter(;
world::UInt=Base.get_world_counter(), code_cache=REACTANT_CACHE
world::UInt=Base.get_world_counter(),
code_cache=REACTANT_CACHE,
within_autodiff=false,
)
return Enzyme.Compiler.Interpreter.EnzymeInterpreter(
REACTANT_CACHE,
Expand All @@ -85,6 +102,7 @@ else
false, #=reverse_rules=#
false, #=inactive_rules=#
false, #=broadcast_rewrite=#
within_autodiff, #=within_autodiff_rewrite=#
set_reactant_abi,
)
end
Expand Down Expand Up @@ -238,7 +256,7 @@ function overload_autodiff(
primargs = ((v.val for v in args)...,)

fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = TracedUtils.make_mlir_fn(
primf, primargs, (), string(f) * "_autodiff", false
primf, primargs, (), string(f) * "_autodiff", false; within_autodiff=true
)

activity = Int32[]
Expand Down
8 changes: 6 additions & 2 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ function make_mlir_fn(
no_args_in_result::Bool=false,
construct_function_without_args::Bool=false,
do_transpose=true,
within_autodiff=false,
)
if sizeof(typeof(f)) != 0 || f isa Base.BroadcastFunction
return (
Expand Down Expand Up @@ -180,8 +181,11 @@ function make_mlir_fn(
arg.mlir_data = row_maj_arg
end
end

Reactant.call_with_reactant(f, traced_args...)
if within_autodiff
Reactant.call_with_reactant_within_autodiff(f, traced_args...)
else
Reactant.call_with_reactant(f, traced_args...)
end
finally
MLIR.IR.deactivate!(fnbody)
end
Expand Down
11 changes: 10 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ function apply(f, args...; kwargs...)
end

function call_with_reactant end
function call_with_reactant_within_autodiff end

function maybe_argextype(@nospecialize(x), src)
return try
Expand Down Expand Up @@ -483,7 +484,11 @@ function call_with_reactant_generator(
MethodError($REDUB_ARGUMENTS_NAME[1], $REDUB_ARGUMENTS_NAME[2:end], $world)
))

interp = ReactantInterpreter(; world)
if self == typeof(Reactant.call_with_reactant_within_autodiff)
interp = ReactantInterpreter(; world, within_autodiff=true)
else
interp = ReactantInterpreter(; world, within_autodiff=false)
end

min_world = Ref{UInt}(typemin(UInt))
max_world = Ref{UInt}(typemax(UInt))
Expand Down Expand Up @@ -728,3 +733,7 @@ end
$(Expr(:meta, :generated_only))
return $(Expr(:meta, :generated, call_with_reactant_generator))
end
@eval function call_with_reactant_within_autodiff($REDUB_ARGUMENTS_NAME...)
$(Expr(:meta, :generated_only))
return $(Expr(:meta, :generated, call_with_reactant_generator))
end
15 changes: 15 additions & 0 deletions test/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,21 @@ fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y))
@test res1[1] ≈ ores1[1]
end

function error_not_within_autodiff()
!Enzyme.within_autodiff() && error("Not within autodiff")
return nothing
end

fwd_within_autodiff(Mode, RT) = Enzyme.autodiff(Mode, error_not_within_autodiff, RT)

@testset "within_autodiff" begin
@test_throws ErrorException error_not_within_autodiff()
@test fwd_within_autodiff(Forward, Const) == ()

@test_throws ErrorException @jit error_not_within_autodiff()
@test (@jit fwd_within_autodiff(Forward, Const)) == ()
end

function gw(z)
return Enzyme.gradient(Forward, sum, z; chunk=Val(1))
end
Expand Down
Loading