-
Notifications
You must be signed in to change notification settings - Fork 856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Coalesced memory read for a slightly faster LLM interence #970
Conversation
I spent quite some time playing with it and it is really really hard to figure out if there is any performance benefit. I changed the unrolled read to a constexpr forloop that would get rid of the static assertion and run 100 generations using the for loop twice (one of them the pointer advancement is done later), this PR and v0.9.1 . I get the following graph. I believe that it doesn't really matter and that the variation has to do with whatever throttling or any other operation that is happening on the machine at the time. There is absolutely no reason the for loop should be faster than the unrolled loop and even if it is, it is a tiny amount. @youknow04 let me know what you think. |
Thank you for your detailed feedback and insights. Let me address each of your points:
I think we now need the LLN(Law of Large Numbers) to test this tiny gain in the LLM. |
I tried to reproduce the results using the following steps:
Used python code to reproduceimport json
import os
import random
import statistics
import subprocess
import time
from dataclasses import dataclass
MLX_LM_PATH = os.getcwd()
MLX_PATH = os.path.expanduser("~/workspace/mlx")
NUM_BATCH = 32
NUM_ITER = 999
RESULT_FILE = "result.jsonl"
@dataclass
class Target:
branch: str
commit_hash: str
targets = [
Target("main", "bddf23f175726a57f0e443cd45518c0757daa166"),
Target("coalesced-mem", "ab58e3718a3e31687e3ef5e8914034855990454f"),
Target("coalesced-mem-unroll-t", "28057ab228de7d1b1042622559b4a1ef7ba14a12"),
# Target("coalesced-mem-for-1", "893a6327b7ec1d2b391c6e8e27822fecbf5d0e04"),
# Target("coalesced-mem-for-2", "af9a0269a35ce68409e4f4f5e09d6f82da55c835"),
]
def run_command(command: list[str]):
return subprocess.check_output(command).decode("utf-8")
def setup_mlx(target: Target):
os.chdir(MLX_PATH)
run_command(["git", "switch", target.branch])
print(f"installing mlx[{target.branch}]")
run_command(["pip", "install", "."])
pip_list_result = subprocess.check_output(["pip", "list"]).decode("utf-8")
if f"+{target.commit_hash[:8]}" not in pip_list_result:
raise RuntimeError(f"wrong MLX version installed. {target.branch}")
os.chdir(MLX_LM_PATH)
def show_result():
parsed: dict[str, list[float]] = {t.branch: [] for t in targets}
with open(RESULT_FILE, "r") as f:
for line in f:
result = json.loads(line)
for r in result:
parsed[r["branch"]].append(float(r["tps"]))
for branch, tps_values in parsed.items():
mean_tps = statistics.mean(tps_values)
median_tps = statistics.median(tps_values)
print(f"{branch}: Mean = {mean_tps}, Median = {median_tps} Min = {min(tps_values)}, Max = {max(tps_values)}")
if __name__ == "__main__":
os.chdir(MLX_LM_PATH)
for i in range(NUM_ITER):
r_samples = random.sample(targets, len(targets))
print(f"{i}th iteration with {r_samples}")
sample_result: list[dict[str, str|float]] = []
for target in r_samples:
setup_mlx(target)
for b in range(NUM_BATCH):
time.sleep(2) # to prevent throttling
llm_result = run_command(["python", "-m", "mlx_lm.generate", "--model", "mlx_mistral7b-8q", "--prompt", "'write a quicksort algorithm in c++'", "--max-tokens", "128"])
tps = float(llm_result.split("Generation: ")[1].split(" tokens-per-sec")[0])
sample_result.append({
"branch": target.branch,
"iter": i,
"batch": b,
"tps": tps,
})
with open(RESULT_FILE, "a") as f:
f.write(json.dumps(sample_result)+"\n")
show_result() and got this result
Result Summary:
In my opinion, for truly low-latency code, complex implementations are inevitable. @angeloskath let me know what you think. |
It seems you have no interest in this PR. |
Background
I enjoyed reading the posts related to this tweet: https://twitter.com/awnihannun/status/1776275621926375498
I'd like to contribute to the MLX side in the challenge with Llama.cpp, even if it's just a small improvement. :)
Proposed changes
This PR modifies the way
scales
andbiases
are read from memory, making the access more coalesced.Although the data in these variables are sparsely arranged for a SIMD thread, they can be read in a more adjacent manner when considering all SIMD threads together.
Performance Benchmarks
Performance improved around
0.2%
with this change.python -m mlx_lm.generate --model mlx_model --prompt "write a quicksort algorithm in c++" --max-tokens 256
Token Processing Speed (tokens per second, higher is better, 15 runs):
Checklist
pre-commit run --all-files
to format my code and installed pre-commit prior to committing changesAdditional Notes
clang-format
toquantized.metal
resulted in significant formatting changes. To avoid overwhelming diffs, I opted to bypass the linting step for this file.