Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QMM, QMV, and QVM using FP32 for Mul and Acc even when running F16, BF16 variants, etc #963

Open
arpan-dhatt opened this issue Apr 5, 2024 · 1 comment

Comments

@arpan-dhatt
Copy link

arpan-dhatt commented Apr 5, 2024

Question regarding QMM, QMV, and QVM kernels. When using the Metal Debugger on them I noticed that no matter the data type used for the operation itself (e.g. float16) all the simdgroup_matrix operations happend in fp32. The fp16 utilization effectively 0.

I went ahead and made the very trivial changes to the quantized.metal file to use fp16 for multiplication and accumulation (set type template param for BlockMMA to half in qmm_t and misc changes to static_cast expressions using float constants that wouldn't do it implicitly). For QMV and QVM kernels I got basically no speedup (memory bound for sure) but for QMM I'm getting roughly a 5-6% on M2 and M1 Max.

I added a benchmark to benchmarks/cpp/single_ops.cpp to check and those were the results I'm getting. I also tried just looking at the error between the QMM/QMV/QVM and un-quantized versions and they were negligible, compiled it and used a language model with it mlx-community/Starling-LM-7B-beta and outputs were also great and prompt processing speed increased.

I haven't run any vigorous perplexity measurements to make sure there isn't any problem for sure though. I can make a PR for this stuff, but I'm a bit hesistant since this was a very simple change and accumulating FP16 with FP16 can of course have overflow problems, so I assumed that was why it wasn't done. After trying it, it seems fine though? I haven't looked at activation distributions in a lot of models (or even LLM's specifically) to know if this is a good idea generally. Perhaps it would be best to do multiply in FP16/BF16 and acc into FP32 assuming there's enough register space to do it.

HEAD on this fork: https://github.com/arpan-dhatt/mlx

@angeloskath
Copy link
Member

Thanks that's good to have. For QMM we have much bigger low hanging fruit that we should explore first (the dequantization needs to be updated like in qmv/qvm). Accumulating in float16 is risky and as you mentioned it may be better to change it to accumulate in float32 but multiply in float16, we 'll see.

For now, I would not PR that but if we do decide to go that way in the future I 'll let you know and you can PR it then.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants