Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

export_model.py crashes with keypoints #5255

Open
Huxwell opened this issue Apr 12, 2024 · 9 comments
Open

export_model.py crashes with keypoints #5255

Huxwell opened this issue Apr 12, 2024 · 9 comments

Comments

@Huxwell
Copy link

Huxwell commented Apr 12, 2024

EDIT: I am discussing export_model.py issues with keypoints in : #5143 since it receives more attention.

Instructions To Reproduce the 馃悰 Bug:

  1. Full runnable code or full changes you made:
  2. What exact command you run:
python detectron2/tools/deploy/export_model.py \
    --sample-image 1344x1344.jpg \
    --config-file detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml \
    --export-method tracing \
    --format onnx \
    --output ./keypoints_onnx \
    MODEL.WEIGHTS model_final_a6e10b.pkl 
    MODEL.DEVICE cuda
  1. Full logs or other relevant observations:
python detectron2/tools/deploy/export_model.py --sample-image 1344x1344.jpg --config-file detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml --export-method tracing --format onnx --output ./keypoints_onnx MODEL.WEIGHTS /home/ubuntu/onnx_trtmodel_final_a6e10b.pkl MODEL.DEVICE cuda
[04/12 12:18:05 detectron2]: Command line arguments: Namespace(config_file='detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml', export_method='tracing', format='onnx', opts=['MODEL.WEIGHTS', '/home/ubuntu/onnx_trtmodel_final_a6e10b.pkl', 'MODEL.DEVICE', 'cuda'], output='./keypoints_onnx', run_eval=False, sample_image='1344x1344.jpg')
[W init.cpp:833] Warning: Use _jit_set_fusion_strategy, bailout depth is deprecated. Setting to (STATIC, 1) (function operator())
/home/ubuntu/detectron2/env_perception/lib/python3.8/site-packages/torch/cuda/__init__.py:546: UserWarning: Can't initialize NVML
  warnings.warn("Can't initialize NVML")
[04/12 12:18:07 d2.checkpoint.detection_checkpoint]: [DetectionCheckpointer] Loading from /home/ubuntu/onnx_trtmodel_final_a6e10b.pkl ...
/home/ubuntu/detectron2/detectron2/detectron2/structures/image_list.py:85: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert t.shape[:-2] == tensors[0].shape[:-2], t.shape
/home/ubuntu/detectron2/env_perception/lib/python3.8/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
/home/ubuntu/detectron2/detectron2/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/home/ubuntu/detectron2/detectron2/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/home/ubuntu/detectron2/detectron2/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/home/ubuntu/detectron2/detectron2/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/home/ubuntu/detectron2/detectron2/detectron2/modeling/proposal_generator/proposal_utils.py:106: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if not valid_mask.all():
/home/ubuntu/detectron2/detectron2/detectron2/structures/boxes.py:191: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!"
/home/ubuntu/detectron2/detectron2/detectron2/structures/boxes.py:192: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  h, w = box_size
/home/ubuntu/detectron2/detectron2/detectron2/layers/nms.py:15: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert boxes.shape[-1] == 4
/home/ubuntu/detectron2/env_perception/lib/python3.8/site-packages/torch/__init__.py:1209: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert condition, message
/home/ubuntu/detectron2/detectron2/detectron2/layers/roi_align.py:55: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert rois.dim() == 2 and rois.size(1) == 5
/home/ubuntu/detectron2/detectron2/detectron2/modeling/roi_heads/fast_rcnn.py:138: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if not valid_mask.all():
/home/ubuntu/detectron2/detectron2/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/home/ubuntu/detectron2/detectron2/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/home/ubuntu/detectron2/detectron2/detectron2/structures/boxes.py:191: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!"
/home/ubuntu/detectron2/detectron2/detectron2/structures/boxes.py:192: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  h, w = box_size
/home/ubuntu/detectron2/detectron2/detectron2/modeling/roi_heads/fast_rcnn.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if num_bbox_reg_classes == 1:
/home/ubuntu/detectron2/detectron2/detectron2/layers/nms.py:15: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert boxes.shape[-1] == 4
/home/ubuntu/detectron2/env_perception/lib/python3.8/site-packages/torch/__init__.py:1209: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert condition, message
/home/ubuntu/detectron2/detectron2/detectron2/layers/roi_align.py:55: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert rois.dim() == 2 and rois.size(1) == 5
/home/ubuntu/detectron2/env_perception/lib/python3.8/site-packages/torch/jit/_trace.py:1223: UserWarning: operator() sees varying value in profiling, ignoring and this should be handled by GUARD logic (Triggered internally at ../third_party/nvfuser/csrc/parser.cpp:3777.)
  return compiled_fn(*args, **kwargs)
/home/ubuntu/detectron2/env_perception/lib/python3.8/site-packages/torch/onnx/symbolic_opset9.py:5589: UserWarning: Exporting aten::index operator of advanced indexing in opset 11 is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph will produce incorrect results.
  warnings.warn(
/home/ubuntu/detectron2/env_perception/lib/python3.8/site-packages/torchvision/ops/_register_onnx_ops.py:59: UserWarning: ROIAlign with aligned=True is only supported in opset >= 16. Please export with opset 16 or higher, or use aligned=False.
  warnings.warn(
============= Diagnostic Run torch.onnx.export version 2.0.0+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

Traceback (most recent call last):
  File "detectron2/tools/deploy/export_model.py", line 225, in <module>
    exported_model = export_tracing(torch_model, sample_inputs)
  File "detectron2/tools/deploy/export_model.py", line 132, in export_tracing
    torch.onnx.export(traceable_model, (image,), f, opset_version=STABLE_ONNX_OPSET_VERSION)
  File "/home/ubuntu/detectron2/env_perception/lib/python3.8/site-packages/torch/onnx/utils.py", line 506, in export
    _export(
  File "/home/ubuntu/detectron2/env_perception/lib/python3.8/site-packages/torch/onnx/utils.py", line 1548, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/home/ubuntu/detectron2/env_perception/lib/python3.8/site-packages/torch/onnx/utils.py", line 1117, in _model_to_graph
    graph = _optimize_graph(
  File "/home/ubuntu/detectron2/env_perception/lib/python3.8/site-packages/torch/onnx/utils.py", line 665, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
  File "/home/ubuntu/detectron2/env_perception/lib/python3.8/site-packages/torch/onnx/utils.py", line 1891, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
  File "/home/ubuntu/detectron2/env_perception/lib/python3.8/site-packages/torch/onnx/symbolic_opset9.py", line 6709, in prim_loop
    torch._C._jit_pass_onnx_block(
  File "/home/ubuntu/detectron2/env_perception/lib/python3.8/site-packages/torch/onnx/utils.py", line 1891, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
  File "/home/ubuntu/detectron2/env_perception/lib/python3.8/site-packages/torch/onnx/symbolic_opset11.py", line 1063, in index
    return opset9.index(g, self, index)
  File "/home/ubuntu/detectron2/env_perception/lib/python3.8/site-packages/torch/onnx/symbolic_opset9.py", line 5580, in index
    return symbolic_helper._unimplemented(
  File "/home/ubuntu/detectron2/env_perception/lib/python3.8/site-packages/torch/onnx/symbolic_helper.py", line 607, in _unimplemented
    _onnx_unsupported(f"{op}, {msg}", value)
  File "/home/ubuntu/detectron2/env_perception/lib/python3.8/site-packages/torch/onnx/symbolic_helper.py", line 618, in _onnx_unsupported
    raise errors.SymbolicValueError(
torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of operator aten::index, operator of advanced indexing on tensor of unknown rank. Try turning on shape inference during export: torch.onnx._export(..., onnx_shape_inference=True).. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues  [Caused by the value 'roi_map.3 defined in (%roi_map.3 : Tensor = onnx::Reshape(%roi_map, %2727) # /home/ubuntu/detectron2/detectron2/detectron2/structures/keypoints.py:205:18
)' (type 'Tensor') in the TorchScript graph. The containing node has kind 'onnx::Reshape'.] 
    (node defined in   File "/home/ubuntu/detectron2/detectron2/detectron2/structures/keypoints.py", line 205
        # Although semantically equivalent, `reshape` is used instead of `squeeze` due
        # to limitation during ONNX export of `squeeze` in scripting mode
        roi_map = roi_map.reshape(roi_map.shape[1:])  # keypoints x H x W
                  ~~~~~~~~~~~~~~~ <--- HERE

        # softmax over the spatial region
)

    Inputs:
        #0: roi_map defined in (%roi_map : Tensor = onnx::Resize[coordinate_transformation_mode="half_pixel", cubic_coeff_a=-0.75, mode="cubic", nearest_mode="floor"](%2710, %2719, %2720, %2718) # /home/ubuntu/detectron2/detectron2/detectron2/structures/keypoints.py:201:18
    )  (type 'Tensor')
        #1: 2727 defined in (%2727 : LongTensor(device=cpu)[] = onnx::Slice(%2722, %2724, %2725, %2723, %2726) # /home/ubuntu/detectron2/detectron2/detectron2/structures/keypoints.py:205:34
    )  (type 'List[Tensor]')
    Outputs:
        #0: roi_map.3 defined in (%roi_map.3 : Tensor = onnx::Reshape(%roi_map, %2727) # /home/ubuntu/detectron2/detectron2/detectron2/structures/keypoints.py:205:18
    )  (type 'Tensor')
  1. please simplify the steps as much as possible so they do not require additional resources to
    run, such as a private dataset.

Expected behavior:

An onnx file generated that I can :

  1. run with onnxruntime to verify correctness
  2. export to TensorRT, but this is probably beyond the scope of detectron2 maintenance

Environment:

Provide your environment information using the following command:

wget -nc -q https://github.com/facebookresearch/detectron2/raw/main/detectron2/utils/collect_env.py && python collect_env.py

/home/ubuntu/detectron2/env_perception/lib/python3.8/site-packages/torch/cuda/__init__.py:546: UserWarning: Can't initialize NVML
  warnings.warn("Can't initialize NVML")
-------------------------------  ------------------------------------------------------------------------------------------------------------
sys.platform                     linux
Python                           3.8.10 (default, Nov 22 2023, 10:22:35) [GCC 9.4.0]
numpy                            1.24.4
detectron2                       0.6 @/home/ubuntu/detectron2/detectron2/detectron2
detectron2._C                    not built correctly: No module named 'detectron2._C'
Compiler ($CXX)                  c++ (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
DETECTRON2_ENV_MODULE            <not set>
PyTorch                          2.0.0+cu117 @/home/ubuntu/detectron2/env_perception/lib/python3.8/site-packages/torch
PyTorch debug build              False
torch._C._GLIBCXX_USE_CXX11_ABI  False
GPU available                    Yes
GPU 0                            NVIDIA T600 Laptop GPU (arch=7.5)
Driver version
CUDA_HOME                        None - invalid!
Pillow                           10.0.0
torchvision                      0.15.1+cu117 @/home/ubuntu/detectron2/env_perception/lib/python3.8/site-packages/torchvision
torchvision arch flags           /home/ubuntu/detectron2/env_perception/lib/python3.8/site-packages/torchvision/_C.so
fvcore                           0.1.5.post20221221
iopath                           0.1.10
cv2                              4.7.0
-------------------------------  ------------------------------------------------------------------------------------------------------------
PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.7.3 (Git Hash 6dbeffbae1f23cbbeae17adb7b5b13f1f37c080e)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.7
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86
  - CuDNN 8.5
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.7, CUDNN_VERSION=8.5.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wunused-local-typedefs -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.0.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, 
@RajUpadhyay
Copy link

@Huxwell I see, it really is weird. It hasn't been long since I ran a keypoint onnx with onnxruntime-gpu.
The error mentioned above is due to the code in the onnx package, you can try using one of the following

  1. ONNX_FALLTHROUGH
  2. ONNX_ATEN
  3. ONNX_ATEN_FALLBACK

I do not remember much but I think I got success with aten_fallback.

@Huxwell
Copy link
Author

Huxwell commented Apr 24, 2024

No luck so far, with ONNX_ATEN_FALLBACK not changing anything for me and ONNX_ATEN resulting in export crash

aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a):
Expected a value of type 'Tensor' for argument 'self' but instead found type 'List[Tensor]'.
Empty lists default to List[Tensor]. Add a variable annotation to the assignment to create an empty list of another type (torch.jit.annotate(List[T, []]) where T is the type of elements in the list for Python 2)

I will read comment section of my '/Apr2024Detectron_venv/lib/python3.8/site-packages/torch/onnx/utils.py', run the unit tests from detectron2 for export and analyze the logs and hope I will find some clue there.

In the meantime, do you happen to have a converted .onnx file from vanilla keypoints detector, that you could share? It would allow me to verify if my problem is with the exporter (Loops in keypoints related code) or rather with my onnxruntime usage/version.

@Huxwell
Copy link
Author

Huxwell commented Apr 29, 2024

Just to clarify: ONNX_FALLTHROUGH successfully generates an onnx file, but onnxruntime crushes when reading such file with

Type parameter (T) of Optype (SequenceConstruct) bound to different types (tensor(int64) and tensor(float) in node (SequenceConstruct_2862)

@Huxwell
Copy link
Author

Huxwell commented Apr 30, 2024

Managed to successfully bump STABLE_ONNX_OPSET_VERSION from 11 to 16 and 17 (eliminating warning about RoIAlign). The petty issue is I was changing the version in cloned version of detectron2, rather than the one installed by pip in a venv.

I have mocked keypoints computation and succeeded in running the model in onnxruntime.
First I though the problem lies in detectron2/modeling/roi_heads/keypoint_head.py keypoint_rcnn_inference() , now I am certain it's in the heatmaps_to_keypoints() it calls
https://github.com/facebookresearch/detectron2/blob/main/detectron2/structures/keypoints.py
It seems I must rewrite a loop inside heatmaps_to_keypoints()

    for i in range(num_rois):
        outsize = (int(heights_ceil[i]), int(widths_ceil[i]))
        roi_map = F.interpolate(maps[i].unsqueeze(0), size=outsize, mode="bicubic", align_corners=False).squeeze(0)

        max_score, _ = roi_map.view(num_keypoints, -1).max(1)
        max_score = max_score.view(num_keypoints, 1, 1)
        roi_map = roi_map - max_score  # Normalize heatmap for stability
        exp_map = roi_map.exp_()
        score = exp_map / exp_map.view(num_keypoints, -1).sum(1).view(num_keypoints, 1, 1)

        pos = exp_map.view(num_keypoints, -1).argmax(1)
        x_int = pos % outsize[1]
        y_int = pos // outsize[1]

        x = (x_int.float() + 0.5) * width_corrections[i]
        y = (y_int.float() + 0.5) * height_corrections[i]

        xy_preds[i, :, 0] = x + offset_x[i]
        xy_preds[i, :, 1] = y + offset_y[i]
        xy_preds[i, :, 2] = roi_map.view(num_keypoints, -1).gather(1, pos.unsqueeze(1)).squeeze(1)
        xy_preds[i, :, 3] = score.view(num_keypoints, -1).gather(1, pos.unsqueeze(1)).squeeze(1)

so it's onnx compatible, working on that.

@RajUpadhyay
Copy link

@Huxwell Glad to hear you could solve it!
The amount of version dependencies between the Detectron2 and TensorRT is insane.

I tried to run Detectron2 as it is on my Ubuntu 22.04, it will run with cpu but won't run with cuda haha.

Let me know if something else happens.

@Huxwell
Copy link
Author

Huxwell commented May 1, 2024

Ok, now I am able to run correctly in onnxruntime, with reasonable predictions (even with custom models, using custom number of keypoints instead of 17, r18/r34 backbone instead of r50, using my weights rather than pretrained etc).
The only changes are in the aforementioned loop (apart from ONNX->ONNX_FALLTHROUGH and STABLE_ONNX_OPSET_VERSION 11->16):
https://github.com/facebookresearch/detectron2/blob/main/detectron2/structures/keypoints.py

@torch.jit.script_if_tracing
def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
    """
    Extract predicted keypoint locations from heatmaps.

    Args:
        maps (Tensor): (#ROIs, #keypoints, POOL_H, POOL_W). The predicted heatmap of logits for
            each ROI and each keypoint.
        rois (Tensor): (#ROIs, 4). The box of each ROI.

    Returns:
        Tensor of shape (#ROIs, #keypoints, 4) with the last dimension corresponding to
        (x, y, logit, score) for each keypoint.

    When converting discrete pixel indices in an NxN image to a continuous keypoint coordinate,
    we maintain consistency with :meth:`Keypoints.to_heatmap` by using the conversion from
    Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
    """

    offset_x = rois[:, 0]
    offset_y = rois[:, 1]

    widths = (rois[:, 2] - rois[:, 0]).clamp(min=1)
    heights = (rois[:, 3] - rois[:, 1]).clamp(min=1)
    widths_ceil = widths.ceil()
    heights_ceil = heights.ceil()

    num_rois, num_keypoints = maps.shape[:2]
    xy_preds = maps.new_zeros(rois.shape[0], num_keypoints, 4)

    width_corrections = widths / widths_ceil
    height_corrections = heights / heights_ceil

    keypoints_idx = torch.arange(num_keypoints, device=maps.device)

    for i in range(num_rois):
        outsize = (int(heights_ceil[i]), int(widths_ceil[i]))
        roi_map = F.interpolate(maps[[i]], size=outsize, mode="bicubic", align_corners=False)[0]

        max_score, _ = roi_map.view(num_keypoints, -1).max(1)
        max_score = max_score.view(num_keypoints, 1, 1)
        roi_map = roi_map - max_score
        exp_map = roi_map.exp()
        total_exp = exp_map.view(num_keypoints, -1).sum(dim=1).view(num_keypoints, 1, 1)
        roi_map_scores = exp_map / total_exp

        w = roi_map.shape[2]
        pos = roi_map.view(num_keypoints, -1).argmax(1)

        x_int = pos % w
        y_int = (pos - x_int) // w

        assert (
            roi_map_scores[keypoints_idx, y_int, x_int]
            == roi_map_scores.view(num_keypoints, -1).max(1)[0]
        ).all()

        x = (x_int.float() + 0.5) * width_corrections[i]
        y = (y_int.float() + 0.5) * height_corrections[i]

        xy_preds[i, :, 0] = x + offset_x[i]
        xy_preds[i, :, 1] = y + offset_y[i]
        xy_preds[i, :, 2] = roi_map.view(num_keypoints, -1).gather(1, pos.unsqueeze(1)).squeeze(1)
        xy_preds[i, :, 3] = roi_map_scores.view(num_keypoints, -1).gather(1, pos.unsqueeze(1)).squeeze(1)

    return xy_preds
    

@Huxwell
Copy link
Author

Huxwell commented May 27, 2024

@RajUpadhyay FYI I am able now to run keypoints prediction in TensorRT (heatmaps -> keypoints + repositioning happens in numpy in postprocessing), describing my process with a little bit more details in TensorRT issue : NVIDIA/TensorRT#3792

@RajUpadhyay
Copy link

@Huxwell Wow, congrats! So glad you could do it!
Thanks for letting me know.
I wonder if this is too much to ask but could I request the create_onnx.py you modified?
Thanks!

@Huxwell
Copy link
Author

Huxwell commented May 29, 2024

Sorry, I asked and apparently my company policy doesn't allow me to, but I think the snippets from the questions I asked in these issues (mostly the roi_head() function) should be enough for you to reproduce the effect relatively easily.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants