Skip to content

Commit

Permalink
Fix sdp logic (#10896)
Browse files Browse the repository at this point in the history
* fix

* fix
  • Loading branch information
gc-fu committed Apr 28, 2024
1 parent 015d07a commit c9fac8c
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions python/llm/src/ipex_llm/transformers/models/llama.py
Expand Up @@ -1353,10 +1353,31 @@ def llama_attention_forward_4_36_original(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# otherwise, use native attention
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
attention_mask,
bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads, output_attentions)
if query_states.device.type == "xpu":
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
attention_mask,
bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads, output_attentions)
else:
# CPU path
if not output_attentions:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with
# AttentionMaskConverter.to_causal_4d that
# does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)
else:
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
attention_mask,
bsz, q_len, kv_seq_len,
self.head_dim,
self.num_heads, output_attentions)

attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
if attn_output.size() != attn_output_size:
Expand Down Expand Up @@ -1778,9 +1799,6 @@ def llama_model_forward_4_36_internal(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

# IPEX-LLM modifications:
# Disable sdpa for CPU
self._use_sdpa = False
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) \
Expand Down

0 comments on commit c9fac8c

Please sign in to comment.