Skip to content

Commit

Permalink
feat(tpu): remove nprocs from xla.spawn (#3324)
Browse files Browse the repository at this point in the history
This parameter will cause issues on recent version of torch_xla.
  • Loading branch information
tengomucho authored Jan 13, 2025
1 parent ba90f85 commit 95f34d6
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
7 changes: 6 additions & 1 deletion src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,7 @@ def deepspeed_launcher(args):

def tpu_launcher(args):
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla import device_count

if args.no_python:
raise ValueError("--no_python cannot be used with TPU launcher")
Expand All @@ -874,13 +875,17 @@ def tpu_launcher(args):
f"Your training script should have a function named {args.main_training_function}, or you should pass a "
"different value to `--main_training_function`."
)
if args.num_processes and args.num_processes != device_count():
raise ValueError(
f"Number of processes ({args.num_processes}) must match the number of TPU devices ({device_count()})"
)

# Patch sys.argv
sys.argv = [mod.__file__] + args.training_script_args

main_function = getattr(mod, args.main_training_function)
with patch_environment(**current_env):
xmp.spawn(PrepareForLaunch(main_function), args=(), nprocs=args.num_processes)
xmp.spawn(PrepareForLaunch(main_function), args=())


def tpu_pod_launcher(args):
Expand Down
7 changes: 3 additions & 4 deletions src/accelerate/launchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,19 +135,18 @@ def train(*args):
if (in_colab or in_kaggle) and (os.environ.get("TPU_NAME", None) is not None):
# TPU launch
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla import device_count

if len(AcceleratorState._shared_state) > 0:
raise ValueError(
"To train on TPU in Colab or Kaggle Kernel, the `Accelerator` should only be initialized inside "
"your training function. Restart your notebook and make sure no cells initializes an "
"`Accelerator`."
)
if num_processes is None:
num_processes = 8

launcher = PrepareForLaunch(function, distributed_type="XLA")
print(f"Launching a training on {num_processes} TPU cores.")
xmp.spawn(launcher, args=args, nprocs=num_processes, start_method="fork")
print(f"Launching a training on {device_count()} TPU cores.")
xmp.spawn(launcher, args=args, start_method="fork")
elif in_colab and get_gpu_info()[1] < 2:
# No need for a distributed launch otherwise as it's either CPU or one GPU.
if torch.cuda.is_available():
Expand Down
16 changes: 14 additions & 2 deletions tests/xla_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pathlib import Path

import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla import device_count


def parse_args():
Expand All @@ -46,7 +47,13 @@ def parse_args():
)

# Optional arguments for the launch helper
parser.add_argument("--num_cores", type=int, default=1, help="Number of TPU cores to use (1 or 8).")
num_devices = device_count()
parser.add_argument(
"--num_cores",
type=int,
default=num_devices,
help="Number of TPU cores to use (1 or number of available devices).",
)

# positional
parser.add_argument(
Expand Down Expand Up @@ -76,7 +83,12 @@ def main():
mod = importlib.import_module(mod_name)

# Patch sys.argv
sys.argv = [args.training_script] + args.training_script_args + ["--tpu_num_cores", str(args.num_cores)]
sys.argv = [args.training_script] + args.training_script_args
num_cores = args.num_cores
if num_cores == device_count() and num_cores != 1:
# There is an error in xmp.spawn that causes it to fail when num_cores is specified and not 1, so we set it to
# None when it matches the number of devices.
num_cores = None
xmp.spawn(mod._mp_fn, args=(), nprocs=args.num_cores)


Expand Down

0 comments on commit 95f34d6

Please sign in to comment.