-
Notifications
You must be signed in to change notification settings - Fork 175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
perf: bf16 kernels #1664
base: main
Are you sure you want to change the base?
perf: bf16 kernels #1664
Conversation
8231900
to
1578a24
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite understand the instructions involved, but it looks like you are frequently casting / reinterpreting bfloat16
as normal float16
and using float16
add / subtract instructions, which I'm worried isn't valid.
32447d1
to
88bd113
Compare
For bf16, we have https://doc.rust-lang.org/core/arch/x86_64/fn._mm512_dpbf16_ps.html , should we just use rust code instead? L2 = dot(x, x) + dot(y, y) - 2 * dot(x, y) |
Cosine = 1 - dot(x, y) This would indeed simplify things a lot, however I would compare performance to see if loading from memory (e.g. with |
Why we need to load 3 times in rust? Not sure i follow. |
@eddyxu Oh sorry I misunderstood your question. I thought you were suggesting calling the dot kernel to calculate L2. |
9ba418f
to
b41de61
Compare
b41de61
to
c1c3a16
Compare
@eddyxu I've looked at using |
See #1651