Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

unsqueeze_copy doesn't seem to respect dynamic shapes #125853

Closed
angelayi opened this issue May 9, 2024 · 4 comments
Closed

unsqueeze_copy doesn't seem to respect dynamic shapes #125853

angelayi opened this issue May 9, 2024 · 4 comments
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@angelayi
Copy link
Contributor

angelayi commented May 9, 2024

馃悰 Describe the bug

It seems like aten.unsqueeze_copy specializes on the shape:

from torch._subclasses.fake_tensor import FakeTensorMode 
from torch.fx.experimental.symbolic_shapes import (
    ShapeEnv, DimDynamic, StatelessSymbolicContext
)

shape_env = ShapeEnv()
t1 = torch.ones(2, 2, 768)
with FakeTensorMode(shape_env=shape_env) as fake_mode:
    t = fake_mode.from_tensor(
        t1,
        symbolic_context=StatelessSymbolicContext(
            dynamic_sizes=[DimDynamic.DYNAMIC, DimDynamic.STATIC, DimDynamic.STATIC],
        )
    )
print(t)  # FakeTensor(..., size=(s0, 2, 768))
print(torch.ops.aten.unsqueeze(t, 1))  # FakeTensor(..., size=(s0, 1, 2, 768))
print(torch.ops.aten.unsqueeze_copy(t, 1))  # FakeTensor(..., size=(2, 1, 2, 768))

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @tarun292 @larryliu0820

Versions

main

@malfet malfet added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module oncall: pt2 labels May 9, 2024
@malfet
Copy link
Contributor

malfet commented May 9, 2024

@ezyang I've asked this question in the past, but guess never got a rely: should dynamic shapes go into oncall pt2 or not

@ezyang
Copy link
Contributor

ezyang commented May 10, 2024

yes please

@ezyang
Copy link
Contributor

ezyang commented May 10, 2024

This typically means that unsqueeze_copy is forcing a specialization. This can be confirmed with TORCH_LOGS=dynamic. https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.fh8zzonyw8ng says how to resolve.

@avikchaudhuri
Copy link
Contributor

At the risk of speculating, maybe due to this line?

new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim]

@angelayi do you have the logs handy?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants