Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[enzyme] broken MultiHeadAttention gradient #2567

Open
CarloLucibello opened this issue Dec 31, 2024 · 5 comments
Open

[enzyme] broken MultiHeadAttention gradient #2567

CarloLucibello opened this issue Dec 31, 2024 · 5 comments
Labels

Comments

@CarloLucibello
Copy link
Member

CarloLucibello commented Dec 31, 2024

using Flux, Enzyme, Statistics, Random

function enzyme_withgradient(f, x...)
    args = []
    for x in x
        if x isa Number
            push!(args, Enzyme.Active(x))
        else
            push!(args, Enzyme.Duplicated(x, Enzyme.make_zero(x)))
        end
    end
    ad = Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal)
    ret = Enzyme.autodiff(ad, Enzyme.Const(f), Enzyme.Active, args...)
    g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
    return ret[2], g
end

loss(model, x) = mean(model(x)[1])
model = MultiHeadAttention(16)
x = randn(Float32, 16, 5, 2)
enzyme_withgradient(loss, model, x)

Output:

ERROR: MethodError: no method matching function_attributes(::LLVM.UserOperandSet)
The function `function_attributes` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  function_attributes(::LLVM.Function)
   @ LLVM ~/.julia/packages/LLVM/wMjUU/src/core/function.jl:127

Stacktrace:
  [1] check_ir!(job::GPUCompiler.CompilerJob, errors::Vector{…}, imported::Set{…}, f::LLVM.Function, deletedfns::Vector{…}, mod::LLVM.Module)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler/validation.jl:402
  [2] check_ir!(job::GPUCompiler.CompilerJob, errors::Vector{Tuple{String, Vector{…}, Any}}, mod::LLVM.Module)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler/validation.jl:210
  [3] check_ir
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler/validation.jl:179 [inlined]
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:3413
  [5] codegen
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:3338 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5387
  [7] _thunk
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5387 [inlined]
  [8] cached_compilation
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5439 [inlined]
  [9] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5550
 [10] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5735
 [11] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Duplicated{…})
    @ Enzyme ~/.julia/packages/Enzyme/DiEvV/src/Enzyme.jl:485
 [12] enzyme_withgradient(::Function, ::MultiHeadAttention{Dense{…}, Dropout{…}, Dense{…}}, ::Vararg{Any})
    @ Main ./REPL[14]:11
 [13] top-level scope
    @ ~/.julia/dev/Flux/prova.jl:16
Some type information was truncated. Use `show(err)` to see complete types.

cc @wsmoses

@wsmoses
Copy link
Contributor

wsmoses commented Dec 31, 2024

This should probably resolve the issue above: EnzymeAD/Enzyme.jl#2239

@CarloLucibello
Copy link
Member Author

Fixed, but gives some warnings

julia> enzyme_withgradient(loss, model, x)
┌ Warning: active variables passed by value to jl_new_task are not yet supported
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/2CW9L/src/utils.jl:59
┌ Warning: active variables passed by value to jl_new_task are not yet supported
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/2CW9L/src/utils.jl:59
(-0.06945832f0, (MultiHeadAttention(16; nheads=8), Float32[0.0098144375 0.019144177  0.005306968 0.009218991; -0.004492016 -0.007321746  0.00065406406 -0.0012907407;  ; 0.00060874515 -0.00632565  -0.0024870127 -0.0071006473; -0.010946414 -0.0075595975  0.006028706 -0.0074309492;;; 0.0026188223 0.010693712  0.02191737 0.0128253205; -0.002983723 0.0007982431  -0.0038285365 0.0014048074;  ; -0.002875765 -0.004498572  -0.0061464626 -0.0040150927; -0.004544966 -0.002571526  0.001993513 0.005268162]))

@wsmoses
Copy link
Contributor

wsmoses commented Jan 1, 2025

I think they can likely be ignored (that warning is over conservative and prints any time you have a spawn)

@CarloLucibello
Copy link
Member Author

@wsmoses this is still failing on julia 1.10 (works on 1.11)

julia> enzyme_withgradient(loss, model, x)
┌ Warning: active variables passed by value to jl_new_task are not yet supported
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/2CW9L/src/utils.jl:59
┌ Warning: active variables passed by value to jl_new_task are not yet supported
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/2CW9L/src/utils.jl:59
┌ Warning: active variables passed by value to jl_new_task are not yet supported
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/2CW9L/src/utils.jl:59
┌ Warning: active variables passed by value to jl_new_task are not yet supported
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/2CW9L/src/utils.jl:59
ERROR: Enzyme compilation failed due to an internal error.
 Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl
 To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)

Illegal replace ficticious phi for:   %_replacementE = phi {} addrspace(10)* , !dbg !319 of   %107 = call fastcc nonnull {} addrspace(10)* @julia_wait_11104() #438, !dbg !362

Stacktrace:
 [1] #wait#645
   @ ./condition.jl:130

Stacktrace:
  [1] julia_error(msg::String, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing}, data2::Ptr{LLVM.API.LLVMOpaqueValue}, B::Ptr{LLVM.API.LLVMOpaqueBuilder})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/errors.jl:384
  [2] julia_error(cstr::Cstring, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing}, data2::Ptr{LLVM.API.LLVMOpaqueValue}, B::Ptr{LLVM.API.LLVMOpaqueBuilder})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/errors.jl:210
  [3] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/R6sE8/src/api.jl:268
  [4] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:1706
  [5] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4550
  [6] codegen
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:3353 [inlined]
  [7] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5410
  [8] _thunk
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5410 [inlined]
  [9] cached_compilation
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5462 [inlined]
 [10] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::Tuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5573
 [11] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::Tuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5758
 [12] autodiff(::ReverseMode{true, true, FFIABI, false, false}, ::Const{typeof(loss)}, ::Type{Active}, ::Duplicated{MultiHeadAttention{Dense{…}, Dropout{…}, Dense{…}}}, ::Duplicated{Array{Float32, 3}})
    @ Enzyme ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:485
 [13] enzyme_withgradient(::Function, ::MultiHeadAttention{Dense{typeof(identity), Matrix{Float32}, Bool}, Dropout{Float64, Colon, TaskLocalRNG}, Dense{typeof(identity), Matrix{Float32}, Bool}}, ::Vararg{Any})
    @ Main ./REPL[3]:11
 [14] top-level scope
    @ REPL[7]:1
Some type information was truncated. Use `show(err)` to see complete types.

@wsmoses
Copy link
Contributor

wsmoses commented Jan 6, 2025

Ah yeah I just fixed that 1.11 intrinsic.

this one is more weird.

Can you try to make a more minimal mwe and open an issue in enzyme?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants