Skip to content

Commit

Permalink
Fixes Chat History and Function Calling Agent.
Browse files Browse the repository at this point in the history
  • Loading branch information
Maximilian-Winter committed May 14, 2024
1 parent e346167 commit de88f4a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
11 changes: 8 additions & 3 deletions src/llama_cpp_agent/chat_history/basic_chat_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,14 @@ def get_message_store(self) -> BasicChatMessageStore:

def get_chat_messages(self) -> List[Dict[str, str]]:
if self.strategy == BasicChatHistoryStrategy.last_k_messages:
messages = [self.message_store.get_message(0)]
messages.extend(self.message_store.get_last_k_messages(self.k - 1))
return convert_messages_to_list_of_dictionaries(messages)
converted_messages = convert_messages_to_list_of_dictionaries(
self.message_store.get_last_k_messages(self.k - 1)
)
if len(converted_messages) == self.k and converted_messages[0]["role"] != "system":
messages = [convert_messages_to_list_of_dictionaries(self.message_store.get_message(0))]
messages.extend(converted_messages[1:])
return messages
return converted_messages
elif self.strategy == BasicChatHistoryStrategy.last_k_tokens:
total_tokens = 0
selected_messages = []
Expand Down
5 changes: 3 additions & 2 deletions src/llama_cpp_agent/function_calling_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from llama_cpp import Llama
from pydantic import BaseModel, Field

from .chat_history.messages import Roles
from .llm_output_settings import LlmStructuredOutputSettings, LlmStructuredOutputType

from .llm_agent import LlamaCppAgent, StreamingResponse
Expand Down Expand Up @@ -224,7 +225,7 @@ def generate_response(
llm_sampling_settings: LlmSamplingSettings = None,
structured_output_settings: LlmStructuredOutputSettings = None,
):
self.llama_cpp_agent.add_message(role="user", message=message)
self.llama_cpp_agent.add_message(role=Roles.user, message=message)

result = self.intern_get_response(llm_sampling_settings=llm_sampling_settings)

Expand Down Expand Up @@ -253,7 +254,7 @@ def generate_response(
else:
function_message += f"{count}. " + res + "\n\n"
self.llama_cpp_agent.add_message(
role="tool", message=function_message.strip()
role=Roles.tool, message=function_message.strip()
)
if agent_sent_message:
break
Expand Down

0 comments on commit de88f4a

Please sign in to comment.