From 55c7f8a24206c3768009ed147e05723e2d69581f Mon Sep 17 00:00:00 2001 From: lichen225 <161898702+lichen225@users.noreply.github.com> Date: Thu, 23 May 2024 20:43:37 -0700 Subject: [PATCH] [Example] Add comments to example codes (#36) In this PR, we add comments explaining VeScale APIs in the nanoGPT example. --- examples/nanogpt_4D_finetune/finetune_4D.py | 23 +++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/examples/nanogpt_4D_finetune/finetune_4D.py b/examples/nanogpt_4D_finetune/finetune_4D.py index e0ad58c..e750652 100644 --- a/examples/nanogpt_4D_finetune/finetune_4D.py +++ b/examples/nanogpt_4D_finetune/finetune_4D.py @@ -115,9 +115,10 @@ def main(): device = f"cuda:{rank}" torch.cuda.set_device(device) init_process_group(backend=backend, world_size=world_size, rank=rank) - + # + + + VeScale API below VESCALE_DEVICE_MESH.init_device_mesh(device, (dp_size, tp_size), mesh_dim_names=["DP", "TP"]) mesh = VESCALE_DEVICE_MESH.get() + # + + + VeScale API above ddp_rank = get_rank() // tp_size else: rank = 0 @@ -137,7 +138,9 @@ def main(): if master_process: os.makedirs(out_dir, exist_ok=True) torch.manual_seed(1337) + # + + + VeScale API below manual_seed(1337, mesh) + # + + + VeScale API above torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast @@ -147,7 +150,13 @@ def main(): # poor man's data loader data_dir = os.path.join("data", dataset) - # + + + Support larger batch size when running evaluation and only the master process do the random sampling + """ + Deterministic data loader for loss match: + This data loader ensures that the mini-batch sampling has identical behavior no matter how many GPUs are used. + In particular, at each training iteration, each rank samples a batch of indices under the identical RNG state. + Then, each Data Parallelism (DP) rank takes the corresponding subset of indices and fetches the corresponding sequences from the dataset. + """ + def get_batch(split, bsz=batch_size, lbsz=local_batch_size): # We recreate np.memmap every batch to avoid a memory leak, as per # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122 @@ -166,9 +175,11 @@ def get_batch(split, bsz=batch_size, lbsz=local_batch_size): x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) else: x, y = x.to(device), y.to(device) + # + + + VeScale API below if ddp: x = distribute_tensor(x, VESCALE_DEVICE_MESH["TP"], [Replicate()]) y = distribute_tensor(y, VESCALE_DEVICE_MESH["TP"], [Replicate()]) + # + + + VeScale API above return x, y # init these up here, can override if init_from='resume' (i.e. from a checkpoint) @@ -335,9 +346,11 @@ def get_lr(it): wandb.init(project=wandb_project, name=wandb_run_name, config=config) # Load checkpoint + # + + + VeScale Load checkpoint if load_checkpoint_path: checkpoint_state = {"model": model, "optimizer": optimizer} vescale.checkpoint.load(load_checkpoint_path, checkpoint_state) + # + + + VeScale API above # training loop X, Y = get_batch("train") # fetch the very first batch t0 = time.time() @@ -369,14 +382,18 @@ def get_lr(it): if iter_num > 0: # When iter_num == 0, the training does not start sotoptimizer state is empty, # Don't save checkpoint + # + + + VeScale API below checkpoint_state = {"model": model, "optimizer": optimizer} vescale.checkpoint.save(os.path.join(save_checkpoint_path, f"iter_{iter_num}"), checkpoint_state) + # + + + VeScale API above if iter_num == 0 and eval_only: break # forward backward update, with optional gradient accumulation to simulate larger batch size + # + + + VeScale API below if ddp: model.zero_grad_buffer() + # + + + VeScale API above for micro_step in range(gradient_accumulation_steps): # with ctx: logits, loss = model(X, Y) @@ -385,8 +402,10 @@ def get_lr(it): X, Y = get_batch("train") # backward pass loss.backward() + # + + + VeScale API below if ddp: model.finish_grad_sync() + # + + + VeScale API above optimizer.step() # flush the gradients as soon as we can, no need for this memory anymore optimizer.zero_grad(set_to_none=True)