diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py index 6cfd66f1cc38..8a14373b2e72 100644 --- a/deepspeed/runtime/base_optimizer.py +++ b/deepspeed/runtime/base_optimizer.py @@ -8,7 +8,7 @@ from deepspeed.utils import logger from deepspeed.utils.tensor_fragment import map_to_flat_opt_states -from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank +from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank, zero_grad_params class DeepSpeedOptimizer(object): @@ -17,6 +17,9 @@ class DeepSpeedOptimizer(object): class ZeROOptimizer(DeepSpeedOptimizer): + def __init__(self): + self.force_overwrite_grads = False + def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None: checkpoint_dir = os.path.join(checkpoint_dir, "zero") optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") @@ -61,3 +64,8 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec if key == 'params': continue param_group[key] = value + + def _do_zero_grad(self, params, set_to_none_fn, set_to_none: bool = True, force: bool = False) -> None: + zero_grad_params(params, set_to_none_fn, self.is_gradient_accumulation_boundary, set_to_none, force) + # Flag to indicate that the reduced gradients should be copied to the buffer, not accumulated + self.force_overwrite_grads = True diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 61e6da2663cf..bf76fed17ba8 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -74,7 +74,7 @@ from deepspeed.utils.debug import debug_extract_module_and_param_names, debug_clear_module_and_param_names from deepspeed.monitor.monitor import MonitorMaster from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop -from deepspeed.runtime.utils import clip_grad_norm_ +from deepspeed.runtime.utils import clip_grad_norm_, zero_grad_with_grad_acc_boundary_check from deepspeed.runtime.eigenvalue import Eigenvalue from deepspeed.runtime.data_pipeline.constants import DATA_SAMPLING, \ DATA_ROUTING, DATA_SAMPLING_ENABLED, CURRICULUM_LEARNING, \ @@ -2097,12 +2097,22 @@ def set_gradient_accumulation_boundary(self, is_boundary): self._is_gradient_accumulation_boundary = is_boundary self.optimizer.is_gradient_accumulation_boundary = is_boundary - def zero_grad(self): + def zero_grad(self, set_to_none: bool = True, force: bool = False) -> None: """ Zero parameter grads. """ - for param_name, param in self.module.named_parameters(): - param.grad = None + # zero grad in basic optimizer could be unreliable and may not exhibit + # the behavior that we want + if self.bfloat16_enabled(): + # TODO: Temporary until bf16_optimizer and zero_optimizer are integrated + if self.zero_optimization() and hasattr(self.optimizer, "zero_grad"): + self.optimizer.zero_grad(set_to_none, force) + else: + pass + elif self.zero_optimization(): + self.optimizer.zero_grad(set_to_none, force) + else: + zero_grad_with_grad_acc_boundary_check(self.optimizer, self.is_gradient_accumulation_boundary(), force) def clip_fp32_gradients(self): clip_grad_norm_(parameters=self.module.parameters(), max_norm=self.gradient_clipping(), mpu=self.mpu) @@ -2132,18 +2142,8 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}): self.eigenvalue_enabled(), block_eigenvalue, ) - # zero grad in basic optimizer could be unreliable and may not exhibit - # the behavior that we want - if self.bfloat16_enabled(): - # TODO: Temporary until bf16_optimizer and zero_optimizer are integrated - if self.zero_optimization() and hasattr(self.optimizer, "zero_grad"): - self.optimizer.zero_grad() - else: - pass - elif self.zero_optimization() or self.fp16_enabled() or self.amp_enabled(): - self.optimizer.zero_grad() - else: - self.zero_grad() + + self.zero_grad(force=True) report_progress = self.global_rank == 0 if self.global_rank else True diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 2c01c3475a70..22b938670bad 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -1065,3 +1065,41 @@ def to_tensor(v): total_norm = -1 return total_norm + + +warn_zero_grad_shown = False + + +def warn_zero_grad() -> None: + global warn_zero_grad_shown + if not warn_zero_grad_shown: + msg = "zero_grad() was called but gradients are not cleared because " \ + "the current iteration is not a gradient accumulation boundary. " \ + "If you want to clear gradients, please set force=True." + logger.info(msg) + warn_zero_grad_shown = True + return + + +def zero_grad_params(params, set_to_none_fn, is_gradient_accumulation_boundary: bool, set_to_none: bool, + force: bool) -> None: + if not is_gradient_accumulation_boundary and not force: + warn_zero_grad() + return + + for param in params: + if set_to_none: + set_to_none_fn(param) + else: + if param.grad is not None: + param.grad.detach_() + param.grad.zero_() + + +def zero_grad_with_grad_acc_boundary_check(optimizer: torch.optim.Optimizer, is_gradient_accumulation_boundary: bool, + force: bool) -> None: + if not is_gradient_accumulation_boundary and not force: + warn_zero_grad() + return + + optimizer.zero_grad() diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 796957a4c6e5..4b78d7d7c533 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -103,7 +103,7 @@ def unwrap_model_for_generation(model): return -INITIAL_MICRO_STEP_ID = -1 +INITIAL_MICRO_STEP_ID = 0 class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer): @@ -156,6 +156,8 @@ def __init__( zero_quantized_weights=False, zero_quantized_nontrainable_weights=False, ): + super().__init__() + see_memory_usage("Stage 3 initialize beginning", force=True) print_rank_0(f"initialized {__class__.__name__} with args: {locals()}", force=False) @@ -293,7 +295,7 @@ def __init__( self.gradient_predivide_factor = gradient_predivide_factor self.postscale_gradients = postscale_gradients self.gradient_accumulation_steps = gradient_accumulation_steps - self.micro_step_id = 0 + self.micro_step_id = INITIAL_MICRO_STEP_ID self.reduce_bucket_size = int(reduce_bucket_size) if self.all2all_process_group is not None: @@ -1463,7 +1465,7 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L # move or accumulate gradient partition to target buffer grad_buffer = self.__param_id_to_grad_partition[param.ds_id].narrow(0, 0, grad_partition.numel()) buffers.append(grad_buffer) - if self.micro_step_id == 0: # don't accumulate + if self.micro_step_id == 0 or self.force_overwrite_grads: # don't accumulate grad_buffer.copy_(grad_partition, non_blocking=True) # ensure grad buffer is a CUDA buffer to speed up the next few # operations and so it can be used asynchronously @@ -1719,24 +1721,15 @@ def get_partition_info(self, tensor_list, partition_size, partition_id): return params_in_partition, params_not_in_partition, first_offset @instrument_w_nvtx - def zero_grad(self, set_to_none=True): - """ - Zero FP16 parameter grads. - """ - self.micro_step_id = 0 - - # FP32 grad should never exist. - # For speed, set model fp16 grad to None by default - for group in self.fp16_groups: - for p in group: - if set_to_none: - if p.grad is not None and get_accelerator().on_accelerator(p.grad): - p.grad.record_stream(get_accelerator().current_stream()) - p.grad = None - else: - if p.grad is not None: - p.grad.detach_() - p.grad.zero_() + def zero_grad(self, set_to_none=True, force=False): + + def set_grad_to_none(p): + if p.grad is not None and get_accelerator().on_accelerator(p.grad): + p.grad.record_stream(get_accelerator().current_stream()) + p.grad = None + + params = [p for group in self.fp16_groups for p in group] + self._do_zero_grad(params, set_grad_to_none, set_to_none, force) def _model_parallel_all_reduce(self, tensor, op): """ Perform all reduce within model parallel group, if any. @@ -1856,7 +1849,7 @@ def reset_cpu_buffers(self): self.norm_for_param_grads = {} def _pre_step(self): - self.micro_step_id = 0 + self.micro_step_id = INITIAL_MICRO_STEP_ID print_rank_0(f"Inside Step function") see_memory_usage(f"In step before checking overflow", force=False) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index df7a2f83e3bc..dc45782da064 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -138,6 +138,8 @@ def __init__(self, fp16_master_weights_and_gradients=False, elastic_checkpoint=False): + super().__init__() + if offload_optimizer_config is not None and offload_optimizer_config.device != OffloadDeviceEnum.none: self.cpu_offload = True self.cpu_offload_pin_memory = offload_optimizer_config.pin_memory @@ -1230,7 +1232,7 @@ def copy_gradients_to_cpu(): if param_id not in self.accumulated_grads_in_cpu: self.accumulated_grads_in_cpu[param_id] = buffer_to_accumulate_to_in_cpu() - if self.micro_step_id > 0: + if self.micro_step_id > 0 or self.force_overwrite_grads: accumulate_gradients() else: copy_gradients_to_cpu() @@ -1416,6 +1418,7 @@ def reduce_ipg_grads(self): self.params_in_ipg_bucket = [] self.ipg_bucket_has_moe_params = False self.elements_in_ipg_bucket = 0 + self.force_overwrite_grads = False ##################################################################### def reduce_ready_partitions_and_remove_grads(self, param, i): @@ -1632,22 +1635,17 @@ def get_partition_info(self, tensor_list, partition_size, partition_id): return params_in_partition, params_not_in_partition, first_offset - def zero_grad(self, set_to_none=True): + def zero_grad(self, set_to_none=True, force=False): """ Zero FP16 parameter grads. """ - # FP32 grad should never exist. - # For speed, set model fp16 grad to None by default - # zero all pointers to grad tensors - for group in self.bit16_groups: - for p in group: - if set_to_none: - p.grad = None # epilogue and in step - p.grad_accum = None - else: - if p.grad is not None: - p.grad.detach_() - p.grad.zero_() + + def set_grad_to_none(p): + p.grad = None # epilogue and in step + p.grad_accum = None + + params = [p for group in self.bit16_groups for p in group] + self._do_zero_grad(params, set_grad_to_none, set_to_none, force) def _model_parallel_all_reduce(self, tensor, op): """ Perform all reduce within model parallel group, if any.