-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Eager parameter updating #1541
Eager parameter updating #1541
Conversation
Is there precedence for this in other libraries? |
There is a pytorch tutorial describing a per-parameter version here # Instead of having just *one* optimizer, we will have a ``dict`` of optimizers
# for every parameter so we could reference them in our hook.
optimizer_dict = {p: torch.optim.Adam([p], foreach=False) for p in model.parameters()}
# Define our hook, which will call the optimizer ``step()`` and ``zero_grad()``
def optimizer_hook(parameter) -> None:
optimizer_dict[parameter].step()
optimizer_dict[parameter].zero_grad()
# Register the hook onto every parameter
for p in model.parameters():
p.register_post_accumulate_grad_hook(optimizer_hook) It would be great if we could have a function like this grad_and_update!(m -> loss(m, x, y), model, opt_state) in an Optimisers extension, also interacting nicely with |
I've never used any other libraries :)
Nice! I figure it must be a normal trick, but I was quite happy with how easy it was to get such a big gain in this ecosystem. And this version luckily doesn't require you to split your optimizer up across the different layers, because you can just pass the right part of the larger opt_state into this. It is now ~one line to halve your model's mem requirements. One thing I might want to tweak before merging: this works when |
Do you mean automatically tracking which bit of the optimizer state would go into |
The most appropriate place for this may be Optimisers.jl, as the technique could be applicable to ADs beyond Zygote. That said, I'm not quite sure I understand how it's meant to work. The abridged example in the docstring does not look like the PyTorch one Carlo shared. Is there a complete, working minimal example that demonstrates this functionality? The main thing I'd like to understand is how it would pick up on the final accumulated gradients for a parameter. |
I think Flux is the appropriate place for this. And I would simplify the interface to eager_update!(opt_state, model, xs...) = model(xs...)
eager_update!(f::Function, opt_state, model, xs...) = f(model, xs...) since we can assume that |
Perhaps, but it is possible that this might be useful for non-Optimisers.jl opts, like if someone rolls their own (not using Optimisers), or is outside of the flux ecosystem entirely. Maybe Flux should have something with your simplified interface that calls this version? Note: rolling your own optimizer is very likely when working with very large models.
I was originally considering an rrule for this but then there was some discussion on Overall, for discussions of where this should go: it should probably be wherever
I was going to push one to Jjama3.jl when we've finalized the interface, but this is the core of it: #Model function def, where model.layers[i] is a Jjama3 TransformerBlock (just an RMSnorm -> attention -> RMSnorm -> feed_forward)
function forward(model::Transformer, tokens::AbstractArray{Int}, opt_state) #<- Note opt_state passed in here
h = model.tok_embeddings(tokens)
rope = model.rope[model.pos+1:model.pos+size(tokens, 1)]
mask = Jjama3.create_mask(h)
for i in 1:length(model.layers)
#Insted of:
#h = model.layers[i](h, model.pos, rope, mask)
#We do:
h = Zygote.eager_update!((model.layers[i], h, model.pos, rope, mask), (Optimisers.update!, opt_state.layers[i]))
end
h = model.norm(h)
output = model.output(h)
return output
end
#Normal opt_state setup:
opt_state = Flux.setup(Apollo(lr), model)
#Training loop:
for i in 1:1000000000
train_toks = ...
l, grads = Flux.withgradient(model) do m
logit_loss(forward(m, train_toks[1:end-1,:], opt_state, eager = true), train_toks[2:end,:])
end
#Then a catch-all for any parameters not updated eagerly:
Flux.update!(opt_state, model, grads[1])
end I have verified that this is training a model where you can't fit both the weights and the gradients on the GPU, in combination with the new Optimiser PR I opened which removes the need to also store the moments of the gradients (using a low-rank trick): FluxML/Optimisers.jl#196 (comment) Together, these bring the memory footprint from a min of 4x the weights (weights + grads+ moment1 + moment2) down to ~1.3x the weights (with some overhead for the low-rank projections and the activations themselves). I've tested up to a 7.5 billion parameter model. |
ok, we can have a wrapper in flux |
Thanks, this really helps. It seems like the main difference here is that
shared_layer = Dense(...)
model = (;
branch_1 = Chain(shared_layer, more_layers...),
branch_2 = Chain(other_layer, shared_layer, another_layer)
)
opt_state = opt_state = Flux.setup(Apollo(lr), model)
function forward_incorrect(model, x, opt_state) #<- Note opt_state passed in here
# using the `eager_update!(opt_state, model, xs...) = model(xs...)` method proposed above for brevity
# optimizer step run for shared_layer using branch_1 gradients only!
y_1 = eager_update!(opt_state.branch_1, model.branch_1, x)
# optimizer step run for shared_layer using branch_2 gradients only!
y_2 = eager_update!(opt_state.branch_2, model.branch_2, x)
return y_1 + y_2
end
function forward_correct(model, x, opt_state) #<- Note opt_state passed in here
# using the `eager_update!(f::Function, opt_state, model, xs...) = f(model, xs...)` method proposed above for brevity
# optimizer step run for shared_layer using accumulated gradients from both branches
y = eager_update!(opt_state, model, x) do model, x
y_1 = model.branch_1(x)
y_2 = model.branch_2(x)
return y_1 + y_2
end
return y
end The problem I see is that one has to wrap every code path
I think that may be a misunderstanding based on out-of-date historical discussion? Zygote is basically on life support at this point, and Flux wants to be rid of it as soon as possible. As such, nice new functionality should ideally find a different home. |
Agreed re: docs. I had this warning in my first commit:
Models with a repeating core that don't have any shared parameters across the repeating layers are a large and critical class, so having a simple trick that helps with these but doesn't help when there are are shared parameters seems fine to me? If a layer shares all its parameters then you don't do this sort of thing, and if it shares some of its parameters then often you can rewrite the layers themselves to separate out the components that share parameters and those that don't, and then use this for the components that don't.
Yes, that was the discussion I saw. And as I said I couldn't follow the argument to know whether or not this would fit as an rrule. Is there an rrule for the equivalent of
Well Enzyme errors whenever we look at it, so I hope this shift isn't too precipitous. But then this becomes a question of whether the Flux wrapper for this should use eg. an rrule version of this trick (instead of the Zygote one) and not a question of whether or not this should be in Zygote. |
The docstring warning LGTM and should probably be sufficient for now. If this feature becomes widely used, we can think about more guardrails.
The feature that would allow someone to write a On that note, this discussion reminded me of another Zygote utility function: function eager_update!(f, (model, xs...), (update!, state))
function update_hook(dmodel)
update!(state, model, dmodel)
return nothing
end
return Zygote.checkpointed(f, Zygote.hook(update_hook, model), xs...)
end While we're at it, perhaps the interface could be simplified as well. I think the key here is realizing that checkpointing and eager updates can be decoupled: function eager_update!(model, opt_state, update! = Optimisers.update!)
function update_hook(dmodel)
update!(opt_state, model, dmodel)
return nothing
end
return Zygote.hook(update_hook, model)
end
# So instead of
eager_update!(f, (model, xs...), (opt_state, Optimisers.update!))
# You could write
Zygote.checkpointed(f, eager_update!(model, opt_state), xs...)
# Or even
Zygote.checkpointed(f, eager_update!(model1, opt_state1), eager_update!(model2, opt_state2), xs...)
# Or not checkpoint, which would be equivalent to https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
f(eager_update!(model, opt_state), xs...) Either way, I think this exercise suggests that |
Maybe Fluxperimental.jl is the right place? I see roughly what this is, but perhaps the very best way to wrap this up isn't entirely clear yet. |
Fluxperimental could work too. Another motivation for having this in a higher-level package than Zygote is that we could define an overload of this function (and |
wow this looks very nice. For discoverability I would prefer to have |
Yes this is the way. Two points:
I've updated my PR accordingly. |
src/lib/grad.jl
Outdated
|
||
If `model.layers[i]` itself is callable, we can use the above by first wrapping it: | ||
|
||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
``` | |
```julia |
src/lib/grad.jl
Outdated
|
||
If `f` is a function that takes a single layer, called as `h = f(model.layers[i], h, other_args...)` then we can eagerly update with: | ||
|
||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
``` | |
```julia |
src/lib/grad.jl
Outdated
|
||
or combine this with gradient checkpointing (for additional memory saving at the cost of increased execution time) with: | ||
|
||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
``` | |
```julia |
Can you add a reference in |
Thank you everyone for your suggestions and reviews! |
This adds something like Zygote's
checkpointed
, but additionally accepts an optimizer state and an update function. The model parameters are updated during the backward pass and then the gradients are discarded, allowing you to train models when you can't fit both the model weights and the full gradients in memory together.I wasn't quite sure if this should be PR'd to Flux instead?
PR Checklist