From 2c67ab991a0c7c9774da5259c928b318e83f5959 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 26 Jul 2023 09:16:36 +0200 Subject: [PATCH 1/9] mixed precision --- src/Optimisers.jl | 2 +- src/rules.jl | 22 ++++++++++++++++++++++ test/rules.jl | 19 ++++++++++++++++++- 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 20fc8aad..ecd6d0bc 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -15,7 +15,7 @@ include("rules.jl") export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp, AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief, WeightDecay, ClipGrad, ClipNorm, OptimiserChain, Lion, - AccumGrad + AccumGrad, MixedPrecision ### ### one-array functions diff --git a/src/rules.jl b/src/rules.jl index e994b740..99d1090d 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -752,3 +752,25 @@ function apply!(o::AccumGrad, state, x, dx) return (accum_dx, counter + 1), nothing end end + +struct MixedPrecision{T<:Number, O<:AbstractRule} <: AbstractRule + opt::O +end + +MixedPrecision(opt::AbstractRule) = MixedPrecision{Float32, typeof(opt)}(opt) +MixedPrecision{T}(opt::AbstractRule) where T = MixedPrecision{T, typeof(opt)}(opt) + +to_precision(::Type{T}, x::AbstractArray) where T = convert(AbstractArray{T}, x) + +function init(o::MixedPrecision{T}, x) where T + xT = to_precision(T, x) + return (xT, init(o.opt, xT)) +end + +function apply!(o::MixedPrecision{T}, state, x, dx) where T + xT, st = state + st′, dx′ = apply!(o.opt, st, xT, dx) + subtract!(xT, dx′) + @. x = xT + return (xT, st′), nothing +end diff --git a/test/rules.jl b/test/rules.jl index a10e055f..c5c9b94f 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -9,6 +9,7 @@ RULES = [ Descent(), Adam(), Momentum(), Nesterov(), Rprop(), RMSProp(), AdaGrad(), AdaMax(), AdaDelta(), AMSGrad(), NAdam(), AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(), + MixedPrecision{Float64}(Adam()), # A few chained combinations: OptimiserChain(WeightDecay(), Adam(0.001)), OptimiserChain(ClipNorm(), Adam(0.001)), @@ -266,4 +267,20 @@ end tree, x4 = Optimisers.update(tree, x3, g4) @test x4 ≈ x3 -end \ No newline at end of file +end + +@testset "MixedPrecision" begin + x = rand(Float16, 2) + opt_state = Optimisers.setup(MixedPrecision(Adam(1e-3)), x) + @test opt_state.state[1] isa Vector{Float32} + @test opt_state.state[2][1] isa Vector{Float32} + g = rand(Float16, 2) + new_state, new_x = Optimisers.update(opt_state, x, rand(Float16, 2)) + @test new_x == Float16.(new_state.state[1]) + @test new_x ≈ x .- 1e-3 .* g + + x = rand(Float16, 2) + opt_state = Optimisers.setup(MixedPrecision{Float64}(Adam(1e-3)), x) + @test opt_state.state[1] isa Vector{Float64} + @test opt_state.state[2][1] isa Vector{Float64} +end From fcc163a7489f8e244d4a7ecd4861f6af6a7e77b9 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 26 Jul 2023 09:39:32 +0200 Subject: [PATCH 2/9] docs --- docs/src/api.md | 3 ++- src/rules.jl | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index 6c021f25..feeaef63 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -25,8 +25,9 @@ In addition to the main course, you may wish to order some of these condiments: Optimisers.AccumGrad Optimisers.ClipGrad Optimisers.ClipNorm -Optimisers.WeightDecay +Optimisers.MixedPrecision Optimisers.OptimiserChain +Optimisers.WeightDecay ``` ## Model Interface diff --git a/src/rules.jl b/src/rules.jl index 99d1090d..f0fb21e4 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -753,6 +753,41 @@ function apply!(o::AccumGrad, state, x, dx) end end +""" + MixedPrecision(opt) + MixedPrecision{T}(opt) + +An optimiser that wraps another optimiser `opt` in order to perform mixed precision +training [1]. + +The state of `MixedPrecision` will contain a copy in precision `T` of the trainable parameter `x`, +call it `xT`. +The internal state of `opt` also operates at precision `T`. +If `T` is not specified, it defaults to `Float32`. + +Call `g` the gradient of `x`. Both `g` and `x` are typically in a precision lower than `T` +(e.g. `Float16`). + +In the `update!(opt_state, x, g)` call, `opt` is used to update `xT` instead of `x`, +then `x` is updated with the value of `xT`. + +[1] Micikevicius et al. '17, "Mixed Precision Training", https://arxiv.org/abs/1710.03740 . + +# Examples + +```julia +x = rand(Float16, 2) # a trainable parameter in low precision + +opt = MixedPrecision(Adam(1e-3)) +opt_state = Optimisers.setup(opt, x) # the state contains a copy of x in Float32 precision + +g = rand(Float16, 2) # a gradient in low precision + +# accumulation is performed in high precision, +# then also the low precision x is synced +Optimisers.update!(opt_state, x, g) +``` +""" struct MixedPrecision{T<:Number, O<:AbstractRule} <: AbstractRule opt::O end From bb87a66f1122e036867f56401011b829e656eef3 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 26 Jul 2023 11:00:15 +0200 Subject: [PATCH 3/9] handle non-writeable --- src/rules.jl | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index f0fb21e4..3ed7d4b1 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -795,17 +795,20 @@ end MixedPrecision(opt::AbstractRule) = MixedPrecision{Float32, typeof(opt)}(opt) MixedPrecision{T}(opt::AbstractRule) where T = MixedPrecision{T, typeof(opt)}(opt) -to_precision(::Type{T}, x::AbstractArray) where T = convert(AbstractArray{T}, x) - -function init(o::MixedPrecision{T}, x) where T - xT = to_precision(T, x) +function init(o::MixedPrecision{T}, x::AbstractArray) where T + xT = T.(x) return (xT, init(o.opt, xT)) end function apply!(o::MixedPrecision{T}, state, x, dx) where T xT, st = state st′, dx′ = apply!(o.opt, st, xT, dx) - subtract!(xT, dx′) - @. x = xT - return (xT, st′), nothing + xT = subtract!(xT, dx′) + if maywrite(x) + x .= xT + dx′ = nothing + else + dx′ = x .- eltype(x).(xT) + end + return (xT, st′), dx′ end From 6885d28b8181a645b8df54af5369b9cf4f2079e1 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 27 Jul 2023 01:20:03 +0200 Subject: [PATCH 4/9] adjust --- src/rules.jl | 5 +++++ test/runtests.jl | 14 ++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/src/rules.jl b/src/rules.jl index 3ed7d4b1..563efe21 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -792,6 +792,8 @@ struct MixedPrecision{T<:Number, O<:AbstractRule} <: AbstractRule opt::O end +@functor MixedPrecision + MixedPrecision(opt::AbstractRule) = MixedPrecision{Float32, typeof(opt)}(opt) MixedPrecision{T}(opt::AbstractRule) where T = MixedPrecision{T, typeof(opt)}(opt) @@ -812,3 +814,6 @@ function apply!(o::MixedPrecision{T}, state, x, dx) where T end return (xT, st′), dx′ end + +adjust(o::MixedPrecision, eta::Real) = MixedPrecision(adjust(o.opt, eta)) +adjust(o::MixedPrecision; kw...) = MixedPrecision(adjust(o.opt; kw...)) diff --git a/test/runtests.jl b/test/runtests.jl index 4e02f4d0..8f713d94 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -302,6 +302,20 @@ y2z(x) = x @test sc1.γ.rule.opts[1].delta == 2.5 @test sc1.γ.rule.opts[2].eta === 0.2f0 # unchanged @test sc1.γ.state[2][1] ≈ [0.1, 0.2, 0.2] + + # MixedPrecision + mp = Optimisers.setup(MixedPrecision(Momentum(0.1, 0.9)), m) + mp1, mp2 = Optimisers.update(mp, m, (α = nothing, γ = [1,10,100],)) + @test mp1.γ.rule.opt.eta == 0.1 + @test mp1.γ.state[2] ≈ [0.1, 1, 10] + + Optimisers.adjust!(mp1, 0.2) + @test mp1.γ.rule.opt.eta == 0.2 + @test mp1.γ.rule.opt.rho == 0.9 + + Optimisers.adjust!(mp1; eta=0.3, rho=0.7) + @test mp1.γ.rule.opt.eta == 0.3 + @test mp1.γ.rule.opt.rho == 0.7 end @testset "freeze/thaw" begin From fb9ad29a3294cd2052d3239b074d1550baf6ce7b Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 27 Jul 2023 07:43:42 +0200 Subject: [PATCH 5/9] more tests --- test/runtests.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 8f713d94..62fac535 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -258,6 +258,20 @@ y2z(x) = x @test sc2.γ.rule.opts[1].delta == 2.5 @test sc2.γ.rule.opts[2].eta === 0.001f0 # unchanged @test sc2.γ.state[2][1] ≈ [0.1, 0.2, 0.2] + + # MixedPrecision + mp = Optimisers.setup(MixedPrecision(Momentum(0.1, 0.9)), m) + mp1, mp2 = Optimisers.update(mp, m, (α = nothing, γ = [1,10,100],)) + @test mp1.γ.rule.opt.eta == 0.1 + @test mp1.γ.state[2] ≈ [0.1, 1, 10] + + mp2 = Optimisers.adjust(mp1, 0.2) + @test mp2.γ.rule.opt.eta == 0.2 + @test mp2.γ.rule.opt.rho == 0.9 + + mp3 = Optimisers.adjust(mp1; eta=0.3, rho=0.7) + @test mp3.γ.rule.opt.eta == 0.3 + @test mp3.γ.rule.opt.rho == 0.7 end @testset "adjusting parameters, in-place" begin From 84704b2ca7ed248450a2a5f43840cc2cf021ff2c Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 30 Jul 2023 07:26:20 -0500 Subject: [PATCH 6/9] Update src/rules.jl Co-authored-by: Kyle Daruwalla --- src/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index 563efe21..545b2a63 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -808,7 +808,7 @@ function apply!(o::MixedPrecision{T}, state, x, dx) where T xT = subtract!(xT, dx′) if maywrite(x) x .= xT - dx′ = nothing + dx′ = Zero() else dx′ = x .- eltype(x).(xT) end From caadc52c47f876f7273e103afe39190cfc7c017a Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 30 Jul 2023 07:26:33 -0500 Subject: [PATCH 7/9] Update src/rules.jl Co-authored-by: Kyle Daruwalla --- src/rules.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 545b2a63..a946491b 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -760,9 +760,8 @@ end An optimiser that wraps another optimiser `opt` in order to perform mixed precision training [1]. -The state of `MixedPrecision` will contain a copy in precision `T` of the trainable parameter `x`, -call it `xT`. -The internal state of `opt` also operates at precision `T`. +The state of `MixedPrecision{T}` will contain a copy in precision `T` of any trainable parameter `x`, +call it `xT`, as well as the internal state of `opt` also at precision `T`. If `T` is not specified, it defaults to `Float32`. Call `g` the gradient of `x`. Both `g` and `x` are typically in a precision lower than `T` From 34e99afb52d60bc0dfc729ce32401bc89b1c7878 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 30 Jul 2023 07:33:48 -0500 Subject: [PATCH 8/9] change constructor --- src/rules.jl | 15 +++++++-------- test/rules.jl | 4 ++-- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index a946491b..ad244f1e 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -754,8 +754,7 @@ function apply!(o::AccumGrad, state, x, dx) end """ - MixedPrecision(opt) - MixedPrecision{T}(opt) + MixedPrecision([T = Float32,] opt) An optimiser that wraps another optimiser `opt` in order to perform mixed precision training [1]. @@ -775,14 +774,14 @@ then `x` is updated with the value of `xT`. # Examples ```julia -x = rand(Float16, 2) # a trainable parameter in low precision +x = rand(Float16, 2) # A trainable parameter in low precision -opt = MixedPrecision(Adam(1e-3)) -opt_state = Optimisers.setup(opt, x) # the state contains a copy of x in Float32 precision +opt = MixedPrecision(Adam(1e-3)) # Equivalent to MixedPrecision(Float32, Adam(1e-3)) +opt_state = Optimisers.setup(opt, x) # The state contains a copy of x in Float32 precision -g = rand(Float16, 2) # a gradient in low precision +g = rand(Float16, 2) # A gradient in low precision -# accumulation is performed in high precision, +# Accumulation is performed in high precision, # then also the low precision x is synced Optimisers.update!(opt_state, x, g) ``` @@ -794,7 +793,7 @@ end @functor MixedPrecision MixedPrecision(opt::AbstractRule) = MixedPrecision{Float32, typeof(opt)}(opt) -MixedPrecision{T}(opt::AbstractRule) where T = MixedPrecision{T, typeof(opt)}(opt) +MixedPrecision(T::Type, opt::AbstractRule) = MixedPrecision{T, typeof(opt)}(opt) function init(o::MixedPrecision{T}, x::AbstractArray) where T xT = T.(x) diff --git a/test/rules.jl b/test/rules.jl index c5c9b94f..235f52e3 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -9,7 +9,7 @@ RULES = [ Descent(), Adam(), Momentum(), Nesterov(), Rprop(), RMSProp(), AdaGrad(), AdaMax(), AdaDelta(), AMSGrad(), NAdam(), AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(), - MixedPrecision{Float64}(Adam()), + MixedPrecision(Float64, Adam()), # A few chained combinations: OptimiserChain(WeightDecay(), Adam(0.001)), OptimiserChain(ClipNorm(), Adam(0.001)), @@ -280,7 +280,7 @@ end @test new_x ≈ x .- 1e-3 .* g x = rand(Float16, 2) - opt_state = Optimisers.setup(MixedPrecision{Float64}(Adam(1e-3)), x) + opt_state = Optimisers.setup(MixedPrecision(Float64, Adam(1e-3)), x) @test opt_state.state[1] isa Vector{Float64} @test opt_state.state[2][1] isa Vector{Float64} end From f7798c8ec70f2eb58e92cc69750e98b3683f9c3f Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 31 Jul 2023 05:24:14 -0600 Subject: [PATCH 9/9] add back nothing --- src/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index ad244f1e..1ec2d450 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -806,7 +806,7 @@ function apply!(o::MixedPrecision{T}, state, x, dx) where T xT = subtract!(xT, dx′) if maywrite(x) x .= xT - dx′ = Zero() + dx′ = nothing else dx′ = x .- eltype(x).(xT) end