-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NNlib Support #171
Comments
Can we re-use the |
In short: yes. We don't have a nice macro for doing it at the minute, but it ought to be quite doable provided that we constrain the types that impose some of our own constraints on the types. In the case of dropout: this one actually might be a little tricky because it uses kwargs, and I've not looked at supporting kwargs yet (this largely happen automatically), but it should otherwise be fine. Something like batched_transpose should be quite straightforward. You would just do something along the lines of function rrule!!(::CoDual{typeof(batched_transpose)}, A::CoDual{Array{<:Any,3}})
B, pb = rrule(batched_transpose, primal(A)
function pb!!(dB)
_, dA_inc = pb(dB)
increment!!(tangent(A), dA_inc)
return NoRData(), NoRData()
end
return zero_fcodual(B), pb!!
end Note that I've restricted the array type here -- we would just need to be careful with how we extend this rule to other array types, as it's only valid in Tapir.jl for arrays whose tangents are arrays, rather than |
Thanks @willtebbutt. A follow-up question on this topic: if we put mutation aside, is it true in general, |
In the narrow sense that lots of However, recall that this gives you an overly optimistic sense of how much more general a |
Was the current |
A bit of both. It's motivated by trying to write rules for things like Zygote, where you really want your rule system to do a lot of work, because Zygote can't differentiate much itself due to lack of mutation support. It's an oversight in the sense that we didn't realise how hard it is to write rules which work for lots of types that are robust. The upshot is the kinds of issues we've discussed previously (mainly around poor composition performance). |
But Or does writing rules with concrete argument types always require something like typed tangents? IIUC, the typed tangents in Tapirs seem like an extension of |
Correct. Equally, Tapir.jl doesn't prevent you from writing rules with abstract arguments types, it's just not generally a good idea.
I'm not sure what you mean by "typed tangents" -- could you expand? |
I am referring to the fact that, in Tapir, each primal type has a unique tangent type. Do writing |
Hmm okay. Could you provide a small code example showing what you mean? I'm still not quite following what you're trying to say / ask, so I think I need some examples. |
This is partially addressed by #254, but there remains work to be done on GPU integration, and generalising the functionality added to |
With the merging of #435 , I'm closing this issue. If we discover more NNlib rules that need incorporating later, I'll re-open. Note that closing this doesn't mean anyone should expect Mooncake will work on Flux.jl or Lux.jl at this point in time -- more GPU rules are required, just hopefully none from NNlib.jl. |
#169 highlights that we need rules for a range of functionality which lives in NNlib.jl -- this is not surprising, and has largely already been done (see e.g. the Enzyme extension). Someone needs to systematically work through and test that Tapir works on everything in NNlib.
The text was updated successfully, but these errors were encountered: