You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
I could just have an incorrect understanding of how to use scatter? I am trying to update a 1D array.
To Reproduce
int main() {
auto x = mlx::core::array({false, false, false}, mlx::core::bool_);
auto i = mlx::core::array({0, 2}, mlx::core::int32);
auto v = mlx::core::array({true, true}, mlx::core::bool_);
if (x.ndim() != 1) {
throw std::invalid_argument("x must be a 1D array");
}
if (i.ndim() != 1) {
throw std::invalid_argument("i must be a 1D array");
}
if (v.ndim() != 1) {
throw std::invalid_argument("v must be a 1D array");
}
x = mlx::core::scatter(x, i, v, 0);
}
Throws:
libc++abi: terminating due to uncaught exception of type std::invalid_argument: [scatter] Updates with 1 dimensions does not match the sum of the array and indices dimensions 2.
Expected behaviour
A clear and concise description of what you expected to happen.
To get x == [True, False, True]
In Python this works:
>>> x = mlx.core.array([False, False, False])
>>> i = mlx.core.array([0, 1])
>>> v = mlx.core.array([True, True])
>>> x.shape, i.shape, v.shape
((3,), (2,), (2,))
>>> x[i] = v
>>> x
array([True, True, False], dtype=bool)
Desktop (please complete the following information):
-Mac ARM64
It seems like this works but I am confused as to why the 1D case wouldn't work?
int main() {
auto x = mlx::core::array({false, false, false}, mlx::core::bool_);
auto i = mlx::core::array({0, 2}, mlx::core::int32);
auto v = mlx::core::array({true, true}, mlx::core::bool_);
if (x.ndim() != 1) {
throw std::invalid_argument("x must be a 1D array");
}
if (i.ndim() != 1) {
throw std::invalid_argument("i must be a 1D array");
}
if (v.ndim() != 1) {
throw std::invalid_argument("v must be a 1D array");
}
v = mlx::core::expand_dims(v, 1);
x = mlx::core::scatter(x, i, v, 0);
}
This isn't a bug. The C++ API is a little hard to use and undocumented, so sorry you ran into that issue.
For the C++ scatter API, the following must be true v.ndim == i.ndim + x.ndim. This is so we know which part of the update corresponds to each index.
I will leave this open and mark it as documentation. I don't think we will change the C++ scatter op. We may add an operator[] which will behave more like python index updates.
Ah that's cool, feel free to close this and hopefully if others stumble into this then they will find this issue as a piece of documentation 💯 Thank you very much! I'm slowly dipping my toes trying to build a little Snake environment using MLX.
Hope you have a wonderful day!
cemlyn007
changed the title
[BUG] Scatter C++ for 1D arrays
[Not BUG] Scatter C++ does not work for 1D arrays
Feb 29, 2024
Describe the bug
I could just have an incorrect understanding of how to use scatter? I am trying to update a 1D array.
To Reproduce
Throws:
Expected behaviour
A clear and concise description of what you expected to happen.
To get x == [True, False, True]
In Python this works:
Desktop (please complete the following information):
-Mac ARM64
The text was updated successfully, but these errors were encountered: