Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add MobileViTV2 onnx export #1823

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft

[WIP] Add MobileViTV2 onnx export #1823

wants to merge 1 commit into from

Conversation

xenova
Copy link
Contributor

@xenova xenova commented Apr 22, 2024

What does this PR do?

Attempts to add support for mobilevitv2 ONNX export. However, I've run into a few issues:

e.g., running:

optimum-cli export onnx --model apple/mobilevitv2-1.0-imagenet1k-256 o
  1. Support for the aten::col2im operator was only added in version 18 (Otherwise we get torch.onnx.errors.UnsupportedOperatorError)
    torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::col2im' to ONNX opset version 12 is not supported. Support for this operator was added in version 18, try exporting with this version.
    
  2. When setting DEFAULT_ONNX_OPSET=18, I get:
    /home/codespace/.local/lib/python3.10/site-packages/torch/onnx/utils.py:1548: OnnxExporterWarning: Exporting to ONNX opset version 18 is not supported. by 'torch.onnx.export()'. The highest opset version supported is 17. To use a newer opset version, consider 'torch.onnx.dynamo_export()'. Note that dynamo_export() is in preview. Please report errors with dynamo_export() as Github issues to https://github.com/pytorch/pytorch/issues.
      warnings.warn(
    Traceback (most recent call last):
      File "/home/codespace/.python/current/bin/optimum-cli", line 8, in <module>
        sys.exit(main())
      File "/workspaces/optimum/optimum/commands/optimum_cli.py", line 163, in main
        service.run()
      File "/workspaces/optimum/optimum/commands/export/onnx.py", line 265, in run
        main_export(
      File "/workspaces/optimum/optimum/exporters/onnx/__main__.py", line 352, in main_export
        onnx_export_from_model(
      File "/workspaces/optimum/optimum/exporters/onnx/convert.py", line 1165, in onnx_export_from_model
        _, onnx_outputs = export_models(
      File "/workspaces/optimum/optimum/exporters/onnx/convert.py", line 776, in export_models
        export(
      File "/workspaces/optimum/optimum/exporters/onnx/convert.py", line 881, in export
        export_output = export_pytorch(
      File "/workspaces/optimum/optimum/exporters/onnx/convert.py", line 577, in export_pytorch
        onnx_export(
      File "/home/codespace/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 516, in export
        _export(
      File "/home/codespace/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 1613, in _export
        graph, params_dict, torch_out = _model_to_graph(
      File "/home/codespace/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 1139, in _model_to_graph
        graph = _optimize_graph(
      File "/home/codespace/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 677, in _optimize_graph
        graph = _C._jit_pass_onnx(graph, operator_export_type)
      File "/home/codespace/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 1957, in _run_symbolic_function
        return symbolic_fn(graph_context, *inputs, **attrs)
      File "/home/codespace/.local/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py", line 306, in wrapper
        return fn(g, *args, **kwargs)
      File "/home/codespace/.local/lib/python3.10/site-packages/torch/onnx/symbolic_opset18.py", line 52, in col2im
        num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]
    TypeError: 'NoneType' object is not subscriptable 
    (Occurred when translating col2im).
    

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

@fxmarty @echarlaix

@xenova
Copy link
Contributor Author

xenova commented Apr 22, 2024

Looks like this is a known issue and won't be fixed until the new torch.onnx.dynamo_export is developed. pytorch/pytorch#105134 (comment)

However, there are workarounds, which we could look into.

@xenova
Copy link
Contributor Author

xenova commented Apr 22, 2024

Patching the folding function with:

def folding(self, patches: torch.Tensor, output_size: Tuple[int, int]) -> torch.Tensor:
    batch_size, in_dim, patch_size, n_patches = patches.shape
    patches = patches.reshape(batch_size, in_dim * patch_size, n_patches)

    # Calculate the number of patches in each dimension
    n_patches_height = int(n_patches ** 0.5)
    n_patches_width = n_patches_height

    # Initialize the output feature map
    feature_map = torch.zeros((batch_size, in_dim, output_size[0], output_size[1]), device=patches.device)

    # Iterate over each patch and place it in the correct position in the feature map
    for i in range(n_patches_height):
        for j in range(n_patches_width):
            patch_idx = i * n_patches_width + j
            patch = patches[:, :, patch_idx]
            patch = patch.reshape(batch_size, in_dim, self.patch_height, self.patch_width)
            feature_map[:, :, i*self.patch_height:(i+1)*self.patch_height, j*self.patch_width:(j+1)*self.patch_width] = patch

    return feature_map

and setting opset=12 seems to give equivalent results. Doesn't support dynamic width/height though, but shouldn't be a problem since the processor resizes/crops to 256x256 anyway.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant