Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix typos #559

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion MaxText/accelerator_to_spec_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
""" Static map of TPU names such as v4-8 to properties such as chip layout."""

""" !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
IF YOU MODIFY THIS FILE YOU SHOULD ALSO ADD CORRESPONDING MODICATIONS TO
IF YOU MODIFY THIS FILE YOU SHOULD ALSO ADD CORRESPONDING MODIFICATIONS TO
UserFacingNameToSystemCharacteristics in xpk/xpk.py !!!!! """

from dataclasses import dataclass
Expand Down
6 changes: 3 additions & 3 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ num_experts_per_tok: 1
mlp_activations: ["silu", "linear"]
dropout_rate: 0
logits_via_embedding: False
normalize_embedding_logits: True # whether to normlize pre-softmax logits if logits_via_embedding is true
normalize_embedding_logits: True # whether to normalize pre-softmax logits if logits_via_embedding is true
logits_dot_in_fp32: True # whether to use fp32 in logits_dense or shared_embedding dot product for stability

# Choose 'remat_policy' between 'minimal', 'save_dot_except_mlpwi', 'save_dot_except_mlp', 'save_qkv_proj', 'qkv_proj_offloaded', 'minimal_offloaded' and 'full'.
Expand Down Expand Up @@ -142,7 +142,7 @@ dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: 1
dcn_fsdp_transpose_parallelism: 1
dcn_sequence_parallelism: 1 # never recommended
dcn_tensor_parallelism: 1 # never recommeneded
dcn_tensor_parallelism: 1 # never recommended
dcn_autoregressive_parallelism: 1 # never recommended
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
Expand Down Expand Up @@ -196,7 +196,7 @@ prefill_cache_dir: "" # If set and load_from_prefill_dir, decode.py reads from d
autoregressive_decode_assert: ""

enable_profiler: False
# If set to true, upload all profiler xplane results from all hosts. Otherwise, only upload the xplane reuslt from the first host.
# If set to true, upload all profiler xplane results from all hosts. Otherwise, only upload the xplane result from the first host.
upload_all_profiler_results: False
# Skip first n steps for profiling, to omit things like compilation and to give
# the iteration time a chance to stabilize.
Expand Down
2 changes: 1 addition & 1 deletion MaxText/generate_param_only_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""

# pylint: disable=g-bad-todo, abstract-method, consider-using-with, ungrouped-imports
"""Trasforms a "full state" including optimzer state to a bfloat16 "parameter state" without optimizer state.
"""Transforms a "full state" including optimizer state to a bfloat16 "parameter state" without optimizer state.
This typically used for turning a state output by training.py into a state than can be consumed by decode.py.

The input "fullstate" is passed in via:
Expand Down
6 changes: 3 additions & 3 deletions MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def reduce_concat_tokens(dataset,
):
"""Token-preprocessor to concatenate multiple unrelated documents.
If we want to generate examples of exactly the right length,
(to avoid wasting space on padding), then we use this function, folowed by
(to avoid wasting space on padding), then we use this function, followed by
split_tokens.
Args:
dataset: a tf.data.Dataset with dictionaries containing the key feature_key.
Expand Down Expand Up @@ -208,7 +208,7 @@ def get_datasets(
train_ds = rekey(train_ds, {'inputs': None, 'targets': 'text'})

eval_ds = eval_ds.shard(num_shards = jax.process_count(), index = jax.process_index())
# note validation_tokenized_5662seqs split is pre tokenized, reduce_concated and splitted to target_length
# note validation_tokenized_5662seqs split is pre tokenized, reduce_concated and split to target_length
# mainly to avoid eval sequences change depending on the number of hosts
eval_ds = rekey(eval_ds, {'inputs': None, 'targets': 'ids'})

Expand All @@ -229,7 +229,7 @@ def preprocess_dataset(config: ml_collections.ConfigDict,
train_ds = split_tokens_to_targets_length(train_ds, config.max_target_length)
train_ds = train_ds.shuffle(shuffle_buffer_size, seed=data_shuffle_seed)

# note eval_ds is pre tokenized, reduce_concated and splitted to target_length
# note eval_ds is pre tokenized, reduce_concated and split to target_length
# mainly to avoid eval sequences change depending on the number of hosts
train_ds = sequence_packing.pack_dataset(train_ds, config.max_target_length)
eval_ds = sequence_packing.pack_dataset(eval_ds, config.max_target_length)
Expand Down
2 changes: 1 addition & 1 deletion MaxText/input_pipeline/input_pipeline_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __next__(self):

@staticmethod
def raw_generate_synthetic_data(config):
"""Generates a single batch of syntehtic data"""
"""Generates a single batch of synthetic data"""
output = {}
output['inputs'] = jax.numpy.zeros( (config.global_batch_size_to_load, config.max_target_length),
dtype=jax.numpy.int32)
Expand Down
2 changes: 1 addition & 1 deletion MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,7 @@ def __call__(self,
Projects the inputs into multi-headed query, key, and value vectors,
applies dot-product attention and project the results to an output vector.

There are three modes: training, prefill and autoregression. During training, the KV cahce
There are three modes: training, prefill and autoregression. During training, the KV cache
is ignored. During prefill, the cache is filled. During autoregression the cache is used.

In the cache initialization call, `inputs_q` has a shape [batch, length,
Expand Down
2 changes: 1 addition & 1 deletion MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def compute_dot_general(inputs, kernel, axis, contract_ind):
kernel_out_axis = np.arange(len(axis), len(axis) + len(features))
if quantizations.in_serve_mode(self.quant):
# During aqt convert state we delete kernel weight from params to save memory.
# Instead they are retreived from the tensors stored in the 'aqt' collection.
# Instead they are retrieved from the tensors stored in the 'aqt' collection.
kernel = jnp.zeros(kernel_shape)
else:
kernel = self.param(
Expand Down
6 changes: 3 additions & 3 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,15 +484,15 @@ def cross_entropy_with_logits(logits: jnp.ndarray, targets: jnp.ndarray,
logits: [batch, length, num_classes] float array.
targets: categorical one-hot targets [batch, length, num_classes] float
array.
z_loss: coefficient for auxilliary z-loss loss term.
z_loss: coefficient for auxiliary z-loss loss term.
Returns:
tuple with the total loss and the z_loss, both
float arrays with shape [batch, length].
"""
logits_sum = jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
log_softmax = logits - logits_sum
loss = -jnp.sum(targets * log_softmax, axis=-1)
# Add auxilliary z-loss term.
# Add auxiliary z-loss term.
log_z = jnp.squeeze(logits_sum, axis=-1)
total_z_loss = z_loss * jax.lax.square(log_z)
loss += total_z_loss
Expand All @@ -513,7 +513,7 @@ def _cross_entropy_with_logits_fwd(
sum_exp = jnp.sum(exp_shifted, axis=-1, keepdims=True)
log_softmax = shifted - jnp.log(sum_exp)
loss = -jnp.sum(targets * log_softmax, axis=-1)
# Add auxilliary z-loss term.
# Add auxiliary z-loss term.
log_z = jnp.squeeze(jnp.log(sum_exp) + max_logit, axis=-1)
total_z_loss = z_loss * jax.lax.square(log_z)
loss += total_z_loss
Expand Down
2 changes: 1 addition & 1 deletion MaxText/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def adam_pax(
) -> optax.GradientTransformation:
"""Standard Adam optimizer that supports weight decay.

Follows the implemenation in pax/praxis sharded_adam
Follows the implementation in pax/praxis sharded_adam
https://github.com/google/praxis/blob/545e00ab126b823265d70c715950d39333484f38/praxis/optimizers.py#L621

Args:
Expand Down
4 changes: 2 additions & 2 deletions MaxText/tests/gpt3_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@


def init_random_model_vars(model, rng, example_batch):
"""initialze random model vars."""
"""initialize random model vars."""
model_vars = model.init(
{'params': rng, 'aqt': rng},
example_batch['inputs'],
Expand Down Expand Up @@ -89,7 +89,7 @@ def test_logits_numerically(self):
# ground truth values are calculated from paxml after loading above model_vars
# note we expect all xents are the same except the padding one since:
# paxml applies padding in mlp layer
# while maxtext implementaiton applies padding in attention mask instead
# while maxtext implementation applies padding in attention mask instead
# the two implementation are equivalent in valid non-padding tokens
per_example_xent_truth = jnp.array([[31.976467, 25.806253, 17.311134, 45.362663, 0.]], dtype=jnp.float32)
logits, _ = self.model.apply(self.model_vars,
Expand Down
2 changes: 1 addition & 1 deletion MaxText/tests/llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def apply_rotary_emb(
freqs_cis: jnp.ndarray,
dtype: jnp.dtype = jnp.bfloat16,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
""" Apply the computed Rotary Postional Embedding"""
""" Apply the computed Rotary Positional Embedding"""
reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)

Expand Down
2 changes: 1 addition & 1 deletion MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def validate_train_config(config):
max_logging.log("WARNING: 'dataset_path' might be pointing your local file system")
if not config.base_output_directory.startswith('gs://'):
max_logging.log("WARNING: 'base_output_directory' might be pointing your local file system")
assert config.steps > 0, "You must set steps or learning_rate_schedule_steps to a positive interger."
assert config.steps > 0, "You must set steps or learning_rate_schedule_steps to a positive integer."



Expand Down
2 changes: 1 addition & 1 deletion MaxText/train_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def get_topology_mesh(config):

def get_shaped_inputs(topology_mesh, config):
""" Get shaped abstractions of inputs to train_step: state, batch and rng """
# Construct the model and optimizier to get shaped versions of the state
# Construct the model and optimizer to get shaped versions of the state
quant = quantizations.configure_quantization(config)
model = Transformer(config, topology_mesh, quant=quant)
# The learning_rate_schedule is baked into the compiled object.
Expand Down
2 changes: 1 addition & 1 deletion getting_started/Run_Gemma.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Gemma is a family of lightweight, state-of-the art open models built from resear

Following commands will let you download Gemma-2B model weights along with its tokenizer, convert the orbax checkpoints to be compatible with MaxText and upload it to a GCS bucket. \
Values for environment variables $KAGGLE_USERNAME and $KAGGLE_KEY can be set using your kaggle account's [API credentials](https://github.com/Kaggle/kaggle-api?tab=readme-ov-file#api-credentials). \
Please use seperate GCS buckets for uploading model weights from kaggle ($MODEL_BUCKET) and MaxText compatible weights ($CHKPT_BUCKET).
Please use separate GCS buckets for uploading model weights from kaggle ($MODEL_BUCKET) and MaxText compatible weights ($CHKPT_BUCKET).
```
wget https://www.kaggle.com/api/v1/models/google/gemma/maxtext/2b/1/download --user=$KAGGLE_USERNAME --password=$KAGGLE_KEY --auth-no-challenge
# Extract downloaded model
Expand Down
2 changes: 1 addition & 1 deletion maxtext_gpu_dependencies.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ RUN apt-get update && apt-get install -y google-cloud-sdk
# Set environment variables for Google Cloud SDK
ENV PATH="/usr/local/google-cloud-sdk/bin:${PATH}"

# Upgrade libcusprase to work with Jax
# Upgrade libcusparse to work with Jax
RUN apt-get update && apt-get install -y libcusparse-12-3

ARG MODE
Expand Down