-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathv2.py
149 lines (114 loc) · 4.88 KB
/
v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""Bigram model implementation."""
from enum import Enum
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from data import fetch_input_data
class DataSplit(Enum):
TRAIN = "train"
VAL = "val"
# Hyperparameters
batch_size = 32 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?
max_epochs = 3000
learning_rate = 1e-2
eval_interval = 300
eval_iters = 200
m_embd = 32
# ---------------
mx.random.seed(1337)
# Data
INPUT_DATA_URL = ("https://raw.githubusercontent.com/karpathy/char-rnn/"
"master/data/tinyshakespeare/input.txt")
input_text = fetch_input_data(INPUT_DATA_URL, "input.txt")
# ----
# here are all the unique characters that occur in this text
chars = sorted(list(set(input_text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
# encoder: take a string, output a list of integers
encode = lambda s: [stoi[c] for c in s]
# decoder: take a list of integers, output a string
decode = lambda l: ''.join([itos[i] for i in l])
# train and test splits
data = mx.array(encode(input_text), dtype=mx.int64)
n = int(0.9 * len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]
def get_batch(split: DataSplit) -> tuple[mx.array, mx.array]:
"""Generate a small batch of data of inputs x and targets y"""
data = train_data if split == DataSplit.TRAIN else val_data
ix = mx.random.randint(0, len(data) - block_size, [batch_size])
# gets `batch_size` blocks stacked
x = mx.stack([data[i.item():i.item() + block_size] for i in ix])
# it's shifted to compute the target vectorized
y = mx.stack([data[i.item() + 1:i.item() + block_size + 1] for i in ix])
return x, y
class Head(nn.Module):
def __init__(self, head_size: int):
super().__init__()
self.key = nn.Linear(m_embd, head_size, bias=False)
self.query = nn.Linear(m_embd, head_size, bias=False)
self.value = nn.Linear(m_embd, head_size, bias=False)
# This should be equivalent to `register_buffer` in PyTorch.
# Thanks https://github.com/awni for the tip!
self._tril = mx.tril(mx.ones((block_size, block_size)))
def __call__(self, x: mx.array) -> mx.array:
# TODO: Continue here
return x
class BigramLanguageModel(nn.Module):
"""Super-simple Bigram model."""
def __init__(self):
super().__init__()
self.token_embedding_table = nn.Embedding(vocab_size, m_embd)
self.position_embedding_table = nn.Embedding(block_size, m_embd)
self.lm_head = nn.Linear(m_embd, vocab_size)
def __call__(self, idx: mx.array) -> mx.array:
B, T = idx.shape
token_embeddings = self.token_embedding_table(idx) # (B, T, m_embd)
position_embeddings = self.position_embedding_table(
mx.arange(T, dtype=mx.int64)) # (T, m_embd)
x = token_embeddings + position_embeddings
return self.lm_head(x) # (B, T, vocab_size)
def generate(self, idx: mx.array, max_new_tokens: int) -> mx.array:
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
# get the predictions
logits = self(idx)
# focus only on the last time step
logits = logits[:, -1, :]
# sample from the distribution
idx_next = mx.random.categorical(logits, num_samples=1, axis=-1)
# append sampled index to the running sequence
idx = mx.concatenate([idx, idx_next], axis=1)
# this is actually going to return 101 tokens since the input counts
return idx
def loss_fn(model: nn.Module, x: mx.array, y: mx.array) -> mx.array:
return mx.mean(nn.losses.cross_entropy(model(x), y))
def estimate_loss():
out = {}
for split in [DataSplit.TRAIN, DataSplit.VAL]:
losses = mx.zeros(eval_iters)
for k in range(eval_iters):
xb, yb = get_batch(split)
loss = loss_fn(model, xb, yb)
losses[k] = loss.item()
out[split] = losses.mean()
return out
model = BigramLanguageModel()
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
optimizer = optim.AdamW(learning_rate=learning_rate)
for epoch in range(max_epochs):
# every once in a while evaluate the loss on train and val sets
if epoch % eval_interval == 0:
losses = estimate_loss()
print(f"step {epoch}: train loss {losses[DataSplit.TRAIN]}, val loss {losses[DataSplit.VAL]}")
xb, yb = get_batch(DataSplit.TRAIN)
loss, grads = loss_and_grad_fn(model, xb, yb)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
# generate from the model
context = mx.zeros((1, 1), dtype=mx.int64)
print(decode(model.generate(context, max_new_tokens=500)[0].tolist()))