From 184898351629206815446f503fda3ada9c530789 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 18 Sep 2024 07:58:33 +0000 Subject: [PATCH 1/7] fix gradient accumulation for z2+offload --- deepspeed/runtime/zero/stage_1_and_2.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 83cf996ca019..df7a2f83e3bc 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -39,6 +39,7 @@ OPTIMIZER_GRADIENTS_TIMER = 'optimizer_gradients' OPTIMIZER_STEP_TIMER = 'optimizer_step' OPTIMIZER_TIMERS = [OPTIMIZER_ALLGATHER_TIMER, OPTIMIZER_GRADIENTS_TIMER, OPTIMIZER_STEP_TIMER] +INITIAL_MICRO_STEP_ID = -1 def input(msg): @@ -224,7 +225,7 @@ def __init__(self, self.gradient_predivide_factor = gradient_predivide_factor self.postscale_gradients = postscale_gradients self.gradient_accumulation_steps = gradient_accumulation_steps - self.micro_step_id = 0 + self.micro_step_id = INITIAL_MICRO_STEP_ID self.ignore_unused_parameters = ignore_unused_parameters self.round_robin_gradients = round_robin_gradients @@ -1231,9 +1232,7 @@ def copy_gradients_to_cpu(): if self.micro_step_id > 0: accumulate_gradients() - - # at the boundary we will send 32bit directly - if not self.is_gradient_accumulation_boundary: + else: copy_gradients_to_cpu() def set_norm_for_param_grad(self, param): @@ -1824,7 +1823,7 @@ def step(self, closure=None): """ Not supporting closure. """ - self.micro_step_id = -1 + self.micro_step_id = INITIAL_MICRO_STEP_ID see_memory_usage(f"In step before checking overflow") From 84ca923c8e72a894ee5471fc60b0902a7299e315 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 18 Sep 2024 20:03:36 +0000 Subject: [PATCH 2/7] improve consistency of zero_grad --- deepspeed/runtime/base_optimizer.py | 10 +++++- deepspeed/runtime/engine.py | 37 ++++++++++++---------- deepspeed/runtime/utils.py | 31 ++++++++++++++++++- deepspeed/runtime/zero/stage3.py | 41 ++++++++++++------------- deepspeed/runtime/zero/stage_1_and_2.py | 21 +++++-------- 5 files changed, 87 insertions(+), 53 deletions(-) diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py index 6cfd66f1cc38..d494008bb315 100644 --- a/deepspeed/runtime/base_optimizer.py +++ b/deepspeed/runtime/base_optimizer.py @@ -5,10 +5,11 @@ import os import torch +from typing import Callable, Iterable from deepspeed.utils import logger from deepspeed.utils.tensor_fragment import map_to_flat_opt_states -from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank +from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank, zero_grad_params class DeepSpeedOptimizer(object): @@ -61,3 +62,10 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec if key == 'params': continue param_group[key] = value + + def _do_zero_grad(self, + params: Iterable[torch.nn.Parameter], + set_to_none_fn: Callable[[torch.Tensor], None], + set_to_none: bool = True, + force: bool = False) -> None: + zero_grad_params(params, set_to_none_fn, self.is_gradient_accumulation_boundary, set_to_none, force) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 61e6da2663cf..643f44d47065 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -74,7 +74,7 @@ from deepspeed.utils.debug import debug_extract_module_and_param_names, debug_clear_module_and_param_names from deepspeed.monitor.monitor import MonitorMaster from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop -from deepspeed.runtime.utils import clip_grad_norm_ +from deepspeed.runtime.utils import clip_grad_norm_, zero_grad_params from deepspeed.runtime.eigenvalue import Eigenvalue from deepspeed.runtime.data_pipeline.constants import DATA_SAMPLING, \ DATA_ROUTING, DATA_SAMPLING_ENABLED, CURRICULUM_LEARNING, \ @@ -2097,12 +2097,27 @@ def set_gradient_accumulation_boundary(self, is_boundary): self._is_gradient_accumulation_boundary = is_boundary self.optimizer.is_gradient_accumulation_boundary = is_boundary - def zero_grad(self): + def zero_grad(self, set_to_none: bool = True, force: bool = False) -> None: """ Zero parameter grads. """ - for param_name, param in self.module.named_parameters(): - param.grad = None + # zero grad in basic optimizer could be unreliable and may not exhibit + # the behavior that we want + if self.bfloat16_enabled(): + # TODO: Temporary until bf16_optimizer and zero_optimizer are integrated + if self.zero_optimization() and hasattr(self.optimizer, "zero_grad"): + self.optimizer.zero_grad(set_to_none, force) + else: + pass + elif self.zero_optimization() or self.fp16_enabled() or self.amp_enabled(): + self.optimizer.zero_grad(set_to_none, force) + else: + + def set_to_none_fn(param): + param.grad = None + + zero_grad_params(self.module.parameters(), set_to_none_fn, self.is_gradient_accumulation_boundary(), + set_to_none, force) def clip_fp32_gradients(self): clip_grad_norm_(parameters=self.module.parameters(), max_norm=self.gradient_clipping(), mpu=self.mpu) @@ -2132,18 +2147,8 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}): self.eigenvalue_enabled(), block_eigenvalue, ) - # zero grad in basic optimizer could be unreliable and may not exhibit - # the behavior that we want - if self.bfloat16_enabled(): - # TODO: Temporary until bf16_optimizer and zero_optimizer are integrated - if self.zero_optimization() and hasattr(self.optimizer, "zero_grad"): - self.optimizer.zero_grad() - else: - pass - elif self.zero_optimization() or self.fp16_enabled() or self.amp_enabled(): - self.optimizer.zero_grad() - else: - self.zero_grad() + + self.zero_grad(force=True) report_progress = self.global_rank == 0 if self.global_rank else True diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 2c01c3475a70..91f9a0bdbd4f 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -8,7 +8,7 @@ Helper functions and classes from multiple sources. """ -from collections.abc import Iterable +from collections.abc import Iterable, Callable from deepspeed.moe.utils import is_moe_param import os import psutil @@ -1065,3 +1065,32 @@ def to_tensor(v): total_norm = -1 return total_norm + + +warn_zero_grad_shown = False + + +def warn_zero_grad() -> None: + global warn_zero_grad_shown + if not warn_zero_grad_shown: + msg = "zero_grad() was called but gradients are not cleared because " \ + "the current iteration is not a gradient accumulation boundary. " \ + "If you want to clear gradients, please set force=True." + logger.info(msg) + warn_zero_grad_shown = True + return + + +def zero_grad_params(params: Iterable[torch.nn.Parameter], set_to_none_fn: Callable[[torch.Tensor], None], + is_gradient_accumulation_boundary: bool, set_to_none: bool, force: bool) -> None: + if not is_gradient_accumulation_boundary and not force: + warn_zero_grad() + return + + for param in params: + if set_to_none: + set_to_none_fn(param) + else: + if param.grad is not None: + param.grad.detach_() + param.grad.zero_() diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 796957a4c6e5..44a137424aa4 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -103,7 +103,7 @@ def unwrap_model_for_generation(model): return -INITIAL_MICRO_STEP_ID = -1 +INITIAL_MICRO_STEP_ID = 0 class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer): @@ -293,7 +293,8 @@ def __init__( self.gradient_predivide_factor = gradient_predivide_factor self.postscale_gradients = postscale_gradients self.gradient_accumulation_steps = gradient_accumulation_steps - self.micro_step_id = 0 + self.micro_step_id = INITIAL_MICRO_STEP_ID + self.force_overwrite_grads = False self.reduce_bucket_size = int(reduce_bucket_size) if self.all2all_process_group is not None: @@ -1463,7 +1464,7 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L # move or accumulate gradient partition to target buffer grad_buffer = self.__param_id_to_grad_partition[param.ds_id].narrow(0, 0, grad_partition.numel()) buffers.append(grad_buffer) - if self.micro_step_id == 0: # don't accumulate + if self.micro_step_id == 0 or self.force_overwrite_grads: # don't accumulate grad_buffer.copy_(grad_partition, non_blocking=True) # ensure grad buffer is a CUDA buffer to speed up the next few # operations and so it can be used asynchronously @@ -1504,6 +1505,8 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L param.grad.record_stream(get_accelerator().current_stream()) param.grad = None + self.force_overwrite_grads = False + if self.offload_optimizer and self.swap_optimizer: for i in offload_fp32_gradients.keys(): self.optimizer_swapper.swap_out_gradients(parameter=self.fp32_partitioned_groups_flat[i], @@ -1719,24 +1722,18 @@ def get_partition_info(self, tensor_list, partition_size, partition_id): return params_in_partition, params_not_in_partition, first_offset @instrument_w_nvtx - def zero_grad(self, set_to_none=True): - """ - Zero FP16 parameter grads. - """ - self.micro_step_id = 0 - - # FP32 grad should never exist. - # For speed, set model fp16 grad to None by default - for group in self.fp16_groups: - for p in group: - if set_to_none: - if p.grad is not None and get_accelerator().on_accelerator(p.grad): - p.grad.record_stream(get_accelerator().current_stream()) - p.grad = None - else: - if p.grad is not None: - p.grad.detach_() - p.grad.zero_() + def zero_grad(self, set_to_none=True, force=False): + + def set_grad_to_none(p): + if p.grad is not None and get_accelerator().on_accelerator(p.grad): + p.grad.record_stream(get_accelerator().current_stream()) + p.grad = None + + params = [p for group in self.fp16_groups for p in group] + self._do_zero_grad(params, set_grad_to_none, set_to_none, force) + + # Flag to indicate that the reduced gradients should be copied to the buffer, not accumulated + self.force_overwrite_grads = True def _model_parallel_all_reduce(self, tensor, op): """ Perform all reduce within model parallel group, if any. @@ -1856,7 +1853,7 @@ def reset_cpu_buffers(self): self.norm_for_param_grads = {} def _pre_step(self): - self.micro_step_id = 0 + self.micro_step_id = INITIAL_MICRO_STEP_ID print_rank_0(f"Inside Step function") see_memory_usage(f"In step before checking overflow", force=False) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 83cf996ca019..c4998c8b9ad6 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1633,22 +1633,17 @@ def get_partition_info(self, tensor_list, partition_size, partition_id): return params_in_partition, params_not_in_partition, first_offset - def zero_grad(self, set_to_none=True): + def zero_grad(self, set_to_none=True, force=False): """ Zero FP16 parameter grads. """ - # FP32 grad should never exist. - # For speed, set model fp16 grad to None by default - # zero all pointers to grad tensors - for group in self.bit16_groups: - for p in group: - if set_to_none: - p.grad = None # epilogue and in step - p.grad_accum = None - else: - if p.grad is not None: - p.grad.detach_() - p.grad.zero_() + + def set_grad_to_none(p): + p.grad = None # epilogue and in step + p.grad_accum = None + + params = [p for group in self.bit16_groups for p in group] + self._do_zero_grad(params, set_grad_to_none, set_to_none, force) def _model_parallel_all_reduce(self, tensor, op): """ Perform all reduce within model parallel group, if any. From fea4811e20cd407a3f1b5796224f81e71aaec702 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 18 Sep 2024 20:49:30 +0000 Subject: [PATCH 3/7] fix zero_grad for z1/2 optimizer --- deepspeed/runtime/base_optimizer.py | 5 +++++ deepspeed/runtime/zero/stage3.py | 8 ++------ deepspeed/runtime/zero/stage_1_and_2.py | 5 ++++- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py index d494008bb315..56e2576fbbfa 100644 --- a/deepspeed/runtime/base_optimizer.py +++ b/deepspeed/runtime/base_optimizer.py @@ -18,6 +18,9 @@ class DeepSpeedOptimizer(object): class ZeROOptimizer(DeepSpeedOptimizer): + def __init__(self): + self.force_overwrite_grads = False + def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None: checkpoint_dir = os.path.join(checkpoint_dir, "zero") optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") @@ -69,3 +72,5 @@ def _do_zero_grad(self, set_to_none: bool = True, force: bool = False) -> None: zero_grad_params(params, set_to_none_fn, self.is_gradient_accumulation_boundary, set_to_none, force) + # Flag to indicate that the reduced gradients should be copied to the buffer, not accumulated + self.force_overwrite_grads = True diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 44a137424aa4..4b78d7d7c533 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -156,6 +156,8 @@ def __init__( zero_quantized_weights=False, zero_quantized_nontrainable_weights=False, ): + super().__init__() + see_memory_usage("Stage 3 initialize beginning", force=True) print_rank_0(f"initialized {__class__.__name__} with args: {locals()}", force=False) @@ -294,7 +296,6 @@ def __init__( self.postscale_gradients = postscale_gradients self.gradient_accumulation_steps = gradient_accumulation_steps self.micro_step_id = INITIAL_MICRO_STEP_ID - self.force_overwrite_grads = False self.reduce_bucket_size = int(reduce_bucket_size) if self.all2all_process_group is not None: @@ -1505,8 +1506,6 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L param.grad.record_stream(get_accelerator().current_stream()) param.grad = None - self.force_overwrite_grads = False - if self.offload_optimizer and self.swap_optimizer: for i in offload_fp32_gradients.keys(): self.optimizer_swapper.swap_out_gradients(parameter=self.fp32_partitioned_groups_flat[i], @@ -1732,9 +1731,6 @@ def set_grad_to_none(p): params = [p for group in self.fp16_groups for p in group] self._do_zero_grad(params, set_grad_to_none, set_to_none, force) - # Flag to indicate that the reduced gradients should be copied to the buffer, not accumulated - self.force_overwrite_grads = True - def _model_parallel_all_reduce(self, tensor, op): """ Perform all reduce within model parallel group, if any. """ diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index cfd0f03cbae5..dc45782da064 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -138,6 +138,8 @@ def __init__(self, fp16_master_weights_and_gradients=False, elastic_checkpoint=False): + super().__init__() + if offload_optimizer_config is not None and offload_optimizer_config.device != OffloadDeviceEnum.none: self.cpu_offload = True self.cpu_offload_pin_memory = offload_optimizer_config.pin_memory @@ -1230,7 +1232,7 @@ def copy_gradients_to_cpu(): if param_id not in self.accumulated_grads_in_cpu: self.accumulated_grads_in_cpu[param_id] = buffer_to_accumulate_to_in_cpu() - if self.micro_step_id > 0: + if self.micro_step_id > 0 or self.force_overwrite_grads: accumulate_gradients() else: copy_gradients_to_cpu() @@ -1416,6 +1418,7 @@ def reduce_ipg_grads(self): self.params_in_ipg_bucket = [] self.ipg_bucket_has_moe_params = False self.elements_in_ipg_bucket = 0 + self.force_overwrite_grads = False ##################################################################### def reduce_ready_partitions_and_remove_grads(self, param, i): From b17019601ee91f8f3b05b7896c64b26b93dd010f Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 18 Sep 2024 20:54:53 +0000 Subject: [PATCH 4/7] remove type hint for python 3.7 --- deepspeed/runtime/base_optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py index 56e2576fbbfa..965e2523e66a 100644 --- a/deepspeed/runtime/base_optimizer.py +++ b/deepspeed/runtime/base_optimizer.py @@ -5,7 +5,7 @@ import os import torch -from typing import Callable, Iterable +from typing import Iterable from deepspeed.utils import logger from deepspeed.utils.tensor_fragment import map_to_flat_opt_states @@ -68,7 +68,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec def _do_zero_grad(self, params: Iterable[torch.nn.Parameter], - set_to_none_fn: Callable[[torch.Tensor], None], + set_to_none_fn, set_to_none: bool = True, force: bool = False) -> None: zero_grad_params(params, set_to_none_fn, self.is_gradient_accumulation_boundary, set_to_none, force) From 86da91da573dd9b0beb2cbd83569121594fdb0d5 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 18 Sep 2024 20:59:38 +0000 Subject: [PATCH 5/7] remove type hint --- deepspeed/runtime/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 91f9a0bdbd4f..b173f08072b5 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -8,7 +8,7 @@ Helper functions and classes from multiple sources. """ -from collections.abc import Iterable, Callable +from collections.abc import Iterable from deepspeed.moe.utils import is_moe_param import os import psutil @@ -1081,8 +1081,8 @@ def warn_zero_grad() -> None: return -def zero_grad_params(params: Iterable[torch.nn.Parameter], set_to_none_fn: Callable[[torch.Tensor], None], - is_gradient_accumulation_boundary: bool, set_to_none: bool, force: bool) -> None: +def zero_grad_params(params: Iterable[torch.nn.Parameter], set_to_none_fn, is_gradient_accumulation_boundary: bool, + set_to_none: bool, force: bool) -> None: if not is_gradient_accumulation_boundary and not force: warn_zero_grad() return From 4a59a2b4c7ec367a1e2722646af07a1eee7cc9a3 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 18 Sep 2024 21:04:47 +0000 Subject: [PATCH 6/7] remove type hint --- deepspeed/runtime/base_optimizer.py | 7 +------ deepspeed/runtime/utils.py | 4 ++-- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py index 965e2523e66a..8a14373b2e72 100644 --- a/deepspeed/runtime/base_optimizer.py +++ b/deepspeed/runtime/base_optimizer.py @@ -5,7 +5,6 @@ import os import torch -from typing import Iterable from deepspeed.utils import logger from deepspeed.utils.tensor_fragment import map_to_flat_opt_states @@ -66,11 +65,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec continue param_group[key] = value - def _do_zero_grad(self, - params: Iterable[torch.nn.Parameter], - set_to_none_fn, - set_to_none: bool = True, - force: bool = False) -> None: + def _do_zero_grad(self, params, set_to_none_fn, set_to_none: bool = True, force: bool = False) -> None: zero_grad_params(params, set_to_none_fn, self.is_gradient_accumulation_boundary, set_to_none, force) # Flag to indicate that the reduced gradients should be copied to the buffer, not accumulated self.force_overwrite_grads = True diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index b173f08072b5..7efad2b88e81 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -1081,8 +1081,8 @@ def warn_zero_grad() -> None: return -def zero_grad_params(params: Iterable[torch.nn.Parameter], set_to_none_fn, is_gradient_accumulation_boundary: bool, - set_to_none: bool, force: bool) -> None: +def zero_grad_params(params, set_to_none_fn, is_gradient_accumulation_boundary: bool, set_to_none: bool, + force: bool) -> None: if not is_gradient_accumulation_boundary and not force: warn_zero_grad() return From 958dfc16b7c6c5d2fcf68f13316dcb92b20c911c Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 18 Sep 2024 21:55:55 +0000 Subject: [PATCH 7/7] fix for non-zero optimizer --- deepspeed/runtime/engine.py | 11 +++-------- deepspeed/runtime/utils.py | 9 +++++++++ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 643f44d47065..bf76fed17ba8 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -74,7 +74,7 @@ from deepspeed.utils.debug import debug_extract_module_and_param_names, debug_clear_module_and_param_names from deepspeed.monitor.monitor import MonitorMaster from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop -from deepspeed.runtime.utils import clip_grad_norm_, zero_grad_params +from deepspeed.runtime.utils import clip_grad_norm_, zero_grad_with_grad_acc_boundary_check from deepspeed.runtime.eigenvalue import Eigenvalue from deepspeed.runtime.data_pipeline.constants import DATA_SAMPLING, \ DATA_ROUTING, DATA_SAMPLING_ENABLED, CURRICULUM_LEARNING, \ @@ -2109,15 +2109,10 @@ def zero_grad(self, set_to_none: bool = True, force: bool = False) -> None: self.optimizer.zero_grad(set_to_none, force) else: pass - elif self.zero_optimization() or self.fp16_enabled() or self.amp_enabled(): + elif self.zero_optimization(): self.optimizer.zero_grad(set_to_none, force) else: - - def set_to_none_fn(param): - param.grad = None - - zero_grad_params(self.module.parameters(), set_to_none_fn, self.is_gradient_accumulation_boundary(), - set_to_none, force) + zero_grad_with_grad_acc_boundary_check(self.optimizer, self.is_gradient_accumulation_boundary(), force) def clip_fp32_gradients(self): clip_grad_norm_(parameters=self.module.parameters(), max_norm=self.gradient_clipping(), mpu=self.mpu) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 7efad2b88e81..22b938670bad 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -1094,3 +1094,12 @@ def zero_grad_params(params, set_to_none_fn, is_gradient_accumulation_boundary: if param.grad is not None: param.grad.detach_() param.grad.zero_() + + +def zero_grad_with_grad_acc_boundary_check(optimizer: torch.optim.Optimizer, is_gradient_accumulation_boundary: bool, + force: bool) -> None: + if not is_gradient_accumulation_boundary and not force: + warn_zero_grad() + return + + optimizer.zero_grad()