Skip to content

Commit

Permalink
fix some other rules
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Sep 1, 2022
1 parent 6162295 commit c2cab6d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
23 changes: 14 additions & 9 deletions src/extra_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function ChainRulesCore.rrule(::DiffractorRuleConfig, g::∇getindex, Δ)
g(Δ), Δ′′->(nothing, Δ′′[1][g.i...])
end

function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(getindex), xs::Array, i...)
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(getindex), xs::Array{<:Number}, i...)
xs[i...], ∇getindex(xs, i)
end

Expand Down Expand Up @@ -220,26 +220,31 @@ struct BackMap{T}
end
(f::BackMap{N})(args...) where {N} = ∂⃖¹(getfield(f, :f), args...)
back_apply(x, y) = x(y)
back_apply_zero(x) = x(Zero())
back_apply_zero(x) = x(Zero()) # Zero is not defined

function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), f, args::Tuple)
a, b = unzip_tuple(map(BackMap(f), args))
function back(Δ)
function map_back(Δ)
(fs, xs) = unzip_tuple(map(back_apply, b, Δ))
(NoTangent(), sum(fs), xs)
end
function back::ZeroTangent)
(fs, xs) = unzip_tuple(map(back_apply_zero, b))
(NoTangent(), sum(fs), xs)
end
a, back
map_back::AbstractZero) = (NoTangent(), NoTangent(), NoTangent())
# function back(Δ::ZeroTangent)
# (fs, xs) = unzip_tuple(map(back_apply_zero, b))
# (NoTangent(), sum(fs), xs)
# end
a, map_back
end

ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), f, args::Tuple{}) = (), _ -> (NoTangent(), NoTangent(), NoTangent())

function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(Base.ntuple), f, n)
a, b = unzip_tuple(ntuple(BackMap(f), n))
a, function (Δ)
function ntuple_back(Δ)
(NoTangent(), sum(map(back_apply, b, Δ)), NoTangent())
end
ntuple_back(::AbstractZero) = (NoTangent(), NoTangent(), NoTangent())
a, ntuple_back
end

function ChainRules.frule(::DiffractorRuleConfig, _, ::Type{Vector{T}}, undef::UndefInitializer, dims::Int...) where {T}
Expand Down
5 changes: 3 additions & 2 deletions src/stage1/generated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,13 @@ function (::∂⃖{N})(::typeof(Core.getfield), s, field::Symbol) where {N}
end

# TODO: Temporary - make better
function (::∂⃖{N})(::typeof(Base.getindex), a::Array, inds...) where {N}
function (::∂⃖{N})(::typeof(Base.getindex), a::Array{<:Number}, inds...) where {N}
getindex(a, inds...), let
EvenOddOdd{1, c_order(N)}(
(@Base.constprop :aggressive Δ->begin
Δ isa AbstractZero && return (NoTangent(), Δ, map(Returns(Δ), inds)...)
BB = zero(a)
BB[inds...] = Δ
BB[inds...] = unthunk(Δ)
(NoTangent(), BB, map(x->NoTangent(), inds)...)
end),
(@Base.constprop :aggressive (_, Δ, _)->begin
Expand All @@ -334,6 +334,7 @@ struct tuple_back{M}; end
(::tuple_back)(Δ::Tuple) = Core.tuple(NoTangent(), Δ...)
(::tuple_back{N})(Δ::AbstractZero) where {N} = Core.tuple(NoTangent(), ntuple(i->Δ, N)...)
(::tuple_back{N})(Δ::Tangent) where {N} = Core.tuple(NoTangent(), ntuple(i->lifted_getfield(Δ, i), N)...)
(t::tuple_back)(Δ::AbstractThunk) = t(unthunk(Δ))

function (::∂⃖{N})(::typeof(Core.tuple), args::Vararg{Any, M}) where {N, M}
Core.tuple(args...),
Expand Down

0 comments on commit c2cab6d

Please sign in to comment.