Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support more general inference case that query length > 1 #730

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

yidong72
Copy link

@yidong72 yidong72 commented Mar 12, 2024

Currently the inference requires the query tensor has length 1. However, there are use cases that the query tensor length > 1.
Note, this fix requires

  1. TE to use the flash_atten > 2.1 so the correct attention mask is applied according to this document.
  2. TE to enable flash attention by removing this line https://github.com/NVIDIA/TransformerEngine/blob/0fbc76af3733ae997394eaf82b78ff9c0498fe91/transformer_engine/pytorch/attention.py#L2732
  3. Latest PyTorch version that has the following bug fix so the batch size 1 value tensor has correct stride after transpose ops, which is required for flash_attn. Or there is a work around by using tensor copy.
--- a/megatron/core/transformer/custom_layers/transformer_engine.py
+++ b/megatron/core/transformer/custom_layers/transformer_engine.py
@@ -457,6 +457,10 @@ class TEDotProductAttention(te.pytorch.DotProductAttention):
 
         if self.config.apply_rope_fusion and qkv_format == 'bshd':
             query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)]
+        
+        new_value = torch.zeros_like(key)
+        new_value[:] = value
+        value = new_value
 
         if self.te_forward_mask_type:
             core_attn_out = super().forward(

Signed-off-by: Yi Dong <yidong@nvidia.com>
Copy link

Marking as stale. No activity in 60 days.

@github-actions github-actions bot added the stale No activity in 60 days on issue or PR label May 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale No activity in 60 days on issue or PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant