Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added option to do backward AG over smaller set of gpus instead of full DDP world #1125

Open
wants to merge 1 commit into
base: ngoyal_bf16_changes
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 180 additions & 2 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def __init__(
offload_config: Optional[OffloadConfig] = None,
state_dict_on_rank_0_only: bool = False,
gradient_predivide_factor: Optional[float] = None,
zero2_process_group: Optional[ProcessGroup] = None,
):
try:
import torch._C
Expand Down Expand Up @@ -380,6 +381,9 @@ def __init__(
"parameter uses all the available ranks for the optimal performance."
)
self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward

self.zero2_process_group = zero2_process_group

self.disable_reshard_on_root = disable_reshard_on_root
self.mixed_precision = mixed_precision
self.fp32_reduce_scatter = fp32_reduce_scatter
Expand Down Expand Up @@ -518,6 +522,9 @@ def __init__(
if isinstance(m, FullyShardedDataParallel):
m._free_ssd_offload()

if self.zero2_process_group is not None:
assert not self.move_params_to_cpu

def _get_gradient_predivide_factor(self, world_size: int) -> float:
factor: int = 1
while world_size % factor == 0 and world_size / factor > factor:
Expand Down Expand Up @@ -1419,7 +1426,10 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
outputs = self.module(*args, **kwargs)

if self.reshard_after_forward:
self._free_full_params()
if self.zero2_process_group is not None:
self._zero2_shard_to_smaller_group()
else:
self._free_full_params()
if self.mixed_precision or self.move_params_to_cpu:
self._free_fp16_param_shard()

Expand Down Expand Up @@ -1499,7 +1509,10 @@ def _pre_backward_hook(*unused: Any) -> None:
# idempotent. So in case they are called unnecessarily, they don't incur much
# overhead.
if self.reshard_after_forward:
self._rebuild_full_params()
if self.zero2_process_group is not None:
self._zero2_rebuild_full_params()
else:
self._rebuild_full_params()
if (
self.reshard_after_forward
and self._fsdp_forward_ordering is not None
Expand Down Expand Up @@ -2006,6 +2019,126 @@ def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
return output_tensors


@torch.no_grad()
def _zero2_rebuild_full_params(self, force_full_precision: bool = False, wait_for_all_gather = True) -> Optional[List[Tuple[torch.Tensor, bool]]]:
"""
Gather all shards of params.

Note, this is idempotent if full params are already gathered. Callers
assume the idempotency. So please keep it that way.

Args:
force_full_precision (bool, Optional): by default params will be gathered
in ``compute_dtype`` (e.g., FP16), unless *force_full_precision* is
``True``, in which case they will be gathered in full precision
(e.g., FP32), possibly in fresh storage. The parameter that's being
rebuilt will end up in full precision as well.

Returns:
A list of tuples, where the first element is the full-sized param
and the second element is a bool indicating if it's safe for the
caller to free the full-sized param. This will be ``None`` if
``force_full_precision=False`` and the full params are already gathered.
"""
output_tensors: List[Tuple[torch.Tensor, bool]] = []

def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
"""
Helper function to update p.data pointer.

Args:
custom_output_tensor (torch.Tensor, Optional): if not None, this
tensor contains the data we just gathered.
"""
if custom_output_tensor is not None:
assert p._is_sharded
p.data = custom_output_tensor
output_tensors.append((p.data, True))
elif not p._is_sharded:
if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision:
assert p._fp16_shard is not None
p.data = p._fp16_shard
output_tensors.append((p.data, True))
else:
# Here p.data == p._fp32_shard, so it's not safe to free.
output_tensors.append((p.data, False))
else:
p.data = p._full_param_padded
output_tensors.append((p.data, True))
# Trim any padding and reshape to match original size.
p.data = p.data[: p._orig_size.numel()].view(p._orig_size)

if self._has_shared_params:
# self.has_full_params flag can be out of sync if a shared param is
# sharded by another FSDP instance. An example is that in eval case
# with reshard_after_forward=False but the sharing instance has
# reshard_after_forward=True. Then, on the second forward, the
# other instance can shard the shared param and but this instance
# can mistakenly think the full param is already gathered from the
# has_full_params flag.
#
# Therefore, we update the flag accordingly here.
self.has_full_params = not any(p._full_param_padded.storage().size() == 0 for p in self.params)

# Early exit if we already have full params and don't need full precision.
if self.has_full_params and not force_full_precision:
if wait_for_all_gather:
torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
for p in self.params:
update_p_data()
return output_tensors

self.has_full_params = True

with torch.cuda.stream(self._streams["all_gather"]):

for p in self.params:
if not p._is_sharded: # e.g., when world_size == 1
update_p_data()
else:
# Skip if already built. Only shared param can be rebuilt multiple times.
# A corner case is p._orig_size = (1,), which means the shape equality is
# not a perfect check. But we assume we don't share a param with shape (1,).
if p.data.shape == p._orig_size and hasattr(p, "_is_shared") and p._is_shared:
continue
# If self.move_params_to_cpu and force_full_precision, we need to cast
# the FP32 CPU param to CUDA for the all-gather.
p_data = p.data.to(p._full_param_padded.device, non_blocking=True)

p_size = p._full_param_padded.size()
assert p_size.numel() % self.world_size == 0
if self.mixed_precision and force_full_precision:
# Allocate fresh tensor in full precision since we are in
# mixed precision and full precision rebuild is asked.
output_tensor = p_data.new_zeros(p_size)
else:
if p._full_param_padded.storage().size() != p_size.numel():
# Allocate based on full size from all shards.
alloc_storage_(p._full_param_padded, size=p_size)
output_tensor = p._full_param_padded

# Fill output_tensor with (p.data for each shard in self.world_size)
if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives:
# New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
dist._all_gather_base(output_tensor, p._zero2_fp16_shard , group=self.zero2_process_group)
else:
chunks = list(output_tensor.chunk(self.world_size))
dist.all_gather(chunks, p._zero2_fp16_shard, group=self.zero2_process_group)

# Set p.data = output_tensor (with padding trimmed)
update_p_data(output_tensor)

if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision:
self._free_zero2_param_shard([p])

if self.move_params_to_cpu and (self.params[0].dtype == self.compute_dtype):
self._free_zero2_param_shard([p])
if wait_for_all_gather:
torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
return output_tensors

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the only difference compared to _rebuild_full_params() is no SSD offload, no CPU offload, and using p._zero2_fp16_shard, self.zero2_process_group, and self._free_zero2_param_shard() -- this makes sense to me.


@torch.no_grad()
def _use_full_params(self) -> None:
"""
Expand Down Expand Up @@ -2074,6 +2207,38 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
free_storage_(p._full_param_padded)
torch.cuda.current_stream().synchronize()


def _zero2_shard_to_smaller_group(self, params: Optional[List[Parameter]] = None):
if params is None:
params = self.params
self.has_full_params = False
current_stream = torch.cuda.current_stream()
for p in params:
if not p._is_sharded: # e.g., world_size == 1
if self.mixed_precision or self.move_params_to_cpu:
self._free_fp16_param_shard([p])
continue
# Cases for when zero2 world size > 1 but less than zero3 size
zero2_world_size = dist.get_world_size(self.zero2_process_group)
zero2_rank = dist.get_rank(self.zero2_process_group)
chunks = p._full_param_padded.chunk(zero2_world_size)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just want to mention that there is a divisibility assumption here (ZeRO-2 world size divides the ZeRO-3 world size), which should always hold in practice.


p._zero2_fp16_shard = torch.empty_like(chunks[zero2_rank])
p._zero2_fp16_shard.copy_(chunks[zero2_rank])

# Don't let PyTorch reuse this memory until all work in the current
# stream is complete.
p._full_param_padded.record_stream(current_stream)
# There may be external references to the Tensor Storage that we
# can't modify, such as references that are created by
# ctx.save_for_backward in the forward pass. Thus when we
# unshard parameters, we should reuse the original Tensor
# Storage object and unshard it in-place. For now, just resize
# the Storage to 0 to save memory.
free_storage_(p._full_param_padded)
torch.cuda.current_stream().synchronize()


def local_metadata_dict(self) -> Dict[str, Any]:
"""
Get the information needed to reconstruct the model from shards offline.
Expand Down Expand Up @@ -2238,6 +2403,19 @@ def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> No
p._fp16_shard.record_stream(current_stream)
free_storage_(p._fp16_shard)

@torch.no_grad()
def _free_zero2_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
"""Free storage for FP16 shards for a list of params."""
if params is None:
params = self.params
current_stream = torch.cuda.current_stream()
for p in params:
if p._zero2_fp16_shard is not None:
# _fp16_shard is allocated in "fp32_to_fp16" stream, so we can't
# free it until the work in the current stream completes.
p._zero2_fp16_shard.record_stream(current_stream)
free_storage_(p._zero2_fp16_shard)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like _zero2_fp16_shard is allocated in the default stream (since _zero2_shard_to_smaller_group() is called from forward() without an explicit stream context manager):

if self.reshard_after_forward:
if self.zero2_process_group is not None:
self._zero2_shard_to_smaller_group()

_zero2_fp16_shard is consumed in the "all_gather" stream, and this _free_zero2_param_shard() is called from that "all_gather" stream as well:
# Fill output_tensor with (p.data for each shard in self.world_size)
if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives:
# New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
dist._all_gather_base(output_tensor, p._zero2_fp16_shard , group=self.zero2_process_group)
else:
chunks = list(output_tensor.chunk(self.world_size))
dist.all_gather(chunks, p._zero2_fp16_shard, group=self.zero2_process_group)
# Set p.data = output_tensor (with padding trimmed)
update_p_data(output_tensor)
if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision:
self._free_zero2_param_shard([p])
if self.move_params_to_cpu and (self.params[0].dtype == self.compute_dtype):
self._free_zero2_param_shard([p])

In that case, I do agree this p._zero2_fp16_shard.record_stream(current_stream) call is necessary to notify the caching allocator of the usage in the "all_gather" stream. However, I think the comment can be changed to say that it was allocated in the default stream. Alternatively, you can do something like _cast_fp32_param_shards_to_fp16(), but I am not sure if there is any actual overlap opportunity given the data dependencies.

with torch.cuda.stream(self._streams["fp32_to_fp16"]):
for p in params:
assert p._fp16_shard is not None
alloc_storage_(p._fp16_shard, size=p._fp32_shard.size())
p._fp16_shard.copy_(
# If move_params_to_cpu is True, this will be non-blocking
# because _fp32_shard is pinned, otherwise it's a no-op.
p._fp32_shard.to(p._fp16_shard.device, non_blocking=True)
)
p.data = p._fp16_shard
torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"])

def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None:
"""Assert we are in the given state."""
# Since assert can be turned off and this error checking
Expand Down