Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add debug functionality for per chip sizes and bytes #625

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

morgandu
Copy link
Collaborator

No description provided.

Copy link
Collaborator

@patemotter patemotter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, some minor comments.

_WARMUP_ITERS = 2


def debug_kv_cache(kv_cache):
singler_kv_cache = kv_cache["cache"]["decoder"]["layers_0"]["self_attention"]["AttentionOp_0"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Is this supposed to be "single" or something else?

Copy link
Collaborator Author

@morgandu morgandu Apr 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added "single_layer"

singler_kv_cache = kv_cache["cache"]["decoder"]["layers_0"]["self_attention"]["AttentionOp_0"]
for cache_key in singler_kv_cache.keys():
cache_element = singler_kv_cache[cache_key]
print(f"{cache_key}:")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Would be helpful to print out what the variable name is. You can do this in f-strings by adding an = like this. print(f"{cache_key=}")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

print(f"{cache_key}:")
if type(cache_element) == flax.linen.spmd.LogicallyPartitioned:
cache_element = cache_element.value
jax.debug.print(" shape: {shape}", shape=cache_element.shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This is a dense series of lines, some whitespace would help make it more readable.

A small thing that you can take or leave related to density is that in jax.debug.print() you can ignore the var naming if you are only printing one var. Like this jax.debug.print(" sharding: {}", cache_element.sharding).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -227,6 +255,8 @@ def main(config):
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)

decode_state = engine.init_decode_state()
debug_kv_cache(decode_state)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to run this twice in the script?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can make this optional, I was also checking the decode_state, which was sharded correctly.

Copy link
Collaborator

@rwitten rwitten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some nits

print(f"{cache_key=}")
if isinstance(cache_element, flax.linen.spmd.LogicallyPartitioned):
cache_element = cache_element.value
jax.debug.print(" shape: {}", cache_element.shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: these really shouldn't be jax.debug.print's because you aren't running them in a jit. You can just print.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -87,6 +87,26 @@ def summarize_size_from_pytree(params):
return num_params, num_bytes, num_bytes / num_params


def calculate_total_params_across_chip(params):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry what does this mean? I wonder if there is a clearer name (and possibly a docstring?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added docstring

@@ -87,6 +87,26 @@ def summarize_size_from_pytree(params):
return num_params, num_bytes, num_bytes / num_params


def calculate_total_params_across_chip(params):
def calculate_sizes_per_chip(arr):
return [np.prod(shard.data.shape) for shard in arr.addressable_shards]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.prod(shard.data.shape) could be shard.data.size?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

sizes_across_chips = jax.tree_util.tree_map(calculate_sizes_per_chip, params)
num_chips = len(sizes_across_chips)
total_sizes_across_chips = jax.tree_util.tree_reduce(lambda x, y: x + y, sizes_across_chips)
sizes_per_chip = total_sizes_across_chips / num_chips
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is INCREDIBLY paranoid code because we're SPMD so calculating any chip is adequate

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But no worries if you're paranoid!

Copy link
Collaborator Author

@morgandu morgandu May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a pass and changed it to a normal paranoid level. But let me share some context why I had this incredibly paranoid code the first place.

If you recall a couple of weeks ago, I was mentioning there was some memory issues that affecting the JetStream serving batch size. One of the issue came down to the prefill_result had an initiation for both prefill cache, and generate cache. The prefill cache was properly sharded, where there was no sharding constraint applied on the generate cache, thus the generate cache created a copy on all TPU chips.

This was confirmed with the utils in this PR. For example, see below ar_key's physical_sizes/bytes versus prefill's:

cached_ar_key:
        shape: (1024, 32, 1, 128)
        sharding: NamedSharding(mesh=Mesh('data': 1, 'fsdp': 1, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 8, 'autoregressive': 1), spec=PartitionSpec())
        total_logical_sizes: 4194304
        total_logical_bytes: 8388608
        n_chips: 8
        total_physical_sizes_across_chips: 33554432
        total_physical_bytes_across_chip: 67108864
cached_ar_value:
        ...... (same as cached_ar_key)
cached_prefill_key:
        shape: (1024, 32, 1, 128)
        sharding: NamedSharding(mesh=Mesh('data': 1, 'fsdp': 1, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 8, 'autoregressive': 1), spec=PartitionSpec(None, 'tensor'))
        total_logical_sizes: 4194304
        total_logical_bytes: 8388608
        n_chips: 8
        total_physical_sizes_across_chips: 4194304
        total_physical_bytes_across_chip: 8388608
cached_prefill_value:
        ...... (same as cached_prefill_key)
logits:
        shape: (1, 1, 32000)
        sharding: NamedSharding(mesh=Mesh('data': 1, 'fsdp': 1, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 8, 'autoregressive': 1), spec=PartitionSpec())
        total_logical_sizes: 32000
        total_logical_bytes: 128000
        n_chips: 8
        total_physical_sizes_across_chips: 256000
        total_physical_bytes_across_chip: 1024000

return total_sizes_across_chips, sizes_per_chip, num_chips


def calculate_total_bytes_across_chip(params):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some similar feedback as above here.

@rwitten rwitten removed their assignment May 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants