From b10ba3febdcda574608ac38378b052d50699068c Mon Sep 17 00:00:00 2001 From: MatthewBarghout Date: Wed, 28 Jan 2026 21:09:50 -0500 Subject: [PATCH] fixes --- backend/src/core/query_utils.py | 306 ++++++++++++++ backend/src/modules/party/party_router.py | 68 ++-- backend/src/modules/party/party_service.py | 181 ++++++--- backend/test/modules/party/party_utils.py | 13 +- .../modules/party/test_party_list_features.py | 375 ++++++++++++++++++ 5 files changed, 844 insertions(+), 99 deletions(-) create mode 100644 backend/src/core/query_utils.py create mode 100644 backend/test/modules/party/test_party_list_features.py diff --git a/backend/src/core/query_utils.py b/backend/src/core/query_utils.py new file mode 100644 index 0000000..d7f4030 --- /dev/null +++ b/backend/src/core/query_utils.py @@ -0,0 +1,306 @@ +""" +Core utilities for server-side pagination, sorting, and filtering. + +This module provides reusable functions to apply pagination, sorting, and filtering +to SQLAlchemy queries in a type-safe and flexible manner. +""" + +from enum import Enum +from typing import Any + +from pydantic import BaseModel, Field, field_validator +from sqlalchemy import Select, asc, desc, func, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import DeclarativeMeta + + +class SortOrder(str, Enum): + """Sort order enumeration.""" + + ASC = "asc" + DESC = "desc" + + +class FilterOperator(str, Enum): + """Filter operators for different comparison types.""" + + EQUALS = "eq" + NOT_EQUALS = "ne" + GREATER_THAN = "gt" + GREATER_THAN_OR_EQUAL = "gte" + LESS_THAN = "lt" + LESS_THAN_OR_EQUAL = "lte" + CONTAINS = "contains" # For string fields (case-insensitive LIKE) + IN = "in" # For checking if value is in a list + NOT_IN = "not_in" + IS_NULL = "is_null" + IS_NOT_NULL = "is_not_null" + + +class PaginationParams(BaseModel): + """Parameters for pagination.""" + + page_number: int = Field(default=1, ge=1, description="Page number (1-indexed)") + page_size: int | None = Field( + default=None, ge=1, le=100, description="Items per page (None = all items)" + ) + + @property + def skip(self) -> int: + """Calculate offset from page_number and page_size.""" + if self.page_size is None: + return 0 + return (self.page_number - 1) * self.page_size + + @property + def limit(self) -> int | None: + """Alias for page_size.""" + return self.page_size + + +class SortParam(BaseModel): + """Single sorting parameter.""" + + field: str = Field(..., description="Field name to sort by") + order: SortOrder = Field(default=SortOrder.ASC, description="Sort order") + + +class FilterParam(BaseModel): + """Single filter parameter.""" + + field: str = Field(..., description="Field name to filter on") + operator: FilterOperator = Field(..., description="Comparison operator") + value: Any | None = Field(default=None, description="Value to compare against") + + @field_validator("value") + @classmethod + def validate_value_for_operator(cls, v: Any, info: Any) -> Any: + """Validate that value is appropriate for the operator.""" + operator = info.data.get("operator") + + if operator in [FilterOperator.IS_NULL, FilterOperator.IS_NOT_NULL]: + # These operators don't need a value + return None + + if operator in [FilterOperator.IN, FilterOperator.NOT_IN] and not isinstance(v, list): + # These operators need a list + raise ValueError(f"Operator {operator} requires a list value") + + return v + + +class QueryParams(BaseModel): + """Combined parameters for pagination, sorting, and filtering.""" + + pagination: PaginationParams | None = Field(default=None) + sort: list[SortParam] | None = Field(default=None) + filters: list[FilterParam] | None = Field(default=None) + + +def apply_pagination(query: Select, params: PaginationParams | None = None) -> Select: + """ + Apply pagination to a SQLAlchemy query. + + Args: + query: The base SQLAlchemy query + params: Pagination parameters (page_number, page_size) + + Returns: + Modified query with LIMIT and OFFSET applied + """ + if params is None: + return query + + query = query.offset(params.skip) + if params.limit is not None: + query = query.limit(params.limit) + + return query + + +def apply_sorting[ModelType: DeclarativeMeta]( + query: Select, + model: type[ModelType], + sort_params: list[SortParam] | None = None, + allowed_fields: list[str] | None = None, +) -> Select: + """ + Apply sorting to a SQLAlchemy query. + + Args: + query: The base SQLAlchemy query + model: The SQLAlchemy model class + sort_params: List of sorting parameters + allowed_fields: Optional list of fields that can be sorted on + + Returns: + Modified query with ORDER BY applied + + Raises: + ValueError: If attempting to sort on a disallowed or non-existent field + """ + if not sort_params: + return query + + for sort_param in sort_params: + # Validate field exists and is allowed + if not hasattr(model, sort_param.field): + raise ValueError(f"Field '{sort_param.field}' does not exist on model") + + if allowed_fields and sort_param.field not in allowed_fields: + raise ValueError(f"Sorting on field '{sort_param.field}' is not allowed") + + # Get the model attribute + field = getattr(model, sort_param.field) + + # Apply sort order + if sort_param.order == SortOrder.DESC: + query = query.order_by(desc(field)) + else: + query = query.order_by(asc(field)) + + return query + + +def apply_filters[ModelType: DeclarativeMeta]( + query: Select, + model: type[ModelType], + filter_params: list[FilterParam] | None = None, + allowed_fields: list[str] | None = None, +) -> Select: + """ + Apply filters to a SQLAlchemy query. + + Args: + query: The base SQLAlchemy query + model: The SQLAlchemy model class + filter_params: List of filter parameters + allowed_fields: Optional list of fields that can be filtered on + + Returns: + Modified query with WHERE clauses applied + + Raises: + ValueError: If attempting to filter on a disallowed or non-existent field + """ + if not filter_params: + return query + + for filter_param in filter_params: + # Validate field exists and is allowed + if not hasattr(model, filter_param.field): + raise ValueError(f"Field '{filter_param.field}' does not exist on model") + + if allowed_fields and filter_param.field not in allowed_fields: + raise ValueError(f"Filtering on field '{filter_param.field}' is not allowed") + + # Get the model attribute + field = getattr(model, filter_param.field) + + # Apply the appropriate filter based on operator + if filter_param.operator == FilterOperator.EQUALS: + query = query.where(field == filter_param.value) + elif filter_param.operator == FilterOperator.NOT_EQUALS: + query = query.where(field != filter_param.value) + elif filter_param.operator == FilterOperator.GREATER_THAN: + query = query.where(field > filter_param.value) + elif filter_param.operator == FilterOperator.GREATER_THAN_OR_EQUAL: + query = query.where(field >= filter_param.value) + elif filter_param.operator == FilterOperator.LESS_THAN: + query = query.where(field < filter_param.value) + elif filter_param.operator == FilterOperator.LESS_THAN_OR_EQUAL: + query = query.where(field <= filter_param.value) + elif filter_param.operator == FilterOperator.CONTAINS: + # Case-insensitive LIKE for string fields + query = query.where(field.ilike(f"%{filter_param.value}%")) + elif filter_param.operator == FilterOperator.IN: + query = query.where(field.in_(filter_param.value)) + elif filter_param.operator == FilterOperator.NOT_IN: + query = query.where(~field.in_(filter_param.value)) + elif filter_param.operator == FilterOperator.IS_NULL: + query = query.where(field.is_(None)) + elif filter_param.operator == FilterOperator.IS_NOT_NULL: + query = query.where(field.is_not(None)) + + return query + + +def apply_query_params[ModelType: DeclarativeMeta]( + query: Select, + model: type[ModelType], + params: QueryParams | None = None, + allowed_sort_fields: list[str] | None = None, + allowed_filter_fields: list[str] | None = None, +) -> tuple[Select, PaginationParams | None]: + """ + Apply all query parameters (filtering, sorting, pagination) to a query. + + This is the main utility function that combines all operations. + + Args: + query: The base SQLAlchemy query + model: The SQLAlchemy model class + params: Combined query parameters + allowed_sort_fields: Optional list of fields that can be sorted on + allowed_filter_fields: Optional list of fields that can be filtered on + + Returns: + Tuple of (modified query, pagination params used) + """ + if params is None: + return query, None + + # Apply filters first (narrows down the dataset) + if params.filters: + query = apply_filters(query, model, params.filters, allowed_filter_fields) + + # Apply sorting (before pagination to ensure consistent ordering) + if params.sort: + query = apply_sorting(query, model, params.sort, allowed_sort_fields) + + # Apply pagination last + pagination_params = params.pagination + if pagination_params: + query = apply_pagination(query, pagination_params) + + return query, pagination_params + + +async def get_total_count( + session: AsyncSession, + base_query: Select, +) -> int: + """ + Get the total count of results for a query (before pagination). + + Args: + session: SQLAlchemy async session + base_query: The base query (with filters but before pagination) + + Returns: + Total number of results + """ + # Use select(func.count()).select_from(base_query.subquery()) for proper counting + count_query = select(func.count()).select_from(base_query.subquery()) + result = await session.execute(count_query) + return result.scalar() or 0 + + +class PaginatedResponse[ModelType](BaseModel): + """Generic paginated response wrapper matching existing PaginatedResponse.""" + + items: list[ModelType] + total_records: int + page_size: int + page_number: int + total_pages: int + + @property + def has_next(self) -> bool: + """Check if there's a next page.""" + return self.page_number < self.total_pages + + @property + def has_prev(self) -> bool: + """Check if there's a previous page.""" + return self.page_number > 1 diff --git a/backend/src/modules/party/party_router.py b/backend/src/modules/party/party_router.py index 3f7a8e9..c60d4b3 100644 --- a/backend/src/modules/party/party_router.py +++ b/backend/src/modules/party/party_router.py @@ -68,52 +68,50 @@ async def create_party( async def list_parties( page_number: int = Query(1, ge=1, description="Page number (1-indexed)"), page_size: int | None = Query(None, ge=1, le=100, description="Items per page (default: all)"), + sort_by: str | None = Query(None, description="Field to sort by (e.g., 'party_datetime')"), + sort_order: str = Query("asc", pattern="^(asc|desc)$", description="Sort order: asc or desc"), + location_id: int | None = Query(None, description="Filter by location ID"), + contact_one_id: int | None = Query(None, description="Filter by contact one (student) ID"), party_service: PartyService = Depends(), _=Depends(authenticate_by_role("admin", "staff", "police")), ) -> PaginatedPartiesResponse: """ - Returns all party registrations in the database with optional pagination. + Returns all party registrations in the database with optional pagination, sorting, and filtering Query Parameters: - - page_number: The page number to retrieve (1-indexed) + - page_number: The page number to retrieve (1-indexed, default: 1) - page_size: Number of items per page (max 100, default: returns all parties) + - sort_by: Field to sort by (allowed: party_datetime, location_id, contact_one_id, id) + - sort_order: Sort order (asc or desc, default: asc) + - location_id: Filter by location ID (optional) + - contact_one_id: Filter by contact one (student) ID (optional) + + Features: + - **Opt-in**: All features have sensible defaults - no parameters returns all parties + - **Server-side**: All sorting, filtering, and pagination happens in the database + - **Performant**: Scales well with large datasets Returns: - - items: List of party registrations - - total_records: Total number of records in the database - - page_size: Requested page size (or total_records if not specified) - - page_number: Requested page number - - total_pages: Total number of pages based on page size + - items: List of party registrations for the current page + - total_records: Total number of records matching filters (not just current page) + - page_size: Items per page (equals total_records when page_size is None) + - page_number: Current page number + - total_pages: Total number of pages based on page size and total records + + Examples: + - Get all parties: GET /api/parties/ + - Get first page of 10: GET /api/parties/?page_size=10 + - Sort by date descending: GET /api/parties/?sort_by=party_datetime&sort_order=desc + - Filter by location: GET /api/parties/?location_id=5 + - Combined: GET /api/parties/?location_id=5&sort_by=party_datetime&page_size=20 """ - # Get total count first - total_records = await party_service.get_party_count() - - # If page_size is None, return all parties - if page_size is None: - parties = await party_service.get_parties(skip=0, limit=None) - return PaginatedPartiesResponse( - items=parties, - total_records=total_records, - page_size=total_records, - page_number=1, - total_pages=1, - ) - - # Calculate skip and limit for pagination - skip = (page_number - 1) * page_size - - # Get parties with pagination - parties = await party_service.get_parties(skip=skip, limit=page_size) - - # Calculate total pages (ceiling division) - total_pages = (total_records + page_size - 1) // page_size if total_records > 0 else 0 - - return PaginatedPartiesResponse( - items=parties, - total_records=total_records, - page_size=page_size, + return await party_service.get_parties_paginated( page_number=page_number, - total_pages=total_pages, + page_size=page_size, + sort_by=sort_by, + sort_order=sort_order, + location_id=location_id, + contact_one_id=contact_one_id, ) diff --git a/backend/src/modules/party/party_service.py b/backend/src/modules/party/party_service.py index cf5d952..10a3c7a 100644 --- a/backend/src/modules/party/party_service.py +++ b/backend/src/modules/party/party_service.py @@ -11,6 +11,16 @@ from src.core.config import env from src.core.database import get_session from src.core.exceptions import BadRequestException, ConflictException, NotFoundException +from src.core.query_utils import ( + FilterOperator, + FilterParam, + PaginationParams, + QueryParams, + SortOrder, + SortParam, + apply_query_params, + get_total_count, +) from src.modules.location.location_model import LocationDto from src.modules.student.student_service import StudentNotFoundException, StudentService @@ -18,7 +28,13 @@ from ..location.location_service import LocationService from ..student.student_entity import StudentEntity from .party_entity import PartyEntity -from .party_model import AdminCreatePartyDto, PartyData, PartyDto, StudentCreatePartyDto +from .party_model import ( + AdminCreatePartyDto, + PaginatedPartiesResponse, + PartyData, + PartyDto, + StudentCreatePartyDto, +) class PartyNotFoundException(NotFoundException): @@ -74,7 +90,6 @@ def _calculate_business_days_ahead(self, target_date: datetime) -> int: """Calculate the number of business days between now and target date.""" # Ensure both datetimes are timezone-aware (use UTC) current_date = datetime.now(UTC).replace(hour=0, minute=0, second=0, microsecond=0) - # If target_date is naive, make it UTC-aware; otherwise keep its timezone if target_date.tzinfo is None: target_date_only = target_date.replace( @@ -82,16 +97,13 @@ def _calculate_business_days_ahead(self, target_date: datetime) -> int: ) else: target_date_only = target_date.replace(hour=0, minute=0, second=0, microsecond=0) - business_days = 0 current = current_date - while current < target_date_only: # Skip weekends (Saturday=5, Sunday=6) if current.weekday() < 5: business_days += 1 current += timedelta(days=1) - return business_days def _validate_party_date(self, party_datetime: datetime) -> None: @@ -103,24 +115,19 @@ def _validate_party_date(self, party_datetime: datetime) -> None: async def _validate_party_smart_attendance(self, student_id: int) -> None: """Validate that student has completed Party Smart after the most recent August 1st.""" student = await self.student_service.get_student_by_id(student_id) - if student.last_registered is None: raise PartySmartNotCompletedException(student_id) - # Calculate the most recent August 1st now = datetime.now(UTC) current_year = now.year - # August 1st of the current year (UTC) august_first_this_year = datetime(current_year, 8, 1, 0, 0, 0, tzinfo=UTC) - # If today is before August 1st, use last year's August 1st # Otherwise, use this year's August 1st if now < august_first_this_year: most_recent_august_first = datetime(current_year - 1, 8, 1, 0, 0, 0, tzinfo=UTC) else: most_recent_august_first = august_first_this_year - # Check if last_registered is after the most recent August 1st if student.last_registered < most_recent_august_first: raise PartySmartNotCompletedException(student_id) @@ -145,7 +152,6 @@ async def _get_student_by_email(self, email: str) -> StudentEntity: account = await self.account_service.get_account_by_email(email) except AccountByEmailNotFoundException as e: raise StudentNotFoundException(email=email) from e - # Then get the student entity result = await self.session.execute( select(StudentEntity) @@ -172,6 +178,111 @@ async def get_parties(self, skip: int = 0, limit: int | None = None) -> list[Par parties = result.scalars().all() return [party.to_dto() for party in parties] + async def get_parties_paginated( + self, + page_number: int = 1, + page_size: int | None = None, + sort_by: str | None = None, + sort_order: str = "asc", + location_id: int | None = None, + contact_one_id: int | None = None, + ) -> PaginatedPartiesResponse: + """ + Get parties with server-side pagination, sorting, and filtering. + + Args: + page_number: Page number (1-indexed) + page_size: Items per page (None = all items) + sort_by: Field to sort by + sort_order: Sort order ('asc' or 'desc') + location_id: Filter by location ID + contact_one_id: Filter by contact one (student) ID + + Returns: + PaginatedPartiesResponse with items and metadata + """ + # Build base query with eager loading + base_query = select(PartyEntity).options( + selectinload(PartyEntity.location), + selectinload(PartyEntity.contact_one).selectinload(StudentEntity.account), + ) + + # Build query params + filters: list[FilterParam] = [] + if location_id is not None: + filters.append( + FilterParam(field="location_id", operator=FilterOperator.EQUALS, value=location_id) + ) + if contact_one_id is not None: + filters.append( + FilterParam( + field="contact_one_id", operator=FilterOperator.EQUALS, value=contact_one_id + ) + ) + + sort_params: list[SortParam] | None = None + if sort_by: + sort_params = [SortParam(field=sort_by, order=SortOrder(sort_order))] + + pagination_params = PaginationParams(page_number=page_number, page_size=page_size) + + query_params = QueryParams( + pagination=pagination_params, + sort=sort_params, + filters=filters if filters else None, + ) + + # Define allowed fields for sorting (security measure) + allowed_sort_fields = ["id", "party_datetime", "location_id", "contact_one_id"] + + # Define allowed fields for filtering (security measure) + allowed_filter_fields = ["location_id", "contact_one_id"] + + # Apply filters and sorting (but not pagination yet - need count first) + filtered_query, _ = apply_query_params( + base_query, + PartyEntity, # type: ignore + QueryParams(filters=query_params.filters, sort=query_params.sort), + allowed_sort_fields=allowed_sort_fields, + allowed_filter_fields=allowed_filter_fields, + ) + + # Get total count after filters but before pagination + total_records = await get_total_count(self.session, filtered_query) + + # Now apply pagination + paginated_query, _ = apply_query_params( + filtered_query, + PartyEntity, # type: ignore + QueryParams(pagination=query_params.pagination), + ) + + # Execute query + result = await self.session.execute(paginated_query) + parties = result.scalars().all() + + # Convert to DTOs + party_dtos = [party.to_dto() for party in parties] + + # Calculate metadata + if page_size is None: + # When no page_size specified, return all items on page 1 + actual_page_size = total_records + total_pages = 1 + actual_page_number = 1 + else: + actual_page_size = page_size + actual_page_number = page_number + total_pages = (total_records + page_size - 1) // page_size if total_records > 0 else 0 + + return PaginatedPartiesResponse( + items=party_dtos, + total_records=total_records, + page_size=actual_page_size, + page_number=actual_page_number, + total_pages=total_pages, + ) + async def get_party_by_id(self, party_id: int) -> PartyDto: party_entity = await self._get_party_entity_by_id(party_id) return party_entity.to_dto() @@ -221,7 +332,6 @@ async def create_party(self, data: PartyData) -> PartyDto: # Validate that referenced resources exist await self.location_service.assert_location_exists(data.location_id) await self.student_service.assert_student_exists(data.contact_one_id) - new_party = PartyEntity.from_data(data) try: self.session.add(new_party) @@ -232,19 +342,15 @@ async def create_party(self, data: PartyData) -> PartyDto: async def update_party(self, party_id: int, data: PartyData) -> PartyDto: party_entity = await self._get_party_entity_by_id(party_id) - # Validate that referenced resources exist await self.location_service.assert_location_exists(data.location_id) await self.student_service.assert_student_exists(data.contact_one_id) - for key, value in data.model_dump().items(): if key == "id": continue if hasattr(party_entity, key): setattr(party_entity, key, value) - party_entity.set_contact_two(data.contact_two) - try: self.session.add(party_entity) await self.session.commit() @@ -258,10 +364,8 @@ async def create_party_from_student_dto( """Create a party registration from a student. contact_one is auto-filled.""" # Validate student party prerequisites (date and Party Smart) await self._validate_student_party_prerequisites(student_account_id, dto.party_datetime) - # Get/create location and validate no hold location = await self._validate_and_get_location(dto.google_place_id) - # Create party data with contact_two information directly party_data = PartyData( party_datetime=dto.party_datetime, @@ -269,7 +373,6 @@ async def create_party_from_student_dto( contact_one_id=student_account_id, contact_two=dto.contact_two, ) - # Create party new_party = PartyEntity.from_data(party_data) self.session.add(new_party) @@ -280,10 +383,8 @@ async def create_party_from_admin_dto(self, dto: AdminCreatePartyDto) -> PartyDt """Create a party registration from an admin. Both contacts must be specified.""" # Get/create location and validate no hold location = await self._validate_and_get_location(dto.google_place_id) - # Get contact_one by email contact_one = await self._get_student_by_email(dto.contact_one_email) - # Create party data with contact_two information directly party_data = PartyData( party_datetime=dto.party_datetime, @@ -291,7 +392,6 @@ async def create_party_from_admin_dto(self, dto: AdminCreatePartyDto) -> PartyDt contact_one_id=contact_one.account_id, contact_two=dto.contact_two, ) - # Create party new_party = PartyEntity.from_data(party_data) self.session.add(new_party) @@ -304,16 +404,12 @@ async def update_party_from_student_dto( """Update a party registration from a student. contact_one is auto-filled.""" # Get existing party party_entity = await self._get_party_entity_by_id(party_id) - # Validate student party prerequisites (date and Party Smart) await self._validate_student_party_prerequisites(student_account_id, dto.party_datetime) - # Get/create location and validate no hold location = await self._validate_and_get_location(dto.google_place_id) - # Validate contact_one (student) exists await self.student_service.assert_student_exists(student_account_id) - # Update party fields party_entity.party_datetime = dto.party_datetime party_entity.location_id = location.id @@ -323,7 +419,6 @@ async def update_party_from_student_dto( party_entity.contact_two_last_name = dto.contact_two.last_name party_entity.contact_two_phone_number = dto.contact_two.phone_number party_entity.contact_two_contact_preference = dto.contact_two.contact_preference - self.session.add(party_entity) await self.session.commit() return await party_entity.load_dto(self.session) @@ -334,14 +429,11 @@ async def update_party_from_admin_dto( """Update a party registration from an admin. Both contacts must be specified.""" # Get existing party party_entity = await self._get_party_entity_by_id(party_id) - # Get/create location and validate no hold location = await self._validate_and_get_location(dto.google_place_id) - # Get contact_one by email contact_one_student = await self._get_student_by_email(dto.contact_one_email) contact_one_id = contact_one_student.account_id - # Update party fields party_entity.party_datetime = dto.party_datetime party_entity.location_id = location.id @@ -351,7 +443,6 @@ async def update_party_from_admin_dto( party_entity.contact_two_last_name = dto.contact_two.last_name party_entity.contact_two_phone_number = dto.contact_two.phone_number party_entity.contact_two_contact_preference = dto.contact_two.contact_preference - self.session.add(party_entity) await self.session.commit() return await party_entity.load_dto(self.session) @@ -377,7 +468,6 @@ async def get_parties_by_student_and_date( ) -> list[PartyDto]: start_of_day = target_date.replace(hour=0, minute=0, second=0, microsecond=0) end_of_day = target_date.replace(hour=23, minute=59, second=59, microsecond=999999) - result = await self.session.execute( select(PartyEntity) .where( @@ -397,7 +487,6 @@ async def get_parties_by_radius(self, latitude: float, longitude: float) -> list current_time = datetime.now(UTC) start_time = current_time - timedelta(hours=6) end_time = current_time + timedelta(hours=12) - result = await self.session.execute( select(PartyEntity) .options( @@ -410,22 +499,18 @@ async def get_parties_by_radius(self, latitude: float, longitude: float) -> list ) ) parties = result.scalars().all() - parties_within_radius: list[PartyEntity] = [] for party in parties: if party.location is None: continue - distance = self._calculate_haversine_distance( latitude, longitude, float(party.location.latitude), float(party.location.longitude), ) - if distance <= env.PARTY_SEARCH_RADIUS_MILES: parties_within_radius.append(party) - return [party.to_dto() for party in parties_within_radius] async def get_parties_by_radius_and_date_range( @@ -437,13 +522,11 @@ async def get_parties_by_radius_and_date_range( ) -> list[PartyDto]: """ Get parties within a radius of a location within a specified date range. - Args: latitude: Latitude of the search center longitude: Longitude of the search center start_date: Start of the date range (inclusive) end_date: End of the date range (inclusive) - Returns: List of parties within the radius and date range """ @@ -459,44 +542,36 @@ async def get_parties_by_radius_and_date_range( ) ) parties = result.scalars().all() - parties_within_radius: list[PartyEntity] = [] for party in parties: if party.location is None: continue - distance = self._calculate_haversine_distance( latitude, longitude, float(party.location.latitude), float(party.location.longitude), ) - if distance <= env.PARTY_SEARCH_RADIUS_MILES: parties_within_radius.append(party) - return [party.to_dto() for party in parties_within_radius] def _calculate_haversine_distance( self, lat1: float, lon1: float, lat2: float, lon2: float ) -> float: lat1, lon1, lat2, lon2 = map(math.radians, [lat1, lon1, lat2, lon2]) - dlat = lat2 - lat1 dlon = lon2 - lon1 a = math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2 c = 2 * math.asin(math.sqrt(a)) - r = 3959 return c * r async def export_parties_to_csv(self, parties: list[PartyDto]) -> str: """ Export a list of parties to CSV format. - Args: parties: List of Party models to export - Returns: CSV content as a string """ @@ -519,9 +594,7 @@ async def export_parties_to_csv(self, parties: list[PartyDto]) -> str: ] ) return output.getvalue() - party_ids = [party.id for party in parties] - result = await self.session.execute( select(PartyEntity) .options( @@ -531,12 +604,9 @@ async def export_parties_to_csv(self, parties: list[PartyDto]) -> str: .where(PartyEntity.id.in_(party_ids)) ) party_entities = result.scalars().all() - party_entity_map = {party.id: party for party in party_entities} - output = io.StringIO() writer = csv.writer(output) - writer.writerow( [ "Fully formatted address", @@ -552,21 +622,17 @@ async def export_parties_to_csv(self, parties: list[PartyDto]) -> str: "Contact Two Contact Preference", ] ) - for party in parties: party_entity = party_entity_map.get(party.id) if party_entity is None: continue - # Format address formatted_address = "" if party_entity.location: formatted_address = party_entity.location.formatted_address or "" - # Format date and time party_date = party.party_datetime.strftime("%Y-%m-%d") if party.party_datetime else "" party_time = party.party_datetime.strftime("%H:%M:%S") if party.party_datetime else "" - contact_one_full_name = "" contact_one_email = "" contact_one_phone = "" @@ -584,7 +650,6 @@ async def export_parties_to_csv(self, parties: list[PartyDto]) -> str: ) if party_entity.contact_one.account: contact_one_email = party_entity.contact_one.account.email or "" - contact_two_full_name = "" contact_two_email = "" contact_two_phone = "" @@ -599,7 +664,6 @@ async def export_parties_to_csv(self, parties: list[PartyDto]) -> str: else "" ) contact_two_email = party_entity.contact_two_email or "" - writer.writerow( [ formatted_address, @@ -615,5 +679,4 @@ async def export_parties_to_csv(self, parties: list[PartyDto]) -> str: contact_two_preference, ] ) - return output.getvalue() diff --git a/backend/test/modules/party/party_utils.py b/backend/test/modules/party/party_utils.py index 55b8806..38b5489 100644 --- a/backend/test/modules/party/party_utils.py +++ b/backend/test/modules/party/party_utils.py @@ -73,15 +73,18 @@ def generate_defaults(count: int) -> dict[str, Any]: @override async def next_dict(self, **overrides: Unpack[PartyOverrides]) -> dict: - if "location_id" not in overrides: + # Create a copy to avoid mutating the original + local_overrides = dict(overrides) + + if "location_id" not in local_overrides: location = await self.location_utils.create_one() - overrides["location_id"] = location.id + local_overrides["location_id"] = location.id - if "contact_one_id" not in overrides: + if "contact_one_id" not in local_overrides: student = await self.student_utils.create_one() - overrides["contact_one_id"] = student.account_id + local_overrides["contact_one_id"] = student.account_id - return await super().next_dict(**overrides) + return await super().next_dict(**local_overrides) def next_contact(self, **overrides: Unpack[PartyOverrides]) -> ContactDto: """Generate test contact data.""" diff --git a/backend/test/modules/party/test_party_list_features.py b/backend/test/modules/party/test_party_list_features.py new file mode 100644 index 0000000..762b874 --- /dev/null +++ b/backend/test/modules/party/test_party_list_features.py @@ -0,0 +1,375 @@ +""" +Tests for party list endpoint with pagination, sorting, and filtering. + +Add these tests to your existing party_router_test.py file or create a new test file. +""" + +from datetime import timedelta + +import pytest +from httpx import AsyncClient +from src.modules.party.party_model import PartyDto +from test.modules.party.party_utils import PartyTestUtils, get_valid_party_datetime +from test.utils.http.assertions import assert_res_paginated + + +class TestPartyListPagination: + """Tests for pagination on GET /api/parties/ endpoint.""" + + admin_client: AsyncClient + party_utils: PartyTestUtils + + @pytest.fixture(autouse=True) + def _setup(self, party_utils: PartyTestUtils, admin_client: AsyncClient): + self.party_utils = party_utils + self.admin_client = admin_client + + @pytest.mark.asyncio + async def test_debug_party_creation(self): + """Debug test to see what's happening with party creation.""" + print("\n=== Starting test ===") + + for i in range(5): + print(f"\n--- Creating party {i + 1} ---") + party = await self.party_utils.create_one() + print(f"Created party with ID: {party.id}") + print(f"Location ID: {party.location_id}, Contact ID: {party.contact_one_id}") + + print("\n=== Checking database ===") + response = await self.admin_client.get("/api/parties/") + data = response.json() + print(f"Total records from API: {data['total_records']}") + print(f"Number of items: {len(data['items'])}") + print(f"Item IDs: {[item['id'] for item in data['items']]}") + + @pytest.mark.asyncio + async def test_list_parties_default_pagination(self): + """Test listing parties with default pagination (no page_size returns all).""" + for _ in range(5): + await self.party_utils.create_one() + + response = await self.admin_client.get("/api/parties/") + paginated = assert_res_paginated( + response, PartyDto, total_records=5, page_size=5, total_pages=1 + ) + assert len(paginated.items) == 5 + assert paginated.page_number == 1 + + @pytest.mark.asyncio + async def test_list_parties_with_page_size(self): + """Test listing parties with explicit page size.""" + for _ in range(15): + await self.party_utils.create_one() + + # First page + response = await self.admin_client.get("/api/parties/?page_number=1&page_size=10") + paginated = assert_res_paginated( + response, PartyDto, total_records=15, page_size=10, total_pages=2, page_number=1 + ) + assert len(paginated.items) == 10 + assert paginated.page_number == 1 + + # Second page + response = await self.admin_client.get("/api/parties/?page_number=2&page_size=10") + paginated = assert_res_paginated( + response, PartyDto, total_records=15, page_size=10, total_pages=2, page_number=2 + ) + assert len(paginated.items) == 5 # Remaining items + assert paginated.page_number == 2 + + @pytest.mark.asyncio + async def test_list_parties_beyond_last_page(self): + """Test requesting a page beyond the last page returns empty results.""" + for _ in range(5): + await self.party_utils.create_one() + + response = await self.admin_client.get("/api/parties/?page_number=10&page_size=10") + paginated = assert_res_paginated( + response, PartyDto, total_records=5, page_size=10, total_pages=1, page_number=10 + ) + assert len(paginated.items) == 0 + assert paginated.page_number == 10 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "total_items,page_size,page_number,expected_items", + [ + (0, 10, 1, 0), # Empty database + (5, 10, 1, 5), # All items fit on one page + (10, 10, 1, 10), # Exactly one page + (11, 10, 1, 10), # First page of two + (11, 10, 2, 1), # Second page with remainder + (25, 5, 3, 5), # Middle page + (25, 5, 5, 5), # Last full page + (100, 20, 3, 20), # Larger dataset + ], + ) + async def test_list_parties_pagination_scenarios( + self, total_items: int, page_size: int, page_number: int, expected_items: int + ): + """Parameterized test for various pagination scenarios.""" + for _ in range(total_items): + await self.party_utils.create_one() + + response = await self.admin_client.get( + f"/api/parties/?page_number={page_number}&page_size={page_size}" + ) + + expected_pages = (total_items + page_size - 1) // page_size if total_items > 0 else 0 + paginated = assert_res_paginated( + response, + PartyDto, + total_records=total_items, + page_size=page_size, + total_pages=expected_pages, + page_number=page_number, + ) + assert len(paginated.items) == expected_items + assert paginated.page_number == page_number + + +class TestPartyListSorting: + """Tests for sorting on GET /api/parties/ endpoint.""" + + admin_client: AsyncClient + party_utils: PartyTestUtils + + @pytest.fixture(autouse=True) + def _setup(self, party_utils: PartyTestUtils, admin_client: AsyncClient): + self.party_utils = party_utils + self.admin_client = admin_client + + @pytest.mark.asyncio + async def test_list_parties_sort_by_datetime_asc(self): + """Test sorting parties by datetime in ascending order.""" + base_datetime = get_valid_party_datetime() + + # Create parties with different datetimes + party1 = await self.party_utils.create_one(party_datetime=base_datetime) + party2 = await self.party_utils.create_one(party_datetime=base_datetime + timedelta(days=1)) + party3 = await self.party_utils.create_one(party_datetime=base_datetime + timedelta(days=2)) + + response = await self.admin_client.get( + "/api/parties/?sort_by=party_datetime&sort_order=asc" + ) + paginated = assert_res_paginated(response, PartyDto, total_records=3) + + # Check order + assert paginated.items[0].id == party1.id + assert paginated.items[1].id == party2.id + assert paginated.items[2].id == party3.id + + @pytest.mark.asyncio + async def test_list_parties_sort_by_datetime_desc(self): + """Test sorting parties by datetime in descending order.""" + base_datetime = get_valid_party_datetime() + + # Create parties with different datetimes + party1 = await self.party_utils.create_one(party_datetime=base_datetime) + party2 = await self.party_utils.create_one(party_datetime=base_datetime + timedelta(days=1)) + party3 = await self.party_utils.create_one(party_datetime=base_datetime + timedelta(days=2)) + + response = await self.admin_client.get( + "/api/parties/?sort_by=party_datetime&sort_order=desc" + ) + paginated = assert_res_paginated(response, PartyDto, total_records=3) + + # Check order (reversed) + assert paginated.items[0].id == party3.id + assert paginated.items[1].id == party2.id + assert paginated.items[2].id == party1.id + + @pytest.mark.asyncio + async def test_list_parties_sort_by_id(self): + """Test sorting parties by ID.""" + for _ in range(5): + await self.party_utils.create_one() + + response = await self.admin_client.get("/api/parties/?sort_by=id&sort_order=asc") + paginated = assert_res_paginated(response, PartyDto, total_records=5) + + # Check IDs are in ascending order + ids = [party.id for party in paginated.items] + assert ids == sorted(ids) + + @pytest.mark.asyncio + async def test_list_parties_sort_with_pagination(self): + """Test that sorting works correctly with pagination.""" + base_datetime = get_valid_party_datetime() + + # Create 10 parties with sequential datetimes + for i in range(10): + await self.party_utils.create_one(party_datetime=base_datetime + timedelta(days=i)) + + # Get first page sorted by datetime desc + response = await self.admin_client.get( + "/api/parties/?page_number=1&page_size=5&sort_by=party_datetime&sort_order=desc" + ) + page1 = assert_res_paginated( + response, PartyDto, total_records=10, page_size=5, total_pages=2, page_number=1 + ) + + # Get second page + response = await self.admin_client.get( + "/api/parties/?page_number=2&page_size=5&sort_by=party_datetime&sort_order=desc" + ) + page2 = assert_res_paginated( + response, PartyDto, total_records=10, page_size=5, total_pages=2, page_number=2 + ) + + # Verify all items across both pages are in desc order + all_datetimes = [p.party_datetime for p in page1.items] + [ + p.party_datetime for p in page2.items + ] + assert all_datetimes == sorted(all_datetimes, reverse=True) + + +class TestPartyListFiltering: + """Tests for filtering on GET /api/parties/ endpoint.""" + + admin_client: AsyncClient + party_utils: PartyTestUtils + + @pytest.fixture(autouse=True) + def _setup(self, party_utils: PartyTestUtils, admin_client: AsyncClient): + self.party_utils = party_utils + self.admin_client = admin_client + + @pytest.mark.asyncio + async def test_list_parties_filter_by_location(self): + """Test filtering parties by location ID.""" + # Create parties at different locations + party1 = await self.party_utils.create_one() + _party2 = await self.party_utils.create_one() + party3 = await self.party_utils.create_one(location_id=party1.location_id) + + # Filter by location_id + response = await self.admin_client.get(f"/api/parties/?location_id={party1.location_id}") + paginated = assert_res_paginated(response, PartyDto, total_records=2) + + # Should return party1 and party3 + returned_ids = {p.id for p in paginated.items} + assert returned_ids == {party1.id, party3.id} + + @pytest.mark.asyncio + async def test_list_parties_filter_by_contact(self): + """Test filtering parties by contact one ID.""" + # Create parties with different contacts + party1 = await self.party_utils.create_one() + _party2 = await self.party_utils.create_one() + party3 = await self.party_utils.create_one(contact_one_id=party1.contact_one_id) + + # Filter by contact_one_id + response = await self.admin_client.get( + f"/api/parties/?contact_one_id={party1.contact_one_id}" + ) + paginated = assert_res_paginated(response, PartyDto, total_records=2) + + # Should return party1 and party3 + returned_ids = {p.id for p in paginated.items} + assert returned_ids == {party1.id, party3.id} + + @pytest.mark.asyncio + async def test_list_parties_multiple_filters(self): + """Test filtering parties with multiple filters simultaneously.""" + # Create parties with known attributes + party1 = await self.party_utils.create_one() + await self.party_utils.create_one() # Different location and contact + party3 = await self.party_utils.create_one( + location_id=party1.location_id, contact_one_id=party1.contact_one_id + ) + + # Filter by both location and contact + response = await self.admin_client.get( + f"/api/parties/?location_id={party1.location_id}&contact_one_id={party1.contact_one_id}" + ) + paginated = assert_res_paginated(response, PartyDto, total_records=2) + + # Should return party1 and party3 + returned_ids = {p.id for p in paginated.items} + assert returned_ids == {party1.id, party3.id} + + @pytest.mark.asyncio + async def test_list_parties_filter_no_matches(self): + """Test filtering with criteria that match no parties.""" + for _ in range(3): + await self.party_utils.create_one() + + # Filter by non-existent location + response = await self.admin_client.get("/api/parties/?location_id=99999") + paginated = assert_res_paginated(response, PartyDto, total_records=0) + assert len(paginated.items) == 0 + + +class TestPartyListCombined: + """Tests for combined pagination, sorting, and filtering.""" + + admin_client: AsyncClient + party_utils: PartyTestUtils + + @pytest.fixture(autouse=True) + def _setup(self, party_utils: PartyTestUtils, admin_client: AsyncClient): + self.party_utils = party_utils + self.admin_client = admin_client + + @pytest.mark.asyncio + async def test_filter_sort_paginate_together(self): + """Test using filtering, sorting, and pagination together.""" + base_datetime = get_valid_party_datetime() + + # Create multiple parties at the same location with different dates + location = await self.party_utils.create_one() + location_id = location.location_id + + for i in range(10): + await self.party_utils.create_one( + location_id=location_id, party_datetime=base_datetime + timedelta(days=i) + ) + + # Create some parties at different location (should be filtered out) + for _ in range(5): + await self.party_utils.create_one() + + # Filter by location, sort by datetime desc, paginate + response = await self.admin_client.get( + f"/api/parties/?location_id={location_id}&sort_by=party_datetime" + "&sort_order=desc&page_number=1&page_size=5" + ) + paginated = assert_res_paginated( + response, PartyDto, total_records=11, page_size=5, total_pages=3 + ) + + # Verify we got the right number of items + assert len(paginated.items) == 5 + + # Verify all items are from the filtered location + assert all(p.location.id == location_id for p in paginated.items) + + # Verify sorting (should be descending by datetime) + datetimes = [p.party_datetime for p in paginated.items] + assert datetimes == sorted(datetimes, reverse=True) + + @pytest.mark.asyncio + async def test_pagination_preserves_filter_count(self): + """Test that total_records reflects filtered count, not total count.""" + # Create 20 parties at location 1 + location1 = await self.party_utils.create_one() + for _ in range(19): + await self.party_utils.create_one(location_id=location1.location_id) + + # Create 10 parties at location 2 + location2 = await self.party_utils.create_one() + for _ in range(9): + await self.party_utils.create_one(location_id=location2.location_id) + + # Filter for location 1 with pagination + response = await self.admin_client.get( + f"/api/parties/?location_id={location1.location_id}&page_size=10" + ) + paginated = assert_res_paginated( + response, PartyDto, total_records=20, page_size=10, total_pages=2 + ) + + # Should show 20 total (filtered count), not 30 (total count) + assert paginated.total_records == 20 + assert len(paginated.items) == 10