Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DP related changes to prepare for EP #1192

Merged
merged 1 commit into from
Jan 12, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions fairscale/nn/model_parallel/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@
_MODEL_PARALLEL_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
_DATA_PARALLEL_RANKS = None
# Pipeline parallel group that the current rank belongs to.
_PIPELINE_PARALLEL_GROUP = None
_PIPELINE_PARALLEL_RANKS = None

# Context parallel group that the current rank belongs to.
_CONTEXT_PARALLEL_GROUP = None
_CONTEXT_PARALLEL_GROUP_RANKS = None

Expand Down Expand Up @@ -111,12 +112,15 @@ def initialize_model_parallel(
# Build the data parallel groups.
global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
global _DATA_PARALLEL_GROUP_RANKS
for i in range(pipeline_length):
for j in range(context_parallel_size):
for k in range(model_parallel_size):
group = torch.distributed.new_group(groups[:, i, j, k].tolist(), backend=ddp_backend, timeout=timeout)
ranks = groups[:, i, j, k].tolist()
group = torch.distributed.new_group(ranks, backend=ddp_backend, timeout=timeout)
if i == found[1] and j == found[2] and k == found[3]:
_DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_GROUP_RANKS = ranks


# Build the model parallel groups.
Expand Down Expand Up @@ -244,13 +248,21 @@ def get_data_parallel_rank() -> int:
return torch.distributed.get_rank(group=get_data_parallel_group())


def get_data_parallel_ranks() -> List[int]:
"""Return data parallel ranks for the data parallel group."""
assert _DATA_PARALLEL_GROUP_RANKS is not None, "data parallel group is not initialized"
return _DATA_PARALLEL_GROUP_RANKS


def destroy_model_parallel() -> None:
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None

global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None
global _DATA_PARALLEL_RANKS
_DATA_PARALLEL_RANKS = None

global _PIPELINE_PARALLEL_GROUP
_PIPELINE_PARALLEL_GROUP = None
Expand Down
Loading