Skip to content

Commit

Permalink
allow non-tuple data in the new train! (#2119)
Browse files Browse the repository at this point in the history
* allow non-tuple data

* cl/batchme

* add tests

* test multiple callback

* cleanup notes

* cleanup

* cleanup

* remove callbacks

* cleanup

* Update src/train.jl

Co-authored-by: Kyle Daruwalla <[email protected]>

Co-authored-by: Kyle Daruwalla <[email protected]>
  • Loading branch information
CarloLucibello and darsnack authored Nov 24, 2022
1 parent da8ce81 commit a5e5546
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 15 deletions.
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Expand Down
19 changes: 8 additions & 11 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,12 @@ end
Uses a `loss` function and training `data` to improve the `model`'s parameters
according to a particular optimisation rule `opt`. Iterates through `data` once,
evaluating `loss(model, d...)` for each `d` in data.
evaluating for each `d in data` either `loss(model, d...)` if `d isa Tuple`,
or else `loss(model, d)` for other `d`.
For example, with these definitions...
```
data = [(x1, y1), (x2, y2), (x3, y3)]; # each element must be a tuple
data = [(x1, y1), (x2, y2), (x3, y3)]
loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument
Expand All @@ -76,7 +77,7 @@ end
```
You can also write this loop yourself, if you need more flexibility.
For this reason `train!` is not highly extensible.
It adds only a few featurs to the loop above:
It adds only a few features to the loop above:
* Stop with a `DomainError` if the loss is infinite or `NaN` at any point.
Expand All @@ -88,9 +89,6 @@ It adds only a few featurs to the loop above:
(This is to move away from Zygote's "implicit" parameter handling, with `Grads`.)
* Instead of `loss` being a function which accepts only the data,
now it must also accept the `model` itself, as the first argument.
* `data` must iterate tuples, otherwise you get an error.
(Previously non-tuple types were not splatted into the loss.
Pass in `((d,) for d in data)` to simulate this.)
* `opt` should be the result of [`Flux.setup`](@ref). Using an optimiser
such as `Adam()` without this step should give you a warning.
* Callback functions are not supported.
Expand All @@ -100,9 +98,8 @@ function train!(loss, model, data, opt; cb = nothing)
isnothing(cb) || error("""train! does not support callback functions.
For more control use a loop with `gradient` and `update!`.""")
@withprogress for (i,d) in enumerate(data)
d isa Tuple || error("""train! expects as data an iterator producing tuples, but got $(typeof(d)).
Pass it `((d,) for d in data)`, or use `gradient` and `update!` for more control.""")
l, gs = Zygote.withgradient(m -> loss(m, d...), model)
d_splat = d isa Tuple ? d : (d,)
l, gs = Zygote.withgradient(m -> loss(m, d_splat...), model)
if !isfinite(l)
throw(DomainError("Loss is $l on data item $i, stopping training"))
end
Expand All @@ -112,8 +109,8 @@ function train!(loss, model, data, opt; cb = nothing)
end

# This method let you use Optimisers.Descent() without setup, when there is no state
function train!(loss, model, data, rule::Optimisers.AbstractRule)
train!(loss, model, data, _rule_to_state(model, rule))
function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing)
train!(loss, model, data, _rule_to_state(model, rule); cb)
end

function _rule_to_state(model, rule::Optimisers.AbstractRule)
Expand Down
11 changes: 8 additions & 3 deletions test/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,15 @@ end
@test CNT == 51 # stopped early
@test m1.weight[1] -5 # did not corrupt weights
end
@testset "data must give tuples" begin
m1 = Dense(1 => 1)
@test_throws ErrorException Flux.train!((args...,) -> 1, m1, [(x=1, y=2) for _ in 1:3], Descent(0.1))

@testset "non-tuple data" begin
loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
model = (weight=copy(w2), bias=zeros(10))
opt = Flux.setup(AdamW(), model)
Flux.train!(loss, model, (rand(10) for _ in 1: 10^5), opt)
@test loss(model, rand(10, 10)) < 0.01
end

@testset "callbacks give helpful error" begin
m1 = Dense(1 => 1)
cb = () -> println("this should not be printed")
Expand Down

0 comments on commit a5e5546

Please sign in to comment.