Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AssertionError: Item selection is supported only on python list/tuple objects #2176

Open
yash-codiste opened this issue Mar 26, 2024 · 6 comments
Labels
bug Unexpected behaviour that should be corrected (type) PyTorch (not traced)

Comments

@yash-codiste
Copy link

Hello Developers,

We are trying to convert Pytorch models to CoreML using coremltools, while converting we used jit.trace to create trace of model where we encountered a warning that if model has controlflow and conditions it is not advisable to use trace instead convert into TorchScript using jit.script,

However after successful conversion of model into TorchScript, Now in the next step of conversion from TorchScript to CoreML here is the error we are getting when we tried to convert to coremltools python package.
This root error is so abstract that we are not able to trace-back from where its occurring.

AssertionError: Item selection is supported only on python list/tuple objects

@TobyRoseman
Copy link
Collaborator

Our support for converting untraced PyTorch models is only experimental. You should have received a warning telling you that.

What is the full error (including stack trace) that you are getting?

Can you give us minimal, self-contained, code to reproduce this issue?

@TobyRoseman TobyRoseman added bug Unexpected behaviour that should be corrected (type) PyTorch (not traced) labels Mar 28, 2024
@yash-codiste
Copy link
Author

Thank you so much @TobyRoseman and team for the quick response on this issue. Here I'm providing required code, logs and information to trace the issue and resolve easily. Please, look at the below sample code which reproduced the above mention error :

import torch
import coremltools as ct

class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()

    def forward(self,x):

        # Error : The error raised from the line x.shape[0] where I have to slice the shape array to get the first dimension. 
        bs = x.shape[0]    
       
        z = torch.randn((bs,bs), dtype=torch.float16)
        x_ = x @ z
        return x_
   
i = torch.randn(3, 3, dtype=torch.float16)
print("in ",i)
mn = MyNet()

out = mn(i)
print("out ",out)

s_mn = torch.jit.script(mn)
print("scripted_mn : ",s_mn.graph)
print("scripted_mn code : ",s_mn.code)

mlmodel = ct.converters.convert(
    s_mn,
    inputs=[ct.TensorType(name="x", shape=i.shape),],
    source='pytorch',
    minimum_deployment_target=ct.target.iOS15,
)

# Save the Core ML model
mlmodel.save("mlmodels/s_mn.mlpackage")

Here is the whole traceback of the error :

Support for converting Torch Script Models is experimental. If possible you should use a traced model for conversion.
Converting PyTorch Frontend ==> MIL Ops:  50%|██████████████████████████████████████████████▌                                              | 4/8 [00:00<?, ? ops/s]
Traceback (most recent call last):
  File "C:\Users\ADMIN\Desktop\yash\misc\ct.py", line 26, in <module>
    mlmodel = ct.converters.convert(
  File "C:\Users\ADMIN\Desktop\yash\misc\misc-venv\lib\site-packages\coremltools\converters\_converters_entry.py", line 574, in convert
    mlmodel = mil_convert(
  File "C:\Users\ADMIN\Desktop\yash\misc\misc-venv\lib\site-packages\coremltools\converters\mil\converter.py", line 188, in mil_convert
    return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
  File "C:\Users\ADMIN\Desktop\yash\misc\misc-venv\lib\site-packages\coremltools\converters\mil\converter.py", line 212, in _mil_convert
    proto, mil_program = mil_convert_to_proto(
  File "C:\Users\ADMIN\Desktop\yash\misc\misc-venv\lib\site-packages\coremltools\converters\mil\converter.py", line 286, in mil_convert_to_proto
    prog = frontend_converter(model, **kwargs)
  File "C:\Users\ADMIN\Desktop\yash\misc\misc-venv\lib\site-packages\coremltools\converters\mil\converter.py", line 108, in __call__
    return load(*args, **kwargs)
  File "C:\Users\ADMIN\Desktop\yash\misc\misc-venv\lib\site-packages\coremltools\converters\mil\frontend\torch\load.py", line 80, in load
    return _perform_torch_convert(converter, debug)
  File "C:\Users\ADMIN\Desktop\yash\misc\misc-venv\lib\site-packages\coremltools\converters\mil\frontend\torch\load.py", line 99, in _perform_torch_convert
    prog = converter.convert()
  File "C:\Users\ADMIN\Desktop\yash\misc\misc-venv\lib\site-packages\coremltools\converters\mil\frontend\torch\converter.py", line 519, in convert
    convert_nodes(self.context, self.graph)
  File "C:\Users\ADMIN\Desktop\yash\misc\misc-venv\lib\site-packages\coremltools\converters\mil\frontend\torch\ops.py", line 88, in convert_nodes
    add_op(context, node)
  File "C:\Users\ADMIN\Desktop\yash\misc\misc-venv\lib\site-packages\coremltools\converters\mil\frontend\torch\ops.py", line 3397, in getitem
    raise AssertionError("Item selection is supported only on python list/tuple objects")
AssertionError: Item selection is supported only on python list/tuple objects

The particular error was raised from the getitem op function from “coremltools/converters/mil/frontend/torch/ops.py” file of coremltools

@M-Quadra
Copy link
Contributor

You can replace shape[0] with

bs = x.size(0)

Note that the output of torch.randn will be fixed after the conversion. As a personal recommendation, it might be better to generate random numbers externally.

@yash-codiste
Copy link
Author

Thank you @M-Quadra, it worked.
Above code purpose is to just reproduce the error, However the model I am trying to convert is a custom model which calls CVAE model in its forward method which requires a randomly generated latent vector as conditional.
For Example

class MyModel(nn.Module):
    def __init__(self,):
        self.cvae = CVAE() # CVAE model subclassed with nn.Module

    def forward(self,input1,input2):
        # some processing on input1 and input2
        processed = input1 + input2
        z = torch.randn((input1.size(0),input1.size(1)))
        out = self.cvae(z,processed)
        return out

As you have mentioned to generate random number externally, any help on how generate the z vector externally that doesn't change the behavior would be appreciated.

@yash-codiste
Copy link
Author

We sincerely appreciate your prompt response and ongoing support @TobyRoseman.

We have encountered another issue that appears to be similar to the one previously addressed in the same repository. Would you kindly review the comment link provided below for further details?
#1683 (comment)

@M-Quadra
Copy link
Contributor

M-Quadra commented Apr 1, 2024

I'll try splitting the model into two parts:

class ShapeModel(nn.Module):
    def forward(self, input1: torch.Tensor, input2: torch.Tensor):
        processed = input1 + input2
        @torch.jit.script_if_tracing
        def _shape(input1, input2):
            return torch.LongTensor([input1.size(0), input2.size(1)]).squeeze()
        return processed, _shape(input1, input2)

class ThenModel(nn.Module):
    def forward(self, z: torch.Tensor, processed: torch.Tensor):
        out = self.cvae(z, processed)
        return out

Next, generate random numbers and feed them into MLMultiArray as the subsequent input.

Like this:

extension MLMultiArray {
    
    /// torch.randn([...]) * scale
    static func randn(
        shape: [NSNumber], dataType: MLMultiArrayDataType = .float32,
        scale: Double = 1
    ) throws -> MLMultiArray {
        let ts = try MLMultiArray(shape: shape, dataType: dataType)
        
        // Box-Muller
        let mean = 0.0, std = 1.0
        let arr = (0..<(ts.count/16 + 1) * 16).map { _ in Double.random(in: 0..<1) }
        
        var i = 0
        while i < ts.count {
            for j in i..<(i+8) {
                let u1 = 1 - arr[j]
                let u2 = arr[j + 8]
                let radius = sqrt(-2 * log(u1))
                let theta = 2 * Double.pi * u2
                
                let z0 = radius * cos(theta) * std + mean
                let z1 = radius * sin(theta) * std + mean
                if j < ts.count { ts[j] = z0 * scale as NSNumber }
                else { return ts }
                if j+8 < ts.count { ts[j+8] = z1 * scale as NSNumber }
            }
            i += 16
        }
        return ts
    }
}

let z = MLMultiArray.randn(shape: [shape[0], shape[1]])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Unexpected behaviour that should be corrected (type) PyTorch (not traced)
Projects
None yet
Development

No branches or pull requests

3 participants