diff --git a/src/nanotron/distributed.py b/src/nanotron/distributed.py index 6dbb0b26..01438719 100644 --- a/src/nanotron/distributed.py +++ b/src/nanotron/distributed.py @@ -9,6 +9,8 @@ from torch.distributed import * # noqa from torch.distributed.distributed_c10d import ProcessGroup +from nanotron.utils import find_free_port + torch_version_above_1_13 = version.parse(torch.__version__) >= version.parse("1.13.0") Work = dist.Work if torch_version_above_1_13 else dist._Work default_pg_timeout = datetime.timedelta(minutes=10) @@ -238,7 +240,7 @@ def get_rank(group: Optional[ProcessGroup] = None) -> int: # pylint: disable=fu return result -def initialize_torch_distributed(port: int): +def initialize_torch_distributed(): """Initializes torch distributed with the environment variables""" rank = int(os.getenv("RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "1")) @@ -257,6 +259,7 @@ def initialize_torch_distributed(port: int): backend = "gloo" # Call the init process. + port = find_free_port() init_method = f"env://localhost:{port}" dist.init_process_group( init_method=init_method, backend=backend, world_size=world_size, rank=rank, timeout=dist.default_pg_timeout diff --git a/src/nanotron/parallel/context.py b/src/nanotron/parallel/context.py index ba71805d..5063454a 100644 --- a/src/nanotron/parallel/context.py +++ b/src/nanotron/parallel/context.py @@ -1,5 +1,5 @@ import os -from typing import Literal, Optional, Tuple +from typing import Literal, Tuple import numpy as np import torch @@ -15,7 +15,6 @@ def __init__( tensor_parallel_size: int, pipeline_parallel_size: int, data_parallel_size: int, - port: Optional[int] = None, backend: DistributedBackend = "nccl", ): """Initialize parallel context.""" @@ -49,10 +48,10 @@ def __init__( assert backend == "nccl", "Only nccl backend is supported for now." if not dist.is_initialized(): - from nanotron.utils import find_free_port + # from nanotron.utils import find_free_port - port = find_free_port() if port is None else port - dist.initialize_torch_distributed(port) + # port = find_free_port() if port is None else port + dist.initialize_torch_distributed() world_size = int(os.getenv("WORLD_SIZE", "1")) ranks = list(range(world_size)) diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 4265c741..698f300f 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -77,15 +77,19 @@ def __init__(self, func, args, kwargs, tp: int, dp: int, pp: int): def __call__(self): with mock_os_environ(update_key_values={"WORLD_SIZE": f"{self.tp * self.dp * self.pp}"}): # NOTE: we use a different random RNG, so that each unit tests don't generate the same port - seed = random.randint(0, 9999) - with torch.random.fork_rng(devices=["cuda"]): - from nanotron.utils import find_free_port - - torch.manual_seed(seed) - port = find_free_port() - parallel_context = ParallelContext( - data_parallel_size=self.dp, pipeline_parallel_size=self.pp, tensor_parallel_size=self.tp, port=port - ) + # seed = random.randint(0, 9999) + # with torch.random.fork_rng(devices=["cuda"]): + # from nanotron.utils import find_free_port + + import time + + random.seed(time.time()) + + # torch.manual_seed(seed) + # port = find_free_port() + parallel_context = ParallelContext( + data_parallel_size=self.dp, pipeline_parallel_size=self.pp, tensor_parallel_size=self.tp + ) assert "parallel_context" not in self.kwargs self.kwargs["parallel_context"] = parallel_context