Skip to content

Commit

Permalink
incorporate tool calling (#131)
Browse files Browse the repository at this point in the history
  • Loading branch information
ccurme committed Apr 27, 2024
1 parent c72c87f commit cc01bbd
Show file tree
Hide file tree
Showing 6 changed files with 1,200 additions and 824 deletions.
1,947 changes: 1,147 additions & 800 deletions backend/poetry.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion backend/pyproject.toml
Expand Up @@ -14,7 +14,7 @@ fastapi = "^0.109.2"
langserve = "^0.0.45"
uvicorn = "^0.27.1"
pydantic = "^1.10"
langchain-openai = "^0.0.8"
langchain-openai = "^0.1.3"
jsonschema = "^4.21.1"
sse-starlette = "^2.0.0"
alembic = "^1.13.1"
Expand All @@ -26,6 +26,8 @@ lxml = "^5.1.0"
faiss-cpu = "^1.7.4"
python-multipart = "^0.0.9"
langchain-fireworks = "^0.1.1"
langchain-anthropic = "^0.1.11"
langchain-groq = "^0.1.3"

[tool.poetry.group.dev.dependencies]
jupyterlab = "^3.6.1"
Expand Down
25 changes: 12 additions & 13 deletions backend/server/extraction_runnable.py
@@ -1,12 +1,13 @@
from __future__ import annotations

import json
import uuid
from typing import Any, Dict, List, Optional, Sequence

from fastapi import HTTPException
from jsonschema import Draft202012Validator, exceptions
from langchain.text_splitter import TokenTextSplitter
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import chain
from langserve import CustomUserType
Expand Down Expand Up @@ -97,19 +98,18 @@ def _make_prompt_template(
# TODO: We'll need to refactor this at some point to
# support other encoding strategies. The function calling logic here
# has some hard-coded assumptions (e.g., name of parameters like `data`).
function_call = {
"arguments": json.dumps(
{
"data": example.output,
}
),
_id = uuid.uuid4().hex[:]
tool_call = {
"args": {"data": example.output},
"name": function_name,
"id": _id,
}
few_shot_prompt.extend(
[
HumanMessage(content=example.text),
AIMessage(
content="", additional_kwargs={"function_call": function_call}
AIMessage(content="", tool_calls=[tool_call]),
ToolMessage(
content="You have correctly called this tool.", tool_call_id=_id
),
]
)
Expand Down Expand Up @@ -172,10 +172,9 @@ async def extraction_runnable(extraction_request: ExtractRequest) -> ExtractResp
schema["title"],
)
model = get_model(extraction_request.model_name)
# N.B. method must be consistent with examples in _make_prompt_template
runnable = (
prompt | model.with_structured_output(schema=schema, method="function_calling")
).with_config({"run_name": "extraction"})
runnable = (prompt | model.with_structured_output(schema=schema)).with_config(
{"run_name": "extraction"}
)

return await runnable.ainvoke({"text": extraction_request.text})

Expand Down
17 changes: 17 additions & 0 deletions backend/server/models.py
@@ -1,8 +1,10 @@
import os
from typing import Optional

from langchain_anthropic import ChatAnthropic
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_fireworks import ChatFireworks
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI


Expand Down Expand Up @@ -37,6 +39,21 @@ def get_supported_models():
),
"description": "Mixtral 8x7B Instruct v0.1 (Together AI)",
}
if "ANTHROPIC_API_KEY" in os.environ:
models["claude-3-sonnet-20240229"] = {
"chat_model": ChatAnthropic(
model="claude-3-sonnet-20240229", temperature=0
),
"description": "Claude 3 Sonnet",
}
if "GROQ_API_KEY" in os.environ:
models["groq-llama3-8b-8192"] = {
"chat_model": ChatGroq(
model="llama3-8b-8192",
temperature=0,
),
"description": "GROQ Llama 3 8B",
}

return models

Expand Down
20 changes: 14 additions & 6 deletions backend/tests/unit_tests/fake/test_fake_chat_model.py
Expand Up @@ -6,25 +6,33 @@
from tests.unit_tests.fake.chat_model import GenericFakeChatModel


class AnyStr(str):
def __init__(self) -> None:
super().__init__()

def __eq__(self, other: object) -> bool:
return isinstance(other, str)


def test_generic_fake_chat_model_invoke() -> None:
# Will alternate between responding with hello and goodbye
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = model.invoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
response = model.invoke("kitty")
assert response == AIMessage(content="goodbye")
assert response == AIMessage(content="goodbye", id=AnyStr())
response = model.invoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())


async def test_generic_fake_chat_model_ainvoke() -> None:
# Will alternate between responding with hello and goodbye
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
response = await model.ainvoke("kitty")
assert response == AIMessage(content="goodbye")
assert response == AIMessage(content="goodbye", id=AnyStr())
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
11 changes: 7 additions & 4 deletions backend/tests/unit_tests/test_utils.py
@@ -1,4 +1,5 @@
from langchain.pydantic_v1 import BaseModel, Field
from langchain_core.messages import AIMessage

from extraction.utils import update_json_schema
from server.extraction_runnable import ExtractionExample, _make_prompt_template
Expand Down Expand Up @@ -82,19 +83,21 @@ def test_make_prompt_template() -> None:
)
prompt = _make_prompt_template(instructions, examples, "name")
messages = prompt.messages
assert 4 == len(messages)
assert 5 == len(messages)
system = messages[0].prompt.template
assert system.startswith(prefix)
assert system.endswith(instructions)

example_input = messages[1]
assert example_input.content == "Test text."
example_output = messages[2]
assert "function_call" in example_output.additional_kwargs
assert example_output.additional_kwargs["function_call"]["name"] == "name"
assert isinstance(example_output, AIMessage)
assert example_output.tool_calls
assert len(example_output.tool_calls) == 1
assert example_output.tool_calls[0]["name"] == "name"

prompt = _make_prompt_template(instructions, None, "name")
assert 2 == len(prompt.messages)

prompt = _make_prompt_template(None, examples, "name")
assert 4 == len(prompt.messages)
assert 5 == len(prompt.messages)

0 comments on commit cc01bbd

Please sign in to comment.