diff --git a/.github/workflows/AD.yml b/.github/workflows/AD.yml index 3777b346..7d0aa4ae 100644 --- a/.github/workflows/AD.yml +++ b/.github/workflows/AD.yml @@ -23,13 +23,13 @@ jobs: AD: - Enzyme - ForwardDiff - - Tapir + - Mooncake - Tracker - ReverseDiff - Zygote exclude: - version: 1.6 - AD: Tapir + AD: Mooncake # TODO(mhauru) Hopefully can enable Enzyme on older versions at some point, see # discussion in https://github.com/TuringLang/Bijectors.jl/pull. - version: 1.6 diff --git a/Project.toml b/Project.toml index 62400e20..283b706c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.13.19" +version = "0.14.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" @@ -30,7 +30,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -40,7 +40,7 @@ BijectorsEnzymeExt = ["Enzyme", "EnzymeCore"] BijectorsForwardDiffExt = "ForwardDiff" BijectorsLazyArraysExt = "LazyArrays" BijectorsReverseDiffExt = "ReverseDiff" -BijectorsTapirExt = "Tapir" +BijectorsMooncakeExt = "Mooncake" BijectorsTrackerExt = "Tracker" BijectorsZygoteExt = "Zygote" @@ -67,7 +67,7 @@ Requires = "0.5, 1" ReverseDiff = "1" Roots = "1.3.4, 2" Statistics = "1" -Tapir = "0.2.23" +Mooncake = "0.4.19" Tracker = "0.2" Zygote = "0.6.63" julia = "1.6" @@ -79,6 +79,6 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/ext/BijectorsEnzymeExt.jl b/ext/BijectorsEnzymeExt.jl index 57f2e4b0..303fd92f 100644 --- a/ext/BijectorsEnzymeExt.jl +++ b/ext/BijectorsEnzymeExt.jl @@ -1,599 +1,18 @@ module BijectorsEnzymeExt if isdefined(Base, :get_extension) - using Enzyme: Enzyme - using EnzymeCore: EnzymeCore - using Bijectors: Bijectors, ChainRulesCore + using Enzyme: @import_rrule, @import_frule + using Bijectors: find_alpha else - using ..Enzyme: Enzyme - using ..EnzymeCore: EnzymeCore - using ..Bijectors: Bijectors, ChainRulesCore + using ..Enzyme: @import_rrule, @import_frule + using ..Bijectors: find_alpha end -#= NOTE(penelopeysm): -Changes made to the way extensions were loaded in Julia 1.11.1 mean that it -is no longer sufficient to call Enzyme.@import_rrule and -Enzyme.@import_frule, as we did in previous versions. This is because both of -those macros rely on a method which is defined in EnzymeChainRulesCoreExt, -and on 1.11.1+, that extension is _not_ loaded before BijectorsEnzymeExt is -loaded. (In the past, for reasons which are not fully clear, -EnzymeChainRulesCoreExt _does_ get loaded first.) - -See https://github.com/TuringLang/Bijectors.jl/pull/333 for further context. - -However, on versions of Julia where the 'default' extension resolution occurs, -we can still use the macros (see the else clause below). We do this to ensure -that the code is compatible with what may potentially be different versions of -Enzyme. - -The code in the if clause was derived by calling @macroexpand on @import_rrule -and @import_frule, then replacing `$(Expr(:meta, :inline))` with -`Base.@_inline_meta`. - -Note that this was done using Enzyme v0.12.36. This code will fail to track any -upstream changes to EnzymeChainRulesCoreExt, so there is no guarantee that this -code will work with later versions of Enzyme. -=# @static if v"1.11.1" <= VERSION < v"1.12" - function (Enzyme.EnzymeRules).augmented_primal( - var"#238#config", - var"#239#fn"::var"#246#FA", - ::Enzyme.Type{var"#245#RetAnnotation"}, - var"#241#arg_1"::var"#247#AN_1", - var"#242#arg_2"::var"#248#AN_2", - var"#243#arg_3"::var"#249#AN_3"; - var"#244#kwargs"..., - ) where { - var"#245#RetAnnotation", - var"#246#FA"<:Enzyme.Annotation{<:typeof(Bijectors.find_alpha)}, - var"#247#AN_1"<:Enzyme.Annotation{<:Real}, - var"#248#AN_2"<:Enzyme.Annotation{<:Real}, - var"#249#AN_3"<:Enzyme.Annotation{<:Real}, - } - var"#231#primcopy_1" = - if ((EnzymeCore.EnzymeRules.overwritten)(var"#238#config"))[1 + 1] - Enzyme.deepcopy((var"#241#arg_1").val) - else - (var"#241#arg_1").val - end - var"#232#primcopy_2" = - if ((EnzymeCore.EnzymeRules.overwritten)(var"#238#config"))[2 + 1] - Enzyme.deepcopy((var"#242#arg_2").val) - else - (var"#242#arg_2").val - end - var"#233#primcopy_3" = - if ((EnzymeCore.EnzymeRules.overwritten)(var"#238#config"))[3 + 1] - Enzyme.deepcopy((var"#243#arg_3").val) - else - (var"#243#arg_3").val - end - (var"#234#res", var"#235#pullback") = if var"#245#RetAnnotation" <: Enzyme.Const - ( - (var"#239#fn").val( - var"#231#primcopy_1", - var"#232#primcopy_2", - var"#233#primcopy_3"; - var"#244#kwargs"..., - ), - Enzyme.nothing, - ) - else - (ChainRulesCore).rrule( - (var"#239#fn").val, - var"#231#primcopy_1", - var"#232#primcopy_2", - var"#233#primcopy_3"; - var"#244#kwargs"..., - ) - end - var"#236#primal" = if (Enzyme.EnzymeRules).needs_primal(var"#238#config") - var"#234#res" - else - Enzyme.nothing - end - var"#237#shadow" = if !((Enzyme.EnzymeRules).needs_shadow(var"#238#config")) - Enzyme.nothing - else - if (Enzyme.EnzymeRules).width(var"#238#config") == 1 - (Enzyme.Enzyme).make_zero(var"#234#res") - else - Enzyme.ntuple( - Enzyme.Val((Enzyme.EnzymeRules).width(var"#238#config")) - ) do var"#250#j" - Base.@_inline_meta - (Enzyme.Enzyme).make_zero(var"#234#res") - end - end - end - return (Enzyme.EnzymeRules).AugmentedReturn( - var"#236#primal", var"#237#shadow", (var"#237#shadow", var"#235#pullback") - ) - end - - function (Enzyme.EnzymeRules).reverse( - var"#254#config", - var"#255#fn"::var"#264#FA", - ::Enzyme.Type{var"#262#RetAnnotation"}, - var"#257#tape"::var"#263#TapeTy", - var"#258#arg_1"::var"#265#AN_1", - var"#259#arg_2"::var"#266#AN_2", - var"#260#arg_3"::var"#267#AN_3"; - var"#261#kwargs"..., - ) where { - var"#262#RetAnnotation", - var"#263#TapeTy", - var"#264#FA"<:Enzyme.Annotation{<:typeof(Bijectors.find_alpha)}, - var"#265#AN_1"<:Enzyme.Annotation{<:Real}, - var"#266#AN_2"<:Enzyme.Annotation{<:Real}, - var"#267#AN_3"<:Enzyme.Annotation{<:Real}, - } - if !(var"#262#RetAnnotation" <: Enzyme.Const) - (var"#251#shadow", var"#252#pullback") = var"#257#tape" - var"#253#tcomb" = Enzyme.ntuple( - Enzyme.Val((Enzyme.EnzymeRules).width(var"#254#config")) - ) do var"#272#batch_i" - Base.@_inline_meta - var"#268#shad" = if (Enzyme.EnzymeRules).width(var"#254#config") == 1 - var"#251#shadow" - else - var"#251#shadow"[var"#272#batch_i"] - end - var"#269#res" = var"#252#pullback"(var"#268#shad") - for (var"#270#cr", var"#271#en") in Enzyme.zip( - var"#269#res", - (var"#255#fn", var"#258#arg_1", var"#259#arg_2", var"#260#arg_3"), - ) - if var"#271#en" isa Enzyme.Const || - var"#270#cr" isa (ChainRulesCore).NoTangent - continue - end - if var"#271#en" isa Enzyme.Active - continue - end - if (Enzyme.EnzymeRules).width(var"#254#config") == 1 - (var"#271#en").dval .+= var"#270#cr" - else - (var"#271#en").dval[var"#272#batch_i"] .+= var"#270#cr" - end - end - ( - if var"#255#fn" isa Enzyme.Active - var"#269#res"[1] - else - Enzyme.nothing - end, - if var"#258#arg_1" isa Enzyme.Active - if var"#269#res"[1 + 1] isa (ChainRulesCore).NoTangent - Enzyme.zero(var"#258#arg_1") - else - (ChainRulesCore).unthunk(var"#269#res"[1 + 1]) - end - else - Enzyme.nothing - end, - if var"#259#arg_2" isa Enzyme.Active - if var"#269#res"[2 + 1] isa (ChainRulesCore).NoTangent - Enzyme.zero(var"#259#arg_2") - else - (ChainRulesCore).unthunk(var"#269#res"[2 + 1]) - end - else - Enzyme.nothing - end, - if var"#260#arg_3" isa Enzyme.Active - if var"#269#res"[3 + 1] isa (ChainRulesCore).NoTangent - Enzyme.zero(var"#260#arg_3") - else - (ChainRulesCore).unthunk(var"#269#res"[3 + 1]) - end - else - Enzyme.nothing - end, - ) - end - return ( - begin - if var"#258#arg_1" isa Enzyme.Active - if (Enzyme.EnzymeRules).width(var"#254#config") == 1 - (var"#253#tcomb"[1])[1 + 1] - else - Enzyme.ntuple( - Enzyme.Val((Enzyme.EnzymeRules).width(var"#254#config")) - ) do var"#273#batch_i" - Base.@_inline_meta - (var"#253#tcomb"[var"#273#batch_i"])[1 + 1] - end - end - else - Enzyme.nothing - end - end, - begin - if var"#259#arg_2" isa Enzyme.Active - if (Enzyme.EnzymeRules).width(var"#254#config") == 1 - (var"#253#tcomb"[1])[2 + 1] - else - Enzyme.ntuple( - Enzyme.Val((Enzyme.EnzymeRules).width(var"#254#config")) - ) do var"#274#batch_i" - Base.@_inline_meta - (var"#253#tcomb"[var"#274#batch_i"])[2 + 1] - end - end - else - Enzyme.nothing - end - end, - begin - if var"#260#arg_3" isa Enzyme.Active - if (Enzyme.EnzymeRules).width(var"#254#config") == 1 - (var"#253#tcomb"[1])[3 + 1] - else - Enzyme.ntuple( - Enzyme.Val((Enzyme.EnzymeRules).width(var"#254#config")) - ) do var"#275#batch_i" - Base.@_inline_meta - (var"#253#tcomb"[var"#275#batch_i"])[3 + 1] - end - end - else - Enzyme.nothing - end - end, - ) - end - return (Enzyme.nothing, Enzyme.nothing, Enzyme.nothing) - end - - function (Enzyme.EnzymeRules).reverse( - var"#280#config", - var"#281#fn"::var"#290#FA", - var"#282#dval"::Enzyme.Active{var"#288#RetAnnotation"}, - var"#283#tape"::var"#289#TapeTy", - var"#284#arg_1"::var"#291#AN_1", - var"#285#arg_2"::var"#292#AN_2", - var"#286#arg_3"::var"#293#AN_3"; - var"#287#kwargs"..., - ) where { - var"#288#RetAnnotation", - var"#289#TapeTy", - var"#290#FA"<:Enzyme.Annotation{<:typeof(Bijectors.find_alpha)}, - var"#291#AN_1"<:Enzyme.Annotation{<:Real}, - var"#292#AN_2"<:Enzyme.Annotation{<:Real}, - var"#293#AN_3"<:Enzyme.Annotation{<:Real}, - } - (var"#276#oldshadow", var"#277#pullback") = var"#283#tape" - var"#278#shadow" = (var"#282#dval").val - var"#279#tcomb" = Enzyme.ntuple( - Enzyme.Val((Enzyme.EnzymeRules).width(var"#280#config")) - ) do var"#298#batch_i" - Base.@_inline_meta - var"#294#shad" = if (Enzyme.EnzymeRules).width(var"#280#config") == 1 - var"#278#shadow" - else - var"#278#shadow"[var"#298#batch_i"] - end - var"#295#res" = var"#277#pullback"(var"#294#shad") - for (var"#296#cr", var"#297#en") in Enzyme.zip( - var"#295#res", - (var"#281#fn", var"#284#arg_1", var"#285#arg_2", var"#286#arg_3"), - ) - if var"#297#en" isa Enzyme.Const || var"#296#cr" isa (ChainRulesCore).NoTangent - continue - end - if var"#297#en" isa Enzyme.Active - continue - end - if (Enzyme.EnzymeRules).width(var"#280#config") == 1 - (var"#297#en").dval .+= var"#296#cr" - else - (var"#297#en").dval[var"#298#batch_i"] .+= var"#296#cr" - end - end - ( - if var"#281#fn" isa Enzyme.Active - var"#295#res"[1] - else - Enzyme.nothing - end, - if var"#284#arg_1" isa Enzyme.Active - if var"#295#res"[1 + 1] isa (ChainRulesCore).NoTangent - Enzyme.zero(var"#284#arg_1") - else - (ChainRulesCore).unthunk(var"#295#res"[1 + 1]) - end - else - Enzyme.nothing - end, - if var"#285#arg_2" isa Enzyme.Active - if var"#295#res"[2 + 1] isa (ChainRulesCore).NoTangent - Enzyme.zero(var"#285#arg_2") - else - (ChainRulesCore).unthunk(var"#295#res"[2 + 1]) - end - else - Enzyme.nothing - end, - if var"#286#arg_3" isa Enzyme.Active - if var"#295#res"[3 + 1] isa (ChainRulesCore).NoTangent - Enzyme.zero(var"#286#arg_3") - else - (ChainRulesCore).unthunk(var"#295#res"[3 + 1]) - end - else - Enzyme.nothing - end, - ) - end - return ( - begin - if var"#284#arg_1" isa Enzyme.Active - if (Enzyme.EnzymeRules).width(var"#280#config") == 1 - (var"#279#tcomb"[1])[1 + 1] - else - Enzyme.ntuple( - Enzyme.Val((Enzyme.EnzymeRules).width(var"#280#config")) - ) do var"#299#batch_i" - Base.@_inline_meta - (var"#279#tcomb"[var"#299#batch_i"])[1 + 1] - end - end - else - Enzyme.nothing - end - end, - begin - if var"#285#arg_2" isa Enzyme.Active - if (Enzyme.EnzymeRules).width(var"#280#config") == 1 - (var"#279#tcomb"[1])[2 + 1] - else - Enzyme.ntuple( - Enzyme.Val((Enzyme.EnzymeRules).width(var"#280#config")) - ) do var"#300#batch_i" - Base.@_inline_meta - (var"#279#tcomb"[var"#300#batch_i"])[2 + 1] - end - end - else - Enzyme.nothing - end - end, - begin - if var"#286#arg_3" isa Enzyme.Active - if (Enzyme.EnzymeRules).width(var"#280#config") == 1 - (var"#279#tcomb"[1])[3 + 1] - else - Enzyme.ntuple( - Enzyme.Val((Enzyme.EnzymeRules).width(var"#280#config")) - ) do var"#301#batch_i" - Base.@_inline_meta - (var"#279#tcomb"[var"#301#batch_i"])[3 + 1] - end - end - else - Enzyme.nothing - end - end, - ) - end - - function (Enzyme.EnzymeRules).forward( - var"#308#fn"::var"#315#FA", - ::Enzyme.Type{var"#314#RetAnnotation"}, - var"#310#arg_1"::var"#316#AN_1", - var"#311#arg_2"::var"#317#AN_2", - var"#312#arg_3"::var"#318#AN_3"; - var"#313#kwargs"..., - ) where { - var"#314#RetAnnotation", - var"#315#FA"<:Enzyme.Annotation{<:typeof(Bijectors.find_alpha)}, - var"#316#AN_1"<:Enzyme.Annotation{<:Real}, - var"#317#AN_2"<:Enzyme.Annotation{<:Real}, - var"#318#AN_3"<:Enzyme.Annotation{<:Real}, - } - var"#302#batchsize" = Enzyme.same_or_one( - 1, var"#310#arg_1", var"#311#arg_2", var"#312#arg_3" - ) - if var"#302#batchsize" == 1 - var"#306#dfn" = if var"#308#fn" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#308#fn").dval - end - var"#303#cres" = (ChainRulesCore).frule( - ( - var"#306#dfn", - if var"#310#arg_1" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#310#arg_1").dval - end, - if var"#311#arg_2" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#311#arg_2").dval - end, - if var"#312#arg_3" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#312#arg_3").dval - end, - ), - (var"#308#fn").val, - (var"#310#arg_1").val, - (var"#311#arg_2").val, - (var"#312#arg_3").val; - var"#313#kwargs"..., - ) - if var"#314#RetAnnotation" <: Enzyme.Const - return var"#303#cres"[2]::Enzyme.eltype(var"#314#RetAnnotation") - elseif var"#314#RetAnnotation" <: Enzyme.Duplicated - return Enzyme.Duplicated(var"#303#cres"[1], var"#303#cres"[2]) - elseif var"#314#RetAnnotation" <: Enzyme.DuplicatedNoNeed - return var"#303#cres"[2]::Enzyme.eltype(var"#314#RetAnnotation") - else - if false - nothing - else - Base.throw(Base.AssertionError("false")) - end - end - else - if var"#314#RetAnnotation" <: Enzyme.Const - var"#303#cres" = - Enzyme.ntuple(Enzyme.Val(var"#302#batchsize")) do var"#305#i" - Base.@_inline_meta - var"#306#dfn" = if var"#308#fn" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#308#fn").dval[var"#305#i"] - end - (ChainRulesCore).frule( - ( - var"#306#dfn", - if var"#310#arg_1" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#310#arg_1").dval[var"#305#i"] - end, - if var"#311#arg_2" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#311#arg_2").dval[var"#305#i"] - end, - if var"#312#arg_3" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#312#arg_3").dval[var"#305#i"] - end, - ), - (var"#308#fn").val, - (var"#310#arg_1").val, - (var"#311#arg_2").val, - (var"#312#arg_3").val; - var"#313#kwargs"..., - ) - end - return (var"#303#cres"[1])[2]::Enzyme.eltype(var"#314#RetAnnotation") - elseif var"#314#RetAnnotation" <: Enzyme.BatchDuplicated - var"#304#cres1" = begin - var"#305#i" = 1 - var"#306#dfn" = if var"#308#fn" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#308#fn").dval[var"#305#i"] - end - (ChainRulesCore).frule( - ( - var"#306#dfn", - if var"#310#arg_1" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#310#arg_1").dval[var"#305#i"] - end, - if var"#311#arg_2" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#311#arg_2").dval[var"#305#i"] - end, - if var"#312#arg_3" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#312#arg_3").dval[var"#305#i"] - end, - ), - (var"#308#fn").val, - (var"#310#arg_1").val, - (var"#311#arg_2").val, - (var"#312#arg_3").val; - var"#313#kwargs"..., - ) - end - var"#307#batches" = - Enzyme.ntuple(Enzyme.Val(var"#302#batchsize" - 1)) do var"#323#j" - Base.@_inline_meta - var"#305#i" = var"#323#j" + 1 - var"#306#dfn" = if var"#308#fn" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#308#fn").dval[var"#305#i"] - end - ((ChainRulesCore).frule( - ( - var"#306#dfn", - if var"#310#arg_1" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#310#arg_1").dval[var"#305#i"] - end, - if var"#311#arg_2" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#311#arg_2").dval[var"#305#i"] - end, - if var"#312#arg_3" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#312#arg_3").dval[var"#305#i"] - end, - ), - (var"#308#fn").val, - (var"#310#arg_1").val, - (var"#311#arg_2").val, - (var"#312#arg_3").val; - var"#313#kwargs"..., - ))[2] - end - return Enzyme.BatchDuplicated( - var"#304#cres1"[1], (var"#304#cres1"[2], var"#307#batches"...) - ) - elseif var"#314#RetAnnotation" <: Enzyme.BatchDuplicatedNoNeed - Enzyme.ntuple(Enzyme.Val(var"#302#batchsize")) do var"#305#i" - Base.@_inline_meta - var"#306#dfn" = if var"#308#fn" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#308#fn").dval[var"#305#i"] - end - ((ChainRulesCore).frule( - ( - var"#306#dfn", - if var"#310#arg_1" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#310#arg_1").dval[var"#305#i"] - end, - if var"#311#arg_2" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#311#arg_2").dval[var"#305#i"] - end, - if var"#312#arg_3" isa Enzyme.Const - (ChainRulesCore).NoTangent() - else - (var"#312#arg_3").dval[var"#305#i"] - end, - ), - (var"#308#fn").val, - (var"#310#arg_1").val, - (var"#311#arg_2").val, - (var"#312#arg_3").val; - var"#313#kwargs"..., - ))[2] - end - else - if false - nothing - else - Base.throw(Base.AssertionError("false")) - end - end - end - end + @warn "Bijectors and Enzyme do not work together on Julia $VERSION" else - Enzyme.@import_rrule typeof(Bijectors.find_alpha) Real Real Real - Enzyme.@import_frule typeof(Bijectors.find_alpha) Real Real Real + @import_rrule typeof(find_alpha) Real Real Real + @import_frule typeof(find_alpha) Real Real Real end end # module diff --git a/ext/BijectorsTapirExt.jl b/ext/BijectorsMooncakeExt.jl similarity index 77% rename from ext/BijectorsTapirExt.jl rename to ext/BijectorsMooncakeExt.jl index 70805a82..d7285bf6 100644 --- a/ext/BijectorsTapirExt.jl +++ b/ext/BijectorsMooncakeExt.jl @@ -1,10 +1,11 @@ -module BijectorsTapirExt +module BijectorsMooncakeExt if isdefined(Base, :get_extension) - using Tapir: @is_primitive, MinimalCtx, Tapir, CoDual, primal, tangent_type, @from_rrule + using Mooncake: + @is_primitive, MinimalCtx, Mooncake, CoDual, primal, tangent_type, @from_rrule using Bijectors: find_alpha, ChainRulesCore else - using ..Tapir: @is_primitive, MinimalCtx, Tapir, primal, tangent_type, @from_rrule + using ..Mooncake: @is_primitive, MinimalCtx, Mooncake, primal, tangent_type, @from_rrule using ..Bijectors: find_alpha, ChainRulesCore end @@ -19,20 +20,20 @@ end # unusual Integer type is encountered. @is_primitive(MinimalCtx, Tuple{typeof(find_alpha),P,P,Integer} where {P<:Base.IEEEFloat}) -function Tapir.rrule!!( +function Mooncake.rrule!!( ::CoDual{typeof(find_alpha)}, x::CoDual{P}, y::CoDual{P}, z::CoDual{I} ) where {P<:Base.IEEEFloat,I<:Integer} # Require that the integer is non-differentiable. - if tangent_type(I) != Tapir.NoTangent + if tangent_type(I) != Mooncake.NoTangent msg = "Integer argument has tangent type $(tangent_type(I)), should be NoTangent." throw(ArgumentError(msg)) end out, pb = ChainRulesCore.rrule(find_alpha, primal(x), primal(y), primal(z)) function find_alpha_pb(dout::P) _, dx, dy, _ = pb(dout) - return Tapir.NoRData(), P(dx), P(dy), Tapir.NoRData() + return Mooncake.NoRData(), P(dx), P(dy), Mooncake.NoRData() end - return Tapir.zero_fcodual(out), find_alpha_pb + return Mooncake.zero_fcodual(out), find_alpha_pb end end diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index bcdb9523..a2c13df1 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -27,9 +27,9 @@ end test_frule(Bijectors.find_alpha, x, y, z) test_rrule(Bijectors.find_alpha, x, y, z) - if @isdefined Tapir + if @isdefined Mooncake rng = Xoshiro(123456) - Tapir.TestUtils.test_rule( + Mooncake.TestUtils.test_rule( rng, Bijectors.find_alpha, x, @@ -37,9 +37,9 @@ end z; is_primitive=true, perf_flag=:none, - interp=Tapir.TapirInterpreter(), + interp=Mooncake.MooncakeInterpreter(), ) - Tapir.TestUtils.test_rule( + Mooncake.TestUtils.test_rule( rng, Bijectors.find_alpha, x, @@ -47,9 +47,9 @@ end 3; is_primitive=true, perf_flag=:none, - interp=Tapir.TapirInterpreter(), + interp=Mooncake.MooncakeInterpreter(), ) - Tapir.TestUtils.test_rule( + Mooncake.TestUtils.test_rule( rng, Bijectors.find_alpha, x, @@ -57,7 +57,7 @@ end UInt32(3); is_primitive=true, perf_flag=:none, - interp=Tapir.TapirInterpreter(), + interp=Mooncake.MooncakeInterpreter(), ) end diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 3e21e693..2e709491 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -7,7 +7,7 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) b in ( :ForwardDiff, :Zygote, - :Tapir, + :Mooncake, :ReverseDiff, :Enzyme, :EnzymeForward, @@ -78,27 +78,39 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) end end - if (AD == "All" || AD == "Tapir") && VERSION >= v"1.10" - rule = Tapir.build_rrule(f, x; safety_on=false) - if :tapir in broken - @test_broken( - isapprox( - Tapir.value_and_gradient!!(rule, f, x)[2][2], - finitediff; - rtol=rtol, - atol=atol, - ) - ) - else - @test( - isapprox( - Tapir.value_and_gradient!!(rule, f, x)[2][2], - finitediff; - rtol=rtol, - atol=atol, - ) - ) + if (AD == "All" || AD == "Mooncake") && VERSION >= v"1.10" + try + Mooncake.build_rrule(f, x) + catch exc + # TODO(penelopeysm): + # @test_throws AssertionError (expr...) doesn't work, unclear why + @test exc isa AssertionError end + # TODO: The above @test_throws happens because of + # https://github.com/compintell/Mooncake.jl/issues/319. If that test + # fails, it probably means that the issue was fixed, in which case + # we can remove that block and uncomment the following instead. + + # rule = Mooncake.build_rrule(f, x) + # if :Mooncake in broken + # @test_broken ( + # isapprox( + # Mooncake.value_and_gradient!!(rule, f, x)[2][2], + # finitediff; + # rtol=rtol, + # atol=atol, + # ) + # ) + # else + # @test( + # isapprox( + # Mooncake.value_and_gradient!!(rule, f, x)[2][2], + # finitediff; + # rtol=rtol, + # atol=atol, + # ) + # ) + # end end return nothing diff --git a/test/bijectors/ordered.jl b/test/bijectors/ordered.jl index b2115fe2..60354005 100644 --- a/test/bijectors/ordered.jl +++ b/test/bijectors/ordered.jl @@ -127,12 +127,12 @@ end end end # Check that the quantiles are reasonable, i.e. within - # 5 standard errors of the true quantiles (and that the MCSE is + # 6 standard errors of the true quantiles (and that the MCSE is # not too large). for i in 1:k for j in 1:length(qts) @test qs_mcse[i, j] < abs(qs_true[i, end] - qs_true[i, 1]) / 2 - @test abs(qs[i, j] - qs_true[i, j]) < 5 * qs_mcse[i, j] + @test abs(qs[i, j] - qs_true[i, j]) < 6 * qs_mcse[i, j] end end end diff --git a/test/runtests.jl b/test/runtests.jl index 914c0e32..638bd15c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,12 +34,12 @@ if VERSION < v"1.9" using Compat: stack end -# Sadly, Tapir.jl cannot be installed on version 1.6, so we have to add it if we're testing +# Sadly, Mooncake.jl cannot be installed on version 1.6, so we have to add it if we're testing # on at least version 1.10. if VERSION >= v"1.10" using Pkg - Pkg.add("Tapir") - using Tapir + Pkg.add("Mooncake") + using Mooncake end const GROUP = get(ENV, "GROUP", "All")