Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] matmul yields different results when using concat #1082

Open
muchi674 opened this issue May 6, 2024 · 1 comment
Open

[BUG] matmul yields different results when using concat #1082

muchi674 opened this issue May 6, 2024 · 1 comment

Comments

@muchi674
Copy link

muchi674 commented May 6, 2024

Describe the bug
matmul yields different result when multiplying vectors concatenated into the same matrix versus multiplying them separately

To Reproduce
code:

import mlx.core as mx
import numpy as np

W, H = 2, 5


def test0():
    w = mx.random.uniform(-1, 1, (W, H), dtype=mx.float16)
    x0 = mx.random.uniform(-1, 1, (W,), dtype=mx.float16)
    x1 = mx.random.uniform(-1, 1, (W,), dtype=mx.float16)
    print(mx.array([x0 @ w, x1 @ w]))
    print(mx.array([x0, x1]) @ w)


def test1():
    np_w = np.random.uniform(-1, 1, (W, H)).astype(np.float16)
    x0 = np.random.uniform(-1, 1, (W,)).astype(np.float16)
    x1 = np.random.uniform(-1, 1, (W,)).astype(np.float16)
    print(np.array([x0 @ np_w, x1 @ np_w]))
    print(np.array([x0, x1]) @ np_w)


if __name__ == "__main__":
    print("mlx:")
    mx.random.seed(0)
    test0()

    print("numpy:")
    np.random.seed(0)
    test1()

output:

mlx:
array([[-0.256348, 0.953125, 0.179932, 0.740234, -0.149292],
       [0.234131, -0.737305, -0.0961914, -0.549805, 0.0778198]], dtype=float16)
array([[-0.256348, 0.953125, 0.179932, 0.740234, -0.149292],
       [0.234131, -0.737793, -0.0962524, -0.550293, 0.0778809]], dtype=float16)
numpy:
[[ 0.07385  0.2439   0.1653   0.10596 -0.1026 ]
 [ 0.2615  -0.04764  0.695    0.8013  -0.2192 ]]
[[ 0.07385  0.2439   0.1653   0.10596 -0.1026 ]
 [ 0.2615  -0.04764  0.695    0.8013  -0.2192 ]]

Expected behavior
the last four numbers of mlx output should match in the two versions

Desktop (please complete the following information):

  • OS Version: MacOS 14.4.1
  • MLX Version: 0.12.2
  • NumPy Version: 1.26.4

Additional context
If this is the case, there should be plenty of issues in your mlx_lm library

@awni
Copy link
Member

awni commented May 6, 2024

I don't think this is a bug but due to numerical differences as order of operations with finite precisions is not associative and the two versions you have could have different orders. The lower precision exacerbates the effect.

If you need them to match (or at least be a lot closer), use fp32. I tried it and I they were identical in that case.

I will let @jagrit06 comment on this before closing, to be sure. Also if you notice any instances with larger discrepancies that would be useful to share.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants