From 41545ce94538c2269e8cf71cf53a4cc3280b8685 Mon Sep 17 00:00:00 2001 From: KimmiShi Date: Fri, 24 May 2024 21:40:55 +0800 Subject: [PATCH] Fix(mha,linear): fix norm_head and mha inference (#234) Co-authored-by: shidongxing --- internlm/core/parallel/comm/__init__.py | 0 internlm/core/parallel/comm/tensor.py | 8 +++--- internlm/model/modules/linear.py | 2 +- internlm/model/modules/mha.py | 38 +++++++++++++++---------- internlm/utils/utils.py | 7 +++-- 5 files changed, 32 insertions(+), 23 deletions(-) create mode 100644 internlm/core/parallel/comm/__init__.py diff --git a/internlm/core/parallel/comm/__init__.py b/internlm/core/parallel/comm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/internlm/core/parallel/comm/tensor.py b/internlm/core/parallel/comm/tensor.py index 47086ad9..ca8c1900 100644 --- a/internlm/core/parallel/comm/tensor.py +++ b/internlm/core/parallel/comm/tensor.py @@ -233,7 +233,7 @@ def grad_output_hook( if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1: return grad_output, DUMMY_HANDLE_CONST - return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1) + return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST def output_hook( self, output: torch.Tensor, async_op: bool = False # pylint: disable=W0613 @@ -244,7 +244,7 @@ def output_hook( if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1: return output, DUMMY_HANDLE_CONST - return _gather(output, parallel_mode=self._parallel_mode, dim=-1) + return _gather(output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST class HeadSequenceParallelCommunicator(SequenceParallelCommunicator): @@ -274,7 +274,7 @@ def grad_output_hook( if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1: return grad_output, DUMMY_HANDLE_CONST - return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1) + return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST # rewrite ouput communication hook def output_hook( @@ -286,7 +286,7 @@ def output_hook( if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1: return output, DUMMY_HANDLE_CONST - return _gather(output, parallel_mode=self._parallel_mode, dim=-1) + return _gather(output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST class MoESequenceParallelCommunicator: diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 44353970..1ea65c6f 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -492,7 +492,7 @@ def forward(self, input): # pylint: disable=W0622 return fused_dense_func( input, - self.weight, + weight, communicator=self._communicator, module=self, bias=self.bias, diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py index e0669726..3c08fb0a 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/modules/mha.py @@ -184,9 +184,13 @@ def _convert_unpacked_qkv_to_packed( max_seqlen_q = attention_mask.shape[-1] max_seqlen_k = attention_mask.shape[-1] - q_packed = q.masked_select(attention_mask.view(batch_size, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]) - kv_packed = kv.masked_select(attention_mask.view(batch_size, -1, 1, 1, 1)).view( - -1, kv.shape[-3], kv.shape[-2], kv.shape[-1] + q_packed = ( + q.masked_select(attention_mask.view(batch_size, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]).unsqueeze(0) + ) + kv_packed = ( + kv.masked_select(attention_mask.view(batch_size, -1, 1, 1, 1)) + .view(-1, kv.shape[-3], kv.shape[-2], kv.shape[-1]) + .unsqueeze(0) ) return q_packed, kv_packed, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k @@ -194,8 +198,8 @@ def _convert_unpacked_qkv_to_packed( def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613 assert inference_params is not None, "inference_params is required for inference" assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" - attention_mask = inference_params.get("attention_mask", None) - sequence_len_offset = inference_params.get("sequence_len_offset", 0) + attention_mask = inference_params.attention_mask + sequence_len_offset = inference_params.sequence_len_offset batch_size = x.shape[0] # wqkv, output: q, kv @@ -230,21 +234,21 @@ def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613 q = self.rotary_emb( q, offsets=sequence_len_offset, cache_type="query", interleaved=self.interleaved ) - k = kv[:, :, 0].squeueze(2) + k = kv[:, :, 0].squeeze(2) self.rotary_emb( k, offsets=0, cache_type="key", interleaved=self.interleaved, in_place=True ) # in-place is important else: if self.rotary_emb_dim > 0: q = self.rotary_emb(q, offsets=0, cache_type="query", interleaved=self.interleaved) - k = kv[:, :, 0].squeueze(2) + k = kv[:, :, 0].squeeze(2) self.rotary_emb( k, offsets=0, cache_type="key", interleaved=self.interleaved, in_place=True ) # in-place is important else: assert self.rotary_emb_dim > 0, "You should use rotary_emb." - k, v = kv[:, :, 0].squeueze(2), kv[:, :, 1].squeueze(2) + k, v = kv[:, :, 0].squeeze(2), kv[:, :, 1].squeeze(2) if attention_mask is None: q = self.rotary_emb(q, offsets=sequence_len_offset, cache_type="query", interleaved=self.interleaved) @@ -474,9 +478,13 @@ def _convert_unpacked_qkv_to_packed( max_seqlen_q = attention_mask.shape[-1] max_seqlen_k = attention_mask.shape[-1] - q_packed = q.masked_select(attention_mask.view(batch_size, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]) - kv_packed = kv.masked_select(attention_mask.view(batch_size, -1, 1, 1, 1)).view( - -1, kv.shape[-3], kv.shape[-2], kv.shape[-1] + q_packed = ( + q.masked_select(attention_mask.view(batch_size, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]).unsqueeze(0) + ) + kv_packed = ( + kv.masked_select(attention_mask.view(batch_size, -1, 1, 1, 1)) + .view(-1, kv.shape[-3], kv.shape[-2], kv.shape[-1]) + .unsqueeze(0) ) return q_packed, kv_packed, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k @@ -484,9 +492,9 @@ def _convert_unpacked_qkv_to_packed( def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613 assert inference_params is not None, "inference_params is required for inference" assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" - attention_mask = inference_params.get("attention_mask", None) - sequence_len_offset = inference_params.get("sequence_len_offset", 0) - window_size = inference_params.get("window_size", None) + attention_mask = inference_params.attention_mask + sequence_len_offset = inference_params.sequence_len_offset + window_size = inference_params.window_size batch_size = x.shape[0] @@ -494,7 +502,7 @@ def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613 if self.enable_qkv_fusion: qkv = self.wqkv(x) qkv = rearrange(qkv, "b s (h gs d) -> b s h gs d", gs=self.q_per_kv + 2, d=self.head_dim) - q, k, v = (qkv[..., : self.q_per_kv, :], qkv[..., -2, :].unsqueeze(-2), qkv[..., -1, :].unsqueeze(-2)) + q, k, v = (qkv[..., : self.q_per_kv, :], qkv[..., -2, :], qkv[..., -1, :]) q = rearrange(q, "b s h gs d -> b s (h gs) d") else: q, k, v = self.wq(x), self.wk(x), self.wv(x) diff --git a/internlm/utils/utils.py b/internlm/utils/utils.py index e8e76e70..34766b3b 100644 --- a/internlm/utils/utils.py +++ b/internlm/utils/utils.py @@ -62,14 +62,15 @@ def __kv_checker(num_args: int): # kv: [batch, seqlen, 3, n_head, headdim] return len(args[2].shape) == 5 - def __cu_seqlens_checker(num_args: int, check_idx: int): + def __cu_seqlens_checker(args, check_idx: int): + num_args = len(args) if num_args < (check_idx + 1): if check_idx == 2: return "cu_seqlens" in kwargs and kwargs["cu_seqlens"] is not None else: return "cu_seqlens_q" in kwargs and kwargs["cu_seqlens_q"] is not None else: - return isinstance(num_args[check_idx], torch.Tensor) + return isinstance(args[check_idx], torch.Tensor) if __qkv_checker(len(args)): # qkv packed, and we should check cu_seqlens with index 2 @@ -81,7 +82,7 @@ def __cu_seqlens_checker(num_args: int, check_idx: int): # qkv splited, and we should check cu_seqlens with index 4 qkv_pack_type = int(QKVPackType.QKVSPLITED) - with_cu_seqlens = __cu_seqlens_checker(len(args), qkv_pack_type) + with_cu_seqlens = __cu_seqlens_checker(args, qkv_pack_type) return str(qkv_pack_type), str(with_cu_seqlens)