Skip to content

Commit

Permalink
Merge pull request #783 from valentimarco/develop
Browse files Browse the repository at this point in the history
[HUGE PR] Chat Messages for ChatModels
  • Loading branch information
Pingdred committed May 10, 2024
2 parents 5d40d9b + 22375d4 commit 2793e78
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 132 deletions.
46 changes: 39 additions & 7 deletions core/cat/convo/messages.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,61 @@





from typing import List, Dict
from cat.utils import BaseModelDict
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
from enum import Enum

#class WorkingMemory(BaseModelDict):
# history : List = []

class Role(Enum):
AI = "AI"
Human = "Human"

class MessageWhy(BaseModelDict):
"""Class for wrapping message why
Variables:
input (str): input message
intermediate_steps (List): intermediate steps
memory (dict): memory
"""
input: str
intermediate_steps: List
memory: dict


class CatMessage(BaseModelDict):
"""Class for wrapping cat message
Variables:
content (str): cat message
user_id (str): user id
"""
content: str
user_id: str
type: str = "chat"
why: MessageWhy | None = None


class UserMessage(BaseModelDict):
"""Class for wrapping user message
Variables:
text (str): user message
user_id (str): user id
"""
text: str
user_id: str



def convert_to_Langchain_message(messages: List[UserMessage | CatMessage] ) -> List[BaseMessage]:
messages = []
for m in messages:
if isinstance(m, CatMessage):
messages.append(HumanMessage(content=m.content, response_metadata={"userId": m.user_id}))
else:
messages.append(AIMessage(content=m.text, response_metadata={"userId": m.user_id}))
return messages

def convert_to_Cat_message(cat_message: AIMessage, why: MessageWhy) -> CatMessage:
return CatMessage(content=cat_message.content, user_id=cat_message.response_metadata["userId"], why=why)


5 changes: 3 additions & 2 deletions core/cat/factory/custom_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

from langchain_core.language_models.llms import LLM
from langchain_openai.llms import OpenAI
from langchain_community.llms.ollama import Ollama, OllamaEndpointNotFoundError
from langchain_community.llms.ollama import Ollama,OllamaEndpointNotFoundError
from langchain_openai.chat_models import ChatOpenAI
from langchain_community.chat_models.ollama import ChatOllama

from cat.log import log

Expand Down Expand Up @@ -91,7 +92,7 @@ def __init__(self, **kwargs):



class CustomOllama(Ollama):
class CustomOllama(ChatOllama):
def __init__(self, **kwargs: Any) -> None:
if "localhost" in kwargs["base_url"]:
log.error(
Expand Down
16 changes: 5 additions & 11 deletions core/cat/factory/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.chat_models.ollama import ChatOllama

from .ollama_utils import _create_stream_patch, _acreate_stream_patch
from typing import Type
Expand All @@ -18,6 +19,7 @@
from cat.mad_hatter.mad_hatter import MadHatter



# Base class to manage LLM configuration.
class LLMSettings(BaseModel):
# class instantiating the model
Expand Down Expand Up @@ -228,29 +230,21 @@ class LLMHuggingFaceEndpointConfig(LLMSettings):
"link": "https://huggingface.co/inference-endpoints",
}
)


# monkey patch to fix stops sequences
OllamaFix: Type = CustomOllama
OllamaFix._create_stream = _create_stream_patch
OllamaFix._acreate_stream = _acreate_stream_patch


class LLMOllamaConfig(LLMSettings):
base_url: str
model: str = "llama2"
model: str = "llama3"
num_ctx: int = 2048
repeat_last_n: int = 64
repeat_penalty: float = 1.1
temperature: float = 0.8

_pyclass: Type = OllamaFix
_pyclass: Type = CustomOllama

model_config = ConfigDict(
json_schema_extra={
"humanReadableName": "Ollama",
"description": "Configuration for Ollama",
"link": "https://ollama.ai/library",
"link": "https://ollama.ai/library"
}
)

Expand Down

0 comments on commit 2793e78

Please sign in to comment.