Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add baseline Architecture to support auth + User sessions + Basic Auth #90

Merged
merged 16 commits into from
May 7, 2024
Merged
5 changes: 4 additions & 1 deletion .env-template
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,7 @@ AZURE_CHAT_ENDPOINT_URL=<ENDPOINT URL>
USE_EXPERIMENTAL_LANGCHAIN=False

# Community features
USE_COMMUNITY_FEATURES='True'
USE_COMMUNITY_FEATURES='True'

# Auth session
SESSION_SECRET_KEY=<GENERATE_A_SECRET_KEY>
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ reset-db:
docker volume rm cohere_toolkit_db
setup:
poetry install --only setup --verbose
poetry run python3 src/backend/cli/main.py
poetry run python3 cli/main.py
lint:
poetry run black .
poetry run isort .
Expand Down
170 changes: 157 additions & 13 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ py-expression-eval = "^0.3.14"
tavily-python = "^0.3.3"
arxiv = "^2.1.0"
xmltodict = "^0.13.0"
authlib = "^1.3.0"
itsdangerous = "^2.2.0"
bcrypt = "^4.1.2"

[tool.poetry.group.dev.dependencies]
pytest = "^7.1.2"
Expand Down
34 changes: 34 additions & 0 deletions src/backend/alembic/versions/b88f00283a27_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""empty message

Revision ID: b88f00283a27
Revises: 2853273872ca
Create Date: 2024-05-02 19:19:52.608062

"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "b88f00283a27"
down_revision: Union[str, None] = "2853273872ca"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"users", sa.Column("hashed_password", sa.LargeBinary(), nullable=True)
)
op.create_unique_constraint("unique_user_email", "users", ["email"])
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("unique_user_email", "users", type_="unique")
op.drop_column("users", "hashed_password")
# ### end Alembic commands ###
26 changes: 26 additions & 0 deletions src/backend/alembic/versions/c15b848babe3_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""empty message

Revision ID: c15b848babe3
Revises: 6553b76de6ca, b88f00283a27
Create Date: 2024-05-07 15:59:05.436751

"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "c15b848babe3"
down_revision: Union[str, None] = ("6553b76de6ca", "b88f00283a27")
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
pass


def downgrade() -> None:
pass
9 changes: 9 additions & 0 deletions src/backend/config/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from backend.services.auth import BasicAuthentication

# Modify this to enable auth strategies.
ENABLED_AUTH_STRATEGIES = []

# Define the mapping from Auth strategy name to class obj.
# Does not need to be manually modified.
# Ex: {"Basic": BasicAuthentication}
ENABLED_AUTH_STRATEGY_MAPPING = {cls.NAME: cls for cls in ENABLED_AUTH_STRATEGIES}
2 changes: 1 addition & 1 deletion src/backend/crud/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def update_user(db: Session, user: User, new_user: UpdateUser) -> User:
Returns:
User: Updated user.
"""
for attr, value in new_user.model_dump().items():
for attr, value in new_user.model_dump(exclude_none=True).items():
setattr(user, attr, value)
db.commit()
db.refresh(user)
Expand Down
4 changes: 4 additions & 0 deletions src/backend/database_models/user.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional

from sqlalchemy import UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column

from backend.database_models.base import Base
Expand All @@ -10,3 +11,6 @@ class User(Base):

fullname: Mapped[str] = mapped_column()
email: Mapped[Optional[str]] = mapped_column()
hashed_password: Mapped[Optional[bytes]] = mapped_column()

__table_args__ = (UniqueConstraint("email", name="unique_user_email"),)
36 changes: 32 additions & 4 deletions src/backend/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import os
from contextlib import asynccontextmanager

from alembic.command import upgrade
from alembic.config import Config
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware

from backend.config.auth import ENABLED_AUTH_STRATEGY_MAPPING
from backend.routers.auth import router as auth_router
from backend.routers.chat import router as chat_router
from backend.routers.conversation import router as conversation_router
from backend.routers.deployment import router as deployment_router
Expand All @@ -15,32 +19,56 @@

load_dotenv()

ORIGINS = ["*"]
tianjing-li marked this conversation as resolved.
Show resolved Hide resolved


@asynccontextmanager
async def lifespan(app: FastAPI):
yield


origins = ["*"]


def create_app():
app = FastAPI(lifespan=lifespan)

# Add routers
app.include_router(auth_router)
app.include_router(chat_router)
app.include_router(user_router)
app.include_router(conversation_router)
app.include_router(tool_router)
app.include_router(deployment_router)
app.include_router(experimental_feature_router)

# Add middleware
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_origins=ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

if ENABLED_AUTH_STRATEGY_MAPPING:
secret_key = os.environ.get("SESSION_SECRET_KEY", None)

if not secret_key:
raise ValueError(
"Missing SESSION_SECRET_KEY environment variable to enable Authentication."
)

# Handle User sessions and Auth
app.add_middleware(
SessionMiddleware,
secret_key=secret_key,
)

# Add auth
for auth in ENABLED_AUTH_STRATEGY_MAPPING.values():
if auth.SHOULD_ATTACH_TO_APP:
# TODO: Add app attachment logic for eg OAuth:
# https://docs.authlib.org/en/latest/client/fastapi.html
pass

return app


Expand Down
12 changes: 6 additions & 6 deletions src/backend/model_deployments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ class BaseDeployment:
@abstractmethod
def rerank_enabled(self) -> bool: ...

@staticmethod
def list_models() -> List[str]: ...

@staticmethod
def is_available() -> bool: ...

@abstractmethod
def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: ...

Expand All @@ -45,9 +51,3 @@ def invoke_rerank(

@abstractmethod
def invoke_tools(self, message: str, tools: List[Any], **kwargs: Any) -> Any: ...

@staticmethod
def list_models() -> List[str]: ...

@staticmethod
def is_available() -> bool: ...
102 changes: 102 additions & 0 deletions src/backend/routers/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from fastapi import APIRouter, Depends, HTTPException
from starlette.requests import Request

from backend.config.auth import ENABLED_AUTH_STRATEGY_MAPPING
from backend.database_models import get_session
from backend.database_models.database import DBSessionDep
from backend.schemas.auth import Login

router = APIRouter(dependencies=[Depends(get_session)])


@router.get("/session")
def get_session(request: Request):
"""
Retrievers the current session user.

Args:
request (Request): current Request object.

Returns:
session: current user session ({} if no active session)

Raises:
401 HTTPException if no user found in session.
"""

if not request.session:
raise HTTPException(status_code=401, detail="Not authenticated.")

return request.session.get("user")


@router.post("/login")
async def login(request: Request, login: Login, session: DBSessionDep):
"""
Logs user in, verifying their credentials and either setting the user session,
or redirecting to /auth endpoint.

Args:er
request (Request): current Request object.
login (Login): Login payload.
session (DBSessionDep): Database session.

Returns:
dict: On success.

Raises:
HTTPException: If the strategy or payload are invalid, or if the login fails.
"""
strategy_name = login.strategy
payload = login.payload

# Check the strategy is valid and enabled
if strategy_name not in ENABLED_AUTH_STRATEGY_MAPPING.keys():
raise HTTPException(
status_code=404, detail=f"Invalid Authentication strategy: {strategy_name}."
)

# Check that the payload required is given
strategy = ENABLED_AUTH_STRATEGY_MAPPING[strategy_name]
strategy_payload = strategy.get_required_payload()
if not set(strategy_payload).issubset(payload):
missing_keys = [key for key in strategy_payload if key not in payload.keys()]
raise HTTPException(
status_code=404,
detail=f"Missing the following keys in the payload: {missing_keys}.",
)

# Do login
user = strategy.login(session, payload)
if not user:
raise HTTPException(
status_code=401,
detail=f"Error performing {strategy_name} authentication with payload: {payload}.",
)

# Set session user
request.session["user"] = user

return {}


@router.post("/auth")
async def auth(request: Request):
# TODO: Implement for OAuth strategies
return {}


@router.get("/logout")
async def logout(request: Request):
"""
Logs out the current user session.

Args:
request (Request): current Request object.

Returns:
dict: On success.
"""
request.session.pop("user", None)

return {}
2 changes: 1 addition & 1 deletion src/backend/routers/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def create_user(user: CreateUser, session: DBSessionDep) -> User:
Returns:
User: Created user.
"""
db_user = UserModel(**user.model_dump())
db_user = UserModel(**user.model_dump(exclude_none=True))
db_user = user_crud.create_user(session, db_user)

return db_user
Expand Down
9 changes: 9 additions & 0 deletions src/backend/schemas/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from pydantic import BaseModel


class Login(BaseModel):
strategy: str
payload: dict[str, str]

class Config:
from_attributes = True
24 changes: 21 additions & 3 deletions src/backend/schemas/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from pydantic import BaseModel

from backend.services.auth import BasicAuthentication


class UserBase(BaseModel):
fullname: str
Expand All @@ -18,13 +20,29 @@ class Config:
from_attributes = True


class CreateUser(UserBase):
pass
class UserPassword(BaseModel):
password: Optional[str] = None
hashed_password: Optional[bytes] = None

def __init__(self, **data):
password = data.pop("password", None)

if password is not None:
data["hashed_password"] = BasicAuthentication.hash_and_salt_password(
password
)

super().__init__(**data)

class UpdateUser(UserBase):

class CreateUser(UserBase, UserPassword):
pass


class UpdateUser(UserPassword):
fullname: Optional[str] = None
email: Optional[str] = None


class DeleteUser(BaseModel):
pass
Empty file.
5 changes: 5 additions & 0 deletions src/backend/services/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from backend.services.auth.basic import BasicAuthentication

__all__ = [
"BasicAuthentication",
]