diff --git a/.github/workflows/build-image.yaml b/.github/workflows/build-image.yaml index b27a884a..90d7b7e6 100644 --- a/.github/workflows/build-image.yaml +++ b/.github/workflows/build-image.yaml @@ -31,6 +31,7 @@ jobs: uses: actions/checkout@v2 with: submodules: true + path: buildimage - name: Free disk space run: | mkdir -p /tmp/emptydir @@ -54,7 +55,7 @@ jobs: if [[ "${{ github.event_name }}" == "release" ]]; then TAGS=$(sed "s/main/${GITHUB_REF##*/}/g" <<< ${TAGS}) fi - DOCKERFILE=dockerfile/${{ matrix.name }}.dockerfile + DOCKERFILE=buildimage/dockerfile/${{ matrix.name }}.dockerfile CACHE_FROM="type=registry,ref=$(cut -d, -f1 <<< ${TAGS})" CACHE_TO="" @@ -87,7 +88,7 @@ jobs: uses: docker/build-push-action@v2 with: platforms: linux/amd64 - context: . + context: ./buildimage file: ${{ steps.metadata.outputs.dockerfile }} push: ${{ github.event_name != 'pull_request' }} tags: ${{ steps.metadata.outputs.tags }} diff --git a/.github/workflows/unit-tests.yaml b/.github/workflows/unit-tests.yaml index 5680f381..fae3a106 100644 --- a/.github/workflows/unit-tests.yaml +++ b/.github/workflows/unit-tests.yaml @@ -57,10 +57,6 @@ jobs: export LD_PRELOAD="/usr/local/lib/libmsamp_dist.so:/usr/local/lib/libnccl.so:${LD_PRELOAD}" cd ${{ matrix.dir }}/ python3 setup.py test - - name: Clean repository - if: always() - run: | - rm -rf ${{ matrix.dir }}/ # - name: Report coverage results # run: | # bash <(curl -s https://codecov.io/bash) diff --git a/docs/getting-started/run-msamp.md b/docs/getting-started/run-msamp.md index 5465aaf7..e1c1f73f 100644 --- a/docs/getting-started/run-msamp.md +++ b/docs/getting-started/run-msamp.md @@ -52,4 +52,8 @@ deepspeed cifar10_deepspeed.py --deepspeed --deepspeed_config ds_config_zero_msa deepspeed cifar10_deepspeed_te.py --deepspeed --deepspeed_config ds_config_zero_te_msamp.json ``` +:::note Note +If you get "ModuleNotFoundError: No module named 'timm'" error when running this example, you need to install timm using `pip install timm`. +::: + For more comprehensive examples, please go to [MS-AMP-Examples](https://github.com/Azure/MS-AMP-Examples). diff --git a/docs/introduction.md b/docs/introduction.md index d502db63..0a60864c 100644 --- a/docs/introduction.md +++ b/docs/introduction.md @@ -36,7 +36,7 @@ Here are the results for GPT-3, Swin-T, DeiT-S and RoBERTa-B. ### System performance -MS-AMP preserves high-precision's accuracy while using only a fraction of the memory footprint on a range of tasks, including GPT-3, DeiT and Swin Transformer. For example, when training GPT-175B on NVIDIA H100 platform, MS-AMP achieves a notable 42% reduction in real memory usage compared with BF16 mixed-precision approach and reduces training time by 17% compared with Transformer Engine. For small models, MS-AMP with O2 mode can achieve 44% memory saving for Swin-1.0B and 26% memory saving for ViT-1.2B, comparing with FP16 AMP. +MS-AMP preserves high-precision's accuracy while using only a fraction of the memory footprint on a range of tasks, including GPT-3, DeiT and Swin Transformer. For example, when training GPT-175B on NVIDIA H100 platform, MS-AMP achieves a notable 39% reduction in real memory usage compared with BF16 mixed-precision approach and reduces training time by 37% compared with Transformer Engine. For small models, MS-AMP with O2 mode can achieve 44% memory saving for Swin-1.0B and 26% memory saving for ViT-1.2B, comparing with FP16 AMP. Here are the resuls for GPT-3: diff --git a/msamp/megatron/optimizer/distrib_optimizer.py b/msamp/megatron/optimizer/distrib_optimizer.py index 3b8f8148..5707aef3 100644 --- a/msamp/megatron/optimizer/distrib_optimizer.py +++ b/msamp/megatron/optimizer/distrib_optimizer.py @@ -351,28 +351,131 @@ def get_model_parallel_group(self): return None def state_dict(self): - """The state dict must contain the fp32-from-float16 and fp16-from-fp8 shards.""" + """Return the state dict of this optimizer. + + The state dict contains all non-DP-rank-dependent (i.e., non-parameter- + related) optimizer variables. The returned state dict can be stored in + the standard model/RNG checkpoint file. The parameter and dependent + optimizer state (e.g., exp_avg, exp_avg_sq) are stored in a separate + checkpoint file by calling 'save_parameter_state()'. + """ + # MS-AMP: Store step in param group. + if hasattr(self.optimizer, 'exp_avg_dtype'): + for group in self.optimizer.param_groups: + step = 0 + for param in group['params']: + if param.grad is not None: + step = self.optimizer.state[param]['step'] + break + group['step'] = step + state_dict = {} - state_dict['optimizer'] = self.optimizer.state_dict() + + # Optimizer state (do not store parameter state here). + state_dict['optimizer'] = {k: v for k, v in self.optimizer.state_dict().items() if k != 'state'} + + for param_group in state_dict['optimizer']['param_groups']: + del param_group['params'] + + # Grad scaler state. if self.grad_scaler: state_dict['grad_scaler'] = self.grad_scaler.state_dict() - # shared master weight - state_dict['shard_fp32_from_float16_groups'] = \ - self.shard_fp32_from_float16_groups - state_dict['shard_hp_from_fp8_groups'] = \ - self.shard_hp_from_fp8_groups + return state_dict def load_state_dict(self, state_dict): - """Load the state dict.""" - optimizer_key = 'optimizer' - if optimizer_key not in state_dict: - optimizer_key = 'optimizer_state_dict' - print_rank_0('***WARNING*** loading optimizer from ' - 'an old checkpoint ...') - # convert optimizer states - ckpt_state_dict = state_dict[optimizer_key] - self.optimizer.load_state_dict(ckpt_state_dict) + """Load the state dict. + + As detailed in state_dict(), the state dict contains all non- + parameter-related variables. This method is notably longer than + state_dict(), because the Torch optimizers state has yet to be + allocated at this point, and so we must do a cross referencing between + the optimizers state (and the ordering it expects for parameter state) + and this DP rank's shards. The optimizer at this point does not contain + any tensor dimension information, so we must get these dimensions from + the DP shards mapped during DistributedOptimizer.__init__(). + + The tensor parameter state is loaded via load_parameter_state(), and + so this method also must populate the loaded state dict with dummy + tensor data (i.e., via torch.empty() below). This will be overwritten + during load_parameter_state(). + + ** Note: Torch optimizer's state structure. ** + The Torch optimizer stores its state in two levels. The top level is a + list of groups, where each group contains a list of integer indexes + (corresponding to parameters) that index into a master parameter list + that is shared by all groups. As such, three values are necessary for + maintaining this ordering: + + - group_index : The group to which a parameter belongs. + - group_order : The index of a parameter within its group. + - state_order : The index of a parameter within the shared parameter + list. + """ + + # Get the Torch optimizer's state dict. + # - This 'inner' optimizer at this point is unallocated, and only + # contains an integer odering of parameters within each group, and + # the ordering of parameters within its flattened parameter state + # list. + inner_state_dict = self.optimizer.state_dict() + state_dict_param_groups = [ + { + **group, + 'params': list(inner_state_dict['param_groups'][idx]['params']), + } for idx, group in enumerate(state_dict['optimizer']['param_groups']) + ] + + # Allocate 'dummy' data for optimizer state (i.e., torch.empty() below) + # - Real data is overwritten during load_parameter_state(). + state_dict_state = [] + for gbuf_range_maps in self.model_gbuf_ranges: + for gbuf_range_map in gbuf_range_maps.values(): + for model_param, param_range_map in gbuf_range_map['param_map'].items(): + + # Get parameter ordering information (see method docstring + # for details). + group_index, group_order = \ + self.model_param_group_index_map[model_param] + state_order = inner_state_dict['param_groups'][group_index]['params'][group_order] + + # Allocate dummy tensors. + numel = len(param_range_map['gbuf_world']) + # MS-AMP: Allocate dummy tensors for exp_avg and exp_avg_sq and cast to ScalingTensor + if hasattr(self.optimizer, 'exp_avg_dtype') and self.optimizer.exp_avg_dtype != torch.float32: + step = state_dict['optimizer']['param_groups'][group_index]['step'] + exp_avg_qtype = Dtypes.dtype_to_qtype[self.optimizer.exp_avg_dtype] + exp_avg_sq_qtype = Dtypes.dtype_to_qtype[self.optimizer.exp_avg_sq_dtype] + exp_avg = torch.empty((numel, ), dtype=torch.float32, + device=torch.cuda.current_device()).cast(exp_avg_qtype) + exp_avg_sq = torch.empty((numel, ), dtype=torch.float32, + device=torch.cuda.current_device()).cast(exp_avg_sq_qtype) + state_dict_state.append( + (state_order, { + 'exp_avg': exp_avg, + 'exp_avg_sq': exp_avg_sq, + 'step': step + }) + ) + else: + init_shard = lambda: torch.empty( # noqa: E731 + (numel, ), dtype=torch.float32, device=torch.cuda.current_device() + ) + + state_dict_state.append((state_order, { + 'exp_avg': init_shard(), + 'exp_avg_sq': init_shard(), + })) + + # Sort by state order (see method docstring for details). + state_dict_state.sort(key=lambda s: s[0]) + state_dict_state = {s[0]: s[1] for s in state_dict_state} + + # Optimizer. + self.optimizer.load_state_dict({ + 'state': state_dict_state, + 'param_groups': state_dict_param_groups, + }) # Grad scaler. if 'grad_scaler' not in state_dict: @@ -389,23 +492,187 @@ def load_state_dict(self, state_dict): 'Skipping loading grad scaler ...' ) - # Copy data for the main params. - for current_group, saved_group in zip( - self.shard_fp32_from_float16_groups, state_dict['shard_fp32_from_float16_groups'] - ): - for current_param, saved_param in zip(current_group, saved_group): - current_param.data.copy_(saved_param.data) + def save_parameter_state(self, filename): + """Save parameter state (i.e., parameter & optimizer tensors). - for current_group, saved_group in zip(self.shard_hp_from_fp8_groups, state_dict['shard_hp_from_fp8_groups']): - for current_param, saved_param in zip(current_group, saved_group): - if current_param.data.qtype == saved_param.data.qtype: - current_param.data.copy_(saved_param.data) - else: - # when the data type of optimizer's master weight and checkpoint's is different - current_param.data.copy_( - saved_param.data.to(current_param.data.device).cast(current_param.data.qtype) + This method performs three steps: + - For each DP rank, copy param & optimizer shards to contiguous CPU + buffers. (e.g., one buffer each for main_param, exp_avg, and + exp_avg_sq). + - Gather contiguous buffers on DP rank 0 and concatenate to world + buffers. + - Save world buffers to disk (i.e., distrib_opt.pt). + """ + + # Data parallelism variables. + data_parallel_world_size = mpu.get_data_parallel_world_size() + data_parallel_rank = mpu.get_data_parallel_rank() + data_parallel_group_gloo = mpu.get_data_parallel_group_gloo() + data_parallel_global_ranks = list(mpu._DATA_PARALLEL_GLOBAL_RANKS) + + # Collect param states. + state = {} + for model_idx, gbuf_range_maps in enumerate(self.model_gbuf_ranges): + + # Iterate grad buffers (by data type). + dtype_state = {} + + # MS-AMP: We use FP8 + FP32 now, so we don't need this assert. + # assert len(gbuf_range_maps) == 1, "single dtype supported, for now." + + for dtype, gbuf_range_map in gbuf_range_maps.items(): + + # Compute local DP contiguous shard's size. + model = self.models[model_idx] + gbuf_world_numel = model._grad_buffers[dtype].numel_padded + gbuf_local_numel = int(gbuf_world_numel / data_parallel_world_size) + local_shards = { + key: torch.empty((gbuf_local_numel, ), dtype=torch.float32, device='cpu') + for key in ('param', 'exp_avg', 'exp_avg_sq') + } + + # Build contiguous DP rank shards (for param + optim states). + for model_param, param_range_map in gbuf_range_map['param_map'].items(): + + # Main param & optimizer states. + group_index, group_order = self.model_param_group_index_map[model_param] + main_param = self.optimizer.param_groups[group_index]['params'][group_order] + optim_state = self.optimizer.state[main_param] + + tensors = { + 'param': main_param, + **optim_state, + } + + # Copy states into contiguous shard. + gbuf_local_start = param_range_map['gbuf_local'].start + gbuf_local_end = param_range_map['gbuf_local'].end + for key in local_shards: + # MS-AMP: Convert to float32 for ScalingTensor. + if isinstance(tensors[key], ScalingTensor): + local_shards[key][gbuf_local_start:gbuf_local_end] \ + .data.copy_(tensors[key].detach().float().view(-1).cpu()) + else: + local_shards[key][gbuf_local_start:gbuf_local_end] \ + .data.copy_(tensors[key].detach().cpu()) + + # Gather contiguous shards on DP rank 0. + world_tensors = {} + for key, send_tensor in local_shards.items(): + + # Gather tensor list. + if data_parallel_rank == 0: + recv_tensors = [ + torch.empty((gbuf_local_numel, ), dtype=torch.float32, device='cpu') + for _ in range(data_parallel_world_size) + ] + else: + recv_tensors = None + + # Gather. + torch.distributed.gather( + send_tensor, + recv_tensors, + data_parallel_global_ranks[0], + data_parallel_group_gloo, + ) + + # Concatenate. + if data_parallel_rank == 0: + world_tensors[key] = torch.cat(recv_tensors) + + # Collect world state. + dtype_state[dtype] = world_tensors + state[model_idx] = dtype_state + + # Save param state. + if data_parallel_rank == 0: + torch.save(state, filename) + + def load_parameter_state(self, filename): + """Load parameter state (i.e., parameter & optimizer tensors). + + This method performs the reverse of save_parameter_state(): + - Load world buffers from disk (i.e., distrib_opt.pt). + - Scatter contiguous buffers from DP rank 0 to each DP rank (each DP + rank receives its relevant subset of the world buffers). + - For each DP rank, copy param & optimizer shards from contiguous CPU + buffers. (e.g., one buffer each for main_param, exp_avg, and + exp_avg_sq). + """ + + # Data parallelism variables. + data_parallel_world_size = mpu.get_data_parallel_world_size() + data_parallel_rank = mpu.get_data_parallel_rank() + data_parallel_group_gloo = mpu.get_data_parallel_group_gloo() + data_parallel_global_ranks = list(mpu._DATA_PARALLEL_GLOBAL_RANKS) + + # Load on DP rank 0. + if data_parallel_rank == 0: + loaded_state = torch.load(filename) + + # Scatter tensors to all DP ranks. + for model_idx, gbuf_range_maps in enumerate(self.model_gbuf_ranges): + for dtype, gbuf_range_map in gbuf_range_maps.items(): + + # Compute local DP contiguous shard's size. + model = self.models[model_idx] + gbuf_world_numel = model._grad_buffers[dtype].numel_padded + gbuf_local_numel = int(gbuf_world_numel / data_parallel_world_size) + + # Contiguous local shards (received from DP rank 0). + local_shards = { + key: torch.empty((gbuf_local_numel, ), dtype=torch.float32, device='cpu') + for key in ('param', 'exp_avg', 'exp_avg_sq') + } + + # Scatter local shards from DP rank 0. + for key, recv_tensor in local_shards.items(): + + # Scatter tensor list. + if data_parallel_rank == 0: + world_tensor = loaded_state[model_idx][dtype][key] + gbuf_start_idxs = \ + list(range(0, gbuf_world_numel, gbuf_local_numel)) + send_tensors = [world_tensor[i:(i + gbuf_local_numel)] for i in gbuf_start_idxs] + else: + send_tensors = None + + # Scatter. + torch.distributed.scatter( + recv_tensor, + send_tensors, + data_parallel_global_ranks[0], + data_parallel_group_gloo, ) + # Copy local contiguous shards to param/optim shards. + for model_param, param_range_map in gbuf_range_map['param_map'].items(): + + # Main param & optimizer states. + group_index, group_order = \ + self.model_param_group_index_map[model_param] + main_param = self.optimizer.param_groups[group_index]['params'][group_order] + optim_state = self.optimizer.state[main_param] + + tensors = { + 'param': main_param, + **optim_state, + } + + # Copy states into contiguous shard. + gbuf_local_start = param_range_map['gbuf_local'].start + gbuf_local_end = param_range_map['gbuf_local'].end + for key in local_shards: + if isinstance(tensors[key], ScalingTensor): + tensors[key].copy_( + local_shards[key][gbuf_local_start:gbuf_local_end].view_as( + tensors[key].value + ).cuda().cast(tensors[key].meta.qtype) + ) + else: + tensors[key].data.copy_(local_shards[key][gbuf_local_start:gbuf_local_end]) + def zero_grad(self, set_to_none=True): """Zero grads. diff --git a/msamp/optim/adamw.py b/msamp/optim/adamw.py index d71b41c7..8cf68e10 100644 --- a/msamp/optim/adamw.py +++ b/msamp/optim/adamw.py @@ -185,8 +185,12 @@ def adamw_fn( # noqa: C901 for i, param in enumerate(params): grad = grads[i].float() if not maximize else -grads[i].float() - exp_avgs[i].meta.scale = _new_exp_avg_factors[i] if self.tensor_scale else 1.0 - exp_avg_sqs[i].meta.scale = _new_exp_avg_sq_factors[i] if self.tensor_scale else 1.0 + exp_avgs[i].meta.scale = _new_exp_avg_factors[i] if self.tensor_scale else torch.ones((), device='cuda') + exp_avgs[i].meta.scale_inv.fill_(1.0 / exp_avgs[i].meta.scale) + exp_avg_sqs[i].meta.scale = _new_exp_avg_sq_factors[i] if self.tensor_scale else torch.ones( + (), device='cuda' + ) + exp_avg_sqs[i].meta.scale_inv.fill_(1.0 / exp_avg_sqs[i].meta.scale) # update state msamp_adamw.adamw_fp8_stage2_compute( grad, exp_avgs[i].value, _exp_avg_inv_factors[i], exp_avgs[i].meta.scale, beta1, diff --git a/msamp/te/extension.py b/msamp/te/extension.py index a831240f..742fbafa 100644 --- a/msamp/te/extension.py +++ b/msamp/te/extension.py @@ -9,6 +9,7 @@ from msamp.common.dtype import Dtypes from msamp.common.tensor import ScalingTensor +from msamp.nn import ScalingParameter class TeExtensionOverrider: @@ -24,6 +25,7 @@ class TeExtensionOverrider: original_fused_cast_transpose = tex.fused_cast_transpose original_cast_to_fp8 = te.cpp_extensions.cast_to_fp8 original_fp8_cast_transpose_fused = te.cpp_extensions.fp8_cast_transpose_fused + original_cast_if_needed = te.utils.cast_if_needed @staticmethod @torch.no_grad() @@ -119,6 +121,24 @@ def cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype, out=None): return TeExtensionOverrider.original_cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype) return TeExtensionOverrider.original_cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype, out) + @staticmethod + def cast_if_needed(tensor, dtype): + """Cast tensor to dtype. + + Args: + tensor (torch.Tensor or ScalingParameter): Input tensor. + dtype (torch.dtype): Output dtype. + + Returns: + torch.Tensor: Output tensor. + """ + with torch.enable_grad(): + if isinstance(tensor, ScalingParameter): + new_tensor = tensor.to(dtype) + new_tensor.requires_grad = tensor.requires_grad + return new_tensor + return TeExtensionOverrider.original_cast_if_needed(tensor, dtype) + @staticmethod def override(): """Override transformer engine extension functions.""" @@ -127,5 +147,9 @@ def override(): te.module.linear.cast_to_fp8 = TeExtensionOverrider.cast_to_fp8 te.cpp_extensions.fp8_cast_transpose_fused = TeExtensionOverrider.fp8_cast_transpose_fused + te.module.layernorm_linear.cast_if_needed = TeExtensionOverrider.cast_if_needed + te.module.linear.cast_if_needed = TeExtensionOverrider.cast_if_needed + te.module.layernorm_mlp.cast_if_needed = TeExtensionOverrider.cast_if_needed + TeExtensionOverrider.override() diff --git a/msamp/te/modules.py b/msamp/te/modules.py index 20c41057..9c8fb7ab 100644 --- a/msamp/te/modules.py +++ b/msamp/te/modules.py @@ -229,8 +229,6 @@ def _override_classes(cls): te.attention.Linear = MSAMPLinear te.attention.LayerNormLinear = MSAMPLayerNormLinear - te.transformer.Linear = MSAMPLinear - te.transformer.LayerNormLinear = MSAMPLayerNormLinear te.transformer.LayerNormMLP = MSAMPLayerNormMLP @staticmethod diff --git a/tests/te/test_te_replacer.py b/tests/te/test_te_replacer.py index 4071de1b..9c56e411 100644 --- a/tests/te/test_te_replacer.py +++ b/tests/te/test_te_replacer.py @@ -5,6 +5,7 @@ import os import unittest +from contextlib import nullcontext import torch import torch.distributed as dist @@ -65,17 +66,16 @@ def _check_model(model): scaling_params = [p for p in model.parameters() if isinstance(p, ScalingParameter)] assert len(scaling_params) == 4 - is_fp8_available = te.fp8.check_fp8_support() - if is_fp8_available: - # Do a forward pass to make sure the model is working. - fp8_format = Format.HYBRID - fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max') - x = torch.rand(self.sequence_length, self.batch_size, self.hidden_size).cuda().to(dtype=self.dtype) - - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - y = model(x, attention_mask=None) - assert y.shape == (self.sequence_length, self.batch_size, self.hidden_size) - y.sum().backward() + is_fp8_available, _ = te.fp8.check_fp8_support() + # Do a forward pass to make sure the model is working. + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max') + x = torch.rand(self.sequence_length, self.batch_size, self.hidden_size).cuda().to(dtype=self.dtype) + + with te.fp8_autocast(enabled=is_fp8_available, fp8_recipe=fp8_recipe) if is_fp8_available else nullcontext(): + y = model(x, attention_mask=None) + assert y.shape == (self.sequence_length, self.batch_size, self.hidden_size) + y.sum().backward() @decorator.cuda_test def test_te_with_deepspeed(self): @@ -100,12 +100,13 @@ def test_te_with_deepspeed(self): fp8_format = Format.HYBRID fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max') - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + is_fp8_available, _ = te.fp8.check_fp8_support() + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe) if is_fp8_available else nullcontext(): input = torch.randn(self.sequence_length, self.batch_size, self.hidden_size).cuda().to(dtype=self.dtype) output = model(input, attention_mask=None) - loss = output.sum() - model.backward(loss) - model.step() + loss = output.sum() + model.backward(loss) + model.step() class TeReplacerDistributedTestCast(MultiProcessTestCase): @@ -163,9 +164,10 @@ def test_fp8_ddp_with_te(self): x = torch.randn(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype) fp8_format = Format.HYBRID fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max') - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + is_fp8_available, _ = te.fp8.check_fp8_support() + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe) if is_fp8_available else nullcontext(): output = model(x, attention_mask=None, is_first_microbatch=True) - output.sum().backward() - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + output.sum().backward() + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe) if is_fp8_available else nullcontext(): output = model(x, attention_mask=None, is_first_microbatch=False) - output.sum().backward() + output.sum().backward()