Skip to content

Commit

Permalink
Fix linting
Browse files Browse the repository at this point in the history
  • Loading branch information
coreystatendet committed Jul 2, 2024
1 parent 3c4f1a4 commit d4dc298
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 15 deletions.
52 changes: 40 additions & 12 deletions fsdp/minimal-fsdp/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@
import random
from typing import Any, Dict, Generator, Optional, TypedDict

import determined as det
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from model import EmbedAndEncode, LMHead, Transformer, TransformerBlock
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy, StateDictType
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import ModuleWrapPolicy

import determined as det
from model import EmbedAndEncode, LMHead, Transformer, TransformerBlock

"""
Minimal transformer model FSDP script with Core API.
Expand All @@ -29,7 +29,7 @@ def get_fake_data_iter(
rank: int,
device: torch.device,
is_validation: bool,
simulated_size_in_batches: int = 10
simulated_size_in_batches: int = 10,
) -> Generator[tuple[torch.Tensor, torch.Tensor], None, None]:
"""
Fake dataloader. Yields a different set of data for each rank, and for train vs validation.
Expand All @@ -41,12 +41,16 @@ def get_fake_data_iter(
if next_idx == 0:
generator.manual_seed(42 + rank + 100000 * is_validation)
fake_sequence = torch.randint(
vocab_size, (batch_size, max_seq_len + 1), device=device, generator=generator
vocab_size,
(batch_size, max_seq_len + 1),
device=device,
generator=generator,
)
inputs, targets = fake_sequence[..., :-1], fake_sequence[..., 1:]
yield inputs, targets
next_idx = (next_idx + 1) % simulated_size_in_batches


def get_loss(
fsdp_model: FSDP, batch: tuple[torch.Tensor, torch.Tensor], use_amp: bool
) -> torch.Tensor:
Expand Down Expand Up @@ -75,7 +79,9 @@ def get_reduced_loss_and_report(
if core_context.distributed.rank == 0:
reduced_loss = loss_history_t.item()
# TypedDict pattern to satisfy mypy.
ReportArgs = TypedDict("ReportArgs", {"steps_completed": int, "metrics": Dict[str, float]})
ReportArgs = TypedDict(
"ReportArgs", {"steps_completed": int, "metrics": Dict[str, float]}
)
report_args: ReportArgs = {
"steps_completed": steps_completed,
"metrics": {"loss": reduced_loss},
Expand Down Expand Up @@ -107,7 +113,9 @@ def save_checkpoint(
optim_state_dict = FSDP.optim_state_dict(fsdp_model, optimizer)

if core_context.distributed.rank == 0:
with core_context.checkpoint.store_path(metadata={"steps_completed": steps_completed}) as (
with core_context.checkpoint.store_path(
metadata={"steps_completed": steps_completed}
) as (
path,
_,
):
Expand All @@ -134,8 +142,12 @@ def load_checkpoint(
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
fsdp_model.load_state_dict(torch.load(path.joinpath("model.bin"), map_location=device))
optim_state_dict = torch.load(path.joinpath("optim.bin"), map_location=device)
fsdp_model.load_state_dict(
torch.load(path.joinpath("model.bin"), map_location=device)
)
optim_state_dict = torch.load(
path.joinpath("optim.bin"), map_location=device
)
optim_state_dict_to_load = FSDP.optim_state_dict_to_load(
model=fsdp_model,
optim=optimizer,
Expand Down Expand Up @@ -223,7 +235,13 @@ def main(
# If a previous checkpoint exists, load it now and correct the steps_completed:
if checkpoint_uuid is not None:
steps_completed = load_checkpoint(
fsdp_model, optimizer, scaler, use_amp, core_context, device, checkpoint_uuid
fsdp_model,
optimizer,
scaler,
use_amp,
core_context,
device,
checkpoint_uuid,
)
# If torch profiler enabled, write profiling results to TensorBoard accessible through WebUI.
if use_torch_profiler:
Expand Down Expand Up @@ -260,20 +278,30 @@ def main(
)
train_loss_history.clear()
# Compute and report an average validation loss.
validation_data_iter = get_fake_data_iter(is_validation=True, **data_iter_arguments)
validation_data_iter = get_fake_data_iter(
is_validation=True, **data_iter_arguments
)
validation_loss_history = []
with torch.inference_mode():
for i in range(validation_batches):
batch = next(validation_data_iter)
loss = get_loss(fsdp_model, batch, use_amp)
validation_loss_history.append(loss)
last_validation_loss = get_reduced_loss_and_report(
validation_loss_history, steps_completed, core_context, validation=True
validation_loss_history,
steps_completed,
core_context,
validation=True,
)

if steps_completed % checkpoint_rate == 0 or this_is_the_last_step:
save_checkpoint(
fsdp_model, optimizer, scaler, use_amp, core_context, steps_completed
fsdp_model,
optimizer,
scaler,
use_amp,
core_context,
steps_completed,
)
# Since should_preempt is blocking, we only check at checkpoint_rate to
# maintain performance.
Expand Down
12 changes: 9 additions & 3 deletions fsdp/minimal-fsdp/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
bsz, seqlen, _ = inputs.shape

# Get queries, keys, and values
q, k, v = self.wqkv(inputs).split([self.d_model, self.d_model, self.d_model], dim=-1)
q, k, v = self.wqkv(inputs).split(
[self.d_model, self.d_model, self.d_model], dim=-1
)
q = q.view(bsz, seqlen, self.n_heads, self.head_dim)
k = k.view(bsz, seqlen, self.n_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_heads, self.head_dim)
Expand Down Expand Up @@ -120,7 +122,9 @@ def __init__(
super().__init__()
# Learned positional encoding and embedding layer:
self.max_seq_len = max_seq_len
self.learned_pos_enc = nn.Parameter(torch.zeros(max_seq_len, d_model, device=device))
self.learned_pos_enc = nn.Parameter(
torch.zeros(max_seq_len, d_model, device=device)
)
self.tok_embeddings = nn.Embedding(vocab_size, d_model, device=device)

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -160,7 +164,9 @@ def __init__(
super().__init__()

# Embed/encode
self.embed_and_encode = EmbedAndEncode(d_model, vocab_size, max_seq_len, device=device)
self.embed_and_encode = EmbedAndEncode(
d_model, vocab_size, max_seq_len, device=device
)

# Transformer blocks
self.layers = nn.ModuleList(
Expand Down

0 comments on commit d4dc298

Please sign in to comment.