Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support fp32 all_reduce and reduce_scatter #389

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions internlm/core/parallel/comm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,33 @@ def wait(self) -> None:
DUMMY_HANDLE_CONST = DummyAsyncCommHandle()


class WrappedHandle:
"""
Handle precision conversion when async all_reduce or reduce_scatter
"""

def __init__(self, handle, output):
self.handle = handle
self.output = output

def wait(self):
self.handle.wait()
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
self.output.data = self.output.to(gpc.config.model.dtype)
self.output = None


# Raw operation, does not support autograd, but does support async
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
input_ = input_.to(gpc.config.reduce_comm_dtype)
input_ = input_.contiguous()
handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
if async_op is False:
input_ = input_.to(gpc.config.model.dtype)
else:
handle = WrappedHandle(handle=handle, output=input_)
return input_, handle


Expand Down Expand Up @@ -122,7 +145,11 @@ def _reduce(input_, parallel_mode):
return input_

group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode)
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
input_ = input_.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(input_, group=group)
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
input_ = input_.to(gpc.config.model.dtype)

return input_

Expand Down Expand Up @@ -241,6 +268,9 @@ def reduce_scatter_raw(
if world_size <= 1:
return input_, None

if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
input_ = input_.to(gpc.config.reduce_comm_dtype)

shape_list = list(input_.shape)
shape_list[reduce_dim] = shape_list[reduce_dim] // world_size

Expand All @@ -251,6 +281,11 @@ def reduce_scatter_raw(
).contiguous()

handle = dist.reduce_scatter_tensor(output, input_.contiguous(), op=op, group=process_group, async_op=async_op)
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
if async_op is False:
output = output.to(gpc.config.model.dtype)
else:
handle = WrappedHandle(handle=handle, output=output)
return output, handle


Expand Down
4 changes: 4 additions & 0 deletions internlm/core/scheduler/no_pipeline_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,11 @@ def _train_one_batch(
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled,
# so we need to do allreduce
if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp:
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
moe_loss = moe_loss.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(moe_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR))
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
moe_loss = moe_loss.to(gpc.config.model.dtype)
moe_loss.div_(gpc.get_world_size(ParallelMode.TENSOR))
moe_loss /= scale_loss
loss /= scale_loss
Expand Down
21 changes: 21 additions & 0 deletions internlm/core/scheduler/pipeline_scheduler_1f1b.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,11 @@ def _forward_step(
)
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce
if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp:
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
moe_loss = moe_loss.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(moe_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR))
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
moe_loss = moe_loss.to(gpc.config.model.dtype)
moe_loss.div_(gpc.get_world_size(ParallelMode.TENSOR))
moe_loss /= self.num_microbatches
accum_moe_loss.add_(moe_loss.detach())
Expand Down Expand Up @@ -454,7 +458,11 @@ def _forward_only_step(self, engine, return_loss=True, return_output_label=True)
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)

if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
accum_moe_loss = accum_moe_loss.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
accum_moe_loss = accum_moe_loss.to(gpc.config.model.dtype)

if accum_loss is not None:
accum_loss += accum_moe_loss
Expand Down Expand Up @@ -658,7 +666,11 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)

if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
accum_moe_loss = accum_moe_loss.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
accum_moe_loss = accum_moe_loss.to(gpc.config.model.dtype)

if accum_loss is not None:
accum_loss += accum_moe_loss
Expand Down Expand Up @@ -879,7 +891,12 @@ def _forward_step(self, engine, chunk_id, input_obj=None):
)
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce
if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp:
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
moe_loss = moe_loss.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR))
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
moe_loss = moe_loss.to(gpc.config.model.dtype)

moe_loss /= self.num_microbatches

if self._accum_moe_loss is not None:
Expand Down Expand Up @@ -1410,7 +1427,11 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo
output, label = (None, None)

if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
self._accum_moe_loss = self._accum_moe_loss.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
self._accum_moe_loss = self._accum_moe_loss.to(gpc.config.model.dtype)
accum_moe_loss = self._accum_moe_loss

accum_loss = self._accum_loss
Expand Down
5 changes: 4 additions & 1 deletion internlm/core/scheduler/pipeline_scheduler_zb.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,11 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)

if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
accum_moe_loss = accum_moe_loss.to(gpc.config.reduce_comm_dtype)
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
if gpc.config.reduce_comm_dtype != gpc.config.model.dtype:
accum_moe_loss = accum_moe_loss.to(gpc.config.model.dtype)

if accum_loss is not None:
accum_loss += accum_moe_loss
Expand Down Expand Up @@ -901,7 +905,6 @@ def _run_steady_loop(
else:
next_unit_chunk_id = 1

# import pdb; pdb.set_trace()
if unit_step == num_units_stage1 - 1:
chunk0_B_need_recv_prev_chunk0_output = False
else:
Expand Down
17 changes: 17 additions & 0 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,23 @@ def args_sanity_check():
gpc.config.parallel["expert_weight"]["overlap"] = False
if gpc.config.parallel["expert"].get("no_tp", None) is None:
gpc.config.parallel["expert"]["no_tp"] = False

# the comm_dtype for reduce communication
if gpc.config.get("reduce_comm_dtype", None) is None:
gpc.config.reduce_comm_dtype = gpc.config.model.dtype
else:
if gpc.config.reduce_comm_dtype == "torch.bfloat16":
gpc.config.reduce_comm_dtype = torch.bfloat16
elif gpc.config.reduce_comm_dtype == "torch.float32":
gpc.config.reduce_comm_dtype = torch.float32
else:
assert gpc.config.reduce_comm_dtype in [
"torch.bfloat16",
"torch.float32",
]
if gpc.config.model.dtype == torch.float32:
assert gpc.config.reduce_comm_dtype == gpc.config.model.dtype

# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:
assert (
Expand Down
31 changes: 26 additions & 5 deletions internlm/solver/optimizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,23 @@ def split_half_float_double(tensor_list):
return buckets


class WrappedHandle:
"""
Handle precision conversion when async all_reduce or reduce_scatter
"""

def __init__(self, handle, output, dtype):
self.handle = handle
self.output = output
self.dtype = dtype

def wait(self):
self.handle.wait()
if gpc.config.reduce_comm_dtype != self.dtype:
self.output.data = self.output.to(self.dtype)
self.output = None


def reduce_tensor(
tensor,
dtype=None,
Expand All @@ -106,13 +123,12 @@ def reduce_tensor(
# use the original dtype
# if dtype is None:
assert dtype is None
dtype = tensor.dtype
dtype = gpc.config.reduce_comm_dtype
tensor_dtype = tensor.dtype

# cast the data to specified dtype for reduce/all-reduce
# if tensor.dtype != dtype:
# tensor_to_reduce = tensor.to(dtype)
# else:
# tensor_to_reduce = tensor
if tensor_dtype != dtype:
tensor = tensor.to(dtype)

# world_size = gpc.get_world_size(parallel_mode)
# tensor.div_(world_size)
Expand All @@ -129,6 +145,11 @@ def reduce_tensor(
global_rank = ranks_in_group[dst_rank]
handle = dist.reduce(tensor=tensor, dst=global_rank, group=group, op=op_type, async_op=async_op)

if tensor_dtype != dtype:
if async_op:
handle = WrappedHandle(handle=handle, output=tensor, dtype=tensor_dtype)
else:
tensor = tensor.to(tensor_dtype)
return handle


Expand Down
Loading