Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use pyproject.toml, pdm and ruff for improved reproducibility and cleaner code #40

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__pycache__/
.DS_Store
.venv/
models/**/*
*.pytest_cache
*.model
*.vocab
*.vocab
1 change: 1 addition & 0 deletions .pdm-python
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.venv/bin/python
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ tokenizer.save("toy")
# writes two files: toy.model (for loading) and toy.vocab (for viewing)
```

According to Wikipedia, running bpe on the input string: "aaabdaaabac" for 3 merges results in the string: "XdXac" where X=ZY, Y=ab, and Z=aa. The tricky thing to note is that minbpe always allocates the 256 individual bytes as tokens, and then merges bytes as needed from there. So for us a=97, b=98, c=99, d=100 (their [ASCII](https://www.asciitable.com) values). Then when (a,a) is merged to Z, Z will become 256. Likewise Y will become 257 and X 258. So we start with the 256 bytes, and do 3 merges to get to the result above, with the expected output of [258, 100, 258, 97, 99].
According to Wikipedia, running bpe on the input string: "aaabdaaabac" for 3 merges results in the string: "XdXac" where X=ZY, Y=ab, and Z=aa. The tricky thing to note is that minbpe always allocates the 256 individual bytes as tokens, and then merges bytes as needed from there. So for us a=97, b=98, c=99, d=100 (their [ASCII](https://www.asciitable.com) values). Then when (a,a) is merged to Z, Z will become 256. Likewise Y will become 257 and X 258. So we start with the 256 bytes, and do 3 merges to get to the result above, with the expected output of [258, 100, 258, 97, 99].

## inference: GPT-4 comparison

Expand Down
19 changes: 9 additions & 10 deletions exercise.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# exercise
# Exercise

Build your own GPT-4 Tokenizer!

### Step 1
## Step 1

Write the `BasicTokenizer` class, with the following three core functions:

Expand All @@ -12,20 +12,19 @@ Write the `BasicTokenizer` class, with the following three core functions:

Train your tokenizer on whatever text you like and visualize the merged tokens. Do they look reasonable? One default test you may wish to use is the text file `tests/taylorswift.txt`.

### Step 2
## Step 2

Convert you `BasicTokenizer` into a `RegexTokenizer`, which takes a regex pattern and splits the text exactly as GPT-4 would. Process the parts separately as before, then concatenate the results. Retrain your tokenizer and compare the results before and after. You should see that you will now have no tokens that go across categories (numbers, letters, punctuation, more than one whitespace). Use the GPT-4 pattern:

```
```python
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
```


### Step 3
## Step 3

You're now ready to load the merges from the GPT-4 tokenizer and show that your tokenizer produces the identical results for both `encode` and `decode`, matching [tiktoken](https://github.com/openai/tiktoken).

```
```python
# match this
import tiktoken
enc = tiktoken.get_encoding("cl100k_base") # this is the GPT-4 tokenizer
Expand All @@ -38,18 +37,18 @@ Unfortunately, you will run into two issues:
1. It is not trivial to recover the raw merges from the GPT-4 tokenizer. You can easily recover what we call `vocab` here, and what they call and store under `enc._mergeable_ranks`. Feel free to copy paste the `recover_merges` function in `minbpe/gpt4.py`, which takes these ranks and returns the raw merges. If you wish to know how this function works, read [this](https://github.com/openai/tiktoken/issues/60) and [this](https://github.com/karpathy/minbpe/issues/11#issuecomment-1950805306). Basically, under some conditions it is enough to only store the parent nodes (and their rank) and get rid of the precise details of which children merged up to any parent.
2. Second, the GPT-4 tokenizer for some reason permutes its raw bytes. It stores this permutation in the first 256 elements of the mergeable ranks, so you can recover this byte shuffle relatively simply as `byte_shuffle = {i: enc._mergeable_ranks[bytes([i])] for i in range(256)}`. In both your encode and decode, you'll have to shuffle bytes around accordingly. If you're stuck, reference the minbpe/gpt4.py` file for hints.

### Step 4
## Step 4

(Optional, irritating, not obviously useful) Add the ability to handle special tokens. You'll then be able to match the output of tiktoken even when special tokens are present, e.g.:

```
```python
import tiktoken
enc = tiktoken.get_encoding("cl100k_base") # this is the GPT-4 tokenizer
ids = enc.encode("<|endoftext|>hello world", allowed_special="all")
```

Without `allowed_special` tiktoken will error.

### Step 5
## Step 5

If you've made it this far, you're now a pro at LLM Tokenization! Sadly, you're not exactly done yet because a lot of LLMs outside of OpenAI (e.g. Llama, Mistral) use [sentencepiece](https://github.com/google/sentencepiece) instead. Primary difference being that sentencepiece runs BPE directly on Unicode code points instead of on UTF-8 encoded bytes. Feel free to explore sentencepiece on your own (good luck, it's not too pretty), and stretch goal if you really experience and suffer from the burden of time, re-write your BPE to be on Unicode code points and match the Llama 2 tokenizer.
4 changes: 3 additions & 1 deletion minbpe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .base import Tokenizer
from .basic import BasicTokenizer
from .regex import RegexTokenizer
from .gpt4 import GPT4Tokenizer
from .regex import RegexTokenizer

__all__ = ["BasicTokenizer", "RegexTokenizer", "GPT4Tokenizer", "Tokenizer"]
35 changes: 20 additions & 15 deletions minbpe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
# -----------------------------------------------------------------------------
# a few helper functions useful for both BasicTokenizer and RegexTokenizer


def get_stats(ids, counts=None):
"""
Given a list of integers, return a dictionary of counts of consecutive pairs
Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
Optionally allows to update an existing dictionary of counts
"""
counts = {} if counts is None else counts
for pair in zip(ids, ids[1:]): # iterate consecutive elements
for pair in zip(ids, ids[1:]): # iterate consecutive elements
counts[pair] = counts.get(pair, 0) + 1
return counts

Expand All @@ -28,17 +29,18 @@ def merge(ids, pair, idx):
of pair with the new integer token idx
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
"""
newids = []
new_ids = []
i = 0
while i < len(ids):
# if not at the very last position AND the pair matches, replace it
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
newids.append(idx)
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i + 1] == pair[1]:
new_ids.append(idx)
i += 2
else:
newids.append(ids[i])
new_ids.append(ids[i])
i += 1
return newids
return new_ids


# first two helper functions...
def replace_control_characters(s: str) -> str:
Expand All @@ -49,29 +51,32 @@ def replace_control_characters(s: str) -> str:
chars = []
for ch in s:
if unicodedata.category(ch)[0] != "C":
chars.append(ch) # this character is ok
chars.append(ch) # this character is ok
else:
chars.append(f"\\u{ord(ch):04x}") # escape
chars.append(f"\\u{ord(ch):04x}") # escape
return "".join(chars)


def render_token(t: bytes) -> str:
# pretty print a token, escaping control characters
s = t.decode('utf-8', errors='replace')
s = t.decode("utf-8", errors="replace")
s = replace_control_characters(s)
return s


# -----------------------------------------------------------------------------
# the base Tokenizer class


class Tokenizer:
"""Base class for Tokenizers"""

def __init__(self):
# default: vocab size of 256 (all bytes), no merges, no patterns
self.merges = {} # (int, int) -> int
self.pattern = "" # str
self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
self.vocab = self._build_vocab() # int -> bytes
self.merges = {} # (int, int) -> int
self.pattern = "" # str
self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
self.vocab = self._build_vocab() # int -> bytes

def train(self, text, vocab_size, verbose=False):
# Tokenizer can train a vocabulary of size vocab_size from text
Expand Down Expand Up @@ -103,7 +108,7 @@ def save(self, file_prefix):
"""
# write the model: to be used in load() later
model_file = file_prefix + ".model"
with open(model_file, 'w') as f:
with open(model_file, "w") as f:
# write the version, pattern and merges, that's all that's needed
f.write("minbpe v1\n")
f.write(f"{self.pattern}\n")
Expand Down Expand Up @@ -144,7 +149,7 @@ def load(self, model_file):
merges = {}
special_tokens = {}
idx = 256
with open(model_file, 'r', encoding="utf-8") as f:
with open(model_file, "r", encoding="utf-8") as f:
# read the version
version = f.readline().strip()
assert version == "minbpe v1"
Expand Down
23 changes: 12 additions & 11 deletions minbpe/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@


class BasicTokenizer(Tokenizer):

def __init__(self):
super().__init__()

Expand All @@ -22,12 +21,12 @@ def train(self, text, vocab_size, verbose=False):
num_merges = vocab_size - 256

# input text preprocessing
text_bytes = text.encode("utf-8") # raw bytes
ids = list(text_bytes) # list of integers in range 0..255
text_bytes = text.encode("utf-8") # raw bytes
ids = list(text_bytes) # list of integers in range 0..255

# iteratively merge the most common pairs to create new tokens
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
for i in range(num_merges):
# count up the number of times every consecutive pair appears
stats = get_stats(ids)
Expand All @@ -42,11 +41,13 @@ def train(self, text, vocab_size, verbose=False):
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# prints
if verbose:
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
print(
f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences"
)

# save class variables
self.merges = merges # used in encode()
self.vocab = vocab # used in decode()
self.merges = merges # used in encode()
self.vocab = vocab # used in decode()

def decode(self, ids):
# given ids (list of integers), return Python string
Expand All @@ -56,8 +57,8 @@ def decode(self, ids):

def encode(self, text):
# given a string text, return the token ids
text_bytes = text.encode("utf-8") # raw bytes
ids = list(text_bytes) # list of integers in range 0..255
text_bytes = text.encode("utf-8") # raw bytes
ids = list(text_bytes) # list of integers in range 0..255
while len(ids) >= 2:
# find the pair with the lowest merge index
stats = get_stats(ids)
Expand All @@ -67,7 +68,7 @@ def encode(self, text):
# just the first pair in the list, arbitrarily
# we can detect this terminating case by a membership check
if pair not in self.merges:
break # nothing else can be merged anymore
break # nothing else can be merged anymore
# otherwise let's merge the best pair (lowest merge index)
idx = self.merges[pair]
ids = merge(ids, pair, idx)
Expand Down
24 changes: 16 additions & 8 deletions minbpe/gpt4.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import tiktoken

from .regex import RegexTokenizer


Expand All @@ -22,7 +23,11 @@ def bpe(mergeable_ranks, token, max_rank):
if min_rank is None or (max_rank is not None and min_rank >= max_rank):
break
assert min_idx is not None
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
parts = (
parts[:min_idx]
+ [parts[min_idx] + parts[min_idx + 1]]
+ parts[min_idx + 2 :]
)
return parts


Expand All @@ -35,7 +40,7 @@ def recover_merges(mergeable_ranks):
merges = {}
for token, rank in mergeable_ranks.items():
if len(token) == 1:
continue # skip raw bytes
continue # skip raw bytes
pair = tuple(bpe(mergeable_ranks, token, max_rank=rank))
assert len(pair) == 2
# recover the integer ranks of the pair
Expand All @@ -45,15 +50,17 @@ def recover_merges(mergeable_ranks):

return merges


GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
GPT4_SPECIAL_TOKENS = {
'<|endoftext|>': 100257,
'<|fim_prefix|>': 100258,
'<|fim_middle|>': 100259,
'<|fim_suffix|>': 100260,
'<|endofprompt|>': 100276
"<|endoftext|>": 100257,
"<|fim_prefix|>": 100258,
"<|fim_middle|>": 100259,
"<|fim_suffix|>": 100260,
"<|endofprompt|>": 100276,
}


class GPT4Tokenizer(RegexTokenizer):
"""Lightweight wrapper on RegexTokenizer that matches GPT-4's tokenizer."""

Expand All @@ -71,7 +78,7 @@ def __init__(self):
self.vocab = vocab
# now here is another tricky part.
# for some reason, the tokens corresponding to individual bytes
# are permuted in a different order. This is completely non-sensical
# are permuted in a different order. This is completely nonsensical
# and probably historical, but therefore we have to deal with it here.
self.byte_shuffle = {i: mergeable_ranks[bytes([i])] for i in range(256)}
self.inverse_byte_shuffle = {v: k for k, v in self.byte_shuffle.items()}
Expand Down Expand Up @@ -112,6 +119,7 @@ def save_vocab(self, vocab_file):
# simple run as:
# python -c "from minbpe import GPT4Tokenizer; GPT4Tokenizer().save_vocab('gpt4.vocab')"
from .base import render_token

# build vocab being mindful of the byte shuffle
vocab = {idx: bytes([self.inverse_byte_shuffle[idx]]) for idx in range(256)}
for (p0, p1), idx in self.merges.items():
Expand Down
27 changes: 16 additions & 11 deletions minbpe/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@
"""

import regex as re
from .base import Tokenizer, get_stats, merge

from .base import Tokenizer, get_stats, merge

# the main GPT text split patterns, see
# https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py
GPT2_SPLIT_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
GPT2_SPLIT_PATTERN = (
r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""


class RegexTokenizer(Tokenizer):

def __init__(self, pattern=None):
"""
- pattern: optional string to override the default (GPT-4 split pattern)
Expand All @@ -44,8 +45,8 @@ def train(self, text, vocab_size, verbose=False):
ids = [list(ch.encode("utf-8")) for ch in text_chunks]

# iteratively merge the most common pairs to create new tokens
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
for i in range(num_merges):
# count the number of times every consecutive pair appears
stats = {}
Expand All @@ -63,11 +64,13 @@ def train(self, text, vocab_size, verbose=False):
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# prints
if verbose:
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
print(
f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences"
)

# save class variables
self.merges = merges # used in encode()
self.vocab = vocab # used in decode()
self.merges = merges # used in encode()
self.vocab = vocab # used in decode()

def register_special_tokens(self, special_tokens):
# special_tokens is a dictionary of str -> int
Expand Down Expand Up @@ -102,7 +105,7 @@ def _encode_chunk(self, text_bytes):
# just the first pair in the list, arbitrarily
# we can detect this terminating case by a membership check
if pair not in self.merges:
break # nothing else can be merged anymore
break # nothing else can be merged anymore
# otherwise let's merge the best pair (lowest merge index)
idx = self.merges[pair]
ids = merge(ids, pair, idx)
Expand All @@ -115,7 +118,7 @@ def encode_ordinary(self, text):
# all chunks of text are encoded separately, then results are joined
ids = []
for chunk in text_chunks:
chunk_bytes = chunk.encode("utf-8") # raw bytes
chunk_bytes = chunk.encode("utf-8") # raw bytes
chunk_ids = self._encode_chunk(chunk_bytes)
ids.extend(chunk_ids)
return ids
Expand All @@ -138,7 +141,9 @@ def encode(self, text, allowed_special="none_raise"):
special = {}
assert all(token not in text for token in self.special_tokens)
elif isinstance(allowed_special, set):
special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}
special = {
k: v for k, v in self.special_tokens.items() if k in allowed_special
}
else:
raise ValueError(f"allowed_special={allowed_special} not understood")
if not special:
Expand Down
Loading