diff --git a/docs/src/api.md b/docs/src/api.md index 6c021f25..feeaef63 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -25,8 +25,9 @@ In addition to the main course, you may wish to order some of these condiments: Optimisers.AccumGrad Optimisers.ClipGrad Optimisers.ClipNorm -Optimisers.WeightDecay +Optimisers.MixedPrecision Optimisers.OptimiserChain +Optimisers.WeightDecay ``` ## Model Interface diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 20fc8aad..ecd6d0bc 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -15,7 +15,7 @@ include("rules.jl") export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp, AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief, WeightDecay, ClipGrad, ClipNorm, OptimiserChain, Lion, - AccumGrad + AccumGrad, MixedPrecision ### ### one-array functions diff --git a/src/rules.jl b/src/rules.jl index e994b740..1ec2d450 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -752,3 +752,66 @@ function apply!(o::AccumGrad, state, x, dx) return (accum_dx, counter + 1), nothing end end + +""" + MixedPrecision([T = Float32,] opt) + +An optimiser that wraps another optimiser `opt` in order to perform mixed precision +training [1]. + +The state of `MixedPrecision{T}` will contain a copy in precision `T` of any trainable parameter `x`, +call it `xT`, as well as the internal state of `opt` also at precision `T`. +If `T` is not specified, it defaults to `Float32`. + +Call `g` the gradient of `x`. Both `g` and `x` are typically in a precision lower than `T` +(e.g. `Float16`). + +In the `update!(opt_state, x, g)` call, `opt` is used to update `xT` instead of `x`, +then `x` is updated with the value of `xT`. + +[1] Micikevicius et al. '17, "Mixed Precision Training", https://arxiv.org/abs/1710.03740 . + +# Examples + +```julia +x = rand(Float16, 2) # A trainable parameter in low precision + +opt = MixedPrecision(Adam(1e-3)) # Equivalent to MixedPrecision(Float32, Adam(1e-3)) +opt_state = Optimisers.setup(opt, x) # The state contains a copy of x in Float32 precision + +g = rand(Float16, 2) # A gradient in low precision + +# Accumulation is performed in high precision, +# then also the low precision x is synced +Optimisers.update!(opt_state, x, g) +``` +""" +struct MixedPrecision{T<:Number, O<:AbstractRule} <: AbstractRule + opt::O +end + +@functor MixedPrecision + +MixedPrecision(opt::AbstractRule) = MixedPrecision{Float32, typeof(opt)}(opt) +MixedPrecision(T::Type, opt::AbstractRule) = MixedPrecision{T, typeof(opt)}(opt) + +function init(o::MixedPrecision{T}, x::AbstractArray) where T + xT = T.(x) + return (xT, init(o.opt, xT)) +end + +function apply!(o::MixedPrecision{T}, state, x, dx) where T + xT, st = state + st′, dx′ = apply!(o.opt, st, xT, dx) + xT = subtract!(xT, dx′) + if maywrite(x) + x .= xT + dx′ = nothing + else + dx′ = x .- eltype(x).(xT) + end + return (xT, st′), dx′ +end + +adjust(o::MixedPrecision, eta::Real) = MixedPrecision(adjust(o.opt, eta)) +adjust(o::MixedPrecision; kw...) = MixedPrecision(adjust(o.opt; kw...)) diff --git a/test/rules.jl b/test/rules.jl index a10e055f..235f52e3 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -9,6 +9,7 @@ RULES = [ Descent(), Adam(), Momentum(), Nesterov(), Rprop(), RMSProp(), AdaGrad(), AdaMax(), AdaDelta(), AMSGrad(), NAdam(), AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(), + MixedPrecision(Float64, Adam()), # A few chained combinations: OptimiserChain(WeightDecay(), Adam(0.001)), OptimiserChain(ClipNorm(), Adam(0.001)), @@ -266,4 +267,20 @@ end tree, x4 = Optimisers.update(tree, x3, g4) @test x4 ≈ x3 -end \ No newline at end of file +end + +@testset "MixedPrecision" begin + x = rand(Float16, 2) + opt_state = Optimisers.setup(MixedPrecision(Adam(1e-3)), x) + @test opt_state.state[1] isa Vector{Float32} + @test opt_state.state[2][1] isa Vector{Float32} + g = rand(Float16, 2) + new_state, new_x = Optimisers.update(opt_state, x, rand(Float16, 2)) + @test new_x == Float16.(new_state.state[1]) + @test new_x ≈ x .- 1e-3 .* g + + x = rand(Float16, 2) + opt_state = Optimisers.setup(MixedPrecision(Float64, Adam(1e-3)), x) + @test opt_state.state[1] isa Vector{Float64} + @test opt_state.state[2][1] isa Vector{Float64} +end diff --git a/test/runtests.jl b/test/runtests.jl index 4e02f4d0..62fac535 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -258,6 +258,20 @@ y2z(x) = x @test sc2.γ.rule.opts[1].delta == 2.5 @test sc2.γ.rule.opts[2].eta === 0.001f0 # unchanged @test sc2.γ.state[2][1] ≈ [0.1, 0.2, 0.2] + + # MixedPrecision + mp = Optimisers.setup(MixedPrecision(Momentum(0.1, 0.9)), m) + mp1, mp2 = Optimisers.update(mp, m, (α = nothing, γ = [1,10,100],)) + @test mp1.γ.rule.opt.eta == 0.1 + @test mp1.γ.state[2] ≈ [0.1, 1, 10] + + mp2 = Optimisers.adjust(mp1, 0.2) + @test mp2.γ.rule.opt.eta == 0.2 + @test mp2.γ.rule.opt.rho == 0.9 + + mp3 = Optimisers.adjust(mp1; eta=0.3, rho=0.7) + @test mp3.γ.rule.opt.eta == 0.3 + @test mp3.γ.rule.opt.rho == 0.7 end @testset "adjusting parameters, in-place" begin @@ -302,6 +316,20 @@ y2z(x) = x @test sc1.γ.rule.opts[1].delta == 2.5 @test sc1.γ.rule.opts[2].eta === 0.2f0 # unchanged @test sc1.γ.state[2][1] ≈ [0.1, 0.2, 0.2] + + # MixedPrecision + mp = Optimisers.setup(MixedPrecision(Momentum(0.1, 0.9)), m) + mp1, mp2 = Optimisers.update(mp, m, (α = nothing, γ = [1,10,100],)) + @test mp1.γ.rule.opt.eta == 0.1 + @test mp1.γ.state[2] ≈ [0.1, 1, 10] + + Optimisers.adjust!(mp1, 0.2) + @test mp1.γ.rule.opt.eta == 0.2 + @test mp1.γ.rule.opt.rho == 0.9 + + Optimisers.adjust!(mp1; eta=0.3, rho=0.7) + @test mp1.γ.rule.opt.eta == 0.3 + @test mp1.γ.rule.opt.rho == 0.7 end @testset "freeze/thaw" begin