diff --git a/src/api/v1/auth/__init__.py b/src/api/v1/auth/__init__.py index 1f6d5ff..4885ab2 100644 --- a/src/api/v1/auth/__init__.py +++ b/src/api/v1/auth/__init__.py @@ -1,5 +1,12 @@ from src.api.v1.auth.routes import router +from src.api.v1.auth.deps import OAuth2BearerDepends, OptionalOAuth2BearerDepends, OAuth2PasswordRequestFormDepends +from src.api.v1.auth.schemas import AccessTokenResponse, JWT __all__ = [ "router", + "OAuth2BearerDepends", + "OptionalOAuth2BearerDepends", + "OAuth2PasswordRequestFormDepends", + "AccessTokenResponse", + "JWT", ] diff --git a/src/api/v1/auth/deps.py b/src/api/v1/auth/deps.py index 4b78e8d..0298b11 100644 --- a/src/api/v1/auth/deps.py +++ b/src/api/v1/auth/deps.py @@ -3,7 +3,12 @@ from fastapi import Depends from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm -__oauth2_bearer = OAuth2PasswordBearer("api/v1/auth/login") +TOKEN_URL = "api/v1/auth/login" -PasswordBearer = Annotated[str, Depends(__oauth2_bearer)] -PasswordForm = Annotated[OAuth2PasswordRequestForm, Depends()] +OAUTH2_BEARER = OAuth2PasswordBearer(TOKEN_URL) +OPTIONAL_OAUTH2_BEARER = OAuth2PasswordBearer(TOKEN_URL, auto_error=False) + +OAuth2BearerDepends = Annotated[str, Depends(OAUTH2_BEARER)] +OptionalOAuth2BearerDepends = Annotated[str | None, Depends(OPTIONAL_OAUTH2_BEARER)] + +OAuth2PasswordRequestFormDepends = Annotated[OAuth2PasswordRequestForm, Depends()] diff --git a/src/api/v1/auth/routes.py b/src/api/v1/auth/routes.py index f933b35..e1f2f5f 100644 --- a/src/api/v1/auth/routes.py +++ b/src/api/v1/auth/routes.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, HTTPException, status -from src.api.v1.auth.deps import PasswordForm +from src.api.v1.auth.deps import OAuth2PasswordRequestFormDepends from src.api.v1.auth.schemas import AccessTokenResponse from src.api.v1.otp.service import expire_otp_if_correct from src.api.v1.users import UserServiceDepends @@ -24,7 +24,7 @@ async def register(args: UserRegistrationRequest, service: UserServiceDepends) - @router.post("/login") -async def login(form: PasswordForm, service: UserServiceDepends) -> AccessTokenResponse: +async def login(form: OAuth2PasswordRequestFormDepends, service: UserServiceDepends) -> AccessTokenResponse: email = form.username # The OAuth2 spec requires the exact name `username`. user = service.get_user_by_email(email) diff --git a/src/api/v1/users/me/deps.py b/src/api/v1/users/me/deps.py index 495e005..a223b6c 100644 --- a/src/api/v1/users/me/deps.py +++ b/src/api/v1/users/me/deps.py @@ -3,14 +3,13 @@ import jwt from fastapi import Depends, HTTPException, status -from src.api.v1.auth.deps import PasswordBearer -from src.api.v1.auth.schemas import JWT +from src.api.v1.auth import JWT, OAuth2BearerDepends, OptionalOAuth2BearerDepends from src.api.v1.users.models import User from src.config import settings from src.db.deps import Session -def __get_current_user(session: Session, raw: PasswordBearer) -> User: +def get_current_user(session: Session, raw: OAuth2BearerDepends) -> User: try: data = JWT(**jwt.decode(raw, settings.jwt.secret, algorithms=[settings.jwt.algorithm])) except Exception: @@ -24,4 +23,9 @@ def __get_current_user(session: Session, raw: PasswordBearer) -> User: return user -CurrentUser = Annotated[User, Depends(__get_current_user)] +def get_current_user_or_none(session: Session, raw: OptionalOAuth2BearerDepends) -> User | None: + return get_current_user(session, raw) if raw else None + + +CurrentUser = Annotated[User, Depends(get_current_user)] +CurrentUserOrNone = Annotated[User | None, Depends(get_current_user_or_none)]