From 2cc0305bcda76eee1fc2603b8aa2a0b701cf3650 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 21 Nov 2022 08:01:38 +0100 Subject: [PATCH 01/10] allow non-tuple data --- Project.toml | 1 - src/train.jl | 18 ++++++++++-------- test/train.jl | 8 ++++++++ 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 84e20d8e9c..a01cab0f9f 100644 --- a/Project.toml +++ b/Project.toml @@ -21,7 +21,6 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] diff --git a/src/train.jl b/src/train.jl index 919821b710..2fb9914005 100644 --- a/src/train.jl +++ b/src/train.jl @@ -56,11 +56,12 @@ end Uses a `loss` function and training `data` to improve the `model`'s parameters according to a particular optimisation rule `opt`. Iterates through `data` once, -evaluating `loss(model, d...)` for each `d` in data. +evaluating `loss(model, d...)` for each `d` in `data` in case of tuple iterates, +and `loss(model, d)` otherwise. For example, with these definitions... ``` -data = [(x1, y1), (x2, y2), (x3, y3)]; # each element must be a tuple +data = [(x1, y1), (x2, y2), (x3, y3)] loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument @@ -76,7 +77,7 @@ end ``` You can also write this loop yourself, if you need more flexibility. For this reason `train!` is not highly extensible. -It adds only a few featurs to the loop above: +It adds only a few features to the loop above: * Stop with a `DomainError` if the loss is infinite or `NaN` at any point. @@ -88,9 +89,8 @@ It adds only a few featurs to the loop above: (This is to move away from Zygote's "implicit" parameter handling, with `Grads`.) * Instead of `loss` being a function which accepts only the data, now it must also accept the `model` itself, as the first argument. - * `data` must iterate tuples, otherwise you get an error. - (Previously non-tuple types were not splatted into the loss. - Pass in `((d,) for d in data)` to simulate this.) + * If `data` iterates over tuples, these will be splatted when passed to `loss`. + If you want to avoid the splatting, you can pass `((d,) for d in data)` instead. * `opt` should be the result of [`Flux.setup`](@ref). Using an optimiser such as `Adam()` without this step should give you a warning. * Callback functions are not supported. @@ -100,8 +100,7 @@ function train!(loss, model, data, opt; cb = nothing) isnothing(cb) || error("""train! does not support callback functions. For more control use a loop with `gradient` and `update!`.""") @withprogress for (i,d) in enumerate(data) - d isa Tuple || error("""train! expects as data an iterator producing tuples, but got $(typeof(d)). - Pass it `((d,) for d in data)`, or use `gradient` and `update!` for more control.""") + d = batchmemaybe(d) l, gs = Zygote.withgradient(m -> loss(m, d...), model) if !isfinite(l) throw(DomainError("Loss is $l on data item $i, stopping training")) @@ -116,6 +115,9 @@ function train!(loss, model, data, rule::Optimisers.AbstractRule) train!(loss, model, data, _rule_to_state(model, rule)) end +batchmemaybe(x) = tuple(x) +batchmemaybe(x::Tuple) = x + function _rule_to_state(model, rule::Optimisers.AbstractRule) state = setup(rule, model) @gensym warn_id diff --git a/test/train.jl b/test/train.jl index 49ecf9c751..8b3f87526c 100644 --- a/test/train.jl +++ b/test/train.jl @@ -30,6 +30,14 @@ using Random Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) @test loss(model, rand(10, 10)) < 0.01 end + + @testset "non-tuple data" begin + loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) + model = (weight=copy(w2), bias=zeros(10)) + opt = Flux.setup(AdamW(), model) + Flux.train!(loss, model, (rand(10) for _ in 1: 10^5), opt) + @test loss(model, rand(10, 10)) < 0.01 + end end @testset "Explicit Flux.train! features" begin From c5312f422ef0513d9b9e66fdf71e9eafd4501924 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 21 Nov 2022 08:49:10 +0100 Subject: [PATCH 02/10] cl/batchme --- src/train.jl | 36 ++++++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/src/train.jl b/src/train.jl index 2fb9914005..87c70aa4dc 100644 --- a/src/train.jl +++ b/src/train.jl @@ -52,7 +52,7 @@ function setup(rule::Optimisers.AbstractRule, model) end """ - train!(loss, model, data, opt) + train!(loss, model, data, opt; cb=nothing) Uses a `loss` function and training `data` to improve the `model`'s parameters according to a particular optimisation rule `opt`. Iterates through `data` once, @@ -97,13 +97,14 @@ It adds only a few features to the loop above: But any code can be included in the above `for` loop. """ function train!(loss, model, data, opt; cb = nothing) - isnothing(cb) || error("""train! does not support callback functions. - For more control use a loop with `gradient` and `update!`.""") + cb = old_cb_deprecation(cb) + cb = runall(cb) @withprogress for (i,d) in enumerate(data) - d = batchmemaybe(d) - l, gs = Zygote.withgradient(m -> loss(m, d...), model) + ds = batchmemaybe(d) + l, gs = Zygote.withgradient(m -> loss(m, ds...), model) + cb((; model, data=d, opt, step=i, loss=l, gradient=gs[1])) if !isfinite(l) - throw(DomainError("Loss is $l on data item $i, stopping training")) + throw(DomainError("Loss is $(l) on data item $i, stopping training")) end opt, model = Optimisers.update!(opt, model, gs[1]) @logprogress Base.haslength(data) ? i/length(data) : nothing @@ -111,13 +112,32 @@ function train!(loss, model, data, opt; cb = nothing) end # This method let you use Optimisers.Descent() without setup, when there is no state -function train!(loss, model, data, rule::Optimisers.AbstractRule) - train!(loss, model, data, _rule_to_state(model, rule)) +function train!(loss, model, data, rule::Optimisers.AbstractRule; cb=nothing) + train!(loss, model, data, _rule_to_state(model, rule); cb) end batchmemaybe(x) = tuple(x) batchmemaybe(x::Tuple) = x +call(f, xs...) = f(xs...) +runall(f) = f +runall(fs::AbstractVector) = x -> foreach(call, fs, x) + +old_cb_deprecation(f::AbstractVector) = [old_cb_deprecation(f) for f in f] + +function old_cb_deprecation(f) + try + f(x) + catch e + if e isa MethodError + @warn "Callback functions must accept a named tuple argument. See the docs for `train!`." + f() + else + rethrow(e) + end + end +end + function _rule_to_state(model, rule::Optimisers.AbstractRule) state = setup(rule, model) @gensym warn_id From 63967dc37417d85771162f511e60b1d309941f36 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 21 Nov 2022 15:37:23 +0100 Subject: [PATCH 03/10] add tests --- src/train.jl | 24 +++++++++++++----------- test/train.jl | 30 +++++++++++++++++++++++++----- 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/src/train.jl b/src/train.jl index 87c70aa4dc..4d57148f8d 100644 --- a/src/train.jl +++ b/src/train.jl @@ -52,7 +52,7 @@ function setup(rule::Optimisers.AbstractRule, model) end """ - train!(loss, model, data, opt; cb=nothing) + train!(loss, model, data, opt; [cb]) Uses a `loss` function and training `data` to improve the `model`'s parameters according to a particular optimisation rule `opt`. Iterates through `data` once, @@ -96,7 +96,7 @@ It adds only a few features to the loop above: * Callback functions are not supported. But any code can be included in the above `for` loop. """ -function train!(loss, model, data, opt; cb = nothing) +function train!(loss, model, data, opt; cb = x -> nothing) cb = old_cb_deprecation(cb) cb = runall(cb) @withprogress for (i,d) in enumerate(data) @@ -112,7 +112,7 @@ function train!(loss, model, data, opt; cb = nothing) end # This method let you use Optimisers.Descent() without setup, when there is no state -function train!(loss, model, data, rule::Optimisers.AbstractRule; cb=nothing) +function train!(loss, model, data, rule::Optimisers.AbstractRule; cb=x->nothing) train!(loss, model, data, _rule_to_state(model, rule); cb) end @@ -126,14 +126,16 @@ runall(fs::AbstractVector) = x -> foreach(call, fs, x) old_cb_deprecation(f::AbstractVector) = [old_cb_deprecation(f) for f in f] function old_cb_deprecation(f) - try - f(x) - catch e - if e isa MethodError - @warn "Callback functions must accept a named tuple argument. See the docs for `train!`." - f() - else - rethrow(e) + return x -> begin + try + f(x) + catch e + if e isa MethodError + @warn "Callback functions must accept a named tuple argument. See the docs for `train!`." + f() + else + rethrow(e) + end end end end diff --git a/test/train.jl b/test/train.jl index 8b3f87526c..900e95539b 100644 --- a/test/train.jl +++ b/test/train.jl @@ -52,14 +52,34 @@ end @test CNT == 51 # stopped early @test m1.weight[1] ≈ -5 # did not corrupt weights end - @testset "data must give tuples" begin + + @testset "deprecated callback style" begin m1 = Dense(1 => 1) - @test_throws ErrorException Flux.train!((args...,) -> 1, m1, [(x=1, y=2) for _ in 1:3], Descent(0.1)) + cb = () -> println("this should not be printed") + Flux.train!((args...,) -> 1, m1, [(1,2)], Descent(0.1); cb) end - @testset "callbacks give helpful error" begin + + + @testset "callback" begin m1 = Dense(1 => 1) - cb = () -> println("this should not be printed") - @test_throws ErrorException Flux.train!((args...,) -> 1, m1, [(1,2)], Descent(0.1); cb) + i = 0 + data = [rand(1) for _ in 1:5] + res = [] + cb = x -> push!(res, x) + opt = Flux.setup(AdamW(), m1) + Flux.train!((m, x) -> sum(m(x)), m1, data, opt; cb) + + @test length(res) == length(data) + for (i,x) in enumerate(res) + @test x isa NamedTuple + @test x.step == i + @test haskey(x, :loss) + @test x.gradient.weight isa Matrix + @test x.gradient.bias isa Vector + @test x.model === m1 + @test haskey(x, :data) + @test x.opt === opt + end end end From 22f84cec17c954d48f22fee06f1d851d71d2d565 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 21 Nov 2022 16:00:38 +0100 Subject: [PATCH 04/10] test multiple callback --- src/train.jl | 35 ++++++++--------------------------- test/train.jl | 16 +++++++++++++--- 2 files changed, 21 insertions(+), 30 deletions(-) diff --git a/src/train.jl b/src/train.jl index 4d57148f8d..67fa5708c3 100644 --- a/src/train.jl +++ b/src/train.jl @@ -56,8 +56,8 @@ end Uses a `loss` function and training `data` to improve the `model`'s parameters according to a particular optimisation rule `opt`. Iterates through `data` once, -evaluating `loss(model, d...)` for each `d` in `data` in case of tuple iterates, -and `loss(model, d)` otherwise. +evaluating for each `d in data` either `loss(model, d...)` if `d isa Tuple`, +or else `loss(model, d)` for other `d`. For example, with these definitions... ``` @@ -97,11 +97,11 @@ It adds only a few features to the loop above: But any code can be included in the above `for` loop. """ function train!(loss, model, data, opt; cb = x -> nothing) - cb = old_cb_deprecation(cb) cb = runall(cb) + @show cb @withprogress for (i,d) in enumerate(data) - ds = batchmemaybe(d) - l, gs = Zygote.withgradient(m -> loss(m, ds...), model) + d_splat = d isa Tuple ? d : (d,) + l, gs = Zygote.withgradient(m -> loss(m, d_splat...), model) cb((; model, data=d, opt, step=i, loss=l, gradient=gs[1])) if !isfinite(l) throw(DomainError("Loss is $(l) on data item $i, stopping training")) @@ -112,33 +112,14 @@ function train!(loss, model, data, opt; cb = x -> nothing) end # This method let you use Optimisers.Descent() without setup, when there is no state -function train!(loss, model, data, rule::Optimisers.AbstractRule; cb=x->nothing) +function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = x -> nothing) train!(loss, model, data, _rule_to_state(model, rule); cb) end -batchmemaybe(x) = tuple(x) -batchmemaybe(x::Tuple) = x - call(f, xs...) = f(xs...) runall(f) = f -runall(fs::AbstractVector) = x -> foreach(call, fs, x) - -old_cb_deprecation(f::AbstractVector) = [old_cb_deprecation(f) for f in f] - -function old_cb_deprecation(f) - return x -> begin - try - f(x) - catch e - if e isa MethodError - @warn "Callback functions must accept a named tuple argument. See the docs for `train!`." - f() - else - rethrow(e) - end - end - end -end + +runall(fs::AbstractVector) = x -> [f(x) for f in fs] function _rule_to_state(model, rule::Optimisers.AbstractRule) state = setup(rule, model) diff --git a/test/train.jl b/test/train.jl index 900e95539b..21ae2a224a 100644 --- a/test/train.jl +++ b/test/train.jl @@ -59,10 +59,8 @@ end Flux.train!((args...,) -> 1, m1, [(1,2)], Descent(0.1); cb) end - - @testset "callback" begin + @testset "single callback" begin m1 = Dense(1 => 1) - i = 0 data = [rand(1) for _ in 1:5] res = [] cb = x -> push!(res, x) @@ -81,6 +79,18 @@ end @test x.opt === opt end end + + @testset "multiple callbacks" begin + m1 = Dense(1 => 1) + i1, i2 = 0, 0 + data = [rand(1) for _ in 1:5] + cb1 = res -> i1 += 1 + cb2 = res -> i2 += res.step + opt = Flux.setup(AdamW(), m1) + Flux.train!((m, x) -> sum(m(x)), m1, data, opt; cb = [cb1, cb2]) + @test i1 == length(data) + @test i2 == sum(1:length(data)) + end end @testset "Explicit Flux.update! features" begin From 34b2ab7a81c1a4579219063dcd355953f9a69b73 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 21 Nov 2022 16:02:53 +0100 Subject: [PATCH 05/10] cleanup notes --- src/train.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/train.jl b/src/train.jl index 67fa5708c3..0fa5525ab8 100644 --- a/src/train.jl +++ b/src/train.jl @@ -89,12 +89,9 @@ It adds only a few features to the loop above: (This is to move away from Zygote's "implicit" parameter handling, with `Grads`.) * Instead of `loss` being a function which accepts only the data, now it must also accept the `model` itself, as the first argument. - * If `data` iterates over tuples, these will be splatted when passed to `loss`. - If you want to avoid the splatting, you can pass `((d,) for d in data)` instead. * `opt` should be the result of [`Flux.setup`](@ref). Using an optimiser such as `Adam()` without this step should give you a warning. - * Callback functions are not supported. - But any code can be included in the above `for` loop. + * Callback functions now receive a named tuple as input. """ function train!(loss, model, data, opt; cb = x -> nothing) cb = runall(cb) From 1101f53dda4df5b0a3c68307105662568bae24a8 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 21 Nov 2022 16:08:14 +0100 Subject: [PATCH 06/10] cleanup --- test/train.jl | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/test/train.jl b/test/train.jl index 21ae2a224a..176b628922 100644 --- a/test/train.jl +++ b/test/train.jl @@ -51,12 +51,7 @@ end end @test CNT == 51 # stopped early @test m1.weight[1] ≈ -5 # did not corrupt weights - end - - @testset "deprecated callback style" begin - m1 = Dense(1 => 1) - cb = () -> println("this should not be printed") - Flux.train!((args...,) -> 1, m1, [(1,2)], Descent(0.1); cb) + end @testset "single callback" begin @@ -91,6 +86,7 @@ end @test i1 == length(data) @test i2 == sum(1:length(data)) end + end @testset "Explicit Flux.update! features" begin From 3759e9e8752e5a46da39e82cbab3a160af572c48 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 21 Nov 2022 16:45:09 +0100 Subject: [PATCH 07/10] cleanup --- src/train.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/train.jl b/src/train.jl index 0fa5525ab8..0a9ed7339c 100644 --- a/src/train.jl +++ b/src/train.jl @@ -95,7 +95,6 @@ It adds only a few features to the loop above: """ function train!(loss, model, data, opt; cb = x -> nothing) cb = runall(cb) - @show cb @withprogress for (i,d) in enumerate(data) d_splat = d isa Tuple ? d : (d,) l, gs = Zygote.withgradient(m -> loss(m, d_splat...), model) From a08785e4026fcc46fba8d9f6ee9965ffba3fc67d Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 22 Nov 2022 06:45:46 +0100 Subject: [PATCH 08/10] remove callbacks --- src/train.jl | 20 ++++++++------------ test/train.jl | 34 ---------------------------------- 2 files changed, 8 insertions(+), 46 deletions(-) diff --git a/src/train.jl b/src/train.jl index 0a9ed7339c..885365da2b 100644 --- a/src/train.jl +++ b/src/train.jl @@ -52,7 +52,7 @@ function setup(rule::Optimisers.AbstractRule, model) end """ - train!(loss, model, data, opt; [cb]) + train!(loss, model, data, opt) Uses a `loss` function and training `data` to improve the `model`'s parameters according to a particular optimisation rule `opt`. Iterates through `data` once, @@ -91,16 +91,17 @@ It adds only a few features to the loop above: now it must also accept the `model` itself, as the first argument. * `opt` should be the result of [`Flux.setup`](@ref). Using an optimiser such as `Adam()` without this step should give you a warning. - * Callback functions now receive a named tuple as input. + * Callback functions are not supported. + But any code can be included in the above `for` loop. """ -function train!(loss, model, data, opt; cb = x -> nothing) - cb = runall(cb) +function train!(loss, model, data, opt; cb = nothing) + isnothing(cb) || error("""train! does not support callback functions. + For more control use a loop with `gradient` and `update!`.""") @withprogress for (i,d) in enumerate(data) d_splat = d isa Tuple ? d : (d,) l, gs = Zygote.withgradient(m -> loss(m, d_splat...), model) - cb((; model, data=d, opt, step=i, loss=l, gradient=gs[1])) if !isfinite(l) - throw(DomainError("Loss is $(l) on data item $i, stopping training")) + throw(DomainError("Loss is $l on data item $i, stopping training")) end opt, model = Optimisers.update!(opt, model, gs[1]) @logprogress Base.haslength(data) ? i/length(data) : nothing @@ -108,15 +109,10 @@ function train!(loss, model, data, opt; cb = x -> nothing) end # This method let you use Optimisers.Descent() without setup, when there is no state -function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = x -> nothing) +function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing) train!(loss, model, data, _rule_to_state(model, rule); cb) end -call(f, xs...) = f(xs...) -runall(f) = f - -runall(fs::AbstractVector) = x -> [f(x) for f in fs] - function _rule_to_state(model, rule::Optimisers.AbstractRule) state = setup(rule, model) @gensym warn_id diff --git a/test/train.jl b/test/train.jl index 176b628922..083c34d6f3 100644 --- a/test/train.jl +++ b/test/train.jl @@ -53,40 +53,6 @@ end @test m1.weight[1] ≈ -5 # did not corrupt weights end - - @testset "single callback" begin - m1 = Dense(1 => 1) - data = [rand(1) for _ in 1:5] - res = [] - cb = x -> push!(res, x) - opt = Flux.setup(AdamW(), m1) - Flux.train!((m, x) -> sum(m(x)), m1, data, opt; cb) - - @test length(res) == length(data) - for (i,x) in enumerate(res) - @test x isa NamedTuple - @test x.step == i - @test haskey(x, :loss) - @test x.gradient.weight isa Matrix - @test x.gradient.bias isa Vector - @test x.model === m1 - @test haskey(x, :data) - @test x.opt === opt - end - end - - @testset "multiple callbacks" begin - m1 = Dense(1 => 1) - i1, i2 = 0, 0 - data = [rand(1) for _ in 1:5] - cb1 = res -> i1 += 1 - cb2 = res -> i2 += res.step - opt = Flux.setup(AdamW(), m1) - Flux.train!((m, x) -> sum(m(x)), m1, data, opt; cb = [cb1, cb2]) - @test i1 == length(data) - @test i2 == sum(1:length(data)) - end - end @testset "Explicit Flux.update! features" begin From 8e6e9ac2020acf43c4d4c554a14d88966b1f987e Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 22 Nov 2022 06:47:46 +0100 Subject: [PATCH 09/10] cleanup --- src/train.jl | 2 +- test/train.jl | 23 ++++++++++++++--------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/train.jl b/src/train.jl index 885365da2b..7bb74f04e4 100644 --- a/src/train.jl +++ b/src/train.jl @@ -61,7 +61,7 @@ or else `loss(model, d)` for other `d`. For example, with these definitions... ``` -data = [(x1, y1), (x2, y2), (x3, y3)] +data = [(x1, y1), (x2, y2), (x3, y3)] loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument diff --git a/test/train.jl b/test/train.jl index 083c34d6f3..cfadde7d9b 100644 --- a/test/train.jl +++ b/test/train.jl @@ -30,14 +30,6 @@ using Random Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) @test loss(model, rand(10, 10)) < 0.01 end - - @testset "non-tuple data" begin - loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) - model = (weight=copy(w2), bias=zeros(10)) - opt = Flux.setup(AdamW(), model) - Flux.train!(loss, model, (rand(10) for _ in 1: 10^5), opt) - @test loss(model, rand(10, 10)) < 0.01 - end end @testset "Explicit Flux.train! features" begin @@ -51,7 +43,20 @@ end end @test CNT == 51 # stopped early @test m1.weight[1] ≈ -5 # did not corrupt weights - + end + + @testset "non-tuple data" begin + loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) + model = (weight=copy(w2), bias=zeros(10)) + opt = Flux.setup(AdamW(), model) + Flux.train!(loss, model, (rand(10) for _ in 1: 10^5), opt) + @test loss(model, rand(10, 10)) < 0.01 + end + + @testset "callbacks give helpful error" begin + m1 = Dense(1 => 1) + cb = () -> println("this should not be printed") + @test_throws ErrorException Flux.train!((args...,) -> 1, m1, [(1,2)], Descent(0.1); cb) end end From b8b3d8aa8016a869bbee025cdc86c757a43b3015 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 24 Nov 2022 15:51:34 +0100 Subject: [PATCH 10/10] Update src/train.jl Co-authored-by: Kyle Daruwalla --- src/train.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/train.jl b/src/train.jl index 7bb74f04e4..d548e0ac02 100644 --- a/src/train.jl +++ b/src/train.jl @@ -57,7 +57,7 @@ end Uses a `loss` function and training `data` to improve the `model`'s parameters according to a particular optimisation rule `opt`. Iterates through `data` once, evaluating for each `d in data` either `loss(model, d...)` if `d isa Tuple`, -or else `loss(model, d)` for other `d`. +or else `loss(model, d)` for other `d`. For example, with these definitions... ```