diff --git a/gen_profiler_data.py b/gen_profiler_data.py index b22ed5b5..37e3b78a 100644 --- a/gen_profiler_data.py +++ b/gen_profiler_data.py @@ -1,5 +1,4 @@ - from internlm.simulator.profiler.perf_comm import gen_perf if __name__ == "__main__": - gen_perf() \ No newline at end of file + gen_perf() diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 8fa4b2fe..0314b802 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -18,13 +18,12 @@ import torch.distributed as dist from internlm.accelerator import get_accelerator -from internlm.core.context.process_group_initializer_simplified import Initializer, ParallelMeta from internlm.utils.common import SingletonMeta from internlm.utils.logger import get_logger from internlm.utils.timeout import LLM_NCCL_TIMEOUT from . import process_group_initializer as pgroup_initializer -from .process_group_initializer_simplified import ParallelMode +from .process_group_initializer import ParallelMode from .random import add_seed, get_seeds, set_mode from internlm.utils.common import get_args @@ -422,20 +421,6 @@ def init_global_dist( use_cpu (bool): whether to set up cpu process group. """ - # find cluster info - if "clusters" not in self.config: - nv_info = { - "rank_range": [0, 8], - "peak_tflops": 320, - "capacity": 80 * 1024**3, - "intra_bw": 150, - "inter_bw": 100, - } - self.set_cluster_info("nv_cluster", nv_info) - else: - for cluster in self.config.clusters: - self.clusters.append(ClusterInfo(**cluster)) - # initialize the default process group if not fake_mode: init_method = f"tcp://[{host}]:{port}" @@ -667,7 +652,7 @@ def _init_pg(self, rank, world_size, parallel_config): initializers.append(pgroup_initializer.Initializer_ISP_Data(*initializer_args)) if ( isinstance(parallel_config["tensor"], dict) - and parallel_config["tensor"]["mode"] == TensorParallelMode.isp.name + and parallel_config["tensor"]["mode"] == "isp" ): initializers.append(pgroup_initializer.Initializer_Zero1_ISP(*initializer_args)) else: @@ -688,6 +673,8 @@ def _init_pg(self, rank, world_size, parallel_config): self._register_dist(*parallel_setting) def _init_use_simplified_pg(self, rank, world_size, parallel_config): + from internlm.core.context.process_group_initializer_simplified import InitializerParallelMeta + try: self.tensor_mode = parallel_config["tensor"]["mode"] except AttributeError: @@ -861,14 +848,14 @@ def check_pg_is_intra(self, parallel_mode: ParallelMode): return (max_rank - min_rank) <= 7 def same_group_in_one_node(self, parallel_mode: ParallelMode): - """获得一个节点内有多少个相同类型的PG, 在跨节点通信时会存在带宽竞争 - 这里返回的相同PG的数量会乘上每个rank的通信数据量大小 + """Get the number of the same type of PG within a node. There will be bandwidth competition during cross-node communication. + The number of the same PG returned here will be multiplied by the communication data size of each rank. Args: parallel_mode (ParallelMode): Returns: - int: 一个节点内相同类型的PG的数量 + int: The number of the same type of PG within a node. """ pg_group_ranks = self.get_ranks_in_group(parallel_mode) pg_group_ranks = sorted(pg_group_ranks) @@ -881,68 +868,7 @@ def same_group_in_one_node(self, parallel_mode: ParallelMode): else: return stride - # def set_cluster_info(self, name: str, info: dict): - # self.clusters[name] = ClusterInfo(**info) - - def get_cluster_info(self, name: str): - return self.clusters[name] - - def get_cluster_name_from_ip(self): - """ - node_ip_list = [ - 'metax-c500-1', - 'metax-c500-2', - 'nvidia-node-1', - 'nvidia-node-2', - ] - """ - hostname = socket.gethostname() - cluster_name = hostname.split("-")[0] - return cluster_name - - def sort_rank_based_on_ip_and_capacity(self): - Capacity = [] - - def sort_rank(x, y): - x_name = self.get_cluster_name_from_ip(x) - y_name = self.get_cluster_name_from_ip(y) - if x_name == y_name: - return x_name > y_name - else: - x_c = self.clusters[x_name]["capacity"] - y_c = self.clusters[y_name]["capacity"] - return x_c > y_c - for cluster_name, cluster_info in self.clusters.items(): - peak_tflops.append(cluster_info["peak_tflops"]) - # Alpha.append(cluster_info.rank_range[-1] - cluster_info.rank_range[-1] + 1) - Capacity.append(cluster_info["capacity"]) - - def switch_topology_aware_rank_scheduling(): - """ - Switch topology-aware rank scheduling can optimize the performance of small-scale - collective communications. Currently only supported in Alibaba Cloud. - """ - - local_rank = int(os.environ["LOCAL_RANK"]) - cluster_name = get_cluster_name_from_ip() - - try: - if cluster_name == "Ali": - pass - else: - rank = int(os.environ["MLP_WORKER_RACK_RANK_INDEX"]) * 8 + local_rank - except Exception as e: - logger.error( - f"The switch topology awareness error is reported, the reason is: {e}", - "but don’t worry, this error will not affect normal training.", - "If you train on Alibaba or Volcano Cloud, please contact wangguoteng or lijiaxing", - ) - else: - # If there is no any error, hack torch rank. - os.environ["RANK"] = str(rank) - if local_rank == 0: - logger.info("Successfully bound node switch affinity!") global_context = ParallelContext() diff --git a/internlm/core/context/process_group_initializer.py b/internlm/core/context/process_group_initializer.py index 5519c7d8..1d27bf93 100644 --- a/internlm/core/context/process_group_initializer.py +++ b/internlm/core/context/process_group_initializer.py @@ -62,6 +62,10 @@ class ParallelMode(Enum): # grouped query attention GQA = "gqa" + + INTRA_DP_SZIE = "intra_dp" + + INTER_DP_SZIE = "inter_dp" class ProcessGroupInitializer(ABC): diff --git a/internlm/core/context/process_group_initializer_simplified.py b/internlm/core/context/process_group_initializer_simplified.py index c1423a5a..257b4f66 100644 --- a/internlm/core/context/process_group_initializer_simplified.py +++ b/internlm/core/context/process_group_initializer_simplified.py @@ -2,13 +2,11 @@ # -*- encoding: utf-8 -*- from copy import deepcopy -from enum import Enum import torch import torch.distributed as dist from internlm.utils.timeout import LLM_NCCL_TIMEOUT -from internlm.core.context.process_group_initializer import ParallelMode class ParallelMeta: def __init__(self, parallel_size, mode) -> None: diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 9a719cf8..e9508403 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -12,7 +12,7 @@ from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import Config from internlm.core.context import global_context as gpc -from internlm.core.context.process_group_initializer_simplified import ParallelMode +from internlm.core.context.process_group_initializer import ParallelMode from internlm.utils.common import get_master_node from internlm.utils.gputest import warmup_process_group from internlm.utils.logger import get_logger @@ -86,7 +86,9 @@ def add_simulator_arguments(parser): group.add_argument( "--pre_profiling_data_path", type=str, help="The path to pre-profiled performance data on the target cluster." ) + group.add_argument("--use_simplified_gp_init", action="store_true", default=True) group.add_argument("--use_simplified_gp_init", action="store_true", default=False) + return parser diff --git a/internlm/model/ops/linear.py b/internlm/model/ops/linear.py index da3eda5b..6ea9ac62 100644 --- a/internlm/model/ops/linear.py +++ b/internlm/model/ops/linear.py @@ -14,10 +14,6 @@ from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import global_context as gpc -from internlm.simulator.ops.linear import ( - _fake_linear_bwdward_op, - _fake_linear_forward_op, -) try: from fused_dense_lib import linear_bias_wgrad as _flash_linear_backward_op diff --git a/internlm/simulator/README.md b/internlm/simulator/README.md new file mode 100644 index 00000000..6e089a94 --- /dev/null +++ b/internlm/simulator/README.md @@ -0,0 +1,54 @@ +# InternLM Simulator + + +## 1. Introduction +The solver mainly consists of two components: +1. `profiling`: Collects the time consumption of each stage during the model training process in advance and saves it as data files and image files. +2. `simulation`: Simulates the model training process based on the collected data files and outputs the time consumption of each stage during the training process. + +## 2. Usage + +### 2.1 Generate profiling data + +There are two types of profiling data: +1. '`linear`' profiling data, include: [`LINEAR`] +2. '`Communication`' profiling data, include: [`ALL2ALL`, `ALLREDUCE`, `REDUCESCATTER`, `ALLGATHER`, `BROADCAST`] + + +Note: +1. It is recommended to use more than 64 GPUs for data collection to ensure more accurate communication data. +2. `Flash Attention` information is not collected in advance but is collected on the fly during the simulation and stored in the cache. This is because there are many variables that affect the performance of flash attention, and collecting in advance cannot cover all variables. + +```python +# generate profiling data +torchrun --nproc-per-node=8 gen_profiler_data.py + +# the profiling data will be saved in the following path +./prof_data +├── data.pt +└── pics + ├── cal + │ └── linear.jpg + └── comm + ├── all2all_intra_2_inter_1.jpg + ├── all2all_intra_4_inter_1.jpg + ├── all_gather_intra_2_inter_1.jpg + ├── all_gather_intra_4_inter_1.jpg + ├── all_reduce_intra_2_inter_1.jpg + ├── all_reduce_intra_4_inter_1.jpg + ├── broadcast_intra_2_inter_1.jpg + ├── broadcast_intra_4_inter_1.jpg + ├── reduce_scatter_intra_2_inter_1.jpg + └── reduce_scatter_intra_4_inter_1.jpg + +``` + +### 2.2 Run simulation +```python +python simulation_train_formulaic.py --pre_profiling_data_path ./data/profiling_data.json --config configs/exp_simluator.py + +``` + + + +## 4. 贡献 diff --git a/internlm/simulator/profiler/benchmark/multi_head_attn.py b/internlm/simulator/profiler/benchmark/multi_head_attn.py index f4cf5d73..bec18e1c 100644 --- a/internlm/simulator/profiler/benchmark/multi_head_attn.py +++ b/internlm/simulator/profiler/benchmark/multi_head_attn.py @@ -60,150 +60,3 @@ def run(): t_bwds += t_bwd return t_fwds / trials, t_bwds / trials - - -# from .base_benchmark import UnitBench -# import math - -# import torch -# from einops import rearrange -# from torch import nn - -# from internlm.model.registry import benchmark_initializer -# from internlm.simulator.common import TP_SIZE_RANGE, K, get_local_rank -# from internlm.utils.common import get_current_device - -# try: -# from flash_attn.flash_attn_interface import ( -# flash_attn_qkvpacked_func, -# flash_attn_varlen_func, -# ) -# from flash_attn.modules.mha import FlashSelfAttention, SelfAttention -# except ModuleNotFoundError: -# print("import fa failed!", flush=True) -# try: -# from deeplink_ext.internevo_ops import FlashCrossAttention, FlashSelfAttention -# except ModuleNotFoundError: -# flash_attn_qkvpacked_func = None -# FlashSelfAttention = None -# SelfAttention = None -# print("import dipu fa failed!", flush=True) - - -# @benchmark_initializer.register_module(module_name=BENCH_TYPE) - -# 对于FA,我们还是用on the fly的方式 profiling,并用cache缓存中间结果 -# class UnitMultiHeadAttn(UnitBench): -# # test_loop = { -# # "seq_len": [ -# # 64 * K, -# # int(0.25 * K), -# # int(0.5 * K), -# # 1 * K, -# # 2 * K, -# # 4 * K, -# # 8 * K, -# # 32 * K, -# # 16 * K, -# # ], # 256 * K, 128 * K, -# # "head_H": [(64, 8192), (48, 6144), (32, 4096), (40, 5120)], # (80, 10240), -# # "dtype": [torch.bfloat16], -# # "micro_bsz": [2, 1], # 4, -# # "tp_size": TP_SIZE_RANGE, -# # "is_fwd": [True, False], -# # } - -# def __init__(self, seq_len, num_heads_and_hidden_dim, dtype, micro_bsz, tp_size, is_fwd) -> None: -# q_head, kv_head, embed_dim = num_heads_and_hidden_dim -# self.num_heads_and_hidden_dim = num_heads_and_hidden_dim -# self.TP = tp_size -# self.S = seq_len -# self.N = num_heads -# self.H = embed_dim // self.N -# self.dtype = dtype -# self.dtype_size = 2 if self.dtype == torch.bfloat16 else 4 -# self.B = micro_bsz -# self.oom = False -# self.is_fwd = is_fwd -# self.causal = True - -# assert num_heads % self.TP == 0, "num_heads must be divisible by tp_size" -# assert num_heads >= tp_size, f"head nums must bigger then tp_size: {tp_size}" - -# self.num_atten_head_tp = num_heads // self.TP -# self.head_dim = self.H // num_heads -# self.tp_embedding_dim = self.H // self.TP - -# self.packed_length = self.S * self.B -# self.device = f"cuda:{get_local_rank()}" -# cu_seqlens = [i * self.S for i in range(self.B + 1)] - -# weights_mem_used = self.packed_length * 3 * self.H * self.dtype_size -# attn_activation = 11 * self.packed_length * self.H -# mem_used = attn_activation + weights_mem_used - -# self.inner_attn = FlashSelfAttention(causal=True, softmax_scale=self.H ** (0.5), attention_dropout=0.0) - -# oom = False -# if mem_used > 75 * 1024**3: -# oom = True - -# # 约束1: seqlen最大不能超过256K(不含) -# # 约束2: embed_dim在被tp切过之后若大于6144, 则packed_length不能大于256k -# if self.packed_length >= 256 * K and (self.H / self.TP) >= 6144: -# oom = True -# if self.S >= 256 * K and self.B > 1: -# oom = True -# if self.packed_length >= 524288 and (self.H / self.TP) >= 3072: -# oom = True -# if self.packed_length >= 1048576 and (self.H / self.TP) >= 2048: -# oom = True - -# if oom: -# assert ( -# False -# ), f"warning : mem_used: {mem_used/1024**3:.2f} GB, seq_len: {self.S}, embed_dim: {self.H}, tp_size: {self.TP}" - -# self.qkv = torch.rand( -# size=(self.B * self.S, 3, self.N // self.TP, self.H), -# dtype=self.dtype, -# device=self.device, -# requires_grad=True, -# ) - -# self.dtype_size = self.qkv.element_size() -# self.cu_seqlens = torch.tensor(data=cu_seqlens, dtype=torch.int32, device=self.device) -# self.max_seqlen = self.S -# if not self.is_fwd: -# self.output = self.run_fwd() -# self.grad = torch.randn_like(self.output) / 32 # avoid grad is too large. - -# def run(self): -# if self.is_fwd: -# self.run_fwd() -# else: -# self.run_bwd(self.output, self.grad) - -# def run_fwd(self): -# context = self.inner_attn(self.qkv, cu_seqlens=self.cu_seqlens, max_seqlen=self.max_seqlen, causal=self.causal) -# return context - -# def run_bwd(self, output, grad): -# output.backward(grad, retain_graph=True) - -# @staticmethod -# def gen_store_key(micro_bsz, seq_len, num_heads_and_hidden_dim, tp_size, is_fwd): -# _, embed_dim = num_heads_and_hidden_dim -# tp_embedding_dim = embed_dim // tp_size -# return f"b_{micro_bsz}_s_{seq_len}_h_{tp_embedding_dim}_fwd_{is_fwd}" - -# def complexity(self): -# return UnitMultiHeadAttn.gen_store_key(self.B, self.S, self.num_heads_and_hidden_dim, self.TP, self.is_fwd) -# # return f"{self.S} * {self.hidden_dim} * {self.hidden_dim}" - - -if __name__ == "__main__": - - micro_bsz, seqlen, hidden_size, q_head, kv_head, dtype = 1, 4096, 4096, 32, 8, torch.bfloat16 - t_fwd, t_bwd = run_fwd(micro_bsz, seqlen, hidden_size, q_head, kv_head, dtype) - print(f"t_fwd: {t_fwd}, t_bwd: {t_bwd}", flush=True) diff --git a/internlm/simulator/profiler/perf_comm.py b/internlm/simulator/profiler/perf_comm.py index 58402bfb..78260990 100644 --- a/internlm/simulator/profiler/perf_comm.py +++ b/internlm/simulator/profiler/perf_comm.py @@ -125,7 +125,8 @@ def gen_perf(): ) group = dist.GroupMember.WORLD - gpc._register_dist(rank, world_size, group, None, list(range(world_size)), ParallelMode.GLOBAL) + # local_rank, world_size, process_group, cpu_group, ranks_in_group, all_ranks, mode + gpc._register_dist(rank, world_size, group, None, list(range(world_size)), list(range(world_size)), ParallelMode.GLOBAL) gpc._global_ranks[ParallelMode.GLOBAL] = rank gpc.set_device(local_rank) @@ -162,29 +163,30 @@ def gen_perf(): sync_all() - for i in range(inter_comm_nums): - for j in range(intra_comm_nums): - inter_size, intra_size = 2**i, 2**j - if inter_size * intra_size != 1: - - x_idx, y_idx = get_group_id(rank, gpus_per_node, intra_size, inter_size) - groups = new_process_group(world_size, gpus_per_node, intra_size, inter_size) - - for test_type in comm_test_list: - key = gen_comm_key(test_op, intra_size, inter_size) - if dist.get_rank() == 0: - print( - f"key: {key}, inter_size: {inter_size}, intra_size: {intra_size}, ranks: {groups[y_idx][x_idx][1]}", - flush=True, - ) - pg = groups[y_idx][x_idx][0] - assert ( - pg != -100 - ), f"key: {key}, x_idx: {x_idx}, y_idx: {y_idx}, rank: {gpc.get_global_rank()}, ranks: {groups[y_idx][x_idx][1]}" - comm_vols, bws = run_comm_profile(test_type, pg, key) - sync_all() - if dist.get_rank() == 0: - spline_model_dict[key] = draw_pics(comm_pic_path, key, comm_vols, bws) + for test_op in comm_test_list: + for i in range(inter_comm_nums): + for j in range(intra_comm_nums): + inter_size, intra_size = 2**i, 2**j + if inter_size * intra_size != 1: + + x_idx, y_idx = get_group_id(rank, gpus_per_node, intra_size, inter_size) + groups = new_process_group(world_size, gpus_per_node, intra_size, inter_size) + + for test_type in comm_test_list: + key = gen_comm_key(test_op, intra_size, inter_size) + if dist.get_rank() == 0: + print( + f"key: {key}, inter_size: {inter_size}, intra_size: {intra_size}, ranks: {groups[y_idx][x_idx][1]}", + flush=True, + ) + pg = groups[y_idx][x_idx][0] + assert ( + pg != -100 + ), f"key: {key}, x_idx: {x_idx}, y_idx: {y_idx}, rank: {gpc.get_global_rank()}, ranks: {groups[y_idx][x_idx][1]}" + comm_vols, bws = run_comm_profile(test_type, pg, key) + sync_all() + if dist.get_rank() == 0: + spline_model_dict[key] = draw_pics(comm_pic_path, key, comm_vols, bws) print(f"rank: {gpc.get_global_rank()}, all done!", flush=True) diff --git a/internlm/simulator/profiler/profiler.py b/internlm/simulator/profiler/profiler.py index 9d28bbfd..adea31fd 100644 --- a/internlm/simulator/profiler/profiler.py +++ b/internlm/simulator/profiler/profiler.py @@ -234,24 +234,7 @@ def draw_pics(base_path, plot_name, comm_vols, bws): def draw_cal_pics(base_path, plot_name, tflop, tflops): - # x, y = [], [] - spline_model = interp1d(tflop, tflops, kind="slinear") - - # start = tflop[0] - # end = tflop[-1] - # for complexity in range(start, end+1): - # try: - # predice_tflops = spline_model(complexity) - # except ValueError: - # if complexity < tflop[0]: - # predice_tflops = spline_model(tflop[0]) - # elif complexity > tflop[-1]: - # predice_tflops = spline_model(tflop[-1]) - - # x.append(complexity) - # y.append(predice_tflops) - pic_path = os.path.join(base_path, plot_name + ".jpg") tflop = list(map(lambda x: x / 10**12, tflop)) tflops = list(map(lambda x: x / 10**12, tflops)) diff --git a/simulation_train_formulaic.py b/simulation_train_formulaic.py index 858bd05d..8b3654ab 100644 --- a/simulation_train_formulaic.py +++ b/simulation_train_formulaic.py @@ -12,6 +12,7 @@ from internlm.core.context import global_context as gpc from internlm.core.context.random import reset_seed from internlm.core.parallel.shard import cluster_load_balance, partition_uniform +from internlm.initialize import get_default_parser from internlm.initialize.launch import args_sanity_check, launch from internlm.simulator.common import AlgoType, cal_block_p_elem, cal_model_p_elem @@ -82,6 +83,8 @@ def comm_dp_cost(dtype_size, algo, pp_blocks_elem, embedding_elem, zp) -> float: block_zp_latency = zp * broadcast(dtype_size * pp_blocks_elem / zp, ParallelMode.ZERO1, comm_nums=zp) embedding_zp_latency = broadcast(dtype_size * embedding_elem, ParallelMode.DATA) zp_latency = max(block_zp_latency, embedding_zp_latency) + else: + raise ValueError(f"Invalid algo type: {algo}") return zp_latency, wdp_latency @@ -384,16 +387,20 @@ def overlaped_fwd_bwd_cost(): def run_loop( global_bsz, world_size, - args, + config_path, use_fixed_micro_bsz=False, use_strict_bsz=True, global_bsz_max=1, global_bsz_min=1, debug=True, ): - gpc.load_config(config=Config.from_file(args.config)) + gpc.load_config(config=Config.from_file(config_path)) gpc.set_fake_mode(True) + if "multiple_of" not in gpc.config.model: + gpc.config.model["multiple_of"] = 256 + print(f"multiple_of not in config, use default value: {gpc.config.model.multiple_of}") + min_comm_cost, msp_min_cost, fsp_min_cost, isp_min_cost = ( float("inf"), float("inf"), @@ -566,9 +573,9 @@ def run_loop( return solutions_list, min_comm_cost, min_cost_solution, msp_min_solu, fsp_min_solu, isp_min_solu -def run_warrper(global_bsz, world_size, args): +def run_warrper(world_size, global_bsz, config_path): solutions_list, min_comm_cost, min_cost_solution, msp_min_solu, fsp_min_solu, isp_min_solu = run_loop( - global_bsz=global_bsz, world_size=world_size, args=args + global_bsz=global_bsz, world_size=world_size, config_path=config_path ) if min_cost_solution is not None: @@ -606,7 +613,17 @@ def run_warrper(global_bsz, world_size, args): print("No solution found") -def run_single(global_bsz=4 * 1024 * 1024): +def get_world_size(): + if "WORLD_SIZE" in os.environ: + return int(os.environ["WORLD_SIZE"]) + else: + if "SLURM_NTASKS" in os.environ: + return int(os.environ["SLURM_NTASKS"]) + else: + return 1 + + +def run_single(world_size, global_bsz=4 * 1024 * 1024): gpc.load_config(config=Config.from_file(args.config)) gpc.set_fake_mode(True) print(f"gpc.config.parallel: {gpc.config.parallel}") @@ -675,17 +692,19 @@ def run_single(global_bsz=4 * 1024 * 1024): if __name__ == "__main__": - args = parse_args() + parser = get_default_parser() + args = parser.parse_args() + hostname = socket.gethostname() - world_size = args.world_size + global_batch_size = args.global_batch_size - init_cost_model(get_args().pre_profiling_data_path) + init_cost_model(args.pre_profiling_data_path) os.environ["fake_mode"] = "1" gloab_allocator.init_capcity = 80 * 1024**3 gloab_allocator.capcity = 80 * 1024**3 - if get_args().run_all_solu: - run_warrper(4096 * 1024, world_size, args) + if args.run_all_solu: + run_warrper(args.world_size, args.global_batch_size, args.config) else: - run_single(get_args().global_batch_size) + run_single(args.world_size, args.global_batch_size)