Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HUGE PR] Chat Messages for ChatModels #783

Merged
merged 7 commits into from
May 10, 2024
46 changes: 39 additions & 7 deletions core/cat/convo/messages.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,61 @@





from typing import List, Dict
from cat.utils import BaseModelDict
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
from enum import Enum

#class WorkingMemory(BaseModelDict):
# history : List = []

class Role(Enum):
AI = "AI"
Human = "Human"

class MessageWhy(BaseModelDict):
"""Class for wrapping message why

Variables:
input (str): input message
intermediate_steps (List): intermediate steps
memory (dict): memory
"""
input: str
intermediate_steps: List
memory: dict


class CatMessage(BaseModelDict):
"""Class for wrapping cat message

Variables:
content (str): cat message
user_id (str): user id
"""
content: str
user_id: str
type: str = "chat"
why: MessageWhy | None = None


class UserMessage(BaseModelDict):
"""Class for wrapping user message

Variables:
text (str): user message
user_id (str): user id
"""
text: str
user_id: str



def convert_to_Langchain_message(messages: List[UserMessage | CatMessage] ) -> List[BaseMessage]:
messages = []
for m in messages:
if isinstance(m, CatMessage):
messages.append(HumanMessage(content=m.content, response_metadata={"userId": m.user_id}))
else:
messages.append(AIMessage(content=m.text, response_metadata={"userId": m.user_id}))
return messages

def convert_to_Cat_message(cat_message: AIMessage, why: MessageWhy) -> CatMessage:
return CatMessage(content=cat_message.content, user_id=cat_message.response_metadata["userId"], why=why)


5 changes: 3 additions & 2 deletions core/cat/factory/custom_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

from langchain_core.language_models.llms import LLM
from langchain_openai.llms import OpenAI
from langchain_community.llms.ollama import Ollama, OllamaEndpointNotFoundError
from langchain_community.llms.ollama import Ollama,OllamaEndpointNotFoundError
from langchain_openai.chat_models import ChatOpenAI
from langchain_community.chat_models.ollama import ChatOllama

from cat.log import log

Expand Down Expand Up @@ -91,7 +92,7 @@ def __init__(self, **kwargs):



class CustomOllama(Ollama):
class CustomOllama(ChatOllama):
def __init__(self, **kwargs: Any) -> None:
if "localhost" in kwargs["base_url"]:
log.error(
Expand Down
16 changes: 5 additions & 11 deletions core/cat/factory/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.chat_models.ollama import ChatOllama

from .ollama_utils import _create_stream_patch, _acreate_stream_patch
from typing import Type
Expand All @@ -18,6 +19,7 @@
from cat.mad_hatter.mad_hatter import MadHatter



# Base class to manage LLM configuration.
class LLMSettings(BaseModel):
# class instantiating the model
Expand Down Expand Up @@ -228,29 +230,21 @@ class LLMHuggingFaceEndpointConfig(LLMSettings):
"link": "https://huggingface.co/inference-endpoints",
}
)


# monkey patch to fix stops sequences
OllamaFix: Type = CustomOllama
OllamaFix._create_stream = _create_stream_patch
OllamaFix._acreate_stream = _acreate_stream_patch


class LLMOllamaConfig(LLMSettings):
base_url: str
model: str = "llama2"
model: str = "llama3"
num_ctx: int = 2048
repeat_last_n: int = 64
repeat_penalty: float = 1.1
temperature: float = 0.8

_pyclass: Type = OllamaFix
_pyclass: Type = CustomOllama

model_config = ConfigDict(
json_schema_extra={
"humanReadableName": "Ollama",
"description": "Configuration for Ollama",
"link": "https://ollama.ai/library",
"link": "https://ollama.ai/library"
}
)

Expand Down