Skip to content

Commit

Permalink
Merge pull request #339 from FluxML/cl/scatdim
Browse files Browse the repository at this point in the history
dstsize for scatter
  • Loading branch information
CarloLucibello authored Jul 30, 2021
2 parents cfe4600 + 2c5360e commit 95f9d0b
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "NNlib"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.7.25"
version = "0.7.26"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
15 changes: 9 additions & 6 deletions src/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ end


"""
scatter(op, src, idx; [init])
scatter(op, src, idx; [init, dstsize])
Scatter operation allocating a destination array `dst` and
calling `scatter!(op, dst, src, idx)` on it.
Expand All @@ -83,16 +83,19 @@ If `init` is provided, it is used to initialize the content of `dst`.
Otherwise, the init values is inferred from the reduction operator `op`
for some common operators (e.g. `init = 0` for `op = +`).
If `dstsize` is provided, it will be used to define the size of
destination array, otherwise it will be inferred by `src` and `idx`.
See [`scatter!`](@ref) for the details.
"""
function scatter(op,
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx,Nidx};
init = nothing) where {Tsrc,Tidx,Nsrc,Nidx}
init = nothing, dstsize = nothing) where {Tsrc,Tidx,Nsrc,Nidx}

dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, Tsrc, dstsize)
dstsz = isnothing(dstsize) ? (size(src)[1:dims]..., maximum_dims(idx)...) : dstsize
dst = similar(src, Tsrc, dstsz)
xinit = isnothing(init) ? scatter_empty(op, Tsrc) : init
fill!(dst, xinit)
scatter!(op, dst, src, idx)
Expand Down Expand Up @@ -156,8 +159,8 @@ function rrule(::typeof(scatter!), op, dst::AbstractArray, src::AbstractArray, i
dst, scatter!_pullback
end

function rrule(::typeof(scatter), op, src::AbstractArray, idx::AbstractArray)
y = scatter(op, src, idx)
function rrule(::typeof(scatter), op, src::AbstractArray, idx::AbstractArray; kws...)
y = scatter(op, src, idx; kws...)
scatter_pullback(Δ) = (NoTangent(), NoTangent(), ∇scatter_src(op, Δ, y, src, idx), NoTangent())
y, scatter_pullback
end
16 changes: 12 additions & 4 deletions test/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ res = Dict(
4. 4. 6. 5. 5.],
)

types = [UInt8, UInt16, UInt32, UInt64, UInt128,
Int8, Int16, Int32, Int64, Int128, BigInt,
Float16, Float32, Float64, BigFloat, Rational]
types = [UInt8, UInt32, UInt128,
Int16, Int64, BigInt,
Float32, Float64, Rational]

@testset "scatter" begin
for T = types
Expand Down Expand Up @@ -146,7 +146,7 @@ types = [UInt8, UInt16, UInt32, UInt64, UInt128,
end
end

for T = [Float16, Float32, Float64, BigFloat, Rational]
for T = [Float16, Float32, Rational]
@testset "$T" begin
PT = promote_type(T, Float64)
@testset "/" begin
Expand Down Expand Up @@ -178,6 +178,14 @@ types = [UInt8, UInt16, UInt32, UInt64, UInt128,
@test_throws AssertionError scatter!(+, dsts[0], srcs[(1, true)], idxs[:int])
idx = [1 2 3 4; 4 2 1 3; 6 7 8 9]
@test_throws BoundsError scatter!(+, dsts[1], srcs[(1, true)], idx)

@testset "dstsize" begin
idx = [2, 2, 3, 4, 4]
src = ones(3, 5)
y = scatter(+, src, idx, dstsize=(3, 6))
@test size(y) == (3, 6)
gradtest(x -> scatter(+, x, idx, dstsize=(3,6)), src)
end
end

@testset "∇scatter" begin
Expand Down

2 comments on commit 95f9d0b

@CarloLucibello
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/41846

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.26 -m "<description of version>" 95f9d0bcd92fda888dbbbee48b327ee9912a455b
git push origin v0.7.26

Please sign in to comment.