Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor tpu e2e test files #1186

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Refactor tpu e2e test files #1186

wants to merge 2 commits into from

Conversation

khatwanimohit
Copy link
Collaborator

@khatwanimohit khatwanimohit commented Jan 22, 2025

Description

This PR refactors all tpu e2e model test files.
First bash script is run on CPU and contains all checkpoint related tasks as follows:

  • Convert parent model checkpoint to MaxText compatible orbax checkpoint
  • Convert scanned checkpoint to Unscanned for efficient decoding
  • Convert MaxText checkpoint to HF

Second bash script runs on TPU and contains all model related test like pre-training, full finetuning and decoding.

This PR aims to separate direct dependency of modelling tests on checkpoint creation. All model tests will find the most recent checkpoint to run tests on. All checkpoint related tests will be run on a separate cadence.

Following model test files are updated in this PR

  • Llama2-7B
  • Llama2-70B
  • Llama3.1-8B
  • LLama3.1-70B
  • Gemma-2B
  • Gemma-7B
  • Mistral-7B

FIXES: b/376935929

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@khatwanimohit khatwanimohit force-pushed the mohit/ckpt_reorg branch 7 times, most recently from 4a3263c to 15d9d53 Compare January 23, 2025 18:20
@khatwanimohit khatwanimohit marked this pull request as ready for review January 23, 2025 18:29
@khatwanimohit khatwanimohit force-pushed the mohit/ckpt_reorg branch 3 times, most recently from 83ef498 to 676bcaf Compare January 24, 2025 21:02
@@ -0,0 +1,35 @@
#!/bin/bash

# This file runs once a day on a CPU and has follows:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "has follows" and "as follows" in consecutive statements seem redundant

# `SCANNED_CHECKPOINT` is the path to the GCS bucket where we want to save our converted (Orbax) checkpoint. Non-Googlers please remember to point `SCANNED_CHECKPOINT` to a GCS bucket that you own
export SCANNED_CHECKPOINT=${CKPT_BUCKET}/${MODEL}/${RUN_ID}/scanned
export UNSCANNED_CHECKPOINT=${CKPT_BUCKET}/${MODEL}/${RUN_ID}
export HF_CHECKPOINT=${CKPT_BUCKET}/${MODEL}/${RUN_ID}/huggingface
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be doing Convert MaxText checkpoint to HF here


# We also test whether the forward pass logits match the golden logits for Gemma-2b
python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_gemma2b per_device_batch_size=1 model_name=gemma-2b max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false attention=dot_product --max_kl_div=0.01
python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_gemma2b per_device_batch_size=1 model_name=${MODEL} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false attention=dot_product --max_kl_div=0.01
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should put in the forward pass test for the HF_CHECKPOINT too

# `SCANNED_CHECKPOINT` is the path to the GCS bucket where we want to save our converted (Orbax) checkpoint. Non-Googlers please remember to point `SCANNED_CHECKPOINT` to a GCS bucket that you own
export SCANNED_CHECKPOINT=${CKPT_BUCKET}/${MODEL}/${RUN_ID}/scanned
export UNSCANNED_CHECKPOINT=${CKPT_BUCKET}/${MODEL}/${RUN_ID}
export HF_CHECKPOINT=${CKPT_BUCKET}/${MODEL}/${RUN_ID}/huggingface
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

export ASYNC_CHECKPOINTING=false
export UNSCANNED_CHECKPOINT=${CKPT_BUCKET}/${MODEL}/${RUN_ID}/unscanned/checkpoints/0/items
export SCANNED_CHECKPOINT=${CKPT_BUCKET}/${MODEL}/${RUN_ID}/scanned/0/items
export HF_CHECKPOINT=${CKPT_BUCKET}/${MODEL}/${RUN_ID}/huggingface
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same


gcloud storage cp -r /tmp/hf_llama2 ${HF_CHECKPOINT}

echo "All Checkpoints saved with RUN_ID=${RUN_ID}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's put a newline


# gcloud storage cp -r /tmp/hf_llama ${HF_CHECKPOINT}

echo "All Checkpoints saved with RUN_ID=${RUN_ID}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's put a newline

# We also test whether the forward pass logits match the golden logits for LLama3.1-8B
python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} tokenizer_path=assets/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CHECKPOINT} run_name=forward_pass_test per_device_batch_size=1 model_name=${MODEL} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=false --max_kl_div=1e-4

# TODO(b/391634569): converting to HF checkpoint OOMs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly let's skip llama3.1-70b from the PR description then


# Generate unscanned ckpt for efficient decoding test
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml async_checkpointing=false base_output_directory=${UNSCANNED_CHECKPOINT} load_parameters_path=${SCANNED_CHECKPOINT} run_name=unscanned model_name='mistral-7b' force_unroll=true

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no hf conversion?

python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${SCANNED_CHECKPOINT} run_name=scanned_decoding per_device_batch_size=1 model_name=mistral-7b async_checkpointing=false tokenizer_path=assets/tokenizer.mistral-v1 max_prefill_predict_length=11 max_target_length=16 prompt="[INST] I love to [/INST]" attention=dot_product megablox=False sparse_matmul=False

# Test whether the forward pass logits match the golden logits - matmul implementation
python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${SCANNED_CHECKPOINT} run_name=matmul_forward_pass_test per_device_batch_size=1 model_name=mistral-7b tokenizer_path=assets/tokenizer.mistral-v1 max_prefill_predict_length=11 max_target_length=11 dataset_type=synthetic dtype=float32 megablox=False sparse_matmul=False --atol=3 --rtol=1 --token_size=4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's put forward pass logit checking with hf too

# 2. Create MaxText compatible unscanned orbax checkpoint

set -ex
RUN_ID=$(date +%Y-%m-%d-%H-%M)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can use use BASE_OUTPUT_PATH like here, it is much simpler to use this way for our nightly tests

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants