From d193616e4f0cf6e0ff75d1b5200071e180c7be78 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Fri, 26 Apr 2024 17:26:54 +0800 Subject: [PATCH 1/2] fix --- .../src/ipex_llm/transformers/models/llama.py | 31 ++++++++++++++----- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index e542c6c5822..285d8736442 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -1346,10 +1346,30 @@ def llama_attention_forward_4_36_original( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) # otherwise, use native attention - attn_output, attn_weights = native_sdp(query_states, key_states, value_states, - attention_mask, - bsz, q_len, kv_seq_len, - self.head_dim, self.num_heads, output_attentions) + if query_states.device.type == "xpu": + attn_output, attn_weights = native_sdp(query_states, key_states, value_states, + attention_mask, + bsz, q_len, kv_seq_len, + self.head_dim, self.num_heads, output_attentions) + else: + # CPU path + if not output_attentions: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that + # does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + else: + attn_output, attn_weights = native_sdp(query_states, key_states, value_states, + attention_mask, + bsz, q_len, kv_seq_len, + self.head_dim, + self.num_heads, output_attentions) attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) if attn_output.size() != attn_output_size: @@ -1773,9 +1793,6 @@ def llama_model_forward_4_36_internal( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # IPEX-LLM modifications: - # Disable sdpa for CPU - self._use_sdpa = False if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) \ From 1923ebac4c9517c8f9f5143d38958b2e4849ec3a Mon Sep 17 00:00:00 2001 From: gc-fu Date: Fri, 26 Apr 2024 17:43:01 +0800 Subject: [PATCH 2/2] fix --- python/llm/src/ipex_llm/transformers/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 285d8736442..077faf50c5b 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -1360,7 +1360,8 @@ def llama_attention_forward_4_36_original( value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that + # The q_len > 1 is necessary to match with + # AttentionMaskConverter.to_causal_4d that # does not create a causal mask in case q_len == 1. is_causal=self.is_causal and attention_mask is None and q_len > 1, )