Skip to content

Commit

Permalink
refactor tpu e2e test files
Browse files Browse the repository at this point in the history
  • Loading branch information
khatwanimohit committed Jan 24, 2025
1 parent 1c47a6d commit 83ef498
Show file tree
Hide file tree
Showing 22 changed files with 600 additions and 476 deletions.
91 changes: 65 additions & 26 deletions MaxText/convert_gemma_chkpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@
import jax
import jax.numpy as jnp
import numpy as np
import psutil
import gc

jax.config.update("jax_platform_name", "cpu")
import argparse
import copy
from flax.training import train_state

from typing import Any
import sys
import logging
import max_logging


Expand All @@ -37,6 +39,8 @@

Params = dict[str, Any]

SIMULATED_CPU_DEVICES_COUNT = 16


def nest_params(params: Params) -> Params:
"""Nests params as a dict of dicts rather than a flat dict."""
Expand All @@ -50,18 +54,11 @@ def nest_params(params: Params) -> Params:
return nested_params


def main(raw_args=None) -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--base_model_path", type=str, required=True)
parser.add_argument("--maxtext_model_path", type=str, required=True)
parser.add_argument("--model_size", type=str, required=True)
args = parser.parse_args(raw_args)
if args.model_size not in ("2b", "7b", "9b"):
raise NotImplementedError

def convert_to_jax_weights(base_model_path, model_size):
"""Convert to MaxText compatible orbax weights."""
print("Loading checkpoint")
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
params = checkpointer.restore(args.base_model_path)
params = checkpointer.restore(base_model_path)
params = nest_params(params)
num_layers = max((int(k.split("_")[1]) for k in params["transformer"].keys() if "layer_" in k)) + 1
hidden_dim, embed_dim = params["transformer"]["layer_0"]["mlp"]["linear"]["w"].shape
Expand Down Expand Up @@ -103,7 +100,7 @@ def main(raw_args=None) -> None:
for layer_idx in range(num_layers):
in_layer_name = "layer_" + str(layer_idx)
# attention block
if args.model_size in ("2b", "9b"): # MQA
if model_size in ("2b", "9b"): # MQA
self_attention["query"]["kernel"].append(
params["transformer"][in_layer_name]["attn"]["q_einsum"]["w"].transpose((1, 0, 2)) * head_dim**-0.5
)
Expand Down Expand Up @@ -148,35 +145,77 @@ def main(raw_args=None) -> None:

layer_weight["self_attention"] = copy.deepcopy(self_attention)
jax_weights["decoder"]["layers"] = copy.deepcopy(layer_weight)
jax_weights = jax.tree_util.tree_map(jnp.array, jax_weights)

def astype_fn(x):
if isinstance(x, jnp.ndarray):
return x.astype(jnp.bfloat16)
return jax_weights


def save_jax_weights_to_checkpoint(maxtext_model_path, jax_weights):
"""
Function to save jax_weights ready for MaxText to a parameters checkpoint
"""
mem_info = psutil.Process()
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
gc.collect()
mesh = jax.sharding.Mesh(jax.devices(), "checkpoint_sharding_axis")
s1 = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("checkpoint_sharding_axis")) # shards first axis
s2 = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None, "checkpoint_sharding_axis")) # shards second axis
s3 = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None)) # no sharding

def checkpoint_device_put(arr):
if arr.shape[0] % SIMULATED_CPU_DEVICES_COUNT == 0:
max_logging.log("sharding first axis")
return jax.device_put(arr, device=s1)
elif len(arr.shape) > 1 and arr.shape[1] % SIMULATED_CPU_DEVICES_COUNT == 0:
max_logging.log("sharding second axis")
return jax.device_put(arr, device=s2)
else:
return x

jax_weights = jax.tree_util.tree_map(astype_fn, jax_weights)

max_logging.log("no sharding was possible, replicating")
return jax.device_put(arr, device=s3)

# convert all weights to jax.numpy with sharding if applicable
jax_weights_flat, jax_weights_struct = jax.tree.flatten(jax_weights)
jax_weights_new = []
while len(jax_weights_flat) > 0:
jax_weight = jax_weights_flat.pop(0)
jax_weights_new.append(checkpoint_device_put(jax_weight))
del jax_weight
gc.collect()
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))

jax_weights = jax.tree.unflatten(jax_weights_struct, jax_weights_new)

# dummy configs for the checkpoint_manager
step_number_to_save_new_ckpt = 0
enable_checkpointing = True
async_checkpointing = False
save_interval_steps = 1

checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
args.maxtext_model_path, enable_checkpointing, async_checkpointing, save_interval_steps
maxtext_model_path, enable_checkpointing, async_checkpointing, save_interval_steps
)

state_new = train_state.TrainState(
step=0, apply_fn=None, params={"params": jax_weights}, tx=None, opt_state={} # type: ignore
)

logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
if checkpoint_manager is not None:
if save_checkpoint(checkpoint_manager, 0, state_new):
max_logging.log("saved a checkpoint at step 0")
if save_checkpoint(checkpoint_manager, step_number_to_save_new_ckpt, state_new):
max_logging.log(f"saved a checkpoint at step {step_number_to_save_new_ckpt}")
# Upon preemption, exit when and only when all ongoing saves are complete.
if checkpoint_manager.reached_preemption(0):
checkpoint_manager.wait_until_finished()
sys.exit()
checkpoint_manager.wait_until_finished()


def main(raw_args=None) -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--base_model_path", type=str, required=True)
parser.add_argument("--maxtext_model_path", type=str, required=True)
parser.add_argument("--model_size", type=str, required=True)
args = parser.parse_args(raw_args)
if args.model_size not in ("2b", "7b", "9b"):
raise NotImplementedError

save_jax_weights_to_checkpoint(args.maxtext_model_path, convert_to_jax_weights(args.base_model_path, args.model_size))


if __name__ == "__main__":
Expand Down
6 changes: 6 additions & 0 deletions MaxText/llama_mistral_mixtral_orbax_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,19 @@ def load_hf_model(model_size):
"""
if model_size == "llama2-7b":
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
elif model_size == "llama2-70b":
config = AutoConfig.from_pretrained("meta-llama/Llama-2-70b-hf")
model = AutoModelForCausalLM.from_config(config)
elif model_size == "mistral-7b":
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
elif model_size == "mixtral-8x7b":
model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1", device_map="auto")
elif model_size == "llama3.1-8b":
config = AutoConfig.from_pretrained("meta-llama/Llama-3.1-8B")
model = AutoModelForCausalLM.from_config(config)
elif model_size == "llama3.1-70b":
config = AutoConfig.from_pretrained("meta-llama/Llama-3.1-70B")
model = AutoModelForCausalLM.from_config(config)
else:
raise NotImplementedError
return model
Expand Down
35 changes: 35 additions & 0 deletions end_to_end/tpu/gemma/2b/1_test_gemma_2b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/bin/bash

# This file runs once a day on a CPU and has follows:

# The flow of this file is as follows:
# 1. Convert the checkpoint downloaded from Kaggle to make it compatible with MaxText
# 2. Create MaxText compatible unscanned orbax checkpoint

set -ex
RUN_ID=$(date +%Y-%m-%d-%H-%M)

export MODEL='gemma-2b'
export ASYNC_CHECKPOINTING=false
export CKPT_BUCKET=gs://maxtext-model-checkpoints
# `SCANNED_CHECKPOINT` is the path to the GCS bucket where we want to save our converted (Orbax) checkpoint. Non-Googlers please remember to point `SCANNED_CHECKPOINT` to a GCS bucket that you own
export SCANNED_CHECKPOINT=${CKPT_BUCKET}/${MODEL}/${RUN_ID}/scanned
export UNSCANNED_CHECKPOINT=${CKPT_BUCKET}/${MODEL}/${RUN_ID}
export HF_CHECKPOINT=${CKPT_BUCKET}/${MODEL}/${RUN_ID}/huggingface

# Installing torch for deps in forward_pass_logit_chekcker.py
pip install torch --index-url https://download.pytorch.org/whl/cpu

# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \
# Non-Googlers please remember to use separate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET).
# Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing.
export CHKPT_BUCKET=gs://maxtext-gemma/flax

JAX_PLATFORMS=cpu python MaxText/convert_gemma_chkpt.py --base_model_path ${CHKPT_BUCKET}/2b --maxtext_model_path ${SCANNED_CHECKPOINT} --model_size 2b

# We define `SCANNED_CHECKPOINT` to refer to the checkpoint subdirectory exactly inside `SCANNED_CHECKPOINT`. This way it is easier to use this path in future commands
export SCANNED_CHECKPOINT=${SCANNED_CHECKPOINT}/0/items

# Note that the `SCANNED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format.
# We can do this by running `MaxText/generate_param_only_checkpoint.py` on `SCANNED_CHECKPOINT` with `force_unroll=true`.
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${UNSCANNED_CHECKPOINT} load_parameters_path=${SCANNED_CHECKPOINT} async_checkpointing=${ASYNC_CHECKPOINTING} run_name=unscanned model_name=${MODEL} force_unroll=true
Loading

0 comments on commit 83ef498

Please sign in to comment.