diff --git a/src/utils.jl b/src/utils.jl index 9786c54..faa68ac 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -66,12 +66,17 @@ function ChainRulesCore.rrule(config::RuleConfig, pf::PrefixedFunction, args...) return y, PrefixedFunctionPullback(back, num_input, num_f_args) end -struct SkipFirstArg{F} <: Function - f::F +# https://github.com/FluxML/NNlib.jl/blob/7369244c1a64317eef5b0a20c142316947a85bb3/src/utils.jl#L131-L141 +function _fast_broadcast2!(f::F, dst::Array, x, yz...) where {F<:Function} + bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, yz...)) + @simd ivdep for I in eachindex(bc) + @inbounds x[I] = bc[I] + end + return x +end +function _fast_broadcast2!(f::F, dst::AbstractArray, x, yz...) where {F<:Function} + return broadcast!(f, dst, x, yz...) end -@inline (_f::SkipFirstArg)(dst, xs...) = _f.f(xs...) - using NNlib: _fast_broadcast! @inline _fast_broadcast(f, x, yz...) = _fast_broadcast!(f, copy(x), yz...) @inline _fast_broadcast2(f, x, yz...) = _fast_broadcast2!(f, similar(x), x, yz...) -@inline _fast_broadcast2!(f, dst, x, yz...) = _fast_broadcast!(SkipFirstArg(f), dst, x, yz...)