Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .devcontainer/post_create.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ npm i --verbose

cd ../backend
python -m script.create_db
python -m script.create_test_db
python -m script.reset_dev
24 changes: 23 additions & 1 deletion .github/workflows/backend-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,31 @@ name: Run Backend Tests

on:
pull_request:
branches: [main]
paths:
- "backend/**"
- ".github/workflows/backend-test.yml"
push:
branches: [main]

jobs:
backend-test:
runs-on: ubuntu-latest

services:
postgres:
image: postgres:latest
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: ocsl_test
ports:
- 5432:5432
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5

steps:
- name: Checkout code
uses: actions/checkout@v4
Expand All @@ -32,6 +49,11 @@ jobs:
PYTHONPATH=backend:backend/src pytest backend/test -v --tb=long --junitxml=backend/test-results/junit.xml
env:
GOOGLE_MAPS_API_KEY: "test_api_key_not_used"
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_HOST: localhost
POSTGRES_PORT: 5432
POSTGRES_DB: ocsl_test

- name: Publish Test Results
uses: EnricoMi/publish-unit-test-result-action@v2
Expand Down
3 changes: 2 additions & 1 deletion backend/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ python_files = *_test.py
python_classes = Test*
python_functions = test_*
asyncio_mode = auto
asyncio_default_fixture_loop_scope = function
asyncio_default_fixture_loop_scope = session
asyncio_default_test_loop_scope = session
Comment on lines +8 to +9
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Necessary for async session fixture shenanigans

addopts = --color=yes
14 changes: 10 additions & 4 deletions backend/script/create_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
engine = create_engine(server_url(sync=True), isolation_level="AUTOCOMMIT")

with engine.connect() as connection:
print("Creating database...")
connection.execute(text(f"CREATE DATABASE {env.POSTGRES_DATABASE}"))

print("Database successfully created")
# Check if database already exists
result = connection.execute(
text(f"SELECT 1 FROM pg_database WHERE datname = '{env.POSTGRES_DATABASE}'")
)
if result.fetchone():
print(f"Database '{env.POSTGRES_DATABASE}' already exists")
else:
print(f"Creating database '{env.POSTGRES_DATABASE}'...")
connection.execute(text(f"CREATE DATABASE {env.POSTGRES_DATABASE}"))
print("Database successfully created")
16 changes: 16 additions & 0 deletions backend/script/create_test_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from src.core.database import server_url
from sqlalchemy import create_engine, text

engine = create_engine(server_url(sync=True), isolation_level="AUTOCOMMIT")

with engine.connect() as connection:
# Check if test database already exists
result = connection.execute(
text("SELECT 1 FROM pg_database WHERE datname = 'ocsl_test'")
)
if result.fetchone():
print("Test database 'ocsl_test' already exists")
else:
print("Creating test database 'ocsl_test'...")
connection.execute(text("CREATE DATABASE ocsl_test"))
print("Test database successfully created")
5 changes: 5 additions & 0 deletions backend/src/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ def __init__(self, detail: str):
super().__init__(status_code=400, detail=detail)


class UnprocessableEntityException(HTTPException):
def __init__(self, detail: str):
super().__init__(status_code=422, detail=detail)


class CredentialsException(HTTPException):
def __init__(self):
super().__init__(
Expand Down
2 changes: 1 addition & 1 deletion backend/src/modules/account/account_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class AccountEntity(MappedAsDataclass, EntityBase):
pid: Mapped[str] = mapped_column(
String(9),
CheckConstraint(
"length(pid) = 9",
"length(pid) = 9 AND pid ~ '^[0-9]{9}$'",
name="check_pid_format",
),
nullable=False,
Expand Down
2 changes: 1 addition & 1 deletion backend/src/modules/complaint/complaint_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class ComplaintEntity(MappedAsDataclass, EntityBase):
location_id: Mapped[int] = mapped_column(
Integer, ForeignKey("locations.id", ondelete="CASCADE"), nullable=False
)
complaint_datetime: Mapped[datetime] = mapped_column(DateTime, nullable=False)
complaint_datetime: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
description: Mapped[str] = mapped_column(String, nullable=False, default="")

# Relationships
Expand Down
58 changes: 19 additions & 39 deletions backend/src/modules/party/party_router.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime
from datetime import datetime, timezone

from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, Query
from fastapi.responses import Response
from src.core.authentication import (
authenticate_admin,
Expand All @@ -9,7 +9,7 @@
authenticate_staff_or_admin,
authenticate_user,
)
from src.core.exceptions import BadRequestException, ForbiddenException
from src.core.exceptions import BadRequestException, ForbiddenException, UnprocessableEntityException
from src.modules.account.account_model import Account, AccountRole
from src.modules.location.location_service import LocationService

Expand Down Expand Up @@ -52,9 +52,7 @@ async def create_party(
return await party_service.create_party_from_student_dto(party_data, user.id)
elif isinstance(party_data, AdminCreatePartyDTO):
if user.role != AccountRole.ADMIN:
raise ForbiddenException(
detail="Only admins can use the admin party creation endpoint"
)
raise ForbiddenException(detail="Only admins can use the admin party creation endpoint")
return await party_service.create_party_from_admin_dto(party_data)
else:
raise ForbiddenException(detail="Invalid request type")
Expand All @@ -63,9 +61,7 @@ async def create_party(
@party_router.get("/")
async def list_parties(
page_number: int = Query(1, ge=1, description="Page number (1-indexed)"),
page_size: int | None = Query(
None, ge=1, le=100, description="Items per page (default: all)"
),
page_size: int | None = Query(None, ge=1, le=100, description="Items per page (default: all)"),
party_service: PartyService = Depends(),
_=Depends(authenticate_by_role("admin", "staff", "police")),
) -> PaginatedPartiesResponse:
Expand Down Expand Up @@ -104,9 +100,7 @@ async def list_parties(
parties = await party_service.get_parties(skip=skip, limit=page_size)

# Calculate total pages (ceiling division)
total_pages = (
(total_records + page_size - 1) // page_size if total_records > 0 else 0
)
total_pages = (total_records + page_size - 1) // page_size if total_records > 0 else 0

return PaginatedPartiesResponse(
items=parties,
Expand All @@ -120,8 +114,8 @@ async def list_parties(
@party_router.get("/nearby")
async def get_parties_nearby(
place_id: str = Query(..., description="Google Maps place ID"),
start_date: str = Query(..., description="Start date (YYYY-MM-DD format)"),
end_date: str = Query(..., description="End date (YYYY-MM-DD format)"),
start_date: str = Query(..., pattern=r"^\d{4}-\d{2}-\d{2}$", description="Start date (YYYY-MM-DD format)"),
end_date: str = Query(..., pattern=r"^\d{4}-\d{2}-\d{2}$", description="End date (YYYY-MM-DD format)"),
party_service: PartyService = Depends(),
location_service: LocationService = Depends(),
_=Depends(authenticate_police_or_admin),
Expand All @@ -145,12 +139,12 @@ async def get_parties_nearby(
"""
# Parse date strings to datetime objects
try:
start_datetime = datetime.strptime(start_date, "%Y-%m-%d")
end_datetime = datetime.strptime(end_date, "%Y-%m-%d")
start_datetime = datetime.strptime(start_date, "%Y-%m-%d").replace(tzinfo=timezone.utc)
end_datetime = datetime.strptime(end_date, "%Y-%m-%d").replace(tzinfo=timezone.utc)
# Set end_datetime to end of day (23:59:59)
end_datetime = end_datetime.replace(hour=23, minute=59, second=59)
except ValueError as e:
raise BadRequestException(f"Invalid date format. Expected YYYY-MM-DD: {str(e)}")
raise UnprocessableEntityException(f"Invalid date format. Expected YYYY-MM-DD: {str(e)}")

# Validate that start_date is not greater than end_date
if start_datetime > end_datetime:
Expand All @@ -172,8 +166,8 @@ async def get_parties_nearby(

@party_router.get("/csv")
async def get_parties_csv(
start_date: str = Query(..., description="Start date in YYYY-MM-DD format"),
end_date: str = Query(..., description="End date in YYYY-MM-DD format"),
start_date: str = Query(..., pattern=r"^\d{4}-\d{2}-\d{2}$", description="Start date in YYYY-MM-DD format"),
end_date: str = Query(..., pattern=r"^\d{4}-\d{2}-\d{2}$", description="End date in YYYY-MM-DD format"),
party_service: PartyService = Depends(),
_=Depends(authenticate_admin),
) -> Response:
Expand All @@ -194,25 +188,15 @@ async def get_parties_csv(
start_datetime = datetime.strptime(start_date, "%Y-%m-%d")
end_datetime = datetime.strptime(end_date, "%Y-%m-%d")

end_datetime = end_datetime.replace(
hour=23, minute=59, second=59, microsecond=999999
)
end_datetime = end_datetime.replace(hour=23, minute=59, second=59, microsecond=999999)
except ValueError:
raise HTTPException(
status_code=400,
detail="Invalid date format. Use YYYY-MM-DD format for dates.",
)
raise UnprocessableEntityException("Invalid date format. Use YYYY-MM-DD format for dates.")

# Validate that start_date is not greater than end_date
if start_datetime > end_datetime:
raise HTTPException(
status_code=400,
detail="Start date must be less than or equal to end date",
)
raise BadRequestException("Start date must be less than or equal to end date")

parties = await party_service.get_parties_by_date_range(
start_datetime, end_datetime
)
parties = await party_service.get_parties_by_date_range(start_datetime, end_datetime)
csv_content = await party_service.export_parties_to_csv(parties)

return Response(
Expand Down Expand Up @@ -247,14 +231,10 @@ async def update_party(
raise ForbiddenException(
detail="Only students can use the student party update endpoint"
)
return await party_service.update_party_from_student_dto(
party_id, party_data, user.id
)
return await party_service.update_party_from_student_dto(party_id, party_data, user.id)
elif isinstance(party_data, AdminCreatePartyDTO):
if user.role != AccountRole.ADMIN:
raise ForbiddenException(
detail="Only admins can use the admin party update endpoint"
)
raise ForbiddenException(detail="Only admins can use the admin party update endpoint")
return await party_service.update_party_from_admin_dto(party_id, party_data)
else:
raise ForbiddenException(detail="Invalid request type")
Expand Down
40 changes: 34 additions & 6 deletions backend/test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

from sqlalchemy import text

os.environ["GOOGLE_MAPS_API_KEY"] = "invalid_google_maps_api_key_for_tests"

from typing import Any, AsyncGenerator, Callable
Expand All @@ -12,8 +14,9 @@
import src.modules # Ensure all modules are imported so their entities are registered # noqa: F401
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.ext.asyncio.engine import AsyncEngine
from src.core.authentication import StringRole
from src.core.database import EntityBase, get_session
from src.core.database import EntityBase, database_url, get_session
from src.main import app
from src.modules.account.account_service import AccountService
from src.modules.complaint.complaint_service import ComplaintService
Expand All @@ -29,20 +32,45 @@
from test.modules.police.police_utils import PoliceTestUtils
from test.modules.student.student_utils import StudentTestUtils

DATABASE_URL = "sqlite+aiosqlite:///:memory:"
DATABASE_URL = database_url("ocsl_test")

# =================================== Database ======================================


@pytest_asyncio.fixture(scope="function")
async def test_session():
@pytest_asyncio.fixture(autouse=True, scope="session", loop_scope="session")
async def test_engine():
"""Create engine and tables once per test session."""
engine = create_async_engine(DATABASE_URL, echo=False)
async with engine.begin() as conn:
await conn.run_sync(EntityBase.metadata.drop_all)
await conn.run_sync(EntityBase.metadata.create_all)
TestAsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
yield engine
async with engine.begin() as conn:
await conn.run_sync(EntityBase.metadata.drop_all)
await engine.dispose()


@pytest_asyncio.fixture(scope="function")
async def test_session(test_engine: AsyncEngine):
"""Create a new session and truncate all tables after each test."""
TestAsyncSessionLocal = async_sessionmaker(
bind=test_engine,
expire_on_commit=False,
class_=AsyncSession,
)

async with TestAsyncSessionLocal() as session:
yield session
await engine.dispose()

# Clean up: truncate all tables and reset sequences
async with test_engine.begin() as conn:
tables = [table.name for table in EntityBase.metadata.sorted_tables]
if tables:
# Disable foreign key checks, truncate, and reset sequences
await conn.execute(text("SET session_replication_role = 'replica';"))
for table in tables:
await conn.execute(text(f"TRUNCATE TABLE {table} RESTART IDENTITY CASCADE;"))
await conn.execute(text("SET session_replication_role = 'origin';"))
Comment on lines +65 to +73
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Truncate instead of drop all tables so that tests don't take a million years



# =================================== Clients =======================================
Expand Down
58 changes: 34 additions & 24 deletions backend/test/modules/party/party_router_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,31 +390,22 @@ async def test_get_parties_nearby_with_date_range(self):
assert data[0].id == party_valid.id

@pytest.mark.asyncio
async def test_get_parties_nearby_validation_errors(self):
@pytest.mark.parametrize(
"params",
[
{"start_date": "2024-01-01", "end_date": "2024-01-02"}, # missing place_id
{"place_id": "ChIJtest123", "end_date": "2024-01-02"}, # missing start_date
{"place_id": "ChIJtest123", "start_date": "2024-01-01"}, # missing end_date
{"place_id": "ChIJtest123", "start_date": "01-01-2024", "end_date": "2024-01-02"}, # MM-DD-YYYY
{"place_id": "ChIJtest123", "start_date": "2024-01-01", "end_date": "01/02/2024"}, # MM/DD/YYYY
{"place_id": "ChIJtest123", "start_date": "2024/01/01", "end_date": "2024-01-02"}, # slashes
{"place_id": "ChIJtest123", "start_date": "2024-1-1", "end_date": "2024-01-02"}, # no leading zeros
{"place_id": "ChIJtest123", "start_date": "2024-01-01", "end_date": "24-01-02"}, # 2-digit year
{"place_id": "ChIJtest123", "start_date": "not-a-date", "end_date": "2024-01-02"}, # invalid string
],
)
async def test_get_parties_nearby_validation_errors(self, params: dict[str, str]):
"""Test validation errors for nearby search."""
now = datetime.now(timezone.utc)

# Missing place_id
params = {
"start_date": now.strftime("%Y-%m-%d"),
"end_date": (now + timedelta(days=1)).strftime("%Y-%m-%d"),
}
response = await self.admin_client.get("/api/parties/nearby", params=params)
assert_res_validation_error(response)

# Missing start_date
params = {
"place_id": "ChIJtest123",
"end_date": (now + timedelta(days=1)).strftime("%Y-%m-%d"),
}
response = await self.admin_client.get("/api/parties/nearby", params=params)
assert_res_validation_error(response)

# Missing end_date
params = {
"place_id": "ChIJtest123",
"start_date": now.strftime("%Y-%m-%d"),
}
response = await self.admin_client.get("/api/parties/nearby", params=params)
assert_res_validation_error(response)

Expand Down Expand Up @@ -470,3 +461,22 @@ async def test_get_parties_csv_with_data(self):
# Verify party IDs are in CSV
for party in parties:
assert str(party.id) in csv_content

@pytest.mark.asyncio
@pytest.mark.parametrize(
"params",
[
{"end_date": "2024-01-02"}, # missing start_date
{"start_date": "2024-01-01"}, # missing end_date
{"start_date": "01-01-2024", "end_date": "2024-01-02"}, # MM-DD-YYYY
{"start_date": "2024-01-01", "end_date": "01/02/2024"}, # MM/DD/YYYY
{"start_date": "2024/01/01", "end_date": "2024-01-02"}, # slashes
{"start_date": "2024-1-1", "end_date": "2024-01-02"}, # no leading zeros
{"start_date": "2024-01-01", "end_date": "24-01-02"}, # 2-digit year
{"start_date": "not-a-date", "end_date": "2024-01-02"}, # invalid string
],
)
async def test_get_parties_csv_validation_errors(self, params: dict[str, str]):
"""Test validation errors for CSV export."""
response = await self.admin_client.get("/api/parties/csv", params=params)
assert_res_validation_error(response)