Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reviewed MaxText README.md #590

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
112 changes: 73 additions & 39 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,17 @@

# Overview

MaxText is a **high performance**, **highly scalable**, **open-source** LLM written in pure Python/Jax and targeting Google Cloud TPUs and GPUs for **training** and **inference**. MaxText achieves [high MFUs](#runtime-performance-results) and scales from single host to very large clusters while staying simple and "optimization-free" thanks to the power of Jax and the XLA compiler.
MaxText is a set of open source reference implementations that are **high performance**, **arbitrarily scalable**, **well-tested**, and written in pure Python/Jax to target Google Cloud TPUs and GPUs. MaxText typically achieves 55% to 60% model-flop utilization and scales from single host to very large clusters while staying simple and "optimization-free" thanks to the power of Jax and the XLA compiler.

MaxText aims to be a launching off point for ambitious LLM projects both in research and production. We encourage users to start by experimenting with MaxText out of the box and then fork and modify MaxText to meet their needs.
MaxText is designed to be a starting point for ambitious LLM projects both in research and production. We encourage you to start experimenting with MaxText and then fork and modify it to meet your needs.

We have used MaxText to [demonstrate high-performance, well-converging training in int8](https://cloud.google.com/blog/products/compute/accurate-quantized-training-aqt-for-tpu-v5e) and [scale training to ~51K chips](https://cloud.google.com/blog/products/compute/the-worlds-largest-distributed-llm-training-job-on-tpu-v5e).

Key supported features:

* TPUs and GPUs (in preview)
* Training and Inference (in preview)
* Models: Llama2, Mistral and Gemma

We have used MaxText to [demonstrate high-performance, well-converging training in int8](https://cloud.google.com/blog/products/compute/accurate-quantized-training-aqt-for-tpu-v5e) and [scale training to ~51K chips](https://cloud.google.com/blog/products/compute/the-worlds-largest-distributed-llm-training-job-on-tpu-v5e).

Expand All @@ -40,15 +48,21 @@ Key supported features:

# Getting Started

For your first time running MaxText, we provide specific [instructions](getting_started/First_run.md).
For instructions on running MaxText the first time, see [First run](getting_started/First_run.md).

MaxText supports training and inference of various open models. Follow user guides in the [getting started](getting_started) folder to know more.
MaxText supports training and inference of various open models. For more information, see [getting started](getting_started).

Some extra helpful guides:
* [Gemma](https://ai.google.dev/gemma): a family of open-weights Large Language Model (LLM) by [Google DeepMind](https://deepmind.google/), based on Gemini research and technology. You can run decode and finetuning using [these instructions](end_to_end/gemma/Run_Gemma.md).
* [Llama2](https://llama.meta.com/llama2/): a family of open-weights Large Language Model (LLM) by Meta. You can run decode and finetuning using [these instructions](getting_started/Run_Llama2.md).
See the these links for more information about the models implemented in MaxText:

In addition to the getting started guides, there are always other MaxText capabilities that are being constantly being added! The full suite of end-to-end tests is in [end_to_end](end_to_end). We run them with a nightly cadence. They can be a good source for understanding MaxText Alternatively you can see the continuous [unit tests](.github/workflows/UnitTests.yml) which are run almost continuously.
* [Gemma](https://ai.google.dev/gemma): a family of open-weights Large Language Model (LLM) by [Google DeepMind](https://deepmind.google/), based on Gemini research and technology. For more information about decoding and fine tuning, see [Run Gemma](end_to_end/gemma/Run_Gemma.md).
* [Llama2](https://llama.meta.com/llama2/): a family of open-weights Large Language Model (LLM) by Meta. For more information about decoding and fine tuning, see [Run Llama2](getting_started/Run_Llama2.md).

In addition to the getting started guides, new content is added regularly! The full suite of end-to-end tests is in [end_to_end](end_to_end). We run them with a nightly cadence. They can be a good source for understanding MaxText. Alternatively, you can see the continuous [unit tests](.github/workflows/UnitTests.yml) which are run on a regular basis.

# Runtime Performance Results

This section describes the runtime performance of MaxText using different TPU versions as well as different numbers of parameters. The performance is measured by TFLOP/sec/chip and [Model Flops Utilization (MFU)](https://services.google.com/fh/files/blogs/tpu_v4_benchmarking.pdf).
You can find more details on reproducing these results can be found in [MaxText/configs/README.md](MaxText/configs/README.md).

# Runtime Performance Results

Expand Down Expand Up @@ -84,52 +98,62 @@ For 16B, 32B, 64B, and 128B models. See full run configs in [MaxText/configs/v5e

# Comparison to Alternatives

MaxText is heavily inspired by [MinGPT](https://github.com/karpathy/minGPT)/[NanoGPT](https://github.com/karpathy/nanoGPT), elegant standalone GPT implementations written in PyTorch and targeting Nvidia GPUs. MaxText is more complex, supporting more industry standard models and scaling to tens of thousands of chips. Ultimately MaxText has an MFU more than three times the [17%](https://twitter.com/karpathy/status/1613250489097027584?cxt=HHwWgIDUhbixteMsAAAA) reported most recently with that codebase, is massively scalable and implements a key-value cache for efficient auto-regressive decoding.
MaxText is heavily inspired by [MinGPT](https://github.com/karpathy/minGPT) and [NanoGPT](https://github.com/karpathy/nanoGPT), elegant standalone GPT implementations written in PyTorch and targeting Nvidia GPUs. MaxText is more complex, supporting more industry standard models and scaling to tens of thousands of chips. Ultimately MaxText has an [MFU](https://cloud.google.com/blog/products/compute/using-cloud-tpu-multislice-to-scale-ai-workloads) more than three times the [17%](https://twitter.com/karpathy/status/1613250489097027584?cxt=HHwWgIDUhbixteMsAAAA) reported most recently with that MinGPT and NanoGPT, is massively scalable, and implements a key-value cache for efficient auto-regressive decoding.

MaxText is more similar to [Nvidia/Megatron-LM](https://github.com/NVIDIA/Megatron-LM), a very well tuned LLM implementation targeting Nvidia GPUs. The two implementations achieve comparable MFUs. The difference in the codebases highlights the different programming strategies. MaxText is pure Python, relying heavily on the XLA compiler to achieve high performance. By contrast, Megatron-LM is a mix of Python and CUDA, relying on well-optimized CUDA kernels to achieve high performance.
MaxText is more similar to [Nvidia/Megatron-LM](https://github.com/NVIDIA/Megatron-LM), a well tuned LLM implementation that targets Nvidia GPUs. MaxText and Megatron-LM implementations achieve comparable MFUs. The difference in the codebases highlights different programming strategies. MaxText is written in pure Python, relying heavily on the XLA compiler to achieve high performance. By contrast, Megatron-LM is a mix of Python and CUDA, relying on well-optimized CUDA kernels to achieve high performance.

MaxText is also comparable to [Pax](https://github.com/google/paxml). Like Pax, MaxText provides high-performance and scalable implementations of LLMs in Jax. However, Pax is a framework in which developers can inject their code or configuration. By contrast, MaxText is a reference implementation designed to be forked and edited as needed.

MaxText is also comparable to [Pax](https://github.com/google/paxml). Like Pax, MaxText provides high-performance and scalable implementations of LLMs in Jax. Pax focuses on enabling powerful configuration parameters, enabling developers to change the model by editing config parameters. By contrast, MaxText is a simple, concrete implementation of various LLMs that encourages users to extend by forking and directly editing the source code.

# Features and Diagnostics

Install the [Cloud TPU diagnostics](https://pypi.org/project/cloud-tpu-diagnostics) Python package to monitor, debug and profile jobs running on Cloud TPUs.

## Collect Stack Traces
When running a Single Program, Multiple Data (SPMD) job on accelerators, the overall process can hang if there is any error or any VM hangs/crashes for some reason. In this scenario, capturing stack traces will help to identify and troubleshoot the issues for the jobs running on TPU VMs.

The following configurations will help to debug a fault or when a program is stuck or hung somewhere by collecting stack traces. Change the parameter values accordingly in `MaxText/configs/base.yml`:
1. Set `collect_stack_trace: True` to enable collection of stack traces on faults or when the program is hung. This setting will periodically dump the traces for the program to help in debugging. To disable this, set `collect_stack_trace: False`.
2. Set `stack_trace_to_cloud: False` to display stack traces on console. `stack_trace_to_cloud: True` will create a temporary file in `/tmp/debugging` in the TPUs to store the stack traces. There is an agent running on TPU VMs that will periodically upload the traces from the temporary directory to cloud logging in the gcp project. You can view the traces in Logs Explorer on Cloud Logging using the following query:
```
logName="projects/<project_name>/logs/tpu.googleapis.com%2Fruntime_monitor"
jsonPayload.verb="stacktraceanalyzer"
```
3. `stack_trace_interval_seconds` signifies the duration in seconds between each stack trace collection event. Setting `stack_trace_interval_seconds: 600` will collect the stack traces every 600 seconds (10 minutes).
When running a Single Program, Multiple Data (SPMD) job on accelerators, the overall process can hang if any errors occur, or if a VM hangs or crashes. Capturing stack traces will help you identify and troubleshoot the issues that are occuring.

The following configurations will help you debug a workload by collecting stack traces. Change the following parameter values accordingly in `MaxText/configs/base.yml`:

* Set `collect_stack_trace: True` to enable collection of stack traces on faults or when the program hangs. This setting periodically dumps stack traces. To disable this, set `collect_stack_trace: False`.
* Set `stack_trace_to_cloud: False` to display stack traces on the console. Or, set `stack_trace_to_cloud: True` to create a temporary file in `/tmp/debugging` in your TPU VMs to store stack traces. There is an agent running on TPU VMs that periodically uploads traces from the temporary directory to [Cloud Logging](https://cloud.google.com/logging/docs/overview). You can view the traces in [Logs Explorer](https://cloud.google.com/logging/docs/view/logs-explorer-interface) in Cloud Logging using the following query:

```none
mikegre-google marked this conversation as resolved.
Show resolved Hide resolved
logName="projects/<project_name>/logs/tpu.googleapis.com%2Fruntime_monitor"
jsonPayload.verb="stacktraceanalyzer"
```
* `stack_trace_interval_seconds` sets the duration in seconds between each stack trace collection event. For example, setting `stack_trace_interval_seconds: 600` will collect the stack traces every 600 seconds (10 minutes).

Here is the related PyPI package: https://pypi.org/project/cloud-tpu-diagnostics.
## Ahead of Time Compilation (AOT, TPU-only)

## Ahead of Time Compilation (AOT, tpu-only)
To compile your training run ahead of time, we provide a tool `train_compile.py`. This tool allows you to compile the main `train_step` in `train.py` for target hardware (e.g. a large number of v5e devices) without using the target hardware, and instead you may use only a CPU or a single VM from a different family. This compilation helps with two main goals:
To compile your training code ahead of time, we provide a tool called `train_compile.py`. This tool allows you to compile the main `train_step` in `train.py` for target hardware (for example, a large number of v5e devices) without using a CPU or a single GCE VM. AOT compilation helps with two main goals:

* It will flag any out of memory (OOM) information, such as when the `per_device_batch_size` is set too high, with an identical OOM stack trace as if it was compiled on the target hardware.
* It flags any out of memory (OOM) information. For example, when `per_device_batch_size` is set too high, the compiler will generate an OOM stack trace identical to one that generated on TPU VMs.

* The ahead of time compilation can be saved and then loaded for fast startup and restart times on the target hardware.
* The compiled code can be saved and loaded for fast startup and restart times on the target hardware.

The tool `train_compile.py` is tightly linked to `train.py` and uses the same configuration file `configs/base.yml`. Although you don't need to run on a TPU, you do need to install `jax[tpu]` in addition to other dependencies, so we recommend running `setup.sh` to install these if you have not already done so.
The tool `train_compile.py` is tightly linked to `train.py` and uses the same configuration file `configs/base.yml`. Although you don't need to run on a TPU VM, you do need to install the `jax[tpu]` package. Run `setup.sh` to install this package and any dependencies.

### Example AOT 1: Compile ahead of time basics
After installing the dependencies listed above, you are ready to compile ahead of time:
```
# Run the below on a single machine, e.g. a CPU

After installing the dependencies listed above, you are ready to compile your code:

```bash
# Run the below on a single machine, for example a CPU
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=v5e-256 compile_topology_num_slices=2 \
global_parameter_scale=16 per_device_batch_size=4
```

This will compile a 16B parameter MaxText model on 2 v5e pods.
This compiles a 16B parameter MaxText model on 2 v5e pods.

### Example AOT 2: Save compiled function, then load and run it
Here is an example that saves then loads the compiled `train_step`, starting with the save:

**Step 1: Run AOT and save compiled function**
```
# Run the below on a single machine, e.g. a CPU
The following example saves and then loads the compiled `train_step`.

**Step 1: Run AOT and save the compiled function**

```bash
# Run these commands on a single machine (CPU).
export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=v5e-256 \
compile_topology_num_slices=2 \
Expand All @@ -139,17 +163,27 @@ per_device_batch_size=4 steps=10000 learning_rate=1e-3

**Step 2: Run train.py and load the compiled function**

To load the compiled train_step, you just need to pass `compiled_trainstep_file=my_compiled_train.pickle` into `train.py`:
```
# Run the below on each host of the target hardware, e.g. each host on 2 slices of v5e-256
To load the compiled `train_step`, pass `compiled_trainstep_file=my_compiled_train.pickle`
into `train.py`:

```bash
# Run the following command on each host of the target hardware.
# In other words, run the command on each host on 2 slices v5e-256
export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
python3 MaxText/train.py MaxText/configs/base.yml run_name=example_load_compile \
compiled_trainstep_file=my_compiled_train.pickle \
global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3 \
base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket
```

In the save step of example 2 above we included exporting the compiler flag `LIBTPU_INIT_ARGS` and `learning_rate` because those affect the compiled object `my_compiled_train.pickle.` The sizes of the model (e.g. `global_parameter_scale`, `max_sequence_length` and `per_device_batch`) are fixed when you initially compile via `compile_train.py`, you will see a size error if you try to run the saved compiled object with different sizes than you compiled with. However a subtle note is that the **learning rate schedule** is also fixed when you run `compile_train` - which is determined by both `steps` and `learning_rate`. The optimizer parameters such as `adam_b1` are passed only as shaped objects to the compiler - thus their real values are determined when you run `train.py`, not during the compilation. If you do pass in different shapes (e.g. `per_device_batch`), you will get a clear error message reporting that the compiled signature has different expected shapes than what was input. If you attempt to run on different hardware than the compilation targets requested via `compile_topology`, you will get an error saying there is a failure to map the devices from the compiled to your real devices. Using different XLA flags or a LIBTPU than what was compiled will probably run silently with the environment you compiled in without error. However there is no guaranteed behavior in this case; you should run in the same environment you compiled in.
The sizes of the model (for example, `global_parameter_scale`, `max_sequence_length` and `per_device_batch`) are fixed when you use AOT compilation using `compile_train.py`. You must run a saved compiled `train_step` with a model of the same size with which it was compiled, otherwise a size error will occur. The **learning rate schedule** (which is determined by both `steps` and `learning_rate`) is also fixed when you run `compile_train`.

Optimizer parameters such as `adam_b1` are passed only as shaped objects to the compiler, their real values are determined at runtime. If you pass in different shapes (for example, `per_device_batch`), than you compiled with, you will get a shape error message.

If you attempt to run a compiled `train_step` on different hardware than the compilation target (using `compile_topology`), you will get an error saying there is a failure to map the devices from the compiled object to your real devices.

While running AOT compiled code with different XLA flags or a different LIBTPU version than what you used during compilation, your code *may* run without error. However you are not guaranteed this will work. Best practice is to run your code in the same environment you compiled in.

## Automatically Upload Logs to Vertex Tensorboard
MaxText supports automatic upload of logs collected in a directory to a Tensorboard instance in Vertex AI. Follow [user guide](getting_started/Use_Vertex_AI_Tensorboard.md) to know more.

MaxText can automatically upload logs generated while your code runs to a Tensorboard instance in Vertex AI. For more information, see [user guide](getting_started/Use_Vertex_AI_Tensorboard.md).