-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add import_frule (reprised) #1333
Conversation
34da7e1
to
c475ab8
Compare
Is there a reason the previous reverse mode one failed? |
It think it had a syntax error, but also since I'm not really sure what I'm doing I wanted to focus on one thing at a time. |
test/ext/chainrulescore.jl
Outdated
Enzyme.@import_frule typeof(Base.sort) Any | ||
for Tret in (Duplicated, DuplicatedNoNeed) | ||
for Tx in (Duplicated, BatchDuplicated) | ||
test_forward(sort, Tret, (x, Tx)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this fails for Tx == BatchDuplicated. Any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the failure?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
julia> x = [1.0, 2.0, 0.0];
julia> Enzyme.@import_frule typeof(Base.sort) Any
julia> test_forward(Base.sort, Duplicated, (x, BatchDuplicated))
test_forward: sort with return activity Duplicated on (::Vector{Float64}, BatchDuplicated): Error During Test at /Users/carlo/.julia/packages/EnzymeTestUtils/CV2ed/src/test_forward.jl:68
Got exception outside of a @test
DimensionMismatch: arrays could not be broadcast to a common size; got a dimension with lengths 3 and 6
Stacktrace:
[1] _bcs1
@ ./broadcast.jl:555 [inlined]
[2] _bcs
@ ./broadcast.jl:549 [inlined]
[3] broadcast_shape
@ ./broadcast.jl:543 [inlined]
[4] combine_axes
@ ./broadcast.jl:524 [inlined]
[5] instantiate
@ ./broadcast.jl:306 [inlined]
[6] materialize
@ ./broadcast.jl:903 [inlined]
[7] (::FiniteDifferences.var"#85#86"{FiniteDifferences.var"#87#88"{EnzymeTestUtils.var"#fnew#28"{EnzymeTestUtils.var"#call_with_copy#38"{@NamedTuple{}}, Tuple{typeof(sort), Vector{Float64}}, Tuple{Bool, Bool}}, typeof(identity)}, Vector{Float64}, Vector{Float64}})(ε::Float64)
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/zWRHl/src/grad.jl:48
[8] newf
@ ~/.julia/packages/StaticArrays/EHHaF/src/broadcast.jl:186 [inlined]
[9] macro expansion
@ ~/.julia/packages/StaticArrays/EHHaF/src/broadcast.jl:135 [inlined]
[10] __broadcast
@ ~/.julia/packages/StaticArrays/EHHaF/src/broadcast.jl:123 [inlined]
[11] _broadcast
@ ~/.julia/packages/StaticArrays/EHHaF/src/broadcast.jl:119 [inlined]
[12] copy
@ ~/.julia/packages/StaticArrays/EHHaF/src/broadcast.jl:60 [inlined]
[13] materialize
@ ./broadcast.jl:903 [inlined]
[14] _eval_function(m::FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}, f::FiniteDifferences.var"#85#86"{FiniteDifferences.var"#87#88"{EnzymeTestUtils.var"#fnew#28"{EnzymeTestUtils.var"#call_with_copy#38"{@NamedTuple{}}, Tuple{typeof(sort), Vector{Float64}}, Tuple{Bool, Bool}}, typeof(identity)}, Vector{Float64}, Vector{Float64}}, x::Float64, step::Float64)
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/zWRHl/src/methods.jl:249
[15] _estimate_magnitudes(m::FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}, f::FiniteDifferences.var"#85#86"{FiniteDifferences.var"#87#88"{EnzymeTestUtils.var"#fnew#28"{EnzymeTestUtils.var"#call_with_copy#38"{@NamedTuple{}}, Tuple{typeof(sort), Vector{Float64}}, Tuple{Bool, Bool}}, typeof(identity)}, Vector{Float64}, Vector{Float64}}, x::Float64)
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/zWRHl/src/methods.jl:378
[16] estimate_step(m::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::FiniteDifferences.var"#85#86"{FiniteDifferences.var"#87#88"{EnzymeTestUtils.var"#fnew#28"{EnzymeTestUtils.var"#call_with_copy#38"{@NamedTuple{}}, Tuple{typeof(sort), Vector{Float64}}, Tuple{Bool, Bool}}, typeof(identity)}, Vector{Float64}, Vector{Float64}}, x::Float64)
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/zWRHl/src/methods.jl:365
[17] (::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}})(f::FiniteDifferences.var"#85#86"{FiniteDifferences.var"#87#88"{EnzymeTestUtils.var"#fnew#28"{EnzymeTestUtils.var"#call_with_copy#38"{@NamedTuple{}}, Tuple{typeof(sort), Vector{Float64}}, Tuple{Bool, Bool}}, typeof(identity)}, Vector{Float64}, Vector{Float64}}, x::Float64)
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/zWRHl/src/methods.jl:193
[18] _jvp
@ ~/.julia/packages/FiniteDifferences/zWRHl/src/grad.jl:48 [inlined]
[19] jvp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::EnzymeTestUtils.var"#fnew#28"{EnzymeTestUtils.var"#call_with_copy#38"{@NamedTuple{}}, Tuple{typeof(sort), Vector{Float64}}, Tuple{Bool, Bool}}, ::Tuple{Vector{Float64}, Tuple{Vector{Float64}, Vector{Float64}}})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/zWRHl/src/grad.jl:60
[20] _fd_forward(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, rettype::Type, y::Vector{Float64}, activities::Tuple{Const{typeof(sort)}, BatchDuplicated{Vector{Float64}, 2}})
@ EnzymeTestUtils ~/.julia/packages/EnzymeTestUtils/CV2ed/src/finite_difference_calls.jl:30
[21] macro expansion
@ ~/.julia/packages/EnzymeTestUtils/CV2ed/src/test_forward.jl:77 [inlined]
[22] macro expansion
@ ~/.julia/juliaup/julia-1.10.2+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
[23] test_forward(f::Function, ret_activity::Type, args::Tuple{Vector{Float64}, UnionAll}; fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, fkwargs::@NamedTuple{}, rtol::Float64, atol::Float64, testset_name::Nothing)
@ EnzymeTestUtils ~/.julia/packages/EnzymeTestUtils/CV2ed/src/test_forward.jl:70
[24] test_forward(f::Function, ret_activity::Type, args::Tuple{Vector{Float64}, UnionAll})
@ EnzymeTestUtils ~/.julia/packages/EnzymeTestUtils/CV2ed/src/test_forward.jl:53
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sethaxen this looks like a bug in enzymetestutils?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I know, no. As I asked in #1264 (comment), does Enzyme now allow one to mix Duplicated
and BatchDuplicated
? That used to cause an error. EnzymeTestUtils assumes these are not mixed and provides are_activities_compatible
to skip cases in TestSet
s that would mix them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, no you need one single batch width for the whole program (e.g. all duplicated, or all batchduplicated with the same width).
However, we do automatically upgrade a duplicated/dupicatednoneed return to whatever the width of the args were, if they were batch (since no data is in the return).
It would be nice for this shorthand to work. But indeed @CarloLucibello the alternate is testing {Duplicated, DuplicatedNoNeed} ret x {Const, Duplicated} input and {BatchDuplicated, BatchDuplicatedNoNeed} ret x {Const, BatchDuplicated} input
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Irrespectively, either we should upgrade testutils to handle this case (since here enzyme actually supports this, by upgrading to batchduplicated), or we should throw a nicer error here rather than bailing out in finite differences internals.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. These 4 are fine
test_forward(Base.sort, Duplicated, (x, Duplicated))
test_forward(Base.sort, BatchDuplicated, (x, BatchDuplicated))
test_forward(Base.sort, DuplicatedNoNeed, (x, Duplicated))
test_forward(Base.sort, BatchDuplicatedNoNeed, (x, BatchDuplicated))
But anything involving Const
errors
test_forward(Base.sort, DuplicatedNoNeed, (x, Const))
UndefVarError: `ChainRulesCore` not defined
Stacktrace:
[1] forward
@ ~/.julia/dev/Enzyme/ext/EnzymeChainRulesCoreExt.jl:67
[2] forward
@ ~/.julia/dev/Enzyme/ext/EnzymeChainRulesCoreExt.jl:62 [inlined]
[3] call_with_kwargs
@ ~/.julia/packages/EnzymeTestUtils/CV2ed/src/test_forward.jl:64 [inlined]
[4] fwddiffejulia_call_with_kwargs_10402wrap
@ ~/.julia/packages/EnzymeTestUtils/CV2ed/src/test_forward.jl:0
[5] macro expansion
@ ~/.julia/dev/Enzyme/src/compiler.jl:5440 [inlined]
[6] enzyme_call
@ ~/.julia/dev/Enzyme/src/compiler.jl:5118 [inlined]
[7] ForwardModeThunk
@ ~/.julia/dev/Enzyme/src/compiler.jl:5003 [inlined]
[8] autodiff
@ ~/.julia/dev/Enzyme/src/Enzyme.jl:384 [inlined]
[9] autodiff(::ForwardMode{FFIABI}, ::EnzymeTestUtils.var"#call_with_kwargs#39"{@NamedTuple{}}, ::Type{DuplicatedNoNeed}, ::Const{typeof(sort)}, ::Const{Vector{Float64}})
@ Enzyme ~/.julia/dev/Enzyme/src/Enzyme.jl:287
[10] macro expansion
@ ~/.julia/packages/EnzymeTestUtils/CV2ed/src/test_forward.jl:79 [inlined]
[11] macro expansion
@ ~/.julia/juliaup/julia-1.10.2+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
[12] test_forward(f::Function, ret_activity::Type, args::Tuple{Vector{Float64}, UnionAll}; fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, fkwargs::@NamedTuple{}, rtol::Float64, atol::Float64, testset_name::Nothing)
@ EnzymeTestUtils ~/.julia/packages/EnzymeTestUtils/CV2ed/src/test_forward.jl:70
[13] test_forward(f::Function, ret_activity::Type, args::Tuple{Vector{Float64}, UnionAll})
@ EnzymeTestUtils ~/.julia/packages/EnzymeTestUtils/CV2ed/src/test_forward.jl:53
[14] macro expansion
@ ~/.julia/dev/Enzyme/test/ext/chainrulescore.jl:34 [inlined]
[15] macro expansion
@ ~/.julia/juliaup/julia-1.10.2+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
[16] macro expansion
@ ~/.julia/dev/Enzyme/test/ext/chainrulescore.jl:25 [inlined]
[17] macro expansion
@ ~/.julia/juliaup/julia-1.10.2+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
[18] top-level scope
@ ~/.julia/dev/Enzyme/test/ext/chainrulescore.jl:11
I get similar errors for
test_forward(Base.sort, DuplicatedNoNeed, (x, Const))
test_forward(Base.sort, Duplicated, (x, Const))
test_forward(Base.sort, BatchDuplicatedNoNeed, (x, Const))
test_forward(Base.sort, BatchDuplicated, (x, Const))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Huh, yeah we definitely need that one to pass. Maybe try changing
:($ty <: Const ? ChainRulesCore.NoTangent() : $val.dval)
into
:($ty <: Const ? $(ChainRulesCore.NoTangent()) : $val.dval)
to fix it? (and similar throughout).
cc @vchuravy since Julia macros are not my forte
end | ||
|
||
quote | ||
function EnzymeRules.forward(fn::FA, ::Type{RetAnnotation}, $(exprs...); kwargs...) where {RetAnnotation, FA<:Annotation{<:$(esc(fn))}, $(anns...)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
function EnzymeRules.forward(fn::FA, ::Type{RetAnnotation}, $(exprs...); kwargs...) where {RetAnnotation, FA<:Annotation{<:$(esc(fn))}, $(anns...)} | |
function EnzymeRules.forward(fn::FA, ::Type{RetAnnotation}, $(exprs...); | |
kwargs...) where {RetAnnotation, | |
FA<:Annotation{<:$(esc(fn))}, | |
$(anns...)} |
batchsize = same_or_one(1, $(vals...)) | ||
if batchsize == 1 | ||
dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval | ||
cres = $ChainRulesCore.frule((dfn, $(tangents...),), fn.val, $(primals...); kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
cres = $ChainRulesCore.frule((dfn, $(tangents...),), fn.val, $(primals...); kwargs...) | |
cres = $ChainRulesCore.frule((dfn, $(tangents...)), fn.val, $(primals...); | |
kwargs...) |
ntuple(Val(batchsize)) do i | ||
Base.@_inline_meta | ||
dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i] | ||
$ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
$ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...) | |
return $ChainRulesCore.frule((dfn, $(tangentsi...)), fn.val, | |
$(primals...); kwargs...) |
cres1 = begin | ||
i = 1 | ||
dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i] | ||
$ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
$ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...) | |
$ChainRulesCore.frule((dfn, $(tangentsi...)), fn.val, $(primals...); | |
kwargs...) |
dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i] | ||
$ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...) | ||
end | ||
batches = ntuple(Val(batchsize-1)) do j |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
batches = ntuple(Val(batchsize-1)) do j | |
batches = ntuple(Val(batchsize - 1)) do j |
@testset "batch duplicated" begin | ||
x = [1.0, 2.0, 0.0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
@testset "batch duplicated" begin | |
x = [1.0, 2.0, 0.0] | |
@testset "batch duplicated" begin | |
x = [1.0, 2.0, 0.0] |
# TEST EXTENSIONS | ||
@static if VERSION ≥ v"1.9-" | ||
using SpecialFunctions | ||
@testset "SpecialFunctions ext" begin | ||
lgabsg(x) = SpecialFunctions.logabsgamma(x)[1] | ||
test_scalar(lgabsg, 1.0; rtol = 1.0e-5, atol = 1.0e-5) | ||
test_scalar(lgabsg, 1.0f0; rtol = 1.0e-5, atol = 1.0e-5) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
# TEST EXTENSIONS | |
@static if VERSION ≥ v"1.9-" | |
using SpecialFunctions | |
@testset "SpecialFunctions ext" begin | |
lgabsg(x) = SpecialFunctions.logabsgamma(x)[1] | |
test_scalar(lgabsg, 1.0; rtol = 1.0e-5, atol = 1.0e-5) | |
test_scalar(lgabsg, 1.0f0; rtol = 1.0e-5, atol = 1.0e-5) | |
end | |
res = first(Enzyme.autodiff(Reverse, Base.hvcat_fill!, Const, Duplicated(ar, dar), | |
Active((1, 2.2, 3, 4.4, 5, 6.6)))) |
using ChainRulesCore | ||
@testset "ChainRulesCore ext" begin | ||
include("ext/chainrulescore.jl") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
using ChainRulesCore | |
@testset "ChainRulesCore ext" begin | |
include("ext/chainrulescore.jl") | |
@test res[2][1] == 0 | |
@test res[2][2] ≈ 2.0 | |
@test res[2][3] ≈ 0 | |
@test res[2][4] ≈ 4.0 | |
@test res[2][5] ≈ 0 | |
@test res[2][6] ≈ 6.0 |
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end |
end | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
# TEST EXTENSIONS | |
@static if VERSION ≥ v"1.9-" | |
using SpecialFunctions | |
@testset "SpecialFunctions ext" begin | |
lgabsg(x) = SpecialFunctions.logabsgamma(x)[1] | |
test_scalar(lgabsg, 1.0; rtol=1.0e-5, atol=1.0e-5) | |
test_scalar(lgabsg, 1.0f0; rtol=1.0e-5, atol=1.0e-5) | |
end | |
Enzyme.@import_frule typeof(Base.sort) Any | ||
|
||
test_forward(Base.sort, Duplicated, (x, Duplicated)) | ||
# Unsupported by EnzymeTestUtils |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sethaxen would it be possible to add support for duplicatednoneed and variants to enzymetestutils.
I'm going to approve/merge this without those, but it would be nice to enable later
Continuation of #996, partially addressing #583. I made the following changes:
TODO