-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stateful Hand-Written Rules #403
Comments
While I was going through The Rule Interface Round 2 of the current documentation, I wrote the following code that should be equivalent to the coded
Sanity-check: my alternative seems to match the example. My REPL showed:
Did you mean something like the |
This is close to what I had in mind, but I don't think it quite captures the problem, because Here's concrete example which involves over-writing memory, thus necessitating restoring it afterwards. Consider writing a rule for signature Tuple(typeof(mul!), Matrix{P}, Matrix{P}, Matrix{P}} where {P<:IEEEFloat} whose semantics are (as usual) to overwrite the first matrix argument with the result of multiplying the second and third matrices. The way to do it using using BenchmarkTools
using Mooncake
using Mooncake: NoRData, CoDual, zero_fcodual
using Base: IEEEFloat
using LinearAlgebra: mul!
function Mooncake.rrule!!(
::CoDual{typeof(mul!)}, C::CoDual{Matrix{P}}, A::CoDual{Matrix{P}}, B::CoDual{Matrix{P}}
) where {P<:IEEEFloat}
# Make a copy of `C` and its adjoint. This is where allocations are introduced.
C_copy = copy(C.x)
dC_copy = copy(C.dx)
# Run the forwards-pass.
mul!(C.x, A.x, B.x)
C.dx .= zero(P)
function pb!!(::NoRData)
# Do the computations needed to increment tangents of A and B.
# code to increment A.dx
# code to increment B.dx
# Reset value of `C`.
copy!(C.x, C_copy)
copy!(C.dx, dC_copy)
return NoRData(), NoRData(), NoRData(), NoRData()
end
return C, pb!!
end A stateful version of this might be something along the lines of # mutable struct which might be uninitialised. We would probably insist that this struct be
# used by all stateful rules.
mutable struct StatefulRRule{Tstate}
state::Tstate
StatefulRRule{Tstate}() where {Tstate} = new{Tstate}()
end
# Create a mutable struct with uninitialised state. It should be possible to add a macro, in
# the same vein as `@is_primitive`, to make this function simpler to add methods to.
function build_primitive_rrule(
::Type{<:Tuple{typeof(mul!), Matrix{P}, Matrix{P}, Matrix{P}}}
) where {P<:IEEEFloat}
return StatefulRRule{Tuple{Matrix{P}, Matrix{P}}}()
end
# Note: the only difference between this signature and the one for the `rrule!!` above is the function itself.
function (rule::StatefulRRule)(
::CoDual{typeof(mul!)}, C::CoDual{Matrix{P}}, A::CoDual{Matrix{P}}, B::CoDual{Matrix{P}}
) where {P<:IEEEFloat}
# If we don't already have some state allocated, allocate some. After this, the
# remainder of the function is identical to the rrule!!.
if !isdefined(rule, :state)
rule.state = (copy(C.x), copy(C.dx))
end
# We can be sure that we have state in the rule at this point, so just make use of it.
C_copy = rule.state[1]
dC_copy = rule.state[2]
copy!(rule.state[1], C.x)
copy!(rule.state[2], C.dx)
# Run the forwards-pass.
mul!(C.x, A.x, B.x)
C.dx .= zero(P)
# The pullback can close over the `StatefulRule`.
function pb!!(::NoRData)
# Do the computations needed to increment tangents of A and B.
# code to increment A.dx
# code to increment B.dx
# Reset value of `C`.
copy!(C.x, C_copy)
copy!(C.dx, dC_copy)
return NoRData(), NoRData(), NoRData(), NoRData()
end
return C, pb!!
end
sig = Tuple{typeof(mul!), Matrix{Float64}, Matrix{Float64}, Matrix{Float64}};
stateful_rule = build_primitive_rrule(sig); Observe that the first time that a stateful rule is called it allocates the memory needed. On subsequent visits, it will just reuse the memory. As a result, we get the following timing results: julia> C, A, B = randn(16, 8), randn(16, 32), randn(32, 8);
julia> @btime mul!($C, $A, $B);
289.234 ns (0 allocations: 0 bytes)
julia> _C, _A, _B = zero_fcodual(C), zero_fcodual(A), zero_fcodual(B);
julia> @btime Mooncake.rrule!!(zero_fcodual(mul!), $_C, $_A, $_B)[2](NoRData());
395.835 ns (4 allocations: 2.22 KiB)
julia> @btime ($stateful_rule)(zero_fcodual(mul!), $_C, $_A, $_B)[2](NoRData());
358.533 ns (0 allocations: 0 bytes) Note that the implementation of the
These points are just to say that some thought is required regarding the caching structure, but hopefully they don't detract from the larger point: introducing a |
Note that another benefit of permitting stateful rules would be that we could properly implement ChainRules.jl's call-back-into-AD mechanism properly. Currently it's not really possible to do, because our AD requires carrying some state around in order to get optimal performance, but our |
Thanks for the explanation! Please post additional tips or ideas here because I'm interested to explore this issue more after my documentation PR. |
At present, if an
rrule!!
needs to allocate storage for some value it overwrites on the forwards pass, it must do so each time the rule is called. This is distinct from derived rules, which have state which is preserved between calls.There's no particular reason for things to be this way. This kind of functionality would be helpful, because it would make it possible to trade off some memory in exchange for runtime performance. Take, for example,
BLAS.gemm!(transA, transB, alpha, A, B, beta, C)
-- this rule has to allocate memory to hold whatever valueC
has upon entry, in order to restore it on the reverse-pass. Currently, it must allocate this memory each time the rule is called, but if we permit rules to be stateful we can just use the same heuristic that we use for the derived rules, and avoid de-allocating this memory -- subsequent calls to a rule would be fast.This change could probably be made non-breaking (
rrule!!
s would remain unchanged). We would just add a functionbuild_primitive_rrule
or something, which returnsrrule!!
by default, but which one can over-ride to return e.g. a custom callable struct instead.I first realised that this is probably going to be required while addressing #394 . The primal is type-stable and non-allocating, but Mooncake has a great deal of allocations. Now that a couple of performance bugs have been fixed (the generated IR was not type stable due to these bugs), the removal of the remaining allocations will require solving this issue.
The text was updated successfully, but these errors were encountered: