Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversion for Padding related to dynamic input shape failed #1921

Open
xorange opened this issue Jul 25, 2023 · 4 comments · May be fixed by #2050
Open

Conversion for Padding related to dynamic input shape failed #1921

xorange opened this issue Jul 25, 2023 · 4 comments · May be fixed by #2050
Labels
bug Unexpected behaviour that should be corrected (type) PyTorch (traced) triaged Reviewed and examined, release as been assigned if applicable (status)

Comments

@xorange
Copy link

xorange commented Jul 25, 2023

🐞Describing the bug

If model contained a Padding, whose pad value is related to dynamic input shape, the conversion failed.

Stack Trace

scikit-learn version 1.2.2 is not supported. Minimum required version: 0.17. Maximum required version: 1.1.2. Disabling scikit-learn conversion API.
Torch version 2.0.1 has not been tested with coremltools. You may run into unexpected errors. Torch 2.0.0 is the most recent version that has been tested.
Testing with coremltools version: 7.0b1 ...
Converting PyTorch Frontend ==> MIL Ops:  91%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████              | 10/11 [00:00<00:00, 6644.97 ops/s]
Traceback (most recent call last):
  File "/Users/oliverxu/Workspace/coremltools_issues/v7.0b_torch_dynamicshape.py", line 24, in <module>
    converted = ct.convert(
  File "/Users/oliverxu/Workspace/coremltools/coremltools/converters/_converters_entry.py", line 530, in convert
    mlmodel = mil_convert(
  File "/Users/oliverxu/Workspace/coremltools/coremltools/converters/mil/converter.py", line 188, in mil_convert
    return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
  File "/Users/oliverxu/Workspace/coremltools/coremltools/converters/mil/converter.py", line 212, in _mil_convert
    proto, mil_program = mil_convert_to_proto(
  File "/Users/oliverxu/Workspace/coremltools/coremltools/converters/mil/converter.py", line 286, in mil_convert_to_proto
    prog = frontend_converter(model, **kwargs)
  File "/Users/oliverxu/Workspace/coremltools/coremltools/converters/mil/converter.py", line 108, in __call__
    return load(*args, **kwargs)
  File "/Users/oliverxu/Workspace/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 63, in load
    return _perform_torch_convert(converter, debug)
  File "/Users/oliverxu/Workspace/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 102, in _perform_torch_convert
    prog = converter.convert()
  File "/Users/oliverxu/Workspace/coremltools/coremltools/converters/mil/frontend/torch/converter.py", line 439, in convert
    convert_nodes(self.context, self.graph)
  File "/Users/oliverxu/Workspace/coremltools/coremltools/converters/mil/frontend/torch/ops.py", line 92, in convert_nodes
    add_op(context, node)
  File "/Users/oliverxu/Workspace/coremltools/coremltools/converters/mil/frontend/torch/ops.py", line 1569, in pad
    if pad.val is not None:
AttributeError: 'list' object has no attribute 'val'

To Reproduce

import torch
import torch.nn as nn
import torch.nn.functional as F

import coremltools as ct

class TryPad(nn.Module):
    def forward(self, x):
        pad_length = x.size(2)
        y = F.pad(x, [0, 0, pad_length, pad_length])
        return y
torch_model = TryPad().eval()

length = 10
x = torch.Tensor([[1] * length * 192]).reshape(1, 192, length)

y = torch_model(x)

print('Testing with coremltools version:', ct.__version__, '...')

traced_model = torch.jit.trace(torch_model, x)
converted = ct.convert(
    traced_model,
    inputs=[ct.TensorType(
        name='x',
        shape=ct.Shape(shape=(1, 192,
            # 10, # static shape, conversion succeed
            ct.RangeDim(lower_bound=2, upper_bound=1024, default=length) # dynamic shape, conversion failed
    )))],
    outputs=[ct.TensorType(name='y')],
    # convert_to="mlprogram", # same w/ or w/o
)

Throwing:
  File "/Users/oliverxu/Workspace/coremltools/coremltools/converters/mil/frontend/torch/ops.py", line 1569, in pad
    if pad.val is not None:
AttributeError: 'list' object has no attribute 'val'

My local installed coremltools has no change:

(base) 7.0b1clean ~/Workspace/coremltools
 $ git status
On branch 7.0b1
nothing to commit, working tree clean
(base) 7.0b1clean ~/Workspace/coremltools
 $ git log -1
commit b5ba7e1d19310dafc10fb26a2e8ef525214b91ff (HEAD -> 7.0b1, tag: 7.0b1)
Author: Toby Roseman <troseman@apple.com>
Date:   Mon Jun 5 15:11:04 2023 -0700

    7.0b1 release (#1874)

System environment (please complete the following information):

  • coremltools version: 7.0b1
  • OS: MacOS Ventura 13.4.1 (22F2083)
  • torch: 2.0.1

Additional context

A more meaningful explanation of this use is from Self-Attention with Relative Position Representations

  • there's a predefined window_size in the network, regardless of input
  • input shape is variant (for variant-length input text)
  • the gap between variant input shape and predefined window_size during inference, is filled with padding and slice, so that the expected length of tensor is constructed for each position in attention, then matmul is applied.

Also, I'm not sure whether this is a bug or is expected behavior (that it's known not supported)

@xorange xorange added the bug Unexpected behaviour that should be corrected (type) label Jul 25, 2023
@YifanShenSZ YifanShenSZ added the triaged Reviewed and examined, release as been assigned if applicable (status) label Jul 27, 2023
@chophilip21
Copy link

Has this issue been fixed? seems like the error remains

@hadiidbouk
Copy link

I am facing the same issue here:

#1991 (comment)

@harimohanraj
Copy link

I'm also running into this issue! Same thing, I have a padding scheme that relies directly on the input shape, and it does not seem like when the conversion is happening that the pad op understands how to propagate the dynamic shape.

@xorange
Copy link
Author

xorange commented Nov 9, 2023

I've applied a fixing PR which is based on 7.0b2. Before reviewed and merged, I recommend it as a temporary work-around.

Test snippet, which contains 2 different dynamic shape dim:

import torch
import torch.nn as nn
import torch.nn.functional as F

import coremltools as ct

def forward_coreml(model, x):
    spec = model.get_spec()
    input_name = spec.description.input[0].name
    output_name = spec.description.output[0].name

    input_dict = { input_name : x }
    coreml_out = model.predict(input_dict)
    cy = coreml_out[ output_name ]
    return cy

class TryPad(nn.Module):
    def forward(self, x):
        y = F.pad(x, [0, 0, x.size(1), x.size(2)])
        return y
torch_model = TryPad().eval()

x = torch.Tensor([[1] * 3 * 4]).reshape(1, 4, 3)

y = torch_model(x)

print('Testing with coremltools version:', ct.__version__, '...')

traced_model = torch.jit.trace(torch_model, x)
converted = ct.convert(
    traced_model,
    inputs=[ct.TensorType(
        name='x',
        shape=ct.Shape(shape=(1,
            ct.RangeDim(lower_bound=3, upper_bound=24, default=length),
            ct.RangeDim(lower_bound=2, upper_bound=1024, default=length)
    )))],
    outputs=[ct.TensorType(name='y')],
    # convert_to="mlprogram", # same w/ or w/o
)

print('y', y)
print('cy', forward_coreml(converted, x))

Output:

 $ python v7.0b2_torch_dynamicshape.py 
scikit-learn version 1.2.2 is not supported. Minimum required version: 0.17. Maximum required version: 1.1.2. Disabling scikit-learn conversion API.
Torch version 2.0.1 has not been tested with coremltools. You may run into unexpected errors. Torch 2.0.0 is the most recent version that has been tested.
Testing with coremltools version: 7.0b2 ...
When both 'convert_to' and 'minimum_deployment_target' not specified, 'convert_to' is set to "mlprogram" and 'minimum_deployment_targer' is set to ct.target.iOS15 (which is same as ct.target.macOS12). Note: the model will not run on systems older than iOS15/macOS12/watchOS8/tvOS15. In order to make your model run on older system, please set the 'minimum_deployment_target' to iOS14/iOS13. Details please see the link: https://coremltools.readme.io/docs/unified-conversion-api#target-conversion-formats
Converting PyTorch Frontend ==> MIL Ops:  93%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████           | 13/14 [00:00<00:00, 3205.71 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 10126.28 passes/s]
Running MIL default pipeline: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 66/66 [00:00<00:00, 4587.21 passes/s]
Running MIL backend_mlprogram pipeline: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 15409.93 passes/s]
y tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]])
cy [[[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Unexpected behaviour that should be corrected (type) PyTorch (traced) triaged Reviewed and examined, release as been assigned if applicable (status)
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants