Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

scalar_tensor call with symbolic bool input does not work in inductor #125956

Open
ezyang opened this issue May 10, 2024 · 1 comment
Open

scalar_tensor call with symbolic bool input does not work in inductor #125956

ezyang opened this issue May 10, 2024 · 1 comment
Assignees
Labels
high priority oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ezyang
Copy link
Contributor

ezyang commented May 10, 2024

馃悰 Describe the bug

Internal xref: https://fb.workplace.com/groups/469587837192818/posts/1638909336927323/

import torch

@torch.compile(fullgraph=True, dynamic=True)
def f(x):
    return torch.tensor(x.size(0) // 100 == 20)

f(torch.randn(8))

fails with

  File "/data/users/ezyang/b/pytorch/torch/_inductor/compile_fx.py", line 784, in fx_codegen_and_compile
    graph.run(*example_inputs)
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/data/users/ezyang/b/pytorch/torch/_inductor/graph.py", line 730, in run
    return super().run(*args)
  File "/data/users/ezyang/b/pytorch/torch/fx/interpreter.py", line 145, in run
    self.env[node] = self.run_node(node)
  File "/data/users/ezyang/b/pytorch/torch/_inductor/graph.py", line 1199, in run_node
    result = super().run_node(n)
  File "/data/users/ezyang/b/pytorch/torch/fx/interpreter.py", line 202, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/data/users/ezyang/b/pytorch/torch/_inductor/graph.py", line 977, in call_function
    raise LoweringException(e, target, args, kwargs).with_traceback(
  File "/data/users/ezyang/b/pytorch/torch/_inductor/graph.py", line 974, in call_function
    out = lowerings[target](*args, **kwargs)
  File "/data/users/ezyang/b/pytorch/torch/_inductor/lowering.py", line 304, in wrapped
    out = decomp_fn(*args, **kwargs)
  File "/data/users/ezyang/b/pytorch/torch/_inductor/lowering.py", line 2299, in tensor
    elif len(data) == 0 or isinstance(data[0], (float, int)) and len(data) <= 8:
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: TypeError: object of type 'Equality' has no len()
  target: aten.scalar_tensor.default
  args[0]: Eq((s0//100), 20)
  kwargs: {'dtype': torch.bool, 'device': device(type='cpu'), 'pin_memory': False}

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Versions

main

cc @gchanan @zou3519 @kadeng @msaroufim @bdhirsh @anijain2305 @chauhang

@ezyang
Copy link
Contributor Author

ezyang commented May 10, 2024

This is a little cursed. So we have

diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py
index 0e83b235edb..1b651fa10cb 100644
--- a/test/inductor/test_torchinductor_dynamic_shapes.py
+++ b/test/inductor/test_torchinductor_dynamic_shapes.py
@@ -227,6 +227,13 @@ class TestInductorDynamic(TestCase):
 
         f(torch.tensor([1, 0, 1, 1, 0, 1, 0]), torch.randn(4))
 
+    def test_scalar_tensor_bool(self, device):
+        @torch.compile(fullgraph=True, dynamic=True)
+        def f(x):
+            return torch.tensor(x.size(0) // 100 == 20, device=x.device)
+
+        f(torch.randn(8, device=device))
+
     @torch._dynamo.config.patch(capture_scalar_outputs=True)
     def test_item_nobreak(self, device):
         @torch.compile(fullgraph=True)
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index a07fb4d8b1b..80679282fe9 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -2286,7 +2286,7 @@ def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False):
 
     ranges: List[sympy.Expr] = []
 
-    if isinstance(data, sympy.Expr):
+    if isinstance(data, (sympy.Expr, sympy.logic.boolalg.Boolean)):
 
         def inner_fn(index):
             return ops.index_expr(data, dtype)

as the obvious first step. But this fails with

  File "/data/users/ezyang/b/pytorch/torch/_inductor/scheduler.py", line 1290, in __init__
    self.nodes = [self.create_scheduler_node(n) for n in nodes]
  File "/data/users/ezyang/b/pytorch/torch/_inductor/scheduler.py", line 1290, in <listcomp>
    self.nodes = [self.create_scheduler_node(n) for n in nodes]
  File "/data/users/ezyang/b/pytorch/torch/_inductor/scheduler.py", line 1382, in create_scheduler_node
    return SchedulerNode(self, node)
  File "/data/users/ezyang/b/pytorch/torch/_inductor/scheduler.py", line 698, in __init__
    self._compute_attrs()
  File "/data/users/ezyang/b/pytorch/torch/_inductor/scheduler.py", line 705, in _compute_attrs
    self._sizes, self._body = self.node.simplify_and_reorder(
  File "/data/users/ezyang/b/pytorch/torch/_inductor/ir.py", line 3440, in simplify_and_reorder
    iter_ranges, iter_reindex, iter_reordering_reindex = simplify_and_reorder(
  File "/data/users/ezyang/b/pytorch/torch/_inductor/ir.py", line 3427, in simplify_and_reorder
    sizes, reindex2, prune = V.graph.sizevars._simplify_loops(
  File "/data/users/ezyang/b/pytorch/torch/_inductor/sizevars.py", line 104, in simplify_loops
    result = self._simplify_loops_impl(index_vars, sizes, index_formulas)
  File "/data/users/ezyang/b/pytorch/torch/_inductor/sizevars.py", line 195, in _simplify_loops_impl
    strides = [self.stride_vars(x, index_vars) for x in index_formulas]
  File "/data/users/ezyang/b/pytorch/torch/_inductor/sizevars.py", line 195, in <listcomp>
    strides = [self.stride_vars(x, index_vars) for x in index_formulas]
  File "/data/users/ezyang/b/pytorch/torch/_inductor/sizevars.py", line 475, in stride_vars
    return cache(index, tuple(vars), tuple(support_vars))
  File "/data/users/ezyang/b/pytorch/torch/_inductor/sizevars.py", line 461, in wrapper
    return fn_cache(*args, **kwargs)
  File "/data/users/ezyang/b/pytorch/torch/_inductor/sizevars.py", line 495, in _stride_vars
    index = index - sympy_subs(
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
TypeError: unsupported operand type(s) for -: 'Equality' and 'Equality'

This is something vaguely like, we don't know how to pass boolean arguments into kernels. I can think of some ways to fix this, but I am curious if people have thoughts.

@bdhirsh bdhirsh added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels May 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants