Skip to content

Commit

Permalink
Use CR OneElement
Browse files Browse the repository at this point in the history
  • Loading branch information
ToucheSir authored Jul 8, 2023
1 parent f8779f8 commit 84b3079
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractFFTs = "1.3.1"
ChainRules = "1.44.1"
ChainRules = "1.51.0"
ChainRulesCore = "1.9"
ChainRulesTestUtils = "1"
Colors = "0.12"
Expand Down
2 changes: 1 addition & 1 deletion src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback,
literal_getproperty, literal_getfield, unthunk_tangent

using ChainRulesCore
using ChainRules: ChainRules, rrule, unthunk, canonicalize
using ChainRules: ChainRules, rrule, unthunk, canonicalize, OneElement
using IRTools
using MacroTools, Requires
using MacroTools: @forward
Expand Down
15 changes: 0 additions & 15 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,6 @@ end
@adjoint (::Type{T})(sz) where {T<:Zeros} = T(sz), Δ->(nothing,)
@adjoint (::Type{T})(sz) where {T<:Ones} = T(sz), Δ->(nothing,)

"""
OneElement(val, ind, axes) <: AbstractArray
Extremely simple `struct` used for the gradient of scalar `getindex`.
"""
struct OneElement{T,N,I,A} <: AbstractArray{T,N}
val::T
ind::I
axes::A
OneElement(val::T, ind::I, axes::A) where {T<:Number, I<:NTuple{N,Int}, A<:NTuple{N,AbstractUnitRange}} where {N} = new{T,N,I,A}(val, ind, axes)
end
Base.size(A::OneElement) = map(length, A.axes)
Base.axes(A::OneElement) = A.axes
Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T))


_zero(xs::AbstractArray{<:Number}, T::Type{Nothing}) = fill!(similar(xs), zero(eltype(xs)))
_zero(xs::AbstractArray{<:Number}, T) = fill!(similar(xs, T), false)
Expand Down

0 comments on commit 84b3079

Please sign in to comment.