Skip to content

Commit

Permalink
backend: Make endpoints with IO operations async by default (#145)
Browse files Browse the repository at this point in the history
Make endpoints with IO operations async by default
  • Loading branch information
tianjing-li committed May 17, 2024
1 parent f2fb766 commit a29a626
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 127 deletions.
2 changes: 1 addition & 1 deletion src/backend/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ async def chat_stream(


@router.post("/chat", dependencies=[Depends(validate_deployment_header)])
def chat(
async def chat(
session: DBSessionDep,
chat_request: CohereChatRequest,
request: Request,
Expand Down
64 changes: 4 additions & 60 deletions src/backend/routers/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

# CONVERSATIONS
@router.get("/{conversation_id}", response_model=Conversation)
def get_conversation(
async def get_conversation(
conversation_id: str, session: DBSessionDep, request: Request
) -> Conversation:
""" "
Expand Down Expand Up @@ -57,7 +57,7 @@ def get_conversation(


@router.get("", response_model=list[ConversationWithoutMessages])
def list_conversations(
async def list_conversations(
*, offset: int = 0, limit: int = 100, session: DBSessionDep, request: Request
) -> list[ConversationWithoutMessages]:
"""
Expand All @@ -79,7 +79,7 @@ def list_conversations(


@router.put("/{conversation_id}", response_model=Conversation)
def update_conversation(
async def update_conversation(
conversation_id: str,
new_conversation: UpdateConversation,
session: DBSessionDep,
Expand Down Expand Up @@ -117,7 +117,7 @@ def update_conversation(


@router.delete("/{conversation_id}")
def delete_conversation(
async def delete_conversation(
conversation_id: str, session: DBSessionDep, request: Request
) -> DeleteConversation:
"""
Expand Down Expand Up @@ -148,62 +148,6 @@ def delete_conversation(
return DeleteConversation()


# FILES
@router.post("/{conversation_id}/upload_file", response_model=UploadFile)
async def upload_file_with_conversation(
conversation_id: str,
session: DBSessionDep,
request: Request,
file: FastAPIUploadFile = RequestFile(...),
) -> UploadFile:
"""
(TO BE DEPRECATED)
Uploads a file to a conversation.
Args:
conversation_id (str): Conversation ID.
session (DBSessionDep): Database session.
file (FastAPIUploadFile): File to be uploaded.
Returns:
UploadFile: Uploaded file.
Raises:
HTTPException: If the conversation with the given ID is not found. Status code 404.
HTTPException: If the file wasn't uploaded correctly. Status code 500.
"""
user_id = request.headers.get("User-Id")
conversation = conversation_crud.get_conversation(session, conversation_id, user_id)

if not conversation:
raise HTTPException(
status_code=404,
detail=f"Conversation with ID: {conversation_id} not found.",
)

file_path = FileService().upload_file(file)

# Raise exception if file wasn't uploaded
if not file_path.exists():
raise HTTPException(
status_code=500, detail=f"Error while uploading file {file.filename}."
)

db_file = FileModel(
user_id=conversation.user_id,
conversation_id=conversation.id,
file_name=file_path.name,
file_path=str(file_path),
file_size=file_path.stat().st_size,
)

db_file = file_crud.create_file(session, db_file)

return db_file


@router.post("/upload_file", response_model=UploadFile)
async def upload_file(
session: DBSessionDep,
Expand Down
2 changes: 1 addition & 1 deletion src/backend/routers/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def list_deployments(all: bool = False) -> list[Deployment]:


@router.post("/{name}/set_env_vars", response_class=Response)
def set_env_vars(
async def set_env_vars(
name: str, env_vars: UpdateDeploymentEnv, valid_env_vars=Depends(validate_env_vars)
):
"""
Expand Down
12 changes: 7 additions & 5 deletions src/backend/routers/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@router.post("/", response_model=User)
def create_user(user: CreateUser, session: DBSessionDep) -> User:
async def create_user(user: CreateUser, session: DBSessionDep) -> User:
"""
Create a new user.
Expand All @@ -28,7 +28,7 @@ def create_user(user: CreateUser, session: DBSessionDep) -> User:


@router.get("/", response_model=list[User])
def list_users(
async def list_users(
*, offset: int = 0, limit: int = 100, session: DBSessionDep
) -> list[User]:
"""
Expand All @@ -46,7 +46,7 @@ def list_users(


@router.get("/{user_id}", response_model=User)
def get_user(user_id: str, session: DBSessionDep) -> User:
async def get_user(user_id: str, session: DBSessionDep) -> User:
"""
Get a user by ID.
Expand All @@ -72,7 +72,9 @@ def get_user(user_id: str, session: DBSessionDep) -> User:


@router.put("/{user_id}", response_model=User)
def update_user(user_id: str, new_user: UpdateUser, session: DBSessionDep) -> User:
async def update_user(
user_id: str, new_user: UpdateUser, session: DBSessionDep
) -> User:
"""
Update a user by ID.
Expand Down Expand Up @@ -100,7 +102,7 @@ def update_user(user_id: str, new_user: UpdateUser, session: DBSessionDep) -> Us


@router.delete("/{user_id}")
def delete_user(user_id: str, session: DBSessionDep) -> DeleteUser:
async def delete_user(user_id: str, session: DBSessionDep) -> DeleteUser:
""" "
Delete a user by ID.
Expand Down
60 changes: 0 additions & 60 deletions src/backend/tests/routers/test_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,66 +379,6 @@ def test_list_files_missing_user_id(
assert response.json() == {"detail": "User-Id required in request headers."}


def test_upload_file_on_conversation(
session_client: TestClient, session: Session
) -> None:
file_path = "src/backend/tests/test_data/Mariana_Trench.pdf"
saved_file_path = "src/backend/data/Mariana_Trench.pdf"
conversation = get_factory("Conversation", session).create()
file_doc = {"file": open(file_path, "rb")}
response = session_client.post(
f"/v1/conversations/{conversation.id}/upload_file",
files=file_doc,
headers={"User-Id": conversation.user_id},
)
response_file = response.json()

assert response.status_code == 200
assert "Mariana_Trench" in response_file["file_name"]
assert response_file["conversation_id"] == conversation.id
assert response_file["user_id"] == conversation.user_id

# Clean up - remove the file from the directory
os.remove(saved_file_path)


def test_fail_upload_file_on_conversation_missing_data(
session_client: TestClient, session: Session
) -> None:
conversation = get_factory("Conversation", session).create()
response = session_client.post(
f"/v1/conversations/{conversation.id}/upload_file",
json={},
headers={"User-Id": conversation.user_id},
)
response_file = response.json()

assert response.status_code == 422
assert response_file == {
"detail": [
{
"type": "missing",
"loc": ["body", "file"],
"msg": "Field required",
"input": None,
"url": "https://errors.pydantic.dev/2.7/v/missing",
}
]
}


def test_upload_file_on_conversation_missing_user_id(
session_client: TestClient, session: Session
) -> None:
conversation = get_factory("Conversation", session).create()
response = session_client.post(
f"/v1/conversations/{conversation.id}/upload_file", json={}
)

assert response.status_code == 401
assert response.json() == {"detail": "User-Id required in request headers."}


def test_upload_file_existing_conversation(
session_client: TestClient, session: Session
) -> None:
Expand Down

0 comments on commit a29a626

Please sign in to comment.