-
Notifications
You must be signed in to change notification settings - Fork 598
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support for torch.cdist #2198
Comments
Can you give us a minimal code to reproduce this issue? |
sure here goes ... macos14.4.1 import torch
from torch import nn
import coremltools as ct
class BugModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, z):
A = torch.ones(576,4)
B = torch.ones(8192,4)
_,min_encoding_indices = torch.min(torch.cdist(A, B), dim=1)
return min_encoding_indices
bug_model = BugModule().eval()
any_shape = (1,4,24,24)
traced_model = torch.jit.trace(bug_model, torch.rand(*any_shape))
coreml_model = ct.convert(traced_model, inputs=[ct.TensorType(shape=any_shape)],
debug=True,
convert_to="mlprogram",
minimum_deployment_target=ct.target.macOS13) output
|
added note : also breaks for me under a (released) pytorch: 2.1.2 |
@SpiraMira please try this: import torch
from torch import nn
import coremltools as ct
def my_cdist(x1: torch.Tensor, x2: torch.Tensor, p=2.0):
assert p == 2.0
x1_norm = x1.pow(2).sum(-1, keepdim=True)
x1_pad = torch.ones_like(x1_norm)
x2_norm = x2.pow(2).sum(-1, keepdim=True)
x2_pad = torch.ones_like(x2_norm)
x1_ = torch.cat([x1.mul(-2), x1_norm, x1_pad], dim=-1)
x2_ = torch.cat([x2, x2_pad, x2_norm], dim=-1)
result = x1_.matmul(x2_.transpose(0, 1))
result = result.clamp_min_(0.0).sqrt_()
return result
# test cdist implementation with random values
def test_cdist():
A = torch.rand(576, 4)
B = torch.rand(8192, 4)
result1 = torch.cdist(A, B)
result2 = my_cdist(A, B)
print(result1)
print(result2)
assert result1.shape == result2.shape
assert torch.eq(result1, result2).all() == True
print("test_cdist: OK")
test_cdist()
class BugModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, z):
A = torch.ones(576,4)
B = torch.ones(8192,4)
# _,min_encoding_indices = torch.min(torch.cdist(A, B), dim=1)
_,min_encoding_indices = torch.min(my_cdist(A, B), dim=1)
return min_encoding_indices
bug_model = BugModule().eval()
any_shape = (1,4,24,24)
traced_model = torch.jit.trace(bug_model, torch.rand(*any_shape))
coreml_model = ct.convert(traced_model, inputs=[ct.TensorType(shape=any_shape)],
debug=True,
convert_to="mlprogram",
minimum_deployment_target=ct.target.macOS13) If you use default parameters (p=2) it should work well https://pytorch.org/docs/stable/_modules/torch/functional.html#cdist def cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'): |
thank you ! works well as a drop in replacement. |
🌱 Describe your Feature Request
How can this feature be used?
Describe alternatives you've considered
OR
Additional context
The text was updated successfully, but these errors were encountered: