Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update function calling api to tools api #404

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.2.1
rev: v0.2.2
hooks:
- id: ruff
13 changes: 8 additions & 5 deletions langroid/agent/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(
self.llm_functions_map: Dict[str, LLMFunctionSpec] = {}
self.llm_functions_handled: Set[str] = set()
self.llm_functions_usable: Set[str] = set()
self.llm_function_force: Optional[Dict[str, str]] = None
self.llm_function_force: Optional[Dict[str, str | Dict[str, str]]] = None

def clone(self, i: int = 0) -> "ChatAgent":
"""Create i'th clone of this agent, ensuring tool use/handling is cloned.
Expand Down Expand Up @@ -390,7 +390,10 @@ def enable_message(
llm_function = message_class.llm_function_schema(defaults=include_defaults)
self.llm_functions_map[request] = llm_function
if force:
self.llm_function_force = dict(name=request)
self.llm_function_force = {
"type": "function",
"function": dict(name=request),
}
else:
self.llm_function_force = None

Expand Down Expand Up @@ -645,9 +648,9 @@ def _prep_llm_messages(

def _function_args(
self,
) -> Tuple[Optional[List[LLMFunctionSpec]], str | Dict[str, str]]:
) -> Tuple[Optional[List[LLMFunctionSpec]], str | Dict[str, str | Dict[str, str]]]:
functions: Optional[List[LLMFunctionSpec]] = None
fun_call: str | Dict[str, str] = "none"
fun_call: str | Dict[str, str | Dict[str, str]] = "none"
if self.config.use_functions_api and len(self.llm_functions_usable) > 0:
functions = [self.llm_functions_map[f] for f in self.llm_functions_usable]
fun_call = (
Expand Down Expand Up @@ -731,7 +734,7 @@ async def llm_response_messages_async(
assert self.config.llm is not None and self.llm is not None
output_len = output_len or self.config.llm.max_output_tokens
functions: Optional[List[LLMFunctionSpec]] = None
fun_call: str | Dict[str, str] = "none"
fun_call: str | Dict[str, str | Dict[str, str]] = "none"
if self.config.use_functions_api and len(self.llm_functions_usable) > 0:
functions = [self.llm_functions_map[f] for f in self.llm_functions_usable]
fun_call = (
Expand Down
2 changes: 1 addition & 1 deletion langroid/language_models/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class AzureConfig(OpenAIGPTConfig):

api_key: str = "" # CAUTION: set this ONLY via env var AZURE_OPENAI_API_KEY
type: str = "azure"
api_version: str = "2023-05-15"
api_version: str = "2024-02-15-preview"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good to update here, can be overwritten via the AZURE_OPENAI_API_VERSION in the .env

deployment_name: str = ""
model_name: str = ""
model_version: str = "" # is used to determine the cost of using the model
Expand Down
4 changes: 2 additions & 2 deletions langroid/language_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def chat(
messages: Union[str, List[LLMMessage]],
max_tokens: int = 200,
functions: Optional[List[LLMFunctionSpec]] = None,
function_call: str | Dict[str, str] = "auto",
function_call: str | Dict[str, str | Dict[str, str]] = "auto",
) -> LLMResponse:
pass

Expand All @@ -401,7 +401,7 @@ async def achat(
messages: Union[str, List[LLMMessage]],
max_tokens: int = 200,
functions: Optional[List[LLMFunctionSpec]] = None,
function_call: str | Dict[str, str] = "auto",
function_call: str | Dict[str, str | Dict[str, str]] = "auto",
) -> LLMResponse:
pass

Expand Down
21 changes: 14 additions & 7 deletions langroid/language_models/openai_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,7 @@ def chat(
messages: Union[str, List[LLMMessage]],
max_tokens: int = 200,
functions: Optional[List[LLMFunctionSpec]] = None,
function_call: str | Dict[str, str] = "auto",
function_call: str | Dict[str, str | Dict[str, str]] = "auto",
) -> LLMResponse:
self.run_on_first_use()

Expand Down Expand Up @@ -1018,7 +1018,7 @@ async def achat(
messages: Union[str, List[LLMMessage]],
max_tokens: int = 200,
functions: Optional[List[LLMFunctionSpec]] = None,
function_call: str | Dict[str, str] = "auto",
function_call: str | Dict[str, str | Dict[str, str]] = "auto",
) -> LLMResponse:
self.run_on_first_use()

Expand Down Expand Up @@ -1123,7 +1123,7 @@ def _prep_chat_completion(
messages: Union[str, List[LLMMessage]],
max_tokens: int,
functions: Optional[List[LLMFunctionSpec]] = None,
function_call: str | Dict[str, str] = "auto",
function_call: str | Dict[str, str | Dict[str, str]] = "auto",
) -> Dict[str, Any]:
if isinstance(messages, str):
llm_messages = [
Expand Down Expand Up @@ -1152,10 +1152,17 @@ def _prep_chat_completion(
if functions is not None:
args.update(
dict(
functions=[f.dict() for f in functions],
function_call=function_call,
tools=[
{
"type": "function",
"function": f.dict(),
}
for f in functions
],
tool_choice=function_call,
)
)

return args

def _process_chat_completion_response(
Expand Down Expand Up @@ -1223,7 +1230,7 @@ def _chat(
messages: Union[str, List[LLMMessage]],
max_tokens: int,
functions: Optional[List[LLMFunctionSpec]] = None,
function_call: str | Dict[str, str] = "auto",
function_call: str | Dict[str, str | Dict[str, str]] = "auto",
) -> LLMResponse:
"""
ChatCompletion API call to OpenAI.
Expand Down Expand Up @@ -1265,7 +1272,7 @@ async def _achat(
messages: Union[str, List[LLMMessage]],
max_tokens: int,
functions: Optional[List[LLMFunctionSpec]] = None,
function_call: str | Dict[str, str] = "auto",
function_call: str | Dict[str, str | Dict[str, str]] = "auto",
) -> LLMResponse:
"""
Async version of _chat(). See that function for details.
Expand Down