From 2d423c90ee187c9daf72c41f2120d42d79189c79 Mon Sep 17 00:00:00 2001 From: Ben Murrell Date: Fri, 13 Dec 2024 11:02:29 +0100 Subject: [PATCH 1/7] Add eager updating --- src/lib/grad.jl | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/src/lib/grad.jl b/src/lib/grad.jl index 6b9002f73..3581baa6c 100644 --- a/src/lib/grad.jl +++ b/src/lib/grad.jl @@ -27,6 +27,42 @@ function Zygote._pullback(ctx::Zygote.AContext, ::typeof(checkpointed), f, xs... return y, pullback_checkpointed end + +""" + + eager_update(f, update, state, xs...) + +Allows training large models when the gradients cannot all fit in memory simultaneously. + +A combination of gradient checkpointing and eagerly updating the model parameters, discarding the updated gradients. +Assumes that `f` is a callable struct, `state` is the optimization state (eg. from Optimisers.jl) matching `f`, and +`update` is the function that updates the parameters of `f` from the state and the gradients, called as `update(state, f, grads)`. + +If eg. `model.layers[i]` is layer in a transformer, then: + +``` +for i in 1:length(model.layers) + h = eager_updater(model.layers[i], Optimisers.update!, opt_state.layers[i], h, other_args) +end +``` + +!!! warning + If different layers share trainable parameters, then `eager_update` will likely give wrong results. +""" +eager_update(f, update, state, xs...) = f(state, xs...) + +function Zygote._pullback(ctx::Zygote.AContext, ::typeof(eager_update), f, update, state, xs...) + y = f(xs...) + function pullback_eager_update(Δy) + y, pb = Zygote._pullback(ctx, f, xs...) + ret = pb(Δy) + update(state, f, ret[1]) + return (nothing, nothing, nothing, nothing, ret[2:end]...) + end + return y, pullback_eager_update +end + + """ hessian(f, x) From e8748360fe6a1209a30127609239bc05cf77d734 Mon Sep 17 00:00:00 2001 From: Ben Murrell Date: Sat, 14 Dec 2024 02:10:10 +0100 Subject: [PATCH 2/7] Making eager_update! work not just for callable structs --- src/lib/grad.jl | 44 ++++++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/src/lib/grad.jl b/src/lib/grad.jl index 3581baa6c..8dd54dfaf 100644 --- a/src/lib/grad.jl +++ b/src/lib/grad.jl @@ -28,40 +28,44 @@ function Zygote._pullback(ctx::Zygote.AContext, ::typeof(checkpointed), f, xs... end -""" - eager_update(f, update, state, xs...) +""" -Allows training large models when the gradients cannot all fit in memory simultaneously. + eager_update!((modelf, xs...), (update!, state)) = modelf(xs...) and update!(state, modelf, grads) + eager_update!(f, (model, xs...), (update!, state)) = f(model, xs...) and update!(state, model, grads) A combination of gradient checkpointing and eagerly updating the model parameters, discarding the updated gradients. -Assumes that `f` is a callable struct, `state` is the optimization state (eg. from Optimisers.jl) matching `f`, and -`update` is the function that updates the parameters of `f` from the state and the gradients, called as `update(state, f, grads)`. +Works when `modelf` is a callable struct that also stores the parameters to be updated, or if you have a function `f` +that takes the `model`` as the first argument. `state` is the optimization state (eg. from Optimisers.jl) matching your model, and +`update!` is the function that updates the parameters of `modelf`/`model` from the state and the gradients, called as `update!(state, model, grads)`. -If eg. `model.layers[i]` is layer in a transformer, then: +If eg. `model.layers[i]` is layer in a transformer, and is callable as `model.layers[i](h, other_args...)`, then: ``` -for i in 1:length(model.layers) - h = eager_updater(model.layers[i], Optimisers.update!, opt_state.layers[i], h, other_args) -end +h = eager_update!((model.layers[i], h, other_args...), (Optimisers.update!, opt_state.layers[i])) ``` -!!! warning - If different layers share trainable parameters, then `eager_update` will likely give wrong results. -""" -eager_update(f, update, state, xs...) = f(state, xs...) +If eg. `f` needs to call `model.layers[i]` (which holds the parameters) as `f(model.layers[i], h, other_args...)`, then: -function Zygote._pullback(ctx::Zygote.AContext, ::typeof(eager_update), f, update, state, xs...) - y = f(xs...) - function pullback_eager_update(Δy) - y, pb = Zygote._pullback(ctx, f, xs...) +``` +h = eager_update!(f, (model.layers[i], h, other_args...), (Optimisers.update!, opt_state.layers[i])) +``` +""" +eager_update!(f, (model, xs...), (update!, state)) = f(model, xs...) +function Zygote._pullback(ctx::Zygote.AContext, ::typeof(eager_update!), f,(model, xs...), (update!, state)) + y = f(model, xs...) + function pullback_eager_update!(Δy) + y, pb = Zygote._pullback(ctx, f, model, xs...) ret = pb(Δy) - update(state, f, ret[1]) - return (nothing, nothing, nothing, nothing, ret[2:end]...) + update(state, model, ret[2]) + return (nothing, nothing, (nothing, ret[3:end]...), nothing) end - return y, pullback_eager_update + return y, pullback_eager_update! end +eager_update!((modelf, xs...), (update!, state)) = eager_update!((m, xs...) -> m(xs...), (modelf, xs...), (update!, state)) + + """ hessian(f, x) From d057d86bbd875ff1554846f2b034d9cba9bf14aa Mon Sep 17 00:00:00 2001 From: Ben Murrell Date: Sat, 14 Dec 2024 02:17:45 +0100 Subject: [PATCH 3/7] Typo fix --- src/lib/grad.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/grad.jl b/src/lib/grad.jl index 8dd54dfaf..6fc6b2ea5 100644 --- a/src/lib/grad.jl +++ b/src/lib/grad.jl @@ -57,7 +57,7 @@ function Zygote._pullback(ctx::Zygote.AContext, ::typeof(eager_update!), f,(mode function pullback_eager_update!(Δy) y, pb = Zygote._pullback(ctx, f, model, xs...) ret = pb(Δy) - update(state, model, ret[2]) + update!(state, model, ret[2]) return (nothing, nothing, (nothing, ret[3:end]...), nothing) end return y, pullback_eager_update! From 421dbcdad7e023deca3c8fbd907bdf76d565f617 Mon Sep 17 00:00:00 2001 From: Ben Murrell Date: Sat, 14 Dec 2024 21:10:17 +0100 Subject: [PATCH 4/7] Adding warning back. --- src/lib/grad.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lib/grad.jl b/src/lib/grad.jl index 6fc6b2ea5..3ef586421 100644 --- a/src/lib/grad.jl +++ b/src/lib/grad.jl @@ -50,6 +50,9 @@ If eg. `f` needs to call `model.layers[i]` (which holds the parameters) as `f(mo ``` h = eager_update!(f, (model.layers[i], h, other_args...), (Optimisers.update!, opt_state.layers[i])) ``` + +!!! warning + If different layers share trainable parameters, then `eager_update` will likely give wrong results. """ eager_update!(f, (model, xs...), (update!, state)) = f(model, xs...) function Zygote._pullback(ctx::Zygote.AContext, ::typeof(eager_update!), f,(model, xs...), (update!, state)) From 4397b283dfa6e0e5f7a4fa0fe16e0d4e894b560c Mon Sep 17 00:00:00 2001 From: Ben Murrell Date: Sun, 15 Dec 2024 16:41:13 +0100 Subject: [PATCH 5/7] Switching to hook --- src/lib/grad.jl | 45 +++++++++++++++++++++------------------------ 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/src/lib/grad.jl b/src/lib/grad.jl index 3ef586421..35c0bc253 100644 --- a/src/lib/grad.jl +++ b/src/lib/grad.jl @@ -31,45 +31,42 @@ end """ - eager_update!((modelf, xs...), (update!, state)) = modelf(xs...) and update!(state, modelf, grads) - eager_update!(f, (model, xs...), (update!, state)) = f(model, xs...) and update!(state, model, grads) + eager_update!(state, model, update!) -A combination of gradient checkpointing and eagerly updating the model parameters, discarding the updated gradients. -Works when `modelf` is a callable struct that also stores the parameters to be updated, or if you have a function `f` -that takes the `model`` as the first argument. `state` is the optimization state (eg. from Optimisers.jl) matching your model, and -`update!` is the function that updates the parameters of `modelf`/`model` from the state and the gradients, called as `update!(state, model, grads)`. +Eagerly updates the model parameters, discarding the updated gradients to save memory. +`model` stores the parameters to be updated, `state` is the optimization state (eg. from Optimisers.jl) matching your model component, and +`update!` is the function that updates the parameters (eg. from `Optimisers.jl`), usually called as `update!(state, model, grads)`. -If eg. `model.layers[i]` is layer in a transformer, and is callable as `model.layers[i](h, other_args...)`, then: +If `f` is a function that takes a single layer, called as `h = f(model.layers[i], h, other_args...)` then we can eagerly update with: ``` -h = eager_update!((model.layers[i], h, other_args...), (Optimisers.update!, opt_state.layers[i])) +h = f(Zygote.eager_update!(state.layers[i], model.layers[i], Optimisers.update!), h, other_args...) ``` -If eg. `f` needs to call `model.layers[i]` (which holds the parameters) as `f(model.layers[i], h, other_args...)`, then: +or combine this with gradient checkpointing (for additional memory saving at the cost of increased execution time) with: ``` -h = eager_update!(f, (model.layers[i], h, other_args...), (Optimisers.update!, opt_state.layers[i])) +h = Zygote.checkpointed(f, eager_update!(state.layers[i], model.layers[i], Optimisers.update!), h, other_args...) +``` + +If `model.layers[i]` itself is callable, we can use the above by first wrapping it: + +``` +f(model, xs...) = model(xs...) +h = f(Zygote.eager_update!(state.layers[i], model.layers[i], Optimisers.update!), h, other_args...) ``` !!! warning - If different layers share trainable parameters, then `eager_update` will likely give wrong results. + If different layers share trainable parameters, then `eager_update!` will likely give wrong results. """ -eager_update!(f, (model, xs...), (update!, state)) = f(model, xs...) -function Zygote._pullback(ctx::Zygote.AContext, ::typeof(eager_update!), f,(model, xs...), (update!, state)) - y = f(model, xs...) - function pullback_eager_update!(Δy) - y, pb = Zygote._pullback(ctx, f, model, xs...) - ret = pb(Δy) - update!(state, model, ret[2]) - return (nothing, nothing, (nothing, ret[3:end]...), nothing) +function eager_update!(state, model, update!) + function update_hook(dmodel) + update!(state, model, dmodel) + return nothing end - return y, pullback_eager_update! + return Zygote.hook(update_hook, model) end -eager_update!((modelf, xs...), (update!, state)) = eager_update!((m, xs...) -> m(xs...), (modelf, xs...), (update!, state)) - - - """ hessian(f, x) From 28962454d8b8ac131566d4d0b28237940e46c490 Mon Sep 17 00:00:00 2001 From: Ben Murrell Date: Sun, 29 Dec 2024 14:09:08 +0100 Subject: [PATCH 6/7] Tweaking docs --- src/lib/grad.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lib/grad.jl b/src/lib/grad.jl index 35c0bc253..92be71d34 100644 --- a/src/lib/grad.jl +++ b/src/lib/grad.jl @@ -39,19 +39,19 @@ Eagerly updates the model parameters, discarding the updated gradients to save m If `f` is a function that takes a single layer, called as `h = f(model.layers[i], h, other_args...)` then we can eagerly update with: -``` +```julia h = f(Zygote.eager_update!(state.layers[i], model.layers[i], Optimisers.update!), h, other_args...) ``` or combine this with gradient checkpointing (for additional memory saving at the cost of increased execution time) with: -``` +```julia h = Zygote.checkpointed(f, eager_update!(state.layers[i], model.layers[i], Optimisers.update!), h, other_args...) ``` If `model.layers[i]` itself is callable, we can use the above by first wrapping it: -``` +```julia f(model, xs...) = model(xs...) h = f(Zygote.eager_update!(state.layers[i], model.layers[i], Optimisers.update!), h, other_args...) ``` From 864ec15f07a51bfa0d285259af77bf65e5cf9eb8 Mon Sep 17 00:00:00 2001 From: Ben Murrell Date: Sun, 29 Dec 2024 14:13:07 +0100 Subject: [PATCH 7/7] Doc ref in utils.md --- docs/src/utils.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/utils.md b/docs/src/utils.md index 3adb1d4c1..fbe8a57cc 100644 --- a/docs/src/utils.md +++ b/docs/src/utils.md @@ -26,6 +26,7 @@ Zygote.hook Zygote.Buffer Zygote.forwarddiff Zygote.checkpointed +Zygote.eager_update! ``` `Params` and `Grads` can be copied to and from arrays using the `copy!` function.