-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
--------- Co-authored-by: phildangamd <[email protected]> Co-authored-by: Eliot Li <[email protected]>
- Loading branch information
1 parent
a7253d0
commit d660f0d
Showing
23 changed files
with
1,845 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2024 Advanced Micro Devices, Inc. | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
Large diffs are not rendered by default.
Oops, something went wrong.
8 changes: 8 additions & 0 deletions
8
blogs/artificial-intelligence/nanoGPT-JAX/config/eval_gpt2.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# evaluate the base gpt2 | ||
# n_layer=12, n_head=12, n_embd=768 | ||
# 124M parameters | ||
batch_size = 8 | ||
eval_iters = 500 # use more iterations to get good estimate | ||
eval_only = True | ||
wandb_log = False | ||
init_from = 'gpt2' |
8 changes: 8 additions & 0 deletions
8
blogs/artificial-intelligence/nanoGPT-JAX/config/eval_gpt2_large.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# evaluate the base gpt2 | ||
# n_layer=36, n_head=20, n_embd=1280 | ||
# 774M parameters | ||
batch_size = 8 | ||
eval_iters = 500 # use more iterations to get good estimate | ||
eval_only = True | ||
wandb_log = False | ||
init_from = 'gpt2-large' |
8 changes: 8 additions & 0 deletions
8
blogs/artificial-intelligence/nanoGPT-JAX/config/eval_gpt2_medium.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# evaluate the base gpt2 | ||
# n_layer=24, n_head=16, n_embd=1024 | ||
# 350M parameters | ||
batch_size = 8 | ||
eval_iters = 500 # use more iterations to get good estimate | ||
eval_only = True | ||
wandb_log = False | ||
init_from = 'gpt2-medium' |
8 changes: 8 additions & 0 deletions
8
blogs/artificial-intelligence/nanoGPT-JAX/config/eval_gpt2_xl.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# evaluate the base gpt2 | ||
# n_layer=48, n_head=25, n_embd=1600 | ||
# 1558M parameters | ||
batch_size = 8 | ||
eval_iters = 500 # use more iterations to get good estimate | ||
eval_only = True | ||
wandb_log = False | ||
init_from = 'gpt2-xl' |
26 changes: 26 additions & 0 deletions
26
blogs/artificial-intelligence/nanoGPT-JAX/config/finetune_shakespeare.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import time | ||
|
||
out_dir = 'out-shakespeare' | ||
eval_interval = 50 | ||
eval_iters = 40 | ||
wandb_log = False # feel free to turn on | ||
wandb_project = 'shakespeare' | ||
wandb_run_name = 'ft-' + str(time.time()) | ||
|
||
dataset = 'shakespeare' | ||
init_from = 'gpt2-medium' # this is the largest GPT-2 model | ||
|
||
# only save checkpoints if the validation loss improves | ||
always_save_checkpoint = False | ||
|
||
# the number of examples per iter: | ||
# 1 batch_size * 32 grad_accum * 1024 tokens = 32,768 tokens/iter | ||
# shakespeare has 301,966 tokens, so 1 epoch ~= 9.2 iters | ||
batch_size = 2 | ||
gradient_accumulation_steps = 1 | ||
max_iters = 5000 | ||
dropout = 0.1 | ||
|
||
# finetune at constant LR | ||
learning_rate = 3e-5 | ||
decay_lr = True |
9 changes: 9 additions & 0 deletions
9
blogs/artificial-intelligence/nanoGPT-JAX/config/sample_shake_finetune.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# ----------------------------------------------------------------------------- | ||
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl') | ||
out_dir = 'out-shakespeare' # ignored if init_from is not 'resume' | ||
start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt" | ||
num_samples = 5 # number of samples to draw | ||
max_new_tokens = 50 # number of tokens generated in each sample | ||
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions | ||
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability | ||
seed = 1337 |
25 changes: 25 additions & 0 deletions
25
blogs/artificial-intelligence/nanoGPT-JAX/config/train_gpt2.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# config for training GPT-2 (124M) down to very nice loss of ~2.85 on 1 node of 8X A100 40GB | ||
# launch as the following (e.g. in a screen session) and wait ~5 days: | ||
# $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py | ||
|
||
wandb_log = True | ||
wandb_project = 'owt' | ||
wandb_run_name='gpt2-124M' | ||
|
||
# these make the total batch size be ~0.5M | ||
# 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520 | ||
batch_size = 12 | ||
block_size = 1024 | ||
gradient_accumulation_steps = 5 * 8 | ||
|
||
# this makes total number of tokens be 300B | ||
max_iters = 600000 | ||
lr_decay_iters = 600000 | ||
|
||
# eval stuff | ||
eval_interval = 1000 | ||
eval_iters = 200 | ||
log_interval = 10 | ||
|
||
# weight decay | ||
weight_decay = 1e-1 |
33 changes: 33 additions & 0 deletions
33
blogs/artificial-intelligence/nanoGPT-JAX/config/train_shakespeare_char.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# train a miniature character-level shakespeare model | ||
# good for debugging and playing on macbooks and such | ||
|
||
out_dir = 'out-shakespeare-char' | ||
eval_interval = 250 # keep frequent because we'll overfit | ||
eval_iters = 100 | ||
log_interval = 10 # don't print too too often | ||
|
||
# we expect to overfit on this small dataset, so only save when val improves | ||
always_save_checkpoint = False | ||
|
||
wandb_log = False # override via command line if you like | ||
wandb_project = 'shakespeare-char' | ||
wandb_run_name = 'mini-gpt' | ||
|
||
dataset = 'shakespeare_char' | ||
batch_size = 64 | ||
block_size = 256 # context of up to 256 previous characters | ||
|
||
# baby GPT model :) | ||
n_layer = 6 | ||
n_head = 6 | ||
n_embd = 384 | ||
dropout = 0.2 | ||
|
||
learning_rate = 1e-3 # with baby networks can afford to go a bit higher | ||
max_iters = 5000 | ||
lr_decay_iters = 5000 # make equal to max_iters usually | ||
min_lr = 1e-4 # learning_rate / 10 usually | ||
beta2 = 0.99 # make a bit bigger because number of tokens per iter is small | ||
|
||
warmup_iters = 100 # not super necessary potentially | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
""" | ||
Poor Man's Configurator. Probably a terrible idea. Example usage: | ||
$ python train.py config/override_file.py --batch_size=32 | ||
this will first run config/override_file.py, then override batch_size to 32 | ||
The code in this file will be run as follows from e.g. train.py: | ||
>>> exec(open('configurator.py').read()) | ||
So it's not a Python module, it's just shuttling this code away from train.py | ||
The code in this script then overrides the globals() | ||
I know people are not going to love this, I just really dislike configuration | ||
complexity and having to prepend config. to every single variable. If someone | ||
comes up with a better simple Python solution I am all ears. | ||
""" | ||
|
||
import sys | ||
from ast import literal_eval | ||
|
||
for arg in sys.argv[1:]: | ||
if '=' not in arg: | ||
# assume it's the name of a config file | ||
assert not arg.startswith('--') | ||
config_file = arg | ||
print(f"Overriding config with {config_file}:") | ||
with open(config_file) as f: | ||
print(f.read()) | ||
exec(open(config_file).read()) | ||
else: | ||
# assume it's a --key=value argument | ||
assert arg.startswith('--') | ||
key, val = arg.split('=') | ||
key = key[2:] | ||
if key in globals(): | ||
try: | ||
# attempt to eval it it (e.g. if bool, number, or etc) | ||
attempt = literal_eval(val) | ||
except (SyntaxError, ValueError): | ||
# if that goes wrong, just use the string | ||
attempt = val | ||
# ensure the types match ok | ||
assert type(attempt) == type(globals()[key]) | ||
# cross fingers | ||
print(f"Overriding: {key} = {attempt}") | ||
globals()[key] = attempt | ||
else: | ||
raise ValueError(f"Unknown config key: {key}") |
81 changes: 81 additions & 0 deletions
81
blogs/artificial-intelligence/nanoGPT-JAX/data/openwebtext/prepare.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# saves the openwebtext dataset to a binary file for training. following was helpful: | ||
# https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py | ||
|
||
import os | ||
from tqdm import tqdm | ||
import numpy as np | ||
import tiktoken | ||
from datasets import load_dataset # huggingface datasets | ||
|
||
# number of workers in .map() call | ||
# good number to use is ~order number of cpu cores // 2 | ||
num_proc = 8 | ||
|
||
# number of workers in load_dataset() call | ||
# best number might be different from num_proc above as it also depends on NW speed. | ||
# it is better than 1 usually though | ||
num_proc_load_dataset = num_proc | ||
|
||
enc = tiktoken.get_encoding("gpt2") | ||
|
||
if __name__ == '__main__': | ||
# takes 54GB in huggingface .cache dir, about 8M documents (8,013,769) | ||
dataset = load_dataset("openwebtext", num_proc=num_proc_load_dataset) | ||
|
||
# owt by default only contains the 'train' split, so create a test split | ||
split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True) | ||
split_dataset['val'] = split_dataset.pop('test') # rename the test split to val | ||
|
||
# this results in: | ||
# >>> split_dataset | ||
# DatasetDict({ | ||
# train: Dataset({ | ||
# features: ['text'], | ||
# num_rows: 8009762 | ||
# }) | ||
# val: Dataset({ | ||
# features: ['text'], | ||
# num_rows: 4007 | ||
# }) | ||
# }) | ||
|
||
# we now want to tokenize the dataset. first define the encoding function (gpt2 bpe) | ||
def process(example): | ||
ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens | ||
ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe | ||
# note: I think eot should be prepended not appended... hmm. it's called "eot" though... | ||
out = {'ids': ids, 'len': len(ids)} | ||
return out | ||
|
||
# tokenize the dataset | ||
tokenized = split_dataset.map( | ||
process, | ||
remove_columns=['text'], | ||
desc="tokenizing the splits", | ||
num_proc=num_proc, | ||
) | ||
|
||
# concatenate all the ids in each dataset into one large file we can use for training | ||
for split, dset in tokenized.items(): | ||
arr_len = np.sum(dset['len'], dtype=np.uint64) | ||
filename = os.path.join(os.path.dirname(__file__), f'{split}.bin') | ||
dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) | ||
arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) | ||
total_batches = 1024 | ||
|
||
idx = 0 | ||
for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'): | ||
# Batch together samples for faster write | ||
batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy') | ||
arr_batch = np.concatenate(batch['ids']) | ||
# Write into mmap | ||
arr[idx : idx + len(arr_batch)] = arr_batch | ||
idx += len(arr_batch) | ||
arr.flush() | ||
|
||
# train.bin is ~17GB, val.bin ~8.5MB | ||
# train has ~9B tokens (9,035,582,198) | ||
# val has ~4M tokens (4,434,897) | ||
|
||
# to read the bin files later, e.g. with numpy: | ||
# m = np.memmap('train.bin', dtype=np.uint16, mode='r') |
15 changes: 15 additions & 0 deletions
15
blogs/artificial-intelligence/nanoGPT-JAX/data/openwebtext/readme.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
|
||
## openwebtext dataset | ||
|
||
after running `prepare.py` (preprocess) we get: | ||
|
||
- train.bin is ~17GB, val.bin ~8.5MB | ||
- train has ~9B tokens (9,035,582,198) | ||
- val has ~4M tokens (4,434,897) | ||
|
||
this came from 8,013,769 documents in total. | ||
|
||
references: | ||
|
||
- OpenAI's WebText dataset is discussed in [GPT-2 paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) | ||
- [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/) dataset |
33 changes: 33 additions & 0 deletions
33
blogs/artificial-intelligence/nanoGPT-JAX/data/shakespeare/prepare.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import os | ||
import requests | ||
import tiktoken | ||
import numpy as np | ||
|
||
# download the tiny shakespeare dataset | ||
input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt') | ||
if not os.path.exists(input_file_path): | ||
data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' | ||
with open(input_file_path, 'w', encoding='utf-8') as f: | ||
f.write(requests.get(data_url).text) | ||
|
||
with open(input_file_path, 'r', encoding='utf-8') as f: | ||
data = f.read() | ||
n = len(data) | ||
train_data = data[:int(n*0.9)] | ||
val_data = data[int(n*0.9):] | ||
|
||
# encode with tiktoken gpt2 bpe | ||
enc = tiktoken.get_encoding("gpt2") | ||
train_ids = enc.encode_ordinary(train_data) | ||
val_ids = enc.encode_ordinary(val_data) | ||
print(f"train has {len(train_ids):,} tokens") | ||
print(f"val has {len(val_ids):,} tokens") | ||
|
||
# export to bin files | ||
train_ids = np.array(train_ids, dtype=np.uint16) | ||
val_ids = np.array(val_ids, dtype=np.uint16) | ||
train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin')) | ||
val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin')) | ||
|
||
# train.bin has 301,966 tokens | ||
# val.bin has 36,059 tokens |
9 changes: 9 additions & 0 deletions
9
blogs/artificial-intelligence/nanoGPT-JAX/data/shakespeare/readme.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
|
||
# tiny shakespeare | ||
|
||
Tiny shakespeare, of the good old char-rnn fame :) | ||
|
||
After running `prepare.py`: | ||
|
||
- train.bin has 301,966 tokens | ||
- val.bin has 36,059 tokens |
Oops, something went wrong.