Skip to content

Commit

Permalink
Merge pull request #344 from FluxML/dg/cr
Browse files Browse the repository at this point in the history
Add generic fallbacks for rrules not expecting Thunks
  • Loading branch information
DhairyaLGandhi authored Aug 3, 2021
2 parents 95f9d0b + 725a8d4 commit a5ff4b5
Show file tree
Hide file tree
Showing 13 changed files with 55 additions and 50 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "NNlib"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.7.26"
version = "0.7.27"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Adapt = "2, 3.2"
ChainRulesCore = "0.9.45, 0.10"
ChainRulesCore = "0.9.45, 0.10, 1"
Compat = "3.14"
Requires = "0.5, 1.0"
julia = "1.6"
Expand Down
4 changes: 2 additions & 2 deletions src/batched/batchedadjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ Base.unsafe_convert(::Type{Ptr{T}}, A::BatchedAdjOrTrans{T}) where {T} =

# Gradients
function rrule(::typeof(batched_transpose), A::AbstractArray{<:Any,3})
b_transpose_back(Δ) = (NoTangent(), batched_transpose(Δ))
b_transpose_back(Δ) = (NoTangent(), batched_transpose(unthunk(Δ)))
batched_transpose(A), b_transpose_back
end
function rrule(::typeof(batched_adjoint), A::AbstractArray{<:Any,3})
b_adjoint_back(Δ) = (NoTangent(), batched_adjoint(Δ))
b_adjoint_back(Δ) = (NoTangent(), batched_adjoint(unthunk(Δ)))
batched_adjoint(A), b_adjoint_back
end

Expand Down
3 changes: 2 additions & 1 deletion src/batched/batchedmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ end
# Gradient, allowing that size(A,3)==1 means it's "broadcasted" out to size(B,3)

function rrule(::typeof(batched_mul), A::AbstractArray{<:Any,3}, B::AbstractArray{<:Any,3})
function batched_mul_pullback(Δ)
function batched_mul_pullback(_Δ)
Δ = unthunk(_Δ)
Athunk = @thunk begin
tmp = batched_mul(Δ, batched_adjoint(B))
size(A,3) == 1 ? sum(tmp, dims=3) : tmp
Expand Down
8 changes: 4 additions & 4 deletions src/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,8 @@ for conv in [:conv, :depthwiseconv]
Δ = colmajor(Δ)
return (
NoTangent(),
@thunk($∇conv_data(Δ, w, cdims, kw...)),
@thunk($∇conv_filter(x, Δ, cdims, kw...)),
@thunk($∇conv_data(unthunk(Δ), w, cdims, kw...)),
@thunk($∇conv_filter(x, unthunk(Δ), cdims, kw...)),
NoTangent(),
)
end
Expand All @@ -323,8 +323,8 @@ for conv in [:conv, :depthwiseconv]
Δ = colmajor(Δ)
return (
NoTangent(),
@thunk($conv(Δ, w, cdims, kw...)),
@thunk($∇conv_filter(Δ, x, cdims, kw...)),
@thunk($conv(unthunk(Δ), w, cdims, kw...)),
@thunk($∇conv_filter(unthunk(Δ), x, cdims, kw...)),
NoTangent(),
)
end
Expand Down
2 changes: 1 addition & 1 deletion src/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,6 @@ end
function rrule(::typeof(gather!), dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
y = gather!(dst, src, idx)
src_size = size(src)
gather!_pullback(Δ) = (NoTangent(), NoTangent(), ∇gather_src(Δ, src_size, idx), NoTangent())
gather!_pullback(Δ) = (NoTangent(), NoTangent(), ∇gather_src(unthunk(Δ), src_size, idx), NoTangent())
y, gather!_pullback
end
2 changes: 1 addition & 1 deletion src/padding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ function rrule(::typeof(pad_constant), x::AbstractArray{T,N},
function pad_constant_pullback(Δ)
p = gen_pad(pad, dims, N)
outsize, center = size_and_center(x, p)
(NoTangent(), @thunk(Δ[center...]), NoTangent(), NoTangent(),)
(NoTangent(), @thunk(unthunk(Δ)[center...]), NoTangent(), NoTangent(),)
end
return y, pad_constant_pullback
end
Expand Down
2 changes: 1 addition & 1 deletion src/pooling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ for pool in [:maxpool, :meanpool]
pullback = Symbol(pool, :_pullback)
@eval function rrule(::typeof($pool), x, pdims::PoolDims; kw...)
Ω = $pool(x, pdims; kw...)
$pullback(Δ) = (NoTangent(), $∇pool(Δ, Ω, x, pdims; kw...), NoTangent())
$pullback(Δ) = (NoTangent(), $∇pool(unthunk(Δ), Ω, x, pdims; kw...), NoTangent())
return Ω, $pullback
end
end
7 changes: 5 additions & 2 deletions src/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,18 @@ function ∇scatter_src(::typeof(mean), Δ, dst,
divide_by_counts!(gather(Δ, idx), idx, dims)
end

∇scatter_src(op, Δ, dst, src, idx) =
∇scatter_src(op, unthunk(Δ), dst, src, idx)

function rrule(::typeof(scatter!), op, dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
dst_old = copy(dst)
scatter!(op, dst, src, idx)
scatter!_pullback(Δ) = (NoTangent(), NoTangent(), ∇scatter!_dst(op, Δ, dst_old, dst), ∇scatter!_src(op, Δ, dst, src, idx), NoTangent())
scatter!_pullback(Δ) = (NoTangent(), NoTangent(), ∇scatter!_dst(op, unthunk(Δ), dst_old, dst), ∇scatter!_src(op, unthunk(Δ), dst, src, idx), NoTangent())
dst, scatter!_pullback
end

function rrule(::typeof(scatter), op, src::AbstractArray, idx::AbstractArray; kws...)
y = scatter(op, src, idx; kws...)
scatter_pullback(Δ) = (NoTangent(), NoTangent(), ∇scatter_src(op, Δ, y, src, idx), NoTangent())
scatter_pullback(Δ) = (NoTangent(), NoTangent(), ∇scatter_src(op, unthunk(Δ), y, src, idx), NoTangent())
y, scatter_pullback
end
39 changes: 20 additions & 19 deletions src/softmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,18 @@ export softmax,
logsumexp

"""
softmax(x; dims=1)
softmax(x; dims = 1)
[Softmax](https://en.wikipedia.org/wiki/Softmax_function) turns input array `x`
into probability distributions that sum to 1 along the dimensions specified by `dims`.
It is semantically equivalent to the following:
softmax(x; dims=1) = exp.(x) ./ sum(exp.(x), dims=dims)
softmax(x; dims = 1) = exp.(x) ./ sum(exp.(x), dims = dims)
with additional manipulations enhancing numerical stability.
For a matrix input `x` it will by default (`dims=1`) treat it as a batch of vectors,
with each column independent. Keyword `dims=2` will instead treat rows independently,
etc...
For a matrix input `x` it will by default (`dims = 1`) treat it as a batch of vectors,
with each column independent. Keyword `dims = 2` will instead treat rows independently, and so on.
See also [`logsoftmax`](@ref).
Expand Down Expand Up @@ -61,9 +60,10 @@ end

∇softmax::AbstractArray{T}, x::AbstractArray, y::AbstractArray{S}; dims = 1) where {T,S} =
∇softmax!(similar(y, promote_type(T, S)), Δ, x, y; dims = dims)
∇softmax(Δ, x, y; dims = 1) = ∇softmax(unthunk(Δ), x, y, dims = dims)

## Can introduce at the end of deprecation cycle of ∇softmax!(out, Δ, x; dims = 1)
#∇softmax!(Δ, x, y; dims = 1) = ∇softmax!(Δ, Δ, x, y; dims = dims)
# Can introduce at the end of deprecation cycle of ∇softmax!(out, Δ, x; dims = 1)
# ∇softmax!(Δ, x, y; dims = 1) = ∇softmax!(Δ, Δ, x, y; dims = dims)

function ∇softmax!(out::AbstractArray, Δ::AbstractArray,
x::AbstractArray, y::AbstractArray; dims = 1)
Expand All @@ -72,26 +72,26 @@ function ∇softmax!(out::AbstractArray, Δ::AbstractArray,
end

# Old 2-arg version recomputing forward
∇softmax(Δ, x; dims=1) = ∇softmax(Δ, x, softmax(x, dims=dims); dims=dims)
∇softmax!(Δ, x; dims=1) = ∇softmax!(Δ, Δ, x, softmax(x, dims=dims); dims=dims)
∇softmax!(out, Δ, x; dims=1) = ∇softmax!(out, Δ, x, softmax(x, dims=dims); dims=dims)
∇softmax(Δ, x; dims = 1) = ∇softmax(Δ, x, softmax(x, dims = dims); dims = dims)
∇softmax!(Δ, x; dims = 1) = ∇softmax!(Δ, Δ, x, softmax(x, dims = dims); dims = dims)
∇softmax!(out, Δ, x; dims = 1) = ∇softmax!(out, Δ, x, softmax(x, dims = dims); dims = dims)

function rrule(::typeof(softmax), xs; dims=1)
y = softmax(xs; dims=dims)
softmax_pullback(Δ) = (NoTangent(), ∇softmax(Δ, xs, y, dims=dims))
softmax_pullback(Δ) = (NoTangent(), ∇softmax(unthunk(Δ), xs, y, dims = dims))
return y, softmax_pullback
end

"""
logsoftmax(x; dims=1)
logsoftmax(x; dims = 1)
Computes the log of softmax in a more numerically stable
way than directly taking `log.(softmax(xs))`. Commonly used in
computing cross entropy loss.
It is semantically equivalent to the following:
logsoftmax(x; dims=1) = x .- log.(sum(exp.(x), dims=dims))
logsoftmax(x; dims = 1) = x .- log.(sum(exp.(x), dims = dims))
See also [`softmax`](@ref).
"""
Expand All @@ -112,11 +112,12 @@ end

∇logsoftmax::AbstractArray{T}, x::AbstractArray, y::AbstractArray{S}; dims = 1) where {T,S} =
∇logsoftmax!(similar(y, promote_type(T, S)), Δ, x, y; dims = dims)
∇logsoftmax(Δ, x, y; dims = 1) = ∇logsoftmax(unthunk(Δ), x, y, dims = dims)

# Old 2-arg version recomputing forward
∇logsoftmax(Δ, x; dims=1) = ∇logsoftmax(Δ, x, logsoftmax(x, dims=dims); dims=dims)
∇logsoftmax!(Δ, x; dims=1) = ∇logsoftmax!(Δ, Δ, x, logsoftmax(x, dims=dims); dims=dims)
∇logsoftmax!(out, Δ, x; dims=1) = ∇logsoftmax!(out, Δ, x, logsoftmax(x, dims=dims); dims=dims)
∇logsoftmax(Δ, x; dims = 1) = ∇logsoftmax(Δ, x, logsoftmax(x, dims = dims); dims = dims)
∇logsoftmax!(Δ, x; dims = 1) = ∇logsoftmax!(Δ, Δ, x, logsoftmax(x, dims = dims); dims = dims)
∇logsoftmax!(out, Δ, x; dims = 1) = ∇logsoftmax!(out, Δ, x, logsoftmax(x, dims = dims); dims = dims)

function ∇logsoftmax!(out::AbstractArray, Δ::AbstractArray,
x::AbstractArray, y::AbstractArray; dims = 1)
Expand All @@ -125,14 +126,14 @@ end

function rrule(::typeof(logsoftmax), xs; dims=1)
y = logsoftmax(xs; dims=dims)
logsoftmax_pullback(Δ) = (NoTangent(), ∇logsoftmax(Δ, xs, y, dims=dims))
logsoftmax_pullback(Δ) = (NoTangent(), ∇logsoftmax(unthunk(Δ), xs, y, dims = dims))
return y, logsoftmax_pullback
end

"""
logsumexp(x; dims=:)
logsumexp(x; dims = :)
Computes `log.(sum(exp.(x); dims=dims))` in a numerically stable
Computes `log.(sum(exp.(x); dims = dims))` in a numerically stable
way.
See also [`logsoftmax`](@ref).
Expand Down
8 changes: 4 additions & 4 deletions src/upsample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ end

function rrule(::typeof(upsample_nearest), x::AbstractArray, s::Tuple)
Ω = upsample_nearest(x, s)
upsample_nearest_pullback(Δ) = (NoTangent(), ∇upsample_nearest(Δ, s), NoTangent())
upsample_nearest_pullback(Δ) = (NoTangent(), ∇upsample_nearest(unthunk(Δ), s), NoTangent())
return Ω, upsample_nearest_pullback
end

Expand Down Expand Up @@ -203,7 +203,7 @@ end
function rrule(::typeof(upsample_linear), x; size)
Ω = upsample_linear(x; size=size)
function upsample_linear_pullback(Δ)
(NoTangent(), ∇upsample_linear(Δ; size=Base.size(x,1)))
(NoTangent(), ∇upsample_linear(unthunk(Δ); size=Base.size(x,1)))
end
return Ω, upsample_linear_pullback
end
Expand Down Expand Up @@ -368,7 +368,7 @@ end
function rrule(::typeof(upsample_bilinear), x; size)
Ω = upsample_bilinear(x; size=size)
function upsample_bilinear_pullback(Δ)
(NoTangent(), ∇upsample_bilinear(Δ; size=(Base.size(x,1),Base.size(x,2))))
(NoTangent(), ∇upsample_bilinear(unthunk(Δ); size=(Base.size(x,1),Base.size(x,2))))
end
return Ω, upsample_bilinear_pullback
end
Expand Down Expand Up @@ -518,7 +518,7 @@ end
function rrule(::typeof(upsample_trilinear), x; size)
Ω = upsample_trilinear(x; size=size)
function upsample_trilinear_pullback(Δ)
(NoTangent(), ∇upsample_trilinear(Δ; size=(Base.size(x,1), Base.size(x,2), Base.size(x,3))))
(NoTangent(), ∇upsample_trilinear(unthunk(Δ); size=(Base.size(x,1), Base.size(x,2), Base.size(x,3))))
end
return Ω, upsample_trilinear_pullback
end
Expand Down
8 changes: 4 additions & 4 deletions test/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ res = Dict(

types = [UInt8, UInt32, UInt128,
Int16, Int64, BigInt,
Float32, Float64, Rational]
Float16, Float32, Float64, BigFloat, Rational]

@testset "scatter" begin
for T = types
Expand Down Expand Up @@ -146,7 +146,7 @@ types = [UInt8, UInt32, UInt128,
end
end

for T = [Float16, Float32, Rational]
for T = [Float16, Float32, BigFloat, Rational]
@testset "$T" begin
PT = promote_type(T, Float64)
@testset "/" begin
Expand Down Expand Up @@ -182,9 +182,9 @@ types = [UInt8, UInt32, UInt128,
@testset "dstsize" begin
idx = [2, 2, 3, 4, 4]
src = ones(3, 5)
y = scatter(+, src, idx, dstsize=(3, 6))
y = scatter(+, src, idx, dstsize = (3, 6))
@test size(y) == (3, 6)
gradtest(x -> scatter(+, x, idx, dstsize=(3,6)), src)
gradtest(x -> scatter(+, x, idx, dstsize = (3,6)), src)
end
end

Expand Down
16 changes: 8 additions & 8 deletions test/softmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,16 @@ end

@testset "AutoDiff" begin
for f in (softmax, logsoftmax), d in (:, 1, 2)
gradtest(f, (3,4); fkwargs=(; dims=d), check_rrule=true)
gradtest(f, (3,4); fkwargs = (dims = d,), check_rrule = true)
end
gradtest(x -> softmax(x).*(1:3), 3)
gradtest(x -> softmax(x).*(1:3), (3,5), atol=1e-4)
gradtest(x -> softmax(x, dims=2).*(1:3), (3,5), atol=1e-4)
gradtest(x -> logsoftmax(x).*(1:3), 3)
gradtest(x -> logsoftmax(x).*(1:3), (3,5))
gradtest(x -> logsoftmax(x, dims=2).*(1:3), (3,5))
gradtest(x -> softmax(x) .* (1:3), 3)
gradtest(x -> softmax(x) .* (1:3), (3,5), atol = 1e-4)
gradtest(x -> softmax(x, dims = 2) .* (1:3), (3,5), atol = 1e-4)
gradtest(x -> logsoftmax(x) .* (1:3), 3)
gradtest(x -> logsoftmax(x) .* (1:3), (3,5))
gradtest(x -> logsoftmax(x, dims = 2) .* (1:3), (3,5))

for d in (:, 1, 2)
gradtest(logsumexp, (3,4), fkwargs=(; dims=d))
gradtest(logsumexp, (3,4), fkwargs = (dims = d,))
end
end
2 changes: 1 addition & 1 deletion test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ given by Zygote. `f` has to be a scalar valued function.
Applies also `ChainRulesTestUtils.test_rrule` if the rrule for `f` is explicitly defined.
"""
function gradtest(f, xs...; atol = 1e-6, rtol = 1e-6, fkwargs=NamedTuple(),
function gradtest(f, xs...; atol = 1e-6, rtol = 1e-6, fkwargs = NamedTuple(),
check_rrule = false,
fdm = :central,
check_broadcast = false,
Expand Down

2 comments on commit a5ff4b5

@DhairyaLGandhi
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/42124

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.27 -m "<description of version>" a5ff4b54cf162104b40f55c8b8a86710e7c2051a
git push origin v0.7.27

Please sign in to comment.