Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PRE-ROPE quantization during inference #1

Closed
minghaoBD opened this issue Feb 21, 2024 · 1 comment
Closed

PRE-ROPE quantization during inference #1

minghaoBD opened this issue Feb 21, 2024 · 1 comment

Comments

@minghaoBD
Copy link

minghaoBD commented Feb 21, 2024

Thanks for the great work! I am curious about the time complexity of the pre-rope quantization.

In detail, I assume the operations act as the following orders with pre-rope quant during inference: qkv_projection_matmul -> quantize_k -> write_cache_k -> load_cache_k -> dequantize_k -> rope_k -> transpose_k. However, in the decode phase, the sequence length is getting longer per step, making it necessary to apply rope_k on all the previous token features for each step. This is an O(m*m) time complexity where m is sequence_length.

This differs with post-rope case, because for post one, what in cache is post-rope quantized key. Time complexity is O(m).

One way to walk around is saving the rope result to another cache, making the time complexity O(m) but it costs much more storage space. Another way I suppose is to over-write the cache with post rope key (bfloat16/float16) but it will be conflict with the default cache dtype (INT4/INT2).

Please correct me if anything wrong above. And looking forward to your reply. Thanks.

@chooper1
Copy link
Collaborator

Thank you for your interest in our work! It is true that we need to apply RoPE to all of the previous token features at each timestep. Applying RoPE is an O(m*d) operation, where m is the sequence length and d is the hidden dimension (as we are using the element-wise formulation highlighted in Appendix A in our preprint). Although we still need to apply this operation at each time step, since loading the key cache for the matrix-vector multiplication is memory bandwidth-bound, we find that we can overlap the added computation with loading data from memory in order to avoid latency overheads. Let me know if you have any further questions!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants