diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index e542c6c5822..077faf50c5b 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -1346,10 +1346,31 @@ def llama_attention_forward_4_36_original( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) # otherwise, use native attention - attn_output, attn_weights = native_sdp(query_states, key_states, value_states, - attention_mask, - bsz, q_len, kv_seq_len, - self.head_dim, self.num_heads, output_attentions) + if query_states.device.type == "xpu": + attn_output, attn_weights = native_sdp(query_states, key_states, value_states, + attention_mask, + bsz, q_len, kv_seq_len, + self.head_dim, self.num_heads, output_attentions) + else: + # CPU path + if not output_attentions: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with + # AttentionMaskConverter.to_causal_4d that + # does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + else: + attn_output, attn_weights = native_sdp(query_states, key_states, value_states, + attention_mask, + bsz, q_len, kv_seq_len, + self.head_dim, + self.num_heads, output_attentions) attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) if attn_output.size() != attn_output_size: @@ -1773,9 +1794,6 @@ def llama_model_forward_4_36_internal( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # IPEX-LLM modifications: - # Disable sdpa for CPU - self._use_sdpa = False if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) \