You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks for the great work! I am curious about the time complexity of the pre-rope quantization.
In detail, I assume the operations act as the following orders with pre-rope quant during inference: qkv_projection_matmul -> quantize_k -> write_cache_k -> load_cache_k -> dequantize_k -> rope_k -> transpose_k. However, in the decode phase, the sequence length is getting longer per step, making it necessary to apply rope_k on all the previous token features for each step. This is an O(m*m) time complexity where m is sequence_length.
This differs with post-rope case, because for post one, what in cache is post-rope quantized key. Time complexity is O(m).
One way to walk around is saving the rope result to another cache, making the time complexity O(m) but it costs much more storage space. Another way I suppose is to over-write the cache with post rope key (bfloat16/float16) but it will be conflict with the default cache dtype (INT4/INT2).
Please correct me if anything wrong above. And looking forward to your reply. Thanks.
The text was updated successfully, but these errors were encountered:
Thank you for your interest in our work! It is true that we need to apply RoPE to all of the previous token features at each timestep. Applying RoPE is an O(m*d) operation, where m is the sequence length and d is the hidden dimension (as we are using the element-wise formulation highlighted in Appendix A in our preprint). Although we still need to apply this operation at each time step, since loading the key cache for the matrix-vector multiplication is memory bandwidth-bound, we find that we can overlap the added computation with loading data from memory in order to avoid latency overheads. Let me know if you have any further questions!
Thanks for the great work! I am curious about the time complexity of the pre-rope quantization.
In detail, I assume the operations act as the following orders with pre-rope quant during inference: qkv_projection_matmul -> quantize_k -> write_cache_k -> load_cache_k -> dequantize_k -> rope_k -> transpose_k. However, in the decode phase, the sequence length is getting longer per step, making it necessary to apply rope_k on all the previous token features for each step. This is an O(m*m) time complexity where m is sequence_length.
This differs with post-rope case, because for post one, what in cache is post-rope quantized key. Time complexity is O(m).
One way to walk around is saving the rope result to another cache, making the time complexity O(m) but it costs much more storage space. Another way I suppose is to over-write the cache with post rope key (bfloat16/float16) but it will be conflict with the default cache dtype (INT4/INT2).
Please correct me if anything wrong above. And looking forward to your reply. Thanks.
The text was updated successfully, but these errors were encountered: