From 41bf8e09d70b830564a3e403c50282935928401c Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Fri, 2 Aug 2024 21:27:05 +0200 Subject: [PATCH 01/36] Equivalent with nano llama 3 --- train_gpt2.py | 749 ++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 662 insertions(+), 87 deletions(-) diff --git a/train_gpt2.py b/train_gpt2.py index b9dee8701..46b97bcd9 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -23,6 +23,21 @@ import inspect from contextlib import nullcontext from dataclasses import dataclass +from pathlib import Path + +from typing import ( + AbstractSet, + cast, + Collection, + Dict, + Iterator, + List, + Literal, + Optional, + Sequence, + Tuple, + Union, +) import numpy as np import torch @@ -33,10 +48,217 @@ from torch.distributed import init_process_group, destroy_process_group from torch.distributed.optim import ZeroRedundancyOptimizer import torch.distributed as dist +from tiktoken.load import load_tiktoken_bpe # ----------------------------------------------------------------------------- # PyTorch nn.Module definitions for the GPT-2 model +# The tiktoken tokenizer can handle <=400k chars without +# pyo3_runtime.PanicException. +TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + +# https://github.com/openai/tiktoken/issues/195 +# Here we iterate over subsequences and split if we exceed the limit +# of max consecutive non-whitespace or whitespace characters. +MAX_NO_WHITESPACES_CHARS = 25_000 + + +class Tokenizer: + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + """ + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 + + def __init__(self, model_path: str): + """ + Initializes the Tokenizer with a Tiktoken model. + + Args: + model_path (str): The path to the Tiktoken model file. + """ + assert os.path.isfile(model_path), model_path + + mergeable_ranks = load_tiktoken_bpe(model_path) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|finetune_right_pad_id|>", + "<|step_id|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eom_id|>", # end of message + "<|eot_id|>", # end of turn + "<|python_tag|>", + ] + reserved_tokens = [ + f"<|reserved_special_token_{2 + i}|>" + for i in range(self.num_reserved_special_tokens - len(special_tokens)) + ] + special_tokens = special_tokens + reserved_tokens + + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + + self.n_words: int = num_base_tokens + len(special_tokens) + # BOS / EOS token IDs + self.bos_id: int = self.special_tokens["<|begin_of_text|>"] + self.eos_id: int = self.special_tokens["<|end_of_text|>"] + self.eot_id: int = self.special_tokens["<|eot_id|>"] + self.eom_id: int = self.special_tokens["<|eom_id|>"] + self.python_tag_id = self.special_tokens["<|python_tag|>"] + self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"] + self.stop_tokens = [ + self.special_tokens["<|eom_id|>"], + self.special_tokens["<|eot_id|>"], + ] + + def encode( + self, + s: str, + *, + bos: bool, + eos: bool, + allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, + disallowed_special: Union[Literal["all"], Collection[str]] = (), + ) -> List[int]: + """ + Encodes a string into a list of token IDs. + + Args: + s (str): The input string to be encoded. + bos (bool): Whether to prepend the beginning-of-sequence token. + eos (bool): Whether to append the end-of-sequence token. + allowed_tokens ("all"|set[str]): allowed special tokens in string + disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string + + Returns: + list[int]: A list of token IDs. + + By default, setting disallowed_special=() encodes a string by ignoring + special tokens. Specifically: + - Setting `disallowed_special` to () will cause all text corresponding + to special tokens to be encoded as natural text (insteading of raising + an error). + - Setting `allowed_special` to "all" will treat all text corresponding + to special tokens to be encoded as special tokens. + """ + if allowed_special is None: + allowed_special = set() + assert type(s) is str + + substrs = ( + substr + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) + for substr in self._split_whitespaces_or_nonwhitespaces( + s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS + ) + ) + t: List[int] = [] + for substr in substrs: + t.extend( + self.model.encode( + substr, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ) + if bos: + t.insert(0, self.bos_id) + if eos: + t.append(self.eos_id) + return t + + def decode(self, t: Sequence[int]) -> str: + """ + Decodes a list of token IDs into a string. + + Args: + t (List[int]): The list of token IDs to be decoded. + + Returns: + str: The decoded string. + """ + # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. + return self.model.decode(cast(List[int], t)) + + @staticmethod + def _split_whitespaces_or_nonwhitespaces( + s: str, max_consecutive_slice_len: int + ) -> Iterator[str]: + """ + Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` + consecutive whitespaces or consecutive non-whitespaces. + """ + current_slice_len = 0 + current_slice_is_space = s[0].isspace() if len(s) > 0 else False + slice_start = 0 + + for i in range(len(s)): + is_now_space = s[i].isspace() + + if current_slice_is_space ^ is_now_space: + current_slice_len = 1 + current_slice_is_space = is_now_space + else: + current_slice_len += 1 + if current_slice_len > max_consecutive_slice_len: + yield s[slice_start:i] + slice_start = i + current_slice_len = 1 + yield s[slice_start:] + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + +def sample_top_p(probs, p): + """ + Perform top-p (nucleus) sampling on a probability distribution. + + Args: + probs (torch.Tensor): Probability distribution tensor. + p (float): Probability threshold for top-p sampling. + + Returns: + torch.Tensor: Sampled token indices. + + Note: + Top-p sampling selects the smallest set of tokens whose cumulative probability mass + exceeds the threshold p. The distribution is renormalized based on the selected tokens. + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + class NewGELU(nn.Module): """Careful there are a few versions of GeLU, this one is the exact one used by OpenAI""" def forward(self, input): @@ -45,72 +267,197 @@ def forward(self, input): # using a global to toggle flash-attention FLASH = 0 +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + +def apply_scaling(freqs: torch.Tensor): + # Values obtained from grid search + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 # original llama3 length + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False +): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + if use_scaled: + freqs = apply_scaling(freqs) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + class CausalSelfAttention(nn.Module): def __init__(self, config): super().__init__() assert config.n_embd % config.n_head == 0 - # key, query, value projections for all heads, but in a batch - self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) - # output projection - self.c_proj = nn.Linear(config.n_embd, config.n_embd) - self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 - # regularization - self.n_head = config.n_head - self.n_embd = config.n_embd + is_llama = config.is_llama + self.is_llama = is_llama + if not is_llama: + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd) + self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 + # regularization + self.n_head = config.n_head + self.n_embd = config.n_embd + else: + self.n_kv_heads = config.n_kv_head + self.n_local_heads = config.n_head + self.n_local_kv_heads = self.n_kv_heads + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = config.n_embd // config.n_head + + self.wq = nn.Linear(config.n_embd, config.n_head * self.head_dim, bias=False) + self.wk = nn.Linear(config.n_embd, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(config.n_embd, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(config.n_head * self.head_dim, config.n_embd, bias=False) + # not really a 'bias', more of a mask, but following the OpenAI/HF naming though self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) - .view(1, 1, config.block_size, config.block_size)) - - def forward(self, x): - B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) - # calculate query, key, values for all heads in batch and move head forward to be the batch dim - qkv = self.c_attn(x) - q, k, v = qkv.split(self.n_embd, dim=2) - k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - if FLASH: - # flashattention - y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + .view(1, 1, config.block_size, config.block_size)) + + def forward(self, x, freqs_cis=None): + if not self.is_llama: + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + qkv = self.c_attn(x) + q, k, v = qkv.split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + if FLASH: + # flashattention + y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + else: + # manual implementation of attention + # this materializes the large (T,T) matrix for all the queries and keys + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + # output projection + y = self.c_proj(y) + return y else: - # manual implementation of attention - # this materializes the large (T,T) matrix for all the queries and keys - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) - att = F.softmax(att, dim=-1) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side - # output projection - y = self.c_proj(y) - return y + bsz, seqlen, _ = x.shape + + # QKV + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + # rotate QK (rope) + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # repeat k/v heads if n_kv_heads < n_heads (GQA) + keys = repeat_kv(xk, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) + + # attention + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) + values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + + scores = scores.masked_fill(self.bias[:,:,:seqlen,:seqlen] == 0, float('-inf')) + + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + return self.wo(output) class MLP(nn.Module): def __init__(self, config): super().__init__() - self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) - self.gelu = NewGELU() - self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) - self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 + is_llama = config.is_llama + self.is_llama = is_llama + if not is_llama: + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) + self.gelu = NewGELU() + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) + self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 + else: + hidden_dim = 4 * config.n_embd + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if config.ffn_dim_multiplier is not None: + hidden_dim = int(config.ffn_dim_multiplier * hidden_dim) + hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of) + self.w1 = nn.Linear(config.n_embd, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, config.n_embd, bias=False) + self.w3 = nn.Linear(config.n_embd, hidden_dim, bias=False) def forward(self, x): - x = self.c_fc(x) - x = self.gelu(x) - x = self.c_proj(x) - return x + if not self.is_llama: + x = self.c_fc(x) + x = self.gelu(x) + x = self.c_proj(x) + return x + else: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) class Block(nn.Module): def __init__(self, config): super().__init__() - self.ln_1 = nn.LayerNorm(config.n_embd) + is_llama = config.is_llama + self.ln_1 = RMSNorm(config.n_embd, config.norm_eps) if is_llama else nn.LayerNorm(config.n_embd) self.attn = CausalSelfAttention(config) - self.ln_2 = nn.LayerNorm(config.n_embd) + self.ln_2 = RMSNorm(config.n_embd, config.norm_eps) if is_llama else nn.LayerNorm(config.n_embd) self.mlp = MLP(config) - def forward(self, x): - x = x + self.attn(self.ln_1(x)) + def forward(self, x, freqs_cis=None): + x = x + self.attn(self.ln_1(x), freqs_cis) x = x + self.mlp(self.ln_2(x)) return x @@ -119,32 +466,60 @@ def forward(self, x): @dataclass class GPTConfig: + is_llama = False block_size: int = 1024 vocab_size: int = 50257 n_layer: int = 12 n_head: int = 12 n_embd: int = 768 +@dataclass +class Llama31Config: + is_llama = True + block_size: int = 1024 + vocab_size: int = 128256 + n_layer: int = 32 + n_head: int = 32 + n_kv_head: int = 8 + n_embd: int = 4096 + ffn_dim_multiplier: float = 1.3 + multiple_of: int = 1024 + norm_eps: float = 1e-5 + rope_theta: float = 500000.0 + use_scaled_rope: bool = True + class GPT(nn.Module): def __init__(self, config): super().__init__() self.config = config + is_llama = config.is_llama + self.is_llama = is_llama self.transformer = nn.ModuleDict(dict( wte = nn.Embedding(config.vocab_size, config.n_embd), - wpe = nn.Embedding(config.block_size, config.n_embd), + **({'wpe': nn.Embedding(config.block_size, config.n_embd)} if not is_llama else {}), h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), - ln_f = nn.LayerNorm(config.n_embd), + ln_f = RMSNorm(config.n_embd, config.norm_eps) if is_llama else nn.LayerNorm(config.n_embd), )) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.lm_head.LLMC_SKIP_INIT = 1 # don't init this one, we will tie weights - self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying + if not is_llama: + self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying # init all weights, use a torch rng object to be very careful self.init_rng = torch.Generator() self.init_rng.manual_seed(42) - self.apply(self._init_weights) + if not is_llama: + self.apply(self._init_weights) + + if is_llama: + self.freqs_cis = precompute_freqs_cis( + config.n_embd // config.n_head, + config.block_size * 2, + config.rope_theta, + config.use_scaled_rope, + ) def _init_weights(self, module): if isinstance(module, nn.Linear): @@ -167,20 +542,25 @@ def forward(self, idx, targets=None, return_logits=True): # forward the GPT model itself tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) - x = tok_emb + pos_emb + freqs_cis = None + if not self.is_llama: + pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) + x = tok_emb + pos_emb + else: + x = tok_emb + freqs_cis = self.freqs_cis[:t] - for block in self.transformer.h: - x = block(x) + for i, block in enumerate(self.transformer.h): + x = block(x, freqs_cis) x = self.transformer.ln_f(x) if targets is not None: # if we are given some desired targets also calculate the loss - logits = self.lm_head(x) + logits = self.lm_head(x).float() loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) else: # inference-time mini-optimization: only forward the lm_head on the very last position - logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim + logits = self.lm_head(x[:, [-1], :]).float() # note: using list [-1] to preserve the time dim loss = None # there are performance reasons why not returning logits is prudent, if not needed @@ -189,6 +569,64 @@ def forward(self, idx, targets=None, return_logits=True): return logits, loss + @staticmethod + def modify_llama_keys(checkpoint, config: Llama31Config): + # 1) rename key tok_embeddings.weight to transformer.wte.weight + checkpoint['transformer.wte.weight'] = checkpoint.pop('tok_embeddings.weight') + + # layers.0.attention_norm.weight -> transformer.h.0.ln_1.weight + # layers.0.ffn_norm.weight -> transformer.h.0.ln_2.weight + # loop over all layers + for i in range(config.n_layer): + for name in ['attention_norm', 'ffn_norm']: + for suffix in ['weight']: + old_key = f'layers.{i}.{name}.{suffix}' + new_key = f'transformer.h.{i}.ln_{1 if name == "attention_norm" else 2}.{suffix}' + checkpoint[new_key] = checkpoint.pop(old_key) + + # layers.0.attention.wq.weight -> transformer.h.0.attn.wq.weight + # layers.0.attention.wk.weight -> transformer.h.0.attn.wk.weight + # layers.0.attention.wv.weight -> transformer.h.0.attn.wv.weight + # layers.0.attention.wo.weight -> transformer.h.0.attn.wo.weight + # loop over all layers + for i in range(config.n_layer): + for name in ['attention.wq', 'attention.wk', 'attention.wv', 'attention.wo']: + for suffix in ['weight']: + old_key = f'layers.{i}.{name}.{suffix}' + new_key = f'transformer.h.{i}.attn.{name.split(".")[-1]}.{suffix}' + checkpoint[new_key] = checkpoint.pop(old_key) + + # layers.0.feed_forward.w1.weight -> transformer.h.0.mlp.w1.weight + # layers.0.feed_forward.w2.weight -> transformer.h.0.mlp.w2.weight + # layers.0.feed_forward.w3.weight -> transformer.h.0.mlp.w3.weight + # loop over all layers + for i in range(config.n_layer): + for name in ['feed_forward.w1', 'feed_forward.w2', 'feed_forward.w3']: + for suffix in ['weight']: + old_key = f'layers.{i}.{name}.{suffix}' + new_key = f'transformer.h.{i}.mlp.{name.split(".")[-1]}.{suffix}' + checkpoint[new_key] = checkpoint.pop(old_key) + + # norm.weight -> transformer.ln_f.weight + # output.weight -> lm_head.weight + checkpoint['transformer.ln_f.weight'] = checkpoint.pop('norm.weight') + checkpoint['lm_head.weight'] = checkpoint.pop('output.weight') + + return checkpoint + + @classmethod + def from_pretrained_llama3_1(cls): + ckpt_dir = "/home/aleksa/Documents/eureka/nano-llama31/llama-models/models/llama3_1/Meta-Llama-3.1-8B" + ckpt_path = sorted(Path(ckpt_dir).glob("*.pth"))[0] + checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) + checkpoint = GPT.modify_llama_keys(checkpoint, Llama31Config()) + model_args = Llama31Config() + + torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) + model = GPT(model_args) + model.load_state_dict(checkpoint, strict=False) + return model + @classmethod def from_pretrained(cls, model_type): """Loads pretrained GPT-2 model weights from huggingface""" @@ -296,6 +734,116 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): return idx + @torch.inference_mode() + def generate2( + self, + prompt_tokens: List[List[int]], + max_gen_len: int, + temperature: float = 0.6, + top_p: float = 0.9, + logprobs: bool = False, + echo: bool = False, + ) -> Tuple[List[List[int]], Optional[List[List[float]]]]: + """ + Generate text sequences based on provided prompts using the language generation model. + + Args: + prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers. + max_gen_len (int): Maximum length of the generated text sequence. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + + Returns: + Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities. + + Note: + This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness. + If logprobs is True, token log probabilities are computed for each generated token. + + """ + params = self.config + bsz = len(prompt_tokens) + # assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) + + min_prompt_len = min(len(t) for t in prompt_tokens) + max_prompt_len = max(len(t) for t in prompt_tokens) + assert max_prompt_len <= params.block_size + total_len = min(params.block_size, max_gen_len + max_prompt_len) + + pad_id = self.tokenizer.pad_id + tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") + for k, t in enumerate(prompt_tokens): + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + if logprobs: + token_logprobs = torch.zeros_like(tokens, dtype=torch.float) + + prev_pos = 0 + eos_reached = torch.tensor([False] * bsz, device="cuda") + input_text_mask = tokens != pad_id + + if min_prompt_len == total_len: + logits, _ = self.forward(tokens) + token_logprobs = -F.cross_entropy( + input=logits.transpose(1, 2), + target=tokens, + reduction="none", + ignore_index=pad_id, + ) + + stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens)) + + for cur_pos in range(min_prompt_len, total_len): + logits, _ = self.forward(tokens[:, :cur_pos]) + if temperature > 0: + probs = torch.softmax(logits[:, -1] / temperature, dim=-1) + next_token = sample_top_p(probs, top_p) + else: + next_token = torch.argmax(logits[:, -1], dim=-1) + + next_token = next_token.reshape(-1) + # only replace token if prompt has already been generated + next_token = torch.where( + input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token + ) + tokens[:, cur_pos] = next_token + if logprobs: + token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( + input=logits.transpose(1, 2), + target=tokens[:, prev_pos + 1 : cur_pos + 1], + reduction="none", + ignore_index=pad_id, + ) + eos_reached |= (~input_text_mask[:, cur_pos]) & ( + torch.isin(next_token, stop_tokens) + ) + prev_pos = cur_pos + if all(eos_reached): + break + + if logprobs: + token_logprobs = token_logprobs.tolist() + out_tokens, out_logprobs = [], [] + for i, toks in enumerate(tokens.tolist()): + # cut to max gen len + start = 0 if echo else len(prompt_tokens[i]) + toks = toks[start : len(prompt_tokens[i]) + max_gen_len] + probs = None + if logprobs: + probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] + # cut to after eos tok if any + for stop_token in self.tokenizer.stop_tokens: + try: + eos_idx = toks.index(stop_token) + toks = toks[:eos_idx] + probs = probs[:eos_idx] if logprobs else None + except ValueError: + pass + out_tokens.append(toks) + out_logprobs.append(probs) + return (out_tokens, out_logprobs if logprobs else None) + # ----------------------------------------------------------------------------- # Our own simple Distributed Data Loader @@ -541,6 +1089,7 @@ def print0(*args, **kwargs): # default settings will overfit a tiny batch of data # and save model weights and debug state to disk on the first iteration parser = argparse.ArgumentParser() + parser.add_argument("--use_llama3", type=int, default=1, help="use llama3 model") # file system input / output parser.add_argument("--input_bin", type=str, default="dev/data/tinyshakespeare/tiny_shakespeare_val.bin", help="input .bin to train on") parser.add_argument("--input_val_bin", type=str, default="", help="input .bin to eval validation loss on") @@ -627,7 +1176,7 @@ def print0(*args, **kwargs): # set up a context manager following the desired dtype and device ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype] - ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() + ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if (device_type == "cuda" and not args.use_llama3) else nullcontext() # rng / reproducibility torch.manual_seed(42) @@ -659,8 +1208,12 @@ def print0(*args, **kwargs): }[args.model] model = GPT(model_config) else: - # load the GPT-2 model weights - model = GPT.from_pretrained(args.model) + if args.use_llama3: + model = GPT.from_pretrained_llama3_1() + else: + # load the GPT-2 model weights + model = GPT.from_pretrained(args.model) + model.train() model.to(device) if args.compile: @@ -682,7 +1235,7 @@ def print0(*args, **kwargs): # PyTorch -> C bridge: save some weights and state for C to load later as reference # do one forward pass to generate ground truth for our C tests - if master_process and args.write_tensors and (not args.inference_only): + if not args.use_llama3 and master_process and args.write_tensors and (not args.inference_only): x, y = train_loader.next_batch() x, y = x.to(device), y.to(device) logits, loss = model(x, y) @@ -745,41 +1298,63 @@ def get_lr(it): last_step = (step == args.num_iterations) # once in a while evaluate the validation dataset - if (args.val_loss_every > 0 \ - and (step % args.val_loss_every == 0 or last_step)) \ - and (val_loader is not None): - model.eval() - val_loader.reset() - with torch.no_grad(): - val_loss = 0.0 - for _ in range(args.val_max_steps): - x, y = val_loader.next_batch() - x, y = x.to(device), y.to(device) - _, loss = model(x, y, return_logits=False) - val_loss += loss.item() - val_loss /= args.val_max_steps - # log to console and to file - print0(f"val loss {val_loss}") - if master_process and logfile is not None: - with open(logfile, "a") as f: - f.write("s:%d tel:%f\n" % (step, val_loss)) + # if (args.val_loss_every > 0 \ + # and (step % args.val_loss_every == 0 or last_step)) \ + # and (val_loader is not None): + # model.eval() + # val_loader.reset() + # with torch.no_grad(): + # val_loss = 0.0 + # for _ in range(args.val_max_steps): + # x, y = val_loader.next_batch() + # x, y = x.to(device), y.to(device) + # _, loss = model(x, y, return_logits=False) + # val_loss += loss.item() + # val_loss /= args.val_max_steps + # # log to console and to file + # print0(f"val loss {val_loss}") + # if master_process and logfile is not None: + # with open(logfile, "a") as f: + # f.write("s:%d tel:%f\n" % (step, val_loss)) # once in a while perform model inference on the master process - if (args.sample_every > 0 \ - and (step % args.sample_every == 0 or last_step)) \ - and master_process: + if True: + # (args.sample_every > 0 \ + # and (step % args.sample_every == 0 or last_step)) \ + # and master_process: model.eval() # before we end, let's also do one round of inference # we'll kick off the generation with "<|endoftext|>", which designates the start of a new sequence - start_ids = [enc.eot_token] - xg = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) - max_new_tokens = 32 - temperature = 1.0 - top_k = 40 - yg = raw_model.generate(xg, max_new_tokens, temperature=temperature, top_k=top_k) - print0('---------------') - print0(enc.decode(yg[0].tolist())) - print0('---------------') + # start_ids = [enc.eot_token] + # xg = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) + # max_new_tokens = 32 + # temperature = 1.0 + # top_k = 40 + tokenizer_path = "/home/aleksa/Documents/eureka/nano-llama31/llama-models/models/llama3_1/Meta-Llama-3.1-8B/tokenizer.model" + tokenizer = Tokenizer(model_path=tokenizer_path) + raw_model.tokenizer = tokenizer + prompts: List[str] = [ + # For these prompts, the expected answer is the natural continuation of the prompt + "Clearly, the meaning of life is", + "Simply put, the theory of relativity states that", + """The repo llm.c on GitHub is""", + # Few shot prompt (providing a few examples before asking model to complete more); + """Translate English to French: + + sea otter => loutre de mer + peppermint => menthe poivrée + plush girafe => girafe peluche + cheese =>""", + ] + + prompt_tokens = [tokenizer.encode(x, bos=True, eos=False) for x in prompts] + + generation_tokens, _ = raw_model.generate2(prompt_tokens, max_gen_len=64, temperature=0.6, top_p=0.9, logprobs=False, echo=False) + results = [{"generation": tokenizer.decode(t)} for t in generation_tokens] + for prompt, result in zip(prompts, results): + print(prompt, end="") # AK: change end="\n" to end="" + print(f"{result['generation']}") + print("\n==================================\n") # bit confusing: we want to make sure to eval and sample on 0th iteration # but also after the very last iteration. so we loop for step <= num_iterations From 838cd13c4f491d804d52cdc27bf796c61d271533 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Fri, 2 Aug 2024 22:38:11 +0200 Subject: [PATCH 02/36] Refactor --- llmc_py/rope.py | 59 +++++ llmc_py/tokenizer.py | 189 ++++++++++++++++ llmc_py/utils.py | 53 +++++ train_gpt2.py | 505 ++++++++++--------------------------------- 4 files changed, 421 insertions(+), 385 deletions(-) create mode 100644 llmc_py/rope.py create mode 100644 llmc_py/tokenizer.py create mode 100644 llmc_py/utils.py diff --git a/llmc_py/rope.py b/llmc_py/rope.py new file mode 100644 index 000000000..3caf58073 --- /dev/null +++ b/llmc_py/rope.py @@ -0,0 +1,59 @@ +# From: https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py + +import math +from typing import Tuple +import torch + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + +def apply_scaling(freqs: torch.Tensor): + # Values obtained from grid search + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 # original llama3 length + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False +): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + if use_scaled: + freqs = apply_scaling(freqs) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis \ No newline at end of file diff --git a/llmc_py/tokenizer.py b/llmc_py/tokenizer.py new file mode 100644 index 000000000..528de113c --- /dev/null +++ b/llmc_py/tokenizer.py @@ -0,0 +1,189 @@ +# From: https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/tokenizer.py + +import os +from pathlib import Path +from typing import ( + AbstractSet, + Callable, + Collection, + Dict, + Iterator, + List, + Literal, + Optional, + Sequence, + Union, + cast, +) + +import tiktoken +from tiktoken.load import load_tiktoken_bpe + +# The tiktoken tokenizer can handle <=400k chars without +# pyo3_runtime.PanicException. +TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + +# https://github.com/openai/tiktoken/issues/195 +# Here we iterate over subsequences and split if we exceed the limit +# of max consecutive non-whitespace or whitespace characters. +MAX_NO_WHITESPACES_CHARS = 25_000 + + +class Tokenizer: + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + """ + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 + + def __init__(self, model_path: str): + """ + Initializes the Tokenizer with a Tiktoken model. + + Args: + model_path (str): The path to the Tiktoken model file. + """ + assert os.path.isfile(model_path), model_path + + mergeable_ranks = load_tiktoken_bpe(model_path) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|finetune_right_pad_id|>", + "<|step_id|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eom_id|>", # end of message + "<|eot_id|>", # end of turn + "<|python_tag|>", + ] + reserved_tokens = [ + f"<|reserved_special_token_{2 + i}|>" + for i in range(self.num_reserved_special_tokens - len(special_tokens)) + ] + special_tokens = special_tokens + reserved_tokens + + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + + self.n_words: int = num_base_tokens + len(special_tokens) + # BOS / EOS token IDs + self.bos_id: int = self.special_tokens["<|begin_of_text|>"] + self.eos_id: int = self.special_tokens["<|end_of_text|>"] + self.eot_id: int = self.special_tokens["<|eot_id|>"] + self.eom_id: int = self.special_tokens["<|eom_id|>"] + self.python_tag_id = self.special_tokens["<|python_tag|>"] + self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"] + self.stop_tokens = [ + self.special_tokens["<|eom_id|>"], + self.special_tokens["<|eot_id|>"], + ] + + def encode( + self, + s: str, + *, + bos: bool, + eos: bool, + allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, + disallowed_special: Union[Literal["all"], Collection[str]] = (), + ) -> List[int]: + """ + Encodes a string into a list of token IDs. + + Args: + s (str): The input string to be encoded. + bos (bool): Whether to prepend the beginning-of-sequence token. + eos (bool): Whether to append the end-of-sequence token. + allowed_tokens ("all"|set[str]): allowed special tokens in string + disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string + + Returns: + list[int]: A list of token IDs. + + By default, setting disallowed_special=() encodes a string by ignoring + special tokens. Specifically: + - Setting `disallowed_special` to () will cause all text corresponding + to special tokens to be encoded as natural text (insteading of raising + an error). + - Setting `allowed_special` to "all" will treat all text corresponding + to special tokens to be encoded as special tokens. + """ + if allowed_special is None: + allowed_special = set() + assert type(s) is str + + substrs = ( + substr + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) + for substr in self._split_whitespaces_or_nonwhitespaces( + s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS + ) + ) + t: List[int] = [] + for substr in substrs: + t.extend( + self.model.encode( + substr, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ) + if bos: + t.insert(0, self.bos_id) + if eos: + t.append(self.eos_id) + return t + + def decode(self, t: Sequence[int]) -> str: + """ + Decodes a list of token IDs into a string. + + Args: + t (List[int]): The list of token IDs to be decoded. + + Returns: + str: The decoded string. + """ + # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. + return self.model.decode(cast(List[int], t)) + + @staticmethod + def _split_whitespaces_or_nonwhitespaces( + s: str, max_consecutive_slice_len: int + ) -> Iterator[str]: + """ + Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` + consecutive whitespaces or consecutive non-whitespaces. + """ + current_slice_len = 0 + current_slice_is_space = s[0].isspace() if len(s) > 0 else False + slice_start = 0 + + for i in range(len(s)): + is_now_space = s[i].isspace() + + if current_slice_is_space ^ is_now_space: + current_slice_len = 1 + current_slice_is_space = is_now_space + else: + current_slice_len += 1 + if current_slice_len > max_consecutive_slice_len: + yield s[slice_start:i] + slice_start = i + current_slice_len = 1 + yield s[slice_start:] \ No newline at end of file diff --git a/llmc_py/utils.py b/llmc_py/utils.py new file mode 100644 index 000000000..66ff7a42e --- /dev/null +++ b/llmc_py/utils.py @@ -0,0 +1,53 @@ +import torch +from torch import nn + +# Special modules +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + +# Sampling +def sample_top_p(probs, p): + """ + Perform top-p (nucleus) sampling on a probability distribution. + + Args: + probs (torch.Tensor): Probability distribution tensor. + p (float): Probability threshold for top-p sampling. + + Returns: + torch.Tensor: Sampled token indices. + + Note: + Top-p sampling selects the smallest set of tokens whose cumulative probability mass + exceeds the threshold p. The distribution is renormalized based on the selected tokens. + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + +# GQA +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) \ No newline at end of file diff --git a/train_gpt2.py b/train_gpt2.py index 46b97bcd9..46a264701 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -26,17 +26,9 @@ from pathlib import Path from typing import ( - AbstractSet, - cast, - Collection, - Dict, - Iterator, List, - Literal, Optional, - Sequence, Tuple, - Union, ) import numpy as np @@ -48,217 +40,14 @@ from torch.distributed import init_process_group, destroy_process_group from torch.distributed.optim import ZeroRedundancyOptimizer import torch.distributed as dist -from tiktoken.load import load_tiktoken_bpe + +from llmc_py.tokenizer import Tokenizer +from llmc_py.rope import precompute_freqs_cis, apply_rotary_emb +from llmc_py.utils import repeat_kv, sample_top_p, RMSNorm # ----------------------------------------------------------------------------- # PyTorch nn.Module definitions for the GPT-2 model -# The tiktoken tokenizer can handle <=400k chars without -# pyo3_runtime.PanicException. -TIKTOKEN_MAX_ENCODE_CHARS = 400_000 - -# https://github.com/openai/tiktoken/issues/195 -# Here we iterate over subsequences and split if we exceed the limit -# of max consecutive non-whitespace or whitespace characters. -MAX_NO_WHITESPACES_CHARS = 25_000 - - -class Tokenizer: - """ - Tokenizing and encoding/decoding text using the Tiktoken tokenizer. - """ - - special_tokens: Dict[str, int] - - num_reserved_special_tokens = 256 - - pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 - - def __init__(self, model_path: str): - """ - Initializes the Tokenizer with a Tiktoken model. - - Args: - model_path (str): The path to the Tiktoken model file. - """ - assert os.path.isfile(model_path), model_path - - mergeable_ranks = load_tiktoken_bpe(model_path) - num_base_tokens = len(mergeable_ranks) - special_tokens = [ - "<|begin_of_text|>", - "<|end_of_text|>", - "<|reserved_special_token_0|>", - "<|reserved_special_token_1|>", - "<|finetune_right_pad_id|>", - "<|step_id|>", - "<|start_header_id|>", - "<|end_header_id|>", - "<|eom_id|>", # end of message - "<|eot_id|>", # end of turn - "<|python_tag|>", - ] - reserved_tokens = [ - f"<|reserved_special_token_{2 + i}|>" - for i in range(self.num_reserved_special_tokens - len(special_tokens)) - ] - special_tokens = special_tokens + reserved_tokens - - self.special_tokens = { - token: num_base_tokens + i for i, token in enumerate(special_tokens) - } - self.model = tiktoken.Encoding( - name=Path(model_path).name, - pat_str=self.pat_str, - mergeable_ranks=mergeable_ranks, - special_tokens=self.special_tokens, - ) - - self.n_words: int = num_base_tokens + len(special_tokens) - # BOS / EOS token IDs - self.bos_id: int = self.special_tokens["<|begin_of_text|>"] - self.eos_id: int = self.special_tokens["<|end_of_text|>"] - self.eot_id: int = self.special_tokens["<|eot_id|>"] - self.eom_id: int = self.special_tokens["<|eom_id|>"] - self.python_tag_id = self.special_tokens["<|python_tag|>"] - self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"] - self.stop_tokens = [ - self.special_tokens["<|eom_id|>"], - self.special_tokens["<|eot_id|>"], - ] - - def encode( - self, - s: str, - *, - bos: bool, - eos: bool, - allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, - disallowed_special: Union[Literal["all"], Collection[str]] = (), - ) -> List[int]: - """ - Encodes a string into a list of token IDs. - - Args: - s (str): The input string to be encoded. - bos (bool): Whether to prepend the beginning-of-sequence token. - eos (bool): Whether to append the end-of-sequence token. - allowed_tokens ("all"|set[str]): allowed special tokens in string - disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string - - Returns: - list[int]: A list of token IDs. - - By default, setting disallowed_special=() encodes a string by ignoring - special tokens. Specifically: - - Setting `disallowed_special` to () will cause all text corresponding - to special tokens to be encoded as natural text (insteading of raising - an error). - - Setting `allowed_special` to "all" will treat all text corresponding - to special tokens to be encoded as special tokens. - """ - if allowed_special is None: - allowed_special = set() - assert type(s) is str - - substrs = ( - substr - for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) - for substr in self._split_whitespaces_or_nonwhitespaces( - s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS - ) - ) - t: List[int] = [] - for substr in substrs: - t.extend( - self.model.encode( - substr, - allowed_special=allowed_special, - disallowed_special=disallowed_special, - ) - ) - if bos: - t.insert(0, self.bos_id) - if eos: - t.append(self.eos_id) - return t - - def decode(self, t: Sequence[int]) -> str: - """ - Decodes a list of token IDs into a string. - - Args: - t (List[int]): The list of token IDs to be decoded. - - Returns: - str: The decoded string. - """ - # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. - return self.model.decode(cast(List[int], t)) - - @staticmethod - def _split_whitespaces_or_nonwhitespaces( - s: str, max_consecutive_slice_len: int - ) -> Iterator[str]: - """ - Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` - consecutive whitespaces or consecutive non-whitespaces. - """ - current_slice_len = 0 - current_slice_is_space = s[0].isspace() if len(s) > 0 else False - slice_start = 0 - - for i in range(len(s)): - is_now_space = s[i].isspace() - - if current_slice_is_space ^ is_now_space: - current_slice_len = 1 - current_slice_is_space = is_now_space - else: - current_slice_len += 1 - if current_slice_len > max_consecutive_slice_len: - yield s[slice_start:i] - slice_start = i - current_slice_len = 1 - yield s[slice_start:] - -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()).type_as(x) - return output * self.weight - -def sample_top_p(probs, p): - """ - Perform top-p (nucleus) sampling on a probability distribution. - - Args: - probs (torch.Tensor): Probability distribution tensor. - p (float): Probability threshold for top-p sampling. - - Returns: - torch.Tensor: Sampled token indices. - - Note: - Top-p sampling selects the smallest set of tokens whose cumulative probability mass - exceeds the threshold p. The distribution is renormalized based on the selected tokens. - """ - probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = probs_sum - probs_sort > p - probs_sort[mask] = 0.0 - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - next_token = torch.multinomial(probs_sort, num_samples=1) - next_token = torch.gather(probs_idx, -1, next_token) - return next_token - class NewGELU(nn.Module): """Careful there are a few versions of GeLU, this one is the exact one used by OpenAI""" def forward(self, input): @@ -267,79 +56,13 @@ def forward(self, input): # using a global to toggle flash-attention FLASH = 0 -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): - ndim = x.ndim - assert 0 <= 1 < ndim - assert freqs_cis.shape == (x.shape[1], x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - -def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" - bs, slen, n_kv_heads, head_dim = x.shape - if n_rep == 1: - return x - return ( - x[:, :, :, None, :] - .expand(bs, slen, n_kv_heads, n_rep, head_dim) - .reshape(bs, slen, n_kv_heads * n_rep, head_dim) - ) - -def apply_scaling(freqs: torch.Tensor): - # Values obtained from grid search - scale_factor = 8 - low_freq_factor = 1 - high_freq_factor = 4 - old_context_len = 8192 # original llama3 length - - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - new_freqs = [] - for freq in freqs: - wavelen = 2 * math.pi / freq - if wavelen < high_freq_wavelen: - new_freqs.append(freq) - elif wavelen > low_freq_wavelen: - new_freqs.append(freq / scale_factor) - else: - assert low_freq_wavelen != high_freq_wavelen - smooth = (old_context_len / wavelen - low_freq_factor) / ( - high_freq_factor - low_freq_factor - ) - new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) - return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) - -def precompute_freqs_cis( - dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False -): - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device, dtype=torch.float32) - if use_scaled: - freqs = apply_scaling(freqs) - freqs = torch.outer(t, freqs) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - class CausalSelfAttention(nn.Module): def __init__(self, config): super().__init__() assert config.n_embd % config.n_head == 0 - is_llama = config.is_llama - self.is_llama = is_llama - if not is_llama: + self.is_llama = config.is_llama + if not self.is_llama: # key, query, value projections for all heads, but in a batch self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) # output projection @@ -349,20 +72,20 @@ def __init__(self, config): self.n_head = config.n_head self.n_embd = config.n_embd else: - self.n_kv_heads = config.n_kv_head - self.n_local_heads = config.n_head - self.n_local_kv_heads = self.n_kv_heads - self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.n_head = config.n_head + self.n_kv_head = config.n_kv_head + self.n_rep = self.n_head // self.n_kv_head self.head_dim = config.n_embd // config.n_head + # TODO(gordicaleksa): this can be easily made the same as the above (c_attn, c_proj) self.wq = nn.Linear(config.n_embd, config.n_head * self.head_dim, bias=False) - self.wk = nn.Linear(config.n_embd, self.n_kv_heads * self.head_dim, bias=False) - self.wv = nn.Linear(config.n_embd, self.n_kv_heads * self.head_dim, bias=False) + self.wk = nn.Linear(config.n_embd, config.n_kv_head * self.head_dim, bias=False) + self.wv = nn.Linear(config.n_embd, config.n_kv_head * self.head_dim, bias=False) self.wo = nn.Linear(config.n_head * self.head_dim, config.n_embd, bias=False) # not really a 'bias', more of a mask, but following the OpenAI/HF naming though self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) - .view(1, 1, config.block_size, config.block_size)) + .view(1, 1, config.block_size, config.block_size)) def forward(self, x, freqs_cis=None): if not self.is_llama: @@ -388,13 +111,14 @@ def forward(self, x, freqs_cis=None): y = self.c_proj(y) return y else: + # TODO(gordicaleksa): this can be easily merged with the if branch above bsz, seqlen, _ = x.shape # QKV xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) - xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xq = xq.view(bsz, seqlen, self.n_head, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim) # rotate QK (rope) xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) @@ -419,9 +143,8 @@ class MLP(nn.Module): def __init__(self, config): super().__init__() - is_llama = config.is_llama - self.is_llama = is_llama - if not is_llama: + self.is_llama = config.is_llama + if not self.is_llama: self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) self.gelu = NewGELU() self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) @@ -450,10 +173,9 @@ class Block(nn.Module): def __init__(self, config): super().__init__() - is_llama = config.is_llama - self.ln_1 = RMSNorm(config.n_embd, config.norm_eps) if is_llama else nn.LayerNorm(config.n_embd) + self.ln_1 = RMSNorm(config.n_embd, config.norm_eps) if config.is_llama else nn.LayerNorm(config.n_embd) self.attn = CausalSelfAttention(config) - self.ln_2 = RMSNorm(config.n_embd, config.norm_eps) if is_llama else nn.LayerNorm(config.n_embd) + self.ln_2 = RMSNorm(config.n_embd, config.norm_eps) if config.is_llama else nn.LayerNorm(config.n_embd) self.mlp = MLP(config) def forward(self, x, freqs_cis=None): @@ -474,8 +196,9 @@ class GPTConfig: n_embd: int = 768 @dataclass -class Llama31Config: +class LlamaConfig: is_llama = True + version: str = "3.1" block_size: int = 1024 vocab_size: int = 128256 n_layer: int = 32 @@ -493,27 +216,26 @@ class GPT(nn.Module): def __init__(self, config): super().__init__() self.config = config - is_llama = config.is_llama - self.is_llama = is_llama + self.is_llama = config.is_llama self.transformer = nn.ModuleDict(dict( wte = nn.Embedding(config.vocab_size, config.n_embd), - **({'wpe': nn.Embedding(config.block_size, config.n_embd)} if not is_llama else {}), + **({} if self.is_llama else {'wpe': nn.Embedding(config.block_size, config.n_embd)}), h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), - ln_f = RMSNorm(config.n_embd, config.norm_eps) if is_llama else nn.LayerNorm(config.n_embd), + ln_f = RMSNorm(config.n_embd, config.norm_eps) if self.is_llama else nn.LayerNorm(config.n_embd), )) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.lm_head.LLMC_SKIP_INIT = 1 # don't init this one, we will tie weights - if not is_llama: + if not self.is_llama: self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying # init all weights, use a torch rng object to be very careful self.init_rng = torch.Generator() self.init_rng.manual_seed(42) - if not is_llama: + if not self.is_llama: self.apply(self._init_weights) - if is_llama: + if self.is_llama: self.freqs_cis = precompute_freqs_cis( config.n_embd // config.n_head, config.block_size * 2, @@ -556,11 +278,15 @@ def forward(self, idx, targets=None, return_logits=True): if targets is not None: # if we are given some desired targets also calculate the loss - logits = self.lm_head(x).float() + logits = self.lm_head(x) + if self.is_llama: + logits = logits.float() loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) else: # inference-time mini-optimization: only forward the lm_head on the very last position - logits = self.lm_head(x[:, [-1], :]).float() # note: using list [-1] to preserve the time dim + logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim + if self.is_llama: + logits = logits.float() loss = None # there are performance reasons why not returning logits is prudent, if not needed @@ -570,13 +296,12 @@ def forward(self, idx, targets=None, return_logits=True): return logits, loss @staticmethod - def modify_llama_keys(checkpoint, config: Llama31Config): - # 1) rename key tok_embeddings.weight to transformer.wte.weight + def adapt_llama_state_dict_keys(checkpoint, config: LlamaConfig): + # rename key tok_embeddings.weight to transformer.wte.weight checkpoint['transformer.wte.weight'] = checkpoint.pop('tok_embeddings.weight') - # layers.0.attention_norm.weight -> transformer.h.0.ln_1.weight - # layers.0.ffn_norm.weight -> transformer.h.0.ln_2.weight - # loop over all layers + # layers.x.attention_norm.weight -> transformer.h.x.ln_1.weight + # layers.x.ffn_norm.weight -> transformer.h.x.ln_2.weight for i in range(config.n_layer): for name in ['attention_norm', 'ffn_norm']: for suffix in ['weight']: @@ -584,11 +309,10 @@ def modify_llama_keys(checkpoint, config: Llama31Config): new_key = f'transformer.h.{i}.ln_{1 if name == "attention_norm" else 2}.{suffix}' checkpoint[new_key] = checkpoint.pop(old_key) - # layers.0.attention.wq.weight -> transformer.h.0.attn.wq.weight - # layers.0.attention.wk.weight -> transformer.h.0.attn.wk.weight - # layers.0.attention.wv.weight -> transformer.h.0.attn.wv.weight - # layers.0.attention.wo.weight -> transformer.h.0.attn.wo.weight - # loop over all layers + # layers.x.attention.wq.weight -> transformer.h.x.attn.wq.weight + # layers.x.attention.wk.weight -> transformer.h.x.attn.wk.weight + # layers.x.attention.wv.weight -> transformer.h.x.attn.wv.weight + # layers.x.attention.wo.weight -> transformer.h.x.attn.wo.weight for i in range(config.n_layer): for name in ['attention.wq', 'attention.wk', 'attention.wv', 'attention.wo']: for suffix in ['weight']: @@ -596,10 +320,9 @@ def modify_llama_keys(checkpoint, config: Llama31Config): new_key = f'transformer.h.{i}.attn.{name.split(".")[-1]}.{suffix}' checkpoint[new_key] = checkpoint.pop(old_key) - # layers.0.feed_forward.w1.weight -> transformer.h.0.mlp.w1.weight - # layers.0.feed_forward.w2.weight -> transformer.h.0.mlp.w2.weight - # layers.0.feed_forward.w3.weight -> transformer.h.0.mlp.w3.weight - # loop over all layers + # layers.x.feed_forward.w1.weight -> transformer.h.x.mlp.w1.weight + # layers.x.feed_forward.w2.weight -> transformer.h.x.mlp.w2.weight + # layers.x.feed_forward.w3.weight -> transformer.h.x.mlp.w3.weight for i in range(config.n_layer): for name in ['feed_forward.w1', 'feed_forward.w2', 'feed_forward.w3']: for suffix in ['weight']: @@ -615,16 +338,19 @@ def modify_llama_keys(checkpoint, config: Llama31Config): return checkpoint @classmethod - def from_pretrained_llama3_1(cls): - ckpt_dir = "/home/aleksa/Documents/eureka/nano-llama31/llama-models/models/llama3_1/Meta-Llama-3.1-8B" + def from_pretrained_llama3(cls, ckpt_dir, tokenizer_path): + model_args = LlamaConfig() + ckpt_path = sorted(Path(ckpt_dir).glob("*.pth"))[0] checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) - checkpoint = GPT.modify_llama_keys(checkpoint, Llama31Config()) - model_args = Llama31Config() + checkpoint = GPT.adapt_llama_state_dict_keys(checkpoint, model_args) torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) model = GPT(model_args) model.load_state_dict(checkpoint, strict=False) + + tokenizer = Tokenizer(model_path=tokenizer_path) + model.tokenizer = tokenizer return model @classmethod @@ -735,7 +461,7 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): return idx @torch.inference_mode() - def generate2( + def generate_llama( self, prompt_tokens: List[List[int]], max_gen_len: int, @@ -1089,7 +815,9 @@ def print0(*args, **kwargs): # default settings will overfit a tiny batch of data # and save model weights and debug state to disk on the first iteration parser = argparse.ArgumentParser() - parser.add_argument("--use_llama3", type=int, default=1, help="use llama3 model") + parser.add_argument("--llama3", type=int, default=1, help="use llama3 model") + parser.add_argument("--llama3_ckpt_dir", type=str, default=None, help="path to llama3 model checkpoint") + parser.add_argument("--llama3_tokenizer_path", type=str, default=None, help="path to llama3 tokenizer") # file system input / output parser.add_argument("--input_bin", type=str, default="dev/data/tinyshakespeare/tiny_shakespeare_val.bin", help="input .bin to train on") parser.add_argument("--input_val_bin", type=str, default="", help="input .bin to eval validation loss on") @@ -1176,7 +904,7 @@ def print0(*args, **kwargs): # set up a context manager following the desired dtype and device ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype] - ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if (device_type == "cuda" and not args.use_llama3) else nullcontext() + ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if (device_type == "cuda" and not args.llama3) else nullcontext() # rng / reproducibility torch.manual_seed(42) @@ -1208,8 +936,10 @@ def print0(*args, **kwargs): }[args.model] model = GPT(model_config) else: - if args.use_llama3: - model = GPT.from_pretrained_llama3_1() + if args.llama3: + assert args.llama3_ckpt_dir is not None and os.path.exists(args.llama3_ckpt_dir), f"llama3 ckpt dir {args.llama3_ckpt_dir} does not exist" + assert args.llama3_tokenizer_path is not None and os.path.exists(args.llama3_tokenizer_path), f"llama3 tokenizer path {args.llama3_tokenizer_path} does not exist" + model = GPT.from_pretrained_llama3(args.llama3_ckpt_dir, args.llama3_tokenizer_path) else: # load the GPT-2 model weights model = GPT.from_pretrained(args.model) @@ -1231,11 +961,38 @@ def print0(*args, **kwargs): if args.input_val_bin: val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size) + # ------------------------------------------------------------------------- + # LLaMA 3 inference + if args.llama3: + model.eval() + prompts: List[str] = [ + # For these prompts, the expected answer is the natural continuation of the prompt + "Clearly, the meaning of life is", + "Simply put, the theory of relativity states that", + """The repo llm.c on GitHub is""", + # Few shot prompt (providing a few examples before asking model to complete more); + """Translate English to French: + + sea otter => loutre de mer + peppermint => menthe poivrée + plush girafe => girafe peluche + cheese =>""", + ] + + prompt_tokens = [model.tokenizer.encode(x, bos=True, eos=False) for x in prompts] + + generation_tokens, _ = model.generate_llama(prompt_tokens, max_gen_len=64, temperature=0.6, top_p=0.9, logprobs=False, echo=False) + results = [{"generation": model.tokenizer.decode(t)} for t in generation_tokens] + for prompt, result in zip(prompts, results): + print(prompt, end="") + print(f"{result['generation']}") + print("\n==================================\n") + # ------------------------------------------------------------------------- # PyTorch -> C bridge: save some weights and state for C to load later as reference # do one forward pass to generate ground truth for our C tests - if not args.use_llama3 and master_process and args.write_tensors and (not args.inference_only): + if master_process and args.write_tensors and (not args.inference_only): x, y = train_loader.next_batch() x, y = x.to(device), y.to(device) logits, loss = model(x, y) @@ -1298,63 +1055,41 @@ def get_lr(it): last_step = (step == args.num_iterations) # once in a while evaluate the validation dataset - # if (args.val_loss_every > 0 \ - # and (step % args.val_loss_every == 0 or last_step)) \ - # and (val_loader is not None): - # model.eval() - # val_loader.reset() - # with torch.no_grad(): - # val_loss = 0.0 - # for _ in range(args.val_max_steps): - # x, y = val_loader.next_batch() - # x, y = x.to(device), y.to(device) - # _, loss = model(x, y, return_logits=False) - # val_loss += loss.item() - # val_loss /= args.val_max_steps - # # log to console and to file - # print0(f"val loss {val_loss}") - # if master_process and logfile is not None: - # with open(logfile, "a") as f: - # f.write("s:%d tel:%f\n" % (step, val_loss)) + if (args.val_loss_every > 0 \ + and (step % args.val_loss_every == 0 or last_step)) \ + and (val_loader is not None): + model.eval() + val_loader.reset() + with torch.no_grad(): + val_loss = 0.0 + for _ in range(args.val_max_steps): + x, y = val_loader.next_batch() + x, y = x.to(device), y.to(device) + _, loss = model(x, y, return_logits=False) + val_loss += loss.item() + val_loss /= args.val_max_steps + # log to console and to file + print0(f"val loss {val_loss}") + if master_process and logfile is not None: + with open(logfile, "a") as f: + f.write("s:%d tel:%f\n" % (step, val_loss)) # once in a while perform model inference on the master process - if True: - # (args.sample_every > 0 \ - # and (step % args.sample_every == 0 or last_step)) \ - # and master_process: + if (args.sample_every > 0 \ + and (step % args.sample_every == 0 or last_step)) \ + and master_process: model.eval() # before we end, let's also do one round of inference # we'll kick off the generation with "<|endoftext|>", which designates the start of a new sequence - # start_ids = [enc.eot_token] - # xg = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) - # max_new_tokens = 32 - # temperature = 1.0 - # top_k = 40 - tokenizer_path = "/home/aleksa/Documents/eureka/nano-llama31/llama-models/models/llama3_1/Meta-Llama-3.1-8B/tokenizer.model" - tokenizer = Tokenizer(model_path=tokenizer_path) - raw_model.tokenizer = tokenizer - prompts: List[str] = [ - # For these prompts, the expected answer is the natural continuation of the prompt - "Clearly, the meaning of life is", - "Simply put, the theory of relativity states that", - """The repo llm.c on GitHub is""", - # Few shot prompt (providing a few examples before asking model to complete more); - """Translate English to French: - - sea otter => loutre de mer - peppermint => menthe poivrée - plush girafe => girafe peluche - cheese =>""", - ] - - prompt_tokens = [tokenizer.encode(x, bos=True, eos=False) for x in prompts] - - generation_tokens, _ = raw_model.generate2(prompt_tokens, max_gen_len=64, temperature=0.6, top_p=0.9, logprobs=False, echo=False) - results = [{"generation": tokenizer.decode(t)} for t in generation_tokens] - for prompt, result in zip(prompts, results): - print(prompt, end="") # AK: change end="\n" to end="" - print(f"{result['generation']}") - print("\n==================================\n") + start_ids = [enc.eot_token] + xg = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) + max_new_tokens = 32 + temperature = 1.0 + top_k = 40 + yg = raw_model.generate(xg, max_new_tokens, temperature=temperature, top_k=top_k) + print0('---------------') + print0(enc.decode(yg[0].tolist())) + print0('---------------') # bit confusing: we want to make sure to eval and sample on 0th iteration # but also after the very last iteration. so we loop for step <= num_iterations From c414d0284ae72fc4681f5ef846f6ef77d3434253 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Fri, 2 Aug 2024 23:12:49 +0200 Subject: [PATCH 03/36] Minor refactor --- llmc_py/utils.py | 4 ++++ train_gpt2.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/llmc_py/utils.py b/llmc_py/utils.py index 66ff7a42e..ed023c78a 100644 --- a/llmc_py/utils.py +++ b/llmc_py/utils.py @@ -1,3 +1,7 @@ +# Taken from: +# 1) https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py +# 2) https://github.com/meta-llama/llama3/blob/11817d47e1ba7a4959b025eb1ca308572e0e3963/llama/generation.py + import torch from torch import nn diff --git a/train_gpt2.py b/train_gpt2.py index 46a264701..b81effe44 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -24,7 +24,6 @@ from contextlib import nullcontext from dataclasses import dataclass from pathlib import Path - from typing import ( List, Optional, @@ -988,6 +987,8 @@ def print0(*args, **kwargs): print(f"{result['generation']}") print("\n==================================\n") + exit(0) # only inference supported for now + # ------------------------------------------------------------------------- # PyTorch -> C bridge: save some weights and state for C to load later as reference From 465aac4aa5bf186cb9188165512d8c2365c7e495 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Sat, 3 Aug 2024 10:03:28 +0200 Subject: [PATCH 04/36] Equivalent to nano llama 3 reference code --- train_gpt2.py | 35 +++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/train_gpt2.py b/train_gpt2.py index b81effe44..f35c130b5 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -82,11 +82,14 @@ def __init__(self, config): self.wv = nn.Linear(config.n_embd, config.n_kv_head * self.head_dim, bias=False) self.wo = nn.Linear(config.n_head * self.head_dim, config.n_embd, bias=False) + self.cache_k = torch.zeros((4, config.block_size, config.n_kv_head, self.head_dim)) + self.cache_v = torch.zeros((4, config.block_size, config.n_kv_head, self.head_dim)) + # not really a 'bias', more of a mask, but following the OpenAI/HF naming though self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size)) - def forward(self, x, freqs_cis=None): + def forward(self, x, freqs_cis=None, start_pos=None): if not self.is_llama: B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim @@ -121,9 +124,21 @@ def forward(self, x, freqs_cis=None): # rotate QK (rope) xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + # kv-caching (which we can disable by setting start_pos = -1) + if start_pos >= 0: + self.cache_k = self.cache_k.to(xq) + self.cache_v = self.cache_v.to(xq) + self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk + self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv + keys = self.cache_k[:bsz, : start_pos + seqlen] + values = self.cache_v[:bsz, : start_pos + seqlen] + else: + keys = xk + values = xv + # repeat k/v heads if n_kv_heads < n_heads (GQA) - keys = repeat_kv(xk, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) - values = repeat_kv(xv, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) + keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) + values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) # attention xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) @@ -177,8 +192,8 @@ def __init__(self, config): self.ln_2 = RMSNorm(config.n_embd, config.norm_eps) if config.is_llama else nn.LayerNorm(config.n_embd) self.mlp = MLP(config) - def forward(self, x, freqs_cis=None): - x = x + self.attn(self.ln_1(x), freqs_cis) + def forward(self, x, freqs_cis=None, start_pos=None): + x = x + self.attn(self.ln_1(x), freqs_cis, start_pos) x = x + self.mlp(self.ln_2(x)) return x @@ -255,7 +270,7 @@ def _init_weights(self, module): elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02, generator=self.init_rng) - def forward(self, idx, targets=None, return_logits=True): + def forward(self, idx, targets=None, return_logits=True, start_pos=-1): device = idx.device b, t = idx.size() assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" @@ -269,10 +284,10 @@ def forward(self, idx, targets=None, return_logits=True): x = tok_emb + pos_emb else: x = tok_emb - freqs_cis = self.freqs_cis[:t] + freqs_cis = self.freqs_cis[start_pos:start_pos+t] for i, block in enumerate(self.transformer.h): - x = block(x, freqs_cis) + x = block(x, freqs_cis, start_pos) x = self.transformer.ln_f(x) if targets is not None: @@ -509,7 +524,7 @@ def generate_llama( input_text_mask = tokens != pad_id if min_prompt_len == total_len: - logits, _ = self.forward(tokens) + logits, _ = self.forward(tokens, start_pos=prev_pos) token_logprobs = -F.cross_entropy( input=logits.transpose(1, 2), target=tokens, @@ -520,7 +535,7 @@ def generate_llama( stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens)) for cur_pos in range(min_prompt_len, total_len): - logits, _ = self.forward(tokens[:, :cur_pos]) + logits, _ = self.forward(tokens[:, prev_pos:cur_pos], start_pos=prev_pos) if temperature > 0: probs = torch.softmax(logits[:, -1] / temperature, dim=-1) next_token = sample_top_p(probs, top_p) From f50f2de8a2e415ee56ccc402ba0805ef7c4c8f75 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Sat, 3 Aug 2024 10:11:26 +0200 Subject: [PATCH 05/36] Refactor attn, change numerics but equivalent --- train_gpt2.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/train_gpt2.py b/train_gpt2.py index f35c130b5..a0056d6a0 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -76,11 +76,8 @@ def __init__(self, config): self.n_rep = self.n_head // self.n_kv_head self.head_dim = config.n_embd // config.n_head - # TODO(gordicaleksa): this can be easily made the same as the above (c_attn, c_proj) - self.wq = nn.Linear(config.n_embd, config.n_head * self.head_dim, bias=False) - self.wk = nn.Linear(config.n_embd, config.n_kv_head * self.head_dim, bias=False) - self.wv = nn.Linear(config.n_embd, config.n_kv_head * self.head_dim, bias=False) - self.wo = nn.Linear(config.n_head * self.head_dim, config.n_embd, bias=False) + self.c_attn = nn.Linear(config.n_embd, (config.n_head + 2 * config.n_kv_head) * self.head_dim) + self.c_proj = nn.Linear(config.n_head * self.head_dim, config.n_embd, bias=False) self.cache_k = torch.zeros((4, config.block_size, config.n_kv_head, self.head_dim)) self.cache_v = torch.zeros((4, config.block_size, config.n_kv_head, self.head_dim)) @@ -117,7 +114,7 @@ def forward(self, x, freqs_cis=None, start_pos=None): bsz, seqlen, _ = x.shape # QKV - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq, xk, xv = torch.split(self.c_attn(x), [self.n_head * self.head_dim, self.n_kv_head * self.head_dim, self.n_kv_head * self.head_dim], dim=-1) xq = xq.view(bsz, seqlen, self.n_head, self.head_dim) xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim) xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim) @@ -151,7 +148,7 @@ def forward(self, x, freqs_cis=None, start_pos=None): scores = F.softmax(scores.float(), dim=-1).type_as(xq) output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) - return self.wo(output) + return self.c_proj(output) class MLP(nn.Module): @@ -323,16 +320,24 @@ def adapt_llama_state_dict_keys(checkpoint, config: LlamaConfig): new_key = f'transformer.h.{i}.ln_{1 if name == "attention_norm" else 2}.{suffix}' checkpoint[new_key] = checkpoint.pop(old_key) - # layers.x.attention.wq.weight -> transformer.h.x.attn.wq.weight - # layers.x.attention.wk.weight -> transformer.h.x.attn.wk.weight - # layers.x.attention.wv.weight -> transformer.h.x.attn.wv.weight - # layers.x.attention.wo.weight -> transformer.h.x.attn.wo.weight + # we merge the following 3: + # layers.x.attention.wq.weight + # layers.x.attention.wk.weight + # layers.x.attention.wv.weight + # into transformer.h.x.attn.c_attn.weight + # layers.x.attention.wo.weight -> transformer.h.x.attn.c_proj.weight for i in range(config.n_layer): - for name in ['attention.wq', 'attention.wk', 'attention.wv', 'attention.wo']: + for name in ['attention.wq', 'attention.wk', 'attention.wv']: for suffix in ['weight']: old_key = f'layers.{i}.{name}.{suffix}' - new_key = f'transformer.h.{i}.attn.{name.split(".")[-1]}.{suffix}' - checkpoint[new_key] = checkpoint.pop(old_key) + new_key = f'transformer.h.{i}.attn.c_attn.weight' + if name == 'attention.wq': + checkpoint[new_key] = checkpoint.pop(old_key) + else: + checkpoint[new_key] = torch.cat((checkpoint[new_key], checkpoint.pop(old_key)), dim=0) + old_key = f'layers.{i}.attention.wo.weight' + new_key = f'transformer.h.{i}.attn.c_proj.weight' + checkpoint[new_key] = checkpoint.pop(old_key) # layers.x.feed_forward.w1.weight -> transformer.h.x.mlp.w1.weight # layers.x.feed_forward.w2.weight -> transformer.h.x.mlp.w2.weight From c0c08ba53d8b31a3d4d44b642f01eb0ed450690e Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Sat, 3 Aug 2024 10:21:56 +0200 Subject: [PATCH 06/36] Have prompts in a file instead of inline, prompt 4 is different --- train_gpt2.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/train_gpt2.py b/train_gpt2.py index a0056d6a0..188753fc8 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -23,6 +23,7 @@ import inspect from contextlib import nullcontext from dataclasses import dataclass +import json from pathlib import Path from typing import ( List, @@ -984,20 +985,7 @@ def print0(*args, **kwargs): # LLaMA 3 inference if args.llama3: model.eval() - prompts: List[str] = [ - # For these prompts, the expected answer is the natural continuation of the prompt - "Clearly, the meaning of life is", - "Simply put, the theory of relativity states that", - """The repo llm.c on GitHub is""", - # Few shot prompt (providing a few examples before asking model to complete more); - """Translate English to French: - - sea otter => loutre de mer - peppermint => menthe poivrée - plush girafe => girafe peluche - cheese =>""", - ] - + prompts: List[str] = json.loads(open(os.path.join(os.path.dirname(__file__), 'llmc_py', 'prompts.json')).read())['prompts'] prompt_tokens = [model.tokenizer.encode(x, bos=True, eos=False) for x in prompts] generation_tokens, _ = model.generate_llama(prompt_tokens, max_gen_len=64, temperature=0.6, top_p=0.9, logprobs=False, echo=False) From de879d129bff0d01a5740825a7f0ab8f7f1a6432 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Sat, 3 Aug 2024 10:28:32 +0200 Subject: [PATCH 07/36] Refactor checkpoint state dict map func --- train_gpt2.py | 41 ++++++++++++----------------------------- 1 file changed, 12 insertions(+), 29 deletions(-) diff --git a/train_gpt2.py b/train_gpt2.py index 188753fc8..1108dadf0 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -309,49 +309,32 @@ def forward(self, idx, targets=None, return_logits=True, start_pos=-1): @staticmethod def adapt_llama_state_dict_keys(checkpoint, config: LlamaConfig): - # rename key tok_embeddings.weight to transformer.wte.weight checkpoint['transformer.wte.weight'] = checkpoint.pop('tok_embeddings.weight') - # layers.x.attention_norm.weight -> transformer.h.x.ln_1.weight - # layers.x.ffn_norm.weight -> transformer.h.x.ln_2.weight for i in range(config.n_layer): for name in ['attention_norm', 'ffn_norm']: - for suffix in ['weight']: - old_key = f'layers.{i}.{name}.{suffix}' - new_key = f'transformer.h.{i}.ln_{1 if name == "attention_norm" else 2}.{suffix}' - checkpoint[new_key] = checkpoint.pop(old_key) + old_key = f'layers.{i}.{name}.weight' # e.g. layers.x.attention_norm.weight -> transformer.h.x.ln_1.weight + new_key = f'transformer.h.{i}.ln_{1 if name == "attention_norm" else 2}.weight' + checkpoint[new_key] = checkpoint.pop(old_key) - # we merge the following 3: - # layers.x.attention.wq.weight - # layers.x.attention.wk.weight - # layers.x.attention.wv.weight - # into transformer.h.x.attn.c_attn.weight - # layers.x.attention.wo.weight -> transformer.h.x.attn.c_proj.weight for i in range(config.n_layer): for name in ['attention.wq', 'attention.wk', 'attention.wv']: - for suffix in ['weight']: - old_key = f'layers.{i}.{name}.{suffix}' - new_key = f'transformer.h.{i}.attn.c_attn.weight' - if name == 'attention.wq': - checkpoint[new_key] = checkpoint.pop(old_key) - else: - checkpoint[new_key] = torch.cat((checkpoint[new_key], checkpoint.pop(old_key)), dim=0) + old_key = f'layers.{i}.{name}.weight' + new_key = f'transformer.h.{i}.attn.c_attn.weight' + if name == 'attention.wq': + checkpoint[new_key] = checkpoint.pop(old_key) + else: # merge 3 weights into transformer.h.x.attn.c_attn.weight + checkpoint[new_key] = torch.cat((checkpoint[new_key], checkpoint.pop(old_key)), dim=0) old_key = f'layers.{i}.attention.wo.weight' new_key = f'transformer.h.{i}.attn.c_proj.weight' checkpoint[new_key] = checkpoint.pop(old_key) - # layers.x.feed_forward.w1.weight -> transformer.h.x.mlp.w1.weight - # layers.x.feed_forward.w2.weight -> transformer.h.x.mlp.w2.weight - # layers.x.feed_forward.w3.weight -> transformer.h.x.mlp.w3.weight for i in range(config.n_layer): for name in ['feed_forward.w1', 'feed_forward.w2', 'feed_forward.w3']: - for suffix in ['weight']: - old_key = f'layers.{i}.{name}.{suffix}' - new_key = f'transformer.h.{i}.mlp.{name.split(".")[-1]}.{suffix}' - checkpoint[new_key] = checkpoint.pop(old_key) + old_key = f'layers.{i}.{name}.weight' + new_key = f'transformer.h.{i}.mlp.{name.split(".")[-1]}.weight' + checkpoint[new_key] = checkpoint.pop(old_key) - # norm.weight -> transformer.ln_f.weight - # output.weight -> lm_head.weight checkpoint['transformer.ln_f.weight'] = checkpoint.pop('norm.weight') checkpoint['lm_head.weight'] = checkpoint.pop('output.weight') From 0199e51a26e614c1eb3dba8bae741ebe09dde301 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Sat, 3 Aug 2024 10:46:15 +0200 Subject: [PATCH 08/36] Refactor MLP --- train_gpt2.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/train_gpt2.py b/train_gpt2.py index 1108dadf0..dbae3b471 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -156,21 +156,21 @@ class MLP(nn.Module): def __init__(self, config): super().__init__() self.is_llama = config.is_llama + hidden_dim = 4 * config.n_embd if not self.is_llama: - self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) + self.c_fc = nn.Linear(config.n_embd, hidden_dim) self.gelu = NewGELU() - self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) + self.c_proj = nn.Linear(hidden_dim, config.n_embd) self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 else: - hidden_dim = 4 * config.n_embd hidden_dim = int(2 * hidden_dim / 3) # custom dim factor multiplier if config.ffn_dim_multiplier is not None: hidden_dim = int(config.ffn_dim_multiplier * hidden_dim) hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of) - self.w1 = nn.Linear(config.n_embd, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, config.n_embd, bias=False) - self.w3 = nn.Linear(config.n_embd, hidden_dim, bias=False) + self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=False) + self.c_fc2 = nn.Linear(config.n_embd, hidden_dim, bias=False) + self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False) def forward(self, x): if not self.is_llama: @@ -179,7 +179,13 @@ def forward(self, x): x = self.c_proj(x) return x else: - return self.w2(F.silu(self.w1(x)) * self.w3(x)) + x1 = self.c_fc(x) + x2 = self.c_fc2(x) + x2 = F.silu(x2) + x = x1 * x2 + x = self.c_proj(x) + return x # SwiGLU self.c_proj(F.silu(self.c_fc2(x)) * self.c_fc(x)) + class Block(nn.Module): @@ -237,8 +243,8 @@ def __init__(self, config): ln_f = RMSNorm(config.n_embd, config.norm_eps) if self.is_llama else nn.LayerNorm(config.n_embd), )) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - self.lm_head.LLMC_SKIP_INIT = 1 # don't init this one, we will tie weights if not self.is_llama: + self.lm_head.LLMC_SKIP_INIT = 1 # don't init this one, we will tie weights self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying # init all weights, use a torch rng object to be very careful @@ -281,7 +287,7 @@ def forward(self, idx, targets=None, return_logits=True, start_pos=-1): pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) x = tok_emb + pos_emb else: - x = tok_emb + x = tok_emb # we use RoPE in llama3 freqs_cis = self.freqs_cis[start_pos:start_pos+t] for i, block in enumerate(self.transformer.h): @@ -329,10 +335,11 @@ def adapt_llama_state_dict_keys(checkpoint, config: LlamaConfig): new_key = f'transformer.h.{i}.attn.c_proj.weight' checkpoint[new_key] = checkpoint.pop(old_key) + ffn_map = {'w1': 'c_fc2', 'w2': 'c_proj', 'w3': 'c_fc'} for i in range(config.n_layer): for name in ['feed_forward.w1', 'feed_forward.w2', 'feed_forward.w3']: old_key = f'layers.{i}.{name}.weight' - new_key = f'transformer.h.{i}.mlp.{name.split(".")[-1]}.weight' + new_key = f'transformer.h.{i}.mlp.{ffn_map[name.split(".")[-1]]}.weight' checkpoint[new_key] = checkpoint.pop(old_key) checkpoint['transformer.ln_f.weight'] = checkpoint.pop('norm.weight') From fdd5345931fa2d5357dc33ed512bccf846c9d80e Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Sat, 3 Aug 2024 11:49:26 +0200 Subject: [PATCH 09/36] Refactor attn mechanism --- train_gpt2.py | 121 +++++++++++++++++++------------------------------- 1 file changed, 46 insertions(+), 75 deletions(-) diff --git a/train_gpt2.py b/train_gpt2.py index dbae3b471..9ebc5e77a 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -63,93 +63,62 @@ def __init__(self, config): assert config.n_embd % config.n_head == 0 self.is_llama = config.is_llama if not self.is_llama: - # key, query, value projections for all heads, but in a batch - self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) - # output projection - self.c_proj = nn.Linear(config.n_embd, config.n_embd) - self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 - # regularization - self.n_head = config.n_head - self.n_embd = config.n_embd - else: - self.n_head = config.n_head - self.n_kv_head = config.n_kv_head - self.n_rep = self.n_head // self.n_kv_head - self.head_dim = config.n_embd // config.n_head + assert config.n_head == config.n_kv_head, "GQA is only available for LLaMA" - self.c_attn = nn.Linear(config.n_embd, (config.n_head + 2 * config.n_kv_head) * self.head_dim) - self.c_proj = nn.Linear(config.n_head * self.head_dim, config.n_embd, bias=False) + self.n_head = config.n_head + self.n_kv_head = config.n_kv_head + self.n_rep = self.n_head // self.n_kv_head + self.hd = config.n_embd // config.n_head - self.cache_k = torch.zeros((4, config.block_size, config.n_kv_head, self.head_dim)) - self.cache_v = torch.zeros((4, config.block_size, config.n_kv_head, self.head_dim)) + self.c_attn = nn.Linear(config.n_embd, (config.n_head + 2 * config.n_kv_head) * self.hd, bias=not self.is_llama) # key, query, value projections + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=not self.is_llama) # output projection + + if not self.is_llama: + self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 + + self.cache_k = torch.zeros((config.batch_size, config.block_size, config.n_kv_head, self.hd)) + self.cache_v = torch.zeros((config.batch_size, config.block_size, config.n_kv_head, self.hd)) # not really a 'bias', more of a mask, but following the OpenAI/HF naming though self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size)) def forward(self, x, freqs_cis=None, start_pos=None): - if not self.is_llama: - B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) - # calculate query, key, values for all heads in batch and move head forward to be the batch dim - qkv = self.c_attn(x) - q, k, v = qkv.split(self.n_embd, dim=2) - k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - if FLASH: - # flashattention - y = F.scaled_dot_product_attention(q, k, v, is_causal=True) - else: - # manual implementation of attention - # this materializes the large (T,T) matrix for all the queries and keys - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) - att = F.softmax(att, dim=-1) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side - # output projection - y = self.c_proj(y) - return y - else: - # TODO(gordicaleksa): this can be easily merged with the if branch above - bsz, seqlen, _ = x.shape - - # QKV - xq, xk, xv = torch.split(self.c_attn(x), [self.n_head * self.head_dim, self.n_kv_head * self.head_dim, self.n_kv_head * self.head_dim], dim=-1) - xq = xq.view(bsz, seqlen, self.n_head, self.head_dim) - xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim) - xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim) - # rotate QK (rope) - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - - # kv-caching (which we can disable by setting start_pos = -1) - if start_pos >= 0: - self.cache_k = self.cache_k.to(xq) - self.cache_v = self.cache_v.to(xq) - self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk - self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv - keys = self.cache_k[:bsz, : start_pos + seqlen] - values = self.cache_v[:bsz, : start_pos + seqlen] - else: - keys = xk - values = xv + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + qkv = self.c_attn(x) + q, k, v = qkv.split([self.n_head * self.hd, self.n_kv_head * self.hd, self.n_kv_head * self.hd], dim=-1) + q, k, v = map(lambda t: t.view(B, T, -1, self.hd), (q, k, v)) # (B, T, NH, HD) - # repeat k/v heads if n_kv_heads < n_heads (GQA) - keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) - values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) + if self.is_llama: + q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) # rotate QK (rope) - # attention - xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) - values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) - scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + if start_pos >= 0: # kv-caching (which we can disable by setting start_pos = -1) + self.cache_k[:B, start_pos : start_pos + T] = k + self.cache_v[:B, start_pos : start_pos + T] = v + k = self.cache_k[:B, : start_pos + T] + v = self.cache_v[:B, : start_pos + T] - scores = scores.masked_fill(self.bias[:,:,:seqlen,:seqlen] == 0, float('-inf')) + if self.is_llama: + k = repeat_kv(k, self.n_rep) # GQA + v = repeat_kv(v, self.n_rep) + + q = q.transpose(1, 2) # (B, NH, T, HD) + k = k.transpose(1, 2) + v = v.transpose(1, 2) - scores = F.softmax(scores.float(), dim=-1).type_as(xq) - output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) - output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) - return self.c_proj(output) + if FLASH: + # flashattention + y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + else: + # manual implementation of attention + # this materializes the large (T,T) matrix for all the queries and keys + scores = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.hd)) + scores = scores.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + att = F.softmax(scores.float(), dim=-1).type_as(q) + y = att @ v # (B, NH, T, T) x (B, NH, T, HD) -> (B, NH, T, HD) + y = y.transpose(1, 2).contiguous().view(B, T, C) + y = self.c_proj(y) + return y class MLP(nn.Module): @@ -211,6 +180,7 @@ class GPTConfig: vocab_size: int = 50257 n_layer: int = 12 n_head: int = 12 + n_kv_head: int = 12 n_embd: int = 768 @dataclass @@ -228,6 +198,7 @@ class LlamaConfig: norm_eps: float = 1e-5 rope_theta: float = 500000.0 use_scaled_rope: bool = True + batch_size = 4 class GPT(nn.Module): From fa7bcc3f298829bd17aab49644d690047d7ba2ac Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Sat, 3 Aug 2024 11:54:10 +0200 Subject: [PATCH 10/36] One more minor attn fix --- train_gpt2.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/train_gpt2.py b/train_gpt2.py index 9ebc5e77a..f9fdc7e94 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -102,9 +102,7 @@ def forward(self, x, freqs_cis=None, start_pos=None): k = repeat_kv(k, self.n_rep) # GQA v = repeat_kv(v, self.n_rep) - q = q.transpose(1, 2) # (B, NH, T, HD) - k = k.transpose(1, 2) - v = v.transpose(1, 2) + q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v)) # (B, NH, T, HD) if FLASH: # flashattention From 180215fd60cc6339c1428cca1492e6009814d35e Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Sat, 3 Aug 2024 12:38:49 +0200 Subject: [PATCH 11/36] Unify generate and generate_llama --- train_gpt2.py | 71 ++++++++++++++++----------------------------------- 1 file changed, 22 insertions(+), 49 deletions(-) diff --git a/train_gpt2.py b/train_gpt2.py index f9fdc7e94..a572a2164 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -76,8 +76,8 @@ def __init__(self, config): if not self.is_llama: self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 - self.cache_k = torch.zeros((config.batch_size, config.block_size, config.n_kv_head, self.hd)) - self.cache_v = torch.zeros((config.batch_size, config.block_size, config.n_kv_head, self.hd)) + self.cache_k = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd)) + self.cache_v = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd)) # not really a 'bias', more of a mask, but following the OpenAI/HF naming though self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) @@ -180,6 +180,7 @@ class GPTConfig: n_head: int = 12 n_kv_head: int = 12 n_embd: int = 768 + max_gen_batch_size = 4 @dataclass class LlamaConfig: @@ -196,7 +197,7 @@ class LlamaConfig: norm_eps: float = 1e-5 rope_theta: float = 500000.0 use_scaled_rope: bool = True - batch_size = 4 + max_gen_batch_size = 4 class GPT(nn.Module): @@ -412,35 +413,8 @@ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type, optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, fused=use_fused) return optimizer - @torch.no_grad() - def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): - """ - Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete - the sequence max_new_tokens times, feeding the predictions back into the model each time. - Most likely you'll want to make sure to be in model.eval() mode of operation for this. - """ - for _ in range(max_new_tokens): - # if the sequence context is growing too long we must crop it at block_size - idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] - # forward the model to get the logits for the index in the sequence - logits, _ = self(idx_cond) - # pluck the logits at the final step and scale by desired temperature - logits = logits[:, -1, :] / temperature - # optionally crop the logits to only the top k options - if top_k is not None: - v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - logits[logits < v[:, [-1]]] = -float('Inf') - # apply softmax to convert logits to (normalized) probabilities - probs = F.softmax(logits, dim=-1) - # sample from the distribution - idx_next = torch.multinomial(probs, num_samples=1) - # append sampled index to the running sequence and continue - idx = torch.cat((idx, idx_next), dim=1) - - return idx - @torch.inference_mode() - def generate_llama( + def generate( self, prompt_tokens: List[List[int]], max_gen_len: int, @@ -468,28 +442,28 @@ def generate_llama( If logprobs is True, token log probabilities are computed for each generated token. """ - params = self.config bsz = len(prompt_tokens) - # assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) + assert bsz <= self.config.max_gen_batch_size, (bsz, self.config.max_gen_batch_size) + device = next(self.parameters()).device min_prompt_len = min(len(t) for t in prompt_tokens) max_prompt_len = max(len(t) for t in prompt_tokens) - assert max_prompt_len <= params.block_size - total_len = min(params.block_size, max_gen_len + max_prompt_len) + assert max_prompt_len <= self.config.block_size + total_len = min(self.config.block_size, max_gen_len + max_prompt_len) pad_id = self.tokenizer.pad_id - tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") + tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=device) for k, t in enumerate(prompt_tokens): - tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device) if logprobs: token_logprobs = torch.zeros_like(tokens, dtype=torch.float) prev_pos = 0 - eos_reached = torch.tensor([False] * bsz, device="cuda") + eos_reached = torch.tensor([False] * bsz, device=device) input_text_mask = tokens != pad_id if min_prompt_len == total_len: - logits, _ = self.forward(tokens, start_pos=prev_pos) + logits, _ = self.forward(tokens, start_pos=prev_pos if self.config.is_llama else -1) token_logprobs = -F.cross_entropy( input=logits.transpose(1, 2), target=tokens, @@ -497,10 +471,10 @@ def generate_llama( ignore_index=pad_id, ) - stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens)) + stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens)).to(device) for cur_pos in range(min_prompt_len, total_len): - logits, _ = self.forward(tokens[:, prev_pos:cur_pos], start_pos=prev_pos) + logits, _ = self.forward(tokens[:, prev_pos:cur_pos], start_pos=prev_pos if self.config.is_llama else -1) if temperature > 0: probs = torch.softmax(logits[:, -1] / temperature, dim=-1) next_token = sample_top_p(probs, top_p) @@ -947,7 +921,7 @@ def print0(*args, **kwargs): prompts: List[str] = json.loads(open(os.path.join(os.path.dirname(__file__), 'llmc_py', 'prompts.json')).read())['prompts'] prompt_tokens = [model.tokenizer.encode(x, bos=True, eos=False) for x in prompts] - generation_tokens, _ = model.generate_llama(prompt_tokens, max_gen_len=64, temperature=0.6, top_p=0.9, logprobs=False, echo=False) + generation_tokens, _ = model.generate(prompt_tokens, max_gen_len=64, temperature=0.6, top_p=0.9, logprobs=False, echo=False) results = [{"generation": model.tokenizer.decode(t)} for t in generation_tokens] for prompt, result in zip(prompts, results): print(prompt, end="") @@ -1049,14 +1023,13 @@ def get_lr(it): model.eval() # before we end, let's also do one round of inference # we'll kick off the generation with "<|endoftext|>", which designates the start of a new sequence - start_ids = [enc.eot_token] - xg = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) - max_new_tokens = 32 - temperature = 1.0 - top_k = 40 - yg = raw_model.generate(xg, max_new_tokens, temperature=temperature, top_k=top_k) + start_ids = [[enc.eot_token]] + enc.pad_id = enc.eot_token # dummy value, it's a no-op + enc.stop_tokens = [-1] # dummy value, we don't stop early + raw_model.tokenizer = enc + yg = raw_model.generate(start_ids, max_gen_len=32, temperature=1.0, top_p=0.9) print0('---------------') - print0(enc.decode(yg[0].tolist())) + print0(enc.decode(yg[0][0])) print0('---------------') # bit confusing: we want to make sure to eval and sample on 0th iteration From 8919b66c3717915963b87ed4a14b758d647f03bc Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Sat, 3 Aug 2024 14:14:05 +0200 Subject: [PATCH 12/36] Fix generate for gpt-2 --- train_gpt2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_gpt2.py b/train_gpt2.py index a572a2164..454b16250 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -497,7 +497,7 @@ def generate( eos_reached |= (~input_text_mask[:, cur_pos]) & ( torch.isin(next_token, stop_tokens) ) - prev_pos = cur_pos + prev_pos = cur_pos if self.config.is_llama else 0 if all(eos_reached): break @@ -768,7 +768,7 @@ def print0(*args, **kwargs): # default settings will overfit a tiny batch of data # and save model weights and debug state to disk on the first iteration parser = argparse.ArgumentParser() - parser.add_argument("--llama3", type=int, default=1, help="use llama3 model") + parser.add_argument("--llama3", type=int, default=0, help="use llama3 model") parser.add_argument("--llama3_ckpt_dir", type=str, default=None, help="path to llama3 model checkpoint") parser.add_argument("--llama3_tokenizer_path", type=str, default=None, help="path to llama3 tokenizer") # file system input / output From ccdbdfd4ceaaad811873da9e93bd24e2161465d5 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Sat, 3 Aug 2024 21:49:45 +0200 Subject: [PATCH 13/36] Going towards pure llama 3 file - fixed attn --- train_gpt2.py | 41 +++++++++++++++-------------------------- 1 file changed, 15 insertions(+), 26 deletions(-) diff --git a/train_gpt2.py b/train_gpt2.py index 454b16250..99a635606 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -1,19 +1,18 @@ """ -Reference code for GPT-2 training and inference. +Reference code for LLaMA-3.1 training and inference. Will save the model weights into files, to be read from C as initialization. References: -1) the official GPT-2 TensorFlow implementation released by OpenAI: -https://github.com/openai/gpt-2/blob/master/src/model.py -2) huggingface/transformers PyTorch implementation: -https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py +# 1) https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/tokenizer.py +# 2) https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py +# 3) https://github.com/meta-llama/llama3/blob/11817d47e1ba7a4959b025eb1ca308572e0e3963/llama/generation.py Example launches to only benchmark the speed of bfloat16 compiled GPU training: 1 GPU: -python train_gpt2.py --write_tensors=0 --num_iterations=50 --sequence_length=1024 --compile=1 --tensorcores=1 --dtype=bfloat16 +python train_llama3.py --write_tensors=0 --num_iterations=50 --sequence_length=1024 --compile=1 --tensorcores=1 --dtype=bfloat16 you can also turn on flash-attention by appending --flash=1 4 GPU: -torchrun --standalone --nproc_per_node=4 train_gpt2.py --write_tensors=0 --num_iterations=50 --sequence_length=1024 --compile=1 --tensorcores=1 --dtype=bfloat16 +torchrun --standalone --nproc_per_node=4 train_llama3.py --write_tensors=0 --num_iterations=50 --sequence_length=1024 --compile=1 --tensorcores=1 --dtype=bfloat16 """ import os @@ -48,11 +47,6 @@ # ----------------------------------------------------------------------------- # PyTorch nn.Module definitions for the GPT-2 model -class NewGELU(nn.Module): - """Careful there are a few versions of GeLU, this one is the exact one used by OpenAI""" - def forward(self, input): - return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) - # using a global to toggle flash-attention FLASH = 0 @@ -61,21 +55,17 @@ class CausalSelfAttention(nn.Module): def __init__(self, config): super().__init__() assert config.n_embd % config.n_head == 0 - self.is_llama = config.is_llama - if not self.is_llama: - assert config.n_head == config.n_kv_head, "GQA is only available for LLaMA" self.n_head = config.n_head self.n_kv_head = config.n_kv_head self.n_rep = self.n_head // self.n_kv_head self.hd = config.n_embd // config.n_head - self.c_attn = nn.Linear(config.n_embd, (config.n_head + 2 * config.n_kv_head) * self.hd, bias=not self.is_llama) # key, query, value projections - self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=not self.is_llama) # output projection - - if not self.is_llama: - self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 + self.c_attn = nn.Linear(config.n_embd, (config.n_head + 2 * config.n_kv_head) * self.hd, bias=False) # key, query, value projections + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) # output projection + self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 + # static KV cache self.cache_k = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd)) self.cache_v = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd)) @@ -85,12 +75,12 @@ def __init__(self, config): def forward(self, x, freqs_cis=None, start_pos=None): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + # calculate query, key, values for all heads in batch and move head forward to be the batch dim qkv = self.c_attn(x) q, k, v = qkv.split([self.n_head * self.hd, self.n_kv_head * self.hd, self.n_kv_head * self.hd], dim=-1) q, k, v = map(lambda t: t.view(B, T, -1, self.hd), (q, k, v)) # (B, T, NH, HD) - if self.is_llama: - q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) # rotate QK (rope) + q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) # rotate QK (rope) if start_pos >= 0: # kv-caching (which we can disable by setting start_pos = -1) self.cache_k[:B, start_pos : start_pos + T] = k @@ -98,9 +88,8 @@ def forward(self, x, freqs_cis=None, start_pos=None): k = self.cache_k[:B, : start_pos + T] v = self.cache_v[:B, : start_pos + T] - if self.is_llama: - k = repeat_kv(k, self.n_rep) # GQA - v = repeat_kv(v, self.n_rep) + k = repeat_kv(k, self.n_rep) # GQA + v = repeat_kv(v, self.n_rep) q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v)) # (B, NH, T, HD) @@ -768,7 +757,7 @@ def print0(*args, **kwargs): # default settings will overfit a tiny batch of data # and save model weights and debug state to disk on the first iteration parser = argparse.ArgumentParser() - parser.add_argument("--llama3", type=int, default=0, help="use llama3 model") + parser.add_argument("--llama3", type=int, default=1, help="use llama3 model") parser.add_argument("--llama3_ckpt_dir", type=str, default=None, help="path to llama3 model checkpoint") parser.add_argument("--llama3_tokenizer_path", type=str, default=None, help="path to llama3 tokenizer") # file system input / output From 8a48df7b169882e8d7e6c9fcc272603aedf980cf Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Sat, 3 Aug 2024 21:53:26 +0200 Subject: [PATCH 14/36] MLP GPT2->LLaMA3 --- train_gpt2.py | 48 ++++++++++++++++++------------------------------ 1 file changed, 18 insertions(+), 30 deletions(-) diff --git a/train_gpt2.py b/train_gpt2.py index 99a635606..64b2e6974 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -111,45 +111,33 @@ class MLP(nn.Module): def __init__(self, config): super().__init__() - self.is_llama = config.is_llama hidden_dim = 4 * config.n_embd - if not self.is_llama: - self.c_fc = nn.Linear(config.n_embd, hidden_dim) - self.gelu = NewGELU() - self.c_proj = nn.Linear(hidden_dim, config.n_embd) - self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 - else: - hidden_dim = int(2 * hidden_dim / 3) - # custom dim factor multiplier - if config.ffn_dim_multiplier is not None: - hidden_dim = int(config.ffn_dim_multiplier * hidden_dim) - hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of) - self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=False) - self.c_fc2 = nn.Linear(config.n_embd, hidden_dim, bias=False) - self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False) + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if config.ffn_dim_multiplier is not None: + hidden_dim = int(config.ffn_dim_multiplier * hidden_dim) + hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of) + self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=False) + self.c_fc2 = nn.Linear(config.n_embd, hidden_dim, bias=False) + self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False) + self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 def forward(self, x): - if not self.is_llama: - x = self.c_fc(x) - x = self.gelu(x) - x = self.c_proj(x) - return x - else: - x1 = self.c_fc(x) - x2 = self.c_fc2(x) - x2 = F.silu(x2) - x = x1 * x2 - x = self.c_proj(x) - return x # SwiGLU self.c_proj(F.silu(self.c_fc2(x)) * self.c_fc(x)) - + # SwiGLU self.c_proj(F.silu(self.c_fc2(x)) * self.c_fc(x)) + x1 = self.c_fc(x) + x2 = self.c_fc2(x) + x2 = F.silu(x2) + x = x1 * x2 + x = self.c_proj(x) + return x class Block(nn.Module): def __init__(self, config): super().__init__() - self.ln_1 = RMSNorm(config.n_embd, config.norm_eps) if config.is_llama else nn.LayerNorm(config.n_embd) + self.ln_1 = RMSNorm(config.n_embd, config.norm_eps) self.attn = CausalSelfAttention(config) - self.ln_2 = RMSNorm(config.n_embd, config.norm_eps) if config.is_llama else nn.LayerNorm(config.n_embd) + self.ln_2 = RMSNorm(config.n_embd, config.norm_eps) self.mlp = MLP(config) def forward(self, x, freqs_cis=None, start_pos=None): From c1d2b7fec112da46ae7960d1a98c527d02a04bef Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Sat, 3 Aug 2024 22:02:02 +0200 Subject: [PATCH 15/36] Removed from pretrained for GPT-2 --- train_gpt2.py | 129 ++++++++++---------------------------------------- 1 file changed, 26 insertions(+), 103 deletions(-) diff --git a/train_gpt2.py b/train_gpt2.py index 64b2e6974..75adc7b21 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -146,22 +146,10 @@ def forward(self, x, freqs_cis=None, start_pos=None): return x # ----------------------------------------------------------------------------- -# The main GPT-2 model - -@dataclass -class GPTConfig: - is_llama = False - block_size: int = 1024 - vocab_size: int = 50257 - n_layer: int = 12 - n_head: int = 12 - n_kv_head: int = 12 - n_embd: int = 768 - max_gen_batch_size = 4 +# The main LLaMA 3.1 model @dataclass class LlamaConfig: - is_llama = True version: str = "3.1" block_size: int = 1024 vocab_size: int = 128256 @@ -174,39 +162,32 @@ class LlamaConfig: norm_eps: float = 1e-5 rope_theta: float = 500000.0 use_scaled_rope: bool = True - max_gen_batch_size = 4 + max_gen_batch_size: int = 4 -class GPT(nn.Module): +class LLaMA(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.is_llama = config.is_llama self.transformer = nn.ModuleDict(dict( wte = nn.Embedding(config.vocab_size, config.n_embd), - **({} if self.is_llama else {'wpe': nn.Embedding(config.block_size, config.n_embd)}), h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), - ln_f = RMSNorm(config.n_embd, config.norm_eps) if self.is_llama else nn.LayerNorm(config.n_embd), + ln_f = RMSNorm(config.n_embd, config.norm_eps), )) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - if not self.is_llama: - self.lm_head.LLMC_SKIP_INIT = 1 # don't init this one, we will tie weights - self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying # init all weights, use a torch rng object to be very careful self.init_rng = torch.Generator() self.init_rng.manual_seed(42) - if not self.is_llama: - self.apply(self._init_weights) - - if self.is_llama: - self.freqs_cis = precompute_freqs_cis( - config.n_embd // config.n_head, - config.block_size * 2, - config.rope_theta, - config.use_scaled_rope, - ) + # self.apply(self._init_weights) + + self.freqs_cis = precompute_freqs_cis( + config.n_embd // config.n_head, + config.block_size * 2, + config.rope_theta, + config.use_scaled_rope, + ) def _init_weights(self, module): if isinstance(module, nn.Linear): @@ -227,15 +208,9 @@ def forward(self, idx, targets=None, return_logits=True, start_pos=-1): assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) - # forward the GPT model itself - tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - freqs_cis = None - if not self.is_llama: - pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) - x = tok_emb + pos_emb - else: - x = tok_emb # we use RoPE in llama3 - freqs_cis = self.freqs_cis[start_pos:start_pos+t] + # forward the LLaMA model itself + x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + freqs_cis = self.freqs_cis[start_pos:start_pos+t] for i, block in enumerate(self.transformer.h): x = block(x, freqs_cis, start_pos) @@ -243,15 +218,11 @@ def forward(self, idx, targets=None, return_logits=True, start_pos=-1): if targets is not None: # if we are given some desired targets also calculate the loss - logits = self.lm_head(x) - if self.is_llama: - logits = logits.float() + logits = self.lm_head(x).float() loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) else: # inference-time mini-optimization: only forward the lm_head on the very last position - logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim - if self.is_llama: - logits = logits.float() + logits = self.lm_head(x[:, [-1], :]).float() # note: using list [-1] to preserve the time dim loss = None # there are performance reasons why not returning logits is prudent, if not needed @@ -296,69 +267,21 @@ def adapt_llama_state_dict_keys(checkpoint, config: LlamaConfig): @classmethod def from_pretrained_llama3(cls, ckpt_dir, tokenizer_path): + """Loads pretrained LLaMA model weights from a checkpoint directory""" model_args = LlamaConfig() ckpt_path = sorted(Path(ckpt_dir).glob("*.pth"))[0] checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) - checkpoint = GPT.adapt_llama_state_dict_keys(checkpoint, model_args) + checkpoint = LLaMA.adapt_llama_state_dict_keys(checkpoint, model_args) torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) - model = GPT(model_args) + model = LLaMA(model_args) model.load_state_dict(checkpoint, strict=False) tokenizer = Tokenizer(model_path=tokenizer_path) model.tokenizer = tokenizer return model - @classmethod - def from_pretrained(cls, model_type): - """Loads pretrained GPT-2 model weights from huggingface""" - assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} - from transformers import GPT2LMHeadModel - print("loading weights from pretrained gpt: %s" % model_type) - - # n_layer, n_head and n_embd are determined from model_type - config_args = { - 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params - 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params - 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params - 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params - }[model_type] - config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints - config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints - # create a from-scratch initialized minGPT model - config = GPTConfig(**config_args) - model = GPT(config) - sd = model.state_dict() - sd_keys = sd.keys() - sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param - - # init a huggingface/transformers model - model_hf = GPT2LMHeadModel.from_pretrained(model_type) - sd_hf = model_hf.state_dict() - - # copy while ensuring all of the parameters are aligned and match in names and shapes - sd_keys_hf = sd_hf.keys() - sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer - sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer) - transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] - # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear - # this means that we have to transpose these weights when we import them - assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" - for k in sd_keys_hf: - if any(k.endswith(w) for w in transposed): - # special treatment for the Conv1D weights we need to transpose - assert sd_hf[k].shape[::-1] == sd[k].shape - with torch.no_grad(): - sd[k].copy_(sd_hf[k].t()) - else: - # vanilla copy over the other parameters - assert sd_hf[k].shape == sd[k].shape - with torch.no_grad(): - sd[k].copy_(sd_hf[k]) - - return model - def configure_optimizers(self, weight_decay, learning_rate, betas, device_type, zero_stage): # start with all of the candidate parameters param_dict = {pn: p for pn, p in self.named_parameters()} @@ -440,7 +363,7 @@ def generate( input_text_mask = tokens != pad_id if min_prompt_len == total_len: - logits, _ = self.forward(tokens, start_pos=prev_pos if self.config.is_llama else -1) + logits, _ = self.forward(tokens, start_pos=prev_pos) token_logprobs = -F.cross_entropy( input=logits.transpose(1, 2), target=tokens, @@ -451,7 +374,7 @@ def generate( stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens)).to(device) for cur_pos in range(min_prompt_len, total_len): - logits, _ = self.forward(tokens[:, prev_pos:cur_pos], start_pos=prev_pos if self.config.is_llama else -1) + logits, _ = self.forward(tokens[:, prev_pos:cur_pos], start_pos=prev_pos) if temperature > 0: probs = torch.softmax(logits[:, -1] / temperature, dim=-1) next_token = sample_top_p(probs, top_p) @@ -474,7 +397,7 @@ def generate( eos_reached |= (~input_text_mask[:, cur_pos]) & ( torch.isin(next_token, stop_tokens) ) - prev_pos = cur_pos if self.config.is_llama else 0 + prev_pos = cur_pos if all(eos_reached): break @@ -864,15 +787,15 @@ def print0(*args, **kwargs): "d36": GPTConfig(block_size=1024, vocab_size=50257, n_layer=36, n_head=20, n_embd=1280), "d48": GPTConfig(block_size=1024, vocab_size=50257, n_layer=48, n_head=25, n_embd=1600), }[args.model] - model = GPT(model_config) + model = LLaMA(model_config) else: if args.llama3: assert args.llama3_ckpt_dir is not None and os.path.exists(args.llama3_ckpt_dir), f"llama3 ckpt dir {args.llama3_ckpt_dir} does not exist" assert args.llama3_tokenizer_path is not None and os.path.exists(args.llama3_tokenizer_path), f"llama3 tokenizer path {args.llama3_tokenizer_path} does not exist" - model = GPT.from_pretrained_llama3(args.llama3_ckpt_dir, args.llama3_tokenizer_path) + model = LLaMA.from_pretrained_llama3(args.llama3_ckpt_dir, args.llama3_tokenizer_path) else: # load the GPT-2 model weights - model = GPT.from_pretrained(args.model) + model = LLaMA.from_pretrained(args.model) model.train() model.to(device) From d855c9695aa9f407dc848dfae9935b6369ed7689 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Sat, 3 Aug 2024 22:16:13 +0200 Subject: [PATCH 16/36] Refactoring - got to main --- train_gpt2.py | 92 +++++++++++---------------------------------------- 1 file changed, 20 insertions(+), 72 deletions(-) diff --git a/train_gpt2.py b/train_gpt2.py index 75adc7b21..e1630f394 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -520,83 +520,53 @@ def write_bf16(tensor, file): file.write(b) def write_tensors(model_tensors, L, file, dtype): - # writes the GPT-2 model's weights to a binary file + # writes LLaMA 3 model's weights to a binary file assert dtype in {"float32", "bfloat16"} write_fun = write_fp32 if dtype == "float32" else write_bf16 write_fun(model_tensors["transformer.wte.weight"], file) # (V, C) - write_fun(model_tensors["transformer.wpe.weight"], file) # (T, C) for i in range(L): # (L, C) write_fun(model_tensors[f"transformer.h.{i}.ln_1.weight"], file) - for i in range(L): # (L, C) - write_fun(model_tensors[f"transformer.h.{i}.ln_1.bias"], file) for i in range(L): # (L, 3C, C) write_fun(model_tensors[f"transformer.h.{i}.attn.c_attn.weight"], file) - for i in range(L): # (L, 3C) - write_fun(model_tensors[f"transformer.h.{i}.attn.c_attn.bias"], file) for i in range(L): # (L, C, C) write_fun(model_tensors[f"transformer.h.{i}.attn.c_proj.weight"], file) - for i in range(L): # (L, C) - write_fun(model_tensors[f"transformer.h.{i}.attn.c_proj.bias"], file) for i in range(L): # (L, C) write_fun(model_tensors[f"transformer.h.{i}.ln_2.weight"], file) - for i in range(L): # (L, C) - write_fun(model_tensors[f"transformer.h.{i}.ln_2.bias"], file) for i in range(L): # (L, 4C, C) write_fun(model_tensors[f"transformer.h.{i}.mlp.c_fc.weight"], file) - for i in range(L): # (L, 4C) - write_fun(model_tensors[f"transformer.h.{i}.mlp.c_fc.bias"], file) + for i in range(L): # (L, 4C, C) + write_fun(model_tensors[f"transformer.h.{i}.mlp.c_fc2.weight"], file) for i in range(L): # (L, C, 4C) write_fun(model_tensors[f"transformer.h.{i}.mlp.c_proj.weight"], file) - for i in range(L): # (L, C) - write_fun(model_tensors[f"transformer.h.{i}.mlp.c_proj.bias"], file) write_fun(model_tensors["transformer.ln_f.weight"], file) # (C, ) - write_fun(model_tensors["transformer.ln_f.bias"], file) # (C, ) - -@torch.no_grad() -def pad_vocab(tensor, multiple=128, value=0): - """ - The dimension of the vocab size in GPT-2 is 50,257 - which is unfortunately a very unfriendly number for a lot of - matrix operations on the GPU. So we pad it to the nearest - friendlier multiple, e.g. 50,304 if multiple=128 when we - export the weights into C land. This is a NOOP algorithmically - and is only done to make the tensor operations more efficient. - """ - assert tensor.ndim == 2 - V, C = tensor.shape - assert V == 50257, "just being defensive here" - # calculate padded vocab size by rounding up to nearest multiple - Vp = ((V + multiple - 1) // multiple) * multiple - # pad the tensor - pad_rows = Vp - V - padded = tensor if pad_rows == 0 else F.pad(tensor, (0, 0, 0, pad_rows), value=value) - assert padded.shape == (Vp, C) - return padded + write_fun(model_tensors["lm_head.weight"], file) # (V, C) def write_model(model, filename, dtype): # everything we need to instantiate the model - # 1) header is: version int, GPTConfig ints, padding to 1024 bytes + # 1) header is: version int, LLaMAConfig ints, padding to 1024 bytes assert dtype in {"float32", "bfloat16"} # float16 todo maybe later version = { "float32": 3, # 3: all tensors are fp32, padded vocab "bfloat16": 5, # 5: all tensors are bf16, padded vocab }[dtype] header = torch.zeros(256, dtype=torch.int32) - header[0] = 20240326 # magic + header[0] = 20240803 # magic header[1] = version # checkpoint version header[2] = model.config.block_size header[3] = model.config.vocab_size header[4] = model.config.n_layer header[5] = model.config.n_head - header[6] = model.config.n_embd + header[6] = model.config.n_kv_head + header[7] = model.config.n_embd + header[8] = model.config.ffn_dim_multiplier + header[9] = model.config.multiple_of + header[10] = model.config.norm_eps + header[11] = model.config.rope_theta + header[12] = model.config.use_scaled_rope + header[13] = model.config.max_gen_batch_size + header[14] = model.version # 2) the parameters follow the header params = {name: param.cpu() for name, param in model.named_parameters()} - # pad the vocab to a multiple of 128 here at export, for efficiency in C - wte = params["transformer.wte.weight"] # (V, C) - wte_padded = pad_vocab(wte) # (Vp, C) - params["transformer.wte.weight"] = wte_padded # (Vp, C) - print(f"padded vocab size from {wte.size(0)} to {wte_padded.size(0)}") - header[7] = wte_padded.size(0) # padded vocab size store in header # now write to file with open(filename, "wb") as file: file.write(header.numpy().tobytes()) # header @@ -608,16 +578,10 @@ def write_state(model, x, y, logits, loss, filename): # it contains information about the input, logits, loss, and the parameter gradients # this can be used for checking the computation correctness in C header = torch.zeros(256, dtype=torch.int32) - header[0] = 20240327 # magic - header[1] = 2 # run state version = 2 (1 -> 2 for padded vocab changes) - header[2] = x.size(0) # batch size of the batch, B - header[3] = x.size(1) # temporal extent of the batch, T + header[0] = 20240803 # magic + header[1] = x.size(0) # batch size of the batch, B + header[2] = x.size(1) # temporal extent of the batch, T grads = {name: param.grad.cpu() for name, param in model.named_parameters()} - # pad the vocab grads here as well, to mirror write_model - wte_grad = grads["transformer.wte.weight"] # (V, C) - wte_grad_padded = pad_vocab(wte_grad, value=0) # (Vp, C) # TODO later maybe pad with nan? - grads["transformer.wte.weight"] = wte_grad_padded # (Vp, C) - print(f"padded vocab size in reference grads from {wte_grad.size(0)} to {wte_grad_padded.size(0)}") with open(filename, "wb") as file: # header file.write(header.numpy().tobytes()) @@ -633,23 +597,6 @@ def write_state(model, x, y, logits, loss, filename): write_tensors(grads, model.config.n_layer, file, "float32") print(f"wrote {filename}") -def write_tokenizer(enc, filename): - n = enc.max_token_value + 1 - header = torch.zeros(256, dtype=torch.int32) - header[0] = 20240328 # magic - header[1] = 2 # tokenizer version = 2 (1 -> 2: includes EOT token) - header[2] = n # number of tokens - header[3] = enc.eot_token # EOT token - with open(filename, "wb") as file: - file.write(header.numpy().tobytes()) - for i in range(n): - b = enc.decode_bytes([i]) - length = len(b) - assert length < 256, f"Token length exceeds 255: {length}" - file.write(struct.pack(" Date: Sat, 3 Aug 2024 22:22:38 +0200 Subject: [PATCH 17/36] Got to llama 3 inference (end) --- train_gpt2.py | 61 ++++++++++++++++----------------------------------- 1 file changed, 19 insertions(+), 42 deletions(-) diff --git a/train_gpt2.py b/train_gpt2.py index e1630f394..4d8270808 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -609,15 +609,13 @@ def print0(*args, **kwargs): if __name__ == "__main__": import time import argparse - import tiktoken print0(f"Running pytorch {torch.version.__version__}") # default settings will overfit a tiny batch of data # and save model weights and debug state to disk on the first iteration parser = argparse.ArgumentParser() - parser.add_argument("--llama3", type=int, default=1, help="use llama3 model") - parser.add_argument("--llama3_ckpt_dir", type=str, default=None, help="path to llama3 model checkpoint") - parser.add_argument("--llama3_tokenizer_path", type=str, default=None, help="path to llama3 tokenizer") + parser.add_argument("--ckpt_dir", type=str, default=None, help="path to llama3 model checkpoint") + parser.add_argument("--tokenizer_path", type=str, default=None, help="path to llama3 tokenizer") # file system input / output parser.add_argument("--input_bin", type=str, default="dev/data/tinyshakespeare/tiny_shakespeare_val.bin", help="input .bin to train on") parser.add_argument("--input_val_bin", type=str, default="", help="input .bin to eval validation loss on") @@ -704,7 +702,7 @@ def print0(*args, **kwargs): # set up a context manager following the desired dtype and device ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype] - ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if (device_type == "cuda" and not args.llama3) else nullcontext() + ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if (device_type == "cuda") else nullcontext() # rng / reproducibility torch.manual_seed(42) @@ -720,30 +718,10 @@ def print0(*args, **kwargs): assert args.flash in {0, 1} FLASH = args.flash - # init (and write) the tokenizer - enc = tiktoken.get_encoding("gpt2") - if master_process and args.write_tensors: # tokenizer is technically not tensors but ok - # write_tokenizer(enc, "gpt2_tokenizer.bin") - pass - - # init the model, either from scratch or from OpenAI pretrained checkpoint - if args.model[0] == "d": - # from scratch (random weights) - model_config = { - "d12": GPTConfig(block_size=1024, vocab_size=50257, n_layer=12, n_head=12, n_embd=768), - "d24": GPTConfig(block_size=1024, vocab_size=50257, n_layer=24, n_head=16, n_embd=1024), - "d36": GPTConfig(block_size=1024, vocab_size=50257, n_layer=36, n_head=20, n_embd=1280), - "d48": GPTConfig(block_size=1024, vocab_size=50257, n_layer=48, n_head=25, n_embd=1600), - }[args.model] - model = LLaMA(model_config) - else: - if args.llama3: - assert args.llama3_ckpt_dir is not None and os.path.exists(args.llama3_ckpt_dir), f"llama3 ckpt dir {args.llama3_ckpt_dir} does not exist" - assert args.llama3_tokenizer_path is not None and os.path.exists(args.llama3_tokenizer_path), f"llama3 tokenizer path {args.llama3_tokenizer_path} does not exist" - model = LLaMA.from_pretrained_llama3(args.llama3_ckpt_dir, args.llama3_tokenizer_path) - else: - # load the GPT-2 model weights - model = LLaMA.from_pretrained(args.model) + # init the model + assert args.ckpt_dir is not None and os.path.exists(args.ckpt_dir), f"llama3 ckpt dir {args.ckpt_dir} does not exist" + assert args.tokenizer_path is not None and os.path.exists(args.tokenizer_path), f"llama3 tokenizer path {args.tokenizer_path} does not exist" + model = LLaMA.from_pretrained_llama3(args.ckpt_dir, args.tokenizer_path) model.train() model.to(device) @@ -764,19 +742,18 @@ def print0(*args, **kwargs): # ------------------------------------------------------------------------- # LLaMA 3 inference - if args.llama3: - model.eval() - prompts: List[str] = json.loads(open(os.path.join(os.path.dirname(__file__), 'llmc_py', 'prompts.json')).read())['prompts'] - prompt_tokens = [model.tokenizer.encode(x, bos=True, eos=False) for x in prompts] - - generation_tokens, _ = model.generate(prompt_tokens, max_gen_len=64, temperature=0.6, top_p=0.9, logprobs=False, echo=False) - results = [{"generation": model.tokenizer.decode(t)} for t in generation_tokens] - for prompt, result in zip(prompts, results): - print(prompt, end="") - print(f"{result['generation']}") - print("\n==================================\n") - - exit(0) # only inference supported for now + model.eval() + prompts: List[str] = json.loads(open(os.path.join(os.path.dirname(__file__), 'llmc_py', 'prompts.json')).read())['prompts'] + prompt_tokens = [model.tokenizer.encode(x, bos=True, eos=False) for x in prompts] + + generation_tokens, _ = model.generate(prompt_tokens, max_gen_len=64, temperature=0.6, top_p=0.9, logprobs=False, echo=False) + results = [{"generation": model.tokenizer.decode(t)} for t in generation_tokens] + for prompt, result in zip(prompts, results): + print(prompt, end="") + print(f"{result['generation']}") + print("\n==================================\n") + + exit(0) # only inference supported for now # ------------------------------------------------------------------------- # PyTorch -> C bridge: save some weights and state for C to load later as reference From bad7857dff4d3b1eb038d55d4fddd44c6398d33d Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Sat, 3 Aug 2024 22:37:54 +0200 Subject: [PATCH 18/36] Done - need to test train loop and saving model --- train_gpt2.py | 58 +++++++++++++++++++-------------------------------- 1 file changed, 21 insertions(+), 37 deletions(-) diff --git a/train_gpt2.py b/train_gpt2.py index 4d8270808..2961bc716 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -45,7 +45,7 @@ from llmc_py.utils import repeat_kv, sample_top_p, RMSNorm # ----------------------------------------------------------------------------- -# PyTorch nn.Module definitions for the GPT-2 model +# PyTorch nn.Module definitions for the LLaMA 3.x model # using a global to toggle flash-attention FLASH = 0 @@ -564,7 +564,8 @@ def write_model(model, filename, dtype): header[11] = model.config.rope_theta header[12] = model.config.use_scaled_rope header[13] = model.config.max_gen_batch_size - header[14] = model.version + header[14] = int(model.config.version.split('.')[0]) # major version + header[15] = int(model.config.version.split('.')[1]) # minor version # 2) the parameters follow the header params = {name: param.cpu() for name, param in model.named_parameters()} # now write to file @@ -620,7 +621,7 @@ def print0(*args, **kwargs): parser.add_argument("--input_bin", type=str, default="dev/data/tinyshakespeare/tiny_shakespeare_val.bin", help="input .bin to train on") parser.add_argument("--input_val_bin", type=str, default="", help="input .bin to eval validation loss on") parser.add_argument("--output_dir", type=str, default="", help="output directory to which to write logs and checkpoints") - parser.add_argument("--model", type=str, default="gpt2", help="gpt2|gpt2-medium|gpt2-large|gpt2-xl|d12|d24|d36|d48") + parser.add_argument("--model", type=str, default="llama3.1", help="llama3.1") # token layout for each step of the optimization parser.add_argument("--batch_size", type=int, default=4, help="batch size, in units of #batch dimensions") parser.add_argument("--sequence_length", type=int, default=64, help="sequence length") @@ -656,7 +657,7 @@ def print0(*args, **kwargs): B, T = args.batch_size, args.sequence_length assert 1 <= T <= 1024 assert args.dtype in {"float32", "float16", "bfloat16"} - assert args.model in {"gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl", "d12", "d24", "d36", "d48"} + assert args.model in {"llama3.1"} # set up DDP (distributed data parallel). torchrun sets this env variable ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? @@ -740,21 +741,6 @@ def print0(*args, **kwargs): if args.input_val_bin: val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size) - # ------------------------------------------------------------------------- - # LLaMA 3 inference - model.eval() - prompts: List[str] = json.loads(open(os.path.join(os.path.dirname(__file__), 'llmc_py', 'prompts.json')).read())['prompts'] - prompt_tokens = [model.tokenizer.encode(x, bos=True, eos=False) for x in prompts] - - generation_tokens, _ = model.generate(prompt_tokens, max_gen_len=64, temperature=0.6, top_p=0.9, logprobs=False, echo=False) - results = [{"generation": model.tokenizer.decode(t)} for t in generation_tokens] - for prompt, result in zip(prompts, results): - print(prompt, end="") - print(f"{result['generation']}") - print("\n==================================\n") - - exit(0) # only inference supported for now - # ------------------------------------------------------------------------- # PyTorch -> C bridge: save some weights and state for C to load later as reference @@ -762,17 +748,16 @@ def print0(*args, **kwargs): if master_process and args.write_tensors and (not args.inference_only): x, y = train_loader.next_batch() x, y = x.to(device), y.to(device) - logits, loss = model(x, y) - loss.backward() + logits, loss = model(x, y, start_pos=0) + # loss.backward() # save model params, in both float32 and bfloat16 - model_to_size = {"gpt2": "124M", "gpt2-medium": "355M", "gpt2-large": "774M", "gpt2-xl": "1558M"} - model_to_size.update({f"d{d}": f"d{d}" for d in [12, 24, 36, 48]}) - model_size_str = model_to_size[args.model] # e.g. "124M", or "d12" - write_model(model, f"gpt2_{model_size_str}.bin", dtype="float32") - write_model(model, f"gpt2_{model_size_str}_bf16.bin", dtype="bfloat16") + model_to_size = {"llama3.1": "8B"} + model_size_str = model_to_size[args.model] # e.g. "8B" + write_model(model, f"llama3.1_{model_size_str}.bin", dtype="float32") + write_model(model, f"llama3.1_{model_size_str}_bf16.bin", dtype="bfloat16") # save x, y, logits, loss, and parameter gradients, for debugging C # always store these in fp32 to have an accurate reference (?) - write_state(model, x, y, logits, loss, f"gpt2_{model_size_str}_debug_state.bin") + write_state(model, x, y, logits, loss, f"llama3_{model_size_str}_debug_state.bin") # reset the train_loader for the optimization below train_loader.reset() @@ -846,16 +831,15 @@ def get_lr(it): and (step % args.sample_every == 0 or last_step)) \ and master_process: model.eval() - # before we end, let's also do one round of inference - # we'll kick off the generation with "<|endoftext|>", which designates the start of a new sequence - start_ids = [[enc.eot_token]] - enc.pad_id = enc.eot_token # dummy value, it's a no-op - enc.stop_tokens = [-1] # dummy value, we don't stop early - raw_model.tokenizer = enc - yg = raw_model.generate(start_ids, max_gen_len=32, temperature=1.0, top_p=0.9) - print0('---------------') - print0(enc.decode(yg[0][0])) - print0('---------------') + prompts: List[str] = json.loads(open(os.path.join(os.path.dirname(__file__), 'llmc_py', 'prompts.json')).read())['prompts'] + prompt_tokens = [model.tokenizer.encode(x, bos=True, eos=False) for x in prompts] + + generation_tokens, _ = model.generate(prompt_tokens, max_gen_len=64, temperature=0.6, top_p=0.9, logprobs=False, echo=False) + results = [{"generation": model.tokenizer.decode(t)} for t in generation_tokens] + for prompt, result in zip(prompts, results): + print(prompt, end="") + print(f"{result['generation']}") + print("\n==================================\n") # bit confusing: we want to make sure to eval and sample on 0th iteration # but also after the very last iteration. so we loop for step <= num_iterations From 879cc5f46cfbbfca4ce55b8a669f1bd338e2d437 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Sun, 4 Aug 2024 08:00:44 +0000 Subject: [PATCH 19/36] Remove init weights as it's gpt-2 specific --- train_gpt2.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/train_gpt2.py b/train_gpt2.py index 2961bc716..a4c1a9ba1 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -180,7 +180,6 @@ def __init__(self, config): # init all weights, use a torch rng object to be very careful self.init_rng = torch.Generator() self.init_rng.manual_seed(42) - # self.apply(self._init_weights) self.freqs_cis = precompute_freqs_cis( config.n_embd // config.n_head, @@ -189,19 +188,6 @@ def __init__(self, config): config.use_scaled_rope, ) - def _init_weights(self, module): - if isinstance(module, nn.Linear): - # apply special scaled init to the residual projections, per GPT-2 paper - std = 0.02 if not hasattr(module, 'LLMC_RESIDUAL_SCALE_FLAG') else 0.02/math.sqrt(2 * self.config.n_layer) - # we want to skip initializing lm_head, which shares parameters with wte - # and wte was already initialized down below during the Embedding init - if not hasattr(module, 'LLMC_SKIP_INIT'): - torch.nn.init.normal_(module.weight, mean=0.0, std=std, generator=self.init_rng) - if module.bias is not None: - torch.nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02, generator=self.init_rng) - def forward(self, idx, targets=None, return_logits=True, start_pos=-1): device = idx.device b, t = idx.size() From 7768a36f37bc47efedefe7c35a11e170650b0ba5 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Sun, 4 Aug 2024 10:07:48 +0200 Subject: [PATCH 20/36] Add prompts file --- llmc_py/prompts.json | 8 ++++++++ train_gpt2.py | 1 - 2 files changed, 8 insertions(+), 1 deletion(-) create mode 100644 llmc_py/prompts.json diff --git a/llmc_py/prompts.json b/llmc_py/prompts.json new file mode 100644 index 000000000..b089bb602 --- /dev/null +++ b/llmc_py/prompts.json @@ -0,0 +1,8 @@ +{ + "prompts": [ + "Clearly, the meaning of life is", + "Simply put, the theory of relativity states that", + "The repo llm.c on GitHub is", + "Translate English to French:\n\nsea otter => loutre de mer\npeppermint => menthe poivrée\nplush girafe => girafe peluche\ncheese =>" + ] + } \ No newline at end of file diff --git a/train_gpt2.py b/train_gpt2.py index a4c1a9ba1..1c5904513 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -18,7 +18,6 @@ import os import math import glob -import struct import inspect from contextlib import nullcontext from dataclasses import dataclass From cd902735b29877cf5c567b2e5760172fe6719726 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Sun, 4 Aug 2024 08:54:21 +0000 Subject: [PATCH 21/36] Fix saving model / state logic --- train_gpt2.py | 43 ++++++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/train_gpt2.py b/train_gpt2.py index 1c5904513..f84084205 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -259,9 +259,11 @@ def from_pretrained_llama3(cls, ckpt_dir, tokenizer_path): checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) checkpoint = LLaMA.adapt_llama_state_dict_keys(checkpoint, model_args) - torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) + original_default_type = torch.get_default_dtype() # save the default type + torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) # much faster loading model = LLaMA(model_args) model.load_state_dict(checkpoint, strict=False) + torch.set_default_tensor_type(torch.tensor([], dtype=original_default_type, device="cpu").type()) # restore default type tokenizer = Tokenizer(model_path=tokenizer_path) model.tokenizer = tokenizer @@ -529,10 +531,10 @@ def write_tensors(model_tensors, L, file, dtype): def write_model(model, filename, dtype): # everything we need to instantiate the model # 1) header is: version int, LLaMAConfig ints, padding to 1024 bytes - assert dtype in {"float32", "bfloat16"} # float16 todo maybe later + assert dtype in {"float32", "bfloat16"} version = { - "float32": 3, # 3: all tensors are fp32, padded vocab - "bfloat16": 5, # 5: all tensors are bf16, padded vocab + "float32": 3, # 3: all tensors are fp32 + "bfloat16": 5, # 5: all tensors are bf16 }[dtype] header = torch.zeros(256, dtype=torch.int32) header[0] = 20240803 # magic @@ -632,10 +634,10 @@ def print0(*args, **kwargs): parser.add_argument("--device", type=str, default="", help="by default we autodetect, or set it here") parser.add_argument("--compile", type=int, default=0, help="torch.compile the model") parser.add_argument("--flash", type=int, default=0, help="use flash attention") - parser.add_argument("--dtype", type=str, default="float32", help="float32|float16|bfloat16") + parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|float16|bfloat16") parser.add_argument("--zero_stage", type=int, default=0, help="zero redundancy optimizer stage (0/1/2/3)") # python -> C bridge - parser.add_argument("--write_tensors", type=int, default=1, help="write tensors to disk") + parser.add_argument("--write_tensors", type=int, default=0, help="write tensors to disk") args = parser.parse_args() # args error checking and convenience variables @@ -644,6 +646,15 @@ def print0(*args, **kwargs): assert args.dtype in {"float32", "float16", "bfloat16"} assert args.model in {"llama3.1"} + # create the logging directory if it does not exist + logfile = None + if args.output_dir: + os.makedirs(args.output_dir, exist_ok=True) + logfile = os.path.join(args.output_dir, "main.log") + # create the log file "main.log" inside it, and wipe it clean + with open(logfile, "w") as f: + pass + # set up DDP (distributed data parallel). torchrun sets this env variable ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? if ddp: @@ -678,6 +689,7 @@ def print0(*args, **kwargs): device = "mps" print(f"using device: {device}") device_type = 'cuda' if 'cuda' in device else 'cpu' + assert device_type in {'cuda'} # we need to load LLaMA as bf16 on CUDA # calculate gradient accumulation from the desired total batch size and the current run configuration tokens_per_fwdbwd = B * T * ddp_world_size @@ -710,7 +722,6 @@ def print0(*args, **kwargs): model = LLaMA.from_pretrained_llama3(args.ckpt_dir, args.tokenizer_path) model.train() - model.to(device) if args.compile: if hasattr(config, "coordinate_descent_tuning"): config.coordinate_descent_tuning = True # suggested by @Chillee @@ -734,15 +745,14 @@ def print0(*args, **kwargs): x, y = train_loader.next_batch() x, y = x.to(device), y.to(device) logits, loss = model(x, y, start_pos=0) - # loss.backward() - # save model params, in both float32 and bfloat16 + loss.backward() + # save model params, in bfloat16 model_to_size = {"llama3.1": "8B"} model_size_str = model_to_size[args.model] # e.g. "8B" - write_model(model, f"llama3.1_{model_size_str}.bin", dtype="float32") - write_model(model, f"llama3.1_{model_size_str}_bf16.bin", dtype="bfloat16") + write_model(model, os.path.join(args.output_dir, f"llama3.1_{model_size_str}_bf16.bin"), dtype="bfloat16") # save x, y, logits, loss, and parameter gradients, for debugging C # always store these in fp32 to have an accurate reference (?) - write_state(model, x, y, logits, loss, f"llama3_{model_size_str}_debug_state.bin") + write_state(model, x, y, logits, loss, os.path.join(args.output_dir, f"llama3_{model_size_str}_debug_state.bin")) # reset the train_loader for the optimization below train_loader.reset() @@ -774,15 +784,6 @@ def get_lr(it): coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0 return min_lr + coeff * (args.learning_rate - min_lr) - # create the logging directory if it does not exist - logfile = None - if args.output_dir: - os.makedirs(args.output_dir, exist_ok=True) - logfile = os.path.join(args.output_dir, "main.log") - # create the log file "main.log" inside it, and wipe it clean - with open(logfile, "w") as f: - pass - if device == "cuda": torch.cuda.reset_peak_memory_stats() timings = [] From 4b386a2a8902804b074652dfd9c0e3f58f517162 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Sun, 4 Aug 2024 09:13:29 +0000 Subject: [PATCH 22/36] Test training loop works --- train_gpt2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_gpt2.py b/train_gpt2.py index f84084205..01bd923c2 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -81,7 +81,7 @@ def forward(self, x, freqs_cis=None, start_pos=None): q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) # rotate QK (rope) - if start_pos >= 0: # kv-caching (which we can disable by setting start_pos = -1) + if not self.training and start_pos >= 0: # use kv-caching during inference self.cache_k[:B, start_pos : start_pos + T] = k self.cache_v[:B, start_pos : start_pos + T] = v k = self.cache_k[:B, : start_pos + T] @@ -187,7 +187,7 @@ def __init__(self, config): config.use_scaled_rope, ) - def forward(self, idx, targets=None, return_logits=True, start_pos=-1): + def forward(self, idx, targets=None, return_logits=True, start_pos=0): device = idx.device b, t = idx.size() assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" @@ -744,7 +744,7 @@ def print0(*args, **kwargs): if master_process and args.write_tensors and (not args.inference_only): x, y = train_loader.next_batch() x, y = x.to(device), y.to(device) - logits, loss = model(x, y, start_pos=0) + logits, loss = model(x, y) loss.backward() # save model params, in bfloat16 model_to_size = {"llama3.1": "8B"} From 0749a4af428a8c100855d8b52413343be2bcbc69 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Sun, 4 Aug 2024 09:28:17 +0000 Subject: [PATCH 23/36] Minor refactor - remove wpe pos array from fwd --- train_gpt2.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/train_gpt2.py b/train_gpt2.py index 01bd923c2..9f16f1313 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -2,6 +2,11 @@ Reference code for LLaMA-3.1 training and inference. Will save the model weights into files, to be read from C as initialization. +This code differs from GPT-2 very slightly, there are three main differences: +1) RoPE: LLaMA uses a different positional encoding scheme called Relative Positional Encoding (RoPE). +2) GQA: Grouped Query Attention (GQA) is used to reduce the number of attention heads. +3) SwiGLU: Swish-Gated Linear Unit (SwiGLU) is used as the activation function in the MLP. + References: # 1) https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/tokenizer.py # 2) https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py @@ -9,10 +14,10 @@ Example launches to only benchmark the speed of bfloat16 compiled GPU training: 1 GPU: -python train_llama3.py --write_tensors=0 --num_iterations=50 --sequence_length=1024 --compile=1 --tensorcores=1 --dtype=bfloat16 +python train_llama3.py --write_tensors=0 --num_iterations=50 --sequence_length=8192 --compile=1 --tensorcores=1 --dtype=bfloat16 you can also turn on flash-attention by appending --flash=1 4 GPU: -torchrun --standalone --nproc_per_node=4 train_llama3.py --write_tensors=0 --num_iterations=50 --sequence_length=1024 --compile=1 --tensorcores=1 --dtype=bfloat16 +torchrun --standalone --nproc_per_node=4 train_llama3.py --write_tensors=0 --num_iterations=50 --sequence_length=8192 --compile=1 --tensorcores=1 --dtype=bfloat16 """ import os @@ -59,14 +64,16 @@ def __init__(self, config): self.n_kv_head = config.n_kv_head self.n_rep = self.n_head // self.n_kv_head self.hd = config.n_embd // config.n_head + self.use_kv = config.use_kv self.c_attn = nn.Linear(config.n_embd, (config.n_head + 2 * config.n_kv_head) * self.hd, bias=False) # key, query, value projections self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) # output projection self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 - # static KV cache - self.cache_k = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd)) - self.cache_v = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd)) + # static KV cache - we could alternatively allocate it outside of the model and just pass it in when needed + if self.use_kv: + self.cache_k = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd)) + self.cache_v = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd)) # not really a 'bias', more of a mask, but following the OpenAI/HF naming though self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) @@ -79,15 +86,15 @@ def forward(self, x, freqs_cis=None, start_pos=None): q, k, v = qkv.split([self.n_head * self.hd, self.n_kv_head * self.hd, self.n_kv_head * self.hd], dim=-1) q, k, v = map(lambda t: t.view(B, T, -1, self.hd), (q, k, v)) # (B, T, NH, HD) - q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) # rotate QK (rope) + q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) # rotate QK (rope) <-- 1. difference compared to GPT-2 - if not self.training and start_pos >= 0: # use kv-caching during inference + if self.use_kv and not self.training and start_pos >= 0: # use kv-caching during inference self.cache_k[:B, start_pos : start_pos + T] = k self.cache_v[:B, start_pos : start_pos + T] = v k = self.cache_k[:B, : start_pos + T] v = self.cache_v[:B, : start_pos + T] - k = repeat_kv(k, self.n_rep) # GQA + k = repeat_kv(k, self.n_rep) # GQA <-- 2. difference compared to GPT-2 v = repeat_kv(v, self.n_rep) q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v)) # (B, NH, T, HD) @@ -122,7 +129,7 @@ def __init__(self, config): self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 def forward(self, x): - # SwiGLU self.c_proj(F.silu(self.c_fc2(x)) * self.c_fc(x)) + # SwiGLU self.c_proj(F.silu(self.c_fc2(x)) * self.c_fc(x)) <-- 3. difference compared to GPT-2 x1 = self.c_fc(x) x2 = self.c_fc2(x) x2 = F.silu(x2) @@ -150,7 +157,7 @@ def forward(self, x, freqs_cis=None, start_pos=None): @dataclass class LlamaConfig: version: str = "3.1" - block_size: int = 1024 + block_size: int = 8192 vocab_size: int = 128256 n_layer: int = 32 n_head: int = 32 @@ -162,6 +169,7 @@ class LlamaConfig: rope_theta: float = 500000.0 use_scaled_rope: bool = True max_gen_batch_size: int = 4 + use_kv: bool = True class LLaMA(nn.Module): @@ -188,10 +196,8 @@ def __init__(self, config): ) def forward(self, idx, targets=None, return_logits=True, start_pos=0): - device = idx.device - b, t = idx.size() + _, t = idx.size() assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" - pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) # forward the LLaMA model itself x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) @@ -642,7 +648,7 @@ def print0(*args, **kwargs): # args error checking and convenience variables B, T = args.batch_size, args.sequence_length - assert 1 <= T <= 1024 + assert 1 <= T <= 8192, "sequence length must be between 1 and 8192" assert args.dtype in {"float32", "float16", "bfloat16"} assert args.model in {"llama3.1"} From 8e55d168630d150bb9c41d5cae8e1d77f1b0e7d3 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Sun, 4 Aug 2024 16:27:12 +0200 Subject: [PATCH 24/36] Support HF & Meta models --- train_gpt2.py | 85 +++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 79 insertions(+), 6 deletions(-) diff --git a/train_gpt2.py b/train_gpt2.py index 9f16f1313..1648aa3e7 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -256,8 +256,71 @@ def adapt_llama_state_dict_keys(checkpoint, config: LlamaConfig): return checkpoint + @staticmethod + def adapt_llama_state_dict_keys_hf(checkpoint, config: LlamaConfig): + checkpoint['transformer.wte.weight'] = checkpoint.pop('model.embed_tokens.weight') + + # We need to unpermute K and V because HF script permuted the original Meta-LLaMA weights + # see: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py + def unpermute(w, n_heads, dim1, dim2): + return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + for i in range(config.n_layer): + for name in ['input_layernorm', 'post_attention_layernorm']: + old_key = f'model.layers.{i}.{name}.weight' # e.g. layers.x.attention_norm.weight -> transformer.h.x.ln_1.weight + new_key = f'transformer.h.{i}.ln_{1 if name == "input_layernorm" else 2}.weight' + checkpoint[new_key] = checkpoint.pop(old_key) + + for i in range(config.n_layer): + for name in ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj']: + old_key = f'model.layers.{i}.{name}.weight' + new_key = f'transformer.h.{i}.attn.c_attn.weight' + if name == 'self_attn.q_proj': + checkpoint[new_key] = unpermute(checkpoint.pop(old_key), config.n_head, config.n_embd, config.n_embd) + else: # merge 3 weights into transformer.h.x.attn.c_attn.weight + tensor = checkpoint.pop(old_key) + if name == 'self_attn.k_proj': + tensor = unpermute(tensor, config.n_kv_head, config.n_kv_head * (config.n_embd // config.n_head), config.n_embd) + checkpoint[new_key] = torch.cat((checkpoint[new_key], tensor), dim=0) + old_key = f'model.layers.{i}.self_attn.o_proj.weight' + new_key = f'transformer.h.{i}.attn.c_proj.weight' + checkpoint[new_key] = checkpoint.pop(old_key) + + ffn_map = {'gate_proj': 'c_fc2', 'down_proj': 'c_proj', 'up_proj': 'c_fc'} + for i in range(config.n_layer): + for name in ['gate_proj', 'down_proj', 'up_proj']: + old_key = f'model.layers.{i}.mlp.{name}.weight' + new_key = f'transformer.h.{i}.mlp.{ffn_map[name]}.weight' + checkpoint[new_key] = checkpoint.pop(old_key) + + checkpoint['transformer.ln_f.weight'] = checkpoint.pop('model.norm.weight') + + return checkpoint + + @classmethod + def from_pretrained_llama3_hf(cls, model_id): + """Loads pretrained LLaMA model weights from HuggingFace""" + from transformers import AutoModelForCausalLM, AutoTokenizer + assert model_id == "meta-llama/Meta-Llama-3.1-8B", "Only the 8B-bae model is supported for now" + model_args = LlamaConfig() + + model = AutoModelForCausalLM.from_pretrained(model_id) + checkpoint = LLaMA.adapt_llama_state_dict_keys_hf(model.state_dict(), model_args) + + original_default_type = torch.get_default_dtype() # save the default type + torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) # much faster loading + model = LLaMA(model_args) + model.load_state_dict(checkpoint, strict=False) + torch.set_default_tensor_type(torch.tensor([], dtype=original_default_type, device="cpu").type()) # restore default type + + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_id = 128004 # this is the pad token id for LLaMA 3.1 base, we need to set this explicitly as our generate func expects it + tokenizer.stop_tokens = [tokenizer.eos_token_id] + model.tokenizer = tokenizer + return model + @classmethod - def from_pretrained_llama3(cls, ckpt_dir, tokenizer_path): + def from_pretrained_llama3_meta(cls, ckpt_dir, tokenizer_path): """Loads pretrained LLaMA model weights from a checkpoint directory""" model_args = LlamaConfig() @@ -272,6 +335,9 @@ def from_pretrained_llama3(cls, ckpt_dir, tokenizer_path): torch.set_default_tensor_type(torch.tensor([], dtype=original_default_type, device="cpu").type()) # restore default type tokenizer = Tokenizer(model_path=tokenizer_path) + # add <|end_of_text|> as the stop token for base model - this is an omission in the reference code + # the reference code only adds instruct model stop tokens... + tokenizer.stop_tokens = tokenizer.stop_tokens + [128001] model.tokenizer = tokenizer return model @@ -608,13 +674,14 @@ def print0(*args, **kwargs): # default settings will overfit a tiny batch of data # and save model weights and debug state to disk on the first iteration parser = argparse.ArgumentParser() + parser.add_argument("--use_hf", type=int, default=1, help="use HuggingFace (default) or use Meta's model") parser.add_argument("--ckpt_dir", type=str, default=None, help="path to llama3 model checkpoint") parser.add_argument("--tokenizer_path", type=str, default=None, help="path to llama3 tokenizer") # file system input / output parser.add_argument("--input_bin", type=str, default="dev/data/tinyshakespeare/tiny_shakespeare_val.bin", help="input .bin to train on") parser.add_argument("--input_val_bin", type=str, default="", help="input .bin to eval validation loss on") parser.add_argument("--output_dir", type=str, default="", help="output directory to which to write logs and checkpoints") - parser.add_argument("--model", type=str, default="llama3.1", help="llama3.1") + parser.add_argument("--model", type=str, default="meta-llama/Meta-Llama-3.1-8B", help="chose the llama model") # token layout for each step of the optimization parser.add_argument("--batch_size", type=int, default=4, help="batch size, in units of #batch dimensions") parser.add_argument("--sequence_length", type=int, default=64, help="sequence length") @@ -650,7 +717,7 @@ def print0(*args, **kwargs): B, T = args.batch_size, args.sequence_length assert 1 <= T <= 8192, "sequence length must be between 1 and 8192" assert args.dtype in {"float32", "float16", "bfloat16"} - assert args.model in {"llama3.1"} + assert args.model in {"meta-llama/Meta-Llama-3.1-8B"} # only 8B base model supported for now # create the logging directory if it does not exist logfile = None @@ -725,7 +792,10 @@ def print0(*args, **kwargs): # init the model assert args.ckpt_dir is not None and os.path.exists(args.ckpt_dir), f"llama3 ckpt dir {args.ckpt_dir} does not exist" assert args.tokenizer_path is not None and os.path.exists(args.tokenizer_path), f"llama3 tokenizer path {args.tokenizer_path} does not exist" - model = LLaMA.from_pretrained_llama3(args.ckpt_dir, args.tokenizer_path) + if args.use_hf: + model = LLaMA.from_pretrained_llama3_hf(args.model) + else: # use Meta's checkpoint + model = LLaMA.from_pretrained_llama3_meta(args.ckpt_dir, args.tokenizer_path) model.train() if args.compile: @@ -753,7 +823,7 @@ def print0(*args, **kwargs): logits, loss = model(x, y) loss.backward() # save model params, in bfloat16 - model_to_size = {"llama3.1": "8B"} + model_to_size = {"meta-llama/Meta-Llama-3.1-8B": "8B"} model_size_str = model_to_size[args.model] # e.g. "8B" write_model(model, os.path.join(args.output_dir, f"llama3.1_{model_size_str}_bf16.bin"), dtype="bfloat16") # save x, y, logits, loss, and parameter gradients, for debugging C @@ -824,7 +894,10 @@ def get_lr(it): and master_process: model.eval() prompts: List[str] = json.loads(open(os.path.join(os.path.dirname(__file__), 'llmc_py', 'prompts.json')).read())['prompts'] - prompt_tokens = [model.tokenizer.encode(x, bos=True, eos=False) for x in prompts] + if args.use_hf: + prompt_tokens = [model.tokenizer(x).input_ids for x in prompts] + else: # Meta + prompt_tokens = [model.tokenizer.encode(x, bos=True, eos=False) for x in prompts] generation_tokens, _ = model.generate(prompt_tokens, max_gen_len=64, temperature=0.6, top_p=0.9, logprobs=False, echo=False) results = [{"generation": model.tokenizer.decode(t)} for t in generation_tokens] From 72dcfeb404e8708c4118822aa0e416942b76d456 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Sun, 4 Aug 2024 22:20:46 +0200 Subject: [PATCH 25/36] Remove float(-inf) --- train_gpt2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_gpt2.py b/train_gpt2.py index 1648aa3e7..4eea1f46e 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -106,7 +106,7 @@ def forward(self, x, freqs_cis=None, start_pos=None): # manual implementation of attention # this materializes the large (T,T) matrix for all the queries and keys scores = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.hd)) - scores = scores.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + scores = scores.masked_fill(self.bias[:,:,:T,:T] == 0, torch.finfo(scores.dtype).min) att = F.softmax(scores.float(), dim=-1).type_as(q) y = att @ v # (B, NH, T, T) x (B, NH, T, HD) -> (B, NH, T, HD) y = y.transpose(1, 2).contiguous().view(B, T, C) From d4ef9c5afdd9c5b2ffcba9b03077535045ee4ad7 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Thu, 8 Aug 2024 17:09:41 +0200 Subject: [PATCH 26/36] Remove llmc_py, single file --- llmc_py/__pycache__/rope.cpython-310.pyc | Bin 0 -> 2167 bytes llmc_py/__pycache__/tokenizer.cpython-310.pyc | Bin 0 -> 5372 bytes llmc_py/__pycache__/utils.cpython-310.pyc | Bin 0 -> 2245 bytes llmc_py/rope.py | 59 ---- llmc_py/tokenizer.py | 173 ---------- llmc_py/utils.py | 57 ---- train_gpt2.py | 298 +++++++++++++++++- 7 files changed, 295 insertions(+), 292 deletions(-) create mode 100644 llmc_py/__pycache__/rope.cpython-310.pyc create mode 100644 llmc_py/__pycache__/tokenizer.cpython-310.pyc create mode 100644 llmc_py/__pycache__/utils.cpython-310.pyc diff --git a/llmc_py/__pycache__/rope.cpython-310.pyc b/llmc_py/__pycache__/rope.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8797638e54d73f9b72c780cf7c209e54840afb2c GIT binary patch literal 2167 zcmaJ?OK)5?6t;bRA2XRenx<_^t12J?8VE_%E)Y@`Dp7U8A`9qhL^Zk@doq)`Gk4;9 z(EU{vNvSZCFe8)G_2SRvsj(v}h?c?*E&y&y2)*0H* zKYic*BVp_p`gpm~>`h4V4U}SvXRJlN$oNpS1g~sCD^P(BcBGP|yHbTJdMI0=iv71x ziK^i{Qg!b&<8`&7p#^I;1=8@xi*u$9ipi;OaNBeo=Pv z;do<)LZED|vd^1F7rk_>+uhu>w@jX@PFj>}H?q8wX2k{u(^XAr8d&2(@9%1Qv^Ufy zEwz*P`uct|Fmy#jSvT{fC@e!~nOP`z?mtf=r-I5Z;+OdqzRc^~ERuWZv3(_lv{wZA z2vYn2rT7GJO@s<2fszwxr4^UhBt##ekG9z)w(J(WyD>>rSOVZBW)sZRY|TakX5O~( zp|HsjN?>;hws0WL^~z@XC3b2L;Tsem_#%fIuaoe&bK=y3o*meWTv+C&({oX&M@4ST1=?(xR?iqJG1n*nAYe^9aBR(I`TaI!bZS`W^|yPk zSmyIgwLAH!)ca*S)1%hH^ELO=dzvg~UM!lLY9n23TsTSlx8mtcR1EXH>{VspYr&Vn zEaJ}&+7w-Xv^|GVi?^Vpkeq&rNJ#TJNCUFK8=}r5E>Bj1C2m&eYGQvEGH{X6<(?UJ zsP3kX4kGb2@+}4>I|Q{4#Sw@~l*6lqy7>|$|BfsE6!r6zf5M&st4ZK}`Ey_XKfbaI zt+2rZR3BDFR1Hhm<1khMc0#SehH8At7TNGHvGB1GdVa*aL~OZfsHDw#sAIZFvt}-% zbMr*1c8b>xl+E-$2$AJ!>EcrwRUGqFX-r*>PBRyGv$RB2xR`1PJ2YeSIvw!_DLSGR z?C-Xr2JKdaQNCBAA5i~quofz?OkSqupsj~n1XV%Lz-qV(sxJATc^8IXA&fox@#bw8 z;8AixRir&#rmntMXpgv79jf(ddj+fvR>du)6QEdzL|@G z>bay8R9W}zSQ;BmfMY{XPXFh24L$`ixllMowZlp4}{a<)-5~%^`27S@9~(VQ;uu^{UMgZ=_lCYHE)jje4WlYei$tac>-RC2xY+ z-k#SrR%Xr(jXB$fx0mTFT7Bp)&eoRedTzZF$c_4NKD@A+N)a^V3YWRF7(`L98u7fc zm_!k8$S{d>>qOYV+DXYpAQO>Wr@~a`gg+~9fXy;$vR1gXSf z<53bYUxt@ta*4 zTH|GA%t2nB!z^!zm3V2*U^Xk?u)GR$*buAy)?mZD#HxIljiA+dm5=Zm8^zvHUYgd3 z%CyGDdb4BQSvJncxebgen_zox*pTL4zK3S8Qd+avJ~oMS_VRt|geh)gep2nW*?u;C z!`jxpDds@lPvk>q=*r<17h7*vT>LuYFkW|qn7I&F!Z0|>c+qhszb4(yln~?wch7XB z<|^3{p;<*Iuca;C2!qI13+Y`=*R-v4=6HDQV45A2*+IdxgICa+*}*8v4zhZ7_Q&VM zx%gb`gHvsQ<@WjGs((0BqYs{c_Jwxl9tT!BfOGgPZ90D_JC`0Ry1-8l&m8~8xs^Kg z@=7ydJn~zxsD=Thb$=yvY&5d`^j`o{yQtqXGX1umR^HSyeOv!nmnMDNhO{zcTmPxH z4o&}FZ=1|u<_%4jGVP*`HFI0vuWjom(ApSf)|A#ZKhf57X0g%@N0teLRx{0a6 z945LKtJpu18GsJ6Kee!0yQvlQrZkL3)k=Aq`eUl^sQ$R>52^ly>Q_{MkLnMr{$ACu zs{Xz{4ovJqH8ZR=M`Wu*R-aVViQ=rJwS+cgjOYA83{ps#Nw)Y(dJ=Ym*I zkM%}Gt1otU22NZ*=AMDUkQ!htyNySox@%#?3+tin_2GNtbJ;3i4`V-B^U3bA`Fo>t z8IM_iV(Y27Oz@P8E1dbeOwNBUn}5ve^B^`om#u{{mzy#7MYz5pVI<5C88!~jWvS$? zVz7q>fspg&> z`V_!AbbY?kMqFWfrcY_VK@9$#`3vK+0bQ9};Q34LdzVOU8QJ_Dt#0JT>~pyuiYG9q zz`X-UHyArAo2{c!)SPV)Z1}C~NBfT3x}I0&B1F$d(rnEy5X z$9+G9rTG5rT$D6|D4j>AyK&d`$hqixdx#n_6VmmD&KYeuRc73Tn@wrIRF~`XS~ta$ zKoum8yrdYsqkE(AX478{Qr-zE1^(9vr!vQo1QgpRF;2;?ZLOw%CDXlabt|{xpvm(x z4Ah6Y=A%t6);XmcUj*@`RE&YBpnHhD1f~c~64*~*AHZ^bDmRiex5&tIJ50%<1!auf zjCh=vKxm!kPtrvmt;94R;#j3r#?^tj)6Gksv(yR3f)p*DCUAggmz4YD<=DTP2$tsd zYLde6a+|BKhc6i-eaSezk-jl1w8H4~;Z!EAj%@3t+)8=0CMM{VX9&>sz0n@qMOZI% z;L^HMgwr*Q!@4f_Q;rF-*>EvVe8J0Wjea08)(kwwQ|*{>fJxWuH4_6 zc-YpK)nRvxd=zCsa<%X8`2-8o2j}!OHz9}Igb)NVlAAygFHz4H@OcsUvVn!kAvCgf zGT8KI0GY+~O-+=gwym`-sozBOmd2)*j=rgVtbd|?q9dnM7I37=ENQ9r(zdB`fW5A@ z%b9&bJ6~%%S^1WIQ@>cg=wwc2WhGYn)V!_h+T{b8Azs1i5VNpexs56-GsPlihB2-Z zPF9B3+reMJRxh|JJ3xRAT-Yfk;xI-)4P5d-H(4uEv3v4F`q=QC4o2lLZf(jgIS^ar zMe66QACM_1L!$1?YLY~V=I?KCQ~-j4c}wsXf}slPDq_cA>qxS8B<+-31y8$aK{1b_ z6|{P!K=P}Ibk$y%qdW67CZ5hxF3&)f^{Rl44Q9epXm2MEcH zK)NCbQ;vekY)0}~7Z;SI z`2mJIe7!nT@2GCTKvu82;d-2qvGkbPeUDjph3d7n5Yark-cRI5TvXobs0$CDb%At` z^R9+bPI_DrDOv z>D>goQm^-|Yk=vF`n#op;#k>8!L$?{>id5%z57yqUX~w9hh+qp7bwbn8{i{KhkM)l zP5qWe$uQF;65>x#{Y&Gvf%s9PH2FhR&zHw88w_PaU0<%7`7^jtgbaLj1Mi%)6*PG2 zC&G`Dcwi=sS~ymg52Z!pLks1CrH|>xj`f9M8uzO^ zTL=2VTsHRv_L&dYT|W(3uq~E z;C$W2<*G&v*~}_a7PFDZ=+FAX-m%Ix%v8~~#vZ)1coW-uIR|}>@;DVoi)bm85~$L( zTgWofloslCecQNA?_z!HZ#@q>K@Tw%{2vPU-8@iag~tYe>CT+$#+^=K{xG@1G1beD z@<5M;M}{#RS`_#l{7!i7J3_tG(N}>WMsQz288x0Vm}A?2sJj|c98qF&h$%` zzE5}c4?_rjM-A*xvy2_wm-N_ z&Wo_5d^9hIsftL-ZR_PCbi54=Z@jUItPrEX?&`;z+&vYh2yBREs74ivdUvUZfLFy@ zoLnu|AGDEILH-YDsSAK0!}kR}Kg1G&Qvms>?+0<5$bkM4m16x6Egc~6 zBCT6gP>BkSs|2VN6(#`+4$4PV3=+dMByY&8kc$i(O;lhB6J*zdBrqVAzWa<~!9`2w z04f^9Vp)cT0A{>m==NuZVcf^?bE9PZXVNzAO_Yt#Y^U2B{! z<5Lzz9QF@zvX8lOfWN>G;Mn5E0qub^5)y){wm0jIM31JY`cd6q)mK$hw%aWN?T6of z8hqa%vF~Vp@`uMdnlG8q=!sbdnajL0(sx(` z=3M4;k2O&mc&x?N&YXUO(OnX*U%;897gB8%MS}8bbC(181qhtz#N z6&XMaL1Z%JbSv^D@Fdt`JX?}{`%#|C)#Fdu5Dh) z#9p}&WxNx=@>A#(c}Zq;PAI9ZnnaeeXO41bUQOrlz)XXYUGkdtLSMUbQgNaEC`yY| zMUg*xJr-3JMdj}@0$Hl0leG*QR2CG>1q27YFxPn8EJ9gHIVOs zP?H(JMuO%58%dd3NYuziuma5s9q|Etv^&VkSm}m3sdyAi(Z)5~z!JA{gI7*eHBoN^ zYz#_q6bpuEC0^d4EqZ$666A6eNnc-2!)xovF)QaHIPOqJC!A*2pn>-nli0YzHk^Wew(o zmIpE8)i$HGH@Bx=?Lmw*<6kKRVE}E13-C}LsIvO33bu+#lJSWQ0~zO4mKFy=Sp;!V ziE=;QPqS2=1T2+Gr27!qWzlV!L%tRV@}ZINNZc%gEt5vwJzmICgb#zA|2ZO1W=Ho* zJ=ijl@Rp7m3ND9b#sa{*gnJ2c3F8@86H%-nS%vU0*hP~#1Klw%g0x_1!sSx6S1NvG z!Jcu(<>>)rK2MYk)Q|@50Xhf$zqir@@pQ;p-8+jj~^$T zv&CUoHdWnV4^Ch8ZIH@9a7?i{OHTp!ek`E}J`4_n>Czf713fP~5J!8E%UGs#i*cE7 zFta#-=*Ou8cuiYw5%0n10z!d6p;#*>d1P$nGHfi0g#+9v?SWl#l52;hhmp2XTmwT@ zT9kPjXWBc60g?zC+BMv3TJ>8-icr7bT*xUqz7)-A;Fhs&75rF5hLV+RGL>h20R|SU zsJaOtU&fWgxZ>L5$5mV~-N2ZKiO@3^iU+Vw zTl*6JMj0=_dC~YV8uFM)?c*nc)8BSua|vU7cptPSmiZ0K2hRo82Hm9AKjKT63Vp-) y;&13ZHD%EhE#qJ9E%I1=3|r0qm0L^^G>rL=7WMQAhLps=o$svEj_q&Ux&0p@@;XES literal 0 HcmV?d00001 diff --git a/llmc_py/rope.py b/llmc_py/rope.py index 3caf58073..e69de29bb 100644 --- a/llmc_py/rope.py +++ b/llmc_py/rope.py @@ -1,59 +0,0 @@ -# From: https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py - -import math -from typing import Tuple -import torch - -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): - ndim = x.ndim - assert 0 <= 1 < ndim - assert freqs_cis.shape == (x.shape[1], x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) - -def apply_scaling(freqs: torch.Tensor): - # Values obtained from grid search - scale_factor = 8 - low_freq_factor = 1 - high_freq_factor = 4 - old_context_len = 8192 # original llama3 length - - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - new_freqs = [] - for freq in freqs: - wavelen = 2 * math.pi / freq - if wavelen < high_freq_wavelen: - new_freqs.append(freq) - elif wavelen > low_freq_wavelen: - new_freqs.append(freq / scale_factor) - else: - assert low_freq_wavelen != high_freq_wavelen - smooth = (old_context_len / wavelen - low_freq_factor) / ( - high_freq_factor - low_freq_factor - ) - new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) - return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - -def precompute_freqs_cis( - dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False -): - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device, dtype=torch.float32) - if use_scaled: - freqs = apply_scaling(freqs) - freqs = torch.outer(t, freqs) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis \ No newline at end of file diff --git a/llmc_py/tokenizer.py b/llmc_py/tokenizer.py index 528de113c..2d2a3bd58 100644 --- a/llmc_py/tokenizer.py +++ b/llmc_py/tokenizer.py @@ -1,7 +1,5 @@ # From: https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/tokenizer.py -import os -from pathlib import Path from typing import ( AbstractSet, Callable, @@ -16,174 +14,3 @@ cast, ) -import tiktoken -from tiktoken.load import load_tiktoken_bpe - -# The tiktoken tokenizer can handle <=400k chars without -# pyo3_runtime.PanicException. -TIKTOKEN_MAX_ENCODE_CHARS = 400_000 - -# https://github.com/openai/tiktoken/issues/195 -# Here we iterate over subsequences and split if we exceed the limit -# of max consecutive non-whitespace or whitespace characters. -MAX_NO_WHITESPACES_CHARS = 25_000 - - -class Tokenizer: - """ - Tokenizing and encoding/decoding text using the Tiktoken tokenizer. - """ - - special_tokens: Dict[str, int] - - num_reserved_special_tokens = 256 - - pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 - - def __init__(self, model_path: str): - """ - Initializes the Tokenizer with a Tiktoken model. - - Args: - model_path (str): The path to the Tiktoken model file. - """ - assert os.path.isfile(model_path), model_path - - mergeable_ranks = load_tiktoken_bpe(model_path) - num_base_tokens = len(mergeable_ranks) - special_tokens = [ - "<|begin_of_text|>", - "<|end_of_text|>", - "<|reserved_special_token_0|>", - "<|reserved_special_token_1|>", - "<|finetune_right_pad_id|>", - "<|step_id|>", - "<|start_header_id|>", - "<|end_header_id|>", - "<|eom_id|>", # end of message - "<|eot_id|>", # end of turn - "<|python_tag|>", - ] - reserved_tokens = [ - f"<|reserved_special_token_{2 + i}|>" - for i in range(self.num_reserved_special_tokens - len(special_tokens)) - ] - special_tokens = special_tokens + reserved_tokens - - self.special_tokens = { - token: num_base_tokens + i for i, token in enumerate(special_tokens) - } - self.model = tiktoken.Encoding( - name=Path(model_path).name, - pat_str=self.pat_str, - mergeable_ranks=mergeable_ranks, - special_tokens=self.special_tokens, - ) - - self.n_words: int = num_base_tokens + len(special_tokens) - # BOS / EOS token IDs - self.bos_id: int = self.special_tokens["<|begin_of_text|>"] - self.eos_id: int = self.special_tokens["<|end_of_text|>"] - self.eot_id: int = self.special_tokens["<|eot_id|>"] - self.eom_id: int = self.special_tokens["<|eom_id|>"] - self.python_tag_id = self.special_tokens["<|python_tag|>"] - self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"] - self.stop_tokens = [ - self.special_tokens["<|eom_id|>"], - self.special_tokens["<|eot_id|>"], - ] - - def encode( - self, - s: str, - *, - bos: bool, - eos: bool, - allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, - disallowed_special: Union[Literal["all"], Collection[str]] = (), - ) -> List[int]: - """ - Encodes a string into a list of token IDs. - - Args: - s (str): The input string to be encoded. - bos (bool): Whether to prepend the beginning-of-sequence token. - eos (bool): Whether to append the end-of-sequence token. - allowed_tokens ("all"|set[str]): allowed special tokens in string - disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string - - Returns: - list[int]: A list of token IDs. - - By default, setting disallowed_special=() encodes a string by ignoring - special tokens. Specifically: - - Setting `disallowed_special` to () will cause all text corresponding - to special tokens to be encoded as natural text (insteading of raising - an error). - - Setting `allowed_special` to "all" will treat all text corresponding - to special tokens to be encoded as special tokens. - """ - if allowed_special is None: - allowed_special = set() - assert type(s) is str - - substrs = ( - substr - for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) - for substr in self._split_whitespaces_or_nonwhitespaces( - s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS - ) - ) - t: List[int] = [] - for substr in substrs: - t.extend( - self.model.encode( - substr, - allowed_special=allowed_special, - disallowed_special=disallowed_special, - ) - ) - if bos: - t.insert(0, self.bos_id) - if eos: - t.append(self.eos_id) - return t - - def decode(self, t: Sequence[int]) -> str: - """ - Decodes a list of token IDs into a string. - - Args: - t (List[int]): The list of token IDs to be decoded. - - Returns: - str: The decoded string. - """ - # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. - return self.model.decode(cast(List[int], t)) - - @staticmethod - def _split_whitespaces_or_nonwhitespaces( - s: str, max_consecutive_slice_len: int - ) -> Iterator[str]: - """ - Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` - consecutive whitespaces or consecutive non-whitespaces. - """ - current_slice_len = 0 - current_slice_is_space = s[0].isspace() if len(s) > 0 else False - slice_start = 0 - - for i in range(len(s)): - is_now_space = s[i].isspace() - - if current_slice_is_space ^ is_now_space: - current_slice_len = 1 - current_slice_is_space = is_now_space - else: - current_slice_len += 1 - if current_slice_len > max_consecutive_slice_len: - yield s[slice_start:i] - slice_start = i - current_slice_len = 1 - yield s[slice_start:] \ No newline at end of file diff --git a/llmc_py/utils.py b/llmc_py/utils.py index ed023c78a..e69de29bb 100644 --- a/llmc_py/utils.py +++ b/llmc_py/utils.py @@ -1,57 +0,0 @@ -# Taken from: -# 1) https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py -# 2) https://github.com/meta-llama/llama3/blob/11817d47e1ba7a4959b025eb1ca308572e0e3963/llama/generation.py - -import torch -from torch import nn - -# Special modules -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()).type_as(x) - return output * self.weight - -# Sampling -def sample_top_p(probs, p): - """ - Perform top-p (nucleus) sampling on a probability distribution. - - Args: - probs (torch.Tensor): Probability distribution tensor. - p (float): Probability threshold for top-p sampling. - - Returns: - torch.Tensor: Sampled token indices. - - Note: - Top-p sampling selects the smallest set of tokens whose cumulative probability mass - exceeds the threshold p. The distribution is renormalized based on the selected tokens. - """ - probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = probs_sum - probs_sort > p - probs_sort[mask] = 0.0 - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - next_token = torch.multinomial(probs_sort, num_samples=1) - next_token = torch.gather(probs_idx, -1, next_token) - return next_token - -# GQA -def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" - bs, slen, n_kv_heads, head_dim = x.shape - if n_rep == 1: - return x - return ( - x[:, :, :, None, :] - .expand(bs, slen, n_kv_heads, n_rep, head_dim) - .reshape(bs, slen, n_kv_heads * n_rep, head_dim) - ) \ No newline at end of file diff --git a/train_gpt2.py b/train_gpt2.py index 4eea1f46e..4a8cc6e0e 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -29,9 +29,18 @@ import json from pathlib import Path from typing import ( + AbstractSet, + Callable, + Collection, + Dict, + Iterator, List, + Literal, Optional, + Sequence, Tuple, + Union, + cast, ) import numpy as np @@ -44,9 +53,8 @@ from torch.distributed.optim import ZeroRedundancyOptimizer import torch.distributed as dist -from llmc_py.tokenizer import Tokenizer -from llmc_py.rope import precompute_freqs_cis, apply_rotary_emb -from llmc_py.utils import repeat_kv, sample_top_p, RMSNorm +import tiktoken +from tiktoken.load import load_tiktoken_bpe # ----------------------------------------------------------------------------- # PyTorch nn.Module definitions for the LLaMA 3.x model @@ -54,6 +62,91 @@ # using a global to toggle flash-attention FLASH = 0 +# Used in Grouped Query Attention (GQA), broadcasts the key and value tensors +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + +# ----------------------------------------------------------------------------- +# RoPE related + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + +def apply_scaling(freqs: torch.Tensor): + # Values obtained from grid search + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 # original llama3 length + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False +): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + if use_scaled: + freqs = apply_scaling(freqs) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + +# ----------------------------------------------------------------------------- +# LLaMA building blocks + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + class CausalSelfAttention(nn.Module): def __init__(self, config): @@ -482,6 +575,205 @@ def generate( out_logprobs.append(probs) return (out_tokens, out_logprobs if logprobs else None) +# ----------------------------------------------------------------------------- +# sampling utils + +def sample_top_p(probs, p): + """ + Perform top-p (nucleus) sampling on a probability distribution. + + Args: + probs (torch.Tensor): Probability distribution tensor. + p (float): Probability threshold for top-p sampling. + + Returns: + torch.Tensor: Sampled token indices. + + Note: + Top-p sampling selects the smallest set of tokens whose cumulative probability mass + exceeds the threshold p. The distribution is renormalized based on the selected tokens. + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + +# ----------------------------------------------------------------------------- +# Llama 3.1 Tokenizer + +# The tiktoken tokenizer can handle <=400k chars without +# pyo3_runtime.PanicException. +TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + +# https://github.com/openai/tiktoken/issues/195 +# Here we iterate over subsequences and split if we exceed the limit +# of max consecutive non-whitespace or whitespace characters. +MAX_NO_WHITESPACES_CHARS = 25_000 + + +class Tokenizer: + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + """ + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 + + def __init__(self, model_path: str): + """ + Initializes the Tokenizer with a Tiktoken model. + + Args: + model_path (str): The path to the Tiktoken model file. + """ + assert os.path.isfile(model_path), model_path + + mergeable_ranks = load_tiktoken_bpe(model_path) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|finetune_right_pad_id|>", + "<|step_id|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eom_id|>", # end of message + "<|eot_id|>", # end of turn + "<|python_tag|>", + ] + reserved_tokens = [ + f"<|reserved_special_token_{2 + i}|>" + for i in range(self.num_reserved_special_tokens - len(special_tokens)) + ] + special_tokens = special_tokens + reserved_tokens + + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + + self.n_words: int = num_base_tokens + len(special_tokens) + # BOS / EOS token IDs + self.bos_id: int = self.special_tokens["<|begin_of_text|>"] + self.eos_id: int = self.special_tokens["<|end_of_text|>"] + self.eot_id: int = self.special_tokens["<|eot_id|>"] + self.eom_id: int = self.special_tokens["<|eom_id|>"] + self.python_tag_id = self.special_tokens["<|python_tag|>"] + self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"] + self.stop_tokens = [ + self.special_tokens["<|eom_id|>"], + self.special_tokens["<|eot_id|>"], + ] + + def encode( + self, + s: str, + *, + bos: bool, + eos: bool, + allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, + disallowed_special: Union[Literal["all"], Collection[str]] = (), + ) -> List[int]: + """ + Encodes a string into a list of token IDs. + + Args: + s (str): The input string to be encoded. + bos (bool): Whether to prepend the beginning-of-sequence token. + eos (bool): Whether to append the end-of-sequence token. + allowed_tokens ("all"|set[str]): allowed special tokens in string + disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string + + Returns: + list[int]: A list of token IDs. + + By default, setting disallowed_special=() encodes a string by ignoring + special tokens. Specifically: + - Setting `disallowed_special` to () will cause all text corresponding + to special tokens to be encoded as natural text (insteading of raising + an error). + - Setting `allowed_special` to "all" will treat all text corresponding + to special tokens to be encoded as special tokens. + """ + if allowed_special is None: + allowed_special = set() + assert type(s) is str + + substrs = ( + substr + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) + for substr in self._split_whitespaces_or_nonwhitespaces( + s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS + ) + ) + t: List[int] = [] + for substr in substrs: + t.extend( + self.model.encode( + substr, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ) + if bos: + t.insert(0, self.bos_id) + if eos: + t.append(self.eos_id) + return t + + def decode(self, t: Sequence[int]) -> str: + """ + Decodes a list of token IDs into a string. + + Args: + t (List[int]): The list of token IDs to be decoded. + + Returns: + str: The decoded string. + """ + # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. + return self.model.decode(cast(List[int], t)) + + @staticmethod + def _split_whitespaces_or_nonwhitespaces( + s: str, max_consecutive_slice_len: int + ) -> Iterator[str]: + """ + Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` + consecutive whitespaces or consecutive non-whitespaces. + """ + current_slice_len = 0 + current_slice_is_space = s[0].isspace() if len(s) > 0 else False + slice_start = 0 + + for i in range(len(s)): + is_now_space = s[i].isspace() + + if current_slice_is_space ^ is_now_space: + current_slice_len = 1 + current_slice_is_space = is_now_space + else: + current_slice_len += 1 + if current_slice_len > max_consecutive_slice_len: + yield s[slice_start:i] + slice_start = i + current_slice_len = 1 + yield s[slice_start:] + # ----------------------------------------------------------------------------- # Our own simple Distributed Data Loader From b25e325c69fee7ae6319d1496106c29b0fb0f7f8 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Thu, 8 Aug 2024 17:55:01 +0200 Subject: [PATCH 27/36] Add explicit external mask --- train_gpt2.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/train_gpt2.py b/train_gpt2.py index 4a8cc6e0e..0921382cd 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -168,11 +168,7 @@ def __init__(self, config): self.cache_k = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd)) self.cache_v = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd)) - # not really a 'bias', more of a mask, but following the OpenAI/HF naming though - self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) - .view(1, 1, config.block_size, config.block_size)) - - def forward(self, x, freqs_cis=None, start_pos=None): + def forward(self, x, freqs_cis=None, start_pos=None, mask: Optional[torch.Tensor] = None): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim qkv = self.c_attn(x) @@ -194,12 +190,13 @@ def forward(self, x, freqs_cis=None, start_pos=None): if FLASH: # flashattention - y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + y = F.scaled_dot_product_attention(q, k, v, mask) else: # manual implementation of attention # this materializes the large (T,T) matrix for all the queries and keys scores = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.hd)) - scores = scores.masked_fill(self.bias[:,:,:T,:T] == 0, torch.finfo(scores.dtype).min) + if mask is not None: + scores.masked_fill_(mask, torch.finfo(scores.dtype).min) att = F.softmax(scores.float(), dim=-1).type_as(q) y = att @ v # (B, NH, T, T) x (B, NH, T, HD) -> (B, NH, T, HD) y = y.transpose(1, 2).contiguous().view(B, T, C) @@ -239,8 +236,8 @@ def __init__(self, config): self.ln_2 = RMSNorm(config.n_embd, config.norm_eps) self.mlp = MLP(config) - def forward(self, x, freqs_cis=None, start_pos=None): - x = x + self.attn(self.ln_1(x), freqs_cis, start_pos) + def forward(self, x, freqs_cis=None, start_pos=None, mask: Optional[torch.Tensor] = None): + x = x + self.attn(self.ln_1(x), freqs_cis, start_pos, mask) x = x + self.mlp(self.ln_2(x)) return x @@ -296,8 +293,10 @@ def forward(self, idx, targets=None, return_logits=True, start_pos=0): x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) freqs_cis = self.freqs_cis[start_pos:start_pos+t] + mask = torch.triu(torch.ones((t, t), device=next(self.parameters()).device, dtype=torch.bool), diagonal=1) + for i, block in enumerate(self.transformer.h): - x = block(x, freqs_cis, start_pos) + x = block(x, freqs_cis, start_pos, mask) x = self.transformer.ln_f(x) if targets is not None: From b7c98c93655384b054f595e82d522681dabe0919 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Thu, 8 Aug 2024 18:12:35 +0200 Subject: [PATCH 28/36] Add llama config error check --- train_gpt2.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/train_gpt2.py b/train_gpt2.py index 0921382cd..c1ccff942 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -261,6 +261,14 @@ class LlamaConfig: max_gen_batch_size: int = 4 use_kv: bool = True + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if hasattr(self, k): + setattr(self, k, v) + assert self.n_kv_head <= self.n_head + assert self.n_head % self.n_kv_head == 0 + assert self.n_embd % self.n_head == 0 + class LLaMA(nn.Module): def __init__(self, config): From 624ed3ce30ad227072e0009e90b01fd8cb7b8ace Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Thu, 8 Aug 2024 18:15:37 +0200 Subject: [PATCH 29/36] Rename the new file to train llama3 --- train_gpt2.py | 986 +++++++++++------------------------- train_llama3.py | 1284 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 1565 insertions(+), 705 deletions(-) create mode 100644 train_llama3.py diff --git a/train_gpt2.py b/train_gpt2.py index c1ccff942..b9dee8701 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -1,47 +1,28 @@ """ -Reference code for LLaMA-3.1 training and inference. +Reference code for GPT-2 training and inference. Will save the model weights into files, to be read from C as initialization. -This code differs from GPT-2 very slightly, there are three main differences: -1) RoPE: LLaMA uses a different positional encoding scheme called Relative Positional Encoding (RoPE). -2) GQA: Grouped Query Attention (GQA) is used to reduce the number of attention heads. -3) SwiGLU: Swish-Gated Linear Unit (SwiGLU) is used as the activation function in the MLP. - References: -# 1) https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/tokenizer.py -# 2) https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py -# 3) https://github.com/meta-llama/llama3/blob/11817d47e1ba7a4959b025eb1ca308572e0e3963/llama/generation.py +1) the official GPT-2 TensorFlow implementation released by OpenAI: +https://github.com/openai/gpt-2/blob/master/src/model.py +2) huggingface/transformers PyTorch implementation: +https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py Example launches to only benchmark the speed of bfloat16 compiled GPU training: 1 GPU: -python train_llama3.py --write_tensors=0 --num_iterations=50 --sequence_length=8192 --compile=1 --tensorcores=1 --dtype=bfloat16 +python train_gpt2.py --write_tensors=0 --num_iterations=50 --sequence_length=1024 --compile=1 --tensorcores=1 --dtype=bfloat16 you can also turn on flash-attention by appending --flash=1 4 GPU: -torchrun --standalone --nproc_per_node=4 train_llama3.py --write_tensors=0 --num_iterations=50 --sequence_length=8192 --compile=1 --tensorcores=1 --dtype=bfloat16 +torchrun --standalone --nproc_per_node=4 train_gpt2.py --write_tensors=0 --num_iterations=50 --sequence_length=1024 --compile=1 --tensorcores=1 --dtype=bfloat16 """ import os import math import glob +import struct import inspect from contextlib import nullcontext from dataclasses import dataclass -import json -from pathlib import Path -from typing import ( - AbstractSet, - Callable, - Collection, - Dict, - Iterator, - List, - Literal, - Optional, - Sequence, - Tuple, - Union, - cast, -) import numpy as np import torch @@ -53,153 +34,54 @@ from torch.distributed.optim import ZeroRedundancyOptimizer import torch.distributed as dist -import tiktoken -from tiktoken.load import load_tiktoken_bpe - # ----------------------------------------------------------------------------- -# PyTorch nn.Module definitions for the LLaMA 3.x model +# PyTorch nn.Module definitions for the GPT-2 model + +class NewGELU(nn.Module): + """Careful there are a few versions of GeLU, this one is the exact one used by OpenAI""" + def forward(self, input): + return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) # using a global to toggle flash-attention FLASH = 0 -# Used in Grouped Query Attention (GQA), broadcasts the key and value tensors -def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" - bs, slen, n_kv_heads, head_dim = x.shape - if n_rep == 1: - return x - return ( - x[:, :, :, None, :] - .expand(bs, slen, n_kv_heads, n_rep, head_dim) - .reshape(bs, slen, n_kv_heads * n_rep, head_dim) - ) - -# ----------------------------------------------------------------------------- -# RoPE related - -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): - ndim = x.ndim - assert 0 <= 1 < ndim - assert freqs_cis.shape == (x.shape[1], x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) - -def apply_scaling(freqs: torch.Tensor): - # Values obtained from grid search - scale_factor = 8 - low_freq_factor = 1 - high_freq_factor = 4 - old_context_len = 8192 # original llama3 length - - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - new_freqs = [] - for freq in freqs: - wavelen = 2 * math.pi / freq - if wavelen < high_freq_wavelen: - new_freqs.append(freq) - elif wavelen > low_freq_wavelen: - new_freqs.append(freq / scale_factor) - else: - assert low_freq_wavelen != high_freq_wavelen - smooth = (old_context_len / wavelen - low_freq_factor) / ( - high_freq_factor - low_freq_factor - ) - new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) - return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - -def precompute_freqs_cis( - dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False -): - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device, dtype=torch.float32) - if use_scaled: - freqs = apply_scaling(freqs) - freqs = torch.outer(t, freqs) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - -# ----------------------------------------------------------------------------- -# LLaMA building blocks - -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()).type_as(x) - return output * self.weight - class CausalSelfAttention(nn.Module): def __init__(self, config): super().__init__() assert config.n_embd % config.n_head == 0 - - self.n_head = config.n_head - self.n_kv_head = config.n_kv_head - self.n_rep = self.n_head // self.n_kv_head - self.hd = config.n_embd // config.n_head - self.use_kv = config.use_kv - - self.c_attn = nn.Linear(config.n_embd, (config.n_head + 2 * config.n_kv_head) * self.hd, bias=False) # key, query, value projections - self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) # output projection + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd) self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 + # regularization + self.n_head = config.n_head + self.n_embd = config.n_embd + # not really a 'bias', more of a mask, but following the OpenAI/HF naming though + self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) + .view(1, 1, config.block_size, config.block_size)) - # static KV cache - we could alternatively allocate it outside of the model and just pass it in when needed - if self.use_kv: - self.cache_k = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd)) - self.cache_v = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd)) - - def forward(self, x, freqs_cis=None, start_pos=None, mask: Optional[torch.Tensor] = None): + def forward(self, x): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim qkv = self.c_attn(x) - q, k, v = qkv.split([self.n_head * self.hd, self.n_kv_head * self.hd, self.n_kv_head * self.hd], dim=-1) - q, k, v = map(lambda t: t.view(B, T, -1, self.hd), (q, k, v)) # (B, T, NH, HD) - - q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) # rotate QK (rope) <-- 1. difference compared to GPT-2 - - if self.use_kv and not self.training and start_pos >= 0: # use kv-caching during inference - self.cache_k[:B, start_pos : start_pos + T] = k - self.cache_v[:B, start_pos : start_pos + T] = v - k = self.cache_k[:B, : start_pos + T] - v = self.cache_v[:B, : start_pos + T] - - k = repeat_kv(k, self.n_rep) # GQA <-- 2. difference compared to GPT-2 - v = repeat_kv(v, self.n_rep) - - q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v)) # (B, NH, T, HD) - + q, k, v = qkv.split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) if FLASH: # flashattention - y = F.scaled_dot_product_attention(q, k, v, mask) + y = F.scaled_dot_product_attention(q, k, v, is_causal=True) else: # manual implementation of attention # this materializes the large (T,T) matrix for all the queries and keys - scores = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.hd)) - if mask is not None: - scores.masked_fill_(mask, torch.finfo(scores.dtype).min) - att = F.softmax(scores.float(), dim=-1).type_as(q) - y = att @ v # (B, NH, T, T) x (B, NH, T, HD) -> (B, NH, T, HD) - y = y.transpose(1, 2).contiguous().view(B, T, C) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + # output projection y = self.c_proj(y) return y @@ -207,23 +89,14 @@ class MLP(nn.Module): def __init__(self, config): super().__init__() - hidden_dim = 4 * config.n_embd - hidden_dim = int(2 * hidden_dim / 3) - # custom dim factor multiplier - if config.ffn_dim_multiplier is not None: - hidden_dim = int(config.ffn_dim_multiplier * hidden_dim) - hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of) - self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=False) - self.c_fc2 = nn.Linear(config.n_embd, hidden_dim, bias=False) - self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False) + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) + self.gelu = NewGELU() + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 def forward(self, x): - # SwiGLU self.c_proj(F.silu(self.c_fc2(x)) * self.c_fc(x)) <-- 3. difference compared to GPT-2 - x1 = self.c_fc(x) - x2 = self.c_fc2(x) - x2 = F.silu(x2) - x = x1 * x2 + x = self.c_fc(x) + x = self.gelu(x) x = self.c_proj(x) return x @@ -231,45 +104,28 @@ class Block(nn.Module): def __init__(self, config): super().__init__() - self.ln_1 = RMSNorm(config.n_embd, config.norm_eps) + self.ln_1 = nn.LayerNorm(config.n_embd) self.attn = CausalSelfAttention(config) - self.ln_2 = RMSNorm(config.n_embd, config.norm_eps) + self.ln_2 = nn.LayerNorm(config.n_embd) self.mlp = MLP(config) - def forward(self, x, freqs_cis=None, start_pos=None, mask: Optional[torch.Tensor] = None): - x = x + self.attn(self.ln_1(x), freqs_cis, start_pos, mask) + def forward(self, x): + x = x + self.attn(self.ln_1(x)) x = x + self.mlp(self.ln_2(x)) return x # ----------------------------------------------------------------------------- -# The main LLaMA 3.1 model +# The main GPT-2 model @dataclass -class LlamaConfig: - version: str = "3.1" - block_size: int = 8192 - vocab_size: int = 128256 - n_layer: int = 32 - n_head: int = 32 - n_kv_head: int = 8 - n_embd: int = 4096 - ffn_dim_multiplier: float = 1.3 - multiple_of: int = 1024 - norm_eps: float = 1e-5 - rope_theta: float = 500000.0 - use_scaled_rope: bool = True - max_gen_batch_size: int = 4 - use_kv: bool = True - - def __init__(self, **kwargs): - for k, v in kwargs.items(): - if hasattr(self, k): - setattr(self, k, v) - assert self.n_kv_head <= self.n_head - assert self.n_head % self.n_kv_head == 0 - assert self.n_embd % self.n_head == 0 - -class LLaMA(nn.Module): +class GPTConfig: + block_size: int = 1024 + vocab_size: int = 50257 + n_layer: int = 12 + n_head: int = 12 + n_embd: int = 768 + +class GPT(nn.Module): def __init__(self, config): super().__init__() @@ -277,43 +133,54 @@ def __init__(self, config): self.transformer = nn.ModuleDict(dict( wte = nn.Embedding(config.vocab_size, config.n_embd), + wpe = nn.Embedding(config.block_size, config.n_embd), h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), - ln_f = RMSNorm(config.n_embd, config.norm_eps), + ln_f = nn.LayerNorm(config.n_embd), )) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.lm_head.LLMC_SKIP_INIT = 1 # don't init this one, we will tie weights + self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying # init all weights, use a torch rng object to be very careful self.init_rng = torch.Generator() self.init_rng.manual_seed(42) - - self.freqs_cis = precompute_freqs_cis( - config.n_embd // config.n_head, - config.block_size * 2, - config.rope_theta, - config.use_scaled_rope, - ) - - def forward(self, idx, targets=None, return_logits=True, start_pos=0): - _, t = idx.size() + self.apply(self._init_weights) + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + # apply special scaled init to the residual projections, per GPT-2 paper + std = 0.02 if not hasattr(module, 'LLMC_RESIDUAL_SCALE_FLAG') else 0.02/math.sqrt(2 * self.config.n_layer) + # we want to skip initializing lm_head, which shares parameters with wte + # and wte was already initialized down below during the Embedding init + if not hasattr(module, 'LLMC_SKIP_INIT'): + torch.nn.init.normal_(module.weight, mean=0.0, std=std, generator=self.init_rng) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02, generator=self.init_rng) + + def forward(self, idx, targets=None, return_logits=True): + device = idx.device + b, t = idx.size() assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) - # forward the LLaMA model itself - x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - freqs_cis = self.freqs_cis[start_pos:start_pos+t] - - mask = torch.triu(torch.ones((t, t), device=next(self.parameters()).device, dtype=torch.bool), diagonal=1) + # forward the GPT model itself + tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) + x = tok_emb + pos_emb - for i, block in enumerate(self.transformer.h): - x = block(x, freqs_cis, start_pos, mask) + for block in self.transformer.h: + x = block(x) x = self.transformer.ln_f(x) if targets is not None: # if we are given some desired targets also calculate the loss - logits = self.lm_head(x).float() + logits = self.lm_head(x) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) else: # inference-time mini-optimization: only forward the lm_head on the very last position - logits = self.lm_head(x[:, [-1], :]).float() # note: using list [-1] to preserve the time dim + logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim loss = None # there are performance reasons why not returning logits is prudent, if not needed @@ -322,123 +189,53 @@ def forward(self, idx, targets=None, return_logits=True, start_pos=0): return logits, loss - @staticmethod - def adapt_llama_state_dict_keys(checkpoint, config: LlamaConfig): - checkpoint['transformer.wte.weight'] = checkpoint.pop('tok_embeddings.weight') - - for i in range(config.n_layer): - for name in ['attention_norm', 'ffn_norm']: - old_key = f'layers.{i}.{name}.weight' # e.g. layers.x.attention_norm.weight -> transformer.h.x.ln_1.weight - new_key = f'transformer.h.{i}.ln_{1 if name == "attention_norm" else 2}.weight' - checkpoint[new_key] = checkpoint.pop(old_key) - - for i in range(config.n_layer): - for name in ['attention.wq', 'attention.wk', 'attention.wv']: - old_key = f'layers.{i}.{name}.weight' - new_key = f'transformer.h.{i}.attn.c_attn.weight' - if name == 'attention.wq': - checkpoint[new_key] = checkpoint.pop(old_key) - else: # merge 3 weights into transformer.h.x.attn.c_attn.weight - checkpoint[new_key] = torch.cat((checkpoint[new_key], checkpoint.pop(old_key)), dim=0) - old_key = f'layers.{i}.attention.wo.weight' - new_key = f'transformer.h.{i}.attn.c_proj.weight' - checkpoint[new_key] = checkpoint.pop(old_key) - - ffn_map = {'w1': 'c_fc2', 'w2': 'c_proj', 'w3': 'c_fc'} - for i in range(config.n_layer): - for name in ['feed_forward.w1', 'feed_forward.w2', 'feed_forward.w3']: - old_key = f'layers.{i}.{name}.weight' - new_key = f'transformer.h.{i}.mlp.{ffn_map[name.split(".")[-1]]}.weight' - checkpoint[new_key] = checkpoint.pop(old_key) - - checkpoint['transformer.ln_f.weight'] = checkpoint.pop('norm.weight') - checkpoint['lm_head.weight'] = checkpoint.pop('output.weight') - - return checkpoint - - @staticmethod - def adapt_llama_state_dict_keys_hf(checkpoint, config: LlamaConfig): - checkpoint['transformer.wte.weight'] = checkpoint.pop('model.embed_tokens.weight') - - # We need to unpermute K and V because HF script permuted the original Meta-LLaMA weights - # see: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py - def unpermute(w, n_heads, dim1, dim2): - return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2) - - for i in range(config.n_layer): - for name in ['input_layernorm', 'post_attention_layernorm']: - old_key = f'model.layers.{i}.{name}.weight' # e.g. layers.x.attention_norm.weight -> transformer.h.x.ln_1.weight - new_key = f'transformer.h.{i}.ln_{1 if name == "input_layernorm" else 2}.weight' - checkpoint[new_key] = checkpoint.pop(old_key) - - for i in range(config.n_layer): - for name in ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj']: - old_key = f'model.layers.{i}.{name}.weight' - new_key = f'transformer.h.{i}.attn.c_attn.weight' - if name == 'self_attn.q_proj': - checkpoint[new_key] = unpermute(checkpoint.pop(old_key), config.n_head, config.n_embd, config.n_embd) - else: # merge 3 weights into transformer.h.x.attn.c_attn.weight - tensor = checkpoint.pop(old_key) - if name == 'self_attn.k_proj': - tensor = unpermute(tensor, config.n_kv_head, config.n_kv_head * (config.n_embd // config.n_head), config.n_embd) - checkpoint[new_key] = torch.cat((checkpoint[new_key], tensor), dim=0) - old_key = f'model.layers.{i}.self_attn.o_proj.weight' - new_key = f'transformer.h.{i}.attn.c_proj.weight' - checkpoint[new_key] = checkpoint.pop(old_key) - - ffn_map = {'gate_proj': 'c_fc2', 'down_proj': 'c_proj', 'up_proj': 'c_fc'} - for i in range(config.n_layer): - for name in ['gate_proj', 'down_proj', 'up_proj']: - old_key = f'model.layers.{i}.mlp.{name}.weight' - new_key = f'transformer.h.{i}.mlp.{ffn_map[name]}.weight' - checkpoint[new_key] = checkpoint.pop(old_key) - - checkpoint['transformer.ln_f.weight'] = checkpoint.pop('model.norm.weight') - - return checkpoint - @classmethod - def from_pretrained_llama3_hf(cls, model_id): - """Loads pretrained LLaMA model weights from HuggingFace""" - from transformers import AutoModelForCausalLM, AutoTokenizer - assert model_id == "meta-llama/Meta-Llama-3.1-8B", "Only the 8B-bae model is supported for now" - model_args = LlamaConfig() - - model = AutoModelForCausalLM.from_pretrained(model_id) - checkpoint = LLaMA.adapt_llama_state_dict_keys_hf(model.state_dict(), model_args) - - original_default_type = torch.get_default_dtype() # save the default type - torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) # much faster loading - model = LLaMA(model_args) - model.load_state_dict(checkpoint, strict=False) - torch.set_default_tensor_type(torch.tensor([], dtype=original_default_type, device="cpu").type()) # restore default type - - tokenizer = AutoTokenizer.from_pretrained(model_id) - tokenizer.pad_id = 128004 # this is the pad token id for LLaMA 3.1 base, we need to set this explicitly as our generate func expects it - tokenizer.stop_tokens = [tokenizer.eos_token_id] - model.tokenizer = tokenizer - return model + def from_pretrained(cls, model_type): + """Loads pretrained GPT-2 model weights from huggingface""" + assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} + from transformers import GPT2LMHeadModel + print("loading weights from pretrained gpt: %s" % model_type) + + # n_layer, n_head and n_embd are determined from model_type + config_args = { + 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params + 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params + 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params + 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params + }[model_type] + config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints + config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints + # create a from-scratch initialized minGPT model + config = GPTConfig(**config_args) + model = GPT(config) + sd = model.state_dict() + sd_keys = sd.keys() + sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param + + # init a huggingface/transformers model + model_hf = GPT2LMHeadModel.from_pretrained(model_type) + sd_hf = model_hf.state_dict() + + # copy while ensuring all of the parameters are aligned and match in names and shapes + sd_keys_hf = sd_hf.keys() + sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer + sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer) + transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] + # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear + # this means that we have to transpose these weights when we import them + assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" + for k in sd_keys_hf: + if any(k.endswith(w) for w in transposed): + # special treatment for the Conv1D weights we need to transpose + assert sd_hf[k].shape[::-1] == sd[k].shape + with torch.no_grad(): + sd[k].copy_(sd_hf[k].t()) + else: + # vanilla copy over the other parameters + assert sd_hf[k].shape == sd[k].shape + with torch.no_grad(): + sd[k].copy_(sd_hf[k]) - @classmethod - def from_pretrained_llama3_meta(cls, ckpt_dir, tokenizer_path): - """Loads pretrained LLaMA model weights from a checkpoint directory""" - model_args = LlamaConfig() - - ckpt_path = sorted(Path(ckpt_dir).glob("*.pth"))[0] - checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) - checkpoint = LLaMA.adapt_llama_state_dict_keys(checkpoint, model_args) - - original_default_type = torch.get_default_dtype() # save the default type - torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) # much faster loading - model = LLaMA(model_args) - model.load_state_dict(checkpoint, strict=False) - torch.set_default_tensor_type(torch.tensor([], dtype=original_default_type, device="cpu").type()) # restore default type - - tokenizer = Tokenizer(model_path=tokenizer_path) - # add <|end_of_text|> as the stop token for base model - this is an omission in the reference code - # the reference code only adds instruct model stop tokens... - tokenizer.stop_tokens = tokenizer.stop_tokens + [128001] - model.tokenizer = tokenizer return model def configure_optimizers(self, weight_decay, learning_rate, betas, device_type, zero_stage): @@ -472,314 +269,32 @@ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type, optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, fused=use_fused) return optimizer - @torch.inference_mode() - def generate( - self, - prompt_tokens: List[List[int]], - max_gen_len: int, - temperature: float = 0.6, - top_p: float = 0.9, - logprobs: bool = False, - echo: bool = False, - ) -> Tuple[List[List[int]], Optional[List[List[float]]]]: + @torch.no_grad() + def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): """ - Generate text sequences based on provided prompts using the language generation model. - - Args: - prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers. - max_gen_len (int): Maximum length of the generated text sequence. - temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. - top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. - logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. - echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. - - Returns: - Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities. - - Note: - This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness. - If logprobs is True, token log probabilities are computed for each generated token. - + Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete + the sequence max_new_tokens times, feeding the predictions back into the model each time. + Most likely you'll want to make sure to be in model.eval() mode of operation for this. """ - bsz = len(prompt_tokens) - assert bsz <= self.config.max_gen_batch_size, (bsz, self.config.max_gen_batch_size) - device = next(self.parameters()).device - - min_prompt_len = min(len(t) for t in prompt_tokens) - max_prompt_len = max(len(t) for t in prompt_tokens) - assert max_prompt_len <= self.config.block_size - total_len = min(self.config.block_size, max_gen_len + max_prompt_len) - - pad_id = self.tokenizer.pad_id - tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=device) - for k, t in enumerate(prompt_tokens): - tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device) - if logprobs: - token_logprobs = torch.zeros_like(tokens, dtype=torch.float) - - prev_pos = 0 - eos_reached = torch.tensor([False] * bsz, device=device) - input_text_mask = tokens != pad_id - - if min_prompt_len == total_len: - logits, _ = self.forward(tokens, start_pos=prev_pos) - token_logprobs = -F.cross_entropy( - input=logits.transpose(1, 2), - target=tokens, - reduction="none", - ignore_index=pad_id, - ) - - stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens)).to(device) - - for cur_pos in range(min_prompt_len, total_len): - logits, _ = self.forward(tokens[:, prev_pos:cur_pos], start_pos=prev_pos) - if temperature > 0: - probs = torch.softmax(logits[:, -1] / temperature, dim=-1) - next_token = sample_top_p(probs, top_p) - else: - next_token = torch.argmax(logits[:, -1], dim=-1) - - next_token = next_token.reshape(-1) - # only replace token if prompt has already been generated - next_token = torch.where( - input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token - ) - tokens[:, cur_pos] = next_token - if logprobs: - token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( - input=logits.transpose(1, 2), - target=tokens[:, prev_pos + 1 : cur_pos + 1], - reduction="none", - ignore_index=pad_id, - ) - eos_reached |= (~input_text_mask[:, cur_pos]) & ( - torch.isin(next_token, stop_tokens) - ) - prev_pos = cur_pos - if all(eos_reached): - break - - if logprobs: - token_logprobs = token_logprobs.tolist() - out_tokens, out_logprobs = [], [] - for i, toks in enumerate(tokens.tolist()): - # cut to max gen len - start = 0 if echo else len(prompt_tokens[i]) - toks = toks[start : len(prompt_tokens[i]) + max_gen_len] - probs = None - if logprobs: - probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] - # cut to after eos tok if any - for stop_token in self.tokenizer.stop_tokens: - try: - eos_idx = toks.index(stop_token) - toks = toks[:eos_idx] - probs = probs[:eos_idx] if logprobs else None - except ValueError: - pass - out_tokens.append(toks) - out_logprobs.append(probs) - return (out_tokens, out_logprobs if logprobs else None) - -# ----------------------------------------------------------------------------- -# sampling utils - -def sample_top_p(probs, p): - """ - Perform top-p (nucleus) sampling on a probability distribution. - - Args: - probs (torch.Tensor): Probability distribution tensor. - p (float): Probability threshold for top-p sampling. - - Returns: - torch.Tensor: Sampled token indices. - - Note: - Top-p sampling selects the smallest set of tokens whose cumulative probability mass - exceeds the threshold p. The distribution is renormalized based on the selected tokens. - """ - probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = probs_sum - probs_sort > p - probs_sort[mask] = 0.0 - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - next_token = torch.multinomial(probs_sort, num_samples=1) - next_token = torch.gather(probs_idx, -1, next_token) - return next_token - -# ----------------------------------------------------------------------------- -# Llama 3.1 Tokenizer - -# The tiktoken tokenizer can handle <=400k chars without -# pyo3_runtime.PanicException. -TIKTOKEN_MAX_ENCODE_CHARS = 400_000 - -# https://github.com/openai/tiktoken/issues/195 -# Here we iterate over subsequences and split if we exceed the limit -# of max consecutive non-whitespace or whitespace characters. -MAX_NO_WHITESPACES_CHARS = 25_000 - - -class Tokenizer: - """ - Tokenizing and encoding/decoding text using the Tiktoken tokenizer. - """ - - special_tokens: Dict[str, int] - - num_reserved_special_tokens = 256 - - pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 - - def __init__(self, model_path: str): - """ - Initializes the Tokenizer with a Tiktoken model. - - Args: - model_path (str): The path to the Tiktoken model file. - """ - assert os.path.isfile(model_path), model_path - - mergeable_ranks = load_tiktoken_bpe(model_path) - num_base_tokens = len(mergeable_ranks) - special_tokens = [ - "<|begin_of_text|>", - "<|end_of_text|>", - "<|reserved_special_token_0|>", - "<|reserved_special_token_1|>", - "<|finetune_right_pad_id|>", - "<|step_id|>", - "<|start_header_id|>", - "<|end_header_id|>", - "<|eom_id|>", # end of message - "<|eot_id|>", # end of turn - "<|python_tag|>", - ] - reserved_tokens = [ - f"<|reserved_special_token_{2 + i}|>" - for i in range(self.num_reserved_special_tokens - len(special_tokens)) - ] - special_tokens = special_tokens + reserved_tokens - - self.special_tokens = { - token: num_base_tokens + i for i, token in enumerate(special_tokens) - } - self.model = tiktoken.Encoding( - name=Path(model_path).name, - pat_str=self.pat_str, - mergeable_ranks=mergeable_ranks, - special_tokens=self.special_tokens, - ) - - self.n_words: int = num_base_tokens + len(special_tokens) - # BOS / EOS token IDs - self.bos_id: int = self.special_tokens["<|begin_of_text|>"] - self.eos_id: int = self.special_tokens["<|end_of_text|>"] - self.eot_id: int = self.special_tokens["<|eot_id|>"] - self.eom_id: int = self.special_tokens["<|eom_id|>"] - self.python_tag_id = self.special_tokens["<|python_tag|>"] - self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"] - self.stop_tokens = [ - self.special_tokens["<|eom_id|>"], - self.special_tokens["<|eot_id|>"], - ] - - def encode( - self, - s: str, - *, - bos: bool, - eos: bool, - allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, - disallowed_special: Union[Literal["all"], Collection[str]] = (), - ) -> List[int]: - """ - Encodes a string into a list of token IDs. - - Args: - s (str): The input string to be encoded. - bos (bool): Whether to prepend the beginning-of-sequence token. - eos (bool): Whether to append the end-of-sequence token. - allowed_tokens ("all"|set[str]): allowed special tokens in string - disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string - - Returns: - list[int]: A list of token IDs. - - By default, setting disallowed_special=() encodes a string by ignoring - special tokens. Specifically: - - Setting `disallowed_special` to () will cause all text corresponding - to special tokens to be encoded as natural text (insteading of raising - an error). - - Setting `allowed_special` to "all" will treat all text corresponding - to special tokens to be encoded as special tokens. - """ - if allowed_special is None: - allowed_special = set() - assert type(s) is str - - substrs = ( - substr - for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) - for substr in self._split_whitespaces_or_nonwhitespaces( - s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS - ) - ) - t: List[int] = [] - for substr in substrs: - t.extend( - self.model.encode( - substr, - allowed_special=allowed_special, - disallowed_special=disallowed_special, - ) - ) - if bos: - t.insert(0, self.bos_id) - if eos: - t.append(self.eos_id) - return t - - def decode(self, t: Sequence[int]) -> str: - """ - Decodes a list of token IDs into a string. - - Args: - t (List[int]): The list of token IDs to be decoded. - - Returns: - str: The decoded string. - """ - # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. - return self.model.decode(cast(List[int], t)) - - @staticmethod - def _split_whitespaces_or_nonwhitespaces( - s: str, max_consecutive_slice_len: int - ) -> Iterator[str]: - """ - Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` - consecutive whitespaces or consecutive non-whitespaces. - """ - current_slice_len = 0 - current_slice_is_space = s[0].isspace() if len(s) > 0 else False - slice_start = 0 - - for i in range(len(s)): - is_now_space = s[i].isspace() - - if current_slice_is_space ^ is_now_space: - current_slice_len = 1 - current_slice_is_space = is_now_space - else: - current_slice_len += 1 - if current_slice_len > max_consecutive_slice_len: - yield s[slice_start:i] - slice_start = i - current_slice_len = 1 - yield s[slice_start:] + for _ in range(max_new_tokens): + # if the sequence context is growing too long we must crop it at block_size + idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] + # forward the model to get the logits for the index in the sequence + logits, _ = self(idx_cond) + # pluck the logits at the final step and scale by desired temperature + logits = logits[:, -1, :] / temperature + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = -float('Inf') + # apply softmax to convert logits to (normalized) probabilities + probs = F.softmax(logits, dim=-1) + # sample from the distribution + idx_next = torch.multinomial(probs, num_samples=1) + # append sampled index to the running sequence and continue + idx = torch.cat((idx, idx_next), dim=1) + + return idx # ----------------------------------------------------------------------------- # Our own simple Distributed Data Loader @@ -878,54 +393,83 @@ def write_bf16(tensor, file): file.write(b) def write_tensors(model_tensors, L, file, dtype): - # writes LLaMA 3 model's weights to a binary file + # writes the GPT-2 model's weights to a binary file assert dtype in {"float32", "bfloat16"} write_fun = write_fp32 if dtype == "float32" else write_bf16 write_fun(model_tensors["transformer.wte.weight"], file) # (V, C) + write_fun(model_tensors["transformer.wpe.weight"], file) # (T, C) for i in range(L): # (L, C) write_fun(model_tensors[f"transformer.h.{i}.ln_1.weight"], file) + for i in range(L): # (L, C) + write_fun(model_tensors[f"transformer.h.{i}.ln_1.bias"], file) for i in range(L): # (L, 3C, C) write_fun(model_tensors[f"transformer.h.{i}.attn.c_attn.weight"], file) + for i in range(L): # (L, 3C) + write_fun(model_tensors[f"transformer.h.{i}.attn.c_attn.bias"], file) for i in range(L): # (L, C, C) write_fun(model_tensors[f"transformer.h.{i}.attn.c_proj.weight"], file) + for i in range(L): # (L, C) + write_fun(model_tensors[f"transformer.h.{i}.attn.c_proj.bias"], file) for i in range(L): # (L, C) write_fun(model_tensors[f"transformer.h.{i}.ln_2.weight"], file) + for i in range(L): # (L, C) + write_fun(model_tensors[f"transformer.h.{i}.ln_2.bias"], file) for i in range(L): # (L, 4C, C) write_fun(model_tensors[f"transformer.h.{i}.mlp.c_fc.weight"], file) - for i in range(L): # (L, 4C, C) - write_fun(model_tensors[f"transformer.h.{i}.mlp.c_fc2.weight"], file) + for i in range(L): # (L, 4C) + write_fun(model_tensors[f"transformer.h.{i}.mlp.c_fc.bias"], file) for i in range(L): # (L, C, 4C) write_fun(model_tensors[f"transformer.h.{i}.mlp.c_proj.weight"], file) + for i in range(L): # (L, C) + write_fun(model_tensors[f"transformer.h.{i}.mlp.c_proj.bias"], file) write_fun(model_tensors["transformer.ln_f.weight"], file) # (C, ) - write_fun(model_tensors["lm_head.weight"], file) # (V, C) + write_fun(model_tensors["transformer.ln_f.bias"], file) # (C, ) + +@torch.no_grad() +def pad_vocab(tensor, multiple=128, value=0): + """ + The dimension of the vocab size in GPT-2 is 50,257 + which is unfortunately a very unfriendly number for a lot of + matrix operations on the GPU. So we pad it to the nearest + friendlier multiple, e.g. 50,304 if multiple=128 when we + export the weights into C land. This is a NOOP algorithmically + and is only done to make the tensor operations more efficient. + """ + assert tensor.ndim == 2 + V, C = tensor.shape + assert V == 50257, "just being defensive here" + # calculate padded vocab size by rounding up to nearest multiple + Vp = ((V + multiple - 1) // multiple) * multiple + # pad the tensor + pad_rows = Vp - V + padded = tensor if pad_rows == 0 else F.pad(tensor, (0, 0, 0, pad_rows), value=value) + assert padded.shape == (Vp, C) + return padded def write_model(model, filename, dtype): # everything we need to instantiate the model - # 1) header is: version int, LLaMAConfig ints, padding to 1024 bytes - assert dtype in {"float32", "bfloat16"} + # 1) header is: version int, GPTConfig ints, padding to 1024 bytes + assert dtype in {"float32", "bfloat16"} # float16 todo maybe later version = { - "float32": 3, # 3: all tensors are fp32 - "bfloat16": 5, # 5: all tensors are bf16 + "float32": 3, # 3: all tensors are fp32, padded vocab + "bfloat16": 5, # 5: all tensors are bf16, padded vocab }[dtype] header = torch.zeros(256, dtype=torch.int32) - header[0] = 20240803 # magic + header[0] = 20240326 # magic header[1] = version # checkpoint version header[2] = model.config.block_size header[3] = model.config.vocab_size header[4] = model.config.n_layer header[5] = model.config.n_head - header[6] = model.config.n_kv_head - header[7] = model.config.n_embd - header[8] = model.config.ffn_dim_multiplier - header[9] = model.config.multiple_of - header[10] = model.config.norm_eps - header[11] = model.config.rope_theta - header[12] = model.config.use_scaled_rope - header[13] = model.config.max_gen_batch_size - header[14] = int(model.config.version.split('.')[0]) # major version - header[15] = int(model.config.version.split('.')[1]) # minor version + header[6] = model.config.n_embd # 2) the parameters follow the header params = {name: param.cpu() for name, param in model.named_parameters()} + # pad the vocab to a multiple of 128 here at export, for efficiency in C + wte = params["transformer.wte.weight"] # (V, C) + wte_padded = pad_vocab(wte) # (Vp, C) + params["transformer.wte.weight"] = wte_padded # (Vp, C) + print(f"padded vocab size from {wte.size(0)} to {wte_padded.size(0)}") + header[7] = wte_padded.size(0) # padded vocab size store in header # now write to file with open(filename, "wb") as file: file.write(header.numpy().tobytes()) # header @@ -937,10 +481,16 @@ def write_state(model, x, y, logits, loss, filename): # it contains information about the input, logits, loss, and the parameter gradients # this can be used for checking the computation correctness in C header = torch.zeros(256, dtype=torch.int32) - header[0] = 20240803 # magic - header[1] = x.size(0) # batch size of the batch, B - header[2] = x.size(1) # temporal extent of the batch, T + header[0] = 20240327 # magic + header[1] = 2 # run state version = 2 (1 -> 2 for padded vocab changes) + header[2] = x.size(0) # batch size of the batch, B + header[3] = x.size(1) # temporal extent of the batch, T grads = {name: param.grad.cpu() for name, param in model.named_parameters()} + # pad the vocab grads here as well, to mirror write_model + wte_grad = grads["transformer.wte.weight"] # (V, C) + wte_grad_padded = pad_vocab(wte_grad, value=0) # (Vp, C) # TODO later maybe pad with nan? + grads["transformer.wte.weight"] = wte_grad_padded # (Vp, C) + print(f"padded vocab size in reference grads from {wte_grad.size(0)} to {wte_grad_padded.size(0)}") with open(filename, "wb") as file: # header file.write(header.numpy().tobytes()) @@ -956,6 +506,23 @@ def write_state(model, x, y, logits, loss, filename): write_tensors(grads, model.config.n_layer, file, "float32") print(f"wrote {filename}") +def write_tokenizer(enc, filename): + n = enc.max_token_value + 1 + header = torch.zeros(256, dtype=torch.int32) + header[0] = 20240328 # magic + header[1] = 2 # tokenizer version = 2 (1 -> 2: includes EOT token) + header[2] = n # number of tokens + header[3] = enc.eot_token # EOT token + with open(filename, "wb") as file: + file.write(header.numpy().tobytes()) + for i in range(n): + b = enc.decode_bytes([i]) + length = len(b) + assert length < 256, f"Token length exceeds 255: {length}" + file.write(struct.pack(" C bridge - parser.add_argument("--write_tensors", type=int, default=0, help="write tensors to disk") + parser.add_argument("--write_tensors", type=int, default=1, help="write tensors to disk") args = parser.parse_args() # args error checking and convenience variables B, T = args.batch_size, args.sequence_length - assert 1 <= T <= 8192, "sequence length must be between 1 and 8192" + assert 1 <= T <= 1024 assert args.dtype in {"float32", "float16", "bfloat16"} - assert args.model in {"meta-llama/Meta-Llama-3.1-8B"} # only 8B base model supported for now - - # create the logging directory if it does not exist - logfile = None - if args.output_dir: - os.makedirs(args.output_dir, exist_ok=True) - logfile = os.path.join(args.output_dir, "main.log") - # create the log file "main.log" inside it, and wipe it clean - with open(logfile, "w") as f: - pass + assert args.model in {"gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl", "d12", "d24", "d36", "d48"} # set up DDP (distributed data parallel). torchrun sets this env variable ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? @@ -1061,7 +617,6 @@ def print0(*args, **kwargs): device = "mps" print(f"using device: {device}") device_type = 'cuda' if 'cuda' in device else 'cpu' - assert device_type in {'cuda'} # we need to load LLaMA as bf16 on CUDA # calculate gradient accumulation from the desired total batch size and the current run configuration tokens_per_fwdbwd = B * T * ddp_world_size @@ -1072,7 +627,7 @@ def print0(*args, **kwargs): # set up a context manager following the desired dtype and device ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype] - ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if (device_type == "cuda") else nullcontext() + ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() # rng / reproducibility torch.manual_seed(42) @@ -1088,15 +643,26 @@ def print0(*args, **kwargs): assert args.flash in {0, 1} FLASH = args.flash - # init the model - assert args.ckpt_dir is not None and os.path.exists(args.ckpt_dir), f"llama3 ckpt dir {args.ckpt_dir} does not exist" - assert args.tokenizer_path is not None and os.path.exists(args.tokenizer_path), f"llama3 tokenizer path {args.tokenizer_path} does not exist" - if args.use_hf: - model = LLaMA.from_pretrained_llama3_hf(args.model) - else: # use Meta's checkpoint - model = LLaMA.from_pretrained_llama3_meta(args.ckpt_dir, args.tokenizer_path) - + # init (and write) the tokenizer + enc = tiktoken.get_encoding("gpt2") + if master_process and args.write_tensors: # tokenizer is technically not tensors but ok + write_tokenizer(enc, "gpt2_tokenizer.bin") + + # init the model, either from scratch or from OpenAI pretrained checkpoint + if args.model[0] == "d": + # from scratch (random weights) + model_config = { + "d12": GPTConfig(block_size=1024, vocab_size=50257, n_layer=12, n_head=12, n_embd=768), + "d24": GPTConfig(block_size=1024, vocab_size=50257, n_layer=24, n_head=16, n_embd=1024), + "d36": GPTConfig(block_size=1024, vocab_size=50257, n_layer=36, n_head=20, n_embd=1280), + "d48": GPTConfig(block_size=1024, vocab_size=50257, n_layer=48, n_head=25, n_embd=1600), + }[args.model] + model = GPT(model_config) + else: + # load the GPT-2 model weights + model = GPT.from_pretrained(args.model) model.train() + model.to(device) if args.compile: if hasattr(config, "coordinate_descent_tuning"): config.coordinate_descent_tuning = True # suggested by @Chillee @@ -1121,13 +687,15 @@ def print0(*args, **kwargs): x, y = x.to(device), y.to(device) logits, loss = model(x, y) loss.backward() - # save model params, in bfloat16 - model_to_size = {"meta-llama/Meta-Llama-3.1-8B": "8B"} - model_size_str = model_to_size[args.model] # e.g. "8B" - write_model(model, os.path.join(args.output_dir, f"llama3.1_{model_size_str}_bf16.bin"), dtype="bfloat16") + # save model params, in both float32 and bfloat16 + model_to_size = {"gpt2": "124M", "gpt2-medium": "355M", "gpt2-large": "774M", "gpt2-xl": "1558M"} + model_to_size.update({f"d{d}": f"d{d}" for d in [12, 24, 36, 48]}) + model_size_str = model_to_size[args.model] # e.g. "124M", or "d12" + write_model(model, f"gpt2_{model_size_str}.bin", dtype="float32") + write_model(model, f"gpt2_{model_size_str}_bf16.bin", dtype="bfloat16") # save x, y, logits, loss, and parameter gradients, for debugging C # always store these in fp32 to have an accurate reference (?) - write_state(model, x, y, logits, loss, os.path.join(args.output_dir, f"llama3_{model_size_str}_debug_state.bin")) + write_state(model, x, y, logits, loss, f"gpt2_{model_size_str}_debug_state.bin") # reset the train_loader for the optimization below train_loader.reset() @@ -1159,6 +727,15 @@ def get_lr(it): coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0 return min_lr + coeff * (args.learning_rate - min_lr) + # create the logging directory if it does not exist + logfile = None + if args.output_dir: + os.makedirs(args.output_dir, exist_ok=True) + logfile = os.path.join(args.output_dir, "main.log") + # create the log file "main.log" inside it, and wipe it clean + with open(logfile, "w") as f: + pass + if device == "cuda": torch.cuda.reset_peak_memory_stats() timings = [] @@ -1192,18 +769,17 @@ def get_lr(it): and (step % args.sample_every == 0 or last_step)) \ and master_process: model.eval() - prompts: List[str] = json.loads(open(os.path.join(os.path.dirname(__file__), 'llmc_py', 'prompts.json')).read())['prompts'] - if args.use_hf: - prompt_tokens = [model.tokenizer(x).input_ids for x in prompts] - else: # Meta - prompt_tokens = [model.tokenizer.encode(x, bos=True, eos=False) for x in prompts] - - generation_tokens, _ = model.generate(prompt_tokens, max_gen_len=64, temperature=0.6, top_p=0.9, logprobs=False, echo=False) - results = [{"generation": model.tokenizer.decode(t)} for t in generation_tokens] - for prompt, result in zip(prompts, results): - print(prompt, end="") - print(f"{result['generation']}") - print("\n==================================\n") + # before we end, let's also do one round of inference + # we'll kick off the generation with "<|endoftext|>", which designates the start of a new sequence + start_ids = [enc.eot_token] + xg = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) + max_new_tokens = 32 + temperature = 1.0 + top_k = 40 + yg = raw_model.generate(xg, max_new_tokens, temperature=temperature, top_k=top_k) + print0('---------------') + print0(enc.decode(yg[0].tolist())) + print0('---------------') # bit confusing: we want to make sure to eval and sample on 0th iteration # but also after the very last iteration. so we loop for step <= num_iterations diff --git a/train_llama3.py b/train_llama3.py new file mode 100644 index 000000000..c1ccff942 --- /dev/null +++ b/train_llama3.py @@ -0,0 +1,1284 @@ +""" +Reference code for LLaMA-3.1 training and inference. +Will save the model weights into files, to be read from C as initialization. + +This code differs from GPT-2 very slightly, there are three main differences: +1) RoPE: LLaMA uses a different positional encoding scheme called Relative Positional Encoding (RoPE). +2) GQA: Grouped Query Attention (GQA) is used to reduce the number of attention heads. +3) SwiGLU: Swish-Gated Linear Unit (SwiGLU) is used as the activation function in the MLP. + +References: +# 1) https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/tokenizer.py +# 2) https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py +# 3) https://github.com/meta-llama/llama3/blob/11817d47e1ba7a4959b025eb1ca308572e0e3963/llama/generation.py + +Example launches to only benchmark the speed of bfloat16 compiled GPU training: +1 GPU: +python train_llama3.py --write_tensors=0 --num_iterations=50 --sequence_length=8192 --compile=1 --tensorcores=1 --dtype=bfloat16 +you can also turn on flash-attention by appending --flash=1 +4 GPU: +torchrun --standalone --nproc_per_node=4 train_llama3.py --write_tensors=0 --num_iterations=50 --sequence_length=8192 --compile=1 --tensorcores=1 --dtype=bfloat16 +""" + +import os +import math +import glob +import inspect +from contextlib import nullcontext +from dataclasses import dataclass +import json +from pathlib import Path +from typing import ( + AbstractSet, + Callable, + Collection, + Dict, + Iterator, + List, + Literal, + Optional, + Sequence, + Tuple, + Union, + cast, +) + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +import torch._inductor.config as config +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.distributed import init_process_group, destroy_process_group +from torch.distributed.optim import ZeroRedundancyOptimizer +import torch.distributed as dist + +import tiktoken +from tiktoken.load import load_tiktoken_bpe + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the LLaMA 3.x model + +# using a global to toggle flash-attention +FLASH = 0 + +# Used in Grouped Query Attention (GQA), broadcasts the key and value tensors +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + +# ----------------------------------------------------------------------------- +# RoPE related + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + +def apply_scaling(freqs: torch.Tensor): + # Values obtained from grid search + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 # original llama3 length + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False +): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + if use_scaled: + freqs = apply_scaling(freqs) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + +# ----------------------------------------------------------------------------- +# LLaMA building blocks + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + +class CausalSelfAttention(nn.Module): + + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + + self.n_head = config.n_head + self.n_kv_head = config.n_kv_head + self.n_rep = self.n_head // self.n_kv_head + self.hd = config.n_embd // config.n_head + self.use_kv = config.use_kv + + self.c_attn = nn.Linear(config.n_embd, (config.n_head + 2 * config.n_kv_head) * self.hd, bias=False) # key, query, value projections + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) # output projection + self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 + + # static KV cache - we could alternatively allocate it outside of the model and just pass it in when needed + if self.use_kv: + self.cache_k = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd)) + self.cache_v = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd)) + + def forward(self, x, freqs_cis=None, start_pos=None, mask: Optional[torch.Tensor] = None): + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + qkv = self.c_attn(x) + q, k, v = qkv.split([self.n_head * self.hd, self.n_kv_head * self.hd, self.n_kv_head * self.hd], dim=-1) + q, k, v = map(lambda t: t.view(B, T, -1, self.hd), (q, k, v)) # (B, T, NH, HD) + + q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) # rotate QK (rope) <-- 1. difference compared to GPT-2 + + if self.use_kv and not self.training and start_pos >= 0: # use kv-caching during inference + self.cache_k[:B, start_pos : start_pos + T] = k + self.cache_v[:B, start_pos : start_pos + T] = v + k = self.cache_k[:B, : start_pos + T] + v = self.cache_v[:B, : start_pos + T] + + k = repeat_kv(k, self.n_rep) # GQA <-- 2. difference compared to GPT-2 + v = repeat_kv(v, self.n_rep) + + q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v)) # (B, NH, T, HD) + + if FLASH: + # flashattention + y = F.scaled_dot_product_attention(q, k, v, mask) + else: + # manual implementation of attention + # this materializes the large (T,T) matrix for all the queries and keys + scores = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.hd)) + if mask is not None: + scores.masked_fill_(mask, torch.finfo(scores.dtype).min) + att = F.softmax(scores.float(), dim=-1).type_as(q) + y = att @ v # (B, NH, T, T) x (B, NH, T, HD) -> (B, NH, T, HD) + y = y.transpose(1, 2).contiguous().view(B, T, C) + y = self.c_proj(y) + return y + +class MLP(nn.Module): + + def __init__(self, config): + super().__init__() + hidden_dim = 4 * config.n_embd + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if config.ffn_dim_multiplier is not None: + hidden_dim = int(config.ffn_dim_multiplier * hidden_dim) + hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of) + self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=False) + self.c_fc2 = nn.Linear(config.n_embd, hidden_dim, bias=False) + self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False) + self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 + + def forward(self, x): + # SwiGLU self.c_proj(F.silu(self.c_fc2(x)) * self.c_fc(x)) <-- 3. difference compared to GPT-2 + x1 = self.c_fc(x) + x2 = self.c_fc2(x) + x2 = F.silu(x2) + x = x1 * x2 + x = self.c_proj(x) + return x + +class Block(nn.Module): + + def __init__(self, config): + super().__init__() + self.ln_1 = RMSNorm(config.n_embd, config.norm_eps) + self.attn = CausalSelfAttention(config) + self.ln_2 = RMSNorm(config.n_embd, config.norm_eps) + self.mlp = MLP(config) + + def forward(self, x, freqs_cis=None, start_pos=None, mask: Optional[torch.Tensor] = None): + x = x + self.attn(self.ln_1(x), freqs_cis, start_pos, mask) + x = x + self.mlp(self.ln_2(x)) + return x + +# ----------------------------------------------------------------------------- +# The main LLaMA 3.1 model + +@dataclass +class LlamaConfig: + version: str = "3.1" + block_size: int = 8192 + vocab_size: int = 128256 + n_layer: int = 32 + n_head: int = 32 + n_kv_head: int = 8 + n_embd: int = 4096 + ffn_dim_multiplier: float = 1.3 + multiple_of: int = 1024 + norm_eps: float = 1e-5 + rope_theta: float = 500000.0 + use_scaled_rope: bool = True + max_gen_batch_size: int = 4 + use_kv: bool = True + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if hasattr(self, k): + setattr(self, k, v) + assert self.n_kv_head <= self.n_head + assert self.n_head % self.n_kv_head == 0 + assert self.n_embd % self.n_head == 0 + +class LLaMA(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + + self.transformer = nn.ModuleDict(dict( + wte = nn.Embedding(config.vocab_size, config.n_embd), + h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), + ln_f = RMSNorm(config.n_embd, config.norm_eps), + )) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # init all weights, use a torch rng object to be very careful + self.init_rng = torch.Generator() + self.init_rng.manual_seed(42) + + self.freqs_cis = precompute_freqs_cis( + config.n_embd // config.n_head, + config.block_size * 2, + config.rope_theta, + config.use_scaled_rope, + ) + + def forward(self, idx, targets=None, return_logits=True, start_pos=0): + _, t = idx.size() + assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + + # forward the LLaMA model itself + x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + freqs_cis = self.freqs_cis[start_pos:start_pos+t] + + mask = torch.triu(torch.ones((t, t), device=next(self.parameters()).device, dtype=torch.bool), diagonal=1) + + for i, block in enumerate(self.transformer.h): + x = block(x, freqs_cis, start_pos, mask) + x = self.transformer.ln_f(x) + + if targets is not None: + # if we are given some desired targets also calculate the loss + logits = self.lm_head(x).float() + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + else: + # inference-time mini-optimization: only forward the lm_head on the very last position + logits = self.lm_head(x[:, [-1], :]).float() # note: using list [-1] to preserve the time dim + loss = None + + # there are performance reasons why not returning logits is prudent, if not needed + if not return_logits: + logits = None + + return logits, loss + + @staticmethod + def adapt_llama_state_dict_keys(checkpoint, config: LlamaConfig): + checkpoint['transformer.wte.weight'] = checkpoint.pop('tok_embeddings.weight') + + for i in range(config.n_layer): + for name in ['attention_norm', 'ffn_norm']: + old_key = f'layers.{i}.{name}.weight' # e.g. layers.x.attention_norm.weight -> transformer.h.x.ln_1.weight + new_key = f'transformer.h.{i}.ln_{1 if name == "attention_norm" else 2}.weight' + checkpoint[new_key] = checkpoint.pop(old_key) + + for i in range(config.n_layer): + for name in ['attention.wq', 'attention.wk', 'attention.wv']: + old_key = f'layers.{i}.{name}.weight' + new_key = f'transformer.h.{i}.attn.c_attn.weight' + if name == 'attention.wq': + checkpoint[new_key] = checkpoint.pop(old_key) + else: # merge 3 weights into transformer.h.x.attn.c_attn.weight + checkpoint[new_key] = torch.cat((checkpoint[new_key], checkpoint.pop(old_key)), dim=0) + old_key = f'layers.{i}.attention.wo.weight' + new_key = f'transformer.h.{i}.attn.c_proj.weight' + checkpoint[new_key] = checkpoint.pop(old_key) + + ffn_map = {'w1': 'c_fc2', 'w2': 'c_proj', 'w3': 'c_fc'} + for i in range(config.n_layer): + for name in ['feed_forward.w1', 'feed_forward.w2', 'feed_forward.w3']: + old_key = f'layers.{i}.{name}.weight' + new_key = f'transformer.h.{i}.mlp.{ffn_map[name.split(".")[-1]]}.weight' + checkpoint[new_key] = checkpoint.pop(old_key) + + checkpoint['transformer.ln_f.weight'] = checkpoint.pop('norm.weight') + checkpoint['lm_head.weight'] = checkpoint.pop('output.weight') + + return checkpoint + + @staticmethod + def adapt_llama_state_dict_keys_hf(checkpoint, config: LlamaConfig): + checkpoint['transformer.wte.weight'] = checkpoint.pop('model.embed_tokens.weight') + + # We need to unpermute K and V because HF script permuted the original Meta-LLaMA weights + # see: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py + def unpermute(w, n_heads, dim1, dim2): + return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + for i in range(config.n_layer): + for name in ['input_layernorm', 'post_attention_layernorm']: + old_key = f'model.layers.{i}.{name}.weight' # e.g. layers.x.attention_norm.weight -> transformer.h.x.ln_1.weight + new_key = f'transformer.h.{i}.ln_{1 if name == "input_layernorm" else 2}.weight' + checkpoint[new_key] = checkpoint.pop(old_key) + + for i in range(config.n_layer): + for name in ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj']: + old_key = f'model.layers.{i}.{name}.weight' + new_key = f'transformer.h.{i}.attn.c_attn.weight' + if name == 'self_attn.q_proj': + checkpoint[new_key] = unpermute(checkpoint.pop(old_key), config.n_head, config.n_embd, config.n_embd) + else: # merge 3 weights into transformer.h.x.attn.c_attn.weight + tensor = checkpoint.pop(old_key) + if name == 'self_attn.k_proj': + tensor = unpermute(tensor, config.n_kv_head, config.n_kv_head * (config.n_embd // config.n_head), config.n_embd) + checkpoint[new_key] = torch.cat((checkpoint[new_key], tensor), dim=0) + old_key = f'model.layers.{i}.self_attn.o_proj.weight' + new_key = f'transformer.h.{i}.attn.c_proj.weight' + checkpoint[new_key] = checkpoint.pop(old_key) + + ffn_map = {'gate_proj': 'c_fc2', 'down_proj': 'c_proj', 'up_proj': 'c_fc'} + for i in range(config.n_layer): + for name in ['gate_proj', 'down_proj', 'up_proj']: + old_key = f'model.layers.{i}.mlp.{name}.weight' + new_key = f'transformer.h.{i}.mlp.{ffn_map[name]}.weight' + checkpoint[new_key] = checkpoint.pop(old_key) + + checkpoint['transformer.ln_f.weight'] = checkpoint.pop('model.norm.weight') + + return checkpoint + + @classmethod + def from_pretrained_llama3_hf(cls, model_id): + """Loads pretrained LLaMA model weights from HuggingFace""" + from transformers import AutoModelForCausalLM, AutoTokenizer + assert model_id == "meta-llama/Meta-Llama-3.1-8B", "Only the 8B-bae model is supported for now" + model_args = LlamaConfig() + + model = AutoModelForCausalLM.from_pretrained(model_id) + checkpoint = LLaMA.adapt_llama_state_dict_keys_hf(model.state_dict(), model_args) + + original_default_type = torch.get_default_dtype() # save the default type + torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) # much faster loading + model = LLaMA(model_args) + model.load_state_dict(checkpoint, strict=False) + torch.set_default_tensor_type(torch.tensor([], dtype=original_default_type, device="cpu").type()) # restore default type + + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_id = 128004 # this is the pad token id for LLaMA 3.1 base, we need to set this explicitly as our generate func expects it + tokenizer.stop_tokens = [tokenizer.eos_token_id] + model.tokenizer = tokenizer + return model + + @classmethod + def from_pretrained_llama3_meta(cls, ckpt_dir, tokenizer_path): + """Loads pretrained LLaMA model weights from a checkpoint directory""" + model_args = LlamaConfig() + + ckpt_path = sorted(Path(ckpt_dir).glob("*.pth"))[0] + checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) + checkpoint = LLaMA.adapt_llama_state_dict_keys(checkpoint, model_args) + + original_default_type = torch.get_default_dtype() # save the default type + torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) # much faster loading + model = LLaMA(model_args) + model.load_state_dict(checkpoint, strict=False) + torch.set_default_tensor_type(torch.tensor([], dtype=original_default_type, device="cpu").type()) # restore default type + + tokenizer = Tokenizer(model_path=tokenizer_path) + # add <|end_of_text|> as the stop token for base model - this is an omission in the reference code + # the reference code only adds instruct model stop tokens... + tokenizer.stop_tokens = tokenizer.stop_tokens + [128001] + model.tokenizer = tokenizer + return model + + def configure_optimizers(self, weight_decay, learning_rate, betas, device_type, zero_stage): + # start with all of the candidate parameters + param_dict = {pn: p for pn, p in self.named_parameters()} + # filter out those that do not require grad + param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} + # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. + # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. + decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] + nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] + optim_groups = [ + {'params': decay_params, 'weight_decay': weight_decay}, + {'params': nodecay_params, 'weight_decay': 0.0} + ] + num_decay_params = sum(p.numel() for p in decay_params) + num_nodecay_params = sum(p.numel() for p in nodecay_params) + print0(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") + print0(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") + # Create AdamW optimizer and use the fused version if it is available + fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters + use_fused = fused_available and device_type == 'cuda' + print0(f"using fused AdamW: {use_fused}") + if zero_stage == 1: + print0("using ZeroRedundancyOptimizer") + optimizer = ZeroRedundancyOptimizer(**optim_groups[0], optimizer_class=torch.optim.AdamW, + lr=learning_rate, betas=betas, fused=use_fused) + optimizer.add_param_group(optim_groups[1]) + else: + print0("using regular AdamW") + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, fused=use_fused) + return optimizer + + @torch.inference_mode() + def generate( + self, + prompt_tokens: List[List[int]], + max_gen_len: int, + temperature: float = 0.6, + top_p: float = 0.9, + logprobs: bool = False, + echo: bool = False, + ) -> Tuple[List[List[int]], Optional[List[List[float]]]]: + """ + Generate text sequences based on provided prompts using the language generation model. + + Args: + prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers. + max_gen_len (int): Maximum length of the generated text sequence. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + + Returns: + Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities. + + Note: + This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness. + If logprobs is True, token log probabilities are computed for each generated token. + + """ + bsz = len(prompt_tokens) + assert bsz <= self.config.max_gen_batch_size, (bsz, self.config.max_gen_batch_size) + device = next(self.parameters()).device + + min_prompt_len = min(len(t) for t in prompt_tokens) + max_prompt_len = max(len(t) for t in prompt_tokens) + assert max_prompt_len <= self.config.block_size + total_len = min(self.config.block_size, max_gen_len + max_prompt_len) + + pad_id = self.tokenizer.pad_id + tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=device) + for k, t in enumerate(prompt_tokens): + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device) + if logprobs: + token_logprobs = torch.zeros_like(tokens, dtype=torch.float) + + prev_pos = 0 + eos_reached = torch.tensor([False] * bsz, device=device) + input_text_mask = tokens != pad_id + + if min_prompt_len == total_len: + logits, _ = self.forward(tokens, start_pos=prev_pos) + token_logprobs = -F.cross_entropy( + input=logits.transpose(1, 2), + target=tokens, + reduction="none", + ignore_index=pad_id, + ) + + stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens)).to(device) + + for cur_pos in range(min_prompt_len, total_len): + logits, _ = self.forward(tokens[:, prev_pos:cur_pos], start_pos=prev_pos) + if temperature > 0: + probs = torch.softmax(logits[:, -1] / temperature, dim=-1) + next_token = sample_top_p(probs, top_p) + else: + next_token = torch.argmax(logits[:, -1], dim=-1) + + next_token = next_token.reshape(-1) + # only replace token if prompt has already been generated + next_token = torch.where( + input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token + ) + tokens[:, cur_pos] = next_token + if logprobs: + token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( + input=logits.transpose(1, 2), + target=tokens[:, prev_pos + 1 : cur_pos + 1], + reduction="none", + ignore_index=pad_id, + ) + eos_reached |= (~input_text_mask[:, cur_pos]) & ( + torch.isin(next_token, stop_tokens) + ) + prev_pos = cur_pos + if all(eos_reached): + break + + if logprobs: + token_logprobs = token_logprobs.tolist() + out_tokens, out_logprobs = [], [] + for i, toks in enumerate(tokens.tolist()): + # cut to max gen len + start = 0 if echo else len(prompt_tokens[i]) + toks = toks[start : len(prompt_tokens[i]) + max_gen_len] + probs = None + if logprobs: + probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] + # cut to after eos tok if any + for stop_token in self.tokenizer.stop_tokens: + try: + eos_idx = toks.index(stop_token) + toks = toks[:eos_idx] + probs = probs[:eos_idx] if logprobs else None + except ValueError: + pass + out_tokens.append(toks) + out_logprobs.append(probs) + return (out_tokens, out_logprobs if logprobs else None) + +# ----------------------------------------------------------------------------- +# sampling utils + +def sample_top_p(probs, p): + """ + Perform top-p (nucleus) sampling on a probability distribution. + + Args: + probs (torch.Tensor): Probability distribution tensor. + p (float): Probability threshold for top-p sampling. + + Returns: + torch.Tensor: Sampled token indices. + + Note: + Top-p sampling selects the smallest set of tokens whose cumulative probability mass + exceeds the threshold p. The distribution is renormalized based on the selected tokens. + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + +# ----------------------------------------------------------------------------- +# Llama 3.1 Tokenizer + +# The tiktoken tokenizer can handle <=400k chars without +# pyo3_runtime.PanicException. +TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + +# https://github.com/openai/tiktoken/issues/195 +# Here we iterate over subsequences and split if we exceed the limit +# of max consecutive non-whitespace or whitespace characters. +MAX_NO_WHITESPACES_CHARS = 25_000 + + +class Tokenizer: + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + """ + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 + + def __init__(self, model_path: str): + """ + Initializes the Tokenizer with a Tiktoken model. + + Args: + model_path (str): The path to the Tiktoken model file. + """ + assert os.path.isfile(model_path), model_path + + mergeable_ranks = load_tiktoken_bpe(model_path) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|finetune_right_pad_id|>", + "<|step_id|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eom_id|>", # end of message + "<|eot_id|>", # end of turn + "<|python_tag|>", + ] + reserved_tokens = [ + f"<|reserved_special_token_{2 + i}|>" + for i in range(self.num_reserved_special_tokens - len(special_tokens)) + ] + special_tokens = special_tokens + reserved_tokens + + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + + self.n_words: int = num_base_tokens + len(special_tokens) + # BOS / EOS token IDs + self.bos_id: int = self.special_tokens["<|begin_of_text|>"] + self.eos_id: int = self.special_tokens["<|end_of_text|>"] + self.eot_id: int = self.special_tokens["<|eot_id|>"] + self.eom_id: int = self.special_tokens["<|eom_id|>"] + self.python_tag_id = self.special_tokens["<|python_tag|>"] + self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"] + self.stop_tokens = [ + self.special_tokens["<|eom_id|>"], + self.special_tokens["<|eot_id|>"], + ] + + def encode( + self, + s: str, + *, + bos: bool, + eos: bool, + allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, + disallowed_special: Union[Literal["all"], Collection[str]] = (), + ) -> List[int]: + """ + Encodes a string into a list of token IDs. + + Args: + s (str): The input string to be encoded. + bos (bool): Whether to prepend the beginning-of-sequence token. + eos (bool): Whether to append the end-of-sequence token. + allowed_tokens ("all"|set[str]): allowed special tokens in string + disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string + + Returns: + list[int]: A list of token IDs. + + By default, setting disallowed_special=() encodes a string by ignoring + special tokens. Specifically: + - Setting `disallowed_special` to () will cause all text corresponding + to special tokens to be encoded as natural text (insteading of raising + an error). + - Setting `allowed_special` to "all" will treat all text corresponding + to special tokens to be encoded as special tokens. + """ + if allowed_special is None: + allowed_special = set() + assert type(s) is str + + substrs = ( + substr + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) + for substr in self._split_whitespaces_or_nonwhitespaces( + s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS + ) + ) + t: List[int] = [] + for substr in substrs: + t.extend( + self.model.encode( + substr, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ) + if bos: + t.insert(0, self.bos_id) + if eos: + t.append(self.eos_id) + return t + + def decode(self, t: Sequence[int]) -> str: + """ + Decodes a list of token IDs into a string. + + Args: + t (List[int]): The list of token IDs to be decoded. + + Returns: + str: The decoded string. + """ + # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. + return self.model.decode(cast(List[int], t)) + + @staticmethod + def _split_whitespaces_or_nonwhitespaces( + s: str, max_consecutive_slice_len: int + ) -> Iterator[str]: + """ + Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` + consecutive whitespaces or consecutive non-whitespaces. + """ + current_slice_len = 0 + current_slice_is_space = s[0].isspace() if len(s) > 0 else False + slice_start = 0 + + for i in range(len(s)): + is_now_space = s[i].isspace() + + if current_slice_is_space ^ is_now_space: + current_slice_len = 1 + current_slice_is_space = is_now_space + else: + current_slice_len += 1 + if current_slice_len > max_consecutive_slice_len: + yield s[slice_start:i] + slice_start = i + current_slice_len = 1 + yield s[slice_start:] + +# ----------------------------------------------------------------------------- +# Our own simple Distributed Data Loader + +def _peek_data_shard(filename): + # only reads the header, returns header data + with open(filename, "rb") as f: + # first read the header, which is 256 int32 integers (4 bytes each) + header = np.frombuffer(f.read(256*4), dtype=np.int32) + if header[0] != 20240520: + print("ERROR: magic number mismatch in the data .bin file!") + print("---> HINT: Are you passing in a correct file with --input_bin?") + print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README") + print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try") + exit(1) + assert header[1] == 1, "unsupported version" + ntok = header[2] # number of tokens (claimed) + return ntok # for now just return the number of tokens + +def _load_data_shard(filename): + with open(filename, "rb") as f: + # first read the header, which is 256 int32 integers (4 bytes each) + header = np.frombuffer(f.read(256*4), dtype=np.int32) + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + ntok = header[2] # number of tokens (claimed) + # the rest of it are tokens, stored as uint16 + tokens = np.frombuffer(f.read(), dtype=np.uint16) + assert len(tokens) == ntok, "number of tokens read does not match header?" + return tokens + +class DistributedDataLoader: + def __init__(self, filename_pattern, B, T, process_rank, num_processes): + self.process_rank = process_rank + self.num_processes = num_processes + self.B = B + self.T = T + + # glob files that match the pattern + self.files = sorted(glob.glob(filename_pattern)) + assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}" + + # load and validate all data shards, count number of tokens in total + ntok_total = 0 + for fname in self.files: + shard_ntok = _peek_data_shard(fname) + assert shard_ntok >= num_processes * B * T + 1 + ntok_total += shard_ntok + self.ntok_total = ntok_total + print0(f"DataLoader: total number of tokens: {ntok_total:,} across {len(self.files)} files") + + # kick things off + self.current_shard = None + self.reset() + + def reset(self): + # we're being a bit clever here: if we already had shard 0 loaded, + # then don't do the work to reload it, just reset the pointer + if self.current_shard != 0: + self.current_shard = 0 + self.tokens = _load_data_shard(self.files[self.current_shard]) + self.current_position = self.process_rank * self.B * self.T + + def advance(self): # advance to next data shard + self.current_shard = (self.current_shard + 1) % len(self.files) + self.current_position = self.process_rank * self.B * self.T + self.tokens = _load_data_shard(self.files[self.current_shard]) + + def next_batch(self): + B = self.B + T = self.T + buf = self.tokens[self.current_position : self.current_position+B*T+1] + buf = torch.tensor(buf.astype(np.int32), dtype=torch.long) + x = (buf[:-1]).view(B, T) # inputs + y = (buf[1:]).view(B, T) # targets + # advance the start pointer in current shard + self.current_position += B * T * self.num_processes + # if loading the next batch would be out of bounds advance the shard + if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens): + self.advance() + return x, y + +# ----------------------------------------------------------------------------- +# Python -> C bridge utilities for saving params/grads/activations to .bin files + +def write_fp32(tensor, file): + t = tensor.detach().cpu().to(torch.float32) + b = t.numpy().tobytes() + file.write(b) + +def write_bf16(tensor, file): + t = tensor.detach().cpu().to(torch.bfloat16) + # numpy doesn't have bf16 datatype so we have to trick it + t = t.view(torch.int16) # trick: reinterpret as int16 + b = t.numpy().tobytes() + file.write(b) + +def write_tensors(model_tensors, L, file, dtype): + # writes LLaMA 3 model's weights to a binary file + assert dtype in {"float32", "bfloat16"} + write_fun = write_fp32 if dtype == "float32" else write_bf16 + write_fun(model_tensors["transformer.wte.weight"], file) # (V, C) + for i in range(L): # (L, C) + write_fun(model_tensors[f"transformer.h.{i}.ln_1.weight"], file) + for i in range(L): # (L, 3C, C) + write_fun(model_tensors[f"transformer.h.{i}.attn.c_attn.weight"], file) + for i in range(L): # (L, C, C) + write_fun(model_tensors[f"transformer.h.{i}.attn.c_proj.weight"], file) + for i in range(L): # (L, C) + write_fun(model_tensors[f"transformer.h.{i}.ln_2.weight"], file) + for i in range(L): # (L, 4C, C) + write_fun(model_tensors[f"transformer.h.{i}.mlp.c_fc.weight"], file) + for i in range(L): # (L, 4C, C) + write_fun(model_tensors[f"transformer.h.{i}.mlp.c_fc2.weight"], file) + for i in range(L): # (L, C, 4C) + write_fun(model_tensors[f"transformer.h.{i}.mlp.c_proj.weight"], file) + write_fun(model_tensors["transformer.ln_f.weight"], file) # (C, ) + write_fun(model_tensors["lm_head.weight"], file) # (V, C) + +def write_model(model, filename, dtype): + # everything we need to instantiate the model + # 1) header is: version int, LLaMAConfig ints, padding to 1024 bytes + assert dtype in {"float32", "bfloat16"} + version = { + "float32": 3, # 3: all tensors are fp32 + "bfloat16": 5, # 5: all tensors are bf16 + }[dtype] + header = torch.zeros(256, dtype=torch.int32) + header[0] = 20240803 # magic + header[1] = version # checkpoint version + header[2] = model.config.block_size + header[3] = model.config.vocab_size + header[4] = model.config.n_layer + header[5] = model.config.n_head + header[6] = model.config.n_kv_head + header[7] = model.config.n_embd + header[8] = model.config.ffn_dim_multiplier + header[9] = model.config.multiple_of + header[10] = model.config.norm_eps + header[11] = model.config.rope_theta + header[12] = model.config.use_scaled_rope + header[13] = model.config.max_gen_batch_size + header[14] = int(model.config.version.split('.')[0]) # major version + header[15] = int(model.config.version.split('.')[1]) # minor version + # 2) the parameters follow the header + params = {name: param.cpu() for name, param in model.named_parameters()} + # now write to file + with open(filename, "wb") as file: + file.write(header.numpy().tobytes()) # header + write_tensors(params, model.config.n_layer, file, dtype) # params + print(f"wrote {filename}") + +def write_state(model, x, y, logits, loss, filename): + # the state is used for debugging. + # it contains information about the input, logits, loss, and the parameter gradients + # this can be used for checking the computation correctness in C + header = torch.zeros(256, dtype=torch.int32) + header[0] = 20240803 # magic + header[1] = x.size(0) # batch size of the batch, B + header[2] = x.size(1) # temporal extent of the batch, T + grads = {name: param.grad.cpu() for name, param in model.named_parameters()} + with open(filename, "wb") as file: + # header + file.write(header.numpy().tobytes()) + # input x + file.write(x.cpu().numpy().astype("int32").tobytes()) # (B, T) + # targets y + file.write(y.cpu().numpy().astype("int32").tobytes()) # (B, T) + # logits (result of the model forward pass) + write_fp32(logits.cpu(), file) + # loss (single float, result of the cross entropy loss) + write_fp32(loss.cpu(), file) + # gradients + write_tensors(grads, model.config.n_layer, file, "float32") + print(f"wrote {filename}") + +# ----------------------------------------------------------------------------- +# int main + +def print0(*args, **kwargs): + # modified print that only prints from the master process + # if this is not a distributed run, it's just a print + if int(os.environ.get("RANK", 0)) == 0: + print(*args, **kwargs) + +if __name__ == "__main__": + import time + import argparse + print0(f"Running pytorch {torch.version.__version__}") + + # default settings will overfit a tiny batch of data + # and save model weights and debug state to disk on the first iteration + parser = argparse.ArgumentParser() + parser.add_argument("--use_hf", type=int, default=1, help="use HuggingFace (default) or use Meta's model") + parser.add_argument("--ckpt_dir", type=str, default=None, help="path to llama3 model checkpoint") + parser.add_argument("--tokenizer_path", type=str, default=None, help="path to llama3 tokenizer") + # file system input / output + parser.add_argument("--input_bin", type=str, default="dev/data/tinyshakespeare/tiny_shakespeare_val.bin", help="input .bin to train on") + parser.add_argument("--input_val_bin", type=str, default="", help="input .bin to eval validation loss on") + parser.add_argument("--output_dir", type=str, default="", help="output directory to which to write logs and checkpoints") + parser.add_argument("--model", type=str, default="meta-llama/Meta-Llama-3.1-8B", help="chose the llama model") + # token layout for each step of the optimization + parser.add_argument("--batch_size", type=int, default=4, help="batch size, in units of #batch dimensions") + parser.add_argument("--sequence_length", type=int, default=64, help="sequence length") + parser.add_argument("--total_batch_size", type=int, default=256, help="total desired batch size, in units of #tokens") + # workload (number of steps) + parser.add_argument("--num_iterations", type=int, default=10, help="number of iterations to run") + parser.add_argument("--inference_only", type=int, default=0, help="only run inference") + # optimization + parser.add_argument("--learning_rate", type=float, default=1e-4, help="learning rate warmup iterations") + parser.add_argument("--warmup_iters", type=int, default=0, help="learning rate warmup iterations") + parser.add_argument("--learning_rate_decay_frac", type=float, default=1.0, help="learning rate warmup iterations") + parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay") + parser.add_argument("--grad_clip", type=float, default=1.0, help="maximum gradient magnitude") + # evaluation + parser.add_argument("--val_loss_every", type=int, default=0, help="every how mant steps to evaluate val loss?") + parser.add_argument("--val_max_steps", type=int, default=20, help="how many batches of val to average?") + parser.add_argument("--sample_every", type=int, default=0, help="how often to sample from the model?") + # debugging + parser.add_argument("--overfit_single_batch", type=int, default=1, help="overfit just one batch of data") + # numerics + parser.add_argument("--tensorcores", type=int, default=0, help="use tensorcores") + # memory management + parser.add_argument("--device", type=str, default="", help="by default we autodetect, or set it here") + parser.add_argument("--compile", type=int, default=0, help="torch.compile the model") + parser.add_argument("--flash", type=int, default=0, help="use flash attention") + parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|float16|bfloat16") + parser.add_argument("--zero_stage", type=int, default=0, help="zero redundancy optimizer stage (0/1/2/3)") + # python -> C bridge + parser.add_argument("--write_tensors", type=int, default=0, help="write tensors to disk") + args = parser.parse_args() + + # args error checking and convenience variables + B, T = args.batch_size, args.sequence_length + assert 1 <= T <= 8192, "sequence length must be between 1 and 8192" + assert args.dtype in {"float32", "float16", "bfloat16"} + assert args.model in {"meta-llama/Meta-Llama-3.1-8B"} # only 8B base model supported for now + + # create the logging directory if it does not exist + logfile = None + if args.output_dir: + os.makedirs(args.output_dir, exist_ok=True) + logfile = os.path.join(args.output_dir, "main.log") + # create the log file "main.log" inside it, and wipe it clean + with open(logfile, "w") as f: + pass + + # set up DDP (distributed data parallel). torchrun sets this env variable + ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? + if ddp: + # use of DDP atm demands CUDA, we set the device appropriately according to rank + assert torch.cuda.is_available(), "for now i think we need CUDA for DDP" + init_process_group(backend='nccl') + ddp_rank = int(os.environ['RANK']) + ddp_local_rank = int(os.environ['LOCAL_RANK']) + ddp_world_size = int(os.environ['WORLD_SIZE']) + device = f'cuda:{ddp_local_rank}' + torch.cuda.set_device(device) + master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. + seed_offset = 0 # each process gets the exact same seed + zero_stage = args.zero_stage + else: + ddp_rank = 0 + ddp_local_rank = 0 + zero_stage = 0 + ddp_world_size = 1 + master_process = True + seed_offset = 0 + # select the device + if args.device: + # provided explicitly by the user + device = args.device + else: + # attempt to autodetect the device + device = "cpu" + if torch.cuda.is_available(): + device = "cuda" + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + device = "mps" + print(f"using device: {device}") + device_type = 'cuda' if 'cuda' in device else 'cpu' + assert device_type in {'cuda'} # we need to load LLaMA as bf16 on CUDA + + # calculate gradient accumulation from the desired total batch size and the current run configuration + tokens_per_fwdbwd = B * T * ddp_world_size + assert args.total_batch_size % tokens_per_fwdbwd == 0 + grad_accum_steps = args.total_batch_size // tokens_per_fwdbwd + print0(f"total desired batch size: {args.total_batch_size}") + print0(f"=> calculated gradient accumulation steps: {grad_accum_steps}") + + # set up a context manager following the desired dtype and device + ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype] + ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if (device_type == "cuda") else nullcontext() + + # rng / reproducibility + torch.manual_seed(42) + if torch.cuda.is_available(): + torch.cuda.manual_seed(42) + + # set the torch precision mode to use TensorFloat32 (TF32) for matmuls + # docs https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html + if args.tensorcores: + torch.set_float32_matmul_precision('high') + + # turn on/off flash attention + assert args.flash in {0, 1} + FLASH = args.flash + + # init the model + assert args.ckpt_dir is not None and os.path.exists(args.ckpt_dir), f"llama3 ckpt dir {args.ckpt_dir} does not exist" + assert args.tokenizer_path is not None and os.path.exists(args.tokenizer_path), f"llama3 tokenizer path {args.tokenizer_path} does not exist" + if args.use_hf: + model = LLaMA.from_pretrained_llama3_hf(args.model) + else: # use Meta's checkpoint + model = LLaMA.from_pretrained_llama3_meta(args.ckpt_dir, args.tokenizer_path) + + model.train() + if args.compile: + if hasattr(config, "coordinate_descent_tuning"): + config.coordinate_descent_tuning = True # suggested by @Chillee + print0("compiling the model...") + model = torch.compile(model) + + # ------------------------------------------------------------------------- + # Our own version of a simple DistributedDataLoader + + # load tokens + train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size) + val_loader = None + if args.input_val_bin: + val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size) + + # ------------------------------------------------------------------------- + # PyTorch -> C bridge: save some weights and state for C to load later as reference + + # do one forward pass to generate ground truth for our C tests + if master_process and args.write_tensors and (not args.inference_only): + x, y = train_loader.next_batch() + x, y = x.to(device), y.to(device) + logits, loss = model(x, y) + loss.backward() + # save model params, in bfloat16 + model_to_size = {"meta-llama/Meta-Llama-3.1-8B": "8B"} + model_size_str = model_to_size[args.model] # e.g. "8B" + write_model(model, os.path.join(args.output_dir, f"llama3.1_{model_size_str}_bf16.bin"), dtype="bfloat16") + # save x, y, logits, loss, and parameter gradients, for debugging C + # always store these in fp32 to have an accurate reference (?) + write_state(model, x, y, logits, loss, os.path.join(args.output_dir, f"llama3_{model_size_str}_debug_state.bin")) + # reset the train_loader for the optimization below + train_loader.reset() + + # ------------------------------------------------------------------------- + # main training loop + + # here we wrap model into DDP container + if ddp: + model = DDP(model, device_ids=[ddp_local_rank]) + raw_model = model.module if ddp else model # always contains the "raw" unwrapped model + + # init the optimizer + optimizer = raw_model.configure_optimizers(weight_decay=args.weight_decay, + learning_rate=args.learning_rate, betas=(0.9, 0.95), + device_type=device, zero_stage=zero_stage) + + # learning rate decay scheduler (cosine with warmup) + def get_lr(it): + min_lr = args.learning_rate * args.learning_rate_decay_frac + # 1) linear warmup for warmup_iters steps + if it < args.warmup_iters: + return args.learning_rate * (it+1) / args.warmup_iters + # 2) if it > lr_decay_iters, return min learning rate + if it > args.num_iterations: + return min_lr + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (it - args.warmup_iters) / (args.num_iterations - args.warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0 + return min_lr + coeff * (args.learning_rate - min_lr) + + if device == "cuda": + torch.cuda.reset_peak_memory_stats() + timings = [] + norm = -1.0 # dummy value to print in inference-only mode + for step in range(args.num_iterations + 1): + t0 = time.time() + last_step = (step == args.num_iterations) + + # once in a while evaluate the validation dataset + if (args.val_loss_every > 0 \ + and (step % args.val_loss_every == 0 or last_step)) \ + and (val_loader is not None): + model.eval() + val_loader.reset() + with torch.no_grad(): + val_loss = 0.0 + for _ in range(args.val_max_steps): + x, y = val_loader.next_batch() + x, y = x.to(device), y.to(device) + _, loss = model(x, y, return_logits=False) + val_loss += loss.item() + val_loss /= args.val_max_steps + # log to console and to file + print0(f"val loss {val_loss}") + if master_process and logfile is not None: + with open(logfile, "a") as f: + f.write("s:%d tel:%f\n" % (step, val_loss)) + + # once in a while perform model inference on the master process + if (args.sample_every > 0 \ + and (step % args.sample_every == 0 or last_step)) \ + and master_process: + model.eval() + prompts: List[str] = json.loads(open(os.path.join(os.path.dirname(__file__), 'llmc_py', 'prompts.json')).read())['prompts'] + if args.use_hf: + prompt_tokens = [model.tokenizer(x).input_ids for x in prompts] + else: # Meta + prompt_tokens = [model.tokenizer.encode(x, bos=True, eos=False) for x in prompts] + + generation_tokens, _ = model.generate(prompt_tokens, max_gen_len=64, temperature=0.6, top_p=0.9, logprobs=False, echo=False) + results = [{"generation": model.tokenizer.decode(t)} for t in generation_tokens] + for prompt, result in zip(prompts, results): + print(prompt, end="") + print(f"{result['generation']}") + print("\n==================================\n") + + # bit confusing: we want to make sure to eval and sample on 0th iteration + # but also after the very last iteration. so we loop for step <= num_iterations + # instead of just < num_iterations (one extra due to <=), only to do + # the validation/sampling one last time, and then we break right here as we're done. + if last_step: + break + + # --------------- TRAINING SECTION BEGIN ----------------- + model.train() + optimizer.zero_grad(set_to_none=True) + # if we are trying to overfit a single batch, we reset the loader here + if args.overfit_single_batch: + train_loader.reset() + # micro-batch loop where we do gradient accumulation to reach desired total batch size + lossf = 0.0 # for getting the mean loss (as simple float) over the accumulation steps + for micro_step in range(grad_accum_steps): + # fetch a batch + x, y = train_loader.next_batch() + x, y = x.to(device), y.to(device) + if ddp: + # we want only the last micro-step to sync grads in a DDP model + # the official way to do this is with model.no_sync(), but that is a + # context manager that bloats the code, so we just toggle this variable + model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1) + # forward pass + with ctx: + _, loss = model(x, y, return_logits=False) + # we have to scale the loss to account for gradient accumulation, + # because the gradients just add on each successive backward(). + # addition of gradients corresponds to a SUM in the objective, but + # instead of a SUM we want MEAN, so we scale the loss here + loss = loss / grad_accum_steps + lossf += loss.detach() # keep track of the mean loss + # backward pass + if not args.inference_only: + loss.backward() + if ddp: + dist.all_reduce(lossf, op=dist.ReduceOp.AVG) + lossf = lossf.item() + norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) + # determine and set the learning rate for this iteration + lr = get_lr(step) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + # step the optimizer + optimizer.step() + # --------------- TRAINING SECTION END ------------------- + # everything that follows now is just diagnostics, prints, logging, etc. + + # wait on the CPU for all device work to end so we get accurate per-iteration timings below + if device == "mps": + torch.mps.synchronize() + elif device == "cuda": + torch.cuda.synchronize() + # time and print + t1 = time.time() + # the 0th iteration is often an outlier (much slower) => skip logging it + tokens_per_second = grad_accum_steps * ddp_world_size * B * T / (t1-t0) + print0(f"step {step+1:4d}/{args.num_iterations} | train loss {lossf:.6f} | norm {norm:.4f} | lr {lr:.2e} | ({(t1-t0)*1000:.2f} ms | {tokens_per_second:.0f} tok/s)") + # log to logile + if master_process and logfile is not None: + with open(logfile, "a") as f: + f.write("s:%d trl:%f\n" % (step, lossf)) + + # keep track of smooth timings, last 20 iterations + if step > 0 and step > args.num_iterations - 20: + timings.append(t1-t0) + + # print the average of the last 20 timings, to get something smooth-ish + timings = timings[-20:] + print0(f"final {len(timings)} iters avg: {np.mean(timings)*1000:.3f}ms") + print0(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") + + # ------------------------------------------------------------------------- + # clean up nice + if ddp: + destroy_process_group() From dfd459bfc4ccc15347f921187e155b863cc54990 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Thu, 8 Aug 2024 18:19:40 +0200 Subject: [PATCH 30/36] Remove prompts.json --- train_llama3.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/train_llama3.py b/train_llama3.py index c1ccff942..f20b2c343 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -1192,7 +1192,17 @@ def get_lr(it): and (step % args.sample_every == 0 or last_step)) \ and master_process: model.eval() - prompts: List[str] = json.loads(open(os.path.join(os.path.dirname(__file__), 'llmc_py', 'prompts.json')).read())['prompts'] + prompts: List[str] = [ + "Clearly, the meaning of life is", + "Simply put, the theory of relativity states that", + """The repo llm.c on GitHub is""", + """Translate English to French: + + sea otter => loutre de mer + peppermint => menthe poivrée + plush girafe => girafe peluche + cheese =>""", + ] if args.use_hf: prompt_tokens = [model.tokenizer(x).input_ids for x in prompts] else: # Meta From ac01536b99011ef7fc68bb645c310121512c547f Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Thu, 8 Aug 2024 18:20:18 +0200 Subject: [PATCH 31/36] Remove the whole llmc_py --- llmc_py/prompts.json | 8 -------- llmc_py/rope.py | 0 llmc_py/tokenizer.py | 16 ---------------- llmc_py/utils.py | 0 4 files changed, 24 deletions(-) delete mode 100644 llmc_py/prompts.json delete mode 100644 llmc_py/rope.py delete mode 100644 llmc_py/tokenizer.py delete mode 100644 llmc_py/utils.py diff --git a/llmc_py/prompts.json b/llmc_py/prompts.json deleted file mode 100644 index b089bb602..000000000 --- a/llmc_py/prompts.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "prompts": [ - "Clearly, the meaning of life is", - "Simply put, the theory of relativity states that", - "The repo llm.c on GitHub is", - "Translate English to French:\n\nsea otter => loutre de mer\npeppermint => menthe poivrée\nplush girafe => girafe peluche\ncheese =>" - ] - } \ No newline at end of file diff --git a/llmc_py/rope.py b/llmc_py/rope.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/llmc_py/tokenizer.py b/llmc_py/tokenizer.py deleted file mode 100644 index 2d2a3bd58..000000000 --- a/llmc_py/tokenizer.py +++ /dev/null @@ -1,16 +0,0 @@ -# From: https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/tokenizer.py - -from typing import ( - AbstractSet, - Callable, - Collection, - Dict, - Iterator, - List, - Literal, - Optional, - Sequence, - Union, - cast, -) - diff --git a/llmc_py/utils.py b/llmc_py/utils.py deleted file mode 100644 index e69de29bb..000000000 From 89addd3e3e8612dfac38ebaf8ac96e23a6fec903 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Thu, 8 Aug 2024 18:20:51 +0200 Subject: [PATCH 32/36] Remove pycache --- llmc_py/__pycache__/rope.cpython-310.pyc | Bin 2167 -> 0 bytes llmc_py/__pycache__/tokenizer.cpython-310.pyc | Bin 5372 -> 0 bytes llmc_py/__pycache__/utils.cpython-310.pyc | Bin 2245 -> 0 bytes 3 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 llmc_py/__pycache__/rope.cpython-310.pyc delete mode 100644 llmc_py/__pycache__/tokenizer.cpython-310.pyc delete mode 100644 llmc_py/__pycache__/utils.cpython-310.pyc diff --git a/llmc_py/__pycache__/rope.cpython-310.pyc b/llmc_py/__pycache__/rope.cpython-310.pyc deleted file mode 100644 index 8797638e54d73f9b72c780cf7c209e54840afb2c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2167 zcmaJ?OK)5?6t;bRA2XRenx<_^t12J?8VE_%E)Y@`Dp7U8A`9qhL^Zk@doq)`Gk4;9 z(EU{vNvSZCFe8)G_2SRvsj(v}h?c?*E&y&y2)*0H* zKYic*BVp_p`gpm~>`h4V4U}SvXRJlN$oNpS1g~sCD^P(BcBGP|yHbTJdMI0=iv71x ziK^i{Qg!b&<8`&7p#^I;1=8@xi*u$9ipi;OaNBeo=Pv z;do<)LZED|vd^1F7rk_>+uhu>w@jX@PFj>}H?q8wX2k{u(^XAr8d&2(@9%1Qv^Ufy zEwz*P`uct|Fmy#jSvT{fC@e!~nOP`z?mtf=r-I5Z;+OdqzRc^~ERuWZv3(_lv{wZA z2vYn2rT7GJO@s<2fszwxr4^UhBt##ekG9z)w(J(WyD>>rSOVZBW)sZRY|TakX5O~( zp|HsjN?>;hws0WL^~z@XC3b2L;Tsem_#%fIuaoe&bK=y3o*meWTv+C&({oX&M@4ST1=?(xR?iqJG1n*nAYe^9aBR(I`TaI!bZS`W^|yPk zSmyIgwLAH!)ca*S)1%hH^ELO=dzvg~UM!lLY9n23TsTSlx8mtcR1EXH>{VspYr&Vn zEaJ}&+7w-Xv^|GVi?^Vpkeq&rNJ#TJNCUFK8=}r5E>Bj1C2m&eYGQvEGH{X6<(?UJ zsP3kX4kGb2@+}4>I|Q{4#Sw@~l*6lqy7>|$|BfsE6!r6zf5M&st4ZK}`Ey_XKfbaI zt+2rZR3BDFR1Hhm<1khMc0#SehH8At7TNGHvGB1GdVa*aL~OZfsHDw#sAIZFvt}-% zbMr*1c8b>xl+E-$2$AJ!>EcrwRUGqFX-r*>PBRyGv$RB2xR`1PJ2YeSIvw!_DLSGR z?C-Xr2JKdaQNCBAA5i~quofz?OkSqupsj~n1XV%Lz-qV(sxJATc^8IXA&fox@#bw8 z;8AixRir&#rmntMXpgv79jf(ddj+fvR>du)6QEdzL|@G z>bay8R9W}zSQ;BmfMY{XPXFh24L$`ixllMowZlp4}{a<)-5~%^`27S@9~(VQ;uu^{UMgZ=_lCYHE)jje4WlYei$tac>-RC2xY+ z-k#SrR%Xr(jXB$fx0mTFT7Bp)&eoRedTzZF$c_4NKD@A+N)a^V3YWRF7(`L98u7fc zm_!k8$S{d>>qOYV+DXYpAQO>Wr@~a`gg+~9fXy;$vR1gXSf z<53bYUxt@ta*4 zTH|GA%t2nB!z^!zm3V2*U^Xk?u)GR$*buAy)?mZD#HxIljiA+dm5=Zm8^zvHUYgd3 z%CyGDdb4BQSvJncxebgen_zox*pTL4zK3S8Qd+avJ~oMS_VRt|geh)gep2nW*?u;C z!`jxpDds@lPvk>q=*r<17h7*vT>LuYFkW|qn7I&F!Z0|>c+qhszb4(yln~?wch7XB z<|^3{p;<*Iuca;C2!qI13+Y`=*R-v4=6HDQV45A2*+IdxgICa+*}*8v4zhZ7_Q&VM zx%gb`gHvsQ<@WjGs((0BqYs{c_Jwxl9tT!BfOGgPZ90D_JC`0Ry1-8l&m8~8xs^Kg z@=7ydJn~zxsD=Thb$=yvY&5d`^j`o{yQtqXGX1umR^HSyeOv!nmnMDNhO{zcTmPxH z4o&}FZ=1|u<_%4jGVP*`HFI0vuWjom(ApSf)|A#ZKhf57X0g%@N0teLRx{0a6 z945LKtJpu18GsJ6Kee!0yQvlQrZkL3)k=Aq`eUl^sQ$R>52^ly>Q_{MkLnMr{$ACu zs{Xz{4ovJqH8ZR=M`Wu*R-aVViQ=rJwS+cgjOYA83{ps#Nw)Y(dJ=Ym*I zkM%}Gt1otU22NZ*=AMDUkQ!htyNySox@%#?3+tin_2GNtbJ;3i4`V-B^U3bA`Fo>t z8IM_iV(Y27Oz@P8E1dbeOwNBUn}5ve^B^`om#u{{mzy#7MYz5pVI<5C88!~jWvS$? zVz7q>fspg&> z`V_!AbbY?kMqFWfrcY_VK@9$#`3vK+0bQ9};Q34LdzVOU8QJ_Dt#0JT>~pyuiYG9q zz`X-UHyArAo2{c!)SPV)Z1}C~NBfT3x}I0&B1F$d(rnEy5X z$9+G9rTG5rT$D6|D4j>AyK&d`$hqixdx#n_6VmmD&KYeuRc73Tn@wrIRF~`XS~ta$ zKoum8yrdYsqkE(AX478{Qr-zE1^(9vr!vQo1QgpRF;2;?ZLOw%CDXlabt|{xpvm(x z4Ah6Y=A%t6);XmcUj*@`RE&YBpnHhD1f~c~64*~*AHZ^bDmRiex5&tIJ50%<1!auf zjCh=vKxm!kPtrvmt;94R;#j3r#?^tj)6Gksv(yR3f)p*DCUAggmz4YD<=DTP2$tsd zYLde6a+|BKhc6i-eaSezk-jl1w8H4~;Z!EAj%@3t+)8=0CMM{VX9&>sz0n@qMOZI% z;L^HMgwr*Q!@4f_Q;rF-*>EvVe8J0Wjea08)(kwwQ|*{>fJxWuH4_6 zc-YpK)nRvxd=zCsa<%X8`2-8o2j}!OHz9}Igb)NVlAAygFHz4H@OcsUvVn!kAvCgf zGT8KI0GY+~O-+=gwym`-sozBOmd2)*j=rgVtbd|?q9dnM7I37=ENQ9r(zdB`fW5A@ z%b9&bJ6~%%S^1WIQ@>cg=wwc2WhGYn)V!_h+T{b8Azs1i5VNpexs56-GsPlihB2-Z zPF9B3+reMJRxh|JJ3xRAT-Yfk;xI-)4P5d-H(4uEv3v4F`q=QC4o2lLZf(jgIS^ar zMe66QACM_1L!$1?YLY~V=I?KCQ~-j4c}wsXf}slPDq_cA>qxS8B<+-31y8$aK{1b_ z6|{P!K=P}Ibk$y%qdW67CZ5hxF3&)f^{Rl44Q9epXm2MEcH zK)NCbQ;vekY)0}~7Z;SI z`2mJIe7!nT@2GCTKvu82;d-2qvGkbPeUDjph3d7n5Yark-cRI5TvXobs0$CDb%At` z^R9+bPI_DrDOv z>D>goQm^-|Yk=vF`n#op;#k>8!L$?{>id5%z57yqUX~w9hh+qp7bwbn8{i{KhkM)l zP5qWe$uQF;65>x#{Y&Gvf%s9PH2FhR&zHw88w_PaU0<%7`7^jtgbaLj1Mi%)6*PG2 zC&G`Dcwi=sS~ymg52Z!pLks1CrH|>xj`f9M8uzO^ zTL=2VTsHRv_L&dYT|W(3uq~E z;C$W2<*G&v*~}_a7PFDZ=+FAX-m%Ix%v8~~#vZ)1coW-uIR|}>@;DVoi)bm85~$L( zTgWofloslCecQNA?_z!HZ#@q>K@Tw%{2vPU-8@iag~tYe>CT+$#+^=K{xG@1G1beD z@<5M;M}{#RS`_#l{7!i7J3_tG(N}>WMsQz288x0Vm}A?2sJj|c98qF&h$%` zzE5}c4?_rjM-A*xvy2_wm-N_ z&Wo_5d^9hIsftL-ZR_PCbi54=Z@jUItPrEX?&`;z+&vYh2yBREs74ivdUvUZfLFy@ zoLnu|AGDEILH-YDsSAK0!}kR}Kg1G&Qvms>?+0<5$bkM4m16x6Egc~6 zBCT6gP>BkSs|2VN6(#`+4$4PV3=+dMByY&8kc$i(O;lhB6J*zdBrqVAzWa<~!9`2w z04f^9Vp)cT0A{>m==NuZVcf^?bE9PZXVNzAO_Yt#Y^U2B{! z<5Lzz9QF@zvX8lOfWN>G;Mn5E0qub^5)y){wm0jIM31JY`cd6q)mK$hw%aWN?T6of z8hqa%vF~Vp@`uMdnlG8q=!sbdnajL0(sx(` z=3M4;k2O&mc&x?N&YXUO(OnX*U%;897gB8%MS}8bbC(181qhtz#N z6&XMaL1Z%JbSv^D@Fdt`JX?}{`%#|C)#Fdu5Dh) z#9p}&WxNx=@>A#(c}Zq;PAI9ZnnaeeXO41bUQOrlz)XXYUGkdtLSMUbQgNaEC`yY| zMUg*xJr-3JMdj}@0$Hl0leG*QR2CG>1q27YFxPn8EJ9gHIVOs zP?H(JMuO%58%dd3NYuziuma5s9q|Etv^&VkSm}m3sdyAi(Z)5~z!JA{gI7*eHBoN^ zYz#_q6bpuEC0^d4EqZ$666A6eNnc-2!)xovF)QaHIPOqJC!A*2pn>-nli0YzHk^Wew(o zmIpE8)i$HGH@Bx=?Lmw*<6kKRVE}E13-C}LsIvO33bu+#lJSWQ0~zO4mKFy=Sp;!V ziE=;QPqS2=1T2+Gr27!qWzlV!L%tRV@}ZINNZc%gEt5vwJzmICgb#zA|2ZO1W=Ho* zJ=ijl@Rp7m3ND9b#sa{*gnJ2c3F8@86H%-nS%vU0*hP~#1Klw%g0x_1!sSx6S1NvG z!Jcu(<>>)rK2MYk)Q|@50Xhf$zqir@@pQ;p-8+jj~^$T zv&CUoHdWnV4^Ch8ZIH@9a7?i{OHTp!ek`E}J`4_n>Czf713fP~5J!8E%UGs#i*cE7 zFta#-=*Ou8cuiYw5%0n10z!d6p;#*>d1P$nGHfi0g#+9v?SWl#l52;hhmp2XTmwT@ zT9kPjXWBc60g?zC+BMv3TJ>8-icr7bT*xUqz7)-A;Fhs&75rF5hLV+RGL>h20R|SU zsJaOtU&fWgxZ>L5$5mV~-N2ZKiO@3^iU+Vw zTl*6JMj0=_dC~YV8uFM)?c*nc)8BSua|vU7cptPSmiZ0K2hRo82Hm9AKjKT63Vp-) y;&13ZHD%EhE#qJ9E%I1=3|r0qm0L^^G>rL=7WMQAhLps=o$svEj_q&Ux&0p@@;XES From f1c91f8ae36e5a6b7ac2f799310a037d19d6ed78 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Thu, 8 Aug 2024 18:45:04 +0200 Subject: [PATCH 33/36] Address Andrej's PR comments --- train_llama3.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/train_llama3.py b/train_llama3.py index f20b2c343..054c3ae0b 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -13,11 +13,7 @@ # 3) https://github.com/meta-llama/llama3/blob/11817d47e1ba7a4959b025eb1ca308572e0e3963/llama/generation.py Example launches to only benchmark the speed of bfloat16 compiled GPU training: -1 GPU: -python train_llama3.py --write_tensors=0 --num_iterations=50 --sequence_length=8192 --compile=1 --tensorcores=1 --dtype=bfloat16 -you can also turn on flash-attention by appending --flash=1 -4 GPU: -torchrun --standalone --nproc_per_node=4 train_llama3.py --write_tensors=0 --num_iterations=50 --sequence_length=8192 --compile=1 --tensorcores=1 --dtype=bfloat16 +TODO: add the actual commands """ import os @@ -134,6 +130,9 @@ def precompute_freqs_cis( # ----------------------------------------------------------------------------- # LLaMA building blocks +# LLaMA reference code explicitly implemented RMSNorm so we copy pasted it +# (https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py) +# we could also use nn.RMSNorm, it has slightly different numeric properties, but equivalent class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() @@ -161,14 +160,13 @@ def __init__(self, config): self.c_attn = nn.Linear(config.n_embd, (config.n_head + 2 * config.n_kv_head) * self.hd, bias=False) # key, query, value projections self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) # output projection - self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 # static KV cache - we could alternatively allocate it outside of the model and just pass it in when needed if self.use_kv: self.cache_k = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd)) self.cache_v = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd)) - def forward(self, x, freqs_cis=None, start_pos=None, mask: Optional[torch.Tensor] = None): + def forward(self, x, freqs_cis=None, start_pos=None, mask=None): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim qkv = self.c_attn(x) @@ -216,7 +214,6 @@ def __init__(self, config): self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=False) self.c_fc2 = nn.Linear(config.n_embd, hidden_dim, bias=False) self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False) - self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 def forward(self, x): # SwiGLU self.c_proj(F.silu(self.c_fc2(x)) * self.c_fc(x)) <-- 3. difference compared to GPT-2 @@ -236,7 +233,7 @@ def __init__(self, config): self.ln_2 = RMSNorm(config.n_embd, config.norm_eps) self.mlp = MLP(config) - def forward(self, x, freqs_cis=None, start_pos=None, mask: Optional[torch.Tensor] = None): + def forward(self, x, freqs_cis=None, start_pos=None, mask=None): x = x + self.attn(self.ln_1(x), freqs_cis, start_pos, mask) x = x + self.mlp(self.ln_2(x)) return x @@ -542,9 +539,7 @@ def generate( next_token = next_token.reshape(-1) # only replace token if prompt has already been generated - next_token = torch.where( - input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token - ) + next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token) tokens[:, cur_pos] = next_token if logprobs: token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( From 8b672ffcb9cf8d4b8df9f0f65b3a0eab6d5e5a18 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Thu, 8 Aug 2024 18:55:17 +0200 Subject: [PATCH 34/36] Add data loader not implemented exception --- train_llama3.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/train_llama3.py b/train_llama3.py index 054c3ae0b..5cb7b9716 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -780,6 +780,7 @@ def _split_whitespaces_or_nonwhitespaces( # Our own simple Distributed Data Loader def _peek_data_shard(filename): + raise NotImplementedError("_peek_data_shard not yet implemented for llama 3") # only reads the header, returns header data with open(filename, "rb") as f: # first read the header, which is 256 int32 integers (4 bytes each) @@ -795,6 +796,7 @@ def _peek_data_shard(filename): return ntok # for now just return the number of tokens def _load_data_shard(filename): + raise NotImplementedError("_load_data_shard not yet implemented for llama 3") with open(filename, "rb") as f: # first read the header, which is 256 int32 integers (4 bytes each) header = np.frombuffer(f.read(256*4), dtype=np.int32) From c5c87fc4e7523aaaea61a8167ddb028f829bc533 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Thu, 8 Aug 2024 20:54:17 +0200 Subject: [PATCH 35/36] Add comments, fix stop tokens --- train_llama3.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/train_llama3.py b/train_llama3.py index 5cb7b9716..4013db7da 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -321,6 +321,8 @@ def forward(self, idx, targets=None, return_logits=True, start_pos=0): @staticmethod def adapt_llama_state_dict_keys(checkpoint, config: LlamaConfig): + # Modify key names from Meta's LLaMA to our LLaMA + # our key names are derived from GPT-2's key names checkpoint['transformer.wte.weight'] = checkpoint.pop('tok_embeddings.weight') for i in range(config.n_layer): @@ -355,6 +357,8 @@ def adapt_llama_state_dict_keys(checkpoint, config: LlamaConfig): @staticmethod def adapt_llama_state_dict_keys_hf(checkpoint, config: LlamaConfig): + # Modify key names from HuggingFace's LLaMA to our LLaMA + # our key names are derived from GPT-2's key names checkpoint['transformer.wte.weight'] = checkpoint.pop('model.embed_tokens.weight') # We need to unpermute K and V because HF script permuted the original Meta-LLaMA weights @@ -432,9 +436,6 @@ def from_pretrained_llama3_meta(cls, ckpt_dir, tokenizer_path): torch.set_default_tensor_type(torch.tensor([], dtype=original_default_type, device="cpu").type()) # restore default type tokenizer = Tokenizer(model_path=tokenizer_path) - # add <|end_of_text|> as the stop token for base model - this is an omission in the reference code - # the reference code only adds instruct model stop tokens... - tokenizer.stop_tokens = tokenizer.stop_tokens + [128001] model.tokenizer = tokenizer return model @@ -676,9 +677,10 @@ def __init__(self, model_path: str): self.eom_id: int = self.special_tokens["<|eom_id|>"] self.python_tag_id = self.special_tokens["<|python_tag|>"] self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"] + # hardcoded stop tokens for the base model self.stop_tokens = [ - self.special_tokens["<|eom_id|>"], - self.special_tokens["<|eot_id|>"], + self.special_tokens["<|begin_of_text|>"], + self.special_tokens["<|end_of_text|>"], ] def encode( From d773c88e49c68b894a4291d17815e17160afa250 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Thu, 8 Aug 2024 21:06:01 +0200 Subject: [PATCH 36/36] Remove unnecessary comment --- train_llama3.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/train_llama3.py b/train_llama3.py index 4013db7da..31596c306 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -740,15 +740,6 @@ def encode( return t def decode(self, t: Sequence[int]) -> str: - """ - Decodes a list of token IDs into a string. - - Args: - t (List[int]): The list of token IDs to be decoded. - - Returns: - str: The decoded string. - """ # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. return self.model.decode(cast(List[int], t))