Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding types to endpoints #490

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion core/cat/routes/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,11 @@ def get_embedder_settings(request: Request, languageEmbedderName: str) -> Dict:
def upsert_embedder_setting(
request: Request,
languageEmbedderName: str,
payload: Dict = Body(examples={"openai_api_key": "your-key-here"}),
payload: Dict = Body(
examples = {
"openai_api_key": "your-key-here"
}
),
) -> Dict:
"""Upsert the Embedder setting"""

Expand Down
6 changes: 5 additions & 1 deletion core/cat/routes/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ def get_llm_settings(request: Request, languageModelName: str) -> Dict:
def upsert_llm_setting(
request: Request,
languageModelName: str,
payload: Dict = Body(examples={"openai_api_key": "your-key-here"}),
payload: Dict = Body(
examples = {
"openai_api_key": "your-key-here"
}
),
) -> Dict:
"""Upsert the Large Language Model setting"""

Expand Down
57 changes: 36 additions & 21 deletions core/cat/routes/memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict
from cat.headers import check_user_id
from cat.routes.types import ResponseMemoryRecall, ResponseCollections, ResponseDelete, ResponseConversationHistory
from fastapi import Query, Request, APIRouter, HTTPException, Depends

router = APIRouter()
Expand All @@ -12,7 +13,7 @@ async def recall_memories_from_text(
text: str = Query(description="Find memories similar to this text."),
k: int = Query(default=100, description="How many memories to return."),
user_id = Depends(check_user_id)
) -> Dict:
) -> ResponseMemoryRecall:
"""Search k memories similar to given text."""

ccat = request.app.state.ccat
Expand Down Expand Up @@ -53,18 +54,20 @@ async def recall_memories_from_text(
memory_dict["vector"] = vector
recalled[c].append(memory_dict)

return {
result = {
"query": query,
"vectors": {
"embedder": str(ccat.embedder.__class__.__name__), # TODO: should be the config class name
"collections": recalled
}
}

return ResponseMemoryRecall(**result)


# GET collection list with some metadata
@router.get("/collections/")
async def get_collections(request: Request) -> Dict:
async def get_collections(request: Request) -> ResponseCollections:
"""Get list of available collections"""

ccat = request.app.state.ccat
Expand All @@ -80,16 +83,16 @@ async def get_collections(request: Request) -> Dict:
"vectors_count": coll_meta.vectors_count
}]

return {
result = {
"collections": collections_metadata
}

return ResponseCollections(**result)


# DELETE all collections
@router.delete("/collections/")
async def wipe_collections(
request: Request,
) -> Dict:
async def wipe_collections(request: Request) -> ResponseDelete:
"""Delete and create all collections"""

ccat = request.app.state.ccat
Expand All @@ -105,14 +108,16 @@ async def wipe_collections(
ccat.mad_hatter.find_plugins()
ccat.mad_hatter.embed_tools()

return {
result = {
"deleted": to_return,
}

return ResponseDelete(**result)


# DELETE one collection
@router.delete("/collections/{collection_id}/")
async def wipe_single_collection(request: Request, collection_id: str) -> Dict:
async def wipe_single_collection(request: Request, collection_id: str) -> ResponseDelete:
"""Delete and recreate a collection"""

ccat = request.app.state.ccat
Expand All @@ -135,18 +140,20 @@ async def wipe_single_collection(request: Request, collection_id: str) -> Dict:
ccat.mad_hatter.find_plugins()
ccat.mad_hatter.embed_tools()

return {
result = {
"deleted": to_return,
}

return ResponseDelete(**result)


# DELETE memories
@router.delete("/collections/{collection_id}/points/{memory_id}/")
async def wipe_memory_point(
request: Request,
collection_id: str,
memory_id: str
) -> Dict:
) -> ResponseDelete:
"""Delete a specific point in memory"""

ccat = request.app.state.ccat
Expand Down Expand Up @@ -174,17 +181,19 @@ async def wipe_memory_point(
# delete point
vector_memory.collections[collection_id].delete_points([memory_id])

return {
result = {
"deleted": memory_id
}

return ResponseDelete(**result)

@router.delete("/collections/{collection_id}/points")

@router.delete("/collections/{collection_id}/points/")
async def wipe_memory_points_by_metadata(
request: Request,
collection_id: str,
metadata: Dict = {},
) -> Dict:
) -> ResponseDelete:
"""Delete points in memory by filter"""

ccat = request.app.state.ccat
Expand All @@ -193,42 +202,48 @@ async def wipe_memory_points_by_metadata(
# delete points
vector_memory.collections[collection_id].delete_points_by_metadata_filter(metadata)

return {
"deleted": [] # TODO: Qdrant does not return deleted points?
result = {
"deleted": True # TODO: Return list of deleted points by Qdrant
}

return ResponseDelete(**result)


# DELETE conversation history from working memory
@router.delete("/conversation_history/")
async def wipe_conversation_history(
request: Request,
user_id = Depends(check_user_id),
) -> Dict:
) -> ResponseDelete:
"""Delete the specified user's conversation history from working memory"""

# TODO: Add possibility to wipe the working memory of specified user id

ccat = request.app.state.ccat
ccat.working_memory["history"] = []

return {
result = {
"deleted": True,
}

return ResponseDelete(**result)


# GET conversation history from working memory
@router.get("/conversation_history/")
async def get_conversation_history(
request: Request,
user_id = Depends(check_user_id),
) -> Dict:
) -> ResponseConversationHistory:
"""Get the specified user's conversation history from working memory"""

# TODO: Add possibility to get the working memory of specified user id

ccat = request.app.state.ccat
history = ccat.working_memory["history"]

return {
result = {
"history": history
}
}

return ResponseConversationHistory(**result)
4 changes: 0 additions & 4 deletions core/cat/routes/plugins.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import mimetypes
from copy import deepcopy
from typing import Dict
from tempfile import NamedTemporaryFile
from fastapi import Body, Request, APIRouter, HTTPException, UploadFile, BackgroundTasks
from cat.log import log
from cat.mad_hatter.registry import registry_search_plugins, registry_download_plugin
from urllib.parse import urlparse
import requests

from pydantic import ValidationError

router = APIRouter()
Expand Down
15 changes: 7 additions & 8 deletions core/cat/routes/settings.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import Annotated
from fastapi import Body, Response, APIRouter, HTTPException, status
from typing import Dict
from fastapi import APIRouter, HTTPException
from cat.db import models
from cat.db import crud


router = APIRouter()


@router.get("/")
def get_settings(search: str = ""):
def get_settings(search: str = "") -> Dict:
"""Get the entire list of settings available in the database"""

settings = crud.get_settings(search=search)
Expand All @@ -19,7 +18,7 @@ def get_settings(search: str = ""):


@router.post("/")
def create_setting(payload: models.SettingBody):
def create_setting(payload: models.SettingBody) -> Dict:
"""Create a new setting in the database"""

# complete the payload with setting_id and updated_at
Expand All @@ -34,7 +33,7 @@ def create_setting(payload: models.SettingBody):


@router.get("/{settingId}")
def get_setting(settingId: str):
def get_setting(settingId: str) -> Dict:
"""Get the a specific setting from the database"""

setting = crud.get_setting_by_id(settingId)
Expand All @@ -51,7 +50,7 @@ def get_setting(settingId: str):


@router.put("/{settingId}")
def update_setting(settingId: str, payload: models.SettingBody):
def update_setting(settingId: str, payload: models.SettingBody) -> Dict:
"""Update a specific setting in the database if it exists"""

# does the setting exist?
Expand All @@ -77,7 +76,7 @@ def update_setting(settingId: str, payload: models.SettingBody):


@router.delete("/{settingId}")
def delete_setting(settingId: str):
def delete_setting(settingId: str) -> Dict:
"""Delete a specific setting in the database"""

# does the setting exist?
Expand Down
36 changes: 36 additions & 0 deletions core/cat/routes/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Any, Dict, List, Optional, TypedDict
from pydantic import BaseModel

class ResponseDelete(BaseModel):
deleted: Any

class Collection(BaseModel):
name: str
vectors_count: int

class ResponseCollections(BaseModel):
collections: List[Collection]

class RecallCollection(BaseModel):
id: str
score: float
vector: List[int]

class Vector(BaseModel):
embedder: str
collections: Dict[str, RecallCollection]

class Query(BaseModel):
text: str
vector: List[int]

class ResponseMemoryRecall(BaseModel):
query: Query
vectors: Vector

class Conversation(BaseModel):
who: str
message: str

class ResponseConversationHistory(BaseModel):
history: List[Conversation]
19 changes: 13 additions & 6 deletions core/cat/routes/upload.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import mimetypes
import requests
from typing import Dict

from fastapi import Body, Request, APIRouter, UploadFile, BackgroundTasks, HTTPException

from cat.log import log

router = APIRouter()
Expand All @@ -19,7 +17,10 @@ async def upload_file(
default=400,
description="Maximum length of each chunk after the document is split (in characters)",
),
chunk_overlap: int = Body(default=100, description="Chunk overlap (in characters)")
chunk_overlap: int = Body(
default=100,
description="Chunk overlap (in characters)"
)
) -> Dict:
"""Upload a file containing text (.txt, .md, .pdf, etc.). File content will be extracted and segmented into chunks.
Chunks will be then vectorized and stored into documents memory.
Expand Down Expand Up @@ -66,8 +67,11 @@ async def upload_url(
default=400,
description="Maximum length of each chunk after the document is split (in characters)",
),
chunk_overlap: int = Body(default=100, description="Chunk overlap (in characters)")
):
chunk_overlap: int = Body(
default=100,
description="Chunk overlap (in characters)"
)
) -> Dict:
"""Upload a url. Website content will be extracted and segmented into chunks.
Chunks will be then vectorized and stored into documents memory."""
# check that URL is valid
Expand All @@ -89,7 +93,10 @@ async def upload_url(
background_tasks.add_task(
ccat.rabbit_hole.ingest_file, url, chunk_size, chunk_overlap
)
return {"url": url, "info": "URL is being ingested asynchronously"}
return {
"url": url,
"info": "URL is being ingested asynchronously"
}
else:
raise HTTPException(
status_code=400,
Expand Down