From c8fe70db7b03ca34658cd3b07a93ce38857da9b2 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Sat, 30 Dec 2023 20:28:02 -0800 Subject: [PATCH 1/2] Remove superfluous methods on `Grads` As far as I can tell, these were never used, tested or documented. Moreover, they don't make sense semantically since `Grads` behaves like a Dict rather than an ordinal-indexed collection like an Array. Meanwhile, their continued existence is causing issues like https://github.com/FluxML/Zygote.jl/issues/1484. --- src/compiler/interface.jl | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index c09d6db31..99833b9b3 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -375,40 +375,6 @@ function Base.copy(gs::Grads) merge!(gs_new, gs) end -broadcasted(f, gs::Grads, gss::ADictOrGrads...) = map(f, gs, gss...) - -broadcasted(f, a::Numeric, gs::Grads) = map(x -> f(a, x), gs) -broadcasted(f, gs::Grads, a::Numeric) = map(x -> f(x, a), gs) - -function materialize!(gs1::Grads, gs2::Grads) - issetequal(gs1.params, gs2.params) || - throw(ArgumentError("Expected Grads objects with the same Params.")) - for p in gs1.params - gs1[p] = gs2[p] - end - return gs1 -end - - -function Base.map(f, gs1::Grads, gss::ADictOrGrads...) - gsout = Grads(IdDict{Any,Any}(), Params(gs1.params)) - return map!(f, gsout, gs1, gss...) -end - -function Base.map!(f, gsout::Grads, gss::ADictOrGrads...) - all(issetequal(gsout.params, keys(gs)) for gs in gss) || - throw(ArgumentError("map! expects Grads objects with the same Params.")) - for p in gsout.params - gsout[p] = f((_getformap(gs, p) for gs in gss)...) - end - return gsout -end - -function _getformap(gs, p) - g = gs[p] - isnothing(g) ? fill!(similar(p), 0) : g -end - function pullback(f, ps::Params) cx = Context{true}(nothing) y, back = _pullback(cx, f) From 2cd611eed24ed36c0973f60aae87ecf1e8e55d30 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Sat, 30 Dec 2023 20:34:02 -0800 Subject: [PATCH 2/2] Remove broadcast-related ambiguity with `Params` --- src/compiler/interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 99833b9b3..cdd2f2372 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -257,7 +257,7 @@ function Base.delete!(ps::Params, x) return ps end -Base.Broadcast.broadcasted(f, ps::Params) = broadcasted(f, ps.order) +Base.Broadcast.broadcastable(ps::Params) = ps.order @adjoint function Broadcast.broadcasted(f::Function, ps::Params) f.(ps), _ -> throw(ArgumentError("Zygote.Params does not support broadcasting within gradients, try iteration `for p in ps`"))