Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stateful Hand-Written Rules #403

Open
willtebbutt opened this issue Dec 1, 2024 · 4 comments
Open

Stateful Hand-Written Rules #403

willtebbutt opened this issue Dec 1, 2024 · 4 comments
Labels
enhancement (performance) Would reduce the time it takes to run some bit of the code high priority

Comments

@willtebbutt
Copy link
Member

willtebbutt commented Dec 1, 2024

At present, if an rrule!! needs to allocate storage for some value it overwrites on the forwards pass, it must do so each time the rule is called. This is distinct from derived rules, which have state which is preserved between calls.

There's no particular reason for things to be this way. This kind of functionality would be helpful, because it would make it possible to trade off some memory in exchange for runtime performance. Take, for example, BLAS.gemm!(transA, transB, alpha, A, B, beta, C) -- this rule has to allocate memory to hold whatever value C has upon entry, in order to restore it on the reverse-pass. Currently, it must allocate this memory each time the rule is called, but if we permit rules to be stateful we can just use the same heuristic that we use for the derived rules, and avoid de-allocating this memory -- subsequent calls to a rule would be fast.

This change could probably be made non-breaking (rrule!!s would remain unchanged). We would just add a function build_primitive_rrule or something, which returns rrule!! by default, but which one can over-ride to return e.g. a custom callable struct instead.

I first realised that this is probably going to be required while addressing #394 . The primal is type-stable and non-allocating, but Mooncake has a great deal of allocations. Now that a couple of performance bugs have been fixed (the generated IR was not type stable due to these bugs), the removal of the remaining allocations will require solving this issue.

@willtebbutt willtebbutt added the enhancement (performance) Would reduce the time it takes to run some bit of the code label Dec 1, 2024
@RoyCCWang
Copy link
Contributor

While I was going through The Rule Interface Round 2 of the current documentation, I wrote the following code that should be equivalent to the coded rrule!! example in that section of the documentation; i.e., rrule_mod!! below should behave the same as the rrule! from the example.

using LinearAlgebra
import Mooncake as MN

T = Float64

# This is from the example, the primal function.
function eval_model(p::Tuple{T, Vector{T}}) where T <: AbstractFloat
    a, b = p[1], p[2]
    return a + sum(b)
end

# This is from the example.
function rrule!!(
    ::MN.CoDual{typeof(eval_model)},
    x::MN.CoDual{
        Tuple{T, Vector{T}},
    },
    ) where T <: AbstractFloat

    dx_fdata = x.dx
    function df_adjoint(dy::T)

        dx_fdata[2] .+= dy
        dx_1_rdata = dy
        dx_rdata = (dx_1_rdata, MN.NoRData())

        return MN.NoRData(), dx_rdata
    end

    x_p = x.x
    return MN.CoDual(x_p[1] + sum(x_p[2]), MN.NoFData()), df_adjoint
end

# My alternatives that I used while exploring the documentation.
struct ReversePassCallable{T}
    aux_state::T # mutates whenever ModelAjoint is called.
end
function (A!::ReversePassCallable)(model_output::AbstractFloat)

    output_state = A!.aux_state[end]
    for i in eachindex(output_state)
        output_state[i] += model_output
    end
    return MN.NoRData(), (model_output, MN.NoRData())
end

function rrule_mod!!(
    ::MN.CoDual{typeof(eval_model)},
    state::MN.CoDual{
        Tuple{T, Vector{T}},
    },
    ) where T <: AbstractFloat

    run_reverse_pass = ReversePassCallable(get_aux(state))
    model_eval = eval_model(get_primal_inputs(state))
    return MN.CoDual(model_eval, MN.NoFData()), run_reverse_pass
end

function get_aux(A::MN.CoDual)
    return A.dx
end

function get_primal_inputs(A::MN.CoDual)
    return A.x
end

D = 2
a = T(5)
b = [T(1); T(2)]
y0 = zeros(D)

out, pb!! = rrule!!(
    MN.CoDual(
        eval_model,
        MN.NoFData(),
    ),
    MN.CoDual(
        (a, b),
        (MN.NoFData(), y0),
    ),
)

out_mod, pb_mod!! = rrule!!(
    MN.CoDual(
        eval_model,
        MN.NoFData(),
    ),
    MN.CoDual(
        (a, b),
        (MN.NoFData(), y0),
    ),
)

Sanity-check: my alternative seems to match the example. My REPL showed:

julia> out
Mooncake.CoDual{Float64, Mooncake.NoFData}(8.0, Mooncake.NoFData())

julia> out_mod
Mooncake.CoDual{Float64, Mooncake.NoFData}(8.0, Mooncake.NoFData())

julia> pb!!(one(T))
(Mooncake.NoRData(), (1.0, Mooncake.NoRData()))

julia> pb_mod!!(one(T))
(Mooncake.NoRData(), (1.0, Mooncake.NoRData()))

Did you mean something like the ReversePassCallable as a first possible approach to the automatic generation of stateful rrules!! callables?

@willtebbutt
Copy link
Member Author

willtebbutt commented Dec 2, 2024

This is close to what I had in mind, but I don't think it quite captures the problem, because eval_model does not appear to mutate its inputs.

Here's concrete example which involves over-writing memory, thus necessitating restoring it afterwards. Consider writing a rule for signature

Tuple(typeof(mul!), Matrix{P}, Matrix{P}, Matrix{P}} where {P<:IEEEFloat}

whose semantics are (as usual) to overwrite the first matrix argument with the result of multiplying the second and third matrices.

The way to do it using rrule!! would be something like

using BenchmarkTools
using Mooncake
using Mooncake: NoRData, CoDual, zero_fcodual
using Base: IEEEFloat
using LinearAlgebra: mul!

function Mooncake.rrule!!(
    ::CoDual{typeof(mul!)}, C::CoDual{Matrix{P}}, A::CoDual{Matrix{P}}, B::CoDual{Matrix{P}}
) where {P<:IEEEFloat}

    # Make a copy of `C` and its adjoint. This is where allocations are introduced.
    C_copy = copy(C.x)
    dC_copy = copy(C.dx)

    # Run the forwards-pass.
    mul!(C.x, A.x, B.x)
    C.dx .= zero(P)

    function pb!!(::NoRData)
        # Do the computations needed to increment tangents of A and B.
        # code to increment A.dx
        # code to increment B.dx

        # Reset value of `C`.
        copy!(C.x, C_copy)
        copy!(C.dx, dC_copy)

        return NoRData(), NoRData(), NoRData(), NoRData()
    end
    return C, pb!!
end

A stateful version of this might be something along the lines of

# mutable struct which might be uninitialised. We would probably insist that this struct be
# used by all stateful rules.
mutable struct StatefulRRule{Tstate}
    state::Tstate
    StatefulRRule{Tstate}() where {Tstate} = new{Tstate}()
end

# Create a mutable struct with uninitialised state. It should be possible to add a macro, in
# the same vein as `@is_primitive`, to make this function simpler to add methods to.
function build_primitive_rrule(
    ::Type{<:Tuple{typeof(mul!), Matrix{P}, Matrix{P}, Matrix{P}}}
) where {P<:IEEEFloat}
    return StatefulRRule{Tuple{Matrix{P}, Matrix{P}}}()
end

# Note: the only difference between this signature and the one for the `rrule!!` above is the function itself.
function (rule::StatefulRRule)(
    ::CoDual{typeof(mul!)}, C::CoDual{Matrix{P}}, A::CoDual{Matrix{P}}, B::CoDual{Matrix{P}}
) where {P<:IEEEFloat}

    # If we don't already have some state allocated, allocate some. After this, the
    # remainder of the function is identical to the rrule!!.
    if !isdefined(rule, :state)
        rule.state = (copy(C.x), copy(C.dx))
    end

    # We can be sure that we have state in the rule at this point, so just make use of it.
    C_copy = rule.state[1]
    dC_copy = rule.state[2]
    copy!(rule.state[1], C.x)
    copy!(rule.state[2], C.dx)

    # Run the forwards-pass.
    mul!(C.x, A.x, B.x)
    C.dx .= zero(P)

    # The pullback can close over the `StatefulRule`.
    function pb!!(::NoRData)
        # Do the computations needed to increment tangents of A and B.
        # code to increment A.dx
        # code to increment B.dx

        # Reset value of `C`.
        copy!(C.x, C_copy)
        copy!(C.dx, dC_copy)

        return NoRData(), NoRData(), NoRData(), NoRData()
    end
    return C, pb!!
end

sig = Tuple{typeof(mul!), Matrix{Float64}, Matrix{Float64}, Matrix{Float64}};
stateful_rule = build_primitive_rrule(sig);

Observe that the first time that a stateful rule is called it allocates the memory needed. On subsequent visits, it will just reuse the memory. As a result, we get the following timing results:

julia> C, A, B = randn(16, 8), randn(16, 32), randn(32, 8);

julia> @btime mul!($C, $A, $B);
  289.234 ns (0 allocations: 0 bytes)

julia> _C, _A, _B = zero_fcodual(C), zero_fcodual(A), zero_fcodual(B);

julia> @btime Mooncake.rrule!!(zero_fcodual(mul!), $_C, $_A, $_B)[2](NoRData());
  395.835 ns (4 allocations: 2.22 KiB)

julia> @btime ($stateful_rule)(zero_fcodual(mul!), $_C, $_A, $_B)[2](NoRData());
  358.533 ns (0 allocations: 0 bytes)

Note that the implementation of the StatefulRule for sig above is incomplete, because it cannot handle

  1. repeated calls in the forwards pass -- we would need to make use of Mooncake.Stack of Tuples in the state, not just a Tuple, and
  2. changes in the size of the matrices passes in -- if the input dimensions change from call-to-call, the cache size will need to change.

These points are just to say that some thought is required regarding the caching structure, but hopefully they don't detract from the larger point: introducing a StatefulRRule type should enable us to (with some careful design) eliminate allocations in the forwards pass in exchange for increasing peak memory usage.

@willtebbutt
Copy link
Member Author

Note that another benefit of permitting stateful rules would be that we could properly implement ChainRules.jl's call-back-into-AD mechanism properly. Currently it's not really possible to do, because our AD requires carrying some state around in order to get optimal performance, but our rrule!!s do not admit state.

@RoyCCWang
Copy link
Contributor

Thanks for the explanation! Please post additional tips or ideas here because I'm interested to explore this issue more after my documentation PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement (performance) Would reduce the time it takes to run some bit of the code high priority
Projects
None yet
Development

No branches or pull requests

3 participants