Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ops.map_coordinates #906

Merged
merged 6 commits into from
Sep 19, 2023
Merged

Conversation

james77777778
Copy link
Contributor

@james77777778 james77777778 commented Sep 18, 2023

Related to keras-team/keras#18442

This PR has implemented ops.map_coordinates for all backends based on the PR from @mihirparadkar #784

It is challenge to obtain a jittable map_coordinates for tensorflow, but I managed to figure out the solution. The key is to use tf.unstack to separate coordinates and form a list of tensor for subsequent operations.

The unit test is borrowed from jax and has been simpified
https://github.com/google/jax/blob/bcc545a69232e983ae31b0395f4972979f2789c0/tests/scipy_ndimage_test.py#L79

The standalone script:

import math

import numpy as np

from keras_core.backend.jax.image import map_coordinates as jax_map_coordinates
from keras_core.backend.numpy.image import map_coordinates as np_map_coordinates
from keras_core.backend.tensorflow.image import map_coordinates as tf_map_coordinates
from keras_core.backend.torch.image import map_coordinates as torch_map_coordinates
import tensorflow as tf

np.random.seed(42)
shape = (3, 4, 5)
coords_shape = (2, 3, 4)
dtype = "float32"
x = np.arange(math.prod(shape), dtype=dtype).reshape(shape)
coords = [
    (size - 1) * np.random.uniform(size=coords_shape).astype(dtype)
    for size in shape
]

print("jax:")
print(jax_map_coordinates(x, coords, 1))
print("np:")
print(np_map_coordinates(x, coords, 1))
print("torch:")
print(torch_map_coordinates(x, coords, 1))
print("tf eager:")
print(tf_map_coordinates(x, coords, 1))
print("tf xla:")
print(tf.function(tf_map_coordinates, jit_compile=True)(x, coords, 1))

Results:

Using TensorFlow backend
jax:
[[[24.009495  50.545628  36.153202  34.760387 ]
  [18.884958  10.515846  13.828117  40.892403 ]
  [25.374344  43.34012   15.488769  52.22368  ]]

 [[39.421623  11.044044  20.851446  15.36548  ]
  [15.1240015 30.588694  18.357327  28.497757 ]
  [28.654016  19.465136  19.45043   23.250359 ]]]
np:
[[[24.009495  50.54563   36.153202  34.76039  ]
  [18.884958  10.515847  13.828115  40.892403 ]
  [25.374344  43.340122  15.488769  52.22368  ]]

 [[39.42162   11.044042  20.851444  15.36548  ]
  [15.1240015 30.588696  18.357325  28.497759 ]
  [28.654016  19.465137  19.450432  23.250357 ]]]
torch:
tensor([[[24.0095, 50.5456, 36.1532, 34.7604],
         [18.8850, 10.5158, 13.8281, 40.8924],
         [25.3743, 43.3401, 15.4888, 52.2237]],

        [[39.4216, 11.0440, 20.8514, 15.3655],
         [15.1240, 30.5887, 18.3573, 28.4978],
         [28.6540, 19.4651, 19.4504, 23.2504]]], device='cuda:0')
tf eager:
tf.Tensor(
[[[24.009495  50.545628  36.153202  34.760387 ]
  [18.884958  10.515846  13.828117  40.892403 ]
  [25.374344  43.34012   15.488769  52.22368  ]]

 [[39.421623  11.044044  20.851446  15.36548  ]
  [15.1240015 30.588694  18.357327  28.497757 ]
  [28.654016  19.465136  19.45043   23.250359 ]]], shape=(2, 3, 4), dtype=float32)
tf xla:
tf.Tensor(
[[[24.009495  50.545628  36.153202  34.760387 ]
  [18.884958  10.515846  13.828117  40.892403 ]
  [25.374344  43.34012   15.488769  52.22368  ]]

 [[39.421623  11.044044  20.851446  15.36548  ]
  [15.1240015 30.588694  18.357327  28.497757 ]
  [28.654016  19.465136  19.45043   23.250359 ]]], shape=(2, 3, 4), dtype=float32)

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR -- Excellent work! 👍


Note that interpolation near boundaries differs from the scipy function,
because we fixed an outstanding bug
https://github.com/scipy/scipy/issues/2640.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use markdown for links.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

@codecov
Copy link

codecov bot commented Sep 19, 2023

Codecov Report

Patch coverage: 86.06% and project coverage change: +0.01% 🎉

Comparison is base (b4019bc) 83.63% compared to head (722a9d1) 83.64%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     keras-team/keras-core#906      +/-   ##
==========================================
+ Coverage   83.63%   83.64%   +0.01%     
==========================================
  Files         318      318              
  Lines       28391    28556     +165     
  Branches     5409     5440      +31     
==========================================
+ Hits        23745    23887     +142     
- Misses       3147     3160      +13     
- Partials     1499     1509      +10     
Flag Coverage Δ
keras_core 83.54% <86.06%> (+0.01%) ⬆️
keras_core-jax 67.29% <15.75%> (-0.30%) ⬇️
keras_core-numpy 60.50% <21.21%> (-0.23%) ⬇️
keras_core-tensorflow 66.94% <43.03%> (-0.14%) ⬇️
keras_core-torch 69.32% <49.09%> (-0.12%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Changed Coverage Δ
keras_core/backend/jax/image.py 76.00% <42.85%> (-3.42%) ⬇️
keras_core/backend/numpy/image.py 79.06% <71.42%> (-1.49%) ⬇️
keras_core/ops/image.py 76.22% <73.68%> (-0.47%) ⬇️
keras_core/backend/tensorflow/image.py 80.73% <90.47%> (+13.34%) ⬆️
keras_core/backend/torch/image.py 78.94% <93.54%> (+8.30%) ⬆️

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the great contribution!

@fchollet fchollet merged commit 956e89a into keras-team:main Sep 19, 2023
8 checks passed
@james77777778 james77777778 deleted the add-map-coordinates branch September 19, 2023 03:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants