Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update API Endpoints Base and Embedder to Version v2 #804

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions core/cat/factory/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,8 @@ def get_embedders_schemas():
EMBEDDER_SCHEMAS[schema["title"]] = schema

return EMBEDDER_SCHEMAS

def get_embedders_class():
# Provide a dictionary containing the name of the embedder and its corresponding class.
return {config_class.model_json_schema()["title"]: config_class
for config_class in get_allowed_embedder_models()}
10 changes: 7 additions & 3 deletions core/cat/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,20 @@ def custom_generate_unique_id(route: APIRoute):
allow_headers=["*"],
)

# Add routers to the middleware stack.
cheshire_cat_api.include_router(base.router, tags=["Status"], dependencies=[Depends(check_api_key)])
# Add routers to the middleware stack v1.
cheshire_cat_api.include_router(base.router_v1, tags=["Status"], dependencies=[Depends(check_api_key)])
cheshire_cat_api.include_router(settings.router, tags=["Settings"], prefix="/settings", dependencies=[Depends(check_api_key)])
cheshire_cat_api.include_router(llm.router, tags=["Large Language Model"], prefix="/llm", dependencies=[Depends(check_api_key)])
cheshire_cat_api.include_router(embedder.router, tags=["Embedder"], prefix="/embedder", dependencies=[Depends(check_api_key)])
cheshire_cat_api.include_router(embedder.router_v1, tags=["Embedder"], prefix="/embedder", dependencies=[Depends(check_api_key)])
cheshire_cat_api.include_router(plugins.router, tags=["Plugins"], prefix="/plugins", dependencies=[Depends(check_api_key)])
cheshire_cat_api.include_router(memory.router, tags=["Memory"], prefix="/memory", dependencies=[Depends(check_api_key)])
cheshire_cat_api.include_router(upload.router, tags=["Rabbit Hole"], prefix="/rabbithole", dependencies=[Depends(check_api_key)])
cheshire_cat_api.include_router(websocket.router, tags=["WebSocket"])

# Add routers to the middleware stack v2.
cheshire_cat_api.include_router(base.router_v2, tags=["Status"], dependencies=[Depends(check_api_key)], prefix="/v2")
cheshire_cat_api.include_router(embedder.router_v2, tags=["Embedder"], prefix="/v2/embedder", dependencies=[Depends(check_api_key)])

# mount static files
# this cannot be done via fastapi.APIrouter:
# https://github.com/tiangolo/fastapi/discussions/9070
Expand Down
81 changes: 72 additions & 9 deletions core/cat/routes/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,43 @@
from fastapi import APIRouter, Depends, Request, Body, Query
from typing import Dict
from fastapi import APIRouter, Depends, Body
from typing import Dict, List
import tomli

from cat.headers import session
from cat.log import log
from cat.looking_glass.stray_cat import StrayCat
from pydantic import BaseModel

# Default router
router_v1 = APIRouter()

router = APIRouter()
# Router v2
router_v2 = APIRouter()

class HomeResponse(BaseModel):
status: str
version: str

# server status
@router.get("/")
async def home() -> Dict:
"""Server status"""

class MemoryResponse(BaseModel):
page_content: str
type: str
score: float
id: str
metadata: Dict[str,str|int|float]

class MessageWhyResponse(BaseModel):
input: str
intermediate_steps: List
memory: Dict[str,List[MemoryResponse]]

class CatResponse(BaseModel):
type: str
content: str
user_id: str
why: MessageWhyResponse


def home() -> Dict:
with open("pyproject.toml", "rb") as f:
project_toml = tomli.load(f)["project"]

Expand All @@ -20,13 +47,49 @@ async def home() -> Dict:
}


@router.post("/message")
# server status
@router_v1.get("/", deprecated=True)
async def home_v1() -> Dict:
"""Server status"""
log.warning("Deprecated: This endpoint will be removed in the next major version.")
return home()

@router_v2.get("/", response_model=HomeResponse, response_model_exclude_none=True)
async def home_v2():
"""Server status"""

return home()


async def message_with_cat(
payload: Dict,
stray: StrayCat
) -> Dict:

answer = await stray(payload)

return answer


@router_v1.post("/message", deprecated=True)
async def message_with_cat_v1(
payload: Dict = Body({"text": "hello!"}),
stray = Depends(session),
) -> Dict:
"""Get a response from the Cat"""

answer = await stray(payload)
answer = await message_with_cat(payload, stray)
log.warning("Deprecated: This endpoint will be removed in the next major version.")

return answer

@router_v2.post("/message", response_model=CatResponse, response_model_exclude_none=True)
async def message_with_cat_v2(
payload: Dict = Body({"text": "hello!"}),
stray = Depends(session),
) -> Dict:
"""Get a response from the Cat"""

answer = await message_with_cat(payload, stray)

return answer