Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Flux.state(x) #2239

Merged
merged 15 commits into from
May 5, 2023
Merged

add Flux.state(x) #2239

merged 15 commits into from
May 5, 2023

Conversation

CarloLucibello
Copy link
Member

@CarloLucibello CarloLucibello commented Apr 25, 2023

This implements the idea expressed multiple times that we should have something like fmapstructure for grabbing the internal state. The aim is to provide a less fragile solution than bson for serialization.

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

@ToucheSir
Copy link
Member

Related discussion: FluxML/Functors.jl#56.

Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unlike fmapstructure, this recursion doesn't know about shared arrays. Does this matter?

If it doesn't, or we don't want it, that's an argument for it being here not in Functors.jl

Bikeshedding the name, is state specific enough? We have opt_state with which this should never be confused. Since it pairs with loadmodel!(m, x), should its name echo that... like modelstate(m)? Or "bare" or "Base" or "tree"? (Honestly a tree if it ignores identity.)

test/utils.jl Outdated Show resolved Hide resolved
src/loading.jl Outdated Show resolved Hide resolved
src/loading.jl Outdated Show resolved Hide resolved
docs/src/models/saving.md Outdated Show resolved Hide resolved
docs/make.jl Outdated Show resolved Hide resolved
@CarloLucibello
Copy link
Member Author

CarloLucibello commented Apr 25, 2023

Unlike fmapstructure, this recursion doesn't know about shared arrays. Does this matter?

it doesn't need to since it doesn't transform (except for returning nothing).
If needed, we can provide in the future serializers and desarializers handling shared arrays.

Bikeshedding the name, is state specific enough? We have opt_state with which this should never be confused. Since it pairs with loadmodel!(m, x), should its name echo that... like modelstate(m)? Or "bare" or "Base" or "tree"? (Honestly a tree if it ignores identity.)

This is the common pattern that I expect for this method:

model = ...
opt_state = Optimisers.setup(Adam(), model)
model_state = Flux.state(model)
BSON.@save "checkpoint.bson" opt_state model_state

It reads nice and clear as it, I think less nice with modelstate, bare, Base or tree.
I would actually rename loadmodel! to loadstate!.
We could also have

Flux.state(opt::Optimisers.AbstractRule, x) = Optimisers.setup(opt, x)

but probably not a good idea, doesn't convey the idea of initialization.

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Apr 25, 2023

What should we keep by default in the state? Since the main purpose is serialization, at the very least we should exclude anonymous functions. This way, we could recommend JLD2.jl as the main serialization solution, since it is more robust and maintained compared to BSON.jl (cc @JonasIsensee).

A slim state would be given by keep = leaf -> leaf isa AbstractArray{<:Number}. This way though we would lose the active field in Dropout and BatchNorm, which I don't know if it's ok or not.

@theabhirath
Copy link
Member

Not sure if this PR is completely ready yet, but I just tried to save weights using Flux.state (using BSON.jl, admittedly) and loadmodel! blows up with both memory and time usage. For only a ResNet-18 model, I get close to 3.5 GB of memory being used and about 2 whole minutes for the model to load. FluxML/Metalhead.jl#225 (comment) leads me to believe this might be something that needs to be addressed from the loadmodel! side of things? Fixing this sort of storing and loading of weights without storing anything more than that needs to be stored (read: RNGs etc. when we store the entire Flux model as is the docs' convention now) is one of the major things that can improve pretrained model support in Metalhead, so I am very invested 😅

@CarloLucibello
Copy link
Member Author

Thanks for the investigation! Yes, the problem surely is with loadmodel! although I don't know the cause. So that shouldn't block this PR, but needs to be investigated.

@mcabbott
Copy link
Member

mcabbott commented Apr 25, 2023

What should we keep by default in the state?

One proposal is this: After some m = make_model(128), anything which can change during training etc. is state, and should be kept.

This excludes not just anonymous functions but all activation functions etc, which seems fine.

This rule would include RNG state. Do we want that? An alternative rule would append "and survives m |> gpu |> cpu" which I believe RNG state does not.

Edit: In the present state the behaviour is this; note that nothing has two meanings here, and that p = 0.1, active = nothing are not copied by loadmodel!.

julia> state(Dropout(0.1))
(p = 0.1, dims = nothing, active = nothing, rng = TaskLocalRNG())

julia> state(Dropout(0.1) |> trainmode!)
(p = 0.1, dims = nothing, active = true, rng = TaskLocalRNG())

julia> Flux.loadmodel!(trainmode!(Dropout(0.2)), state(Dropout(0.1))) |> dump
Dropout{Float64, Colon, TaskLocalRNG}
  p: Float64 0.2  # p = 0.1 was included in state, but not copied
  dims: Colon() (function of type Colon)
  active: Bool true  # active = nothing in state was not copied
  rng: TaskLocalRNG TaskLocalRNG()

Next, what is the use case for changing keep?

One thing I can picture wanting to configure is this: Instead of getting all state, get only the trainable parameters, as a similar nested structure. That cannot be controlled by keep, but would instead want children to be replaced by trainable. No idea whether loadmodel! will not at present understand such a truncated tree, but it could be made to.

julia> Flux.loadmodel!(BatchNorm(1), Flux.trainable(BatchNorm(1)))
ERROR: ArgumentError: Tried to load (β = Float32[0.0], γ = Float32[1.0]) into BatchNorm(1) but the structures do not match.

@darsnack
Copy link
Member

darsnack commented Apr 25, 2023

The current implementation is equivalent to fmapstructure(identity, m; exclude = keep, cache = nothing) (keep might need to change slightly). My point is we should not write yet another fmap-like loop somewhere in the ecosystem. I worry about maintaining all these different loops when decisions about trainable, children, walks, etc. are made. So unless it warrants a custom loop like loadmodel! or Optimisers.jl, then we should try to implement it with Functors. It will also let us easily turn on sharing if desired (better to keep it hidden for now).

Re: what's kept. loadmodel! presently ignores functions, RNGs, etc. So those are obvious candidates not to save. It does understand non-trainable state. I believe it assume the active field of the destination model.

@darsnack
Copy link
Member

One thing I can picture wanting to configure is this: Instead of getting all state, get only the trainable parameters, as a similar nested structure. That cannot be controlled by keep, but would instead want children to be replaced by trainable. No idea whether loadmodel! will at present understand such a truncated tree, but it could be made to.

The current solution in loadmodel! is the filter keyword to ignore the truncated sub-tree, then manually handle it if necessary. Perhaps a simple filter for the use-case you describe is if some sentinel value is written at that node in the stored tree.

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Apr 25, 2023

I'd like to settle for a minimal interface allowing robust serialization of the state (for whatever definition of state covers common needs such as the ones of Metalhead.jl).

If we allow anything except functions to be in the state (what this PR does right now) then we can have compatibility issues between julia versions (mostly due to rngs I guess, but maybe there are other examples):

# in julia 1.9

julia> s = Flux.state(Dropout(0.2))
(p = 0.2, dims = nothing, active = nothing, rng = Random.TaskLocalRNG())

julia> using BSON, JLD2

julia> BSON.@save "test.bson" s

julia> JLD2.jldsave("test.jld2"; s)
# now in julia 1,6
julia> using BSON, JLD2

julia> BSON.@load "test.bson" s
ERROR: UndefVarError: TaskLocalRNG not defined
Stacktrace:
  [1] (::BSON.var"#31#32")(m::Module, f::String)
    @ BSON ~/.julia/packages/BSON/DOYqe/src/extensions.jl:35
  ....

julia> s = JLD2.load("test.jld2", "s")
┌ Warning: type parameters for NamedTuple{(:p, :dims, :active, :rng),Tuple} do not match type NamedTuple in workspace; reconstructing
└ @ JLD2 ~/.julia/packages/JLD2/ryhNR/src/data/reconstructing_datatypes.jl:475
JLD2.ReconstructedTypes.var"##NamedTuple{(:p, :dims, :active, :rng),Tuple}#276"(0.2)

julia> s.p
0.2

julia> s.active
ERROR: type ##NamedTuple{(:p, :dims, :active, :rng),Tuple}#276 has no field active
Stacktrace:
 [1] getproperty(x::JLD2.ReconstructedTypes.var"##NamedTuple{(:p, :dims, :active, :rng),Tuple}#276", f::Symbol)
   @ Base ./Base.jl:33
 [2] top-level scope
   @ REPL[20]:1
  ....

What about exposing only a Flux.state_arrays(model) method for the time being, corresponding to keep = leaf -> leaf isa AbstractArray?

@CarloLucibello
Copy link
Member Author

So what I'm proposing is that this whole PR becomes

state_arrays(x) = fmapstructure(x -> x isa AbstractArray ? x : nothing, x)

@darsnack
Copy link
Member

Given what loadmodel! understands, I think your proposal is the correct thing to do right now. Minor change might be using Optimisers.numeric instead of checking just for arrays.

@ToucheSir
Copy link
Member

Skimming through some prior art and possible inspiration:

  • PyTorch has named_parameters(), which flattens the model tree but implicitly tracks structure in the names. Objax does something similiar with vars().
  • TF has an interesting combination of get_config() and get_weights(). The names are so-so, but I like the idea of having a way of accessing both even if we usually only save the latter.
  • The well-known JAX-based frameworks like Flax or Haiku already store parameters outside the callable model structure, but they generally use params or variables to refer to them (e.g.).

When it comes to serialization, I think our approach would be closest to how JAX encourages serializing PyTrees. A 1-1 match would require keeping around non-arrays (while excluding stuff like RNGs) and adding metadata tags for what type each node in the tree was created from. For now though, we can get away with far less metadata.

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Apr 26, 2023

I think in pytorch the closest thing would be .dict_state()

For our state I would prefer an opt-in approach for the leaf types that get into the state. We can make the contract that future (non-breaking) releases can modify the returned state by allowing for more types to get in, but never less.

The implementation could be

state(x) = fmapstructure(x -> _keep(x) ? x : missing, x)

_keep(x::TOKEEP) = true
_keep(x) = false

Then it is not clear where we should set the bar. These are possible candidates

  1. TOKEEP = AbstractArray
  2. TOKEEP = Union{AbstractArray, Number}
  3. TOKEEP = Union{AbstractArray, Number, Nothing, AbstractString, Symbol}

In the future we can expose some interface for customization.

Maybe we start conservative with 1 or 2 and then see if 3 or more is desired?

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Apr 28, 2023

In the end I went with TOKEEP = Union{AbstractArray, Number, Nothing, AbstractString, Symbol}, I don't see any harm in saving a few more simple types.

I also rewrote the "Saving and Loading" docs to encourage the serialization of the state and discourage the serialization of the model struct. In the new docs I emphasized JLD2 over BSON, since JLD2 is more actively maintained thanks to the work @JonasIsensee

@ToucheSir
Copy link
Member

I'm generally in favour of this now. Thoughts on JLD2 vs the Serialization stdlib? I know the former sometimes requires maintenance for new Julia versions, but the latter has its own issues.

@JonasIsensee
Copy link

Maybe this is a good moment for me to chime in. (also since I've been tagged the second time ;) )

From what I can tell, you've decided to largely restrict yourself to storing "plain" data such as Strings, Floats, Ints, and Arrays of such. In that case, JLD2 output matches regular HDF5 output. (Great, because you can use the inspection tools and load everything independently of julia)

Now, concerning Nothing, Symbol, Missing, etc. (and do you really have/need <:AbstractString that are not String ?).
These do not exist a priori in HDF5 but JLD2 will happily encode them. Given the stage of Julia development, I don't think these will ever break either. (except for the strings?)

Lastly, I would very much recommend building your own serializetojld2 function in order to unroll the different named fields
(citing from above)

julia> s = Flux.state(Dropout(0.2))
(p = 0.2, dims = nothing, active = nothing, rng = Random.TaskLocalRNG())

into separate jld2 datasets e.g.:

function writestatetojld2(fname, nt)
     jldopen(fname) do f
         wsession = JLD2.JLDWriteSession() # objid tracker for identity preservation
        for (k,v) in pairs(nt)
            write(f, string(k), v, wsession)
        end
     end
end

and an analog for loading.
While this is a bit more work on your side, it allows for much more graceful / gradual failing when some datasets (e.g. RNG states) fail to reconstruct.

@github-actions
Copy link
Contributor

Once the build has completed, you can preview any updated documentation at this URL: https://fluxml.ai/Flux.jl/previews/PR2239/ in ~20 minutes

Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor comments below, but I think this basically looks great.

(and do you really have/need <:AbstractString that are not String ?)

This would be super-easy to standardise on String to be safe.

We could surely also convert all Symbols to String, and have loadmodel! convert back when it sees that the original model has a Symbol there.

Nothing, Symbol, Missing, etc.

This PR uses missing for unsaved fields. I think another obvious candidate would be (). Quick tests are that loadmodel! understands this fine. (This is what Optimisers.setup uses.)

At present nothing is meaningful as the default automatic test/train state of Dropout, etc. We could consider changing them all to something like mode = :auto / :train / :test. Converting so say a string "nothing" might be tricky for loadmodel! to reverse, as these mutable structs take Union{Nothing,Bool}, but perhaps someone else does Union{Nothing,String} or something.

src/loading.jl Outdated Show resolved Hide resolved
src/loading.jl Outdated Show resolved Hide resolved
src/loading.jl Outdated Show resolved Hide resolved
@ToucheSir
Copy link
Member

I don't think the plan is to depend on any serialization libraries in Flux itself. At best this would be handled by extensions or left to users.

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good to go, but one minor point worth discussing since changing it would require a breaking release. I think this got brought up previously but went undiscussed.

NEWS.md Outdated Show resolved Hide resolved
src/loading.jl Outdated Show resolved Hide resolved
@CarloLucibello
Copy link
Member Author

Should be ready for merge

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a couple more tests, but looks ready

test/loading.jl Outdated Show resolved Hide resolved
test/loading.jl Outdated Show resolved Hide resolved
@mcabbott
Copy link
Member

mcabbott commented May 5, 2023

I'm a little unsure of this prune-the-tree idea. I agree it makes for a tidier state to look at.

But it adds complexity for anything consuming it: Instead of knowing that field names will always match, and some will be empty (::Missing or ::Tuple{}), now you need to also handle the field names being a subset. You can't use map(f, ::NamedTuple, ::NamedTuple), and I think you can't use fmap(f, model, state). Yet you are not free from knowing about a special token for empty spaces at all, since this is still needed for when x::Tuple.

And, in addition, accidentally using the wrong saved state is less likely to give you an informative error.

In the AD world, Zygote always makes a NamedTuple for all fields. ChainRules allows some to be missing via Tangent, and IMO this is a mistake, it adds complexity without much gain. (Maybe for Dict it makes sense to allow missing entries, as they aren't free.)

@darsnack
Copy link
Member

darsnack commented May 5, 2023

You are right, I had thought a missing key would be cleaner and the same as not recursing a node, but it is not. We have Functors.NoChildren which is exactly (). It may be ugly to use NoChildren for a missing leaf, but it is more robust than deleting the leaf.

Now I understand why the change to loadmodel! was necessary. I was scratching my head before, but the loop behaves differently for asymmetric keys. We spent a lot of time discussing that function and what should be considered correct. I feel like needing to change it to make the current state work is another indicator that () better matches the semantics we have already established.

darsnack
darsnack previously approved these changes May 5, 2023
@CarloLucibello
Copy link
Member Author

I reverted the change, so state is not dropping fields anymore, and () is for non-state leaves.

I preserved the change I made to loadmodel! though, I think it is good for it to be able to accept a subtree as a source.
If needed, we can add an option strict=true to it forcing all fields to be present in the source. Future PR material possibly.

@darsnack darsnack self-requested a review May 5, 2023 13:34
@darsnack darsnack dismissed their stale review May 5, 2023 13:34

Oops

@darsnack
Copy link
Member

darsnack commented May 5, 2023

Why not just put that change in the future PR that also adds strict? I can see both sides for accepting subtrees and not. The current loadmodel! forces you to use filter (albeit undocumented) to handle the subtree case. For saving/loading, I think it is better to be overly strict with tooling for handling edges. The filter keyword is a weak implementation of Orbax's transforms. Orbax might be guidance for us, because Jax/Flax most closely matches our approach to layers.

@CarloLucibello
Copy link
Member Author

I'm not using filter, filter could be removed in fact. I'm simply iterating over the source children

@CarloLucibello
Copy link
Member Author

merge like this?

@ToucheSir
Copy link
Member

Isn't filter still needed for Metalhead? I confess I haven't followed the implementation here very closely.

@darsnack
Copy link
Member

darsnack commented May 5, 2023

I am not saying your change is using filter. I am saying filter is an existing mechanism for handling a partial copy of a subtree by skipping it and doing it manually. All I'm suggesting is that the change to loadmodel! is no longer required to make this PR work, and that we should do it in another PR where there is time to discuss it. Isn't it breaking anyways? Something that previously intentionally errored does not error anymore.

@darsnack
Copy link
Member

darsnack commented May 5, 2023

filter was added for a use-case that came up in Metalhead, but once we start saving just the state like this PR (i.e. in 0.8), we won't need it anymore. But the original use-case has not disappeared. I do think filter should be re-hauled with something more sophisticated like Orbax's transforms. The way filter works is confusing, so I am all for talking about it as well as better handling of partial trees, but it just seems big enough to move to another PR.

@mcabbott
Copy link
Member

mcabbott commented May 5, 2023

From memory my confusion about filter was exactly the intersection with something like state -- if it goes by the type of the thing being skipped, then it may treat load!(m, m2) and load!(m, state(m2)) differently. There ought to be a super-nice way, but I agree better as another PR.

@CarloLucibello
Copy link
Member Author

restored previous behaviour for loadmodel!

@CarloLucibello CarloLucibello merged commit 5fe8ada into master May 5, 2023
Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Late but...

Comment on lines -91 to +100
(keys(ldsts) == keys(lsrcs)) ||
throw(ArgumentError("Tried to load $(keys(lsrcs)) into $(keys(ldsts)) but the structures do not match."))

foreach(ldsts, lsrcs) do ldst, lsrc
keys_ldsts = keys(ldsts)
keys_lsrcs = keys(lsrcs)
collect(keys_ldsts) == collect(keys_lsrcs) || throw(ArgumentError("Tried to load $(keys_lsrcs) into $(keys_ldsts) but the structures do not match."))

for k in keys_lsrcs
lsrc, ldst = lsrcs[k], ldsts[k]
if ldst in cache # we already loaded this parameter before
_tie_check(ldst, lsrc) && return ldst
_tie_check(ldst, lsrc)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change to collect?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants