From cd79718f5acea2347e0a5a9f2464db1d6c5a26e3 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 31 Aug 2022 22:00:22 -0400 Subject: [PATCH 1/5] fixes --- src/rulesets/Base/broadcast.jl | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index d376a64f0..bdc419444 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -59,12 +59,13 @@ function split_bc_derivatives(f::F, arg) where {F} @debug("split broadcasting derivative", f) ys = f.(arg) function bc_one_back(dys) # For f.(x) we do not need StructArrays / unzip at all - delta = broadcast(unthunk(dys), ys, arg) do dy, y, a + delta = broadcast(dys, ys, arg) do dy, y, a das = only(derivatives_given_output(y, f, a)) dy * conj(only(das)) # possibly this * should be made nan-safe. end return (TRI_NO..., ProjectTo(arg)(delta)) end + bc_one_back(dys::AbstractThunk) = bc_one_back(unthunk(dys)) bc_one_back(z::AbstractZero) = (TRI_NO..., z) return ys, bc_one_back end @@ -72,13 +73,14 @@ function split_bc_derivatives(f::F, args::Vararg{Any,N}) where {F,N} @debug("split broadcasting derivatives", f, N) ys = f.(args...) function bc_many_back(dys) - deltas = unzip_broadcast(unthunk(dys), ys, args...) do dy, y, as... + deltas = unzip_broadcast(dys, ys, args...) do dy, y, as... das = only(derivatives_given_output(y, f, as...)) map(da -> dy * conj(da), das) # possibly this * should be made nan-safe. end dargs = map(unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of unzip_broadcast? return (TRI_NO..., dargs...) end + bc_many_back(dys::AbstractThunk) = bc_many_back(unthunk(dys)) bc_many_back(z::AbstractZero) = (TRI_NO..., map(Returns(z), args)...) return ys, bc_many_back end @@ -109,11 +111,12 @@ function split_bc_inner(frule_fun::R, cfg::RuleConfig, f::F, arg) where {R,F} frule_fun(cfg, (NoTangent(), one(a)), f, a) end function back_forwards(dys) - delta = broadcast(ydots, unthunk(dys), arg) do ydot, dy, a + delta = broadcast(ydots, dys, arg) do ydot, dy, a ProjectTo(a)(conj(ydot) * dy) # possibly this * should be made nan-safe. end return (TRI_NO..., ProjectTo(arg)(delta)) end + back_forwards(dys::AbstractThunk) = back_forwards(unthunk(dys)) back_forwards(z::AbstractZero) = (TRI_NO..., z) return ys, back_forwards end @@ -128,13 +131,14 @@ function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N} rrule_via_ad(cfg, f, a...) end function back_generic(dys) - deltas = unzip_broadcast(backs, unthunk(dys)) do back, dy # (could be map, sizes match) + deltas = unzip_broadcast(backs, dys) do back, dy # (could be map, sizes match) map(unthunk, back(dy)) end dargs = map(unbroadcast, args, Base.tail(deltas)) df = ProjectTo(f)(sum(first(deltas))) return (NoTangent(), NoTangent(), df, dargs...) end + back_generic(dys::AbstractThunk) = back_generic(unthunk(dys)) back_generic(z::AbstractZero) = (TRI_NO..., map(Returns(z), args)...) return ys3, back_generic end @@ -318,7 +322,7 @@ rrule(::typeof(broadcasted), ::typeof(complex), x::Number) = rrule(complex, x) | function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx_raw) dx = unthunk(dx_raw) - N = ndims(dx) + N = _ndims(dx) if length(x) == length(dx) ProjectTo(x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors else @@ -328,6 +332,9 @@ function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx_raw) end unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::AbstractZero) = dx +_ndims(x) = ndims(x) +_ndims(::Tuple) = 1 + function unbroadcast(x::T, dx_raw) where {T<:Tuple{Vararg{Any,N}}} where {N} dx = unthunk(dx_raw) val = if N == length(dx) From b931a4cea9e8808a9275259b0c38be917e6d33b6 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 31 Aug 2022 22:13:38 -0400 Subject: [PATCH 2/5] a test --- test/rulesets/Base/broadcast.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl index 219b45a71..331616cce 100644 --- a/test/rulesets/Base/broadcast.jl +++ b/test/rulesets/Base/broadcast.jl @@ -176,5 +176,6 @@ BT1 = Broadcast.BroadcastStyle(Tuple) @testset "bugs" begin @test ChainRules.unbroadcast((1, 2, [3]), [4, 5, [6]]) isa Tangent # earlier, NTuple demanded same type + @test ChainRules.unbroadcast(broadcasted(-, (1, 2), 3), (4, 5)) == (4, 5) # earlier, called ndims(::Tuple) end end \ No newline at end of file From be8bb594e3972ec25f8ed130bfd033f844c8728a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 1 Sep 2022 21:49:41 -0400 Subject: [PATCH 3/5] revert some as the broke inference --- src/rulesets/Base/broadcast.jl | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index bdc419444..3519a097b 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -59,13 +59,12 @@ function split_bc_derivatives(f::F, arg) where {F} @debug("split broadcasting derivative", f) ys = f.(arg) function bc_one_back(dys) # For f.(x) we do not need StructArrays / unzip at all - delta = broadcast(dys, ys, arg) do dy, y, a + delta = broadcast(unthunk(dys), ys, arg) do dy, y, a das = only(derivatives_given_output(y, f, a)) dy * conj(only(das)) # possibly this * should be made nan-safe. end return (TRI_NO..., ProjectTo(arg)(delta)) end - bc_one_back(dys::AbstractThunk) = bc_one_back(unthunk(dys)) bc_one_back(z::AbstractZero) = (TRI_NO..., z) return ys, bc_one_back end @@ -73,14 +72,13 @@ function split_bc_derivatives(f::F, args::Vararg{Any,N}) where {F,N} @debug("split broadcasting derivatives", f, N) ys = f.(args...) function bc_many_back(dys) - deltas = unzip_broadcast(dys, ys, args...) do dy, y, as... + deltas = unzip_broadcast(unthunk(dys), ys, args...) do dy, y, as... das = only(derivatives_given_output(y, f, as...)) map(da -> dy * conj(da), das) # possibly this * should be made nan-safe. end dargs = map(unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of unzip_broadcast? return (TRI_NO..., dargs...) end - bc_many_back(dys::AbstractThunk) = bc_many_back(unthunk(dys)) bc_many_back(z::AbstractZero) = (TRI_NO..., map(Returns(z), args)...) return ys, bc_many_back end @@ -111,12 +109,11 @@ function split_bc_inner(frule_fun::R, cfg::RuleConfig, f::F, arg) where {R,F} frule_fun(cfg, (NoTangent(), one(a)), f, a) end function back_forwards(dys) - delta = broadcast(ydots, dys, arg) do ydot, dy, a + delta = broadcast(ydots, unthunk(dys), arg) do ydot, dy, a ProjectTo(a)(conj(ydot) * dy) # possibly this * should be made nan-safe. end return (TRI_NO..., ProjectTo(arg)(delta)) end - back_forwards(dys::AbstractThunk) = back_forwards(unthunk(dys)) back_forwards(z::AbstractZero) = (TRI_NO..., z) return ys, back_forwards end From 4c6e33d97aaf064d433a8d39b7696d873026606b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 1 Sep 2022 21:50:06 -0400 Subject: [PATCH 4/5] ignore contents of at-debug macro --- src/rulesets/Base/nondiff.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/rulesets/Base/nondiff.jl b/src/rulesets/Base/nondiff.jl index 4d208a95b..22aeb1748 100644 --- a/src/rulesets/Base/nondiff.jl +++ b/src/rulesets/Base/nondiff.jl @@ -477,6 +477,11 @@ end @non_differentiable Broadcast.result_style(::Any) @non_differentiable Broadcast.result_style(::Any, ::Any) +@non_differentiable Base.CoreLogging.current_logger_for_env(::Any...) +@non_differentiable Base.CoreLogging._invoked_shouldlog(::Any...) +@non_differentiable Base.CoreLogging.Base.fixup_stdlib_path(::Any) +@non_differentiable Base.CoreLogging.handle_message(::Any...) + @non_differentiable Libc.free(::Any) @non_differentiable Libc.getpid() @non_differentiable Libc.strptime(::AbstractString) From 56e5451f9d88080def14af7b75999b924d12d1bd Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 1 Sep 2022 21:52:20 -0400 Subject: [PATCH 5/5] version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 581934ac9..7ea574d09 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.44.5" +version = "1.44.6" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"