Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve reasoning for size oblivious equations involving max() #125914

Open
ezyang opened this issue May 10, 2024 · 3 comments
Open

Improve reasoning for size oblivious equations involving max() #125914

ezyang opened this issue May 10, 2024 · 3 comments
Assignees
Labels
high priority module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ezyang
Copy link
Contributor

ezyang commented May 10, 2024

馃悰 Describe the bug

Internal xref: https://fb.workplace.com/groups/6829516587176185/posts/7172183699576137/

This program fails to compile:

@torch.compile(fullgraph=True, backend="eager")
def cf(x):
    u0, u1 = x.tolist()
    torch._check_is_size(u0)
    torch._check_is_size(u1)
    torch._check(u0 + u1 == 20)
    if guard_size_oblivious(torch.sym_max(1, u0 + u1) == 20):
        return torch.tensor(True)
    else:
        return torch.tensor(False)

@run_test
def test_symmax():
    assert cf(torch.tensor([10, 10])).item()

Actually, we should be able to make the inference here, because u0 and u1 are size-like in a size oblivious, so we assume they are >= 2, which means that the Max should evaporate, but we are unable to do this.

cc @gchanan @zou3519 @kadeng @msaroufim @bdhirsh @anijain2305 @chauhang @lezcano

Versions

main

@ezyang
Copy link
Contributor Author

ezyang commented May 10, 2024

cc @shazqadeer

@ezyang
Copy link
Contributor Author

ezyang commented May 10, 2024

The harder version of this is when 20 is replaced with a (typically backed) symbolic variable s0.

@lezcano
Copy link
Collaborator

lezcano commented May 10, 2024

I would argue that, since you are literally checking u0+u1 == 20 in the line before, we should be able to pattern match this and use it. We should pattern match something that is expression == constant at the very least.

@bdhirsh bdhirsh added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels May 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants