From 68ab65d04e3f04b90461178c1b5e10e4d00a59c1 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Sat, 18 Mar 2023 21:31:43 -0700 Subject: [PATCH 1/3] Improve literal function and module detection Previously, instrumentation would be thrown off by storing a function in a variable/SSA value. Likewise, module references like `Main.Base` would be ignored because they're lowered as GlobalRefs. --- src/compiler/reverse.jl | 82 ++++++++++++++++++++++++----------------- test/compiler.jl | 2 + 2 files changed, 50 insertions(+), 34 deletions(-) diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl index 333323e83..992b4fbab 100644 --- a/src/compiler/reverse.jl +++ b/src/compiler/reverse.jl @@ -28,6 +28,18 @@ unwrapquote(x::QuoteNode) = x.value is_getproperty(ex) = iscall(ex, Base, :getproperty) +# Allows us to resolve constants which have been stored in Variables. +# e.g. `%1 = 1; %2 = %1``, or `%1 = identity; %1(...)`. +trylookup(ir::IR, @nospecialize(v)) = v +trylookup(ir::IR, v::Variable) = haskey(ir, v) ? trylookup(ir, ir[v].expr) : v +# Only resolve GlobalRefs to traverse module hierarchies +function trylookup(ir::IR, ref::GlobalRef) + isconst(ref.mod, ref.name) || return ref + val = getproperty(ref.mod, ref.name) + return val isa Module ? val : ref +end + + # The initial premise of literal_getproperty was in some ways inherently flawed, because for # getproperty it was intended that _pullback falls back to literal_getproperty, but we actually # want the opposite to happen, since Zygote should fall back to recursing into the getproperty @@ -35,16 +47,19 @@ is_getproperty(ex) = iscall(ex, Base, :getproperty) # literal_getproperty, though. We can't really have mutually recursive definitions here, so we # now always instrument getproperty as literal_getproperty, no matter whether the second # argument is a literal or not. -function instrument_getproperty!(ir, v, ex) - if is_getproperty(ex) - obj, prop = ex.args[2], ex.args[3] - if obj isa Module && prop isa QuoteNode && isconst(obj, unwrapquote(prop)) +function instrument_getproperty!(ir::Pipe, v, ex) + func = trylookup(ir.from, ex.args[1]) + if func == GlobalRef(Base, :getproperty) && length(ex.args) >= 3 + obj, prop = ex.args[2], trylookup(ir.from, ex.args[3]) + original = trylookup(ir.from, obj) + if original isa Module && prop isa QuoteNode && isconst(original, unwrapquote(prop)) # Metaprogramming can generate getproperty(::Module, ...) calls. # Like other types, these are type unstable without constprop. # However, literal_getproperty's heuristic is also not general enough for modules. # Thankfully, we can skip instrumenting these if they're const properties. - ex - elseif prop isa Union{QuoteNode,Integer} + return ex + end + if prop isa Union{QuoteNode,Integer} ir[v] = xcall(Zygote, :literal_getproperty, obj, Val(unwrapquote(prop))) else f = insert!(ir, v, :(Val($(prop)))) @@ -55,45 +70,48 @@ function instrument_getproperty!(ir, v, ex) end end -is_literal_getfield(ex) = - (iscall(ex, Core, :getfield) || iscall(ex, Base, :getfield)) && - ex.args[3] isa Union{QuoteNode,Integer} # Here, only instrumenting getfield with literals is fine, since users should never have to # define custom adjoints for literal_getfield function instrument_getfield!(ir, v, ex) - if is_literal_getfield(ex) - ir[v] = xcall(Zygote, :literal_getfield, ex.args[2], Val(unwrapquote(ex.args[3]))) - else - ex + func = trylookup(ir.from, ex.args[1]) + if func == GlobalRef(Core, :getfield) || func == GlobalRef(Base, :getfield) + obj, field = ex.args[2], trylookup(ir.from, ex.args[3]) + if field isa Union{QuoteNode,Integer} + call = xcall(Zygote, :literal_getfield, obj, Val(unwrapquote(field))) + return ir[v] = call + end end + return ex end -is_literal_getindex(ex) = - iscall(ex, Base, :getindex) && length(ex.args) == 3 && ex.args[3] isa Union{Integer,QuoteNode} - # TODO: is this always correct for user defined getindex methods? function instrument_getindex!(ir, v, ex) - if is_literal_getindex(ex) - ir[v] = xcall(Zygote, :literal_getindex, ex.args[2], Val(unwrapquote(ex.args[3]))) - else - ex + func = trylookup(ir.from, ex.args[1]) + if func == GlobalRef(Base, :getindex) && length(ex.args) == 3 + obj, idx = ex.args[2], trylookup(ir.from, ex.args[3]) + if idx isa Union{QuoteNode,Integer} + call = xcall(Zygote, :literal_getindex, obj, Val(unwrapquote(idx))) + return ir[v] = call + end end + return ex end -is_literal_iterate(ex) = - iscall(ex, Base, :indexed_iterate) && length(ex.args) >= 3 && ex.args[3] isa Union{Integer,QuoteNode} - function instrument_iterate!(ir, v, ex) - if is_literal_iterate(ex) - ir[v] = xcall(Zygote, :literal_indexed_iterate, ex.args[2], - Val(unwrapquote(ex.args[3])), ex.args[4:end]...) - else - ex + func = trylookup(ir.from, ex.args[1]) + if func == GlobalRef(Base, :indexed_iterate) && length(ex.args) >= 3 + obj, idx, rest = ex.args[2], trylookup(ir.from, ex.args[3]), ex.args[4:end] + if idx isa Union{QuoteNode,Integer} + call = xcall(Zygote, :literal_indexed_iterate, obj, Val(unwrapquote(idx)), rest...) + return ir[v] = call + end end + return ex end function instrument_literals!(ir, v, ex) + isexpr(ex, :call) || return ex ex = instrument_getproperty!(ir, v, ex) ex = instrument_getfield!(ir, v, ex) ex = instrument_getindex!(ir, v, ex) @@ -177,14 +195,10 @@ ignored_f(ir, f) = ignored_f(f) ignored_f(ir, f::Variable) = ignored_f(get(ir, f, nothing)) function ignored(ir, ex) - isexpr(ex, :call) || return false - f = ex.args[1] + f = trylookup(ir, ex.args[1]) ignored_f(ir, f) && return true - if f isa Variable && haskey(ir, f) - f = ir[f].expr - end if f == GlobalRef(Base, :getproperty) && length(ex.args) >= 3 - obj, prop = ex.args[2], ex.args[3] + obj, prop = trylookup(ir, ex.args[2]), trylookup(ir, ex.args[3]) # Metaprogramming can generate getproperty(::Module, ...) calls. # These are type unstable without constprop, which transforming to _pullback breaks. # However, we can skip differentiating these if they're const properties. diff --git a/test/compiler.jl b/test/compiler.jl index c9b091f78..214aeb2fd 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -169,6 +169,7 @@ module MyMod end @eval usesmod(x) = Base.getproperty($MyMod, :func)(x, Base.getproperty($MyMod, :C)) +usesmod2(x) = Base.getproperty(MyMod, :func)(x, Base.getproperty(MyMod, :C)) @testset "inference for `getproperty`" begin Gaussian = _Gaussian(:getproperty) @@ -221,6 +222,7 @@ end # Const properties on modules should be lowered as-is (not differentiated) @test @inferred gradient(usesmod, 1)[1] == 1.0 + @test @inferred gradient(usesmod2, 1)[1] == 1.0 end # issue 897 From 4fb6daa56133bfdcd2301b19af66cfb1b0664cce Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Sat, 18 Mar 2023 22:21:48 -0700 Subject: [PATCH 2/3] Don't lookup for indexed_iterate --- src/compiler/reverse.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl index 992b4fbab..7d00765dc 100644 --- a/src/compiler/reverse.jl +++ b/src/compiler/reverse.jl @@ -99,7 +99,7 @@ function instrument_getindex!(ir, v, ex) end function instrument_iterate!(ir, v, ex) - func = trylookup(ir.from, ex.args[1]) + func = ex.args[1] if func == GlobalRef(Base, :indexed_iterate) && length(ex.args) >= 3 obj, idx, rest = ex.args[2], trylookup(ir.from, ex.args[3]), ex.args[4:end] if idx isa Union{QuoteNode,Integer} From 6d983d5a4e710e814d21d05d4e6d4e5040f0f962 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Sun, 19 Mar 2023 13:10:55 -0700 Subject: [PATCH 3/3] Collapse arrays of CR zeros --- src/compiler/chainrules.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 9b8b60552..fe6b9499d 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -152,6 +152,7 @@ Convert `dx` from the format Zygote uses internally to differentials types Chain @inline wrap_chainrules_input(::Nothing) = ChainRules.ZeroTangent() @inline wrap_chainrules_input(::Tuple{Vararg{Nothing}}) = ChainRules.ZeroTangent() @inline wrap_chainrules_input(::AbstractArray{Nothing}) = ChainRules.ZeroTangent() +@inline wrap_chainrules_input(dxs::AbstractArray{T}) where {T<:AbstractZero} = first(dxs) @inline function wrap_chainrules_input(dxs::Union{Tuple, NamedTuple}) xp = map(wrap_chainrules_input, dxs) # This produces Tangent{Any} since it does not get to see the primal, `x`.