Skip to content

Commit

Permalink
Remove bool rule too and correct test
Browse files Browse the repository at this point in the history
  • Loading branch information
ToucheSir authored Sep 6, 2023
1 parent c5042e2 commit 1037852
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 15 deletions.
13 changes: 0 additions & 13 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -329,19 +329,6 @@ end
end

# Reductions
#=
@adjoint function sum(xs::AbstractArray; dims = :)
if dims === (:)
sum(xs), Δ -> (Fill(Δ, size(xs)),)
else
sum(xs, dims = dims), Δ -> (similar(xs) .= Δ,)
end
end
=#

@adjoint function sum(xs::AbstractArray{Bool}; dims = :)
sum(xs, dims = dims), Δ -> (nothing,)
end

function _pullback(cx::AContext, ::typeof(prod), f, xs::AbstractArray)
return _pullback(cx, (f, xs) -> prod(f.(xs)), f, xs)
Expand Down
4 changes: 2 additions & 2 deletions test/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ end
@testset "dictionary comprehension" begin
d = Dict(1 => 5, 2 => 6)
g = gradient(d -> sum([v^2 for (_,v) in d]), d)[1]
@test g isa Dict{Int, Int}
@test g == Dict(1 => 10, 2 => 12)
@test g isa Dict{Int, Float64}
@test g == Dict(1 => 10.0, 2 => 12.0)


w = randn(5)
Expand Down

0 comments on commit 1037852

Please sign in to comment.