Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Not BUG] Scatter C++ does not work for 1D arrays #756

Open
cemlyn007 opened this issue Feb 28, 2024 · 3 comments
Open

[Not BUG] Scatter C++ does not work for 1D arrays #756

cemlyn007 opened this issue Feb 28, 2024 · 3 comments
Labels
documentation Improvements or additions to documentation

Comments

@cemlyn007
Copy link

cemlyn007 commented Feb 28, 2024

Describe the bug
I could just have an incorrect understanding of how to use scatter? I am trying to update a 1D array.

To Reproduce

int main() {
  auto x = mlx::core::array({false, false, false}, mlx::core::bool_);
  auto i = mlx::core::array({0, 2}, mlx::core::int32);
  auto v = mlx::core::array({true, true}, mlx::core::bool_);
  if (x.ndim() != 1) {
    throw std::invalid_argument("x must be a 1D array");
  }
  if (i.ndim() != 1) {
    throw std::invalid_argument("i must be a 1D array");
  }
  if (v.ndim() != 1) {
    throw std::invalid_argument("v must be a 1D array");
  }
  x = mlx::core::scatter(x, i, v, 0);
}

Throws:

libc++abi: terminating due to uncaught exception of type std::invalid_argument: [scatter] Updates with 1 dimensions does not match the sum of the array and indices dimensions 2.

Expected behaviour
A clear and concise description of what you expected to happen.
To get x == [True, False, True]

In Python this works:

>>> x = mlx.core.array([False, False, False])
>>> i = mlx.core.array([0, 1])
>>> v = mlx.core.array([True, True])
>>> x.shape, i.shape, v.shape
((3,), (2,), (2,))
>>> x[i] = v
>>> x
array([True, True, False], dtype=bool)

Desktop (please complete the following information):
-Mac ARM64

@cemlyn007 cemlyn007 changed the title [BUG] MLX Scatter C++ for 1D arrays [BUG] Scatter C++ for 1D arrays Feb 28, 2024
@cemlyn007
Copy link
Author

It seems like this works but I am confused as to why the 1D case wouldn't work?

int main() {
  auto x = mlx::core::array({false, false, false}, mlx::core::bool_);
  auto i = mlx::core::array({0, 2}, mlx::core::int32);
  auto v = mlx::core::array({true, true}, mlx::core::bool_);
  if (x.ndim() != 1) {
    throw std::invalid_argument("x must be a 1D array");
  }
  if (i.ndim() != 1) {
    throw std::invalid_argument("i must be a 1D array");
  }
  if (v.ndim() != 1) {
    throw std::invalid_argument("v must be a 1D array");
  }
  v = mlx::core::expand_dims(v, 1);
  x = mlx::core::scatter(x, i, v, 0);
}

@awni
Copy link
Member

awni commented Feb 29, 2024

This isn't a bug. The C++ API is a little hard to use and undocumented, so sorry you ran into that issue.

For the C++ scatter API, the following must be true v.ndim == i.ndim + x.ndim. This is so we know which part of the update corresponds to each index.

I will leave this open and mark it as documentation. I don't think we will change the C++ scatter op. We may add an operator[] which will behave more like python index updates.

@awni awni added the documentation Improvements or additions to documentation label Feb 29, 2024
@cemlyn007
Copy link
Author

Ah that's cool, feel free to close this and hopefully if others stumble into this then they will find this issue as a piece of documentation 💯 Thank you very much! I'm slowly dipping my toes trying to build a little Snake environment using MLX.

Hope you have a wonderful day!

@cemlyn007 cemlyn007 changed the title [BUG] Scatter C++ for 1D arrays [Not BUG] Scatter C++ does not work for 1D arrays Feb 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

2 participants