Skip to content

Commit

Permalink
refactor: use user service instead of plain functions
Browse files Browse the repository at this point in the history
  • Loading branch information
zobweyt committed Nov 1, 2024
1 parent d5f63a6 commit 32367ec
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 75 deletions.
20 changes: 10 additions & 10 deletions src/api/v1/auth/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,30 @@
from src.api.v1.auth.deps import PasswordForm
from src.api.v1.auth.schemas import AccessTokenResponse
from src.api.v1.otp.service import expire_otp_if_correct
from src.api.v1.users import UserServiceDepends
from src.api.v1.users.schemas import UserPasswordResetRequest, UserRegistrationRequest
from src.api.v1.users.service import get_user_by_email, is_email_registered, register_user, update_password
from src.db.deps import Session
from src.i18n import gettext as _
from src.security import create_access_token, is_valid_password

router = APIRouter(prefix="/auth", tags=["Authentication"])


@router.post("/register", status_code=status.HTTP_201_CREATED)
async def register(args: UserRegistrationRequest, session: Session) -> AccessTokenResponse:
if is_email_registered(session, args.email):
async def register(args: UserRegistrationRequest, service: UserServiceDepends) -> AccessTokenResponse:
if service.is_email_registered(args.email):
raise HTTPException(status.HTTP_409_CONFLICT, _("Email already registered."))
if not expire_otp_if_correct(args.email, args.otp):
raise HTTPException(status.HTTP_406_NOT_ACCEPTABLE, _("The One-Time Password (OTP) is incorrect or expired."))

user = register_user(session, args)
user = service.register_user(args)

return create_access_token(user.id)


@router.post("/login")
async def login(form: PasswordForm, session: Session) -> AccessTokenResponse:
async def login(form: PasswordForm, service: UserServiceDepends) -> AccessTokenResponse:
email = form.username # The OAuth2 spec requires the exact name `username`.
user = get_user_by_email(session, email)
user = service.get_user_by_email(email)

if not user:
raise HTTPException(status.HTTP_404_NOT_FOUND, _("User not found."))
Expand All @@ -39,14 +39,14 @@ async def login(form: PasswordForm, session: Session) -> AccessTokenResponse:


@router.patch("/reset-password")
def reset_password(args: UserPasswordResetRequest, session: Session) -> AccessTokenResponse:
user = get_user_by_email(session, args.email)
def reset_password(args: UserPasswordResetRequest, service: UserServiceDepends) -> AccessTokenResponse:
user = service.get_user_by_email(args.email)

if not user:
raise HTTPException(status.HTTP_404_NOT_FOUND, _("User not found."))
if not expire_otp_if_correct(args.email, args.otp):
raise HTTPException(status.HTTP_406_NOT_ACCEPTABLE, _("The One-Time Password (OTP) is incorrect or expired."))

update_password(session, user, args.password.get_secret_value())
service.update_password(user, args.password.get_secret_value())

return create_access_token(user.id)
2 changes: 1 addition & 1 deletion src/api/v1/authors/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_authors(search_params: SearchParamsDepends, service: AuthorServiceDepend


@router.post("/", response_model=AuthorResponse)
def create_author(current_user: CurrentUser, args: AuthorCreateRequest, service: AuthorServiceDepends):
def create_author(args: AuthorCreateRequest, current_user: CurrentUser, service: AuthorServiceDepends):
if service.exists(args.name):
raise HTTPException(status.HTTP_409_CONFLICT, _("An author with the name '%s' already exists." % (args.name,)))

Expand Down
4 changes: 4 additions & 0 deletions src/api/v1/users/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from src.api.v1.users import me
from src.api.v1.users.deps import UserServiceDepends
from src.api.v1.users.models import User
from src.api.v1.users.routes import router
from src.api.v1.users.service import UserService

router.include_router(me.router)

__all__ = [
"User",
"UserService",
"UserServiceDepends",
"router",
]
7 changes: 7 additions & 0 deletions src/api/v1/users/deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing import Annotated

from fastapi import Depends

from src.api.v1.users.service import UserService

UserServiceDepends = Annotated[UserService, Depends(UserService)]
45 changes: 28 additions & 17 deletions src/api/v1/users/me/routes.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
from typing import Annotated

from fastapi import APIRouter, File, HTTPException, UploadFile, status
from PIL import Image

from src.api.v1.otp.service import expire_otp_if_correct
from src.api.v1.users.deps import UserServiceDepends
from src.api.v1.users.me.deps import CurrentUser
from src.api.v1.users.me.schemas import CurrentUserEmailUpdateRequest, CurrentUserResponse
from src.api.v1.users.models import User
from src.api.v1.users.schemas import UserPasswordRequest
from src.api.v1.users.service import (
delete_avatar,
is_email_registered,
update_avatar,
update_email,
update_password,
)
from src.config import settings
from src.db.deps import Session
from src.i18n import gettext as _
from src.storage import fs, mimetype

Expand All @@ -27,24 +22,38 @@ def get_current_user(current_user: CurrentUser) -> User:


@router.patch("/email", response_model=CurrentUserResponse)
def update_current_user_email(current_user: CurrentUser, args: CurrentUserEmailUpdateRequest, session: Session) -> User:
if is_email_registered(session, args.email):
def update_current_user_email(
args: CurrentUserEmailUpdateRequest,
current_user: CurrentUser,
service: UserServiceDepends,
) -> User:
if service.is_email_registered(args.email):
raise HTTPException(status.HTTP_400_BAD_REQUEST, _("Email is already taken."))
if not expire_otp_if_correct(args.email, args.otp):
raise HTTPException(status.HTTP_406_NOT_ACCEPTABLE, _("The One-Time Password (OTP) is incorrect or expired."))

update_email(session, current_user, args.email)
service.update_email(current_user, args.email)

return current_user


@router.patch("/password", response_model=CurrentUserResponse)
def update_current_user_password(current_user: CurrentUser, args: UserPasswordRequest, session: Session) -> User:
update_password(session, current_user, args.password.get_secret_value())
def update_current_user_password(
args: UserPasswordRequest,
current_user: CurrentUser,
service: UserServiceDepends,
) -> User:
service.update_password(current_user, args.password.get_secret_value())

return current_user


@router.patch("/avatar", response_model=CurrentUserResponse)
async def update_current_user_avatar(current_user: CurrentUser, session: Session, file: UploadFile = File()):
def update_current_user_avatar(
file: Annotated[UploadFile, File()],
current_user: CurrentUser,
service: UserServiceDepends,
):
if not fs.is_size_in_range(file.file, max_size=settings.api.max_avatar_size):
mb = settings.api.max_avatar_size / (1024 * 1024)
raise HTTPException(
Expand All @@ -59,14 +68,16 @@ async def update_current_user_avatar(current_user: CurrentUser, session: Session
)

image = Image.open(file.file)
update_avatar(session, current_user, image)
service.update_avatar(current_user, image)

return current_user


@router.delete("/avatar", response_model=CurrentUserResponse)
def delete_current_user_avatar(current_user: CurrentUser, session: Session):
def delete_current_user_avatar(current_user: CurrentUser, service: UserServiceDepends):
if not current_user.avatar_url:
raise HTTPException(status.HTTP_404_NOT_FOUND, _("Avatar not found."))

delete_avatar(session, current_user)
service.delete_avatar(current_user)

return current_user
7 changes: 3 additions & 4 deletions src/api/v1/users/routes.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from fastapi import APIRouter

from src.api.v1.users.deps import UserServiceDepends
from src.api.v1.users.schemas import UserEmailRequest, UserExistenceResponse
from src.api.v1.users.service import is_email_registered
from src.db.deps import Session

router = APIRouter(prefix="/users", tags=["Users"])


@router.post("/exists")
def user_exists(args: UserEmailRequest, session: Session) -> UserExistenceResponse:
exists = is_email_registered(session, args.email)
def user_exists(args: UserEmailRequest, service: UserServiceDepends) -> UserExistenceResponse:
exists = service.is_email_registered(args.email)
return UserExistenceResponse(exists=exists)
86 changes: 43 additions & 43 deletions src/api/v1/users/service.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,69 @@
from PIL.ImageFile import ImageFile
from sqlalchemy.orm import Session
from sqlalchemy.sql import exists

from src.api.v1.users.models import User
from src.api.v1.users.schemas import UserRegistrationRequest
from src.db.deps import Session
from src.security import get_password_hash
from src.storage import fs
from src.storage.images import crop_image_to_square


def get_user_by_email(session: Session, email: str) -> User | None:
return session.query(User).filter(User.email == email).first()
class UserService:
def __init__(self, session: Session) -> None:
self.session = session

def get_user_by_email(self, email: str) -> User | None:
return self.session.query(User).filter(User.email == email).first()

def is_email_registered(session: Session, email: str) -> bool:
return session.query(exists().where(User.email == email)).scalar()
def is_email_registered(self, email: str) -> bool:
return self.session.query(exists().where(User.email == email)).scalar()

def register_user(self, args: UserRegistrationRequest) -> User:
user = User(
email=args.email,
password=get_password_hash(args.password.get_secret_value()),
is_verified=True,
)

def register_user(session: Session, args: UserRegistrationRequest) -> User:
user = User(
email=args.email,
password=get_password_hash(args.password.get_secret_value()),
is_verified=True,
)
self.session.add(user)
self.session.commit()
self.session.refresh(user)

session.add(user)
session.commit()
session.refresh(user)
return user

return user
def update_email(self, user: User, new_email: str):
user.is_verified = True
user.email = new_email

self.session.commit()
self.session.refresh(user)

def update_email(session: Session, user: User, new_email: str):
user.is_verified = True
user.email = new_email
session.commit()
session.refresh(user)
def update_password(self, user: User, new_password: str) -> None:
hashed_password = get_password_hash(password=new_password)
user.password = hashed_password

self.session.commit()
self.session.refresh(user)

def update_password(session: Session, user: User, new_password: str) -> None:
hashed_password = get_password_hash(password=new_password)
user.password = hashed_password
session.commit()
session.refresh(user)
def delete_avatar(self, user: User) -> None:
self.record_avatar(user)

def update_avatar(self, user: User, image: ImageFile) -> None:
cropped_image = crop_image_to_square(image)

def record_avatar(session: Session, user: User, avatar_url: str | None = None) -> None:
if user.avatar_url:
fs.remove(user.avatar_url)
extension = cropped_image.format or "PNG"
name = fs.generate_unique_file_name_from_extension(extension)
path = fs.get_system_path(name)
cropped_image.save(path, extension)

user.avatar_url = avatar_url
session.commit()
session.refresh(user)
self.record_avatar(user, name)

def record_avatar(self, user: User, avatar_url: str | None = None) -> None:
if user.avatar_url:
fs.remove(user.avatar_url)

def update_avatar(session: Session, user: User, image: ImageFile) -> None:
cropped_image = crop_image_to_square(image)
user.avatar_url = avatar_url

extension = cropped_image.format or "PNG"
name = fs.generate_unique_file_name_from_extension(extension)
path = fs.get_system_path(name)
cropped_image.save(path, extension)

record_avatar(session, user, name)


def delete_avatar(session: Session, user: User) -> None:
record_avatar(session, user)
self.session.commit()
self.session.refresh(user)

0 comments on commit 32367ec

Please sign in to comment.