Skip to content

Commit

Permalink
Merge pull request #813 from Pingdred/chat_models_continue
Browse files Browse the repository at this point in the history
Chat models continue
  • Loading branch information
Pingdred committed May 10, 2024
2 parents 2793e78 + 2fb3b52 commit 3860a48
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 48 deletions.
5 changes: 1 addition & 4 deletions core/cat/experimental/form/cat_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def confirm(self) -> bool:
User said "{user_message}"
JSON:
```json
{{
"confirm": """

Expand Down Expand Up @@ -105,7 +104,6 @@ def check_exit_intent(self) -> bool:
User said "{user_message}"
JSON:
```json
{{
"exit": """

Expand Down Expand Up @@ -227,7 +225,7 @@ def extract(self):
verbose = True,
output_key = "output"
)
json_str = extraction_chain.invoke({"stop": ["```"]})["output"]
json_str = extraction_chain.invoke({})["output"] #{"stop": ["```"]}

log.debug(f"Form JSON after parser:\n{json_str}")

Expand Down Expand Up @@ -272,7 +270,6 @@ def extraction_prompt(self):
{history}
Updated JSON:
```json
"""

# TODO: convo example (optional but supported)
Expand Down
88 changes: 46 additions & 42 deletions core/cat/looking_glass/agent_manager.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,30 @@
import json
import os
import random
import time
import json
import random
import traceback
from datetime import timedelta
from typing import List, Dict

from copy import deepcopy
from typing import List, Dict
from datetime import timedelta

from cat.convo.messages import Role
from langchain_core.utils import get_colored_text
from langchain.agents import AgentExecutor
from langchain.docstore.document import Document
from langchain.prompts import ChatPromptTemplate
from langchain.chains.llm import LLMChain
from langchain.agents import AgentExecutor
from langchain_core.runnables import RunnableConfig
from langchain_core.prompts.chat import SystemMessagePromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain.callbacks.tracers import ConsoleCallbackHandler

from cat.mad_hatter.plugin import Plugin
from cat.mad_hatter.mad_hatter import MadHatter
from cat.mad_hatter.decorators.tool import CatTool
from cat.looking_glass import prompts
from cat.looking_glass.output_parser import ChooseProcedureOutputParser
from cat.utils import verbal_timedelta
from cat.log import log
from langchain_core.runnables import RunnableConfig
from cat.looking_glass.callbacks import NewTokenHandler
from cat.experimental.form import CatForm, CatFormState
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.utils import get_colored_text


class AgentManager:
Expand Down Expand Up @@ -92,7 +90,8 @@ async def execute_procedures_agent(self, agent_input, stray):
template=self.mad_hatter.execute_hook(
"agent_prompt_instructions", prompts.TOOL_PROMPT, cat=stray
)
),
),
# *(stray.langchainfy_chat_history())
]
)

Expand All @@ -114,17 +113,23 @@ def examples():
# Add example
list_examples += f"\n```json\n{example_json}\n```"

list_examples += """```json
{{
"action": "final_answer",
"action_input": null
}}
```"""
return list_examples

prompt = prompt.partial(
tools="\n".join(
f"- {tool.name}: {tool.description}" for tool in allowed_tools
f"- {tool.name}: {tool.description}" for tool in allowed_procedures.values()
),
tool_names=", ".join(allowed_procedures.keys()),
agent_scratchpad="",
chat_history=stray.stringify_chat_history(),
examples=examples(),
)
llm_with_stop = stray._llm

def scratchpad(x):
thoughts = ""
Expand All @@ -136,8 +141,7 @@ def scratchpad(x):
"""
return thoughts

def logging(x):

def logging(x):
print("\n",get_colored_text(x.to_string(),"green"))
return x

Expand All @@ -147,7 +151,7 @@ def logging(x):
)
| prompt
| RunnableLambda(lambda x: logging(x))
| llm_with_stop
| stray._llm
| ChooseProcedureOutputParser()
)

Expand Down Expand Up @@ -202,33 +206,33 @@ async def execute_form_agent(self, stray):
async def execute_memory_chain(
self, agent_input, prompt_prefix, prompt_suffix, stray
):
chat_history = []
for message in stray.working_memory.history:
if message["role"] == Role.Human:
chat_history.append(HumanMessage(content=message["message"]))
else:
chat_history.append(AIMessage(content=message["message"]))

final_prompt = ChatPromptTemplate(
messages=[
SystemMessagePromptTemplate.from_template(
template=prompt_prefix + prompt_suffix
),
*chat_history,
]
)
final_prompt = ChatPromptTemplate(
messages=[
SystemMessagePromptTemplate.from_template(
template=prompt_prefix + prompt_suffix
),
*(stray.langchainfy_chat_history())
]
)

def logging(x):
#The names are not shown in the chat history log, the model however receives the name correctly
log.info("The names are not shown in the chat history log, the model however receives the name correctly")
print("\n",get_colored_text(x.to_string(),"green"))
return x

memory_chain = LLMChain(
prompt=final_prompt,
llm=stray._llm,
verbose=self.verbose,
output_key="output",
# return_final_only=False
memory_chain = (
final_prompt
| RunnableLambda(lambda x: logging(x))
| stray._llm
)

return memory_chain.invoke(
output = memory_chain.invoke(
agent_input, config=RunnableConfig(callbacks=[NewTokenHandler(stray)])
)
agent_input["output"] = output.content

return agent_input

async def execute_agent(self, stray):
"""Instantiate the Agent with tools.
Expand Down Expand Up @@ -345,13 +349,13 @@ def format_agent_input(self, stray):
)

# format conversation history to be inserted in the prompt
conversation_history_formatted_content = stray.stringify_chat_history()
#conversation_history_formatted_content = stray.stringify_chat_history()

return {
"input": stray.working_memory.user_message_json.text, # TODO: deprecate, since it is included in chat history
"episodic_memory": episodic_memory_formatted_content,
"declarative_memory": declarative_memory_formatted_content,
"chat_history": conversation_history_formatted_content,
#"chat_history": conversation_history_formatted_content,
"tools_output": "",
}

Expand Down
2 changes: 1 addition & 1 deletion core/cat/looking_glass/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def format(self, **kwargs) -> str:
## Actions sequence used until now:
{agent_scratchpad}
## Next action:
# Next Action to perform or final_answare:
"""


Expand Down
15 changes: 14 additions & 1 deletion core/cat/looking_glass/stray_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from langchain.docstore.document import Document
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_community.llms import BaseLLM
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage

from fastapi import WebSocket

from cat.log import log
from cat.looking_glass.cheshire_cat import CheshireCat
from cat.looking_glass.callbacks import NewTokenHandler
from cat.memory.working_memory import WorkingMemory
from cat.convo.messages import CatMessage, UserMessage, MessageWhy
from cat.convo.messages import CatMessage, UserMessage, MessageWhy, Role

MSG_TYPES = Literal["notification", "chat", "error", "chat_token"]

Expand Down Expand Up @@ -503,6 +504,18 @@ def stringify_chat_history(self, latest_n: int = 5) -> str:
history_string += f"\n - {turn['who']}: {turn['message']}"

return history_string

def langchainfy_chat_history(self, latest_n: int = 5) -> List[BaseMessage]:
chat_history = self.working_memory.history[-latest_n:]

langchain_chat_history = []
for message in chat_history:
if message["role"] == Role.Human:
langchain_chat_history.append(HumanMessage(name=message["who"], content=message["message"]))
else:
langchain_chat_history.append(AIMessage(name=message["who"], content=message["message"]))

return langchain_chat_history

@property
def user_id(self):
Expand Down

0 comments on commit 3860a48

Please sign in to comment.