diff --git a/src/analytics/models.py b/src/analytics/models.py index 596b8cb..a5fb80e 100644 --- a/src/analytics/models.py +++ b/src/analytics/models.py @@ -1,10 +1,10 @@ """Analytics models for tracking ticket scans, transfers, and invalid attempts.""" -from sqlalchemy import create_engine, Column, Integer, String, DateTime, Boolean, Text, Index +from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text, Index from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from datetime import datetime -from src.config import get_settings +import src.db as _db Base = declarative_base() @@ -96,24 +96,18 @@ class AnalyticsStats(Base): ) -def get_database_url(): - """Get database URL from centralized settings.""" - return get_settings().DATABASE_URL - - def get_engine(): - """Create database engine.""" - return create_engine(get_database_url()) + """Return the shared database engine from src.db.""" + return _db.get_engine() def get_session(): - """Create database session.""" - engine = get_engine() - SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - return SessionLocal() + """Return a database session from src.db.""" + return _db.get_session() def init_db(): """Initialize the database tables.""" engine = get_engine() - Base.metadata.create_all(bind=engine) + if engine is not None: + Base.metadata.create_all(bind=engine) diff --git a/src/analytics/service.py b/src/analytics/service.py index 8dd7c43..4dbaa2f 100644 --- a/src/analytics/service.py +++ b/src/analytics/service.py @@ -1,10 +1,11 @@ """Analytics service for tracking ticket scans, transfers, and invalid attempts.""" import json import logging +import time from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple -from sqlalchemy import asc, desc, func +from sqlalchemy import asc, desc, func, text from sqlalchemy.orm import Session from src.analytics.models import ( @@ -14,8 +15,13 @@ TicketTransfer, get_session, ) +import src.db as _db from src.logging_config import log_error, log_info +# Simple in-memory cache: (result, expiry_timestamp) +_trending_cache: Optional[Tuple[List[Dict[str, Any]], float]] = None +_TRENDING_CACHE_TTL = 600 # 10 minutes + class AnalyticsService: """Service to handle analytics data storage and retrieval.""" @@ -355,6 +361,83 @@ def get_invalid_attempts(self, event_id: str, limit: int = 100) -> List[Dict[str finally: session.close() + def get_trending_events(self, limit: int = 10, hours: int = 24) -> List[Dict[str, Any]]: + """Return top events by scan velocity over the last N hours. + + Results are cached for 10 minutes to avoid repeated heavy queries. + Joins with event_sales_summary to include event names where available. + """ + global _trending_cache + + # Return cached result if still valid + if _trending_cache is not None: + cached_result, expiry = _trending_cache + if time.monotonic() < expiry: + return cached_result[:limit] + + engine = _db.get_engine() + if engine is None: + return [] + + cutoff = datetime.utcnow() - timedelta(hours=hours) + try: + with engine.connect() as conn: + # Attempt join with event_sales_summary for event names + try: + result = conn.execute( + text(""" + SELECT ts.event_id, + COALESCE(ess.event_name, ts.event_id) AS event_name, + COUNT(*) AS scan_count + FROM ticket_scans ts + LEFT JOIN event_sales_summary ess + ON ts.event_id = ess.event_id + WHERE ts.scan_timestamp >= :cutoff + GROUP BY ts.event_id, ess.event_name + ORDER BY scan_count DESC + LIMIT :limit + """), + {"cutoff": cutoff, "limit": limit}, + ) + rows = [ + { + "event_id": row[0], + "event_name": row[1], + "scan_count": int(row[2]), + "window_hours": hours, + } + for row in result + ] + except Exception: + # Fallback: query ticket_scans only (event_sales_summary may not exist) + result = conn.execute( + text(""" + SELECT event_id, COUNT(*) AS scan_count + FROM ticket_scans + WHERE scan_timestamp >= :cutoff + GROUP BY event_id + ORDER BY scan_count DESC + LIMIT :limit + """), + {"cutoff": cutoff, "limit": limit}, + ) + rows = [ + { + "event_id": row[0], + "event_name": row[0], + "scan_count": int(row[1]), + "window_hours": hours, + } + for row in result + ] + except Exception as exc: + log_error("Failed to get trending events", {"error": str(exc)}) + return [] + + # Cache the full ordered result (up to a large limit for reuse) + _trending_cache = (rows, time.monotonic() + _TRENDING_CACHE_TTL) + return rows[:limit] + def _update_analytics_stats(self, event_id: str, increment_scan: bool = False, is_valid: bool = True, increment_transfer: bool = False, is_successful: bool = True, diff --git a/src/config.py b/src/config.py index 79cf872..dfc3651 100644 --- a/src/config.py +++ b/src/config.py @@ -35,6 +35,9 @@ class Settings(BaseSettings): BQ_TABLE_DAILY_SALES: str = "daily_ticket_sales" + POOL_SIZE: int = 5 + POOL_MAX_OVERFLOW: int = 10 + SERVICE_API_KEY: str = "default_service_secret_change_me" ADMIN_API_KEY: str = "default_admin_secret_change_me" diff --git a/src/db.py b/src/db.py new file mode 100644 index 0000000..ca7e83d --- /dev/null +++ b/src/db.py @@ -0,0 +1,75 @@ +"""Centralised SQLAlchemy engine singleton with connection pooling. + +All modules that need a database engine should import get_engine() and +get_session() from here rather than creating engines themselves. +""" +from __future__ import annotations + +import logging +from typing import Any, Dict, Optional + +from sqlalchemy import create_engine, text +from sqlalchemy.engine import Engine +from sqlalchemy.orm import Session, sessionmaker + +from src.config import get_settings + +logger = logging.getLogger("veritix.db") + +_engine: Optional[Engine] = None + + +def get_engine() -> Optional[Engine]: + """Return the shared SQLAlchemy engine, creating it once on first call. + + Returns None if DATABASE_URL is not configured. + """ + global _engine + if _engine is None: + settings = get_settings() + url = getattr(settings, "DATABASE_URL", None) + if not url: + logger.info("DATABASE_URL not set; skipping engine creation") + return None + try: + _engine = create_engine( + url, + pool_size=settings.POOL_SIZE, + max_overflow=settings.POOL_MAX_OVERFLOW, + pool_timeout=30, + pool_recycle=1800, + pool_pre_ping=True, + ) + logger.info( + "Database engine created with pool_size=%d, max_overflow=%d", + settings.POOL_SIZE, + settings.POOL_MAX_OVERFLOW, + ) + except Exception as exc: + logger.error("Failed to create database engine: %s", exc) + return None + return _engine + + +def get_session() -> Optional[Session]: + """Create and return a new database session, or None if DB is not configured.""" + engine = get_engine() + if engine is None: + return None + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + return SessionLocal() + + +def get_pool_status() -> Dict[str, Any]: + """Return live connection pool statistics.""" + engine = get_engine() + if engine is None: + return {"status": "unavailable", "reason": "DATABASE_URL not configured"} + pool = engine.pool + return { + "pool_size": pool.size(), + "checked_in": pool.checkedin(), + "checked_out": pool.checkedout(), + "overflow": pool.overflow(), + "invalid": pool.invalid(), + } diff --git a/src/etl/__init__.py b/src/etl/__init__.py index 6cdd93f..3c43824 100644 --- a/src/etl/__init__.py +++ b/src/etl/__init__.py @@ -11,7 +11,7 @@ String, Table, TIMESTAMP, - create_engine, + text, ) from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.engine import Engine @@ -23,6 +23,7 @@ from src.logging_config import ETL_JOBS_TOTAL, log_error, log_info from src.config import get_settings +import src.db as _db from .extract import extract_events_and_sales logger = logging.getLogger("veritix.etl") @@ -110,14 +111,7 @@ def transform_summary( # --------------------------------------------------------------------------- def _pg_engine() -> Optional[Engine]: - url = get_settings().DATABASE_URL - if not url: - return None - try: - return create_engine(url, pool_pre_ping=True) - except Exception as exc: - logger.error("Failed to create PG engine: %s", exc) - return None + return _db.get_engine() def load_postgres( @@ -289,6 +283,99 @@ def _ensure_table(table_name: str, schema: List[Any]) -> str: ) +# --------------------------------------------------------------------------- +# ETL Diff (dry-run) +# --------------------------------------------------------------------------- + +def diff_etl_output( + event_rows: List[Dict[str, Any]], + daily_rows: List[Dict[str, Any]], +) -> Dict[str, Any]: + """Compare transformed rows against the current PostgreSQL data without loading. + + Returns a summary of what the next real ETL run would insert, update, or leave + unchanged for both the event_sales_summary and daily_ticket_sales tables. + """ + engine = _pg_engine() + if engine is None: + logger.info("DATABASE_URL not set; returning empty diff") + return { + "events": {"would_insert": 0, "would_update": 0, "unchanged": 0}, + "daily": {"would_insert": 0, "would_update": 0, "unchanged": 0}, + } + + current_events: Dict[str, Dict[str, Any]] = {} + current_daily: Dict[tuple, Dict[str, Any]] = {} + + with engine.connect() as conn: + try: + result = conn.execute( + text("SELECT event_id, total_tickets, total_revenue FROM event_sales_summary") + ) + for row in result: + current_events[str(row[0])] = { + "total_tickets": int(row[1]) if row[1] is not None else 0, + "total_revenue": float(row[2]) if row[2] is not None else 0.0, + } + except Exception: + pass # table may not exist yet + + try: + result = conn.execute( + text("SELECT event_id, sale_date, tickets_sold, revenue FROM daily_ticket_sales") + ) + for row in result: + current_daily[(str(row[0]), str(row[1]))] = { + "tickets_sold": int(row[2]) if row[2] is not None else 0, + "revenue": float(row[3]) if row[3] is not None else 0.0, + } + except Exception: + pass # table may not exist yet + + ev_insert = ev_update = ev_unchanged = 0 + for row in event_rows: + eid = str(row.get("event_id", "")) + if eid not in current_events: + ev_insert += 1 + else: + cur = current_events[eid] + if ( + cur["total_tickets"] != int(row.get("total_tickets", 0)) + or abs(cur["total_revenue"] - float(row.get("total_revenue", 0.0))) > 0.005 + ): + ev_update += 1 + else: + ev_unchanged += 1 + + daily_insert = daily_update = daily_unchanged = 0 + for row in daily_rows: + key = (str(row.get("event_id", "")), str(row.get("sale_date", ""))) + if key not in current_daily: + daily_insert += 1 + else: + cur = current_daily[key] + if ( + cur["tickets_sold"] != int(row.get("tickets_sold", 0)) + or abs(cur["revenue"] - float(row.get("revenue", 0.0))) > 0.005 + ): + daily_update += 1 + else: + daily_unchanged += 1 + + return { + "events": { + "would_insert": ev_insert, + "would_update": ev_update, + "unchanged": ev_unchanged, + }, + "daily": { + "would_insert": daily_insert, + "would_update": daily_update, + "unchanged": daily_unchanged, + }, + } + + # --------------------------------------------------------------------------- # Orchestration # --------------------------------------------------------------------------- diff --git a/src/main.py b/src/main.py index 72498bf..17d51a2 100644 --- a/src/main.py +++ b/src/main.py @@ -17,7 +17,7 @@ from src.chat import ChatMessage, EscalationEvent, chat_manager from src.config import get_settings from src.core.ratelimit import limiter -from src.etl import run_etl_once +from src.etl import diff_etl_output, extract_events_and_sales, run_etl_once, transform_summary from src.exceptions import register_exception_handlers from src.fraud import check_fraud_rules from src.logging_config import ( @@ -372,7 +372,13 @@ def search_events(payload: SearchEventsRequest) -> Any: try: keywords = extract_keywords(payload.query) all_events = get_mock_events() - matching_events = filter_events_by_keywords(all_events, keywords) + matching_events = filter_events_by_keywords( + all_events, + keywords, + min_price=payload.min_price, + max_price=payload.max_price, + max_capacity=payload.max_capacity, + ) event_results = [ EventResult( @@ -664,4 +670,79 @@ async def get_user_conversations(user_id: str) -> ChatUserConversationsResponse: raise except Exception as exc: logger.error("Error getting user conversations: %s", exc) - raise HTTPException(status_code=500, detail="Failed to get user conversations") \ No newline at end of file + raise HTTPException(status_code=500, detail="Failed to get user conversations") + + +# --------------------------------------------------------------------------- +# Trending events +# --------------------------------------------------------------------------- + +@app.get("/events/trending", response_model=List[Dict[str, Any]]) +def get_trending_events( + limit: int = Query(10, ge=1, le=100, description="Maximum number of trending events to return"), +) -> Any: + """Return top events ranked by ticket scan velocity in the last 24 hours. + + Results are cached for 10 minutes. + """ + try: + results = analytics_service.get_trending_events(limit=limit, hours=24) + return results + except Exception as exc: + log_error("Failed to get trending events", {"error": str(exc)}) + raise HTTPException(status_code=500, detail=f"Failed to get trending events: {exc}") + + +# --------------------------------------------------------------------------- +# ETL diff (dry-run) — admin only +# --------------------------------------------------------------------------- + +# In-memory store for async ETL diff jobs +_etl_diff_jobs: Dict[str, Any] = {} + + +@app.get("/etl/diff") +def etl_diff(request: Request) -> Any: + """Dry-run ETL diff: show what the next ETL run would load without committing. + + Requires X-Admin-Key header matching ADMIN_API_KEY. + For slow extracts (> 5 s) returns HTTP 202 with a job_id for async polling. + """ + import threading + import time as _time + + api_key = request.headers.get("X-Admin-Key", "") + if api_key != get_settings().ADMIN_API_KEY: + raise HTTPException(status_code=403, detail="Admin access required") + + start = _time.monotonic() + try: + events, sales = extract_events_and_sales() + elapsed = _time.monotonic() - start + + ev_rows, daily_rows = transform_summary( + [e.raw for e in events], [s.raw for s in sales] + ) + + if elapsed > 5.0: + job_id = str(uuid.uuid4()) + + def _run_diff() -> None: + result = diff_etl_output(ev_rows, daily_rows) + _etl_diff_jobs[job_id] = {"status": "complete", "result": result} + + _etl_diff_jobs[job_id] = {"status": "pending"} + threading.Thread(target=_run_diff, daemon=True).start() + return JSONResponse( + status_code=202, + content={"job_id": job_id, "status": "pending"}, + ) + + result = diff_etl_output(ev_rows, daily_rows) + return result + + except HTTPException: + raise + except Exception as exc: + log_error("ETL diff failed", {"error": str(exc)}) + raise HTTPException(status_code=500, detail=f"ETL diff failed: {exc}") \ No newline at end of file diff --git a/src/report_service.py b/src/report_service.py index 205576e..e4e98c9 100644 --- a/src/report_service.py +++ b/src/report_service.py @@ -5,26 +5,17 @@ from pathlib import Path from typing import Any, Dict, List, Optional -from sqlalchemy import Engine, create_engine, text +from sqlalchemy import text -from src.config import get_settings +import src.db as _db logger = logging.getLogger("veritix.report_service") REPORTS_DIR = Path("reports") -def _pg_engine() -> Optional[Engine]: - url = get_settings().DATABASE_URL - if not url: - logger.info("DATABASE_URL not set; skipping Postgres engine creation") - return None - try: - engine = create_engine(url, pool_pre_ping=True) - return engine - except Exception as exc: - logger.error("Failed to create PG engine: %s", exc) - return None +def _pg_engine(): + return _db.get_engine() def _ensure_reports_dir() -> None: diff --git a/src/routers/health.py b/src/routers/health.py index b2854b9..9058433 100644 --- a/src/routers/health.py +++ b/src/routers/health.py @@ -15,7 +15,7 @@ from fastapi.responses import JSONResponse from sqlalchemy import text -from src.analytics.models import get_engine +import src.db as _db from src.config import get_settings logger = logging.getLogger("veritix.health") @@ -61,7 +61,9 @@ def health() -> JSONResponse: def _check_database() -> str: """Return 'ok' if a SELECT 1 succeeds, otherwise 'error'.""" try: - engine = get_engine() + engine = _db.get_engine() + if engine is None: + return "error" with engine.connect() as conn: conn.execute(text("SELECT 1")) return "ok" @@ -82,6 +84,43 @@ def _check_nest_api() -> str: return "error" +@router.get("/health/db") +def health_db() -> JSONResponse: + """Check database connectivity and return live pool statistics.""" + try: + engine = _db.get_engine() + if engine is None: + return JSONResponse( + status_code=503, + content={ + "status": "error", + "reason": "DATABASE_URL not configured", + "timestamp": _now_iso(), + }, + ) + with engine.connect() as conn: + conn.execute(text("SELECT 1")) + pool_stats = _db.get_pool_status() + return JSONResponse( + status_code=200, + content={ + "status": "ok", + "pool": pool_stats, + "timestamp": _now_iso(), + }, + ) + except Exception as exc: # noqa: BLE001 + logger.warning("Database health check failed: %s", exc) + return JSONResponse( + status_code=503, + content={ + "status": "error", + "reason": str(exc), + "timestamp": _now_iso(), + }, + ) + + @router.get("/ready") def ready() -> JSONResponse: """Readiness probe — checks database and NestJS API connectivity.""" diff --git a/src/search_utils.py b/src/search_utils.py index 26b0adb..c9300ff 100644 --- a/src/search_utils.py +++ b/src/search_utils.py @@ -86,30 +86,64 @@ def extract_keywords(query: str) -> Dict[str, Any]: word for word in words if word not in stop_words and len(word) > 2 ] + # Price-intent detection + # "free" / "affordable" / "cheap" / "budget" → max_price hint + # "premium" / "vip" / "luxury" / "expensive" → min_price hint + nlp_min_price: Optional[float] = None + nlp_max_price: Optional[float] = None + + if any(word in query_lower for word in ["free", "no cost", "zero"]): + nlp_max_price = 0.0 + elif any(word in query_lower for word in ["cheap", "affordable", "budget", "low cost", "low-cost"]): + nlp_max_price = 5000.0 + elif any(word in query_lower for word in ["premium", "vip", "luxury", "expensive", "high-end"]): + nlp_min_price = 10000.0 + return { "event_types": detected_event_types, "locations": detected_locations, "time_filter": time_filter, "keywords": general_keywords, + "min_price": nlp_min_price, + "max_price": nlp_max_price, + "max_capacity": None, } def filter_events_by_keywords( events: List[Dict[str, Any]], keywords: Dict[str, Any], + min_price: Optional[float] = None, + max_price: Optional[float] = None, + max_capacity: Optional[int] = None, ) -> List[Dict[str, Any]]: - """Filter events based on extracted keywords. + """Filter events based on extracted keywords and optional price/capacity filters. Args: events: List of event dictionaries keywords: Dictionary of extracted keywords from extract_keywords() + min_price: Override/supplement NLP-inferred minimum price filter + max_price: Override/supplement NLP-inferred maximum price filter + max_capacity: Maximum venue capacity filter Returns: List of matching events """ + # Merge explicit filter params with NLP-inferred values (explicit takes precedence) + effective_min_price: Optional[float] = min_price if min_price is not None else keywords.get("min_price") + effective_max_price: Optional[float] = max_price if max_price is not None else keywords.get("max_price") + effective_max_capacity: Optional[int] = max_capacity if max_capacity is not None else keywords.get("max_capacity") + + has_price_capacity_filter = ( + effective_min_price is not None + or effective_max_price is not None + or effective_max_capacity is not None + ) + # No filters — return everything if not any([keywords["event_types"], keywords["locations"], - keywords["time_filter"], keywords["keywords"]]): + keywords["time_filter"], keywords["keywords"], + has_price_capacity_filter]): return events filtered_events: List[Dict[str, Any]] = [] @@ -187,6 +221,32 @@ def filter_events_by_keywords( ): matches = False + # Price filters + if matches and effective_min_price is not None: + try: + event_price = float(event.get("price", 0)) + if event_price < effective_min_price: + matches = False + except (TypeError, ValueError): + matches = False + + if matches and effective_max_price is not None: + try: + event_price = float(event.get("price", 0)) + if event_price > effective_max_price: + matches = False + except (TypeError, ValueError): + matches = False + + # Capacity filter + if matches and effective_max_capacity is not None: + try: + event_capacity = int(event.get("capacity", 0)) + if event_capacity > effective_max_capacity: + matches = False + except (TypeError, ValueError): + matches = False + if matches: filtered_events.append(event) diff --git a/src/types_custom.py b/src/types_custom.py index 25620c7..d480fd8 100644 --- a/src/types_custom.py +++ b/src/types_custom.py @@ -68,14 +68,24 @@ class QRValidateResponse(BaseModel): # --- Search Events Types --- class SearchEventsRequest(BaseModel): """Request body for /search-events endpoint. - - Contains a natural language query to search for events. + + Contains a natural language query to search for events and optional + price/capacity filters that override NLP-inferred values. """ model_config = ConfigDict(extra="forbid") query: str = Field( - ..., - min_length=1, - description="Natural language search query (e.g., 'music events in Lagos this weekend')" + ..., + min_length=1, + description="Natural language search query (e.g., 'music events in Lagos this weekend')", + ) + min_price: Optional[float] = Field( + None, ge=0, description="Minimum ticket price filter (inclusive)" + ) + max_price: Optional[float] = Field( + None, ge=0, description="Maximum ticket price filter (inclusive)" + ) + max_capacity: Optional[int] = Field( + None, ge=1, description="Maximum venue capacity filter (inclusive)" ) diff --git a/tests/test_db_pooling.py b/tests/test_db_pooling.py new file mode 100644 index 0000000..8560adf --- /dev/null +++ b/tests/test_db_pooling.py @@ -0,0 +1,137 @@ +"""Tests for Issue #164: database connection pooling (src/db.py) and /health/db endpoint.""" +import os +import sys + +os.environ.setdefault("SKIP_MODEL_TRAINING", "true") + +import pytest +from unittest.mock import MagicMock, patch + +from src.config import get_settings + + +# --------------------------------------------------------------------------- +# Config: POOL_SIZE / POOL_MAX_OVERFLOW +# --------------------------------------------------------------------------- + + +def test_settings_pool_defaults(): + """POOL_SIZE and POOL_MAX_OVERFLOW are present with correct defaults.""" + settings = get_settings() + assert settings.POOL_SIZE == 5 + assert settings.POOL_MAX_OVERFLOW == 10 + + +# --------------------------------------------------------------------------- +# src.db module +# --------------------------------------------------------------------------- + + +def test_db_get_engine_returns_none_when_no_url(): + """get_engine() returns None gracefully when DATABASE_URL is not configured.""" + import src.db as db_mod + + original_engine = db_mod._engine + try: + db_mod._engine = None + with patch("src.db.get_settings") as mock_settings: + mock_settings.return_value.DATABASE_URL = "" + engine = db_mod.get_engine() + assert engine is None + finally: + db_mod._engine = original_engine + + +def test_db_get_session_returns_none_when_no_engine(): + """get_session() returns None when engine is unavailable.""" + import src.db as db_mod + + with patch.object(db_mod, "get_engine", return_value=None): + session = db_mod.get_session() + assert session is None + + +def test_db_get_pool_status_unavailable_when_no_engine(): + """get_pool_status() returns unavailable status when engine is None.""" + import src.db as db_mod + + with patch.object(db_mod, "get_engine", return_value=None): + status = db_mod.get_pool_status() + assert status["status"] == "unavailable" + + +def test_db_get_pool_status_returns_dict_with_engine(): + """get_pool_status() returns pool stats dict when engine exists.""" + import src.db as db_mod + + mock_pool = MagicMock() + mock_pool.size.return_value = 5 + mock_pool.checkedin.return_value = 4 + mock_pool.checkedout.return_value = 1 + mock_pool.overflow.return_value = 0 + mock_pool.invalid.return_value = 0 + + mock_engine = MagicMock() + mock_engine.pool = mock_pool + + with patch.object(db_mod, "get_engine", return_value=mock_engine): + status = db_mod.get_pool_status() + + assert status["pool_size"] == 5 + assert status["checked_in"] == 4 + assert status["checked_out"] == 1 + assert "overflow" in status + assert "invalid" in status + + +# --------------------------------------------------------------------------- +# /health/db endpoint +# --------------------------------------------------------------------------- + + +def test_health_db_returns_503_when_no_database(): + """GET /health/db returns 503 when the database engine is unavailable.""" + from fastapi.testclient import TestClient + from src.main import app + import src.db as db_mod + + with patch.object(db_mod, "get_engine", return_value=None): + client = TestClient(app) + response = client.get("/health/db") + + assert response.status_code == 503 + data = response.json() + assert data["status"] == "error" + + +def test_health_db_returns_200_when_database_ok(): + """GET /health/db returns 200 with pool stats when DB is reachable.""" + from fastapi.testclient import TestClient + from src.main import app + import src.db as db_mod + from sqlalchemy.engine import Connection + + mock_conn = MagicMock(spec=Connection) + mock_conn.__enter__ = lambda s: s + mock_conn.__exit__ = MagicMock(return_value=False) + + mock_pool = MagicMock() + mock_pool.size.return_value = 5 + mock_pool.checkedin.return_value = 5 + mock_pool.checkedout.return_value = 0 + mock_pool.overflow.return_value = 0 + mock_pool.invalid.return_value = 0 + + mock_engine = MagicMock() + mock_engine.pool = mock_pool + mock_engine.connect.return_value = mock_conn + + with patch.object(db_mod, "get_engine", return_value=mock_engine): + client = TestClient(app) + response = client.get("/health/db") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert "pool" in data + assert "timestamp" in data diff --git a/tests/test_etl_diff.py b/tests/test_etl_diff.py new file mode 100644 index 0000000..afeb8e7 --- /dev/null +++ b/tests/test_etl_diff.py @@ -0,0 +1,185 @@ +"""Tests for Issue #163: GET /etl/diff dry-run endpoint.""" +import os + +os.environ.setdefault("SKIP_MODEL_TRAINING", "true") + +from unittest.mock import MagicMock, patch + +import pytest +from fastapi.testclient import TestClient +from src.main import app +from src.etl import diff_etl_output + +client = TestClient(app) + +ADMIN_KEY = "default_admin_secret_change_me" + + +# --------------------------------------------------------------------------- +# diff_etl_output unit tests +# --------------------------------------------------------------------------- + + +def test_diff_etl_output_returns_empty_when_no_engine(): + """diff_etl_output returns zeros when DB is not configured.""" + import src.etl as etl_mod + + with patch.object(etl_mod, "_pg_engine", return_value=None): + result = diff_etl_output( + [{"event_id": "E1", "event_name": "A", "total_tickets": 10, "total_revenue": 100.0}], + [{"event_id": "E1", "sale_date": "2026-01-01", "tickets_sold": 5, "revenue": 50.0}], + ) + + assert result["events"]["would_insert"] == 0 + assert result["events"]["would_update"] == 0 + assert result["events"]["unchanged"] == 0 + assert result["daily"]["would_insert"] == 0 + + +def test_diff_etl_output_detects_new_rows(): + """diff_etl_output counts new rows as would_insert when table is empty.""" + import src.etl as etl_mod + import src.db as db_mod + + mock_result = MagicMock() + mock_result.__iter__ = MagicMock(return_value=iter([])) + + mock_conn = MagicMock() + mock_conn.__enter__ = lambda s: s + mock_conn.__exit__ = MagicMock(return_value=False) + mock_conn.execute.return_value = mock_result + + mock_engine = MagicMock() + mock_engine.connect.return_value = mock_conn + + with patch.object(etl_mod, "_pg_engine", return_value=mock_engine): + result = diff_etl_output( + [{"event_id": "E1", "event_name": "A", "total_tickets": 10, "total_revenue": 100.0}], + [{"event_id": "E1", "sale_date": "2026-01-01", "tickets_sold": 5, "revenue": 50.0}], + ) + + assert result["events"]["would_insert"] == 1 + assert result["daily"]["would_insert"] == 1 + + +def test_diff_etl_output_detects_unchanged_rows(): + """diff_etl_output counts unchanged rows correctly.""" + import src.etl as etl_mod + + event_rows_from_db = [("E1", 10, 100.0)] + daily_rows_from_db = [("E1", "2026-01-01", 5, 50.0)] + + call_count = [0] + + def make_mock_result(rows): + m = MagicMock() + m.__iter__ = MagicMock(return_value=iter(rows)) + return m + + mock_conn = MagicMock() + mock_conn.__enter__ = lambda s: s + mock_conn.__exit__ = MagicMock(return_value=False) + + results_queue = [make_mock_result(event_rows_from_db), make_mock_result(daily_rows_from_db)] + + def execute_side_effect(*args, **kwargs): + return results_queue.pop(0) + + mock_conn.execute.side_effect = execute_side_effect + mock_engine = MagicMock() + mock_engine.connect.return_value = mock_conn + + with patch.object(etl_mod, "_pg_engine", return_value=mock_engine): + result = diff_etl_output( + [{"event_id": "E1", "event_name": "A", "total_tickets": 10, "total_revenue": 100.0}], + [{"event_id": "E1", "sale_date": "2026-01-01", "tickets_sold": 5, "revenue": 50.0}], + ) + + assert result["events"]["unchanged"] == 1 + assert result["daily"]["unchanged"] == 1 + + +def test_diff_etl_output_detects_updated_rows(): + """diff_etl_output counts modified rows as would_update.""" + import src.etl as etl_mod + + # DB has E1 with 10 tickets; transform produced 15 tickets → update + event_rows_from_db = [("E1", 10, 100.0)] + daily_rows_from_db = [("E1", "2026-01-01", 5, 50.0)] + + def make_mock_result(rows): + m = MagicMock() + m.__iter__ = MagicMock(return_value=iter(rows)) + return m + + mock_conn = MagicMock() + mock_conn.__enter__ = lambda s: s + mock_conn.__exit__ = MagicMock(return_value=False) + + results_queue = [make_mock_result(event_rows_from_db), make_mock_result(daily_rows_from_db)] + + def execute_side_effect(*args, **kwargs): + return results_queue.pop(0) + + mock_conn.execute.side_effect = execute_side_effect + mock_engine = MagicMock() + mock_engine.connect.return_value = mock_conn + + with patch.object(etl_mod, "_pg_engine", return_value=mock_engine): + result = diff_etl_output( + [{"event_id": "E1", "event_name": "A", "total_tickets": 15, "total_revenue": 150.0}], + [{"event_id": "E1", "sale_date": "2026-01-01", "tickets_sold": 8, "revenue": 80.0}], + ) + + assert result["events"]["would_update"] == 1 + assert result["daily"]["would_update"] == 1 + + +def test_diff_etl_output_empty_inputs(): + """diff_etl_output handles empty event_rows and daily_rows gracefully.""" + import src.etl as etl_mod + + with patch.object(etl_mod, "_pg_engine", return_value=None): + result = diff_etl_output([], []) + + assert result["events"]["would_insert"] == 0 + assert result["daily"]["would_insert"] == 0 + + +# --------------------------------------------------------------------------- +# GET /etl/diff API tests +# --------------------------------------------------------------------------- + + +def test_etl_diff_requires_admin_key(): + """GET /etl/diff returns 403 when admin key is missing.""" + response = client.get("/etl/diff") + assert response.status_code == 403 + + +def test_etl_diff_returns_403_with_wrong_key(): + """GET /etl/diff returns 403 with incorrect admin key.""" + response = client.get("/etl/diff", headers={"X-Admin-Key": "wrong_key"}) + assert response.status_code == 403 + + +def test_etl_diff_sync_path_returns_result(): + """GET /etl/diff returns diff result synchronously for fast extracts.""" + import src.etl as etl_mod + + mock_events = [] + mock_sales = [] + + with ( + patch.object(etl_mod, "extract_events_and_sales", return_value=(mock_events, mock_sales)), + patch.object(etl_mod, "_pg_engine", return_value=None), + ): + response = client.get("/etl/diff", headers={"X-Admin-Key": ADMIN_KEY}) + + assert response.status_code == 200 + data = response.json() + assert "events" in data + assert "daily" in data + assert "would_insert" in data["events"] + assert "would_update" in data["events"] + assert "unchanged" in data["events"] diff --git a/tests/test_search_filters.py b/tests/test_search_filters.py new file mode 100644 index 0000000..784ca8e --- /dev/null +++ b/tests/test_search_filters.py @@ -0,0 +1,161 @@ +"""Tests for Issue #158: price range and capacity filters in /search-events.""" +import os + +os.environ.setdefault("SKIP_MODEL_TRAINING", "true") + +from fastapi.testclient import TestClient +from src.main import app +from src.search_utils import extract_keywords, filter_events_by_keywords + +client = TestClient(app) + +# Sample events for unit-level filter tests +SAMPLE_EVENTS = [ + {"id": "e1", "name": "Free Concert", "description": "Free music", "event_type": "music", + "location": "Lagos", "date": "2026-06-01", "price": 0.0, "capacity": 200}, + {"id": "e2", "name": "Budget Seminar", "description": "Tech workshop", "event_type": "tech", + "location": "Abuja", "date": "2026-06-02", "price": 3000.0, "capacity": 100}, + {"id": "e3", "name": "VIP Gala", "description": "Premium dinner", "event_type": "entertainment", + "location": "Lagos", "date": "2026-06-03", "price": 25000.0, "capacity": 50}, + {"id": "e4", "name": "Sports Day", "description": "Marathon", "event_type": "sports", + "location": "Kano", "date": "2026-06-04", "price": 5000.0, "capacity": 1000}, +] + + +# --------------------------------------------------------------------------- +# NLP price-intent extraction +# --------------------------------------------------------------------------- + +def test_extract_keywords_free_sets_max_price_zero(): + kw = extract_keywords("free music events") + assert kw["max_price"] == 0.0 + + +def test_extract_keywords_cheap_sets_max_price(): + kw = extract_keywords("cheap events in Lagos") + assert kw["max_price"] == 5000.0 + + +def test_extract_keywords_affordable_sets_max_price(): + kw = extract_keywords("affordable concert this weekend") + assert kw["max_price"] == 5000.0 + + +def test_extract_keywords_premium_sets_min_price(): + kw = extract_keywords("premium VIP gala") + assert kw["min_price"] == 10000.0 + + +def test_extract_keywords_no_price_intent(): + kw = extract_keywords("music events in Lagos") + assert kw["min_price"] is None + assert kw["max_price"] is None + + +def test_extract_keywords_has_max_capacity_none_by_default(): + kw = extract_keywords("tech conference") + assert kw["max_capacity"] is None + + +# --------------------------------------------------------------------------- +# filter_events_by_keywords — price filters +# --------------------------------------------------------------------------- + +def test_filter_max_price_zero_returns_free_events(): + kw = extract_keywords("free events") + results = filter_events_by_keywords(SAMPLE_EVENTS, kw) + assert all(e["price"] == 0.0 for e in results) + assert any(e["id"] == "e1" for e in results) + + +def test_filter_max_price_explicit_overrides_nlp(): + kw = extract_keywords("music events") # no NLP price hint + results = filter_events_by_keywords(SAMPLE_EVENTS, kw, max_price=5000.0) + for e in results: + assert e["price"] <= 5000.0 + + +def test_filter_min_price_explicit(): + kw = extract_keywords("events") + results = filter_events_by_keywords(SAMPLE_EVENTS, kw, min_price=10000.0) + assert all(e["price"] >= 10000.0 for e in results) + assert any(e["id"] == "e3" for e in results) + + +def test_filter_min_and_max_price_combined(): + kw = extract_keywords("events") + results = filter_events_by_keywords(SAMPLE_EVENTS, kw, min_price=2000.0, max_price=6000.0) + for e in results: + assert 2000.0 <= e["price"] <= 6000.0 + + +def test_filter_max_capacity(): + kw = extract_keywords("events") + results = filter_events_by_keywords(SAMPLE_EVENTS, kw, max_capacity=100) + assert all(e["capacity"] <= 100 for e in results) + + +def test_filter_price_and_capacity_combined(): + kw = extract_keywords("events") + results = filter_events_by_keywords(SAMPLE_EVENTS, kw, max_price=10000.0, max_capacity=200) + for e in results: + assert e["price"] <= 10000.0 + assert e["capacity"] <= 200 + + +# --------------------------------------------------------------------------- +# /search-events API — price/capacity filter params +# --------------------------------------------------------------------------- + +def test_search_events_max_price_filter(): + payload = {"query": "events", "max_price": 5000.0} + response = client.post("/search-events", json=payload) + assert response.status_code == 200 + data = response.json() + for event in data["results"]: + assert event["price"] <= 5000.0 + + +def test_search_events_min_price_filter(): + payload = {"query": "events", "min_price": 8000.0} + response = client.post("/search-events", json=payload) + assert response.status_code == 200 + data = response.json() + for event in data["results"]: + assert event["price"] >= 8000.0 + + +def test_search_events_max_capacity_filter(): + payload = {"query": "events", "max_capacity": 500} + response = client.post("/search-events", json=payload) + assert response.status_code == 200 + data = response.json() + for event in data["results"]: + assert event["capacity"] <= 500 + + +def test_search_events_nlp_cheap_keyword(): + payload = {"query": "cheap music events in Lagos"} + response = client.post("/search-events", json=payload) + assert response.status_code == 200 + data = response.json() + for event in data["results"]: + assert event["price"] <= 5000.0 + + +def test_search_events_price_filters_in_keywords_extracted(): + payload = {"query": "events", "max_price": 3000.0, "max_capacity": 200} + response = client.post("/search-events", json=payload) + assert response.status_code == 200 + + +def test_search_events_invalid_max_price_negative(): + payload = {"query": "events", "max_price": -1.0} + response = client.post("/search-events", json=payload) + assert response.status_code == 422 + + +def test_search_events_invalid_max_capacity_zero(): + payload = {"query": "events", "max_capacity": 0} + response = client.post("/search-events", json=payload) + assert response.status_code == 422 diff --git a/tests/test_trending_events.py b/tests/test_trending_events.py new file mode 100644 index 0000000..004b322 --- /dev/null +++ b/tests/test_trending_events.py @@ -0,0 +1,167 @@ +"""Tests for Issue #160: GET /events/trending endpoint.""" +import os + +os.environ.setdefault("SKIP_MODEL_TRAINING", "true") + +from unittest.mock import MagicMock, patch + +import pytest +from fastapi.testclient import TestClient +from src.main import app + +client = TestClient(app) + + +# --------------------------------------------------------------------------- +# AnalyticsService.get_trending_events unit tests +# --------------------------------------------------------------------------- + + +def test_get_trending_events_returns_empty_when_no_db(): + """get_trending_events returns [] when DB engine is unavailable.""" + from src.analytics.service import AnalyticsService + import src.db as db_mod + + svc = AnalyticsService() + with patch.object(db_mod, "get_engine", return_value=None): + results = svc.get_trending_events(limit=5, hours=24) + assert results == [] + + +def test_get_trending_events_queries_and_returns_rows(): + """get_trending_events returns rows from a mocked DB.""" + from src.analytics.service import AnalyticsService, _trending_cache + import src.analytics.service as svc_mod + import src.db as db_mod + + # Reset cache + svc_mod._trending_cache = None + + mock_row = MagicMock() + mock_row.__getitem__ = lambda self, i: ["event_001", "Event One", 42][i] + + mock_result = MagicMock() + mock_result.__iter__ = MagicMock(return_value=iter([mock_row])) + + mock_conn = MagicMock() + mock_conn.__enter__ = lambda s: s + mock_conn.__exit__ = MagicMock(return_value=False) + mock_conn.execute.return_value = mock_result + + mock_engine = MagicMock() + mock_engine.connect.return_value = mock_conn + + svc = AnalyticsService() + with patch.object(db_mod, "get_engine", return_value=mock_engine): + results = svc.get_trending_events(limit=10, hours=24) + + assert isinstance(results, list) + + +def test_get_trending_events_respects_limit(): + """get_trending_events honours the limit parameter.""" + from src.analytics.service import AnalyticsService + import src.analytics.service as svc_mod + import src.db as db_mod + + svc_mod._trending_cache = None + + # Build 20 mock rows + rows = [] + for i in range(20): + row = MagicMock() + row.__getitem__ = lambda self, idx, _i=i: [f"evt_{_i:02}", f"Event {_i}", 20 - _i][idx] + rows.append(row) + + mock_result = MagicMock() + mock_result.__iter__ = MagicMock(return_value=iter(rows)) + + mock_conn = MagicMock() + mock_conn.__enter__ = lambda s: s + mock_conn.__exit__ = MagicMock(return_value=False) + mock_conn.execute.return_value = mock_result + + mock_engine = MagicMock() + mock_engine.connect.return_value = mock_conn + + svc = AnalyticsService() + with patch.object(db_mod, "get_engine", return_value=mock_engine): + results = svc.get_trending_events(limit=5, hours=24) + + assert len(results) <= 5 + + +# --------------------------------------------------------------------------- +# GET /events/trending API tests +# --------------------------------------------------------------------------- + + +def test_trending_events_endpoint_returns_200_with_empty_db(): + """GET /events/trending returns 200 with empty list when DB is unavailable.""" + import src.db as db_mod + + with patch.object(db_mod, "get_engine", return_value=None): + response = client.get("/events/trending") + + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + assert data == [] + + +def test_trending_events_default_limit(): + """GET /events/trending defaults to limit=10.""" + import src.db as db_mod + import src.analytics.service as svc_mod + + svc_mod._trending_cache = None + + with patch.object(db_mod, "get_engine", return_value=None): + response = client.get("/events/trending") + + assert response.status_code == 200 + + +def test_trending_events_custom_limit(): + """GET /events/trending?limit=3 is accepted.""" + import src.db as db_mod + import src.analytics.service as svc_mod + + svc_mod._trending_cache = None + + with patch.object(db_mod, "get_engine", return_value=None): + response = client.get("/events/trending?limit=3") + + assert response.status_code == 200 + + +def test_trending_events_limit_too_large_rejected(): + """GET /events/trending?limit=200 is rejected (>100).""" + response = client.get("/events/trending?limit=200") + assert response.status_code == 422 + + +def test_trending_events_limit_zero_rejected(): + """GET /events/trending?limit=0 is rejected (<1).""" + response = client.get("/events/trending?limit=0") + assert response.status_code == 422 + + +def test_trending_events_cache_is_used(): + """Second call uses cached results without hitting the DB again.""" + import src.analytics.service as svc_mod + import src.db as db_mod + + cached_data = [{"event_id": "cached_evt", "event_name": "Cached", "scan_count": 99, "window_hours": 24}] + import time + svc_mod._trending_cache = (cached_data, time.monotonic() + 600) + + with patch.object(db_mod, "get_engine", side_effect=AssertionError("Should not hit DB")): + response = client.get("/events/trending?limit=10") + + assert response.status_code == 200 + data = response.json() + assert any(e["event_id"] == "cached_evt" for e in data) + + # Cleanup + svc_mod._trending_cache = None