Skip to content

Commit

Permalink
FSDP as well
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Jan 17, 2025
1 parent 3820c40 commit 4dfa816
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 50 deletions.
9 changes: 3 additions & 6 deletions benchmarks/fp8/torchao/ddp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -22,19 +22,16 @@

import evaluate
import torch
from datasets import load_dataset
from fp8_utils import get_dataloaders
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torchao.float8 import convert_to_float8_training
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup
from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup

from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import AORecipeKwargs, set_seed

from fp8_utils import get_dataloaders


MODEL_NAME = "bert-base-cased"
METRIC = evaluate.load("glue", "mrpc")
Expand Down
127 changes: 83 additions & 44 deletions benchmarks/fp8/torchao/fsdp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,7 +13,7 @@
# limitations under the License.

"""
This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`.
This script tests to ensure that `accelerate` performs at the same level as raw `torchao`.
This particular script verifies this for FSDP training.
"""
Expand All @@ -22,20 +22,19 @@

import evaluate
import torch
import transformer_engine.common.recipe as te_recipe
import transformer_engine.pytorch as te
from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities
from fp8_utils import get_dataloaders
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformer_engine.common.recipe import DelayedScaling
from torch.optim import AdamW
from torchao.float8 import convert_to_float8_training
from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup
from transformers.models.bert import BertLayer

from accelerate import Accelerator
from accelerate import FullyShardedDataParallelPlugin as FSDPPlugin
from accelerate.state import AcceleratorState
from accelerate.utils import FP8RecipeKwargs, set_seed
from accelerate.utils.transformer_engine import convert_model
from accelerate.utils import AORecipeKwargs, set_seed


MODEL_NAME = "bert-base-cased"
Expand All @@ -44,23 +43,72 @@
FSDP_WRAP_POLICY = partial(transformer_auto_wrap_policy, transformer_layer_cls={BertLayer})


def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None):
"""
Returns a tuple of:
- Model
- Optimizer
- Train dataloader (prepared)
- Eval dataloader (prepared)
- LR Scheduler
Suitable for training on the MRPC dataset
"""

if accelerator is None:
accelerator = Accelerator()
model = AutoModelForSequenceClassification.from_pretrained(model_name)
train_dataloader, eval_dataloader = get_dataloaders(model_name, batch_size)
optimizer = AdamW(model.parameters(), lr=0.0001)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=100,
num_training_steps=len(train_dataloader) * 2,
)
train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader)
return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler


def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=None):
if isinstance(module, torch.nn.Linear):
if module.in_features % 16 != 0 or module.out_features % 16 != 0:
return False
# For stability reasons, we skip the first and last linear layers
# Otherwise can lead to the model not training or converging properly
if fqn in (first_layer_name, last_layer_name):
return False
return True


def evaluate_model(model, dataloader, metric, accelerator=None):
"Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on"
model.eval()
for step, batch in enumerate(dataloader):
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
references = batch["labels"]
if accelerator is not None and accelerator.num_processes > 1:
predictions, references = accelerator.gather_for_metrics((predictions, references))
metric.add_batch(predictions=predictions, references=references)
return metric.compute()


def train_baseline():
set_seed(42)
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
first_linear = None
last_linear = None
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if first_linear is None:
first_linear = name
last_linear = name
func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear)
accelerator = Accelerator()
device = accelerator.device
model.to(device)

# Convert the model to TE
old_named_params = get_named_parameters(model)

with torch.no_grad():
convert_model(model)

FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"}
fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS)

new_named_params = get_named_parameters(model)
convert_to_float8_training(model, module_filter_fn=func)

# Convert the model to FSDP
model = FSDP(
Expand All @@ -70,24 +118,18 @@ def train_baseline():
auto_wrap_policy=FSDP_WRAP_POLICY,
)

mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
for param_group in optimizer.param_groups:
param_group["params"] = [mapping[p] for p in param_group["params"]]

base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
model.train()

for _ in range(2):
for batch in train_dataloader:
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
batch = batch.to(device)
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
for batch in train_dataloader:
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
batch = batch.to(device)
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()

trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)

Expand All @@ -102,15 +144,13 @@ def train_baseline():


def train_integration():
FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"}
kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)]
AcceleratorState()._reset_state(True)
fsdp_plugin = FSDPPlugin(
auto_wrap_policy=FSDP_WRAP_POLICY,
use_orig_params=True,
mixed_precision_policy=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32),
)
accelerator = Accelerator(mixed_precision="fp8", fsdp_plugin=fsdp_plugin, kwargs_handlers=kwargs_handlers)
accelerator = Accelerator(mixed_precision="fp8", fsdp_plugin=fsdp_plugin, kwargs_handlers=[AORecipeKwargs()])
set_seed(42)
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
MODEL_NAME, accelerator=accelerator
Expand All @@ -120,14 +160,13 @@ def train_integration():
base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
model.train()

for _ in range(2):
for batch in train_dataloader:
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
for batch in train_dataloader:
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()

trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)

Expand Down

0 comments on commit 4dfa816

Please sign in to comment.