Skip to content

Commit

Permalink
add debug for per chip sizes and bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
morgandu committed Apr 26, 2024
1 parent 18ba1a7 commit e3c5fb3
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 4 deletions.
6 changes: 4 additions & 2 deletions MaxText/generate_param_only_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def new_pspec(x):
return jax.sharding.PartitionSpec(*x[0 : config.param_scan_axis] + x[config.param_scan_axis + 1 :])

new_per_layer_state_annotation = jax.tree_util.tree_map(new_pspec, training_state_annotations_layers)
new_per_layer_state_sharding = jax.tree_util.tree_map(lambda x: jax.sharding.NamedSharding(mesh, x), new_per_layer_state_annotation)
new_per_layer_state_sharding = jax.tree_util.tree_map(
lambda x: jax.sharding.NamedSharding(mesh, x), new_per_layer_state_annotation)

for i in range(config.num_decoder_layers):

Expand Down Expand Up @@ -90,7 +91,8 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh):
def _save_decode_checkpoint(config, state, checkpoint_manager):
"""Generate checkpoint for decode from the training_state."""
with jax.spmd_mode("allow_all"):
decode_state = max_utils.init_decode_state(None, jax.tree_util.tree_map(lambda x: x.astype(jax.numpy.bfloat16), state.params))
decode_state = max_utils.init_decode_state(
None, jax.tree_util.tree_map(lambda x: x.astype(jax.numpy.bfloat16), state.params))
if checkpoint_manager is not None:
if save_checkpoint(checkpoint_manager, 0, decode_state):
max_logging.log(f"saved an decode checkpoint at {config.checkpoint_dir}")
Expand Down
25 changes: 24 additions & 1 deletion MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""Inference microbenchmark for prefill and autoregressive steps."""
import datetime
import jax
import flax
import json
import sys

Expand All @@ -27,10 +28,30 @@
import maxtext_utils
import pyconfig


_WARMUP_ITERS = 2


def debug_kv_cache(kv_cache):
"""Debug KV Cache sizing and sharding across chips."""
singler_layer_kv_cache = kv_cache["cache"]["decoder"]["layers_0"]["self_attention"]["AttentionOp_0"]
for cache_key in singler_layer_kv_cache.keys():
cache_element = singler_layer_kv_cache[cache_key]
print(f"{cache_key=}")
if isinstance(cache_element, flax.linen.spmd.LogicallyPartitioned):
cache_element = cache_element.value
jax.debug.print(" shape: {}", cache_element.shape)
jax.debug.print(" sharding: {}", cache_element.sharding)
total_logical_sizes, total_logical_bytes, _ = max_utils.summarize_size_from_pytree(cache_element)
total_sizes_across_chips, sizes_per_chip, num_chips = max_utils.calculate_total_params_across_chip(cache_element)
total_bytes_across_chip, bytes_per_chip, num_chips = max_utils.calculate_total_bytes_across_chip(cache_element)
jax.debug.print(" total_logical_sizes, total_logical_bytes: {x}, {y}",
x=total_logical_sizes, y=total_logical_bytes)
jax.debug.print(" total_sizes_across_chips, sizes_per_chip, num_chips: {x}, {y}, {z}",
x=total_sizes_across_chips, y=sizes_per_chip, z=num_chips)
jax.debug.print(" total_bytes_across_chip, bytes_per_chip, num_chips: {x}, {y}, {z}",
x=total_bytes_across_chip, y=bytes_per_chip, z=num_chips)


def prefill_benchmark_loop(engine, params, tokens, true_length, iters):
"""Inner loop for benchmarking prefill step."""
start = datetime.datetime.now()
Expand Down Expand Up @@ -198,6 +219,7 @@ def summarize_prefill_result(engine, params, tokens, true_length):
print(f"Prefill result of length {tokens.size}:\n")
prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
jax.block_until_ready(prefill_result)
debug_kv_cache(prefill_result)
num_prefill_logits_params, total_prefill_logits_size, avg_prefill_logits_param_size = (
max_utils.summarize_pytree_data(prefill_result["logits"], name="Prefill Logits", raw=True)
)
Expand Down Expand Up @@ -227,6 +249,7 @@ def main(config):
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)

decode_state = engine.init_decode_state()

_, cache_size, _ = max_utils.summarize_pytree_data(decode_state["cache"], name="Cache")
num_model_params, model_size, _ = max_utils.summarize_pytree_data(params, name="Model")

Expand Down
20 changes: 20 additions & 0 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,26 @@ def summarize_size_from_pytree(params):
return num_params, num_bytes, num_bytes / num_params


def calculate_total_params_across_chip(params):
def calculate_sizes_per_chip(arr):
return [np.prod(shard.data.shape) for shard in arr.addressable_shards]
sizes_across_chips = jax.tree_util.tree_map(calculate_sizes_per_chip, params)
num_chips = len(sizes_across_chips)
total_sizes_across_chips = jax.tree_util.tree_reduce(lambda x, y: x + y, sizes_across_chips)
sizes_per_chip = total_sizes_across_chips / num_chips
return total_sizes_across_chips, sizes_per_chip, num_chips


def calculate_total_bytes_across_chip(params):
def calculate_bytes_across_chip(arr):
return [shard.data.nbytes for shard in arr.addressable_shards]
bytes_across_chips = jax.tree_util.tree_map(calculate_bytes_across_chip, params)
num_chips = len(bytes_across_chips)
total_bytes_across_chip = jax.tree_util.tree_reduce(lambda x, y: x + y, bytes_across_chips)
bytes_per_chip = total_bytes_across_chip / num_chips
return total_bytes_across_chip, bytes_per_chip, num_chips


def activate_profiler(config, optional_postfix=""):
if config.enable_profiler and (config.upload_all_profiler_results or jax.process_index() == 0):
output_path = os.path.join(config.tensorboard_dir, optional_postfix)
Expand Down
3 changes: 2 additions & 1 deletion MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def load_params(self, *args, **kwargs) -> Params:
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), state.params
)
self.kv_cache_annotations = max_utils.get_kv_cache_annotations(self.model, self.config, self.rng, self._mesh)
self.kv_cache_shardings = jax.tree_util.tree_map(lambda x: jax.sharding.NamedSharding(self._mesh, x), self.kv_cache_annotations)
self.kv_cache_shardings = jax.tree_util.tree_map(
lambda x: jax.sharding.NamedSharding(self._mesh, x), self.kv_cache_annotations)

if not self.model.quant:
self.abstract_params = jax.tree_util.tree_map(
Expand Down

0 comments on commit e3c5fb3

Please sign in to comment.