-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add Flux.state(x) #2239
add Flux.state(x) #2239
Conversation
Related discussion: FluxML/Functors.jl#56. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unlike fmapstructure, this recursion doesn't know about shared arrays. Does this matter?
If it doesn't, or we don't want it, that's an argument for it being here not in Functors.jl
Bikeshedding the name, is state
specific enough? We have opt_state
with which this should never be confused. Since it pairs with loadmodel!(m, x)
, should its name echo that... like modelstate(m)
? Or "bare" or "Base" or "tree"? (Honestly a tree if it ignores identity.)
it doesn't need to since it doesn't transform (except for returning nothing).
This is the common pattern that I expect for this method: model = ...
opt_state = Optimisers.setup(Adam(), model)
model_state = Flux.state(model)
BSON.@save "checkpoint.bson" opt_state model_state It reads nice and clear as it, I think less nice with Flux.state(opt::Optimisers.AbstractRule, x) = Optimisers.setup(opt, x) but probably not a good idea, doesn't convey the idea of initialization. |
What should we A slim state would be given by |
Not sure if this PR is completely ready yet, but I just tried to save weights using |
Thanks for the investigation! Yes, the problem surely is with |
One proposal is this: After some This excludes not just anonymous functions but all activation functions etc, which seems fine. This rule would include RNG state. Do we want that? An alternative rule would append "and survives Edit: In the present state the behaviour is this; note that julia> state(Dropout(0.1))
(p = 0.1, dims = nothing, active = nothing, rng = TaskLocalRNG())
julia> state(Dropout(0.1) |> trainmode!)
(p = 0.1, dims = nothing, active = true, rng = TaskLocalRNG())
julia> Flux.loadmodel!(trainmode!(Dropout(0.2)), state(Dropout(0.1))) |> dump
Dropout{Float64, Colon, TaskLocalRNG}
p: Float64 0.2 # p = 0.1 was included in state, but not copied
dims: Colon() (function of type Colon)
active: Bool true # active = nothing in state was not copied
rng: TaskLocalRNG TaskLocalRNG() Next, what is the use case for changing One thing I can picture wanting to configure is this: Instead of getting all state, get only the trainable parameters, as a similar nested structure. That cannot be controlled by julia> Flux.loadmodel!(BatchNorm(1), Flux.trainable(BatchNorm(1)))
ERROR: ArgumentError: Tried to load (β = Float32[0.0], γ = Float32[1.0]) into BatchNorm(1) but the structures do not match. |
The current implementation is equivalent to Re: what's kept. |
The current solution in |
I'd like to settle for a minimal interface allowing robust serialization of the state (for whatever definition of state covers common needs such as the ones of Metalhead.jl). If we allow anything except functions to be in the state (what this PR does right now) then we can have compatibility issues between julia versions (mostly due to rngs I guess, but maybe there are other examples): # in julia 1.9
julia> s = Flux.state(Dropout(0.2))
(p = 0.2, dims = nothing, active = nothing, rng = Random.TaskLocalRNG())
julia> using BSON, JLD2
julia> BSON.@save "test.bson" s
julia> JLD2.jldsave("test.jld2"; s) # now in julia 1,6
julia> using BSON, JLD2
julia> BSON.@load "test.bson" s
ERROR: UndefVarError: TaskLocalRNG not defined
Stacktrace:
[1] (::BSON.var"#31#32")(m::Module, f::String)
@ BSON ~/.julia/packages/BSON/DOYqe/src/extensions.jl:35
....
julia> s = JLD2.load("test.jld2", "s")
┌ Warning: type parameters for NamedTuple{(:p, :dims, :active, :rng),Tuple} do not match type NamedTuple in workspace; reconstructing
└ @ JLD2 ~/.julia/packages/JLD2/ryhNR/src/data/reconstructing_datatypes.jl:475
JLD2.ReconstructedTypes.var"##NamedTuple{(:p, :dims, :active, :rng),Tuple}#276"(0.2)
julia> s.p
0.2
julia> s.active
ERROR: type ##NamedTuple{(:p, :dims, :active, :rng),Tuple}#276 has no field active
Stacktrace:
[1] getproperty(x::JLD2.ReconstructedTypes.var"##NamedTuple{(:p, :dims, :active, :rng),Tuple}#276", f::Symbol)
@ Base ./Base.jl:33
[2] top-level scope
@ REPL[20]:1
.... What about exposing only a |
So what I'm proposing is that this whole PR becomes state_arrays(x) = fmapstructure(x -> x isa AbstractArray ? x : nothing, x) |
Given what |
Skimming through some prior art and possible inspiration:
When it comes to serialization, I think our approach would be closest to how JAX encourages serializing PyTrees. A 1-1 match would require keeping around non-arrays (while excluding stuff like RNGs) and adding metadata tags for what type each node in the tree was created from. For now though, we can get away with far less metadata. |
I think in pytorch the closest thing would be .dict_state() For our state I would prefer an opt-in approach for the leaf types that get into the state. We can make the contract that future (non-breaking) releases can modify the returned state by allowing for more types to get in, but never less. The implementation could be state(x) = fmapstructure(x -> _keep(x) ? x : missing, x)
_keep(x::TOKEEP) = true
_keep(x) = false Then it is not clear where we should set the bar. These are possible candidates
In the future we can expose some interface for customization. Maybe we start conservative with 1 or 2 and then see if 3 or more is desired? |
be419fb
to
85961da
Compare
In the end I went with I also rewrote the "Saving and Loading" docs to encourage the serialization of the state and discourage the serialization of the model struct. In the new docs I emphasized JLD2 over BSON, since JLD2 is more actively maintained thanks to the work @JonasIsensee |
I'm generally in favour of this now. Thoughts on JLD2 vs the Serialization stdlib? I know the former sometimes requires maintenance for new Julia versions, but the latter has its own issues. |
Maybe this is a good moment for me to chime in. (also since I've been tagged the second time ;) ) From what I can tell, you've decided to largely restrict yourself to storing "plain" data such as Now, concerning Lastly, I would very much recommend building your own
into separate jld2 datasets e.g.:
and an analog for loading. |
Once the build has completed, you can preview any updated documentation at this URL: https://fluxml.ai/Flux.jl/previews/PR2239/ in ~20 minutes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor comments below, but I think this basically looks great.
(and do you really have/need
<:AbstractString
that are notString
?)
This would be super-easy to standardise on String to be safe.
We could surely also convert all Symbols to String, and have loadmodel!
convert back when it sees that the original model has a Symbol there.
Nothing
,Symbol
,Missing
, etc.
This PR uses missing
for unsaved fields. I think another obvious candidate would be ()
. Quick tests are that loadmodel!
understands this fine. (This is what Optimisers.setup
uses.)
At present nothing
is meaningful as the default automatic test/train state of Dropout
, etc. We could consider changing them all to something like mode = :auto / :train / :test
. Converting so say a string "nothing"
might be tricky for loadmodel!
to reverse, as these mutable structs take Union{Nothing,Bool}, but perhaps someone else does Union{Nothing,String}
or something.
I don't think the plan is to depend on any serialization libraries in Flux itself. At best this would be handled by extensions or left to users. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks good to go, but one minor point worth discussing since changing it would require a breaking release. I think this got brought up previously but went undiscussed.
Should be ready for merge |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a couple more tests, but looks ready
I'm a little unsure of this prune-the-tree idea. I agree it makes for a tidier state to look at. But it adds complexity for anything consuming it: Instead of knowing that field names will always match, and some will be empty ( And, in addition, accidentally using the wrong saved state is less likely to give you an informative error. In the AD world, Zygote always makes a NamedTuple for all fields. ChainRules allows some to be missing via Tangent, and IMO this is a mistake, it adds complexity without much gain. (Maybe for Dict it makes sense to allow missing entries, as they aren't free.) |
You are right, I had thought a missing key would be cleaner and the same as not recursing a node, but it is not. We have Now I understand why the change to |
I reverted the change, so state is not dropping fields anymore, and I preserved the change I made to |
Why not just put that change in the future PR that also adds |
I'm not using |
merge like this? |
Isn't |
I am not saying your change is using |
|
From memory my confusion about |
restored previous behaviour for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Late but...
(keys(ldsts) == keys(lsrcs)) || | ||
throw(ArgumentError("Tried to load $(keys(lsrcs)) into $(keys(ldsts)) but the structures do not match.")) | ||
|
||
foreach(ldsts, lsrcs) do ldst, lsrc | ||
keys_ldsts = keys(ldsts) | ||
keys_lsrcs = keys(lsrcs) | ||
collect(keys_ldsts) == collect(keys_lsrcs) || throw(ArgumentError("Tried to load $(keys_lsrcs) into $(keys_ldsts) but the structures do not match.")) | ||
|
||
for k in keys_lsrcs | ||
lsrc, ldst = lsrcs[k], ldsts[k] | ||
if ldst in cache # we already loaded this parameter before | ||
_tie_check(ldst, lsrc) && return ldst | ||
_tie_check(ldst, lsrc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this change to collect?
This implements the idea expressed multiple times that we should have something like
fmapstructure
for grabbing the internal state. The aim is to provide a less fragile solution than bson for serialization.PR Checklist