Skip to content

Commit

Permalink
Default synced_gpus to True when using FullyShardedDataParallel (
Browse files Browse the repository at this point in the history
…huggingface#33483)

* Default synced_gpus to True when using FullyShardedDataParallel

Fixes huggingface#30228

Related:

* pytorch/pytorch#100069
* pytorch/pytorch#123962

Similar to DeepSpeed ZeRO Stage 3, when using FSDP with multiple GPUs and differently sized data per rank, the ranks reach different synchronization points at the same time, leading to deadlock

To avoid this, we can automatically set synced_gpus to True if we detect that a PreTrainedModel is being managed by FSDP using _is_fsdp_managed_module, which was added in 2.0.0 for torch.compile: https://github.com/pytorch/pytorch/blob/v2.0.0/torch/distributed/fsdp/_dynamo_utils.py

* Remove test file

* ruff formatting

* ruff format

* Update copyright year

Co-authored-by: Arthur <[email protected]>

* Add test for FSDP-wrapped model generation

Before huggingface#33483, these tests would have hung for 10 minutes before crashing due to a timeout error

* Ruff format

* Move argparse import

* Remove barrier

I think this might cause more problems if one of the workers was killed

* Move import into function to decrease load time

huggingface#33483 (comment)

* Add test for accelerate and Trainer

huggingface#33483 (comment)

* Refactor imports

* Ruff format

* Use nullcontext

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
ringohoffman and ArthurZucker authored Oct 10, 2024
1 parent 24b82f3 commit 70b07d9
Show file tree
Hide file tree
Showing 26 changed files with 434 additions and 98 deletions.
33 changes: 19 additions & 14 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from ..configuration_utils import PretrainedConfig
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..integrations.fsdp import is_fsdp_managed_module
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..pytorch_utils import isin_mps_friendly
from ..tokenization_utils import ExtensionsTrie
Expand Down Expand Up @@ -1913,9 +1914,9 @@ def generate(
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://arxiv.org/abs/2010.00904).
synced_gpus (`bool`, *optional*):
Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
`True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
generating before other GPUs. Otherwise it'll be set to `False`.
Whether to continue running the while loop until max_length. Unless overridden, this flag will be set
to `True` if using `FullyShardedDataParallel` or DeepSpeed ZeRO Stage 3 with multiple GPUs to avoid
deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`.
assistant_model (`PreTrainedModel`, *optional*):
An assistant model that can be used to accelerate generation. The assistant model must have the exact
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
Expand Down Expand Up @@ -1962,10 +1963,7 @@ def generate(

# 2. Set generation parameters if not already defined
if synced_gpus is None:
if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1:
synced_gpus = True
else:
synced_gpus = False
synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1

logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
Expand Down Expand Up @@ -2499,7 +2497,8 @@ def _dola_decoding(
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
Expand Down Expand Up @@ -2702,7 +2701,8 @@ def _contrastive_search(
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
Expand Down Expand Up @@ -3105,7 +3105,8 @@ def _sample(
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
Expand Down Expand Up @@ -3307,7 +3308,8 @@ def _beam_search(
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
Expand Down Expand Up @@ -3585,7 +3587,8 @@ def _group_beam_search(
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
model_kwargs:
Additional model specific kwargs that will be forwarded to the `forward` function of the model. If
model is an encoder-decoder model the kwargs should include `encoder_outputs`.
Expand Down Expand Up @@ -3874,7 +3877,8 @@ def _constrained_beam_search(
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
Expand Down Expand Up @@ -4112,7 +4116,8 @@ def _assisted_decoding(
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
],
"eetq": ["replace_with_eetq_linear"],
"fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"],
"fsdp": ["is_fsdp_managed_module"],
"ggml": [
"GGUF_CONFIG_MAPPING",
"GGUF_TENSOR_MAPPING",
Expand Down Expand Up @@ -155,6 +156,7 @@
)
from .eetq import replace_with_eetq_linear
from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear
from .fsdp import is_fsdp_managed_module
from .ggml import (
GGUF_CONFIG_MAPPING,
GGUF_TENSOR_MAPPING,
Expand Down
33 changes: 33 additions & 0 deletions src/transformers/integrations/fsdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from ..utils import is_torch_available


if TYPE_CHECKING:
from torch import nn


def is_fsdp_managed_module(module: nn.Module) -> bool:
if not is_torch_available():
return False

import torch.distributed.fsdp

return isinstance(module, torch.distributed.fsdp.FullyShardedDataParallel) or getattr(
module, "_is_fsdp_managed_module", False
)
7 changes: 4 additions & 3 deletions src/transformers/models/data2vec/modeling_data2vec_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
Expand Down Expand Up @@ -826,7 +827,7 @@ def forward(
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)

deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)

for layer in self.layers:
if output_hidden_states:
Expand All @@ -836,8 +837,8 @@ def forward(
dropout_probability = torch.rand([])

skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/models/deprecated/mctct/modeling_mctct.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ....activations import ACT2FN
from ....file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ....integrations.deepspeed import is_deepspeed_zero3_enabled
from ....integrations.fsdp import is_fsdp_managed_module
from ....modeling_attn_mask_utils import _prepare_4d_attention_mask
from ....modeling_outputs import BaseModelOutput, CausalLMOutput
from ....modeling_utils import (
Expand Down Expand Up @@ -579,7 +580,7 @@ def forward(
f"but it is for {head_mask.size()[0]}."
)

deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
Expand All @@ -588,8 +589,8 @@ def forward(
dropout_probability = torch.rand([])

skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
Expand Down
13 changes: 7 additions & 6 deletions src/transformers/models/hubert/modeling_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
Expand Down Expand Up @@ -968,7 +969,7 @@ def forward(
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)

deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)

for layer in self.layers:
if output_hidden_states:
Expand All @@ -978,8 +979,8 @@ def forward(
dropout_probability = torch.rand([])

skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
Expand Down Expand Up @@ -1055,7 +1056,7 @@ def forward(
hidden_states = hidden_states + position_embeddings
hidden_states = self.dropout(hidden_states)

deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)

for layer in self.layers:
if output_hidden_states:
Expand All @@ -1065,8 +1066,8 @@ def forward(
dropout_probability = torch.rand([])

skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
Expand Down
13 changes: 7 additions & 6 deletions src/transformers/models/m2m_100/modeling_m2m_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
Expand Down Expand Up @@ -1055,7 +1056,7 @@ def forward(
f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)

for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
Expand All @@ -1065,8 +1066,8 @@ def forward(
dropout_probability = torch.rand([])

skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync

if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
Expand Down Expand Up @@ -1312,7 +1313,7 @@ def forward(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)

for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
Expand All @@ -1322,8 +1323,8 @@ def forward(
dropout_probability = torch.rand([])

skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync

past_key_value = past_key_values[idx] if past_key_values is not None else None

Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/musicgen/modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1506,7 +1506,8 @@ def generate(
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
Expand Down Expand Up @@ -2513,7 +2514,8 @@ def generate(
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
Expand Down
Loading

0 comments on commit 70b07d9

Please sign in to comment.