Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 306 additions & 0 deletions backend/src/core/query_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
"""
Core utilities for server-side pagination, sorting, and filtering.

This module provides reusable functions to apply pagination, sorting, and filtering
to SQLAlchemy queries in a type-safe and flexible manner.
"""

from enum import Enum
from typing import Any

from pydantic import BaseModel, Field, field_validator
from sqlalchemy import Select, asc, desc, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeMeta


class SortOrder(str, Enum):
"""Sort order enumeration."""

ASC = "asc"
DESC = "desc"


class FilterOperator(str, Enum):
"""Filter operators for different comparison types."""

EQUALS = "eq"
NOT_EQUALS = "ne"
GREATER_THAN = "gt"
GREATER_THAN_OR_EQUAL = "gte"
LESS_THAN = "lt"
LESS_THAN_OR_EQUAL = "lte"
CONTAINS = "contains" # For string fields (case-insensitive LIKE)
IN = "in" # For checking if value is in a list
NOT_IN = "not_in"
IS_NULL = "is_null"
IS_NOT_NULL = "is_not_null"


class PaginationParams(BaseModel):
"""Parameters for pagination."""

page_number: int = Field(default=1, ge=1, description="Page number (1-indexed)")
page_size: int | None = Field(
default=None, ge=1, le=100, description="Items per page (None = all items)"
)

@property
def skip(self) -> int:
"""Calculate offset from page_number and page_size."""
if self.page_size is None:
return 0
return (self.page_number - 1) * self.page_size

@property
def limit(self) -> int | None:
"""Alias for page_size."""
return self.page_size


class SortParam(BaseModel):
"""Single sorting parameter."""

field: str = Field(..., description="Field name to sort by")
order: SortOrder = Field(default=SortOrder.ASC, description="Sort order")


class FilterParam(BaseModel):
"""Single filter parameter."""

field: str = Field(..., description="Field name to filter on")
operator: FilterOperator = Field(..., description="Comparison operator")
value: Any | None = Field(default=None, description="Value to compare against")

@field_validator("value")
@classmethod
def validate_value_for_operator(cls, v: Any, info: Any) -> Any:
"""Validate that value is appropriate for the operator."""
operator = info.data.get("operator")

if operator in [FilterOperator.IS_NULL, FilterOperator.IS_NOT_NULL]:
# These operators don't need a value
return None

if operator in [FilterOperator.IN, FilterOperator.NOT_IN] and not isinstance(v, list):
# These operators need a list
raise ValueError(f"Operator {operator} requires a list value")

return v


class QueryParams(BaseModel):
"""Combined parameters for pagination, sorting, and filtering."""

pagination: PaginationParams | None = Field(default=None)
sort: list[SortParam] | None = Field(default=None)
filters: list[FilterParam] | None = Field(default=None)


def apply_pagination(query: Select, params: PaginationParams | None = None) -> Select:
"""
Apply pagination to a SQLAlchemy query.

Args:
query: The base SQLAlchemy query
params: Pagination parameters (page_number, page_size)

Returns:
Modified query with LIMIT and OFFSET applied
"""
if params is None:
return query

query = query.offset(params.skip)
if params.limit is not None:
query = query.limit(params.limit)

return query


def apply_sorting[ModelType: DeclarativeMeta](
query: Select,
model: type[ModelType],
sort_params: list[SortParam] | None = None,
allowed_fields: list[str] | None = None,
) -> Select:
"""
Apply sorting to a SQLAlchemy query.

Args:
query: The base SQLAlchemy query
model: The SQLAlchemy model class
sort_params: List of sorting parameters
allowed_fields: Optional list of fields that can be sorted on

Returns:
Modified query with ORDER BY applied

Raises:
ValueError: If attempting to sort on a disallowed or non-existent field
"""
if not sort_params:
return query

for sort_param in sort_params:
# Validate field exists and is allowed
if not hasattr(model, sort_param.field):
raise ValueError(f"Field '{sort_param.field}' does not exist on model")

if allowed_fields and sort_param.field not in allowed_fields:
raise ValueError(f"Sorting on field '{sort_param.field}' is not allowed")

# Get the model attribute
field = getattr(model, sort_param.field)

# Apply sort order
if sort_param.order == SortOrder.DESC:
query = query.order_by(desc(field))
else:
query = query.order_by(asc(field))

return query


def apply_filters[ModelType: DeclarativeMeta](
query: Select,
model: type[ModelType],
filter_params: list[FilterParam] | None = None,
allowed_fields: list[str] | None = None,
) -> Select:
"""
Apply filters to a SQLAlchemy query.

Args:
query: The base SQLAlchemy query
model: The SQLAlchemy model class
filter_params: List of filter parameters
allowed_fields: Optional list of fields that can be filtered on

Returns:
Modified query with WHERE clauses applied

Raises:
ValueError: If attempting to filter on a disallowed or non-existent field
"""
if not filter_params:
return query

for filter_param in filter_params:
# Validate field exists and is allowed
if not hasattr(model, filter_param.field):
raise ValueError(f"Field '{filter_param.field}' does not exist on model")

if allowed_fields and filter_param.field not in allowed_fields:
raise ValueError(f"Filtering on field '{filter_param.field}' is not allowed")

# Get the model attribute
field = getattr(model, filter_param.field)

# Apply the appropriate filter based on operator
if filter_param.operator == FilterOperator.EQUALS:
query = query.where(field == filter_param.value)
elif filter_param.operator == FilterOperator.NOT_EQUALS:
query = query.where(field != filter_param.value)
elif filter_param.operator == FilterOperator.GREATER_THAN:
query = query.where(field > filter_param.value)
elif filter_param.operator == FilterOperator.GREATER_THAN_OR_EQUAL:
query = query.where(field >= filter_param.value)
elif filter_param.operator == FilterOperator.LESS_THAN:
query = query.where(field < filter_param.value)
elif filter_param.operator == FilterOperator.LESS_THAN_OR_EQUAL:
query = query.where(field <= filter_param.value)
elif filter_param.operator == FilterOperator.CONTAINS:
# Case-insensitive LIKE for string fields
query = query.where(field.ilike(f"%{filter_param.value}%"))
elif filter_param.operator == FilterOperator.IN:
query = query.where(field.in_(filter_param.value))
elif filter_param.operator == FilterOperator.NOT_IN:
query = query.where(~field.in_(filter_param.value))
elif filter_param.operator == FilterOperator.IS_NULL:
query = query.where(field.is_(None))
elif filter_param.operator == FilterOperator.IS_NOT_NULL:
query = query.where(field.is_not(None))

return query


def apply_query_params[ModelType: DeclarativeMeta](
query: Select,
model: type[ModelType],
params: QueryParams | None = None,
allowed_sort_fields: list[str] | None = None,
allowed_filter_fields: list[str] | None = None,
) -> tuple[Select, PaginationParams | None]:
"""
Apply all query parameters (filtering, sorting, pagination) to a query.

This is the main utility function that combines all operations.

Args:
query: The base SQLAlchemy query
model: The SQLAlchemy model class
params: Combined query parameters
allowed_sort_fields: Optional list of fields that can be sorted on
allowed_filter_fields: Optional list of fields that can be filtered on

Returns:
Tuple of (modified query, pagination params used)
"""
if params is None:
return query, None

# Apply filters first (narrows down the dataset)
if params.filters:
query = apply_filters(query, model, params.filters, allowed_filter_fields)

# Apply sorting (before pagination to ensure consistent ordering)
if params.sort:
query = apply_sorting(query, model, params.sort, allowed_sort_fields)

# Apply pagination last
pagination_params = params.pagination
if pagination_params:
query = apply_pagination(query, pagination_params)

return query, pagination_params


async def get_total_count(
session: AsyncSession,
base_query: Select,
) -> int:
"""
Get the total count of results for a query (before pagination).

Args:
session: SQLAlchemy async session
base_query: The base query (with filters but before pagination)

Returns:
Total number of results
"""
# Use select(func.count()).select_from(base_query.subquery()) for proper counting
count_query = select(func.count()).select_from(base_query.subquery())
result = await session.execute(count_query)
return result.scalar() or 0


class PaginatedResponse[ModelType](BaseModel):
"""Generic paginated response wrapper matching existing PaginatedResponse."""

items: list[ModelType]
total_records: int
page_size: int
page_number: int
total_pages: int

@property
def has_next(self) -> bool:
"""Check if there's a next page."""
return self.page_number < self.total_pages

@property
def has_prev(self) -> bool:
"""Check if there's a previous page."""
return self.page_number > 1
68 changes: 33 additions & 35 deletions backend/src/modules/party/party_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,52 +68,50 @@ async def create_party(
async def list_parties(
page_number: int = Query(1, ge=1, description="Page number (1-indexed)"),
page_size: int | None = Query(None, ge=1, le=100, description="Items per page (default: all)"),
sort_by: str | None = Query(None, description="Field to sort by (e.g., 'party_datetime')"),
sort_order: str = Query("asc", pattern="^(asc|desc)$", description="Sort order: asc or desc"),
location_id: int | None = Query(None, description="Filter by location ID"),
contact_one_id: int | None = Query(None, description="Filter by contact one (student) ID"),
party_service: PartyService = Depends(),
_=Depends(authenticate_by_role("admin", "staff", "police")),
) -> PaginatedPartiesResponse:
"""
Returns all party registrations in the database with optional pagination.
Returns all party registrations in the database with optional pagination, sorting, and filtering

Query Parameters:
- page_number: The page number to retrieve (1-indexed)
- page_number: The page number to retrieve (1-indexed, default: 1)
- page_size: Number of items per page (max 100, default: returns all parties)
- sort_by: Field to sort by (allowed: party_datetime, location_id, contact_one_id, id)
- sort_order: Sort order (asc or desc, default: asc)
- location_id: Filter by location ID (optional)
- contact_one_id: Filter by contact one (student) ID (optional)

Features:
- **Opt-in**: All features have sensible defaults - no parameters returns all parties
- **Server-side**: All sorting, filtering, and pagination happens in the database
- **Performant**: Scales well with large datasets

Returns:
- items: List of party registrations
- total_records: Total number of records in the database
- page_size: Requested page size (or total_records if not specified)
- page_number: Requested page number
- total_pages: Total number of pages based on page size
- items: List of party registrations for the current page
- total_records: Total number of records matching filters (not just current page)
- page_size: Items per page (equals total_records when page_size is None)
- page_number: Current page number
- total_pages: Total number of pages based on page size and total records

Examples:
- Get all parties: GET /api/parties/
- Get first page of 10: GET /api/parties/?page_size=10
- Sort by date descending: GET /api/parties/?sort_by=party_datetime&sort_order=desc
- Filter by location: GET /api/parties/?location_id=5
- Combined: GET /api/parties/?location_id=5&sort_by=party_datetime&page_size=20
"""
# Get total count first
total_records = await party_service.get_party_count()

# If page_size is None, return all parties
if page_size is None:
parties = await party_service.get_parties(skip=0, limit=None)
return PaginatedPartiesResponse(
items=parties,
total_records=total_records,
page_size=total_records,
page_number=1,
total_pages=1,
)

# Calculate skip and limit for pagination
skip = (page_number - 1) * page_size

# Get parties with pagination
parties = await party_service.get_parties(skip=skip, limit=page_size)

# Calculate total pages (ceiling division)
total_pages = (total_records + page_size - 1) // page_size if total_records > 0 else 0

return PaginatedPartiesResponse(
items=parties,
total_records=total_records,
page_size=page_size,
return await party_service.get_parties_paginated(
page_number=page_number,
total_pages=total_pages,
page_size=page_size,
sort_by=sort_by,
sort_order=sort_order,
location_id=location_id,
contact_one_id=contact_one_id,
)


Expand Down
Loading