Skip to content

Commit

Permalink
add python random seed
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Feb 11, 2024
1 parent 5b375f5 commit 081b17d
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 15 deletions.
5 changes: 4 additions & 1 deletion src/nanotron/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from torch.distributed import * # noqa
from torch.distributed.distributed_c10d import ProcessGroup

from nanotron.utils import find_free_port

torch_version_above_1_13 = version.parse(torch.__version__) >= version.parse("1.13.0")
Work = dist.Work if torch_version_above_1_13 else dist._Work
default_pg_timeout = datetime.timedelta(minutes=10)
Expand Down Expand Up @@ -238,7 +240,7 @@ def get_rank(group: Optional[ProcessGroup] = None) -> int: # pylint: disable=fu
return result


def initialize_torch_distributed(port: int):
def initialize_torch_distributed():
"""Initializes torch distributed with the environment variables"""
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
Expand All @@ -257,6 +259,7 @@ def initialize_torch_distributed(port: int):
backend = "gloo"

# Call the init process.
port = find_free_port()
init_method = f"env://localhost:{port}"
dist.init_process_group(
init_method=init_method, backend=backend, world_size=world_size, rank=rank, timeout=dist.default_pg_timeout
Expand Down
9 changes: 4 additions & 5 deletions src/nanotron/parallel/context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Literal, Optional, Tuple
from typing import Literal, Tuple

import numpy as np
import torch
Expand All @@ -15,7 +15,6 @@ def __init__(
tensor_parallel_size: int,
pipeline_parallel_size: int,
data_parallel_size: int,
port: Optional[int] = None,
backend: DistributedBackend = "nccl",
):
"""Initialize parallel context."""
Expand Down Expand Up @@ -49,10 +48,10 @@ def __init__(
assert backend == "nccl", "Only nccl backend is supported for now."

if not dist.is_initialized():
from nanotron.utils import find_free_port
# from nanotron.utils import find_free_port

port = find_free_port() if port is None else port
dist.initialize_torch_distributed(port)
# port = find_free_port() if port is None else port
dist.initialize_torch_distributed()

world_size = int(os.getenv("WORLD_SIZE", "1"))
ranks = list(range(world_size))
Expand Down
22 changes: 13 additions & 9 deletions tests/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,19 @@ def __init__(self, func, args, kwargs, tp: int, dp: int, pp: int):
def __call__(self):
with mock_os_environ(update_key_values={"WORLD_SIZE": f"{self.tp * self.dp * self.pp}"}):
# NOTE: we use a different random RNG, so that each unit tests don't generate the same port
seed = random.randint(0, 9999)
with torch.random.fork_rng(devices=["cuda"]):
from nanotron.utils import find_free_port

torch.manual_seed(seed)
port = find_free_port()
parallel_context = ParallelContext(
data_parallel_size=self.dp, pipeline_parallel_size=self.pp, tensor_parallel_size=self.tp, port=port
)
# seed = random.randint(0, 9999)
# with torch.random.fork_rng(devices=["cuda"]):
# from nanotron.utils import find_free_port

import time

random.seed(time.time())

# torch.manual_seed(seed)
# port = find_free_port()
parallel_context = ParallelContext(
data_parallel_size=self.dp, pipeline_parallel_size=self.pp, tensor_parallel_size=self.tp
)

assert "parallel_context" not in self.kwargs
self.kwargs["parallel_context"] = parallel_context
Expand Down

0 comments on commit 081b17d

Please sign in to comment.