Skip to content

Commit

Permalink
Add Llama2 13B Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
morgandu committed May 16, 2024
1 parent 618450b commit 4db4056
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 0 deletions.
49 changes: 49 additions & 0 deletions end_to_end/tpu/llama2/13b/1_test_llama2_13b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/bin/bash

# This file, combined with step 2 in the same directory, demonstrates converting a Llama2-13B checkpoint from Meta and running various MaxText operations on it.
# This step is tested nightly on an ordinary CPU VM.

# The flow of this file is as follows:
# 1. Pull the checkpoint from a GCS bucket and uploads the new MaxText compatible checkpoint to destination GCS bucket.
# 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding.

# Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/llama2/13b/1_test_llama2_13b.sh
# Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama2/13b/2_test_llama2_13b.sh.
# Please note that in these two scripts (1_test_llama2_13b.sh and 2_test_llama2_13b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and
# the subfolders names aka RUN_NAMEs are static. Please remember to change BASE_OUTPUT_PATH across different runs.

set -ex
MODEL_VARIATION='llama2-13b'


# We install torch CPU because the checkpoint conversion script MaxText/llama_or_mistral_ckpt.py does not need a TPU/GPU
pip install torch --index-url https://download.pytorch.org/whl/cpu

# We define a var for the path to the Meta checkpoint. Non-Googlers please remember to update the source `META_CHECKPOINT_PATH` to the GCS bucket where you have your Meta checkpoint
export META_CHECKPOINT_PATH=gs://maxtext-llama/llama2-13b/meta-ckpt

# In the following command, we are copying Meta's checkpoint into a local directory `tmp`.
# You can use a different local directory than /tmp/, if you do so, please use the same local path for `base-model-path` when running `python3 MaxText/llama_or_mistral_ckpt.py`
gcloud storage cp -r ${META_CHECKPOINT_PATH} /tmp/

if [ -z "${BASE_OUTPUT_PATH}" ]; then
# Non-Googlers please remember to point BASE_OUTPUT_PATH to GCS buckets that you own, this script uses internal buckets for testing.
# Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama2/13b/2_test_llama2_13b
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)
echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}"
fi

echo "Converted checkpoints are stored at ${BASE_OUTPUT_PATH}"

#Next, run the conversion script `MaxText/llama_or_mistral_ckpt.py` to convert Meta's PyTorch checkpoint in `base-model-path` and save the new converted (Orbax) checkpoint in the `maxtext-model-path`
JAX_PLATFORMS=cpu python3 MaxText/llama_or_mistral_ckpt.py --base-model-path /tmp/meta-ckpt --maxtext-model-path ${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt --model-size ${MODEL_VARIATION}

echo "Wrote MaxText compatible checkpoint to ${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt"

# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory.
export CONVERTED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt/0/items
# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format.
# We can do this by running `MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`.
export RUN_NAME=unscanned_chkpt
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='llama2-13b' force_unroll=true
echo "Written MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items"
61 changes: 61 additions & 0 deletions end_to_end/tpu/llama2/13b/2_test_llama2_13b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#!/bin/bash

# This file is both an integration test that runs once a day on a v4-128 and documentation for how to get started with Llama2-13b.
# Please make sure you have run end_to_end/tpu/llama2/13b/1_test_llama2_13b.sh before running commands from this file.

# The flow of this file is as follows:
# 1. Run decoding, finetuning of Llama2-13B with the converted checkpoint obtained from end_to_end/tpu/llama2/13b/1_test_llama2_13b.sh. Also, run pretraining of Llama2-13B
# 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding.
# 3. Run decoding from the finetuned checkpoint from step 1


# Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/llama2/13b/2_test_llama2_13b.sh
# Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama2/13b/1_test_llama2_13b.sh
# Please note that in these two scripts (1_test_llama2_13b.sh and 2_test_llama2_13b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and
# the subfolders names aka RUN_NAMEs are static. Please remember to change BASE_OUTPUT_PATH across different runs.

set -ex
export MODEL_VARIATION='llama2-13b'

if [ -z "${BASE_OUTPUT_PATH}" ]; then
# Non-Googlers please remember to point `BASE_OUTPUT_PATH` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run
# Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama2/13b/1_test_llama2_13b.sh
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)
echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}"
fi



# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
export DATASET_PATH=gs://maxtext-dataset


# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. This way it is easier to use this path in the `train.py` and `decode.py` commands
export CONVERTED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt/0/items
export RUN_NAME=unscanned_chkpt
# We defined path to unscanned checkpoint created in 1_test_llama2_13b.sh
export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items

# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state.
# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
# We compare our decoded results by asserting with golden outputs using `autoregressive_decode_assert`
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" autoregressive_decode_assert="read. I love to read books, magazines,"

# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" autoregressive_decode_assert="read. I love to read books, magazines,"

# Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning
export FINETUNE_RUN_NAME=runner_finetune
python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} steps=10 async_checkpointing=false model_name=${MODEL_VARIATION} checkpoint_period=5

# We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from
python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.llama2 per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) steps=5 enable_checkpointing=false model_name=${MODEL_VARIATION}

# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters.
# So, we can use the `MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_PATH` from our previous finetuning run.
# `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding
export PARAM_RUN_NAME=param_chkpt
python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_full_state_path=${BASE_OUTPUT_PATH}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true

# Now, run decoding on the checkpoint generated from our finetune run.
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to"

0 comments on commit 4db4056

Please sign in to comment.