Skip to content

Commit

Permalink
feat: support tensor parallel & Data loader (#3173)
Browse files Browse the repository at this point in the history
* feat: add dataloader for TP and n-dim parallel in non-dispatch mode

Signed-off-by: Mehant Kammakomati <[email protected]>

* feat: add support for CLI usage

Signed-off-by: Mehant Kammakomati <[email protected]>

* fix: test cases

Signed-off-by: Mehant Kammakomati <[email protected]>

* fix: when tp not in use fix num_procs

Signed-off-by: Mehant Kammakomati <[email protected]>

---------

Signed-off-by: Mehant Kammakomati <[email protected]>
  • Loading branch information
kmehant authored Jan 29, 2025
1 parent 675e35b commit 0315365
Show file tree
Hide file tree
Showing 13 changed files with 282 additions and 7 deletions.
42 changes: 41 additions & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
ProjectConfiguration,
RNGType,
TorchDynamoPlugin,
TorchTensorParallelPlugin,
apply_fp8_autowrap,
check_os_kernel,
clean_state_dict_for_safetensors,
Expand Down Expand Up @@ -107,7 +108,12 @@
save_fsdp_optimizer,
wait_for_everyone,
)
from .utils.constants import FSDP_PYTORCH_VERSION, PROFILE_PATTERN_NAME
from .utils.constants import (
BETA_TP_AVAILABLE_PYTORCH_VERSION,
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION,
FSDP_PYTORCH_VERSION,
PROFILE_PATTERN_NAME,
)
from .utils.modeling import get_state_dict_offloaded_model
from .utils.other import is_compiled_module

Expand Down Expand Up @@ -189,6 +195,9 @@ class Accelerator:
fsdp_plugin ([`~utils.FullyShardedDataParallelPlugin`], *optional*):
Tweak your FSDP related args using this argument. This argument is optional and can be configured directly
using *accelerate config*
torch_tp_plugin ([`~utils.TorchTensorParallelPlugin`], *optional*):
Tweak your torch tensor parallel. This argument is optional and can be configured directly using
*accelerate config*
megatron_lm_plugin ([`~utils.MegatronLMPlugin`], *optional*):
Tweak your MegatronLM related args using this argument. This argument is optional and can be configured
directly using *accelerate config*
Expand Down Expand Up @@ -258,6 +267,7 @@ def __init__(
dataloader_config: DataLoaderConfiguration | None = None,
deepspeed_plugin: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None,
fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
torch_tp_plugin: TorchTensorParallelPlugin | None = None,
megatron_lm_plugin: MegatronLMPlugin | None = None,
rng_types: list[str | RNGType] | None = None,
log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None,
Expand Down Expand Up @@ -354,6 +364,15 @@ def __init__(
if not is_torch_version(">=", FSDP_PYTORCH_VERSION):
raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")

if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or isinstance(
torch_tp_plugin, TorchTensorParallelPlugin
):
if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION):
raise ValueError(f"TP requires PyTorch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}")

if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")

if fsdp_plugin is None: # init from env variables
fsdp_plugin = (
FullyShardedDataParallelPlugin() if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" else None
Expand All @@ -363,6 +382,15 @@ def __init__(
raise TypeError("`fsdp_plugin` must be a FullyShardedDataParallelPlugin object.")
os.environ["ACCELERATE_USE_FSDP"] = "true" # use FSDP if plugin is provided

if torch_tp_plugin is None:
torch_tp_plugin = (
TorchTensorParallelPlugin() if os.environ.get("ACCELERATE_USE_TP", "false") == "true" else None
)
else:
if not isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
raise TypeError("`torch_tp_plugin` must be a TorchTensorParallelPlugin object.")
os.environ["ACCELERATE_USE_TP"] = "true"

if megatron_lm_plugin is None: # init from env variables
megatron_lm_plugin = (
MegatronLMPlugin() if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true" else None
Expand Down Expand Up @@ -428,6 +456,7 @@ def __init__(
dynamo_plugin=dynamo_plugin,
deepspeed_plugin=deepspeed_plugins,
fsdp_plugin=fsdp_plugin,
torch_tp_plugin=torch_tp_plugin,
megatron_lm_plugin=megatron_lm_plugin,
_from_accelerator=True,
**kwargs,
Expand Down Expand Up @@ -1471,6 +1500,16 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
)
if self.ddp_handler is not None:
self.ddp_handler.register_comm_hook(model)
elif self.distributed_type == DistributedType.TP:
if hasattr(model, "supports_tp_plan") and not model.supports_tp_plan:
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
raise NotImplementedError(
"Provided model does not support tensor parallelism. \
Tensor parallelism plan can be added as base_model_tp_plan to model config class \
and _tp_plan attribute to model class."
)
model.tensor_parallel(self.state.torch_tp_plugin.torch_device_mesh["tp"])
elif self.distributed_type == DistributedType.FSDP:
# We need to fix the optimizer *before* sharding the model
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
Expand Down Expand Up @@ -2122,6 +2161,7 @@ def prepare_data_loader(
data_seed=self.dataloader_config.data_seed,
non_blocking=self.non_blocking,
use_stateful_dataloader=self.use_stateful_dataloader,
torch_device_mesh=self.state.torch_tp_plugin.torch_device_mesh if self.state.torch_tp_plugin else None,
)
self._dataloaders.append(prepared_data_loader)
return prepared_data_loader
Expand Down
18 changes: 17 additions & 1 deletion src/accelerate/commands/config/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def get_cluster_input():
)

fsdp_config = {}
tp_config = {}
if distributed_type in [
DistributedType.MULTI_GPU,
DistributedType.MULTI_NPU,
Expand Down Expand Up @@ -475,7 +476,21 @@ def get_cluster_input():
default=False,
error_message="Please enter yes or no.",
)

if not use_fsdp:
use_tp = _ask_field(
"Do you want to use TensorParallel? [yes/NO]: ",
_convert_yes_no_to_bool,
default=False,
error_message="Please enter yes or no.",
)
if use_tp:
distributed_type = DistributedType.TP
if distributed_type == DistributedType.TP:
tp_config["tp_size"] = _ask_field(
"What should be your Tensor Parallel degree? [1]: ",
int,
default=1,
)
megatron_lm_config = {}
if distributed_type in [DistributedType.MULTI_GPU]:
use_megatron_lm = _ask_field(
Expand Down Expand Up @@ -810,6 +825,7 @@ def get_cluster_input():
fp8_config=fp8_config,
deepspeed_config=deepspeed_config,
fsdp_config=fsdp_config,
tp_config=tp_config,
megatron_lm_config=megatron_lm_config,
ipex_config=ipex_config,
mpirun_config=mpirun_config,
Expand Down
4 changes: 4 additions & 0 deletions src/accelerate/commands/config/config_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ class ClusterConfig(BaseConfig):
deepspeed_config: dict = None
# args for fsdp
fsdp_config: dict = None
# args for tp
tp_config: dict = None
# args for megatron_lm
megatron_lm_config: dict = None
# args for ipex
Expand Down Expand Up @@ -221,6 +223,8 @@ def __post_init__(self):
self.deepspeed_config = {}
if self.fsdp_config is None:
self.fsdp_config = {}
if self.tp_config is None:
self.tp_config = {}
if self.megatron_lm_config is None:
self.megatron_lm_config = {}
if self.ipex_config is None:
Expand Down
26 changes: 24 additions & 2 deletions src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
"tpu": "TPU",
"use_deepspeed": "DeepSpeed Arguments",
"use_fsdp": "FSDP Arguments",
"use_tp": "PyTorch TP Arguments",
"use_megatron_lm": "Megatron-LM Arguments",
"fp8_backend": "FP8 Arguments",
}
Expand Down Expand Up @@ -261,6 +262,12 @@ def launch_command_parser(subparsers=None):
action="store_true",
help="Whether to use fsdp.",
)
paradigm_args.add_argument(
"--use_tp",
default=False,
action="store_true",
help="Whether to use PyTorch TP.",
)
paradigm_args.add_argument(
"--use_megatron_lm",
default=False,
Expand Down Expand Up @@ -588,6 +595,15 @@ def launch_command_parser(subparsers=None):
help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).",
)

# tp args
tp_args = parser.add_argument_group("TP Arguments", "Arguments related to Tensor Parallelism using PyToch.")
tp_args.add_argument(
"--tp_size",
default=1,
type=int,
help="PyTorch Tensor Parallelism (TP) degree. Set a value greater than 1 to activate. (useful only when `use_tp` flag is passed)",
)

# megatron_lm args
megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.")
megatron_lm_args.add_argument(
Expand Down Expand Up @@ -969,9 +985,9 @@ def sagemaker_launcher(sagemaker_config: SageMakerConfig, args):

def _validate_launch_command(args):
# Sanity checks
if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp]) > 1:
if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp, args.use_tp]) > 1:
raise ValueError(
"You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp` at a time."
"You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp`, `--use_tp` at a time."
)
if args.multi_gpu and (args.num_processes is not None) and (args.num_processes < 2):
raise ValueError("You need to use at least 2 processes to use `--multi_gpu`.")
Expand All @@ -988,6 +1004,7 @@ def _validate_launch_command(args):
and not args.tpu_use_cluster
and not args.use_deepspeed
and not args.use_fsdp
and not args.use_tp
and not args.use_megatron_lm
):
args.use_deepspeed = defaults.distributed_type == DistributedType.DEEPSPEED
Expand All @@ -1005,6 +1022,7 @@ def _validate_launch_command(args):
)
args.tpu = defaults.distributed_type == DistributedType.XLA
args.use_fsdp = defaults.distributed_type == DistributedType.FSDP
args.use_tp = defaults.distributed_type == DistributedType.TP
args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM
args.tpu_use_cluster = defaults.tpu_use_cluster if args.tpu else False
if args.gpu_ids is None:
Expand Down Expand Up @@ -1032,6 +1050,8 @@ def _validate_launch_command(args):
if "fsdp" not in arg_to_set:
arg_to_set = "fsdp_" + arg_to_set
setattr(args, arg_to_set, defaults.fsdp_config[k])
for k in defaults.tp_config:
setattr(args, k, defaults.tp_config[k])
for k in defaults.megatron_lm_config:
setattr(args, k, defaults.megatron_lm_config[k])
for k in defaults.dynamo_config:
Expand Down Expand Up @@ -1157,6 +1177,8 @@ def launch_command(args):
deepspeed_launcher(args)
elif args.use_fsdp and not args.cpu:
multi_gpu_launcher(args)
elif args.use_tp and not args.cpu:
multi_gpu_launcher(args)
elif args.use_megatron_lm and not args.cpu:
multi_gpu_launcher(args)
elif args.multi_gpu and not args.cpu:
Expand Down
75 changes: 73 additions & 2 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ def __init__(
use_stateful_dataloader=False,
_drop_last: bool = False,
_non_blocking: bool = False,
torch_device_mesh=None,
**kwargs,
):
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
Expand Down Expand Up @@ -726,6 +727,7 @@ def __init__(
_drop_last: bool = False,
_non_blocking: bool = False,
slice_fn=None,
torch_device_mesh=None,
**kwargs,
):
shuffle = False
Expand All @@ -744,26 +746,68 @@ def __init__(
self._drop_last = _drop_last
self._non_blocking = _non_blocking
self.skip_batches = skip_batches
self.torch_device_mesh = torch_device_mesh

self.slice_fn = slice_tensors if slice_fn is None else slice_fn
self.iteration = 0

# if a device mesh is provided extract each dimension (dp, fsdp, tp)
# device mesh may hold any number of dimensions, however,
# below code is for targetted support for dp, fsdp and tp

# device mesh will be used only if there is tp involved
# or any multi-dimensional parallelism involving tp
# (dp, tp) (fsdp, tp) (dp, fsdp, tp)
# otherwise the default behavour not using device mesh should be sufficient
# since multi dimensional parallelism devoid of tp would anyway need
# different batches for each process irrespective of dp or fsdp
self.submesh_tp = None
self.submesh_dp = None
self.submesh_fsdp = None
if self.torch_device_mesh and "tp" in self.torch_device_mesh.mesh_dim_names:
self.submesh_tp = self.torch_device_mesh["tp"]
if "dp" in self.torch_device_mesh.mesh_dim_names:
self.submesh_dp = self.torch_device_mesh["dp"]
if "fsdp" in self.torch_device_mesh.mesh_dim_names:
self.submesh_fsdp = self.torch_device_mesh["fsdp"]
if self.submesh_tp and (self.submesh_dp or self.submesh_fsdp):
raise ValueError("TP + (DP/FSDP) is not yet supported in dispatch mode")

def _fetch_batches(self, iterator):
batches, batch = None, None
# On process 0, we gather the batch to dispatch.
if self.state.process_index == 0:
# Procedure to support TP only is simpler
# since we want to dispatch the same batch of samples across all ranks
# this removes complexity of handling multiple tp rank groups when TP + DP
# combination is involved.

try:
# for TP case avoid using split_batches
# since it would mean that the dataloader should be spilling out
# duplicates of batches.
if self.split_batches:
# One batch of the main iterator is dispatched and split.
if self.submesh_tp:
logger.warning(
"Use of split_batches for TP would need the dataloader to produce duplicate batches,"
"otherwise, use dispatch_batches=True instead."
)
self._update_state_dict()
batch = next(iterator)
else:
# num_processes batches of the main iterator are concatenated then dispatched and split.
# We add the batches one by one so we have the remainder available when drop_last=False.
batches = []
for _ in range(self.state.num_processes):
if self.submesh_tp:
# when tp, extract single batch and then replicate
self._update_state_dict()
batches.append(next(iterator))
batch = next(iterator)
batches = [batch] * self.state.num_processes
else:
for _ in range(self.state.num_processes):
self._update_state_dict()
batches.append(next(iterator))
try:
batch = concatenate(batches, dim=0)
except RuntimeError as e:
Expand Down Expand Up @@ -954,6 +998,7 @@ def prepare_data_loader(
data_seed: Optional[int] = None,
non_blocking: bool = False,
use_stateful_dataloader: bool = False,
torch_device_mesh=None,
) -> DataLoader:
"""
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
Expand Down Expand Up @@ -1021,6 +1066,8 @@ def prepare_data_loader(
"If set to true, the dataloader prepared by the Accelerator will be backed by "
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
torch_device_mesh (`torch.distributed.DeviceMesh`, *optional*, defaults to `None`):
PyTorch device mesh.
Returns:
Expand All @@ -1045,9 +1092,32 @@ def prepare_data_loader(
state = PartialState()
if num_processes is None:
num_processes = state.num_processes

if process_index is None:
process_index = state.process_index

# when device mesh is used, specifically with TP
# then there is need to update process_index and num_processes
# to bring in the effect of generating same batch across TP ranks
# and different batch across FSDP and DP ranks.
# Example:
# if device mesh is (dp,fsdp,tp) = (2, 2, 3)
# ranks would range from 0...11
# from data angle ranks should look like 0 0 0 1 1 1 2 2 2 3 3 3
# processes with same ranks/ids would receive the same batch
if torch_device_mesh:
submesh_fsdp_size = 1
submesh_dp_size = 1
submesh_tp_size = 1
if "tp" in torch_device_mesh.mesh_dim_names:
submesh_tp_size = torch_device_mesh["tp"].size()
if "dp" in torch_device_mesh.mesh_dim_names:
submesh_dp_size = torch_device_mesh["dp"].size()
if "fsdp" in torch_device_mesh.mesh_dim_names:
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
process_index = process_index // submesh_tp_size
num_processes = submesh_fsdp_size * submesh_dp_size

# Sanity check
if split_batches:
if dataloader.batch_size is not None:
Expand Down Expand Up @@ -1156,6 +1226,7 @@ def prepare_data_loader(
_non_blocking=non_blocking,
slice_fn=slice_fn_for_dispatch,
use_stateful_dataloader=use_stateful_dataloader,
torch_device_mesh=torch_device_mesh,
**kwargs,
)
elif sampler_is_batch_sampler:
Expand Down
Loading

0 comments on commit 0315365

Please sign in to comment.