Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo] Support tracing through issubclass with numpy array getattr arg #125942

Closed
williamwen42 opened this issue May 10, 2024 · 0 comments
Closed
Assignees
Labels
module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@williamwen42
Copy link
Member

williamwen42 commented May 10, 2024

Found when fixing #93624

Code:

import numpy as np
import torch

def fn(x):
    if issubclass(x.__class__, np.ndarray):
        return 1
    return 0

opt_fn = torch.compile(fn, backend="eager")
opt_fn(np.ones([3, 3]))

Error:

Traceback (most recent call last):
  File "/data/users/williamwen/pytorch2/playground3.py", line 10, in <module>
    opt_fn(np.ones([3, 3]))
  File "/data/users/williamwen/pytorch2/torch/_dynamo/eval_frame.py", line 414, in _fn
    return fn(*args, **kwargs)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/convert_frame.py", line 986, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/convert_frame.py", line 827, in _convert_frame
    result = inner_convert(
  File "/data/users/williamwen/pytorch2/torch/_dynamo/convert_frame.py", line 381, in _convert_frame_assert
    return _compile(
  File "/data/users/williamwen/pytorch2/torch/_utils_internal.py", line 70, in wrapper_function
    return function(*args, **kwargs)
  File "/data/users/williamwen/py310-env/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/convert_frame.py", line 737, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "/data/users/williamwen/pytorch2/torch/_dynamo/convert_frame.py", line 708, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/convert_frame.py", line 543, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/bytecode_transformation.py", line 1167, in transform_code_object
    transformations(instructions, code_options)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/convert_frame.py", line 172, in _fn
    return fn(*args, **kwargs)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/convert_frame.py", line 490, in transform
    tracer.run()
  File "/data/users/williamwen/pytorch2/torch/_dynamo/symbolic_convert.py", line 2234, in run
    super().run()
  File "/data/users/williamwen/pytorch2/torch/_dynamo/symbolic_convert.py", line 884, in run
    while self.step():
  File "/data/users/williamwen/pytorch2/torch/_dynamo/symbolic_convert.py", line 799, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/symbolic_convert.py", line 494, in wrapper
    return inner_fn(self, inst)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/symbolic_convert.py", line 1253, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/data/users/williamwen/pytorch2/torch/_dynamo/symbolic_convert.py", line 737, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/data/users/williamwen/pytorch2/torch/_dynamo/variables/builtin.py", line 948, in call_function
    return handler(tx, args, kwargs)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/variables/builtin.py", line 823, in builtin_dipatch
    rv = handler(tx, args, kwargs)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/variables/builtin.py", line 750, in call_self_handler
    result = self_handler(tx, *args, **kwargs)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/variables/builtin.py", line 1404, in call_issubclass
    left_ty = left_ty.as_python_constant()
  File "/data/users/williamwen/pytorch2/torch/_dynamo/variables/base.py", line 203, in as_python_constant
    raise NotImplementedError(f"{self} is not a constant")
torch._dynamo.exc.InternalTorchDynamoError: GetAttrVariable(NumpyNdarrayVariable(), __class__) is not a constant

from user code:
   File "/data/users/williamwen/pytorch2/playground3.py", line 5, in fn
    if issubclass(x.__class__, np.ndarray):

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng

@williamwen42 williamwen42 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module oncall: pt2 module: dynamo labels May 10, 2024
@williamwen42 williamwen42 self-assigned this May 10, 2024
@williamwen42 williamwen42 changed the title [dynamo] Support tracing through numpy array getattr [dynamo] Support tracing through issubclass with numpy array getattr arg May 10, 2024
ZelboK pushed a commit to ZelboK/pytorch that referenced this issue May 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

1 participant