Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] fix unbacked case in pointwise + reduction vertical fusion #125982

Closed
wants to merge 7 commits into from

Conversation

ColinPeppler
Copy link
Contributor

@ColinPeppler ColinPeppler commented May 11, 2024

$ INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 python test/inductor/test_unbacked_symints.py -k test_vertical_pointwise_reduction_fusion

  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 1953, in fuse_nodes_once
    for node1, node2 in self.get_possible_fusions():
  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2010, in get_possible_fusions
    check_all_pairs(node_grouping)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 1997, in check_all_pairs
    if self.can_fuse(node1, node2):
  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2252, in can_fuse
    return self.get_backend(device).can_fuse_vertical(node1, node2)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/cuda_combined_scheduling.py", line 39, in can_fuse_vertical
    return self._triton_scheduling.can_fuse_vertical(node1, node2)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3237, in can_fuse
    if not all(
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3238, in <genexpr>
    TritonKernel.is_compatible((numel2, rnumel2), n.get_ranges())
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 1543, in is_compatible
    cls._split_iteration_ranges(groups, lengths)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 1507, in _split_iteration_ranges
    while current_group < len(remaining) and sv.size_hint(remaining[current_group]) == 1:
  File "/data/users/colinpeppler/pytorch/torch/_inductor/sizevars.py", line 442, in size_hint
    return int(out)
  File "/home/colinpeppler/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/core/expr.py", line 320, in __int__
    raise TypeError("Cannot convert symbols to int")
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
TypeError: Cannot convert symbols to int

Where the unbacked symints show up at.

> /data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py(1506)_split_iteration_ranges()
(Pdb) print(groups)
(1, 512*u0)
(Pdb) print(lengths)
([u0, 32, 16], [])

Stack from ghstack (oldest at bottom):

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @desertfire @chauhang

Differential Revision: D57527131

Copy link

pytorch-bot bot commented May 11, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125982

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (6 Unrelated Failures)

As of commit 82af1bd with merge base a0429c0 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ColinPeppler added a commit that referenced this pull request May 11, 2024
ghstack-source-id: 8f1f5bf3c637a50bbbf96739945cbde956e52337
Pull Request resolved: #125982
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang

[ghstack-poisoned]
ColinPeppler added a commit that referenced this pull request May 15, 2024
ghstack-source-id: c7b7fe8d7bfd2b5840ac3e7552c68af512c80756
Pull Request resolved: #125982
@ColinPeppler ColinPeppler changed the title [inductor] fix edge case in pw + reduction fusion [inductor] fix edge case in pw + reduction fusion w/ unbacked symints May 16, 2024
@ColinPeppler ColinPeppler changed the title [inductor] fix edge case in pw + reduction fusion w/ unbacked symints [inductor] fix unbacked case in pointwise + reduction vertical fusion May 16, 2024
…ical fusion"

```
$ INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 python test/inductor/test_unbacked_symints.py -k test_vertical_pointwise_reduction_fusion

  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 1953, in fuse_nodes_once
    for node1, node2 in self.get_possible_fusions():
  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2010, in get_possible_fusions
    check_all_pairs(node_grouping)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 1997, in check_all_pairs
    if self.can_fuse(node1, node2):
  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2252, in can_fuse
    return self.get_backend(device).can_fuse_vertical(node1, node2)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/cuda_combined_scheduling.py", line 39, in can_fuse_vertical
    return self._triton_scheduling.can_fuse_vertical(node1, node2)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3237, in can_fuse
    if not all(
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3238, in <genexpr>
    TritonKernel.is_compatible((numel2, rnumel2), n.get_ranges())
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 1543, in is_compatible
    cls._split_iteration_ranges(groups, lengths)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 1507, in _split_iteration_ranges
    while current_group < len(remaining) and sv.size_hint(remaining[current_group]) == 1:
  File "/data/users/colinpeppler/pytorch/torch/_inductor/sizevars.py", line 442, in size_hint
    return int(out)
  File "/home/colinpeppler/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/core/expr.py", line 320, in __int__
    raise TypeError("Cannot convert symbols to int")
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
TypeError: Cannot convert symbols to int
```





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang

[ghstack-poisoned]
ColinPeppler added a commit that referenced this pull request May 16, 2024
ghstack-source-id: d3310234ded31530db7c87188d0ac41c39720ccb
Pull Request resolved: #125982
…ical fusion"

```
$ INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 python test/inductor/test_unbacked_symints.py -k test_vertical_pointwise_reduction_fusion

  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 1953, in fuse_nodes_once
    for node1, node2 in self.get_possible_fusions():
  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2010, in get_possible_fusions
    check_all_pairs(node_grouping)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 1997, in check_all_pairs
    if self.can_fuse(node1, node2):
  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2252, in can_fuse
    return self.get_backend(device).can_fuse_vertical(node1, node2)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/cuda_combined_scheduling.py", line 39, in can_fuse_vertical
    return self._triton_scheduling.can_fuse_vertical(node1, node2)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3237, in can_fuse
    if not all(
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3238, in <genexpr>
    TritonKernel.is_compatible((numel2, rnumel2), n.get_ranges())
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 1543, in is_compatible
    cls._split_iteration_ranges(groups, lengths)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 1507, in _split_iteration_ranges
    while current_group < len(remaining) and sv.size_hint(remaining[current_group]) == 1:
  File "/data/users/colinpeppler/pytorch/torch/_inductor/sizevars.py", line 442, in size_hint
    return int(out)
  File "/home/colinpeppler/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/core/expr.py", line 320, in __int__
    raise TypeError("Cannot convert symbols to int")
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
TypeError: Cannot convert symbols to int
```





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang

[ghstack-poisoned]
ColinPeppler added a commit that referenced this pull request May 16, 2024
ghstack-source-id: 9440993388f901618d8ac6759e845caf32b5928d
Pull Request resolved: #125982
…ical fusion"

```
$ INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 python test/inductor/test_unbacked_symints.py -k test_vertical_pointwise_reduction_fusion

  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 1953, in fuse_nodes_once
    for node1, node2 in self.get_possible_fusions():
  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2010, in get_possible_fusions
    check_all_pairs(node_grouping)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 1997, in check_all_pairs
    if self.can_fuse(node1, node2):
  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2252, in can_fuse
    return self.get_backend(device).can_fuse_vertical(node1, node2)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/cuda_combined_scheduling.py", line 39, in can_fuse_vertical
    return self._triton_scheduling.can_fuse_vertical(node1, node2)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3237, in can_fuse
    if not all(
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3238, in <genexpr>
    TritonKernel.is_compatible((numel2, rnumel2), n.get_ranges())
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 1543, in is_compatible
    cls._split_iteration_ranges(groups, lengths)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 1507, in _split_iteration_ranges
    while current_group < len(remaining) and sv.size_hint(remaining[current_group]) == 1:
  File "/data/users/colinpeppler/pytorch/torch/_inductor/sizevars.py", line 442, in size_hint
    return int(out)
  File "/home/colinpeppler/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/core/expr.py", line 320, in __int__
    raise TypeError("Cannot convert symbols to int")
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
TypeError: Cannot convert symbols to int
```

Where the unbacked symints show up at. 
```
> /data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py(1506)_split_iteration_ranges()
(Pdb) print(groups)
(1, 512*u0)
(Pdb) print(lengths)
([u0, 32, 16], [])
```





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang

[ghstack-poisoned]
ColinPeppler added a commit that referenced this pull request May 16, 2024
ghstack-source-id: dcc801c96afed521190ebf1cbb9d953f2a81b168
Pull Request resolved: #125982
…ical fusion"

```
$ INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 python test/inductor/test_unbacked_symints.py -k test_vertical_pointwise_reduction_fusion

  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 1953, in fuse_nodes_once
    for node1, node2 in self.get_possible_fusions():
  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2010, in get_possible_fusions
    check_all_pairs(node_grouping)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 1997, in check_all_pairs
    if self.can_fuse(node1, node2):
  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2252, in can_fuse
    return self.get_backend(device).can_fuse_vertical(node1, node2)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/cuda_combined_scheduling.py", line 39, in can_fuse_vertical
    return self._triton_scheduling.can_fuse_vertical(node1, node2)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3237, in can_fuse
    if not all(
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3238, in <genexpr>
    TritonKernel.is_compatible((numel2, rnumel2), n.get_ranges())
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 1543, in is_compatible
    cls._split_iteration_ranges(groups, lengths)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 1507, in _split_iteration_ranges
    while current_group < len(remaining) and sv.size_hint(remaining[current_group]) == 1:
  File "/data/users/colinpeppler/pytorch/torch/_inductor/sizevars.py", line 442, in size_hint
    return int(out)
  File "/home/colinpeppler/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/core/expr.py", line 320, in __int__
    raise TypeError("Cannot convert symbols to int")
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
TypeError: Cannot convert symbols to int
```

Where the unbacked symints show up at. 
```
> /data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py(1506)_split_iteration_ranges()
(Pdb) print(groups)
(1, 512*u0)
(Pdb) print(lengths)
([u0, 32, 16], [])
```





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang

[ghstack-poisoned]
ColinPeppler added a commit that referenced this pull request May 16, 2024
ghstack-source-id: 15681bf019170af1e33f406b097a90174fef8475
Pull Request resolved: #125982
@@ -310,6 +310,14 @@ def statically_known_leq(self, left: Expr, right: Expr) -> bool:
expr = left <= right
return self.is_expr_static_and_true(expr)

# See Note - [On Statically Known]
def statically_known_geq(self, left: Expr, right: Expr) -> bool:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left right should probably be typed as Expr | int

@ColinPeppler ColinPeppler marked this pull request as ready for review May 16, 2024 18:11
@ColinPeppler
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 16, 2024
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@ColinPeppler
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

…ical fusion"

```
$ INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 python test/inductor/test_unbacked_symints.py -k test_vertical_pointwise_reduction_fusion

  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 1953, in fuse_nodes_once
    for node1, node2 in self.get_possible_fusions():
  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2010, in get_possible_fusions
    check_all_pairs(node_grouping)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 1997, in check_all_pairs
    if self.can_fuse(node1, node2):
  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2252, in can_fuse
    return self.get_backend(device).can_fuse_vertical(node1, node2)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/cuda_combined_scheduling.py", line 39, in can_fuse_vertical
    return self._triton_scheduling.can_fuse_vertical(node1, node2)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3237, in can_fuse
    if not all(
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3238, in <genexpr>
    TritonKernel.is_compatible((numel2, rnumel2), n.get_ranges())
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 1543, in is_compatible
    cls._split_iteration_ranges(groups, lengths)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 1507, in _split_iteration_ranges
    while current_group < len(remaining) and sv.size_hint(remaining[current_group]) == 1:
  File "/data/users/colinpeppler/pytorch/torch/_inductor/sizevars.py", line 442, in size_hint
    return int(out)
  File "/home/colinpeppler/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/core/expr.py", line 320, in __int__
    raise TypeError("Cannot convert symbols to int")
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
TypeError: Cannot convert symbols to int
```

Where the unbacked symints show up at. 
```
> /data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py(1506)_split_iteration_ranges()
(Pdb) print(groups)
(1, 512*u0)
(Pdb) print(lengths)
([u0, 32, 16], [])
```





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 amjames desertfire chauhang

[ghstack-poisoned]
ColinPeppler added a commit that referenced this pull request May 16, 2024
ghstack-source-id: 0fed4f22b87ac7c3d5e80c8c29c0127ed46dad0a
Pull Request resolved: #125982
@ColinPeppler
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@ColinPeppler
Copy link
Contributor Author

@ColinPeppler has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

ZelboK pushed a commit to ZelboK/pytorch that referenced this pull request May 19, 2024
…pytorch#125982)

```
$ INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 python test/inductor/test_unbacked_symints.py -k test_vertical_pointwise_reduction_fusion

  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 1953, in fuse_nodes_once
    for node1, node2 in self.get_possible_fusions():
  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2010, in get_possible_fusions
    check_all_pairs(node_grouping)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 1997, in check_all_pairs
    if self.can_fuse(node1, node2):
  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2252, in can_fuse
    return self.get_backend(device).can_fuse_vertical(node1, node2)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/cuda_combined_scheduling.py", line 39, in can_fuse_vertical
    return self._triton_scheduling.can_fuse_vertical(node1, node2)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3237, in can_fuse
    if not all(
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3238, in <genexpr>
    TritonKernel.is_compatible((numel2, rnumel2), n.get_ranges())
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 1543, in is_compatible
    cls._split_iteration_ranges(groups, lengths)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 1507, in _split_iteration_ranges
    while current_group < len(remaining) and sv.size_hint(remaining[current_group]) == 1:
  File "/data/users/colinpeppler/pytorch/torch/_inductor/sizevars.py", line 442, in size_hint
    return int(out)
  File "/home/colinpeppler/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/core/expr.py", line 320, in __int__
    raise TypeError("Cannot convert symbols to int")
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
TypeError: Cannot convert symbols to int
```

Where the unbacked symints show up at.
```
> /data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py(1506)_split_iteration_ranges()
(Pdb) print(groups)
(1, 512*u0)
(Pdb) print(lengths)
([u0, 32, 16], [])
```

Pull Request resolved: pytorch#125982
Approved by: https://github.com/jansel
@ColinPeppler
Copy link
Contributor Author

@ColinPeppler has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@ColinPeppler
Copy link
Contributor Author

@ColinPeppler has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

pytorchmergebot pushed a commit that referenced this pull request May 20, 2024
…TemplateBuffer (#126622)

# Context
Here's a peripheral scenario causing the JIT-pass and AOT-pass to pick different fusions.
```py
# JIT -- buf3 is a MultiTemplateBuffer
V.graph.buffers = [buf0, buf1, buf2, buf3, buf4]
                                ^          ^
# JIT pass calls finalize_multi_template_buffers()
V.graph.buffers = [buf0, buf1, buf2, buf4, *buf3*]

# AOT, note proximity_score(buf2, buf4) is "better" for fusion than JIT
V.graph.buffers = [buf0, buf1, buf2, buf4, *buf3*]
                                ^    ^
```

It happens like this:
* JIT starts with the original set nodes using V.graph.buffers
* In JIT, finalize_multi_template_buffers() is called which can change the order of the buffers.
* This makes the order of buffers/scheduler nodes different.
* Now, each node's min/max-order is different than before.
* As a result, the proximity between two nodes is different. https://github.com/pytorch/pytorch/blob/ad67553c5c1672d65b810acd7a6a01e11695098b/torch/_inductor/scheduler.py#L2316-L2335

# Error
```
$ TORCH_LOGS="+fusion" python test/inductor/test_max_autotune.py -k test_jit_fusion_matches_aot_fusion
======================================================================
FAIL: test_jit_fusion_matches_aot_fusion (__main__.TestMaxAutotune)
----------------------------------------------------------------------
Traceback (most recent call last):
  ...
  File "/data/users/colinpeppler/pytorch/torch/_inductor/graph.py", line 1718, in compile_to_fn
    code, linemap = self.codegen_with_cpp_wrapper()
  File "/data/users/colinpeppler/pytorch/torch/_inductor/graph.py", line 1618, in codegen_with_cpp_wrapper
    return self.codegen()
  File "/data/users/colinpeppler/pytorch/torch/_inductor/graph.py", line 1636, in codegen
    self.scheduler.codegen()
  File "/data/users/colinpeppler/pytorch/torch/_dynamo/utils.py", line 210, in time_wrapper
    r = func(*args, **kwargs)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/scheduler.py", line 2602, in codegen
    self.get_backend(device).codegen_node(node)  # type: ignore[possibly-undefined]
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/cuda_combined_scheduling.py", line 66, in codegen_node
    return self._triton_scheduling.codegen_node(node)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3377, in codegen_node
    return self.codegen_node_schedule(node_schedule, buf_accesses, numel, rnumel)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3602, in codegen_node_schedule
    final_kernel.call_kernel(final_kernel.kernel_name)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 3055, in call_kernel
    grid = wrapper.generate_default_grid(name, grid)
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/cpp_wrapper_cuda.py", line 174, in generate_default_grid
    params is not None
AssertionError: cuda kernel parameters for triton_poi_fused_add_0 should already exist at this moment, only found dict_keys(['Placeholder.DESCRIPTIVE_NAME', 'triton_poi_fused_add_mul_0', 'triton_poi_fused_pow_1'])
```

Pull Request resolved: #126622
Approved by: https://github.com/chenyang78
ghstack dependencies: #125982
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants