Skip to content

Commit

Permalink
improve grad perf
Browse files Browse the repository at this point in the history
  • Loading branch information
chengchingwen committed Jun 2, 2024
1 parent c6c0310 commit 587913e
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/functional/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function ChainRulesCore.rrule(config::RuleConfig, ::typeof(dropout), rng::Abstra
m = GetIndexer(IndexerAdaptor(rng), RandomMask(p), _dropout_masksize(x, dims), scale)
function dropout_pullback(Ybar)
= unthunk(Ybar)
thk = @thunk _fast_broadcast(*, Ȳ, m)
thk = @thunk _fast_broadcast2!(*, similar(x), Ȳ, m)
return (NoTangent(), NoTangent(), thk, NoTangent(), NoTangent())
end
return _fast_broadcast(*, x, m), dropout_pullback
Expand Down
2 changes: 1 addition & 1 deletion src/mask/dataless.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ adapt_structure(to::IndexerAdaptor, x::RandomMask) = RandomMask(x.p, adapt(to, @

Base.@propagate_inbounds maskgetindex(::Dims, m::RandomMask{Nothing}, _::Integer...) = rand(Float32) >= m.p
Base.@propagate_inbounds function maskgetindex(destsize::Dims, m::RandomMask, I::Integer...)
s = xor(((one(UInt32), UInt32.(Base.tail(destsize))...) .* UInt32.(reverse(I)))...) + one(UInt32)
s = +((shape2stride(unsafe_trunc.(UInt32, destsize)) .* unsafe_trunc.(UInt32, I))...)
v, rng = prand(Float32, setpos(m.rng, s), s)
return v >= m.p
end
Expand Down
8 changes: 6 additions & 2 deletions src/mask/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ ChainRulesCore.@non_differentiable lengths(m)

function ChainRulesCore.rrule(::typeof(apply_mask), ::NaiveMaskOp, mask, score)
m = GetIndexer(mask, size(score))
naive_apply_mask_pullback(Ȳ) = (NoTangent(), NoTangent(), NoTangent(), _fast_broadcast(*, unthunk(Ȳ), m))
function naive_apply_mask_pullback(Ybar)
= unthunk(Ybar)
thk = @thunk _fast_broadcast2!(*, similar(score), Ȳ, m)
(NoTangent(), NoTangent(), NoTangent(), thk)
end
return _fast_broadcast(*, score, m), naive_apply_mask_pullback
end

Expand Down Expand Up @@ -58,7 +62,7 @@ function ChainRulesCore.rrule(config::RuleConfig, ::typeof(apply_broadcast_mask)
m = GetIndexer(mask, size(score), convert(eltype(score), scale))
function apply_broadcast_mask_pullback(Ybar)
= unthunk(Ybar)
thk = @thunk _fast_broadcast(*, Ȳ, m)
thk = @thunk _fast_broadcast2!(*, similar(score), Ȳ, m)
return (NoTangent(), NoTangent(), NoTangent(), thk, NoTangent())
end
return _fast_broadcast(*, score, m), apply_broadcast_mask_pullback
Expand Down
4 changes: 2 additions & 2 deletions src/mask/indexer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ Base.size(I::Indexer) = I.destsize

@inline Base.@propagate_inbounds Base.getindex(m::Indexer{Bool}, I::Integer...) = __maskgetindex__(m.destsize, m.mask, I...)
@inline Base.@propagate_inbounds Base.getindex(m::Indexer{Bool}, I::Tuple) = __maskgetindex__(m.destsize, m.mask, I...)
@inline Base.@propagate_inbounds Base.getindex(m::Indexer, I::Integer...) = m.scale * __maskgetindex__(m.destsize, m.mask, I...)
@inline Base.@propagate_inbounds Base.getindex(m::Indexer, I::Tuple) = m.scale * __maskgetindex__(m.destsize, m.mask, I...)
@inline Base.@propagate_inbounds Base.getindex(m::Indexer, I::Integer...) = ifelse(__maskgetindex__(m.destsize, m.mask, I...), m.scale, zero(m.scale))
@inline Base.@propagate_inbounds Base.getindex(m::Indexer, I::Tuple) = ifelse(__maskgetindex__(m.destsize, m.mask, I...), m.scale, zero(m.scale))

using Adapt
import Adapt: adapt_structure
Expand Down

0 comments on commit 587913e

Please sign in to comment.