-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
219 lines (174 loc) · 6.55 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import os
import math
import time
import json
import mlx
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_flatten, tree_map
from model import GPTConfig, GPT
from optimizer import AdamW
from tboard_utils import init_tensorboard, get_tensorboard
n_layer = 12
n_head = 12
n_embd = 768
dropout = 0.0
bias = False
d_type = 'float32'
learning_rate = 6.0e-4
min_lr = 6.0e-5
num_iters = 600000
warmup_pct = 0.1
warmup_iters = 2000
lr_decay_iters = 600000
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
meta_vocab_size = None
dataset = 'shakespeare'
batch_size = 1
gradient_accumulation_steps = 512
context_size = 1024
save_interval = 1
eval_interval = 10
log_interval = 10
eval_only = False
out_dir = 'gpt2_shakespeare_pretrain_mlx'
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open('configurator.py').read())
config = {k: globals()[k] for k in config_keys}
data_dir = os.path.join('data', dataset)
def load_binary_data(file_path, dtype):
with open(file_path, 'rb') as f:
data = f.read()
return mx.array(memoryview(data).cast(dtype))
train_data = load_binary_data(os.path.join(data_dir, 'train.bin'), 'H')
val_data = load_binary_data(os.path.join(data_dir, 'val.bin'), 'H')
save_model_path = os.path.join(out_dir, out_dir + '.npz')
save_model_config_path = os.path.join(out_dir, out_dir + '.json')
os.makedirs(out_dir, exist_ok=True)
tboard_dir = os.path.join(out_dir, "tboard_log")
init_tensorboard(tboard_dir)
def get_batch(split):
data = train_data if split == 'train' else val_data
ix = mx.random.randint(0, len(data) - context_size, shape=(batch_size,)).tolist()
x = mx.stack([(mx.array(data[i:i+context_size])) for i in ix]).astype(mx.int64)
y = mx.stack([(mx.array(data[i+1:i+1+context_size])) for i in ix]).astype(mx.int64)
return x, y
def print_loss(optimizer, iteration_count, average_loss, tic):
toc = time.perf_counter()
print(
f"iter {iteration_count}: train loss {average_loss:.3f}, "
f"it/sec {1.0 / (toc - tic):.3f}, "
f"lr {optimizer.learning_rate.item():.9f}"
)
return toc
def update_learning_rate(it):
if it < warmup_iters:
return learning_rate * it / warmup_iters
if it > lr_decay_iters:
return min_lr
decay_ratio = (it - warmup_iters) / (
lr_decay_iters - warmup_iters
)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
new_lr = min_lr + coeff * (learning_rate - min_lr)
return new_lr
def log_tboard_dict(log_dict, itr, pre, post=''):
writer = get_tensorboard()
for k, v in log_dict.items():
writer.add_scalar(f'{pre}/{k}{post}', v, itr)
def main():
# model init
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=context_size,
bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line
# initialize model:
if meta_vocab_size is None:
print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
gptconf = GPTConfig(**model_args)
model = GPT(gptconf)
print(model)
weights = tree_map(lambda p: p.astype(getattr(mx, d_type)), model.parameters())
model.update(weights)
mx.eval(model.parameters())
nparams = sum(
x.size for k, x in tree_flatten(model.parameters()) if "embedding" not in k
)
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
def loss_fn(model, x, y, reduce=True):
logits = model(x)
losses = nn.losses.cross_entropy(
logits.reshape(-1, logits.shape[-1]), y.reshape(-1)
)
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
# setup optimizer
optimizer = AdamW(learning_rate=learning_rate,
betas=[beta1, beta2],
weight_decay=weight_decay)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
def step(inputs, targets, gradient_accumulation_steps):
# gradient accumulation
accumulated_grads = tree_map(
lambda x: mx.zeros_like(x), model.parameters()
)
accumulated_loss = 0.0
for micro_step in range(gradient_accumulation_steps):
loss, grads = loss_and_grad_fn(model, inputs, targets)
accumulated_grads = tree_map(
lambda acc, new: acc + new * (1.0 / gradient_accumulation_steps),
accumulated_grads,
grads,
)
tree_map(
lambda grad: mx.eval(grad),
accumulated_grads,
)
accumulated_loss += loss.item()
# scale the loss to account for gradient accumulation
loss = mx.array(accumulated_loss / gradient_accumulation_steps)
optimizer.update(model, accumulated_grads)
accumulated_grads = tree_map(
lambda x: mx.zeros_like(x), model.parameters()
)
return loss
# fetch the first batch of samples.
X, Y = get_batch('train')
state = [model.state, optimizer.state]
tic = time.perf_counter()
local_iter_num = 0 # number of iterations in the lifetime of this process
iter_num = 0
while True:
if iter_num == 0 and eval_only:
break
# lr schedule
new_lr = update_learning_rate(iter_num)
optimizer.set_learning_rate(new_lr)
# mx.simplify(loss, model.parameters())
loss = step(X, Y, gradient_accumulation_steps)
# immediately async prefetch next batch while model is doing the forward pass on the GPU
X, Y = get_batch('train')
tic = print_loss(optimizer, iter_num, loss.item(), tic)
mx.eval(state)
if iter_num % log_interval == 0:
log_train_dict = {
'loss': loss.item(),
'lr': new_lr
}
log_tboard_dict(log_train_dict, iter_num, 'train')
if iter_num % save_interval == 0:
# save mode weights
flat_params = tree_flatten(model.parameters())
mx.savez(save_model_path, **dict(flat_params))
# save model config
with open(save_model_config_path, "w") as f:
json.dump(model.config.__dict__, f)
iter_num += 1
local_iter_num += 1
# termination conditions
if iter_num > num_iters:
break
if __name__ == "__main__":
main()