-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added option to do backward AG over smaller set of gpus instead of full DDP world #1125
base: ngoyal_bf16_changes
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -332,6 +332,7 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
offload_config: Optional[OffloadConfig] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
state_dict_on_rank_0_only: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
gradient_predivide_factor: Optional[float] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
zero2_process_group: Optional[ProcessGroup] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import torch._C | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -380,6 +381,9 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"parameter uses all the available ranks for the optimal performance." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.zero2_process_group = zero2_process_group | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.disable_reshard_on_root = disable_reshard_on_root | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.mixed_precision = mixed_precision | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.fp32_reduce_scatter = fp32_reduce_scatter | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -518,6 +522,9 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if isinstance(m, FullyShardedDataParallel): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
m._free_ssd_offload() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.zero2_process_group is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert not self.move_params_to_cpu | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _get_gradient_predivide_factor(self, world_size: int) -> float: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
factor: int = 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
while world_size % factor == 0 and world_size / factor > factor: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1419,7 +1426,10 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
outputs = self.module(*args, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.reshard_after_forward: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._free_full_params() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.zero2_process_group is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._zero2_shard_to_smaller_group() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._free_full_params() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.mixed_precision or self.move_params_to_cpu: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._free_fp16_param_shard() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -1499,7 +1509,10 @@ def _pre_backward_hook(*unused: Any) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# idempotent. So in case they are called unnecessarily, they don't incur much | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# overhead. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.reshard_after_forward: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._rebuild_full_params() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.zero2_process_group is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._zero2_rebuild_full_params() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._rebuild_full_params() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.reshard_after_forward | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
and self._fsdp_forward_ordering is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -2006,6 +2019,126 @@ def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
torch.cuda.current_stream().wait_stream(self._streams["all_gather"]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return output_tensors | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@torch.no_grad() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _zero2_rebuild_full_params(self, force_full_precision: bool = False, wait_for_all_gather = True) -> Optional[List[Tuple[torch.Tensor, bool]]]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Gather all shards of params. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Note, this is idempotent if full params are already gathered. Callers | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assume the idempotency. So please keep it that way. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
force_full_precision (bool, Optional): by default params will be gathered | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
in ``compute_dtype`` (e.g., FP16), unless *force_full_precision* is | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
``True``, in which case they will be gathered in full precision | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
(e.g., FP32), possibly in fresh storage. The parameter that's being | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
rebuilt will end up in full precision as well. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
A list of tuples, where the first element is the full-sized param | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
and the second element is a bool indicating if it's safe for the | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
caller to free the full-sized param. This will be ``None`` if | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
``force_full_precision=False`` and the full params are already gathered. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
output_tensors: List[Tuple[torch.Tensor, bool]] = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Helper function to update p.data pointer. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
custom_output_tensor (torch.Tensor, Optional): if not None, this | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
tensor contains the data we just gathered. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if custom_output_tensor is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert p._is_sharded | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
p.data = custom_output_tensor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
output_tensors.append((p.data, True)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif not p._is_sharded: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert p._fp16_shard is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
p.data = p._fp16_shard | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
output_tensors.append((p.data, True)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Here p.data == p._fp32_shard, so it's not safe to free. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
output_tensors.append((p.data, False)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
p.data = p._full_param_padded | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
output_tensors.append((p.data, True)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Trim any padding and reshape to match original size. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
p.data = p.data[: p._orig_size.numel()].view(p._orig_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self._has_shared_params: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# self.has_full_params flag can be out of sync if a shared param is | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# sharded by another FSDP instance. An example is that in eval case | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# with reshard_after_forward=False but the sharing instance has | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# reshard_after_forward=True. Then, on the second forward, the | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# other instance can shard the shared param and but this instance | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# can mistakenly think the full param is already gathered from the | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# has_full_params flag. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Therefore, we update the flag accordingly here. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.has_full_params = not any(p._full_param_padded.storage().size() == 0 for p in self.params) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Early exit if we already have full params and don't need full precision. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.has_full_params and not force_full_precision: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if wait_for_all_gather: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
torch.cuda.current_stream().wait_stream(self._streams["all_gather"]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for p in self.params: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
update_p_data() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return output_tensors | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.has_full_params = True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
with torch.cuda.stream(self._streams["all_gather"]): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for p in self.params: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if not p._is_sharded: # e.g., when world_size == 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
update_p_data() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Skip if already built. Only shared param can be rebuilt multiple times. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# A corner case is p._orig_size = (1,), which means the shape equality is | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# not a perfect check. But we assume we don't share a param with shape (1,). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if p.data.shape == p._orig_size and hasattr(p, "_is_shared") and p._is_shared: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# If self.move_params_to_cpu and force_full_precision, we need to cast | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# the FP32 CPU param to CUDA for the all-gather. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
p_data = p.data.to(p._full_param_padded.device, non_blocking=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
p_size = p._full_param_padded.size() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert p_size.numel() % self.world_size == 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.mixed_precision and force_full_precision: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Allocate fresh tensor in full precision since we are in | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# mixed precision and full precision rebuild is asked. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
output_tensor = p_data.new_zeros(p_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if p._full_param_padded.storage().size() != p_size.numel(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Allocate based on full size from all shards. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
alloc_storage_(p._full_param_padded, size=p_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
output_tensor = p._full_param_padded | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Fill output_tensor with (p.data for each shard in self.world_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dist._all_gather_base(output_tensor, p._zero2_fp16_shard , group=self.zero2_process_group) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
chunks = list(output_tensor.chunk(self.world_size)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dist.all_gather(chunks, p._zero2_fp16_shard, group=self.zero2_process_group) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Set p.data = output_tensor (with padding trimmed) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
update_p_data(output_tensor) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._free_zero2_param_shard([p]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.move_params_to_cpu and (self.params[0].dtype == self.compute_dtype): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._free_zero2_param_shard([p]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if wait_for_all_gather: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
torch.cuda.current_stream().wait_stream(self._streams["all_gather"]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return output_tensors | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@torch.no_grad() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _use_full_params(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -2074,6 +2207,38 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
free_storage_(p._full_param_padded) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
torch.cuda.current_stream().synchronize() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _zero2_shard_to_smaller_group(self, params: Optional[List[Parameter]] = None): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if params is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
params = self.params | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.has_full_params = False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
current_stream = torch.cuda.current_stream() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for p in params: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if not p._is_sharded: # e.g., world_size == 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.mixed_precision or self.move_params_to_cpu: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._free_fp16_param_shard([p]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Cases for when zero2 world size > 1 but less than zero3 size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
zero2_world_size = dist.get_world_size(self.zero2_process_group) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
zero2_rank = dist.get_rank(self.zero2_process_group) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
chunks = p._full_param_padded.chunk(zero2_world_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just want to mention that there is a divisibility assumption here (ZeRO-2 world size divides the ZeRO-3 world size), which should always hold in practice. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
p._zero2_fp16_shard = torch.empty_like(chunks[zero2_rank]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
p._zero2_fp16_shard.copy_(chunks[zero2_rank]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Don't let PyTorch reuse this memory until all work in the current | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# stream is complete. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
p._full_param_padded.record_stream(current_stream) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# There may be external references to the Tensor Storage that we | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# can't modify, such as references that are created by | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# ctx.save_for_backward in the forward pass. Thus when we | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# unshard parameters, we should reuse the original Tensor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Storage object and unshard it in-place. For now, just resize | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# the Storage to 0 to save memory. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
free_storage_(p._full_param_padded) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
torch.cuda.current_stream().synchronize() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def local_metadata_dict(self) -> Dict[str, Any]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Get the information needed to reconstruct the model from shards offline. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -2238,6 +2403,19 @@ def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> No | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
p._fp16_shard.record_stream(current_stream) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
free_storage_(p._fp16_shard) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@torch.no_grad() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _free_zero2_param_shard(self, params: Optional[List[Parameter]] = None) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Free storage for FP16 shards for a list of params.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if params is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
params = self.params | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
current_stream = torch.cuda.current_stream() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for p in params: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if p._zero2_fp16_shard is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# _fp16_shard is allocated in "fp32_to_fp16" stream, so we can't | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# free it until the work in the current stream completes. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
p._zero2_fp16_shard.record_stream(current_stream) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
free_storage_(p._zero2_fp16_shard) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py Lines 1428 to 1430 in 0b77de4
_zero2_fp16_shard is consumed in the "all_gather" stream, and this _free_zero2_param_shard() is called from that "all_gather" stream as well:fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py Lines 2121 to 2136 in 0b77de4
In that case, I do agree this fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py Lines 2381 to 2391 in 0b77de4
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Assert we are in the given state.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Since assert can be turned off and this error checking | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the only difference compared to
_rebuild_full_params()
is no SSD offload, no CPU offload, and usingp._zero2_fp16_shard
,self.zero2_process_group
, andself._free_zero2_param_shard()
-- this makes sense to me.