forked from cohere-ai/cohere-toolkit
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add baseline Architecture to support auth + User sessions + Basic Auth (
cohere-ai#90) * Auth wip * Add tests * working basic auth * add test coverage: still todo add session tests * Add session tests * Add docs * fix types * merge alembic * add secret key fixture * fix tests
- Loading branch information
1 parent
968e2ef
commit 110d47b
Showing
24 changed files
with
770 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
"""empty message | ||
Revision ID: b88f00283a27 | ||
Revises: 2853273872ca | ||
Create Date: 2024-05-02 19:19:52.608062 | ||
""" | ||
|
||
from typing import Sequence, Union | ||
|
||
import sqlalchemy as sa | ||
from alembic import op | ||
|
||
# revision identifiers, used by Alembic. | ||
revision: str = "b88f00283a27" | ||
down_revision: Union[str, None] = "2853273872ca" | ||
branch_labels: Union[str, Sequence[str], None] = None | ||
depends_on: Union[str, Sequence[str], None] = None | ||
|
||
|
||
def upgrade() -> None: | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
op.add_column( | ||
"users", sa.Column("hashed_password", sa.LargeBinary(), nullable=True) | ||
) | ||
op.create_unique_constraint("unique_user_email", "users", ["email"]) | ||
# ### end Alembic commands ### | ||
|
||
|
||
def downgrade() -> None: | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
op.drop_constraint("unique_user_email", "users", type_="unique") | ||
op.drop_column("users", "hashed_password") | ||
# ### end Alembic commands ### |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
"""empty message | ||
Revision ID: c15b848babe3 | ||
Revises: 6553b76de6ca, b88f00283a27 | ||
Create Date: 2024-05-07 15:59:05.436751 | ||
""" | ||
|
||
from typing import Sequence, Union | ||
|
||
import sqlalchemy as sa | ||
from alembic import op | ||
|
||
# revision identifiers, used by Alembic. | ||
revision: str = "c15b848babe3" | ||
down_revision: Union[str, None] = ("6553b76de6ca", "b88f00283a27") | ||
branch_labels: Union[str, Sequence[str], None] = None | ||
depends_on: Union[str, Sequence[str], None] = None | ||
|
||
|
||
def upgrade() -> None: | ||
pass | ||
|
||
|
||
def downgrade() -> None: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from backend.services.auth import BasicAuthentication | ||
|
||
# Modify this to enable auth strategies. | ||
ENABLED_AUTH_STRATEGIES = [] | ||
|
||
# Define the mapping from Auth strategy name to class obj. | ||
# Does not need to be manually modified. | ||
# Ex: {"Basic": BasicAuthentication} | ||
ENABLED_AUTH_STRATEGY_MAPPING = {cls.NAME: cls for cls in ENABLED_AUTH_STRATEGIES} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from fastapi import APIRouter, Depends, HTTPException | ||
from starlette.requests import Request | ||
|
||
from backend.config.auth import ENABLED_AUTH_STRATEGY_MAPPING | ||
from backend.database_models import get_session | ||
from backend.database_models.database import DBSessionDep | ||
from backend.schemas.auth import Login | ||
|
||
router = APIRouter(dependencies=[Depends(get_session)]) | ||
|
||
|
||
@router.get("/session") | ||
def get_session(request: Request): | ||
""" | ||
Retrievers the current session user. | ||
Args: | ||
request (Request): current Request object. | ||
Returns: | ||
session: current user session ({} if no active session) | ||
Raises: | ||
401 HTTPException if no user found in session. | ||
""" | ||
|
||
if not request.session: | ||
raise HTTPException(status_code=401, detail="Not authenticated.") | ||
|
||
return request.session.get("user") | ||
|
||
|
||
@router.post("/login") | ||
async def login(request: Request, login: Login, session: DBSessionDep): | ||
""" | ||
Logs user in, verifying their credentials and either setting the user session, | ||
or redirecting to /auth endpoint. | ||
Args:er | ||
request (Request): current Request object. | ||
login (Login): Login payload. | ||
session (DBSessionDep): Database session. | ||
Returns: | ||
dict: On success. | ||
Raises: | ||
HTTPException: If the strategy or payload are invalid, or if the login fails. | ||
""" | ||
strategy_name = login.strategy | ||
payload = login.payload | ||
|
||
# Check the strategy is valid and enabled | ||
if strategy_name not in ENABLED_AUTH_STRATEGY_MAPPING.keys(): | ||
raise HTTPException( | ||
status_code=404, detail=f"Invalid Authentication strategy: {strategy_name}." | ||
) | ||
|
||
# Check that the payload required is given | ||
strategy = ENABLED_AUTH_STRATEGY_MAPPING[strategy_name] | ||
strategy_payload = strategy.get_required_payload() | ||
if not set(strategy_payload).issubset(payload): | ||
missing_keys = [key for key in strategy_payload if key not in payload.keys()] | ||
raise HTTPException( | ||
status_code=404, | ||
detail=f"Missing the following keys in the payload: {missing_keys}.", | ||
) | ||
|
||
# Do login | ||
user = strategy.login(session, payload) | ||
if not user: | ||
raise HTTPException( | ||
status_code=401, | ||
detail=f"Error performing {strategy_name} authentication with payload: {payload}.", | ||
) | ||
|
||
# Set session user | ||
request.session["user"] = user | ||
|
||
return {} | ||
|
||
|
||
@router.post("/auth") | ||
async def auth(request: Request): | ||
# TODO: Implement for OAuth strategies | ||
return {} | ||
|
||
|
||
@router.get("/logout") | ||
async def logout(request: Request): | ||
""" | ||
Logs out the current user session. | ||
Args: | ||
request (Request): current Request object. | ||
Returns: | ||
dict: On success. | ||
""" | ||
request.session.pop("user", None) | ||
|
||
return {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from pydantic import BaseModel | ||
|
||
|
||
class Login(BaseModel): | ||
strategy: str | ||
payload: dict[str, str] | ||
|
||
class Config: | ||
from_attributes = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from backend.services.auth.basic import BasicAuthentication | ||
|
||
__all__ = [ | ||
"BasicAuthentication", | ||
] |
Oops, something went wrong.