-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add minimal FSDP example #23
Conversation
fsdp/minimal-fsdp/fsdp.py
Outdated
if next_idx == 0: | ||
generator.manual_seed(42 + rank + 100000 * is_validation) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious: why this if block? Why not seed immediately after creating the generator?
Also, this isn't very important, but I guess the restarting logic starts the dataloader over, rather than continuing from the last batch. I wouldn't fix this, just maybe note it in a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Goal: I'm trying to have the generator yield a recurring sequence of length simulated_size_in_batches
. I do this by reseeding to the same initial seed after simulated_size_in_batches
steps, which comes from setting next_idx = (next_idx + 1) % simulated_size_in_batches
.
I save repeating the code by having the initial seed and the reseeds use the same if condition here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, I missed the modulo, got it now
fsdp/minimal-fsdp/fsdp.py
Outdated
|
||
# Wrap the embedding layer, the lm head, and each transformer block into its own FSDP unit: | ||
auto_wrap_policy = ModuleWrapPolicy([TransformerBlock, EmbedAndEncode, LMHead]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I should have caught this before, but there are some small use_amp = True
related issues we probably want to adress.
First, there is some special handling for the FSDP weight/comms/etc types when using amp, as in
from torch.distributed.fsdp import MixedPrecision
fsdp_model = FSDP(model,
mixed_precision=MixedPrecision(param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16), ...)
Not needed for correctness, just efficiency.
Second, for the autocast in get_loss
it's better practice to specify the dtype=torch.bfloat16
arg, as in
with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch.bfloat16):
outputs = fsdp_model(inputs)
The default is dtype=torch.float16
(not bfloat16
).
Third, with bfloat16
there should be no need for the ShardedGradScaler
; only needed for float16
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies, I should have caught these. Think I missed them when just looking at the diffs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First, there is some special handling for the FSDP weight/comms/etc types when using amp ... Not needed for correctness, just efficiency.
Good catch, wasn't aware of this; I assume autocast
just fails to recognize / handle some of the buffers and operations used inside of FSDP without it?
Second, for the autocast in get_loss it's better practice to specify the dtype=torch.bfloat16 arg, as in
bfloat16
isn't supported for pre-A100 GPUs, right? I imagine we want our example to run on any GPU, hence defaulting to float16
. Doing a quick Google search, looks like this can be made conditional via e.g.
compute_capability = torch.cuda.get_device_capability()
if compute_capability[0] < 8:
....
So we could switch to using bfloat16
iff it's supported.
Third, with bfloat16 there should be no need for the ShardedGradScaler; only needed for float16.
See above, but yes -- if we switch to bfloat16
entirely or conditionally, we could omit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call on hardware compatibility. Apparently if the hardware doesn't support bfloat16
, it silently falls back to float32
? Seem weird.
Anyway, your call on what to support there. Could also make it configurable as another hparam, amp_dtype
?
I assume autocast just fails to recognize / handle some of the buffers and operations used inside of FSDP without it?
Nah, it's a little different, IIUC. The actual weights used for the forwards are held in whatever precision you specify in MixedPrecision
, or float32
if omitted. Under autocasting, various tensors have their dtypes
changed so that specific operations occur in either high or low precision. E.g. matmuls in low-precision and softmax in high-precision. MixedPrecision
affects which direction most of the casts occur in (among other things).
Like in non-FSDP weights would always be kept in high-precision and down-cast as needed. With FSDP and MixedPrecision(param_dtype=torch.bfloat16)
, say, you also avoid those down-casts because the weights are already in the desired precision. But you then might need some extra up-casts elsewhere.
There are no explicit failures if you don't specify MixedPrecision
w/ FSDP and run under autocast
; just perf differences, and likely some (hopefully) small numerical differences also.
I likely got some details incorrect here, also. The FSDP + amp API isn't super well documented. Talked to one of the OLMo researchers about exactly this topic and how it's confusing a while ago. (They are heavily FSDP based.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, if you specify MixedPrecision
in FSDP but don't wrap in autocast
none of the casts ever happen and you can end up running every part of the forwards in low-precision, rather than mixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! I missed some things about mixed precision that might be worth fixing, but up to you.
fsdp/minimal-fsdp/fsdp.py
Outdated
|
||
# Wrap the embedding layer, the lm head, and each transformer block into its own FSDP unit: | ||
auto_wrap_policy = ModuleWrapPolicy([TransformerBlock, EmbedAndEncode, LMHead]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies, I should have caught these. Think I missed them when just looking at the diffs.
Add an example (originally by Garrett with small updates from me) that shows how to use torch's FSDP implementation for LLM training alongside Core API.