diff --git a/src/buffer.jl b/src/buffer.jl index 488b7a6..5ce8aeb 100644 --- a/src/buffer.jl +++ b/src/buffer.jl @@ -87,8 +87,8 @@ end grad_mut(b::Buffer) = fill!(similar(b.data, Any), nothing) grad_mut(b::Buffer{T}) where T<:Number = fill!(similar(b.data, float(T)), 0) -function ZygoteRules._pullback(cx::AContext, ::typeof(Buffer), args...) - Buffer(args...), _ -> nothing +function ZygoteRules._pullback(cx::AContext, ::Type{T}, args...) where {T<:Buffer} + T(args...), _ -> nothing end @adjoint function getindex(b::Buffer, i...)