Skip to content

Commit

Permalink
fixup, rm many comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Oct 18, 2022
1 parent 7f56d8d commit b88d70f
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 72 deletions.
20 changes: 0 additions & 20 deletions src/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,26 +121,6 @@ end
# Path 4: The most generic, save all the pullbacks. Can be 1000x slower.
# While broadcast makes no guarantee about order of calls, it's cheap to reverse the iteration.

#=
julia> Yota.grad(xs -> sum(abs2, (x -> abs(x)).(xs)), [1,2,3.0])
┌ Debug: split broadcasting generic
│ f = #69 (generic function with 1 method)
│ N = 1
└ @ ChainRules ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:126
(14.0, (ZeroTangent(), [2.0, 4.0, 6.0]))
julia> ENV["JULIA_DEBUG"] = nothing
julia> @btime Yota.grad(xs -> sum(abs2, (x -> abs(x)).(xs)), $(rand(1000)));
min 1.321 ms, mean 1.434 ms (23010 allocations, 594.66 KiB) # with unzip_map, as before
min 1.279 ms, mean 1.393 ms (23029 allocations, 595.73 KiB) # with unzip_map_reversed
julia> @btime Yota.grad(xs -> sum(abs2, abs.(xs)), $(randn(1000))); # Debug: split broadcasting derivative
min 2.144 μs, mean 6.620 μs (6 allocations, 23.88 KiB)
=#

function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N}
@debug("split broadcasting generic", f, N)
ys3, backs = unzip_broadcast(args...) do a...
Expand Down
49 changes: 0 additions & 49 deletions src/rulesets/Base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,57 +17,10 @@ function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect), gen::G) whe
ys, generator_pullback
end

# Needed for Yota, but shouldn't these be automatic?
ChainRulesCore.rrule(::Type{<:Base.Generator}, f, iter) = Base.Generator(f, iter), dy -> (NoTangent(), dy.f, dy.iter)
ChainRulesCore.rrule(::Type{<:Iterators.ProductIterator}, iters) = Iterators.ProductIterator(iters), dy -> (NoTangent(), dy.iterators)

#=
Yota.grad(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3)
Diffractor.gradient(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3)
Yota.grad((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) # ERROR: all field arrays must have same shape
Diffractor.gradient((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) # ERROR: type Array has no field iterators
Yota.grad(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3)
Diffractor.gradient(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3) # fails internally
Yota.grad(xs -> sum(abs, [sin(x/y) for (x,y) in zip(xs, 1:2)]), [1,2,3]pi/3)
Diffractor.gradient(xs -> sum(abs, [sin(x/y) for (x,y) in zip(xs, 1:2)]), [1,2,3]pi/3)
Yota.grad(xs -> sum(abs, map((x,y) -> sin(x/y), xs, 1:2)), [1,2,3]pi/3)
Diffractor.gradient(xs -> sum(abs, map((x,y) -> sin(x/y), xs, 1:2)), [1,2,3]pi/3)
@btime Yota.grad($(rand(1000))) do xs
sum(abs2, [sqrt(x) for x in xs])
end
# Yota min 759.000 μs, mean 800.754 μs (22041 allocations, 549.62 KiB)
# Diffractor min 559.000 μs, mean 622.464 μs (18051 allocations, 612.34 KiB)
# Zygote min 3.198 μs, mean 6.849 μs (20 allocations, 40.11 KiB)
@btime Yota.grad($(rand(1000)), $(rand(1000))) do xs, ys
zs = map(xs, ys) do x, y
atan(x/y)
end
sum(abs2, zs)
end
# Yota + CR: min 1.598 ms, mean 1.691 ms (38030 allocations, 978.75 KiB)
# Diffractor + CR: min 767.250 μs, mean 847.640 μs (26045 allocations, 838.66 KiB)
# Zygote: min 13.417 μs, mean 22.896 μs (26 allocations, 79.59 KiB) -- 100x faster
=#


#####
##### `zip`
#####


function rrule(::typeof(zip), xs::AbstractArray...)
function zip_pullback(dy)
@debug "zip array pullback" summary(dy)
Expand All @@ -94,8 +47,6 @@ function _unmap_pad(x::AbstractArray, dx::AbstractArray)
@debug "_unmap_pad is extending gradient" length(x) == length(dx)
i1 = firstindex(x)
∇getindex(x, vec(dx), i1:i1+length(dx)-1)
# dx2 = vcat(vec(dx), similar(x, ZeroTangent, length(x) - length(dx)))
# ProjectTo(x)(reshape(dx2, axes(x)))
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ BT1 = Broadcast.BroadcastStyle(Tuple)

@testset "split 2: derivatives" begin
test_rrule(copybroadcasted, BS1, log, rand(3) .+ 1)
test_rrule(copybroadcasted, BT1, log, Tuple(rand(3) .+ 1))
test_rrule(copybroadcasted, BT1, log, Tuple(rand(3) .+ 1), check_inferred=false) # return type Tuple{NoTangent, NoTangent, NoTangent, Tangent{Tuple{Float64, Float64, Float64}, Tuple{Float64, Float64, Float64}}} does not match inferred return type Tuple{NoTangent, NoTangent, NoTangent, Union{NoTangent, Tangent{Tuple{Float64, Float64, Float64}, Tuple{Float64, Float64, Float64}}}}

# Two args uses StructArrays
test_rrule(copybroadcasted, BS1, atan, rand(3), rand(3))
Expand Down
4 changes: 2 additions & 2 deletions test/rulesets/Base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

y2, bk2 = rrule(CFG, collect, Iterators.map(Counter(), [11, 12, 13.0]))
@test y2 == map(Counter(), 11:13)
@test bk2(ones(3))[2].iter == [93, 83, 73]
@test bk2(ones(3))[2].iter == [33, 23, 13]
end
end

Expand All @@ -23,4 +23,4 @@ end
test_rrule(collectzip, rand(3), rand(5))
test_rrule(collectzip, rand(3,2), rand(5))
end
end
end

0 comments on commit b88d70f

Please sign in to comment.