Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.jit.trace] Inplace Index Put Silent Error #2188

Open
YifanShenSZ opened this issue Apr 4, 2024 · 0 comments
Open

[torch.jit.trace] Inplace Index Put Silent Error #2188

YifanShenSZ opened this issue Apr 4, 2024 · 0 comments
Assignees
Labels
bug Unexpected behaviour that should be corrected (type) PyTorch (traced)

Comments

@YifanShenSZ
Copy link
Collaborator

馃悶Describing the bug

This simple model can be converted but gives incorrect result

        class IndexPutModel(torch.nn.Module):
            def forward(self, x, position, val):
                y = x.clone()
                y[:, position] = val
                return y

The issue is in how we parse the TorchScript

graph(
    %x : Tensor(2, 2, "<class 'coremltools.converters.mil.mil.types.type_int.make_int.<locals>.int'>"),
    %position : Tensor(1, "<class 'coremltools.converters.mil.mil.types.type_int.make_int.<locals>.int'>"),
    %val : Tensor(1, "<class 'coremltools.converters.mil.mil.types.type_int.make_int.<locals>.int'>"),
):
  %4 = constant[]()
  %y = clone[](%x, %4)
  %6 = constant[value=0]()
  %7 = constant[value=0]()
  %8 = constant[value=9223372036854775807]()
  %9 = constant[value=1]()
  %10 = slice[](%y, %6, %7, %8, %9)
  %11 = listconstruct[]()
  %12 = view[](%val, %11)
  %13 = constant[]()
  %14 = listconstruct[](%13, %position)
  %15 = constant[value=False]()
  %16 = index_put_[](%10, %14, %12, %15)
return (%y)

That we did not realize %10 is a view of %y and %16 is a reference of %10, so we early terminate translation when we see the output %y has been created

main[CoreML5](%x: (2, 2, int32)(Tensor),
              %position: (1,int32)(Tensor),
              %val: (1,int32)(Tensor)) {
  block0() {
    %x_tmp: (2, 2, int32)(Tensor) = identity(x=%x, name="x_tmp")
    %position_tmp: (1,int32)(Tensor) = identity(x=%position, name="position_tmp")
    %val_tmp: (1,int32)(Tensor) = identity(x=%val, name="val_tmp")
  } -> (%x)
}

Potential Solution

We probably won't be able to fully fix this, since the solution is equivalent to torch's own functionalization 馃槀 But we do should find some way to error out, rather than silently produces a wrong Core ML model

@YifanShenSZ YifanShenSZ added bug Unexpected behaviour that should be corrected (type) PyTorch (traced) labels Apr 4, 2024
@YifanShenSZ YifanShenSZ self-assigned this Apr 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Unexpected behaviour that should be corrected (type) PyTorch (traced)
Projects
None yet
Development

No branches or pull requests

1 participant