Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use SQLDatabaseToolkit for Agent sql Skill #9132

Merged
merged 12 commits into from
Apr 29, 2024
Expand Up @@ -140,7 +140,7 @@ def get_columns(self, table_name) -> Response:
"""
Show details about the table
"""
q = f"SELECT column_name, data_type, FROM \
q = f"SELECT column_name AS Field, data_type as Type, FROM \
`{self.connection_data['project_id']}.{self.connection_data['dataset']}.INFORMATION_SCHEMA.COLUMNS` WHERE table_name = '{table_name}'"
result = self.native_query(q)
return result
Expand Down
Expand Up @@ -92,3 +92,4 @@
DEFAULT_MAX_TOKENS = 2048
DEFAULT_MODEL_NAME = 'gpt-4-0125-preview'
DEFAULT_USER_COLUMN = 'question'
DEFAULT_EMBEDDINGS_MODEL_PROVIDER = 'openai'
Expand Up @@ -18,6 +18,7 @@
DEFAULT_AGENT_TIMEOUT_SECONDS,
DEFAULT_AGENT_TOOLS,
DEFAULT_AGENT_TYPE,
DEFAULT_EMBEDDINGS_MODEL_PROVIDER,
DEFAULT_MAX_ITERATIONS,
DEFAULT_MAX_TOKENS,
DEFAULT_MODEL_NAME,
Expand Down Expand Up @@ -219,15 +220,12 @@ def create_agent(self, df: pd.DataFrame, args: Dict=None, pred_args: Dict=None)

embeddings_args = args.pop('embedding_model_args', {})

# no embedding model args provided, use same provider as llm
# no embedding model args provided, use default provider.
if not embeddings_args:
logger.warning("'embedding_model_args' not found in input params, "
"Trying to use the same provider used for llm. "
f"provider: {args['provider']}"
f"Trying to use default provider: {DEFAULT_EMBEDDINGS_MODEL_PROVIDER}"
)

# get args for embeddings model
embeddings_args['class'] = args['provider']
embeddings_args['class'] = DEFAULT_EMBEDDINGS_MODEL_PROVIDER

# create embeddings model
pred_args['embeddings_model'] = self._create_embeddings_model(embeddings_args)
Expand Down
Expand Up @@ -2,7 +2,7 @@
Wrapper around MindsDB's executor and integration controller following the implementation of the original
langchain.sql_database.SQLDatabase class to partly replicate its behavior.
"""
from typing import Iterable, List, Optional
from typing import Any, Iterable, List, Optional

from langchain.sql_database import SQLDatabase

Expand All @@ -18,7 +18,7 @@ def __init__(
self,
engine=None,
database: Optional[str] = 'mindsdb',
metadata: Optional = None,
metadata: Optional[Any] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
Expand Down
58 changes: 32 additions & 26 deletions mindsdb/integrations/handlers/langchain_handler/tools.py
Expand Up @@ -112,20 +112,21 @@ def _setup_standard_tools(tools, llm, model_kwargs):

all_standard_tools = []
langchain_tools = []
mdb_tool = Tool(
name="MindsDB",
func=get_exec_call_tool(llm, executor, model_kwargs),
description="useful to read from databases or tables connected to the mindsdb machine learning package. the action must be a valid simple SQL query, always ending with a semicolon. For example, you can do `show databases;` to list the available data sources, and `show tables;` to list the available tables within each data source." # noqa
)

mdb_meta_tool = Tool(
name="MDB-Metadata",
func=get_exec_metadata_tool(llm, executor, model_kwargs),
description="useful to get column names from a mindsdb table or metadata from a mindsdb data source. the command should be either 1) a data source name, to list all available tables that it exposes, or 2) a string with the format `data_source_name.table_name` (for example, `files.my_table`), to get the table name, table type, column names, data types per column, and amount of rows of the specified table." # noqa
)
all_standard_tools.append(mdb_tool)
all_standard_tools.append(mdb_meta_tool)
for tool in tools:
if tool == 'mindsdb_read':
mdb_tool = Tool(
name="MindsDB",
func=get_exec_call_tool(llm, executor, model_kwargs),
description="useful to read from databases or tables connected to the mindsdb machine learning package. the action must be a valid simple SQL query, always ending with a semicolon. For example, you can do `show databases;` to list the available data sources, and `show tables;` to list the available tables within each data source." # noqa
)

mdb_meta_tool = Tool(
name="MDB-Metadata",
func=get_exec_metadata_tool(llm, executor, model_kwargs),
description="useful to get column names from a mindsdb table or metadata from a mindsdb data source. the command should be either 1) a data source name, to list all available tables that it exposes, or 2) a string with the format `data_source_name.table_name` (for example, `files.my_table`), to get the table name, table type, column names, data types per column, and amount of rows of the specified table." # noqa
)
all_standard_tools.append(mdb_tool)
all_standard_tools.append(mdb_meta_tool)
if tool == 'mindsdb_write':
mdb_write_tool = Tool(
name="MDB-Write",
Expand Down Expand Up @@ -280,20 +281,25 @@ def _build_retrieval_tool(tool: dict, pred_args: dict, skill: db.Skills):
)


def langchain_tool_from_skill(skill, pred_args):
def langchain_tools_from_skill(skill, pred_args):
# Makes Langchain compatible tools from a skill
tool = skill_tool.get_tool_from_skill(skill)
tools = skill_tool.get_tools_from_skill(skill)

if tool['type'] == SkillType.RETRIEVAL.value:

return _build_retrieval_tool(tool, pred_args, skill)

return Tool(
name=tool['name'],
func=tool['func'],
description=tool['description'],
return_direct=True
)
all_tools = []
for tool in tools:
if skill.type == SkillType.RETRIEVAL.value:
all_tools.append(_build_retrieval_tool(tool, pred_args, skill))
continue
if isinstance(tool, dict):
all_tools.append(Tool(
name=tool['name'],
func=tool['func'],
description=tool['description'],
return_direct=True
))
continue
all_tools.append(tool)
return all_tools

def get_skills(pred_args):
return pred_args.get('skills', [])
Expand All @@ -316,7 +322,7 @@ def setup_tools(llm, model_kwargs, pred_args, default_agent_tools):
tools = []
skills = get_skills(pred_args)
for skill in skills:
tools.append(langchain_tool_from_skill(skill, pred_args))
tools += langchain_tools_from_skill(skill, pred_args)

if len(tools) == 0:
tools = _setup_standard_tools(standard_tools, llm, model_kwargs)
Expand Down
4 changes: 2 additions & 2 deletions mindsdb/interfaces/knowledge_base/controller.py
Expand Up @@ -424,7 +424,7 @@ def _create_persistent_chroma(self, kb_name, engine="chromadb"):
self.session.integration_controller.add(vector_store_name, engine, connection_args)
return vector_store_name

def _create_default_embedding_model(self, project_name, kb_name, engine="sentence_transformers"):
def _create_default_embedding_model(self, project_name, kb_name, engine="langchain_embedding"):
"""create a default embedding model for knowledge base, if not specified"""
model_name = f"{kb_name}_default_model"

Expand All @@ -437,7 +437,7 @@ def _create_default_embedding_model(self, project_name, kb_name, engine="sentenc
)
ml_handler = self.session.integration_controller.get_ml_handler(engine)

self.session.model_controller.create_model(
_ = self.session.model_controller.create_model(
statement,
ml_handler
)
Expand Down
68 changes: 40 additions & 28 deletions mindsdb/interfaces/skills/skill_tool.py
Expand Up @@ -8,10 +8,12 @@
from .sql_agent import SQLAgent

_DEFAULT_TOP_K_SIMILARITY_SEARCH = 5
_DEFAULT_SQL_LLM_MODEL = 'gpt-3.5-turbo'


class SkillType(enum.Enum):
TEXT2SQL = 'text2sql'
TEXT2SQL_LEGACY = 'text2sql'
TEXT2SQL = 'sql'
KNOWLEDGE_BASE = 'knowledge_base'
RETRIEVAL = 'retrieval'

Expand Down Expand Up @@ -50,32 +52,42 @@ def _make_text_to_sql_tools(self, skill: db.Skills) -> dict:
'''
Uses SQLAgent to execute tool
'''

# To prevent dependency on Langchain unless an actual tool uses it.
try:
from mindsdb.integrations.handlers.langchain_handler.mindsdb_database_agent import MindsDBSQL
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_community.chat_models import ChatOpenAI
except ImportError:
raise ImportError('To use the text-to-SQL skill, please install langchain with `pip install mindsdb[langchain]`')
database = skill.params['database']
tables = skill.params['tables']

sql_agent = SQLAgent(
self.get_command_executor(),
database,
include_tables=tables
)

description = (
"Use the conversation context to decide which table to query. "
"Input to this tool is a detailed and correct SQL query, output is a result from the database. "
"If the query is not correct, an error message will be returned. "
"If an error is returned, rewrite the query, check the query, and try again. "
f"These are the available tables: {','.join(tables)}\n"
)
for table in tables:
description += f'Table name: "{table}", columns {sql_agent.get_table_columns(table)}\n'

return dict(
name='sql_db_query',
func=sql_agent.query_safe,
description=description,
type=skill.type
tables_to_include = [f'{database}.{table}' for table in tables]
db = MindsDBSQL(
engine=self.get_command_executor(),
metadata=self.get_command_executor().session.integration_controller,
include_tables=tables_to_include
)
# Users probably don't need to configure this for now.
llm = ChatOpenAI(model=_DEFAULT_SQL_LLM_MODEL, temperature=0)
tmichaeldb marked this conversation as resolved.
Show resolved Hide resolved
sql_database_tools = SQLDatabaseToolkit(db=db, llm=llm).get_tools()
description = skill.params.get('description', '')
tables_list = ','.join([f'{database}.{table}' for table in tables])
for i, tool in enumerate(sql_database_tools):
if isinstance(tool, QuerySQLDataBaseTool):
# Add our own custom description so our agent knows when to query this table.
tool.description = (
f'Use this tool if you need data about {description}. '
'Use the conversation context to decide which table to query. '
f'These are the available tables: {tables_list}.\n'
f'ALWAYS consider these special cases:\n'
f'- Not all SQL functions are supported. Do NOT use the following functions: INTERVAL. \n'
f'- For TIMESTAMP type columns, make sure you include the time portion in your query (e.g. WHERE date_column = "2020-01-01 12:00:00")'
f'Here are the rest of the instructions:\n'
f'{tool.description}'
)
sql_database_tools[i] = tool
return sql_database_tools

def _make_retrieval_tools(self, skill: db.Skills) -> dict:
"""
Expand Down Expand Up @@ -125,7 +137,7 @@ def _make_knowledge_base_tools(self, skill: db.Skills) -> dict:
type=skill.type
)

def get_tool_from_skill(self, skill: db.Skills) -> dict:
def get_tools_from_skill(self, skill: db.Skills) -> dict:
"""
Creates function for skill and metadata (name, description)
Args:
Expand All @@ -141,12 +153,12 @@ def get_tool_from_skill(self, skill: db.Skills) -> dict:
raise NotImplementedError(
f'skill of type {skill.type} is not supported as a tool, supported types are: {list(SkillType._member_names_)}')

if skill_type == SkillType.TEXT2SQL:
if skill_type == SkillType.TEXT2SQL or skill_type == SkillType.TEXT2SQL_LEGACY:
return self._make_text_to_sql_tools(skill)
if skill_type == SkillType.KNOWLEDGE_BASE:
return self._make_knowledge_base_tools(skill)
return [self._make_knowledge_base_tools(skill)]
if skill_type == SkillType.RETRIEVAL:
return self._make_retrieval_tools(skill)
return [self._make_retrieval_tools(skill)]


skill_tool = SkillToolController()