Skip to content

Commit

Permalink
add debug for cache sharding and size across chips
Browse files Browse the repository at this point in the history
  • Loading branch information
morgandu committed Apr 26, 2024
1 parent 18ba1a7 commit 4bdd2c5
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,26 @@ def summarize_size_from_pytree(params):
return num_params, num_bytes, num_bytes / num_params


def calculate_total_params_across_chip(params):
def calculate_sizes_per_chip(arr):
return [np.prod(shard.data.shape) for shard in arr.addressable_shards]
sizes_across_chips = jax.tree_util.tree_map(calculate_sizes_per_chip, params)
num_chips = len(sizes_across_chips)
total_sizes_across_chips = jax.tree_util.tree_reduce(lambda x, y: x + y, sizes_across_chips)
sizes_per_chip = total_sizes_across_chips / num_chips
return total_sizes_across_chips, sizes_per_chip, num_chips


def calculate_total_bytes_across_chip(params):
def calculate_bytes_across_chip(arr):
return [shard.data.nbytes for shard in arr.addressable_shards]
bytes_across_chips = jax.tree_util.tree_map(calculate_bytes_across_chip, params)
num_chips = len(bytes_across_chips)
total_bytes_across_chip = jax.tree_util.tree_reduce(lambda x, y: x + y, bytes_across_chips)
bytes_per_chip = total_bytes_across_chip / num_chips
return total_bytes_across_chip, bytes_per_chip, num_chips


def activate_profiler(config, optional_postfix=""):
if config.enable_profiler and (config.upload_all_profiler_results or jax.process_index() == 0):
output_path = os.path.join(config.tensorboard_dir, optional_postfix)
Expand Down

0 comments on commit 4bdd2c5

Please sign in to comment.