Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add baseline Architecture to support auth + User sessions + Basic Auth #90

Merged
merged 16 commits into from
May 7, 2024
Merged
170 changes: 157 additions & 13 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ py-expression-eval = "^0.3.14"
tavily-python = "^0.3.3"
arxiv = "^2.1.0"
xmltodict = "^0.13.0"
authlib = "^1.3.0"
itsdangerous = "^2.2.0"
bcrypt = "^4.1.2"

[tool.poetry.group.dev.dependencies]
pytest = "^7.1.2"
Expand Down
34 changes: 34 additions & 0 deletions src/backend/alembic/versions/b88f00283a27_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""empty message

Revision ID: b88f00283a27
Revises: 2853273872ca
Create Date: 2024-05-02 19:19:52.608062

"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "b88f00283a27"
down_revision: Union[str, None] = "2853273872ca"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"users", sa.Column("hashed_password", sa.LargeBinary(), nullable=True)
)
op.create_unique_constraint("unique_user_email", "users", ["email"])
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("unique_user_email", "users", type_="unique")
op.drop_column("users", "hashed_password")
# ### end Alembic commands ###
12 changes: 6 additions & 6 deletions src/backend/chat/custom/model_deployments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ class BaseDeployment:
@abstractmethod
def rerank_enabled(self) -> bool: ...

@staticmethod
def list_models() -> List[str]: ...

@staticmethod
def is_available() -> bool: ...

@abstractmethod
def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: ...

Expand All @@ -45,9 +51,3 @@ def invoke_rerank(

@abstractmethod
def invoke_tools(self, message: str, tools: List[Any], **kwargs: Any) -> Any: ...

@staticmethod
def list_models() -> List[str]: ...

@staticmethod
def is_available() -> bool: ...
5 changes: 5 additions & 0 deletions src/backend/config/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from backend.services.auth import BasicAuthentication

ENABLED_AUTH_STRATEGIES = [
BasicAuthentication,
]
tianjing-li marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion src/backend/crud/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def update_user(db: Session, user: User, new_user: UpdateUser) -> User:
Returns:
User: Updated user.
"""
for attr, value in new_user.model_dump().items():
for attr, value in new_user.model_dump(exclude_none=True).items():
setattr(user, attr, value)
db.commit()
db.refresh(user)
Expand Down
26 changes: 22 additions & 4 deletions src/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware

from backend.config.auth import ENABLED_AUTH_STRATEGIES
from backend.routers.auth import router as auth_router
from backend.routers.chat import router as chat_router
from backend.routers.conversation import router as conversation_router
from backend.routers.deployment import router as deployment_router
Expand All @@ -15,32 +18,47 @@

load_dotenv()

ORIGINS = ["*"]
tianjing-li marked this conversation as resolved.
Show resolved Hide resolved


@asynccontextmanager
async def lifespan(app: FastAPI):
yield


origins = ["*"]


def create_app():
app = FastAPI(lifespan=lifespan)

# Add routers
app.include_router(auth_router)
app.include_router(chat_router)
app.include_router(user_router)
app.include_router(conversation_router)
app.include_router(tool_router)
app.include_router(deployment_router)
app.include_router(experimental_feature_router)

# Add auth
for auth in ENABLED_AUTH_STRATEGIES:
if auth.SHOULD_ATTACH_TO_APP:
tianjing-li marked this conversation as resolved.
Show resolved Hide resolved
# TODO: Add app attachment logic for eg OAuth:
# https://docs.authlib.org/en/latest/client/fastapi.html
pass

# Add middleware
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_origins=ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

app.add_middleware(
SessionMiddleware,
secret_key="abcd", # TODO: Replace with os.env crypto key
)

return app


Expand Down
4 changes: 4 additions & 0 deletions src/backend/models/user.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional

from sqlalchemy import UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column

from backend.models.base import Base
Expand All @@ -10,3 +11,6 @@ class User(Base):

fullname: Mapped[str] = mapped_column()
email: Mapped[Optional[str]] = mapped_column()
hashed_password: Mapped[Optional[bytes]] = mapped_column()

__table_args__ = (UniqueConstraint("email", name="unique_user_email"),)
99 changes: 99 additions & 0 deletions src/backend/routers/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from fastapi import APIRouter, Depends, HTTPException
from starlette.requests import Request

from backend.config.auth import ENABLED_AUTH_STRATEGIES
from backend.models import get_session
from backend.models.database import DBSessionDep
from backend.schemas.auth import Login

# Define the mapping from Auth strategy name to class obj
# Ex: {"Basic": BasicAuthentication}
ENABLED_AUTH_STRATEGY_MAPPING = {cls.NAME: cls for cls in ENABLED_AUTH_STRATEGIES}

router = APIRouter(dependencies=[Depends(get_session)])


@router.get("/session")
def get_session(request: Request):
"""
Retrievers the current session user.

Args:
request (Request): current Request object.

Returns:
session: current user session ({} if no active session)
"""
return request.session


@router.post("/login")
async def login(request: Request, login: Login, session: DBSessionDep):
"""
Logs user in, verifying their credentials and either setting the user session,
or redirecting to /auth endpoint.

Args:
request (Request): current Request object.
login (Login): Login payload.
session (DBSessionDep): Database session.

Returns:
dict: On success.

Raises:
HTTPException: If the strategy or payload are invalid, or if the login fails.
"""
strategy_name = login.strategy
payload = login.payload

# Check the strategy is valid and enabled
if strategy_name not in ENABLED_AUTH_STRATEGY_MAPPING.keys():
raise HTTPException(
status_code=404, detail=f"Invalid Authentication strategy: {strategy_name}."
)

# Check that the payload required is given
strategy = ENABLED_AUTH_STRATEGY_MAPPING[strategy_name]
strategy_payload = strategy.get_required_payload()
if not set(strategy_payload).issubset(payload):
missing_keys = [key for key in strategy_payload if key not in payload.keys()]
raise HTTPException(
status_code=404,
detail=f"Missing the following keys in the payload: {missing_keys}.",
)

# Do login
user = strategy.login(session, payload)
if not user:
raise HTTPException(
status_code=401,
detail=f"Error performing {strategy_name} authentication with payload: {payload}.",
)

# Set session user
request.session["user"] = user

return {}


@router.post("/auth")
async def auth(request: Request):
# TODO: Implement for OAuth strategies
return {}


@router.get("/logout")
async def logout(request: Request):
"""
Logs out the current user session.

Args:
request (Request): current Request object.

Returns:
dict: On success.
"""
request.session.pop("user", None)

return {}
2 changes: 1 addition & 1 deletion src/backend/routers/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def create_user(user: CreateUser, session: DBSessionDep) -> User:
Returns:
User: Created user.
"""
db_user = UserModel(**user.model_dump())
db_user = UserModel(**user.model_dump(exclude_none=True))
db_user = user_crud.create_user(session, db_user)

return db_user
Expand Down
9 changes: 9 additions & 0 deletions src/backend/schemas/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from pydantic import BaseModel


class Login(BaseModel):
strategy: str
payload: dict[str, str]

class Config:
from_attributes = True
24 changes: 21 additions & 3 deletions src/backend/schemas/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from pydantic import BaseModel

from backend.services.auth import BasicAuthentication


class UserBase(BaseModel):
fullname: str
Expand All @@ -18,13 +20,29 @@ class Config:
from_attributes = True


class CreateUser(UserBase):
pass
class UserPassword(BaseModel):
password: Optional[str] = None
hashed_password: Optional[bytes] = None

def __init__(self, **data):
password = data.pop("password", None)

if password is not None:
data["hashed_password"] = BasicAuthentication.hash_and_salt_password(
password
)

super().__init__(**data)

class UpdateUser(UserBase):

class CreateUser(UserBase, UserPassword):
pass


class UpdateUser(UserPassword):
fullname: Optional[str] = None
email: Optional[str] = None


class DeleteUser(BaseModel):
pass
Empty file.
5 changes: 5 additions & 0 deletions src/backend/services/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from backend.services.auth.basic import BasicAuthentication

__all__ = [
"BasicAuthentication",
]
39 changes: 39 additions & 0 deletions src/backend/services/auth/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from abc import abstractmethod
from typing import Any, List


class BaseAuthenticationStrategy:
"""
Base strategy for authentication, abstract class that should be inherited from.

Attributes:
NAME (str): The name of the strategy.
SHOULD_ATTACH_TO_APP (str): Whether the strategy needs to be attached to the FastAPI application.
SHOULD_AUTH_REDIRECT (str): Whether the strategy requires a redirect to the /auth endpoint after login.
"""

NAME = "Base"
SHOULD_ATTACH_TO_APP = False
SHOULD_AUTH_REDIRECT = False

@staticmethod
def get_required_payload(self) -> List[str]:
"""
The required /login payload for the Auth strategy
"""
...

@classmethod
def login(cls, **kwargs: Any):
"""
Login logic: dealing with checking credentials.
"""
...

@classmethod
def authenticate(cls, **kwargs: Any):
"""
Authentication logic: dealing with user data and returning it
to set the current user session.
"""
...