diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 57225243..d92de405 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -48,6 +48,9 @@ class LlamaConfig: rms_norm_eps: float = 1e-6 rope_scaling: Optional[dict] = None rope_theta: float = 10000.0 + rope_interleaved: bool = ( + False # The default value has been True, but for loading Llama3 checkpoints you have to set it to False + ) tie_word_embeddings: bool = False use_cache: bool = True vocab_size: int = 32000 diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 49ea86e6..4c563bc8 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, Optional, Union, List +from typing import Dict, List, Optional, Union import torch from torch import nn @@ -74,9 +74,10 @@ def init_rotary_embeddings(self): self.freqs_cis = self.freqs_cis.to(torch.float) assert self.freqs_cis.dtype == torch.float freqs = 1.0 / ( - self.theta - ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cuda")[: (self.dim // 2)] / self.dim) - ) + self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cpu")[: (self.dim // 2)] / self.dim) + ).to( + "cuda" + ) # should be computed on CPU, otherwise different results with Transformers. t = torch.arange(self.end, device="cuda") freqs = torch.outer(t, freqs).float() complex_freqs = torch.polar(torch.ones_like(freqs), freqs) @@ -118,6 +119,78 @@ def forward( return x_out.type(dtype) +## Copy from transformers. Non interleaved version of RoPE. Will be refactored later +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, dim: int, end: int, theta: float = 500000.0): + super().__init__() + self.dim = dim + self.end = end + self.theta = theta + self.init_rotary_embeddings() + + def init_rotary_embeddings(self): + inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cpu") / self.dim) + ) # important to compute on CPU + self.register_buffer( + "inv_freq", torch.empty(self.dim // 2, dtype=torch.float, device="cuda"), persistent=False + ) + self.inv_freq = self.inv_freq.to( + torch.float + ) # make it float32 before copy to avoid precision loss during copy_ + self.inv_freq.copy_(inv_freq) + + @torch.no_grad() + def forward( + self, + x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk] + position_ids: Optional[torch.LongTensor], # [batch_size, seq_length] + ): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def rotate_half(self, x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=2): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (self.rotate_half(q) * sin) + k_embed = (k * cos) + (self.rotate_half(k) * sin) + return q_embed, k_embed + + class GLUActivation(nn.Module): def __init__(self, act_fn_name: str): super().__init__() @@ -319,14 +392,24 @@ def __init__( tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) # TODO(kunhao): We want to have only one version per device and not one version per layer. - self.rotary_embedding = RotaryEmbedding( - dim=self.d_qk, - end=config.max_position_embeddings, - theta=config.rope_theta, - ) + if config.rope_interleaved: + self.rotary_embedding = RotaryEmbedding( + dim=self.d_qk, + end=config.max_position_embeddings, + theta=config.rope_theta, + ) + else: + self.rotary_embedding = LlamaRotaryEmbedding( + dim=self.d_qk, + end=config.max_position_embeddings, + theta=config.rope_theta, + ) + self.rope_interleaved = config.rope_interleaved # NOTE: Only supported for training (TODO(fmom): position_ids not supported yet) - self.flash_rotary_embedding = FlashRotaryEmbedding(dim=self.d_qk, base=config.rope_theta, interleaved=True) + self.flash_rotary_embedding = FlashRotaryEmbedding( + dim=self.d_qk, base=config.rope_theta, interleaved=config.rope_interleaved + ) self.o_proj = TensorParallelRowLinear( config.num_attention_heads * self.d_qk, @@ -405,8 +488,16 @@ def forward( # Compute rotary embeddings # Note: keep track of old rotary embedding end to check if we need to enlarge k_cache and v_cache old_rotary_embed_end = self.rotary_embedding.end - query_states = self.rotary_embedding(query_states, position_ids=position_ids) - key_states = self.rotary_embedding(key_states, position_ids=position_ids) + # interleaved version. + if self.rope_interleaved: + query_states = self.rotary_embedding(query_states, position_ids=position_ids) + key_states = self.rotary_embedding(key_states, position_ids=position_ids) + # non interleaved version. + else: + cos, sin = self.rotary_embedding(value_states, position_ids) + query_states, key_states = self.rotary_embedding.apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) if "key" not in store: # First inference iteration (Prefill) @@ -545,7 +636,7 @@ def forward( cache_seqlens=position_offsets.contiguous(), softmax_scale=softmax_scale, causal=True, - rotary_interleaved=False, # GPT-NeoX style + rotary_interleaved=False, # the value is not used unless rotary_cos/sin is provided. https://github.com/Dao-AILab/flash-attention ) store.update( @@ -620,9 +711,9 @@ def __init__( self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) - + self.recompute_layer = parallel_config.recompute_layer - + def _core_forward( self, hidden_states: Union[torch.Tensor, TensorPointer], @@ -641,12 +732,12 @@ def _core_forward( hidden_states = hidden_states + residual return hidden_states, output["sequence_mask"] - + def _checkpointed_forward( self, hidden_states: torch.Tensor, sequence_mask: torch.Tensor, - ) -> List[torch.Tensor]: + ) -> List[torch.Tensor]: return CheckpointFunction.apply(self._core_forward, True, hidden_states, sequence_mask) def forward( @@ -654,7 +745,7 @@ def forward( hidden_states: Union[torch.Tensor, TensorPointer], sequence_mask: Union[torch.Tensor, TensorPointer], ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: - + if self.recompute_layer and not isinstance(hidden_states, TensorPointer): hidden_states, sequence_mask = self._checkpointed_forward(hidden_states, sequence_mask) else: @@ -665,6 +756,7 @@ def forward( "sequence_mask": sequence_mask, } + class Embedding(nn.Module, AttachableStore): def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]): super().__init__() @@ -727,7 +819,14 @@ def __init__( module_input_keys={"input_ids", "input_mask"}, module_output_keys={"input_embeds"}, ) - + log_rank(f"Initialize RoPE Theta = {config.rope_theta}", logger=logger, level=logging.INFO, rank=0) + if config.rope_interleaved: + log_rank( + "The RoPE interleaved version differs from the Transformers implementation. It's better to set rope_interleaved=False if you need to convert the weights to Transformers", + logger=logger, + level=logging.INFO, + rank=0, + ) self.decoder = nn.ModuleList( [ PipelineBlock( diff --git a/src/nanotron/nn/layer_norm.py b/src/nanotron/nn/layer_norm.py index 688eaa78..ef3b4c50 100644 --- a/src/nanotron/nn/layer_norm.py +++ b/src/nanotron/nn/layer_norm.py @@ -22,6 +22,8 @@ def forward( ) +# This is equivalent to LLaMA RMSNorm +# https://github.com/huggingface/transformers/blob/28952248b19db29ca25ccf34a5eec413376494a9/src/transformers/models/llama/modeling_llama.py#L112 class TritonRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype}