From 95f34d62434869f0f190f0256ecfecc7c1a84d9d Mon Sep 17 00:00:00 2001 From: Alvaro Moran <6949769+tengomucho@users.noreply.github.com> Date: Mon, 13 Jan 2025 10:37:00 +0100 Subject: [PATCH] feat(tpu): remove nprocs from xla.spawn (#3324) This parameter will cause issues on recent version of torch_xla. --- src/accelerate/commands/launch.py | 7 ++++++- src/accelerate/launchers.py | 7 +++---- tests/xla_spawn.py | 16 ++++++++++++++-- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/src/accelerate/commands/launch.py b/src/accelerate/commands/launch.py index 49579467dd4..e0074f80762 100644 --- a/src/accelerate/commands/launch.py +++ b/src/accelerate/commands/launch.py @@ -854,6 +854,7 @@ def deepspeed_launcher(args): def tpu_launcher(args): import torch_xla.distributed.xla_multiprocessing as xmp + from torch_xla import device_count if args.no_python: raise ValueError("--no_python cannot be used with TPU launcher") @@ -874,13 +875,17 @@ def tpu_launcher(args): f"Your training script should have a function named {args.main_training_function}, or you should pass a " "different value to `--main_training_function`." ) + if args.num_processes and args.num_processes != device_count(): + raise ValueError( + f"Number of processes ({args.num_processes}) must match the number of TPU devices ({device_count()})" + ) # Patch sys.argv sys.argv = [mod.__file__] + args.training_script_args main_function = getattr(mod, args.main_training_function) with patch_environment(**current_env): - xmp.spawn(PrepareForLaunch(main_function), args=(), nprocs=args.num_processes) + xmp.spawn(PrepareForLaunch(main_function), args=()) def tpu_pod_launcher(args): diff --git a/src/accelerate/launchers.py b/src/accelerate/launchers.py index 3b8dc5aef23..b1b95a46ade 100644 --- a/src/accelerate/launchers.py +++ b/src/accelerate/launchers.py @@ -135,6 +135,7 @@ def train(*args): if (in_colab or in_kaggle) and (os.environ.get("TPU_NAME", None) is not None): # TPU launch import torch_xla.distributed.xla_multiprocessing as xmp + from torch_xla import device_count if len(AcceleratorState._shared_state) > 0: raise ValueError( @@ -142,12 +143,10 @@ def train(*args): "your training function. Restart your notebook and make sure no cells initializes an " "`Accelerator`." ) - if num_processes is None: - num_processes = 8 launcher = PrepareForLaunch(function, distributed_type="XLA") - print(f"Launching a training on {num_processes} TPU cores.") - xmp.spawn(launcher, args=args, nprocs=num_processes, start_method="fork") + print(f"Launching a training on {device_count()} TPU cores.") + xmp.spawn(launcher, args=args, start_method="fork") elif in_colab and get_gpu_info()[1] < 2: # No need for a distributed launch otherwise as it's either CPU or one GPU. if torch.cuda.is_available(): diff --git a/tests/xla_spawn.py b/tests/xla_spawn.py index c97f272b7e1..66ed5ee4817 100644 --- a/tests/xla_spawn.py +++ b/tests/xla_spawn.py @@ -30,6 +30,7 @@ from pathlib import Path import torch_xla.distributed.xla_multiprocessing as xmp +from torch_xla import device_count def parse_args(): @@ -46,7 +47,13 @@ def parse_args(): ) # Optional arguments for the launch helper - parser.add_argument("--num_cores", type=int, default=1, help="Number of TPU cores to use (1 or 8).") + num_devices = device_count() + parser.add_argument( + "--num_cores", + type=int, + default=num_devices, + help="Number of TPU cores to use (1 or number of available devices).", + ) # positional parser.add_argument( @@ -76,7 +83,12 @@ def main(): mod = importlib.import_module(mod_name) # Patch sys.argv - sys.argv = [args.training_script] + args.training_script_args + ["--tpu_num_cores", str(args.num_cores)] + sys.argv = [args.training_script] + args.training_script_args + num_cores = args.num_cores + if num_cores == device_count() and num_cores != 1: + # There is an error in xmp.spawn that causes it to fail when num_cores is specified and not 1, so we set it to + # None when it matches the number of devices. + num_cores = None xmp.spawn(mod._mp_fn, args=(), nprocs=args.num_cores)