-
-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rule for mixed precision training #152
base: master
Are you sure you want to change the base?
Changes from all commits
2c67ab9
fcc163a
bb87a66
6885d28
fb9ad29
84704b2
caadc52
34e99af
f7798c8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -752,3 +752,66 @@ function apply!(o::AccumGrad, state, x, dx) | |||||
return (accum_dx, counter + 1), nothing | ||||||
end | ||||||
end | ||||||
|
||||||
""" | ||||||
MixedPrecision([T = Float32,] opt) | ||||||
|
||||||
An optimiser that wraps another optimiser `opt` in order to perform mixed precision | ||||||
training [1]. | ||||||
|
||||||
The state of `MixedPrecision{T}` will contain a copy in precision `T` of any trainable parameter `x`, | ||||||
call it `xT`, as well as the internal state of `opt` also at precision `T`. | ||||||
If `T` is not specified, it defaults to `Float32`. | ||||||
|
||||||
Call `g` the gradient of `x`. Both `g` and `x` are typically in a precision lower than `T` | ||||||
(e.g. `Float16`). | ||||||
|
||||||
In the `update!(opt_state, x, g)` call, `opt` is used to update `xT` instead of `x`, | ||||||
then `x` is updated with the value of `xT`. | ||||||
|
||||||
[1] Micikevicius et al. '17, "Mixed Precision Training", https://arxiv.org/abs/1710.03740 . | ||||||
|
||||||
# Examples | ||||||
|
||||||
```julia | ||||||
x = rand(Float16, 2) # A trainable parameter in low precision | ||||||
|
||||||
opt = MixedPrecision(Adam(1e-3)) # Equivalent to MixedPrecision(Float32, Adam(1e-3)) | ||||||
opt_state = Optimisers.setup(opt, x) # The state contains a copy of x in Float32 precision | ||||||
|
||||||
g = rand(Float16, 2) # A gradient in low precision | ||||||
|
||||||
# Accumulation is performed in high precision, | ||||||
# then also the low precision x is synced | ||||||
Optimisers.update!(opt_state, x, g) | ||||||
``` | ||||||
""" | ||||||
struct MixedPrecision{T<:Number, O<:AbstractRule} <: AbstractRule | ||||||
opt::O | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe this should be rule?
Suggested change
|
||||||
end | ||||||
|
||||||
@functor MixedPrecision | ||||||
|
||||||
MixedPrecision(opt::AbstractRule) = MixedPrecision{Float32, typeof(opt)}(opt) | ||||||
MixedPrecision(T::Type, opt::AbstractRule) = MixedPrecision{T, typeof(opt)}(opt) | ||||||
|
||||||
function init(o::MixedPrecision{T}, x::AbstractArray) where T | ||||||
xT = T.(x) | ||||||
return (xT, init(o.opt, xT)) | ||||||
end | ||||||
|
||||||
function apply!(o::MixedPrecision{T}, state, x, dx) where T | ||||||
xT, st = state | ||||||
st′, dx′ = apply!(o.opt, st, xT, dx) | ||||||
xT = subtract!(xT, dx′) | ||||||
if maywrite(x) | ||||||
x .= xT | ||||||
dx′ = nothing | ||||||
darsnack marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is correct. But perhaps weird things will happen if you try to compose it, e.g. |
||||||
else | ||||||
dx′ = x .- eltype(x).(xT) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On this path, should the subtraction happen in high or low precision, does it matter? This is the sort of place that I worry about scaling & the range of Float16. But haven't thought hard. |
||||||
end | ||||||
return (xT, st′), dx′ | ||||||
end | ||||||
|
||||||
adjust(o::MixedPrecision, eta::Real) = MixedPrecision(adjust(o.opt, eta)) | ||||||
adjust(o::MixedPrecision; kw...) = MixedPrecision(adjust(o.opt; kw...)) | ||||||
Comment on lines
+816
to
+817
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't these forget |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.