Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support nested pydantic models with Gemini tool calling [DRAFT] #222

Draft
wants to merge 3 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
55 changes: 41 additions & 14 deletions mirascope/gemini/tools.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
"""Classes for using tools with Google's Gemini API."""
from __future__ import annotations

from typing import Callable, Type
import pprint
from typing import Any, Callable, Type

import jsonref
from google.ai.generativelanguage import FunctionCall
from google.generativeai.types import ( # type: ignore
FunctionDeclaration,
Tool,
)
from pydantic import BaseModel, ConfigDict

from mirascope.base.tools import DEFAULT_TOOL_DOCSTRING
from mirascope.gemini.utils import remove_invalid_title_keys_from_parameters

from ..base import (
BaseTool,
BaseType,
Expand Down Expand Up @@ -68,20 +73,42 @@ def tool_schema(cls) -> Tool:
Returns:
The constructed `Tool` schema.
"""
tool_schema = super().tool_schema()
if "parameters" in tool_schema:
if "$defs" in tool_schema["parameters"]:
raise ValueError(
"Unfortunately Google's Gemini API cannot handle nested structures "
"with $defs."
super().tool_schema()
model_schema: dict[str, Any] = cls.model_json_schema()
pprint.pprint(model_schema)

# Replace all references with their values
without_refs: dict[str, Any] = jsonref.replace_refs(model_schema) # type: ignore
pprint.pprint(without_refs)

# Remove all Defs
without_refs.pop("$defs")
pprint.pprint(without_refs)

# Get the name and description, and remove them from the schema
name: str = without_refs.pop("title") # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick note:

name and description are already in super().tool_schema(), so it's really just the parameters field that needs to be modified (and thus remove duplicate code).

description: str = ( # type: ignore
without_refs.pop("description", None) or DEFAULT_TOOL_DOCSTRING
)
parameters: dict[str, Any] = without_refs

# Remove all instances of title key in each param definition
# This is careful not to delete keys that represent a field with the name title
remove_invalid_title_keys_from_parameters(parameters)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if all of the conversion of parameters was pushed inside this utility function so the main function remains clean and the utility function can be separately unit tested?


print(f"{name=}")
print(f"{description=}")
print(f"{parameters=}")

return Tool(
function_declarations=[
FunctionDeclaration(
name=name,
description=description,
parameters=parameters,
)
tool_schema["parameters"]["properties"] = {
prop: {
key: value for key, value in prop_schema.items() if key != "title"
}
for prop, prop_schema in tool_schema["parameters"]["properties"].items()
}
return Tool(function_declarations=[FunctionDeclaration(**tool_schema)])
]
)

@classmethod
def from_tool_call(cls, tool_call: FunctionCall) -> GeminiTool:
Expand Down
55 changes: 55 additions & 0 deletions mirascope/gemini/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Any


def remove_invalid_title_keys_from_parameters(d: dict[str, Any] | Any) -> None:
"""
For each property, remove the title key. However, we make sure to only remove
the title key in each schema
Before
{
"properties": {
"books": {
"items": {
"properties": {
"author_name": {"title": "Author Name", "type": "string"},
"title": {"title": "Title", "type": "string"},
},
"required": ["author_name", "title"],
"title": "Book",
"type": "object",
},
"title": "Books",
"type": "array",
}
},
"required": ["books"],
"title": "Books",
"type": "object",
}

AFTER
{
"properties": {
"books": {
"items": {
"properties": {
"author_name": {"type": "string"},
"title": {"type": "string"},
},
"required": ["author_name", "title"],
"type": "object",
},
"type": "array",
}
},
"required": ["books"],
"type": "object",
}

"""
if isinstance(d, dict):
for key in list(d.keys()):
if key == "title" and "type" in d.keys():
del d[key]
else:
remove_invalid_title_keys_from_parameters(d[key])