Skip to content

Commit

Permalink
Fixes for RAG Integration to Python SDK (#9098)
Browse files Browse the repository at this point in the history
* Allow 'sql' skill name instead of just 'text_to_sql'

* Fixes for RAG integration withj SDK

* Kb update tweaks - DU (#9101)

* Update default params and handle missing embedding args

The default chunk size and overlap values in file_handler have been increased for better performance. In langchain_handler, checks have been added to use the same provider and model for embeddings if 'embedding_model_args' are not provided in input params. Skill_tool's retriever_config has been renamed to 'config' for clarity.

* Handle invoke errors in langchain_handler

This update modifies the langchain handler to accommodate errors during the invocation process. Specifically, it implements exception handling for the agent_executor's invoke method. If an error occurs, instead of crashing, it will now return the error message. However, if the error doesn't match a specific format, the exception will still be raised.

* Update file_handler.py

* Add logging and refactor default vector store in langchain handler

This change introduces logging for situations where 'vector_store_config' is not present in the `langchain_handler` tool configuration. It also modifies the condition that checks for the absence of this property to add a persisting directory. Furthermore, it refactors the code in `langchain_handler.py` related to the absence of 'embedding_model_args'.

* Remove default vector store from langchain handler

The default vector store setup was removed from the langchain handler. The 'vector_store_config' is no longer automatically assigned, and an alert for this missing parameter is no longer logged. The code was adjusted to operate without the previously used 'mindsdb_path'.

* Update warning message format in langchain_handler

Added double quotes around the default collection name in the warning message within langchain_handler. This is to enhance readability and make the default name stand out in the midst of the message.

* Update warning message format in langchain_handler

Added double quotes around the default collection name in the warning message within langchain_handler. This is to enhance readability and make the default name stand out in the midst of the message.

---------

Co-authored-by: Daniel Usvyat <usvyat@gmail.com>
  • Loading branch information
tmichaeldb and dusvyat committed Apr 19, 2024
1 parent 8abf1c1 commit 48f7699
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 16 deletions.
7 changes: 4 additions & 3 deletions mindsdb/integrations/handlers/email_handler/email_ingestor.py
Expand Up @@ -76,10 +76,11 @@ def ingest(self) -> pd.DataFrame:
df = pd.DataFrame(all_email_data)

# Replace "(UTC)" with empty string over a pandas DataFrame column
df['datetime'] = df['datetime'].str.replace(' (UTC)', '')
if 'datetime' in df.columns:
df['datetime'] = df['datetime'].str.replace(' (UTC)', '')

# Convert datetime string to datetime object, and normalize timezone to UTC.
df['datetime'] = pd.to_datetime(df['datetime'], utc=True, format="%a, %d %b %Y %H:%M:%S %z", errors='coerce')
# Convert datetime string to datetime object, and normalize timezone to UTC.
df['datetime'] = pd.to_datetime(df['datetime'], utc=True, format="%a, %d %b %Y %H:%M:%S %z", errors='coerce')

return df

101 changes: 97 additions & 4 deletions mindsdb/integrations/handlers/langchain_handler/tools.py
Expand Up @@ -17,8 +17,13 @@
from langchain.chains import ReduceDocumentsChain, MapReduceDocumentsChain

from mindsdb.integrations.utilities.rag.rag_pipeline_builder import RAG
from mindsdb.integrations.utilities.rag.settings import RAGPipelineModel
from mindsdb.integrations.utilities.rag.settings import RAGPipelineModel, VectorStoreType, DEFAULT_COLLECTION_NAME
from mindsdb.interfaces.skills.skill_tool import skill_tool, SkillType
from mindsdb.interfaces.storage import db
from mindsdb.utilities import log

logger = log.getLogger(__name__)
from mindsdb.interfaces.storage.db import KnowledgeBase
from mindsdb.utilities import log

logger = log.getLogger(__name__)
Expand Down Expand Up @@ -160,7 +165,62 @@ def _get_rag_params(pred_args: Dict) -> Dict:
return rag_params


def _build_retrieval_tool(tool: dict, pred_args: dict):
def _create_conn_string(connection_args: dict) -> str:
"""
Creates a PostgreSQL connection string from connection args.
"""
user = connection_args.get('user')
host = connection_args.get('host')
port = connection_args.get('port')
password = connection_args.get('password')
dbname = connection_args.get('database')

if password:
return f"postgresql://{user}:{password}@{host}:{port}/{dbname}"
else:
return f"postgresql://{user}@{host}:{port}/{dbname}"


def _get_knowledge_base(knowledge_base_name: str, project_id, executor) -> db.KnowledgeBase:

kb = executor.session.kb_controller.get(knowledge_base_name, project_id)

return kb


def _build_vector_store_config_from_knowledge_base(rag_params: Dict, knowledge_base: KnowledgeBase, executor) -> Dict:
"""
build vector store config from knowledge base
"""

vector_store_config = rag_params['vector_store_config'].copy()

vector_store_type = knowledge_base.vector_database.engine
vector_store_config['vector_store_type'] = vector_store_type

if vector_store_type == VectorStoreType.CHROMA.value:
# For chromadb used, we get persist_directory
vector_store_folder_name = knowledge_base.vector_database.data['persist_directory']
integration_handler = executor.session.integration_controller.get_data_handler(
knowledge_base.vector_database.name
)
persist_dir = integration_handler.handler_storage.folder_get(vector_store_folder_name)
vector_store_config['persist_directory'] = persist_dir

elif vector_store_type == VectorStoreType.PGVECTOR.value:
# For pgvector, we get connection string
#todo requires further testing
connection_params = knowledge_base.vector_database.data
vector_store_config['connection_string'] = _create_conn_string(connection_params)

else:
raise ValueError(f"Invalid vector store type: {vector_store_type}. "
f"Only {[v.name for v in VectorStoreType]} are currently supported.")

return vector_store_config


def _build_retrieval_tool(tool: dict, pred_args: dict, skill: db.Skills):
"""
Builds a retrieval tool i.e RAG
"""
Expand All @@ -173,7 +233,40 @@ def _build_retrieval_tool(tool: dict, pred_args: dict):

rag_params = _get_rag_params(tools_config)

rag_config = RAGPipelineModel(**rag_params)
if 'vector_store_config' not in rag_params:
rag_params['vector_store_config'] = {}
logger.warning(f'No collection_name specified for the retrieval tool, '
f"using default collection_name: '{DEFAULT_COLLECTION_NAME}'"
f'\nWarning: If this collection does not exist, no data will be retrieved')

if 'source' in tool:
kb_name = tool['source']
executor = skill_tool.get_command_executor()
kb = _get_knowledge_base(kb_name, skill.project_id, executor)

if not kb:
raise ValueError(f"Knowledge base not found: {kb_name}")

rag_params['vector_store_config'] = _build_vector_store_config_from_knowledge_base(rag_params, kb, executor)

# Can run into weird validation errors when unpacking rag_params directly into constructor.
rag_config = RAGPipelineModel(
embeddings_model=rag_params['embeddings_model']
)
if 'documents' in rag_params:
rag_config.documents = rag_params['documents']
if 'vector_store_config' in rag_params:
rag_config.vector_store_config = rag_params['vector_store_config']
if 'db_connection_string' in rag_params:
rag_config.db_connection_string = rag_params['db_connection_string']
if 'table_name' in rag_params:
rag_config.table_name = rag_params['table_name']
if 'llm' in rag_params:
rag_config.llm = rag_params['llm']
if 'rag_prompt_template' in rag_params:
rag_config.rag_prompt_template = rag_params['rag_prompt_template']
if 'retriever_prompt_template' in rag_params:
rag_config.retriever_prompt_template = rag_params['retriever_prompt_template']

# build retriever
rag_pipeline = RAG(rag_config)
Expand All @@ -192,7 +285,7 @@ def langchain_tool_from_skill(skill, pred_args):

if tool['type'] == SkillType.RETRIEVAL.value:

return _build_retrieval_tool(tool, pred_args)
return _build_retrieval_tool(tool, pred_args, skill)

return Tool(
name=tool['name'],
Expand Down
6 changes: 4 additions & 2 deletions mindsdb/integrations/utilities/rag/settings.py
Expand Up @@ -11,6 +11,8 @@
from langchain.text_splitter import TextSplitter
from pydantic import BaseModel

DEFAULT_COLLECTION_NAME = 'default_collection'

# Multi retriever specific
DEFAULT_ID_KEY = "doc_id"
DEFAULT_MAX_CONCURRENCY = 5
Expand Down Expand Up @@ -69,7 +71,7 @@ class MultiVectorRetrieverMode(Enum):


class VectorStoreType(Enum):
CHROMA = 'chroma'
CHROMA = 'chromadb'
PGVECTOR = 'pgvector'


Expand All @@ -88,7 +90,7 @@ class RetrieverType(Enum):
class VectorStoreConfig(BaseModel):
vector_store_type: VectorStoreType = VectorStoreType.CHROMA
persist_directory: str = None
collection_name: str = 'default'
collection_name: str = DEFAULT_COLLECTION_NAME
connection_string: str = None

class Config:
Expand Down
20 changes: 13 additions & 7 deletions mindsdb/interfaces/knowledge_base/controller.py
Expand Up @@ -444,7 +444,7 @@ def _create_default_embedding_model(self, project_name, kb_name, engine="sentenc

return model_name

def delete(self, name: str, project_name: str, if_exists: bool = False) -> None:
def delete(self, name: str, project_name: int, if_exists: bool = False) -> None:
"""
Delete a knowledge base from the database
"""
Expand Down Expand Up @@ -473,15 +473,21 @@ def delete(self, name: str, project_name: str, if_exists: bool = False) -> None:

# drop objects if they were created automatically
if 'vector_storage' in kb.params:
self.session.integration_controller.delete(kb.params['vector_storage'])
try:
self.session.integration_controller.delete(kb.params['vector_storage'])
except EntityNotExistsError:
pass
if 'embedding_model' in kb.params:
self.session.model_controller.delete_model(kb.params['embedding_model'], project_name)
try:
self.session.model_controller.delete_model(kb.params['embedding_model'], project_name)
except EntityNotExistsError:
pass

# kb exists
db.session.delete(kb)
db.session.commit()

def get(self, name: str, project_id: str) -> db.KnowledgeBase:
def get(self, name: str, project_id: int) -> db.KnowledgeBase:
"""
Get a knowledge base from the database
by name + project_id
Expand All @@ -496,7 +502,7 @@ def get(self, name: str, project_id: str) -> db.KnowledgeBase:
)
return kb

def get_table(self, name: str, project_id: str) -> KnowledgeBaseTable:
def get_table(self, name: str, project_id: int) -> KnowledgeBaseTable:
"""
Returns kb table object
:param name: table name
Expand All @@ -507,7 +513,7 @@ def get_table(self, name: str, project_id: str) -> KnowledgeBaseTable:
if kb is not None:
return KnowledgeBaseTable(kb, self.session)

def list(self, project_id: str) -> List[db.KnowledgeBase]:
def list(self, project_id: int) -> List[db.KnowledgeBase]:
"""
List all knowledge bases from the database
belonging to a project
Expand All @@ -521,7 +527,7 @@ def list(self, project_id: str) -> List[db.KnowledgeBase]:
)
return kbs

def update(self, name: str, project_id: str, **kwargs) -> db.KnowledgeBase:
def update(self, name: str, project_id: int, **kwargs) -> db.KnowledgeBase:
"""
Update a knowledge base record
"""
Expand Down
1 change: 1 addition & 0 deletions mindsdb/interfaces/skills/skill_tool.py
Expand Up @@ -84,6 +84,7 @@ def _make_retrieval_tools(self, skill: db.Skills) -> dict:
params = skill.params
return dict(
name=params.get('name', skill.name),
source=params.get('source', None),
config=params.get('config', {}),
description=f'You must use this tool to get more context or information '
f'to answer a question about {params["description"]}. '
Expand Down

0 comments on commit 48f7699

Please sign in to comment.