Skip to content

Commit

Permalink
Adding nanoGPT-JAX blog
Browse files Browse the repository at this point in the history
---------

Co-authored-by: phildangamd <[email protected]>
Co-authored-by: Eliot Li <[email protected]>
  • Loading branch information
3 people committed Jul 2, 2024
1 parent a7253d0 commit d660f0d
Show file tree
Hide file tree
Showing 23 changed files with 1,845 additions and 1 deletion.
2 changes: 1 addition & 1 deletion blogs/artificial-intelligence/mamba/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ myst:

# Mamba on AMD GPUs with ROCm

<span style="font-size:0.7em;">28, Jun 2024 by Sean Song, Jassani Adeem and Moskvichev Arseny</span>
<span style="font-size:0.7em;">28, Jun 2024 by {hoverxref}`Sean Song<seansong>`, {hoverxref}`Jassani Adeem<jassadee>`, {hoverxref}`Moskvichev Arseny<moskarse>`. </span>

Recently, [Mamba](https://arxiv.org/abs/2312.00752) introduced a novel architecture that not only surpasses the Transformers in modeling effectiveness but also achieves linear scaling to the input sequence length.

Expand Down
21 changes: 21 additions & 0 deletions blogs/artificial-intelligence/nanoGPT-JAX/LICENSE.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2024 Advanced Micro Devices, Inc.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
878 changes: 878 additions & 0 deletions blogs/artificial-intelligence/nanoGPT-JAX/README.md

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions blogs/artificial-intelligence/nanoGPT-JAX/config/eval_gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# evaluate the base gpt2
# n_layer=12, n_head=12, n_embd=768
# 124M parameters
batch_size = 8
eval_iters = 500 # use more iterations to get good estimate
eval_only = True
wandb_log = False
init_from = 'gpt2'
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# evaluate the base gpt2
# n_layer=36, n_head=20, n_embd=1280
# 774M parameters
batch_size = 8
eval_iters = 500 # use more iterations to get good estimate
eval_only = True
wandb_log = False
init_from = 'gpt2-large'
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# evaluate the base gpt2
# n_layer=24, n_head=16, n_embd=1024
# 350M parameters
batch_size = 8
eval_iters = 500 # use more iterations to get good estimate
eval_only = True
wandb_log = False
init_from = 'gpt2-medium'
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# evaluate the base gpt2
# n_layer=48, n_head=25, n_embd=1600
# 1558M parameters
batch_size = 8
eval_iters = 500 # use more iterations to get good estimate
eval_only = True
wandb_log = False
init_from = 'gpt2-xl'
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import time

out_dir = 'out-shakespeare'
eval_interval = 50
eval_iters = 40
wandb_log = False # feel free to turn on
wandb_project = 'shakespeare'
wandb_run_name = 'ft-' + str(time.time())

dataset = 'shakespeare'
init_from = 'gpt2-medium' # this is the largest GPT-2 model

# only save checkpoints if the validation loss improves
always_save_checkpoint = False

# the number of examples per iter:
# 1 batch_size * 32 grad_accum * 1024 tokens = 32,768 tokens/iter
# shakespeare has 301,966 tokens, so 1 epoch ~= 9.2 iters
batch_size = 2
gradient_accumulation_steps = 1
max_iters = 5000
dropout = 0.1

# finetune at constant LR
learning_rate = 3e-5
decay_lr = True
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# -----------------------------------------------------------------------------
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
out_dir = 'out-shakespeare' # ignored if init_from is not 'resume'
start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 5 # number of samples to draw
max_new_tokens = 50 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
25 changes: 25 additions & 0 deletions blogs/artificial-intelligence/nanoGPT-JAX/config/train_gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# config for training GPT-2 (124M) down to very nice loss of ~2.85 on 1 node of 8X A100 40GB
# launch as the following (e.g. in a screen session) and wait ~5 days:
# $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py

wandb_log = True
wandb_project = 'owt'
wandb_run_name='gpt2-124M'

# these make the total batch size be ~0.5M
# 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520
batch_size = 12
block_size = 1024
gradient_accumulation_steps = 5 * 8

# this makes total number of tokens be 300B
max_iters = 600000
lr_decay_iters = 600000

# eval stuff
eval_interval = 1000
eval_iters = 200
log_interval = 10

# weight decay
weight_decay = 1e-1
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# train a miniature character-level shakespeare model
# good for debugging and playing on macbooks and such

out_dir = 'out-shakespeare-char'
eval_interval = 250 # keep frequent because we'll overfit
eval_iters = 100
log_interval = 10 # don't print too too often

# we expect to overfit on this small dataset, so only save when val improves
always_save_checkpoint = False

wandb_log = False # override via command line if you like
wandb_project = 'shakespeare-char'
wandb_run_name = 'mini-gpt'

dataset = 'shakespeare_char'
batch_size = 64
block_size = 256 # context of up to 256 previous characters

# baby GPT model :)
n_layer = 6
n_head = 6
n_embd = 384
dropout = 0.2

learning_rate = 1e-3 # with baby networks can afford to go a bit higher
max_iters = 5000
lr_decay_iters = 5000 # make equal to max_iters usually
min_lr = 1e-4 # learning_rate / 10 usually
beta2 = 0.99 # make a bit bigger because number of tokens per iter is small

warmup_iters = 100 # not super necessary potentially

47 changes: 47 additions & 0 deletions blogs/artificial-intelligence/nanoGPT-JAX/configurator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
Poor Man's Configurator. Probably a terrible idea. Example usage:
$ python train.py config/override_file.py --batch_size=32
this will first run config/override_file.py, then override batch_size to 32
The code in this file will be run as follows from e.g. train.py:
>>> exec(open('configurator.py').read())
So it's not a Python module, it's just shuttling this code away from train.py
The code in this script then overrides the globals()
I know people are not going to love this, I just really dislike configuration
complexity and having to prepend config. to every single variable. If someone
comes up with a better simple Python solution I am all ears.
"""

import sys
from ast import literal_eval

for arg in sys.argv[1:]:
if '=' not in arg:
# assume it's the name of a config file
assert not arg.startswith('--')
config_file = arg
print(f"Overriding config with {config_file}:")
with open(config_file) as f:
print(f.read())
exec(open(config_file).read())
else:
# assume it's a --key=value argument
assert arg.startswith('--')
key, val = arg.split('=')
key = key[2:]
if key in globals():
try:
# attempt to eval it it (e.g. if bool, number, or etc)
attempt = literal_eval(val)
except (SyntaxError, ValueError):
# if that goes wrong, just use the string
attempt = val
# ensure the types match ok
assert type(attempt) == type(globals()[key])
# cross fingers
print(f"Overriding: {key} = {attempt}")
globals()[key] = attempt
else:
raise ValueError(f"Unknown config key: {key}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# saves the openwebtext dataset to a binary file for training. following was helpful:
# https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py

import os
from tqdm import tqdm
import numpy as np
import tiktoken
from datasets import load_dataset # huggingface datasets

# number of workers in .map() call
# good number to use is ~order number of cpu cores // 2
num_proc = 8

# number of workers in load_dataset() call
# best number might be different from num_proc above as it also depends on NW speed.
# it is better than 1 usually though
num_proc_load_dataset = num_proc

enc = tiktoken.get_encoding("gpt2")

if __name__ == '__main__':
# takes 54GB in huggingface .cache dir, about 8M documents (8,013,769)
dataset = load_dataset("openwebtext", num_proc=num_proc_load_dataset)

# owt by default only contains the 'train' split, so create a test split
split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True)
split_dataset['val'] = split_dataset.pop('test') # rename the test split to val

# this results in:
# >>> split_dataset
# DatasetDict({
# train: Dataset({
# features: ['text'],
# num_rows: 8009762
# })
# val: Dataset({
# features: ['text'],
# num_rows: 4007
# })
# })

# we now want to tokenize the dataset. first define the encoding function (gpt2 bpe)
def process(example):
ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens
ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe
# note: I think eot should be prepended not appended... hmm. it's called "eot" though...
out = {'ids': ids, 'len': len(ids)}
return out

# tokenize the dataset
tokenized = split_dataset.map(
process,
remove_columns=['text'],
desc="tokenizing the splits",
num_proc=num_proc,
)

# concatenate all the ids in each dataset into one large file we can use for training
for split, dset in tokenized.items():
arr_len = np.sum(dset['len'], dtype=np.uint64)
filename = os.path.join(os.path.dirname(__file__), f'{split}.bin')
dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
total_batches = 1024

idx = 0
for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'):
# Batch together samples for faster write
batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy')
arr_batch = np.concatenate(batch['ids'])
# Write into mmap
arr[idx : idx + len(arr_batch)] = arr_batch
idx += len(arr_batch)
arr.flush()

# train.bin is ~17GB, val.bin ~8.5MB
# train has ~9B tokens (9,035,582,198)
# val has ~4M tokens (4,434,897)

# to read the bin files later, e.g. with numpy:
# m = np.memmap('train.bin', dtype=np.uint16, mode='r')
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

## openwebtext dataset

after running `prepare.py` (preprocess) we get:

- train.bin is ~17GB, val.bin ~8.5MB
- train has ~9B tokens (9,035,582,198)
- val has ~4M tokens (4,434,897)

this came from 8,013,769 documents in total.

references:

- OpenAI's WebText dataset is discussed in [GPT-2 paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)
- [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/) dataset
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
import requests
import tiktoken
import numpy as np

# download the tiny shakespeare dataset
input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt')
if not os.path.exists(input_file_path):
data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
with open(input_file_path, 'w', encoding='utf-8') as f:
f.write(requests.get(data_url).text)

with open(input_file_path, 'r', encoding='utf-8') as f:
data = f.read()
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

# encode with tiktoken gpt2 bpe
enc = tiktoken.get_encoding("gpt2")
train_ids = enc.encode_ordinary(train_data)
val_ids = enc.encode_ordinary(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")

# export to bin files
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin'))
val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin'))

# train.bin has 301,966 tokens
# val.bin has 36,059 tokens
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@

# tiny shakespeare

Tiny shakespeare, of the good old char-rnn fame :)

After running `prepare.py`:

- train.bin has 301,966 tokens
- val.bin has 36,059 tokens
Loading

0 comments on commit d660f0d

Please sign in to comment.