From c22a1d185f16874e9eb7c0fd74323c495f0dda4a Mon Sep 17 00:00:00 2001 From: sallyjunjun Date: Fri, 6 Dec 2024 15:42:37 +0800 Subject: [PATCH] support fp32 all_reduce and reduce_scatter --- internlm/core/parallel/comm/utils.py | 34 +++++++++++++++++++ .../core/scheduler/no_pipeline_scheduler.py | 4 +++ .../core/scheduler/pipeline_scheduler_1f1b.py | 21 ++++++++++++ .../core/scheduler/pipeline_scheduler_zb.py | 5 ++- internlm/initialize/launch.py | 17 ++++++++++ internlm/solver/optimizer/utils.py | 30 +++++++++++++--- 6 files changed, 105 insertions(+), 6 deletions(-) diff --git a/internlm/core/parallel/comm/utils.py b/internlm/core/parallel/comm/utils.py index 5cd8cb79..2c4bf71f 100644 --- a/internlm/core/parallel/comm/utils.py +++ b/internlm/core/parallel/comm/utils.py @@ -29,10 +29,32 @@ def wait(self) -> None: DUMMY_HANDLE_CONST = DummyAsyncCommHandle() +class WrappedHandle: + """ + Handle precision conversion when async all_reduce or reduce_scatter + """ + def __init__(self, handle, output): + self.handle = handle + self.output = output + + def wait(self): + self.handle.wait() + if gpc.config.reduce_comm_dtype != gpc.config.model.dtype: + self.output.data = self.output.to(gpc.config.model.dtype) + self.output = None + + # Raw operation, does not support autograd, but does support async def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): + if gpc.config.reduce_comm_dtype != gpc.config.model.dtype: + input_ = input_.to(gpc.config.reduce_comm_dtype) input_ = input_.contiguous() handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op) + if gpc.config.reduce_comm_dtype != gpc.config.model.dtype: + if async_op is False: + input_ = input_.to(gpc.config.model.dtype) + else: + handle = WrappedHandle(handle=handle, output=input_) return input_, handle @@ -122,7 +144,11 @@ def _reduce(input_, parallel_mode): return input_ group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode) + if gpc.config.reduce_comm_dtype != gpc.config.model.dtype: + input_ = input_.to(gpc.config.reduce_comm_dtype) dist.all_reduce(input_, group=group) + if gpc.config.reduce_comm_dtype != gpc.config.model.dtype: + input_ = input_.to(gpc.config.model.dtype) return input_ @@ -241,6 +267,9 @@ def reduce_scatter_raw( if world_size <= 1: return input_, None + if gpc.config.reduce_comm_dtype != gpc.config.model.dtype: + input_ = input_.to(gpc.config.reduce_comm_dtype) + shape_list = list(input_.shape) shape_list[reduce_dim] = shape_list[reduce_dim] // world_size @@ -251,6 +280,11 @@ def reduce_scatter_raw( ).contiguous() handle = dist.reduce_scatter_tensor(output, input_.contiguous(), op=op, group=process_group, async_op=async_op) + if gpc.config.reduce_comm_dtype != gpc.config.model.dtype: + if async_op is False: + output = output.to(gpc.config.model.dtype) + else: + handle = WrappedHandle(handle=handle, output=output) return output, handle diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 84b94dbf..7e309beb 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -136,7 +136,11 @@ def _train_one_batch( # the moe_loss is computed among the "tensor" group if sequence parallel is enabled, # so we need to do allreduce if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp: + if gpc.config.reduce_comm_dtype != gpc.config.model.dtype: + moe_loss = moe_loss.to(gpc.config.reduce_comm_dtype) dist.all_reduce(moe_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR)) + if gpc.config.reduce_comm_dtype != gpc.config.model.dtype: + moe_loss = moe_loss.to(gpc.config.model.dtype) moe_loss.div_(gpc.get_world_size(ParallelMode.TENSOR)) moe_loss /= scale_loss loss /= scale_loss diff --git a/internlm/core/scheduler/pipeline_scheduler_1f1b.py b/internlm/core/scheduler/pipeline_scheduler_1f1b.py index 4864c77f..c4142955 100644 --- a/internlm/core/scheduler/pipeline_scheduler_1f1b.py +++ b/internlm/core/scheduler/pipeline_scheduler_1f1b.py @@ -316,7 +316,11 @@ def _forward_step( ) # the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp: + if gpc.config.reduce_comm_dtype != gpc.config.model.dtype: + moe_loss = moe_loss.to(gpc.config.reduce_comm_dtype) dist.all_reduce(moe_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR)) + if gpc.config.reduce_comm_dtype != gpc.config.model.dtype: + moe_loss = moe_loss.to(gpc.config.model.dtype) moe_loss.div_(gpc.get_world_size(ParallelMode.TENSOR)) moe_loss /= self.num_microbatches accum_moe_loss.add_(moe_loss.detach()) @@ -454,7 +458,11 @@ def _forward_only_step(self, engine, return_loss=True, return_output_label=True) output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None) if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: + if gpc.config.reduce_comm_dtype != gpc.config.model.dtype: + accum_moe_loss = accum_moe_loss.to(gpc.config.reduce_comm_dtype) dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) + if gpc.config.reduce_comm_dtype != gpc.config.model.dtype: + accum_moe_loss = accum_moe_loss.to(gpc.config.model.dtype) if accum_loss is not None: accum_loss += accum_moe_loss @@ -658,7 +666,11 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None) if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: + if gpc.config.reduce_comm_dtype != gpc.config.model.dtype: + accum_moe_loss = accum_moe_loss.to(gpc.config.reduce_comm_dtype) dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) + if gpc.config.reduce_comm_dtype != gpc.config.model.dtype: + accum_moe_loss = accum_moe_loss.to(gpc.config.model.dtype) if accum_loss is not None: accum_loss += accum_moe_loss @@ -879,7 +891,12 @@ def _forward_step(self, engine, chunk_id, input_obj=None): ) # the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp: + if gpc.config.reduce_comm_dtype != gpc.config.model.dtype: + moe_loss = moe_loss.to(gpc.config.reduce_comm_dtype) dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR)) + if gpc.config.reduce_comm_dtype != gpc.config.model.dtype: + moe_loss = moe_loss.to(gpc.config.model.dtype) + moe_loss /= self.num_microbatches if self._accum_moe_loss is not None: @@ -1410,7 +1427,11 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo output, label = (None, None) if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: + if gpc.config.reduce_comm_dtype != gpc.config.model.dtype: + self._accum_moe_loss = self._accum_moe_loss.to(gpc.config.reduce_comm_dtype) dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) + if gpc.config.reduce_comm_dtype != gpc.config.model.dtype: + self._accum_moe_loss = self._accum_moe_loss.to(gpc.config.model.dtype) accum_moe_loss = self._accum_moe_loss accum_loss = self._accum_loss diff --git a/internlm/core/scheduler/pipeline_scheduler_zb.py b/internlm/core/scheduler/pipeline_scheduler_zb.py index 75cf1844..4f8fb676 100644 --- a/internlm/core/scheduler/pipeline_scheduler_zb.py +++ b/internlm/core/scheduler/pipeline_scheduler_zb.py @@ -351,7 +351,11 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None) if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: + if gpc.config.reduce_comm_dtype != gpc.config.model.dtype: + accum_moe_loss = accum_moe_loss.to(gpc.config.reduce_comm_dtype) dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) + if gpc.config.reduce_comm_dtype != gpc.config.model.dtype: + accum_moe_loss = accum_moe_loss.to(gpc.config.model.dtype) if accum_loss is not None: accum_loss += accum_moe_loss @@ -901,7 +905,6 @@ def _run_steady_loop( else: next_unit_chunk_id = 1 - # import pdb; pdb.set_trace() if unit_step == num_units_stage1 - 1: chunk0_B_need_recv_prev_chunk0_output = False else: diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 1ac8ef31..034d55e5 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -455,6 +455,23 @@ def args_sanity_check(): gpc.config.parallel["expert_weight"]["overlap"] = False if gpc.config.parallel["expert"].get("no_tp", None) is None: gpc.config.parallel["expert"]["no_tp"] = False + + # the comm_dtype for reduce communication + if gpc.config.get("reduce_comm_dtype", None) is None: + gpc.config.reduce_comm_dtype = gpc.config.model.dtype + else: + if gpc.config.reduce_comm_dtype == "torch.bfloat16": + gpc.config.reduce_comm_dtype = torch.bfloat16 + elif gpc.config.reduce_comm_dtype == "torch.float32": + gpc.config.reduce_comm_dtype = torch.float32 + else: + assert gpc.config.reduce_comm_dtype in [ + "torch.bfloat16", + "torch.float32", + ] + if gpc.config.model.dtype == torch.float32: + assert gpc.config.reduce_comm_dtype == gpc.config.model.dtype + # currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1: assert ( diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index a0180a59..b111361c 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -81,6 +81,22 @@ def split_half_float_double(tensor_list): return buckets +class WrappedHandle: + """ + Handle precision conversion when async all_reduce or reduce_scatter + """ + def __init__(self, handle, output, dtype): + self.handle = handle + self.output = output + self.dtype = dtype + + def wait(self): + self.handle.wait() + if gpc.config.reduce_comm_dtype != self.dtype: + self.output.data = self.output.to(self.dtype) + self.output = None + + def reduce_tensor( tensor, dtype=None, @@ -106,13 +122,12 @@ def reduce_tensor( # use the original dtype # if dtype is None: assert dtype is None - dtype = tensor.dtype + dtype = gpc.config.reduce_comm_dtype + tensor_dtype = tensor.dtype # cast the data to specified dtype for reduce/all-reduce - # if tensor.dtype != dtype: - # tensor_to_reduce = tensor.to(dtype) - # else: - # tensor_to_reduce = tensor + if tensor_dtype != dtype: + tensor = tensor.to(dtype) # world_size = gpc.get_world_size(parallel_mode) # tensor.div_(world_size) @@ -129,6 +144,11 @@ def reduce_tensor( global_rank = ranks_in_group[dst_rank] handle = dist.reduce(tensor=tensor, dst=global_rank, group=group, op=op_type, async_op=async_op) + if tensor_dtype != dtype: + if async_op: + handle = WrappedHandle(handle=handle, output=tensor, dtype=tensor_dtype) + else: + tensor = tensor.to(tensor_dtype) return handle