-
Notifications
You must be signed in to change notification settings - Fork 315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor tpu e2e test files #1186
base: main
Are you sure you want to change the base?
Conversation
4a3263c
to
15d9d53
Compare
83ef498
to
676bcaf
Compare
@@ -0,0 +1,35 @@ | |||
#!/bin/bash | |||
|
|||
# This file runs once a day on a CPU and has follows: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "has follows" and "as follows" in consecutive statements seem redundant
# `SCANNED_CHECKPOINT` is the path to the GCS bucket where we want to save our converted (Orbax) checkpoint. Non-Googlers please remember to point `SCANNED_CHECKPOINT` to a GCS bucket that you own | ||
export SCANNED_CHECKPOINT=${CKPT_BUCKET}/${MODEL}/${RUN_ID}/scanned | ||
export UNSCANNED_CHECKPOINT=${CKPT_BUCKET}/${MODEL}/${RUN_ID} | ||
export HF_CHECKPOINT=${CKPT_BUCKET}/${MODEL}/${RUN_ID}/huggingface |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should be doing Convert MaxText checkpoint to HF
here
|
||
# We also test whether the forward pass logits match the golden logits for Gemma-2b | ||
python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_gemma2b per_device_batch_size=1 model_name=gemma-2b max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false attention=dot_product --max_kl_div=0.01 | ||
python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_gemma2b per_device_batch_size=1 model_name=${MODEL} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false attention=dot_product --max_kl_div=0.01 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should put in the forward pass test for the HF_CHECKPOINT too
# `SCANNED_CHECKPOINT` is the path to the GCS bucket where we want to save our converted (Orbax) checkpoint. Non-Googlers please remember to point `SCANNED_CHECKPOINT` to a GCS bucket that you own | ||
export SCANNED_CHECKPOINT=${CKPT_BUCKET}/${MODEL}/${RUN_ID}/scanned | ||
export UNSCANNED_CHECKPOINT=${CKPT_BUCKET}/${MODEL}/${RUN_ID} | ||
export HF_CHECKPOINT=${CKPT_BUCKET}/${MODEL}/${RUN_ID}/huggingface |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same
export ASYNC_CHECKPOINTING=false | ||
export UNSCANNED_CHECKPOINT=${CKPT_BUCKET}/${MODEL}/${RUN_ID}/unscanned/checkpoints/0/items | ||
export SCANNED_CHECKPOINT=${CKPT_BUCKET}/${MODEL}/${RUN_ID}/scanned/0/items | ||
export HF_CHECKPOINT=${CKPT_BUCKET}/${MODEL}/${RUN_ID}/huggingface |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same
|
||
gcloud storage cp -r /tmp/hf_llama2 ${HF_CHECKPOINT} | ||
|
||
echo "All Checkpoints saved with RUN_ID=${RUN_ID}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: let's put a newline
|
||
# gcloud storage cp -r /tmp/hf_llama ${HF_CHECKPOINT} | ||
|
||
echo "All Checkpoints saved with RUN_ID=${RUN_ID}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: let's put a newline
# We also test whether the forward pass logits match the golden logits for LLama3.1-8B | ||
python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} tokenizer_path=assets/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CHECKPOINT} run_name=forward_pass_test per_device_batch_size=1 model_name=${MODEL} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=false --max_kl_div=1e-4 | ||
|
||
# TODO(b/391634569): converting to HF checkpoint OOMs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly let's skip llama3.1-70b from the PR description then
|
||
# Generate unscanned ckpt for efficient decoding test | ||
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml async_checkpointing=false base_output_directory=${UNSCANNED_CHECKPOINT} load_parameters_path=${SCANNED_CHECKPOINT} run_name=unscanned model_name='mistral-7b' force_unroll=true | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no hf conversion?
python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${SCANNED_CHECKPOINT} run_name=scanned_decoding per_device_batch_size=1 model_name=mistral-7b async_checkpointing=false tokenizer_path=assets/tokenizer.mistral-v1 max_prefill_predict_length=11 max_target_length=16 prompt="[INST] I love to [/INST]" attention=dot_product megablox=False sparse_matmul=False | ||
|
||
# Test whether the forward pass logits match the golden logits - matmul implementation | ||
python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${SCANNED_CHECKPOINT} run_name=matmul_forward_pass_test per_device_batch_size=1 model_name=mistral-7b tokenizer_path=assets/tokenizer.mistral-v1 max_prefill_predict_length=11 max_target_length=11 dataset_type=synthetic dtype=float32 megablox=False sparse_matmul=False --atol=3 --rtol=1 --token_size=4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's put forward pass logit checking with hf too
# 2. Create MaxText compatible unscanned orbax checkpoint | ||
|
||
set -ex | ||
RUN_ID=$(date +%Y-%m-%d-%H-%M) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can use use BASE_OUTPUT_PATH
like here, it is much simpler to use this way for our nightly tests
676bcaf
to
80dead8
Compare
80dead8
to
53fec32
Compare
Description
This PR refactors all tpu e2e model test files.
First bash script is run on CPU and contains all checkpoint related tasks as follows:
Second bash script runs on TPU and contains all model related test like pre-training, full finetuning and decoding.
This PR aims to separate direct dependency of modelling tests on checkpoint creation. All model tests will find the most recent checkpoint to run tests on. All checkpoint related tests will be run on a separate cadence.
Following model test files are updated in this PR
FIXES: b/376935929
Tests
Please describe how you tested this change, and include any instructions and/or
commands to reproduce.
Checklist
Before submitting this PR, please make sure (put X in square brackets):