Skip to content

Commit

Permalink
works for fp8 with deepspeed
Browse files Browse the repository at this point in the history
  • Loading branch information
XiaobingSuper committed Jan 23, 2025
1 parent 4c2c89e commit a2ded14
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,13 +433,14 @@ def __init__(
**kwargs,
)

if self.state.mixed_precision == "fp8" and self.fp8_recipe_handler is None:
self._mixed_precision = mixed_precision
if mixed_precision == "fp8" and self.fp8_recipe_handler is None:
self.fp8_recipe_handler = FP8RecipeKwargs()

self.delayed_fp8_autocast = False
if self.fp8_recipe_handler is not None:
# We already check if FP8 is available during `self.state`
if self.state.mixed_precision != "fp8" and (
if mixed_precision != "fp8" and (
self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED)
):
raise ValueError("Passing in a `FP8RecipeKwargs` object requires setting `mixed_precision='fp8'`.")
Expand Down Expand Up @@ -507,7 +508,10 @@ def __init__(
if mixed_precision == "bf16" and not self.native_amp and not is_torch_xla_available():
raise ValueError("bf16 mixed precision requires PyTorch >= 1.10 and a supported device.")

elif self.state.mixed_precision == "fp8":
# for DeepSpeed, self.state.mixed_precision is always "bf16",
# see https://github.com/huggingface/accelerate/blob/main/src/accelerate/state.py#L968 and
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/dataclasses.py#L1263.
elif mixed_precision == "fp8" or self.state.mixed_precision == "fp8":
# We always enable `native_amp` for FP8
self.native_amp = True
if self.fp8_backend == "MSAMP":
Expand Down Expand Up @@ -3600,7 +3604,7 @@ def lomo_backward(self, loss: torch.Tensor, learning_rate: float) -> None:
@property
def fp8_backend(self):
"Returns the configured backend for training in FP8"
if self.mixed_precision == "fp8" and self.fp8_recipe_handler is not None:
if self._mixed_precision == "fp8" and self.fp8_recipe_handler is not None:
return self.fp8_recipe_handler.backend
elif self.state.deepspeed_plugin is not None and self.state.deepspeed_plugin.enable_msamp:
return "MSAMP"
Expand Down

0 comments on commit a2ded14

Please sign in to comment.