Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Coalesced memory read for a slightly faster LLM interence #970

Closed
wants to merge 1 commit into from

Conversation

youknow04
Copy link

Background

I enjoyed reading the posts related to this tweet: https://twitter.com/awnihannun/status/1776275621926375498
I'd like to contribute to the MLX side in the challenge with Llama.cpp, even if it's just a small improvement. :)

Proposed changes

This PR modifies the way scales and biases are read from memory, making the access more coalesced.
Although the data in these variables are sparsely arranged for a SIMD thread, they can be read in a more adjacent manner when considering all SIMD threads together.

Performance Benchmarks

Performance improved around 0.2% with this change.

  • Machine: M3MAX (128GB)
  • LLM: mistral-7b (4Q)
  • Client: https://github.com/ml-explore/mlx-examples
  • Command: python -m mlx_lm.generate --model mlx_model --prompt "write a quicksort algorithm in c++" --max-tokens 256
  • Before: main(tag: v0.9.1)
  • After: this PR

Token Processing Speed (tokens per second, higher is better, 15 runs):

Before After Speedup
Median 68.554 68.680 0.18%
Mean 68.557 68.734 0.26%
Min 68.210 68.592 0.56%
Max 68.800 69.087 0.42%

Checklist

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code and installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Additional Notes

  • All tests passed.
  • Applying clang-format to quantized.metal resulted in significant formatting changes. To avoid overwhelming diffs, I opted to bypass the linting step for this file.

@angeloskath
Copy link
Member

I spent quite some time playing with it and it is really really hard to figure out if there is any performance benefit. I changed the unrolled read to a constexpr forloop that would get rid of the static assertion and run 100 generations using the for loop twice (one of them the pointer advancement is done later), this PR and v0.9.1 .

I get the following graph.

tps

I believe that it doesn't really matter and that the variation has to do with whatever throttling or any other operation that is happening on the machine at the time. There is absolutely no reason the for loop should be faster than the unrolled loop and even if it is, it is a tiny amount.

@youknow04 let me know what you think.

@youknow04
Copy link
Author

Thank you for your detailed feedback and insights. Let me address each of your points:

  • Tiny performance gain issue
    Matrix-vector multiplication is a fundamental operation and the core computational load in Transformers.
    Optimizing this primitive operation can indeed have a persistent effect on all variants.

  • Faster for loop issue
    First, I should mention that this is only my second time programming in Metal, so I don't have expertise in the Metal framework itself.
    However, I have done such low-latency optimizations many times, and unrolling is not always faster than a for loop, especially when the unrolled code is too long.
    The code itself should be in memory, so larger code size implies more memory overhead, such as cache misses.
    But I didn't expect the for loop to be faster than the unrolled version here.
    I appreciate your intensive checking.

  • Experiment issue
    I don't think it's a throttling issue because I manually tested each run with checks.
    I tested performance with GPU monitoring, ran the LLM when GPU utilization was idle, executed each run with enough rest to prevent throttling issues, and the cooling fan was not triggered in my M3MAX (128G) for all tests.

I think we now need the LLN(Law of Large Numbers) to test this tiny gain in the LLM.
MLX is still in its early stages, and you must have a lot of work to do.
I want to support the MLX team and really don't want to take your time for this hacky gain.
I will make the test code for intensive test and share the result again.

@youknow04
Copy link
Author

youknow04 commented Apr 14, 2024

I tried to reproduce the results using the following steps:

  • reboot my mac
  • runsudo sysctl iogpu.wired_lwm_mb=100000
  • run following python script
    • this is randomized script to prevent possible biases.
Used python code to reproduce
import json
import os
import random
import statistics
import subprocess
import time

from dataclasses import dataclass

MLX_LM_PATH = os.getcwd()
MLX_PATH = os.path.expanduser("~/workspace/mlx")
NUM_BATCH = 32
NUM_ITER = 999
RESULT_FILE = "result.jsonl"

@dataclass
class Target:
    branch: str
    commit_hash: str

targets = [
    Target("main", "bddf23f175726a57f0e443cd45518c0757daa166"),
    Target("coalesced-mem", "ab58e3718a3e31687e3ef5e8914034855990454f"),
    Target("coalesced-mem-unroll-t", "28057ab228de7d1b1042622559b4a1ef7ba14a12"),
    # Target("coalesced-mem-for-1", "893a6327b7ec1d2b391c6e8e27822fecbf5d0e04"),
    # Target("coalesced-mem-for-2", "af9a0269a35ce68409e4f4f5e09d6f82da55c835"),
]

def run_command(command: list[str]):
    return subprocess.check_output(command).decode("utf-8")


def setup_mlx(target: Target):
    os.chdir(MLX_PATH)
    run_command(["git", "switch", target.branch])

    print(f"installing mlx[{target.branch}]")
    run_command(["pip", "install", "."])
    pip_list_result = subprocess.check_output(["pip", "list"]).decode("utf-8")
    if f"+{target.commit_hash[:8]}" not in pip_list_result:
        raise RuntimeError(f"wrong MLX version installed. {target.branch}")
    os.chdir(MLX_LM_PATH)

def show_result():
    parsed: dict[str, list[float]] = {t.branch: [] for t in targets}
    with open(RESULT_FILE, "r") as f:
        for line in f:
            result = json.loads(line)
            for r in result:
                parsed[r["branch"]].append(float(r["tps"]))

    for branch, tps_values in parsed.items():
        mean_tps = statistics.mean(tps_values)
        median_tps = statistics.median(tps_values)
        print(f"{branch}: Mean = {mean_tps}, Median = {median_tps} Min = {min(tps_values)}, Max = {max(tps_values)}")


if __name__ == "__main__":
    os.chdir(MLX_LM_PATH)

    for i in range(NUM_ITER):
        r_samples = random.sample(targets, len(targets))
        print(f"{i}th iteration with {r_samples}")
        sample_result: list[dict[str, str|float]] = []
        for target in r_samples:
            setup_mlx(target)
            for b in range(NUM_BATCH):
                time.sleep(2)  # to prevent throttling
                llm_result = run_command(["python", "-m", "mlx_lm.generate", "--model", "mlx_mistral7b-8q", "--prompt", "'write a quicksort algorithm in c++'", "--max-tokens", "128"])
                tps = float(llm_result.split("Generation: ")[1].split(" tokens-per-sec")[0])
                sample_result.append({
                    "branch": target.branch,
                    "iter": i,
                    "batch": b,
                    "tps": tps,
                })
        with open(RESULT_FILE, "a") as f:
            f.write(json.dumps(sample_result)+"\n")
        show_result()

and got this result

mistral7b-4bit
Configuration Mean Median Min Max
coalesced-mem-for-1 69.817 69.808 69.427 71.343
coalesced-mem-for-2 69.832 69.813 69.296 70.835
main 69.848 69.840 69.339 71.243
coalesced-mem 69.875 69.861 69.421 70.648
coalesced-mem-unroll-t 69.966 69.941 69.500 71.985
mistral7b-8bit
Configuration Mean Median Min Max
main 42.391 42.383 42.218 42.806
coalesced-mem 42.416 42.404 42.225 42.864
coalesced-mem-unroll-t 42.476 42.468 42.295 42.744

Result Summary:

  • After rebooting and the iogpu.wired_lwm_mb setting, all inference speeds increased, and the speed gap between different configurations was reduced.
  • The rank order remained consistent.
  • The combination of this PR and Slight speed improvement for LLM inference #803 shows the best performance.
  • We now have clean results, and throttling or other operational issues have minimal impact due to randomization and the LLN.
  • This PR delivers improved performance, and it is even faster when combined with Slight speed improvement for LLM inference #803.

In my opinion, for truly low-latency code, complex implementations are inevitable.
we also have room for improvement when we choose this direction.

@angeloskath let me know what you think.

@youknow04
Copy link
Author

It seems you have no interest in this PR.
I will close this.

@youknow04 youknow04 closed this May 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants