Skip to content

Commit

Permalink
support fp32 all_reduce and reduce_scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
sallyjunjun committed Dec 6, 2024
1 parent 6b7df0b commit c22a1d1
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 6 deletions.
34 changes: 34 additions & 0 deletions internlm/core/parallel/comm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,32 @@ def wait(self) -> None:
DUMMY_HANDLE_CONST = DummyAsyncCommHandle()


class WrappedHandle:
"""
Handle precision conversion when async all_reduce or reduce_scatter
"""
def __init__(self, handle, output):
self.handle = handle
self.output = output

def wait(self):
self.handle.wait()
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
self.output.data = self.output.to(gpc.config.model.dtype)
self.output = None


# Raw operation, does not support autograd, but does support async
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
input_ = input_.to(gpc.config.reduce_comm_dtype)
input_ = input_.contiguous()
handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
if async_op is False:
input_ = input_.to(gpc.config.model.dtype)
else:
handle = WrappedHandle(handle=handle, output=input_)
return input_, handle


Expand Down Expand Up @@ -122,7 +144,11 @@ def _reduce(input_, parallel_mode):
return input_

group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode)
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
input_ = input_.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(input_, group=group)
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
input_ = input_.to(gpc.config.model.dtype)

return input_

Expand Down Expand Up @@ -241,6 +267,9 @@ def reduce_scatter_raw(
if world_size <= 1:
return input_, None

if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
input_ = input_.to(gpc.config.reduce_comm_dtype)

shape_list = list(input_.shape)
shape_list[reduce_dim] = shape_list[reduce_dim] // world_size

Expand All @@ -251,6 +280,11 @@ def reduce_scatter_raw(
).contiguous()

handle = dist.reduce_scatter_tensor(output, input_.contiguous(), op=op, group=process_group, async_op=async_op)
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
if async_op is False:
output = output.to(gpc.config.model.dtype)
else:
handle = WrappedHandle(handle=handle, output=output)
return output, handle


Expand Down
4 changes: 4 additions & 0 deletions internlm/core/scheduler/no_pipeline_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,11 @@ def _train_one_batch(
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled,
# so we need to do allreduce
if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp:
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
moe_loss = moe_loss.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(moe_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR))
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
moe_loss = moe_loss.to(gpc.config.model.dtype)
moe_loss.div_(gpc.get_world_size(ParallelMode.TENSOR))
moe_loss /= scale_loss
loss /= scale_loss
Expand Down
21 changes: 21 additions & 0 deletions internlm/core/scheduler/pipeline_scheduler_1f1b.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,11 @@ def _forward_step(
)
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce
if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp:
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
moe_loss = moe_loss.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(moe_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR))
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
moe_loss = moe_loss.to(gpc.config.model.dtype)
moe_loss.div_(gpc.get_world_size(ParallelMode.TENSOR))
moe_loss /= self.num_microbatches
accum_moe_loss.add_(moe_loss.detach())
Expand Down Expand Up @@ -454,7 +458,11 @@ def _forward_only_step(self, engine, return_loss=True, return_output_label=True)
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)

if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
accum_moe_loss = accum_moe_loss.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
accum_moe_loss = accum_moe_loss.to(gpc.config.model.dtype)

if accum_loss is not None:
accum_loss += accum_moe_loss
Expand Down Expand Up @@ -658,7 +666,11 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)

if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
accum_moe_loss = accum_moe_loss.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
accum_moe_loss = accum_moe_loss.to(gpc.config.model.dtype)

if accum_loss is not None:
accum_loss += accum_moe_loss
Expand Down Expand Up @@ -879,7 +891,12 @@ def _forward_step(self, engine, chunk_id, input_obj=None):
)
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce
if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp:
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
moe_loss = moe_loss.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR))
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
moe_loss = moe_loss.to(gpc.config.model.dtype)

moe_loss /= self.num_microbatches

if self._accum_moe_loss is not None:
Expand Down Expand Up @@ -1410,7 +1427,11 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo
output, label = (None, None)

if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
self._accum_moe_loss = self._accum_moe_loss.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
self._accum_moe_loss = self._accum_moe_loss.to(gpc.config.model.dtype)
accum_moe_loss = self._accum_moe_loss

accum_loss = self._accum_loss
Expand Down
5 changes: 4 additions & 1 deletion internlm/core/scheduler/pipeline_scheduler_zb.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,11 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)

if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
accum_moe_loss = accum_moe_loss.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
accum_moe_loss = accum_moe_loss.to(gpc.config.model.dtype)

if accum_loss is not None:
accum_loss += accum_moe_loss
Expand Down Expand Up @@ -901,7 +905,6 @@ def _run_steady_loop(
else:
next_unit_chunk_id = 1

# import pdb; pdb.set_trace()
if unit_step == num_units_stage1 - 1:
chunk0_B_need_recv_prev_chunk0_output = False
else:
Expand Down
17 changes: 17 additions & 0 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,23 @@ def args_sanity_check():
gpc.config.parallel["expert_weight"]["overlap"] = False
if gpc.config.parallel["expert"].get("no_tp", None) is None:
gpc.config.parallel["expert"]["no_tp"] = False

# the comm_dtype for reduce communication
if gpc.config.get("reduce_comm_dtype", None) is None:
gpc.config.reduce_comm_dtype = gpc.config.model.dtype
else:
if gpc.config.reduce_comm_dtype == "torch.bfloat16":
gpc.config.reduce_comm_dtype = torch.bfloat16
elif gpc.config.reduce_comm_dtype == "torch.float32":
gpc.config.reduce_comm_dtype = torch.float32
else:
assert gpc.config.reduce_comm_dtype in [
"torch.bfloat16",
"torch.float32",
]
if gpc.config.model.dtype == torch.float32:
assert gpc.config.reduce_comm_dtype == gpc.config.model.dtype

# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:
assert (
Expand Down
30 changes: 25 additions & 5 deletions internlm/solver/optimizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,22 @@ def split_half_float_double(tensor_list):
return buckets


class WrappedHandle:
"""
Handle precision conversion when async all_reduce or reduce_scatter
"""
def __init__(self, handle, output, dtype):
self.handle = handle
self.output = output
self.dtype = dtype

def wait(self):
self.handle.wait()
if gpc.config.reduce_comm_dtype != self.dtype:
self.output.data = self.output.to(self.dtype)
self.output = None


def reduce_tensor(
tensor,
dtype=None,
Expand All @@ -106,13 +122,12 @@ def reduce_tensor(
# use the original dtype
# if dtype is None:
assert dtype is None
dtype = tensor.dtype
dtype = gpc.config.reduce_comm_dtype
tensor_dtype = tensor.dtype

# cast the data to specified dtype for reduce/all-reduce
# if tensor.dtype != dtype:
# tensor_to_reduce = tensor.to(dtype)
# else:
# tensor_to_reduce = tensor
if tensor_dtype != dtype:
tensor = tensor.to(dtype)

# world_size = gpc.get_world_size(parallel_mode)
# tensor.div_(world_size)
Expand All @@ -129,6 +144,11 @@ def reduce_tensor(
global_rank = ranks_in_group[dst_rank]
handle = dist.reduce(tensor=tensor, dst=global_rank, group=group, op=op_type, async_op=async_op)

if tensor_dtype != dtype:
if async_op:
handle = WrappedHandle(handle=handle, output=tensor, dtype=tensor_dtype)
else:
tensor = tensor.to(tensor_dtype)
return handle


Expand Down

0 comments on commit c22a1d1

Please sign in to comment.