Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inductor] Flex attention supports dynamic shape #125994

Closed
wants to merge 6 commits into from

Conversation

yanboliang
Copy link
Contributor

@yanboliang yanboliang commented May 11, 2024

static shapes perf

| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod   | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|----------------|
| Average |     0.692 |              |             |             |             |            |             |                |
| Max     |     0.855 |           16 |          16 |        4096 |        4096 |         64 | head_bias   | torch.bfloat16 |
| Min     |     0.419 |            8 |          16 |         512 |         512 |        256 | noop        | torch.bfloat16 |

dynamic shapes perf

| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod     | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|---------------|----------------|
| Average |     0.670 |              |             |             |             |            |               |                |
| Max     |     0.864 |           16 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |
| Min     |     0.376 |            8 |          16 |         512 |         512 |        256 | relative_bias | torch.bfloat16 |

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

Copy link

pytorch-bot bot commented May 11, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125994

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit 5014543 with merge base d7fe3c4 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@yanboliang yanboliang marked this pull request as ready for review May 11, 2024 05:51
@yanboliang yanboliang added the topic: not user facing topic category label May 11, 2024
Copy link
Contributor

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should run some benchmarks too.

test/inductor/test_flex_attention.py Outdated Show resolved Hide resolved
test/inductor/test_flex_attention.py Show resolved Hide resolved
torch/_inductor/kernel/flex_attention.py Show resolved Hide resolved
@yanboliang
Copy link
Contributor Author

Should run some benchmarks too.

Yea, benchmarking is on the way.

@yanboliang yanboliang requested a review from Chillee May 14, 2024 17:10
@yanboliang yanboliang added the ciflow/trunk Trigger trunk jobs on your pull request label May 14, 2024
@@ -98,7 +99,7 @@ def generate_inputs(
return query, key, value


def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Above in this file is

torch._dynamo.config.automatic_dynamic_shapes = False

does compile ignore this if dynamic=true?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, dynamic=True means forcing dynamic.

@@ -126,6 +126,19 @@ def score_mod(score, b, h, m, n):


class TestTemplatedSDPA(InductorTestCase):
def _check_equal(self, golden_out, ref_out, compiled_out, dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -617,3 +617,7 @@ def is_from_defaults(source: Source):
if isinstance(source, ChainedSource):
return is_from_defaults(source.base)
return False


def is_cell_contents(source: Source):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this doing out of curiosity?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is part of heuristic rules that determinate if we should wrap int as symint. Here we are saying if the value is from a cell closures, we would not make it dynamic since cell closures usually are constant. We define these heuristics based on source.

@yanboliang
Copy link
Contributor Author

@pytorchbot merge -f "No space left on device"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@yanboliang yanboliang deleted the flex-dyn branch May 15, 2024 04:43
ZelboK pushed a commit to ZelboK/pytorch that referenced this pull request May 19, 2024
## static shapes perf
```
| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod   | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|-------------|----------------|
| Average |     0.692 |              |             |             |             |            |             |                |
| Max     |     0.855 |           16 |          16 |        4096 |        4096 |         64 | head_bias   | torch.bfloat16 |
| Min     |     0.419 |            8 |          16 |         512 |         512 |        256 | noop        | torch.bfloat16 |
```
## dynamic shapes perf
```
| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod     | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|---------------|----------------|
| Average |     0.670 |              |             |             |             |            |               |                |
| Max     |     0.864 |           16 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |
| Min     |     0.376 |            8 |          16 |         512 |         512 |        256 | relative_bias | torch.bfloat16 |
```

Pull Request resolved: pytorch#125994
Approved by: https://github.com/Chillee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants