Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(simulator): support parallel cost simulator for internevo #243

Draft
wants to merge 3 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions gen_profiler_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from internlm.simulator.profiler.perf_comm import gen_perf

if __name__ == "__main__":
gen_perf()
261 changes: 227 additions & 34 deletions internlm/core/context/parallel_context.py

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions internlm/core/context/process_group_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ class ParallelMode(Enum):

# grouped query attention
GQA = "gqa"

INTRA_DP_SZIE = "intra_dp"

INTER_DP_SZIE = "inter_dp"


class ProcessGroupInitializer(ABC):
Expand Down
227 changes: 227 additions & 0 deletions internlm/core/context/process_group_initializer_simplified.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from copy import deepcopy

import torch
import torch.distributed as dist

from internlm.utils.timeout import LLM_NCCL_TIMEOUT

class ParallelMeta:
def __init__(self, parallel_size, mode) -> None:
self.parallel_size = parallel_size
self.mode = mode

def __str__(self) -> str:
return self.__repr__()

def __repr__(self) -> str:
return f"{self.mode}, {self.parallel_size}"


def determine_intra_inter_size_of_group(one_group_indexs, intra_range=8):
"Determine the inter size and intra size of a rank group."
gourp_size = len(one_group_indexs)
if gourp_size == 1:
return 1, 1
else:
group_stride = one_group_indexs[1] - one_group_indexs[0]
if group_stride >= intra_range:
return 1, gourp_size
else:
intra_size = intra_range // group_stride
inter_size = gourp_size // intra_size
return max(1, intra_size), max(1, inter_size)


class Initializer:
def __init__(
self,
rank: int,
world_size: int,
fake_mode: bool = False,
tensor_mode: str = "fsp",
parallel_info: dict = None,
):
"""Initialize communication groups

Args:
rank (int): global rank
world_size (int): world size
fake_mode (bool, optional): Whether to create actual NCCL communication
groups.Defaults to False.
tensor_mode (str, optional): ISP/FSP/MSP. Defaults to "fsp".
parallel_info (dict, optional): parallel_info. Defaults to None.
"""
self.rank = rank
self.world_size = world_size
self.fake_mode = fake_mode
self.tensor_mode = tensor_mode
self.parallel_info = parallel_info

# assert sequence_parallel_size == tensor_parallel_size
super().__init__()

def init_dist_group(self, use_cpu: bool = False):
parallel_info, world_size = self.parallel_info, self.world_size

wp_size = parallel_info["wp"].parallel_size
# tp_size = parallel_info["tp"].parallel_size
# pp_size = parallel_info["pp"].parallel_size
wdp_size = parallel_info["wdp"].parallel_size
zero1_size = parallel_info["zero1"].parallel_size
ep_size = parallel_info["ep"].parallel_size
edp_size = parallel_info["edp"].parallel_size

re_group_args = {}

# stride_order means the placement priority of PG groups.
stride_order = ["tp", "dp", "pp"]
strides = {}

def assemble_group(all_ranks, dim_name):
for ranks in all_ranks:
if self.fake_mode or len(all_ranks) == 1:
group, group_cpu = None, None
else:
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu:
group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else group
)
else:
group_cpu = None

if self.rank in ranks:
local_rank = ranks.tolist().index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks.tolist()

new_all_ranks = []
for ranks in all_ranks:
new_all_ranks.append(ranks.tolist())

return (
local_rank,
group_world_size,
process_group,
cpu_group,
ranks_in_group,
new_all_ranks,
parallel_info[dim_name].mode,
)

def split_orthogonal_sub_group(dim_name, indexs, size, stride):
assert size <= world_size, f"{dim_name} stride: {size} should less then worldsize: {world_size} !"

indexs = indexs.reshape(-1, stride).T.reshape(-1)
all_ranks = torch.split(indexs, size)

return indexs, assemble_group(all_ranks, dim_name)

def split_horizontal_sub_group(dim_name, indexs, size, stride):
assert size <= world_size, f"{dim_name} stride: {size} should less then worldsize: {world_size} !"

indexs = indexs.reshape(stride, -1).reshape(-1)
all_ranks = torch.split(indexs, size)

return indexs, assemble_group(all_ranks, dim_name)

count = 0
for dim_name in stride_order:
parallel_size = parallel_info[dim_name].parallel_size
if parallel_size == 1:
continue

if count == 0:
strides[dim_name] = 1
else:
strides[dim_name] = strides[old_dim_name] * parallel_info[old_dim_name].parallel_size

father_indexs, group_args = split_orthogonal_sub_group(
dim_name, torch.arange(start=0, end=world_size), size=parallel_size, stride=strides[dim_name]
)
re_group_args[dim_name] = group_args

if dim_name == "dp":
"""
"EP, EDP, and ZeRO are auxiliary parallel modes within DP."
"""
if wp_size == 1 and self.tensor_mode != "isp":
re_group_args["zero1"] = split_horizontal_sub_group("zero1", father_indexs, zero1_size, zero1_size)[
1
]
print(f"re_group_args['zero1']: {re_group_args['zero1']}")

# MoE expert group is subgroup of data parallel group
if ep_size > 1:
ep_indexs, group_ep_args = split_horizontal_sub_group(
"ep", father_indexs, size=ep_size, stride=ep_size
)
re_group_args["ep"] = group_ep_args
re_group_args["edp"] = split_orthogonal_sub_group("edp", ep_indexs, edp_size, ep_size)[1]

one_group_indexs = group_args[4] # one group ranks
intra_dp_size, inter_dp_size = determine_intra_inter_size_of_group(one_group_indexs)

# It will be used in drawing heatmap.
parallel_info["intra_dp"].parallel_size = intra_dp_size
parallel_info["inter_dp"].parallel_size = inter_dp_size

# The only parallel group with a higher priority than DP is TP.
# see: stride_order = ["tp", "dp", "pp"]
high_priority_group = parallel_info["tp"].parallel_size

re_group_args["intra_dp"] = split_horizontal_sub_group(
"intra_dp", father_indexs, size=intra_dp_size, stride=high_priority_group
)[1]

re_group_args["inter_dp"] = split_orthogonal_sub_group(
"inter_dp", father_indexs, size=inter_dp_size, stride=intra_dp_size
)[1]

elif dim_name == "tp":
"""
The situation with isp is somewhat complex. When using isp, the head/embedding is partitioned
according to the Megatron-TP method and uses the TP communication group, while other modules
are partitioned according to the WP communication group and reuse the TP communication group
(but perform DeepSpeed-Ulysses instead of Megatron-TP). Therefore,
for head/embedding, their Zero1 communication group is orthogonal to the TP group,
for other modules, their Zero1 communication group is the Wdp communication group
(orthogonal to the WP/TP communication groups).
FIXME: Can this be further simplified?
"""
if self.tensor_mode == "isp":
if wp_size == 1:
re_group_args["zero1"] = split_horizontal_sub_group(
"zero1", father_indexs, zero1_size, zero1_size
)[1]
else:
wp_index, re_group_args["wp"] = split_horizontal_sub_group(
"wp", torch.arange(start=0, end=world_size), wp_size, wp_size
)
re_group_args["wdp"] = split_orthogonal_sub_group("wdp", wp_index, wdp_size, wp_size)[1]
re_group_args["zero1"] = split_orthogonal_sub_group(
"zero1", father_indexs, zero1_size, wp_size
)[1]

count += 1
old_dim_name = dim_name

for name, info in parallel_info.items():
if info.parallel_size == 1:
# If the degree of parallelism is 1, for logical consistency,
# we still need to create a logical communication group
re_group_args[name] = assemble_group([torch.tensor([self.rank])], name)

# If two groups are orthogonal to each other and one group has a parallelism degree of 1,
# then the parallelism degree of the other group is world_size.
if parallel_info["wp"].parallel_size == 1:
re_group_args["wdp"] = tuple(list(deepcopy(re_group_args["dp"]))[0:-1] + [parallel_info["wdp"].mode])

return re_group_args
15 changes: 15 additions & 0 deletions internlm/core/context/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# -*- encoding: utf-8 -*-
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context

import os
from contextlib import contextmanager

from torch import Tensor
Expand All @@ -10,6 +11,8 @@

from .process_group_initializer import ParallelMode

fake_mode = "fake_mode" in os.environ

internlm_accelerator = get_accelerator()


Expand All @@ -35,11 +38,15 @@ def seed_states(self):

def set_state(self, parallel_mode: ParallelMode, state: Tensor):
"""Sets the state of the seed manager for `parallel_mode`."""
if fake_mode:
return
assert parallel_mode in self._seed_states, f"{parallel_mode} not found in seed manager"
self._seed_states[parallel_mode] = state

def set_mode(self, parallel_mode: ParallelMode, update_rng_current_mode: bool = True):
"""Sets the current mode of the seed manager."""
if fake_mode:
return
if update_rng_current_mode and self.current_mode:
# save state for current mode
self._seed_states[self._current_mode] = internlm_accelerator.get_rng_state()
Expand All @@ -50,6 +57,8 @@ def set_mode(self, parallel_mode: ParallelMode, update_rng_current_mode: bool =

def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrite: bool = False):
"""Adds a seed to the seed manager for `parallel_mode`."""
if fake_mode:
return
assert isinstance(parallel_mode, ParallelMode), "Invalid ParallelMode"
if not overwrite:
assert parallel_mode not in self._seed_states, f"Seed for {parallel_mode} exists"
Expand All @@ -63,6 +72,8 @@ def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrite: bool = Fal
internlm_accelerator.set_rng_state(current_state)

def reset(self):
if fake_mode:
return
self._current_mode = None
self._seeds = {}
self._seed_states = {}
Expand Down Expand Up @@ -131,3 +142,7 @@ def seed(parallel_mode: ParallelMode):
yield _SEED_MANAGER.set_mode(parallel_mode)
finally:
_SEED_MANAGER.set_mode(current_mode)


def reset_seed():
_SEED_MANAGER.reset()
12 changes: 6 additions & 6 deletions internlm/core/parallel/comm/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def register_prerequisite_for_forward_prefetch_hooks(self, prerequisite_func: Ca
def weight_hook(
self, tensor: torch.Tensor, async_op: bool = False, module: nn.Module = None, is_bias: bool = False
) -> torch.Tensor:
if dist.get_world_size(self.process_group) <= 1:
if gpc.get_group_size(self.process_group) <= 1:
return tensor

if not self.overlap:
Expand All @@ -545,7 +545,7 @@ def grad_hook(
reduce_op: dist.ReduceOp = dist.ReduceOp.AVG,
is_bias: bool = False,
) -> Tuple[torch.Tensor, AsyncCommHandle]:
if dist.get_world_size(self.process_group) <= 1:
if gpc.get_group_size(self.process_group) <= 1:
return tensor, DUMMY_HANDLE_CONST

if not self.overlap:
Expand Down Expand Up @@ -573,7 +573,7 @@ def grad_hook(
result, handle = (
self._get_constant_zero(
(
tensor.shape[0] // dist.get_world_size(self.process_group),
tensor.shape[0] // gpc.get_group_size(self.process_group),
*tensor.shape[1:],
)
),
Expand Down Expand Up @@ -634,10 +634,10 @@ def forward(ctx, group: dist.ProcessGroup, input_: torch.Tensor, scatter_idx: in
ctx.scatter_idx = scatter_idx
ctx.gather_idx = gather_idx

if dist.get_world_size(group) <= 1:
if gpc.get_group_size(group) <= 1:
return input_

seq_world_size = dist.get_world_size(group)
seq_world_size = gpc.get_group_size(group)

input_list = [t.contiguous() for t in torch.tensor_split(input_, seq_world_size, scatter_idx)]
output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
Expand All @@ -647,7 +647,7 @@ def forward(ctx, group: dist.ProcessGroup, input_: torch.Tensor, scatter_idx: in

@staticmethod
def backward(ctx, *grad_output: torch.Tensor) -> Tuple[None, torch.Tensor, None, None]:
if dist.get_world_size(ctx.group) <= 1:
if gpc.get_group_size(ctx.group) <= 1:
return (None, *grad_output, None, None)

return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None)
Expand Down
Loading