diff --git a/README.md b/README.md index 364058f..8629362 100644 --- a/README.md +++ b/README.md @@ -213,7 +213,24 @@ docker compose --profile dev up --build ```bash cd backend -pytest -q +python -m pytest -q +``` + +### 压力测试(后端) + +无需启动 uvicorn(默认走 ASGI in-process): + +```bash +cd backend +python scripts/stress_test.py --scenario health --requests 5000 --concurrency 100 +python scripts/stress_test.py --scenario me --users 50 --requests 5000 --concurrency 100 +``` + +如需对已启动的服务做压测,可指定 `--url`: + +```bash +cd backend +python scripts/stress_test.py --url http://127.0.0.1:8000 --scenario health --requests 10000 --concurrency 200 ``` 前端可使用: diff --git a/backend/app/db/session.py b/backend/app/db/session.py index eeac1f4..54c2f51 100644 --- a/backend/app/db/session.py +++ b/backend/app/db/session.py @@ -1,16 +1,34 @@ """Database session and engine configuration.""" +from __future__ import annotations + +from functools import lru_cache + from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine from app.core.config import get_settings settings = get_settings() -engine: AsyncEngine = create_async_engine(settings.async_database_uri, echo=False) -SessionLocal = async_sessionmaker(bind=engine, expire_on_commit=False, class_=AsyncSession) + +@lru_cache +def get_engine() -> AsyncEngine: + """Create (and memoize) the async engine. + + Lazily initializing the engine keeps import-time side effects minimal and + allows test suites to override DB dependencies without requiring every + production DB driver to be installed. + """ + + return create_async_engine(settings.async_database_uri, echo=False) + + +@lru_cache +def get_sessionmaker() -> async_sessionmaker[AsyncSession]: + return async_sessionmaker(bind=get_engine(), expire_on_commit=False, class_=AsyncSession) async def get_db() -> AsyncSession: """Yield an async database session for request lifecycle management.""" - async with SessionLocal() as session: + async with get_sessionmaker()() as session: yield session diff --git a/backend/app/db/utils.py b/backend/app/db/utils.py index 3106da6..9181ff9 100644 --- a/backend/app/db/utils.py +++ b/backend/app/db/utils.py @@ -23,7 +23,7 @@ def ensure_database_exists(settings: Settings, retries: int = 5, delay: float = ) attempt = 0 - while attempt <= retries: + while True: engine = create_engine(default_uri, isolation_level="AUTOCOMMIT") try: with engine.connect() as connection: @@ -37,7 +37,7 @@ def ensure_database_exists(settings: Settings, retries: int = 5, delay: float = logger.info("Created database %s", settings.postgres_db) else: logger.debug("Database %s already present", settings.postgres_db) - return + break except SQLAlchemyError as exc: attempt += 1 logger.warning("Database availability check failed (attempt %s/%s): %s", attempt, retries, exc) diff --git a/backend/pytest.ini b/backend/pytest.ini new file mode 100644 index 0000000..2184bc5 --- /dev/null +++ b/backend/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + asyncio: async test executed via event loop diff --git a/backend/scripts/stress_test.py b/backend/scripts/stress_test.py new file mode 100644 index 0000000..d0c3273 --- /dev/null +++ b/backend/scripts/stress_test.py @@ -0,0 +1,259 @@ +"""Simple backend stress/load test runner. + +Default mode uses ASGITransport so you don't need to start uvicorn. +It is intentionally lightweight (no extra dependencies). +""" + +from __future__ import annotations + +import argparse +import asyncio +import contextlib +import json +import math +import os +import random +import shutil +import statistics +import tempfile +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, AsyncIterator +from uuid import uuid4 + +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine + +from app.db.session import get_db +from app.models.base import Base + + +@dataclass +class Result: + ok: int + errors: int + status_counts: dict[int, int] + latencies_ms: list[float] + elapsed_s: float + + +def _percentile(sorted_values: list[float], p: float) -> float: + if not sorted_values: + return 0.0 + if p <= 0: + return float(sorted_values[0]) + if p >= 100: + return float(sorted_values[-1]) + k = (len(sorted_values) - 1) * (p / 100.0) + f = math.floor(k) + c = math.ceil(k) + if f == c: + return float(sorted_values[int(k)]) + d0 = sorted_values[f] * (c - k) + d1 = sorted_values[c] * (k - f) + return float(d0 + d1) + + +def _print_summary(result: Result) -> None: + total = result.ok + result.errors + rps = total / result.elapsed_s if result.elapsed_s > 0 else 0.0 + lat_sorted = sorted(result.latencies_ms) + avg = statistics.fmean(lat_sorted) if lat_sorted else 0.0 + p50 = _percentile(lat_sorted, 50) + p95 = _percentile(lat_sorted, 95) + p99 = _percentile(lat_sorted, 99) + worst = lat_sorted[-1] if lat_sorted else 0.0 + + print(f"total={total} ok={result.ok} errors={result.errors} elapsed_s={result.elapsed_s:.3f} rps={rps:.1f}") + print(f"latency_ms avg={avg:.2f} p50={p50:.2f} p95={p95:.2f} p99={p99:.2f} max={worst:.2f}") + codes = " ".join(f"{code}:{count}" for code, count in sorted(result.status_counts.items())) + print(f"status_codes {codes}") + + +async def _setup_sqlite_db(db_path: Path) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: + # Ensure models are imported for metadata registration + import app.models # noqa: F401 + + url = f"sqlite+aiosqlite:///{db_path}" + engine = create_async_engine(url, future=True) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + return engine, async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) + + +async def _build_asgi_client(session_factory: async_sessionmaker[AsyncSession]) -> AsyncClient: + from app.main import app + + async def override_get_db() -> AsyncIterator[AsyncSession]: + async with session_factory() as session: + yield session + + app.dependency_overrides[get_db] = override_get_db + return AsyncClient(transport=ASGITransport(app=app), base_url="http://test") + + +async def _register_and_login(client: AsyncClient, email: str, password: str) -> str: + resp = await client.post("/api/v1/users", json={"email": email, "password": password}) + # allow re-runs with same user count / db + if resp.status_code not in (201, 409): + raise RuntimeError(f"register failed: {resp.status_code} {resp.text}") + + token_resp = await client.post( + "/api/v1/auth/token", + data={"username": email, "password": password}, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + if token_resp.status_code != 200: + raise RuntimeError(f"login failed: {token_resp.status_code} {token_resp.text}") + return token_resp.json()["access_token"] + + +async def _run_fixed_requests( + client: AsyncClient, + request_count: int, + concurrency: int, + make_request, +) -> Result: + sem = asyncio.Semaphore(concurrency) + latencies_ms: list[float] = [0.0] * request_count + status_counts: dict[int, int] = {} + ok = 0 + errors = 0 + lock = asyncio.Lock() + + async def one(i: int) -> None: + nonlocal ok, errors + async with sem: + start = time.perf_counter() + try: + status = await make_request(client) + except Exception: + status = 0 + elapsed_ms = (time.perf_counter() - start) * 1000 + latencies_ms[i] = elapsed_ms + async with lock: + status_counts[status] = status_counts.get(status, 0) + 1 + if 200 <= status < 400: + ok += 1 + else: + errors += 1 + + started = time.perf_counter() + async with asyncio.TaskGroup() as tg: + for i in range(request_count): + tg.create_task(one(i)) + elapsed_s = time.perf_counter() - started + return Result(ok=ok, errors=errors, status_counts=status_counts, latencies_ms=latencies_ms, elapsed_s=elapsed_s) + + +async def main() -> int: + parser = argparse.ArgumentParser(description="Backend stress/load test (ASGI in-process by default).") + parser.add_argument("--scenario", choices=("health", "me"), default="health") + parser.add_argument("--requests", type=int, default=2000) + parser.add_argument("--concurrency", type=int, default=50) + parser.add_argument("--warmup", type=int, default=50) + parser.add_argument("--users", type=int, default=50, help="Only for scenario=me") + parser.add_argument("--password", default="StrongPass1!") + parser.add_argument("--json", dest="json_output", action="store_true", help="Print JSON summary") + parser.add_argument( + "--url", + default="", + help="Live server base URL (e.g. http://127.0.0.1:8000). If set, requests go over HTTP instead of ASGI.", + ) + args = parser.parse_args() + + if args.requests <= 0 or args.concurrency <= 0: + raise SystemExit("--requests/--concurrency must be > 0") + + if args.scenario == "me" and args.users <= 0: + raise SystemExit("--users must be > 0 for scenario=me") + + # Build client + tmpdir = Path(tempfile.mkdtemp(prefix="ir-stress-")) + db_path = tmpdir / "stress.sqlite3" + client: AsyncClient | None = None + engine: AsyncEngine | None = None + using_asgi = not bool(args.url) + + try: + if args.url: + # Live server mode (requires backend already running) + client = AsyncClient(base_url=args.url.rstrip("/")) + else: + # Redirect common ML config caches to a writable directory to avoid sandbox issues. + (tmpdir / "mpl").mkdir(parents=True, exist_ok=True) + os.environ.setdefault("MPLCONFIGDIR", str(tmpdir / "mpl")) + os.environ.setdefault("YOLO_CONFIG_DIR", str(tmpdir / "ultralytics")) + + engine, session_factory = await _setup_sqlite_db(db_path) + client = await _build_asgi_client(session_factory) + + # Scenario setup + headers_list: list[dict[str, str]] = [] + if args.scenario == "me": + # pre-create tokens so hot path is only authenticated GET + for _ in range(args.users): + email = f"stress-{uuid4().hex[:12]}@example.com" + token = await _register_and_login(client, email, args.password) + headers_list.append({"Authorization": f"Bearer {token}"}) + + # Warmup + async def warmup_request(c: AsyncClient) -> int: + if args.scenario == "health": + return (await c.get("/health")).status_code + header = random.choice(headers_list) + return (await c.get("/api/v1/users/me", headers=header)).status_code + + for _ in range(args.warmup): + await warmup_request(client) + + # Main load + async def make_request(c: AsyncClient) -> int: + if args.scenario == "health": + return (await c.get("/health")).status_code + header = random.choice(headers_list) + return (await c.get("/api/v1/users/me", headers=header)).status_code + + result = await _run_fixed_requests(client, args.requests, args.concurrency, make_request) + + if args.json_output: + payload = { + "scenario": args.scenario, + "requests": args.requests, + "concurrency": args.concurrency, + "ok": result.ok, + "errors": result.errors, + "elapsed_s": result.elapsed_s, + "rps": (args.requests / result.elapsed_s) if result.elapsed_s else 0.0, + "status_counts": result.status_counts, + "latency_ms": { + "avg": statistics.fmean(result.latencies_ms) if result.latencies_ms else 0.0, + "p50": _percentile(sorted(result.latencies_ms), 50), + "p95": _percentile(sorted(result.latencies_ms), 95), + "p99": _percentile(sorted(result.latencies_ms), 99), + "max": max(result.latencies_ms) if result.latencies_ms else 0.0, + }, + } + print(json.dumps(payload, ensure_ascii=False, indent=2)) + else: + _print_summary(result) + + return 0 + finally: + if client is not None: + await client.aclose() + if engine is not None: + with contextlib.suppress(Exception): + await engine.dispose() + if using_asgi: + with contextlib.suppress(Exception): + from app.main import app + + app.dependency_overrides.clear() + with contextlib.suppress(Exception): + shutil.rmtree(tmpdir, ignore_errors=True) + + +if __name__ == "__main__": + raise SystemExit(asyncio.run(main())) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index df6c422..fd3ffc7 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -4,11 +4,11 @@ import asyncio import os import shutil +import inspect from pathlib import Path from typing import AsyncIterator, Generator import pytest -import pytest_asyncio # type: ignore[import-not-found] from httpx import ASGITransport, AsyncClient from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine @@ -32,31 +32,70 @@ def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]: loop.close() -@pytest_asyncio.fixture() -async def async_client() -> AsyncIterator[AsyncClient]: - """Yield an AsyncClient wired to an isolated SQLite database.""" +@pytest.fixture() +def async_client(event_loop: asyncio.AbstractEventLoop) -> AsyncIterator[AsyncClient]: + """Yield an AsyncClient wired to an isolated SQLite database. + + This fixture is intentionally implemented without pytest-asyncio so the + test suite can run with vanilla pytest. + """ engine = create_async_engine(TEST_DATABASE_URL, future=True, connect_args={"uri": True}) - if TEST_MEDIA_ROOT.exists(): - shutil.rmtree(TEST_MEDIA_ROOT) - TEST_MEDIA_ROOT.mkdir(parents=True, exist_ok=True) + async def setup() -> AsyncClient: + if TEST_MEDIA_ROOT.exists(): + shutil.rmtree(TEST_MEDIA_ROOT) + TEST_MEDIA_ROOT.mkdir(parents=True, exist_ok=True) - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) - session_factory = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) + session_factory = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) - async def override_get_db() -> AsyncIterator[AsyncSession]: - async with session_factory() as session: - yield session + async def override_get_db() -> AsyncIterator[AsyncSession]: + async with session_factory() as session: + yield session - app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_db] = override_get_db - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as client: - yield client + transport = ASGITransport(app=app) + return AsyncClient(transport=transport, base_url="http://test") - app.dependency_overrides.clear() - await engine.dispose() - shutil.rmtree(TEST_MEDIA_ROOT, ignore_errors=True) \ No newline at end of file + client = event_loop.run_until_complete(setup()) + + try: + yield client + finally: + async def teardown() -> None: + await client.aclose() + app.dependency_overrides.clear() + await engine.dispose() + shutil.rmtree(TEST_MEDIA_ROOT, ignore_errors=True) + + event_loop.run_until_complete(teardown()) + + +def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> bool | None: + """Run async tests without pytest-asyncio.""" + + if not inspect.iscoroutinefunction(pyfuncitem.obj): + return None + + argnames = getattr(pyfuncitem, "_fixtureinfo", None) + if argnames is not None: + names = list(pyfuncitem._fixtureinfo.argnames) # type: ignore[attr-defined] + test_kwargs = {name: pyfuncitem.funcargs[name] for name in names} + else: + test_kwargs = dict(pyfuncitem.funcargs) + + loop = pyfuncitem.funcargs.get("event_loop") + if loop is None: + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(pyfuncitem.obj(**test_kwargs)) + finally: + loop.close() + return True + + loop.run_until_complete(pyfuncitem.obj(**test_kwargs)) + return True diff --git a/backend/tests/test_config.py b/backend/tests/test_config.py new file mode 100644 index 0000000..5361cca --- /dev/null +++ b/backend/tests/test_config.py @@ -0,0 +1,29 @@ +"""Unit tests for configuration helpers.""" +from __future__ import annotations + +from pathlib import Path + +from app.core.config import Settings, get_settings + + +def test_settings_redis_url() -> None: + settings = Settings(redis_host="127.0.0.1", redis_port=6380, redis_db=2) + assert settings.redis_url == "redis://127.0.0.1:6380/2" + + +def test_settings_frontend_oauth_redirect_url_normalizes_path() -> None: + settings = Settings(frontend_app_url="http://localhost:3000/", frontend_oauth_redirect_path="auth/sso") + assert settings.frontend_oauth_redirect_url == "http://localhost:3000/auth/sso" + + +def test_settings_media_path_prefers_env_override(monkeypatch) -> None: + override = Path(__file__).resolve().parent / "tmp-media" + monkeypatch.setenv("MEDIA_ROOT", str(override)) + settings = Settings(media_root="/should/not/be/used") + assert settings.media_path == override + + +def test_get_settings_appends_wildcard_cors_in_local_env() -> None: + get_settings.cache_clear() + settings = get_settings() + assert "*" in settings.cors_origins diff --git a/backend/tests/test_db_utils.py b/backend/tests/test_db_utils.py new file mode 100644 index 0000000..07f213c --- /dev/null +++ b/backend/tests/test_db_utils.py @@ -0,0 +1,224 @@ +"""Unit tests for database utility helpers.""" +from __future__ import annotations + +import pytest +from sqlalchemy.exc import SQLAlchemyError + +from app.core.config import Settings +from app.db import utils as db_utils + + +class _ScalarResult: + def __init__(self, value) -> None: + self._value = value + + def scalar(self): + return self._value + + +class _FakeConnection: + def __init__(self, exists: bool, executed: list[str]) -> None: + self._exists = exists + self._executed = executed + self._select_calls = 0 + + def execute(self, statement, params=None): # noqa: ANN001 + sql = str(statement) + self._executed.append(sql) + if "SELECT 1 FROM pg_database" in sql: + self._select_calls += 1 + return _ScalarResult(1 if self._exists else None) + return _ScalarResult(None) + + +class _ConnectCtx: + def __init__(self, connection: _FakeConnection | None, error: Exception | None) -> None: + self._connection = connection + self._error = error + + def __enter__(self) -> _FakeConnection: + if self._error is not None: + raise self._error + assert self._connection is not None + return self._connection + + def __exit__(self, exc_type, exc, tb) -> None: # noqa: ANN001 + return None + + +class _FakeEngine: + def __init__(self, connect_ctx: _ConnectCtx, disposed: list[bool]) -> None: + self._connect_ctx = connect_ctx + self._disposed = disposed + + def connect(self) -> _ConnectCtx: + return self._connect_ctx + + def dispose(self) -> None: + self._disposed.append(True) + + +class _FakeCursor: + def __init__(self, executed: list[str]) -> None: + self._executed = executed + + def execute(self, sql: str) -> None: + self._executed.append(sql) + + def __enter__(self) -> "_FakeCursor": + return self + + def __exit__(self, exc_type, exc, tb) -> None: # noqa: ANN001 + return None + + +class _FakePsycopgConn: + def __init__(self, executed: list[str]) -> None: + self.autocommit = False + self._executed = executed + + def cursor(self) -> _FakeCursor: + return _FakeCursor(self._executed) + + def __enter__(self) -> "_FakePsycopgConn": + return self + + def __exit__(self, exc_type, exc, tb) -> None: # noqa: ANN001 + return None + + +class _WarmupConn: + def __init__(self, closed: list[bool]) -> None: + self._closed = closed + + def close(self) -> None: + self._closed.append(True) + + +def _settings() -> Settings: + return Settings( + postgres_user="demo_user", + postgres_password="demo_pass", + postgres_host="localhost", + postgres_port=5432, + postgres_db="demo_db", + ) + + +def test_ensure_database_exists_skips_create_when_present(monkeypatch: pytest.MonkeyPatch) -> None: + executed: list[str] = [] + disposed: list[bool] = [] + grants: list[str] = [] + warmup_closed: list[bool] = [] + + engine = _FakeEngine(_ConnectCtx(_FakeConnection(exists=True, executed=executed), None), disposed) + + def fake_create_engine(*args, **kwargs): # noqa: ANN001 + return engine + + def fake_connect(**kwargs): # noqa: ANN001 + if kwargs.get("dbname") == "postgres": + return _FakePsycopgConn(grants) + return _WarmupConn(warmup_closed) + + monkeypatch.setattr(db_utils, "create_engine", fake_create_engine) + monkeypatch.setattr(db_utils.psycopg, "connect", fake_connect) + + db_utils.ensure_database_exists(_settings()) + + assert any("SELECT 1 FROM pg_database" in sql for sql in executed) + assert not any("CREATE DATABASE" in sql for sql in executed) + assert disposed == [True] + assert any('GRANT CONNECT ON DATABASE "demo_db"' in sql for sql in grants) + assert warmup_closed == [True] + + +def test_ensure_database_exists_creates_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + executed: list[str] = [] + disposed: list[bool] = [] + grants: list[str] = [] + warmup_closed: list[bool] = [] + + engine = _FakeEngine(_ConnectCtx(_FakeConnection(exists=False, executed=executed), None), disposed) + + def fake_create_engine(*args, **kwargs): # noqa: ANN001 + return engine + + def fake_connect(**kwargs): # noqa: ANN001 + if kwargs.get("dbname") == "postgres": + return _FakePsycopgConn(grants) + return _WarmupConn(warmup_closed) + + monkeypatch.setattr(db_utils, "create_engine", fake_create_engine) + monkeypatch.setattr(db_utils.psycopg, "connect", fake_connect) + + db_utils.ensure_database_exists(_settings()) + + assert any("CREATE DATABASE" in sql for sql in executed) + assert disposed == [True] + assert any('GRANT ALL PRIVILEGES ON DATABASE "demo_db"' in sql for sql in grants) + assert warmup_closed == [True] + + +def test_ensure_database_exists_retries_then_succeeds(monkeypatch: pytest.MonkeyPatch) -> None: + executed: list[str] = [] + disposed: list[bool] = [] + sleep_calls: list[float] = [] + grants: list[str] = [] + warmup_closed: list[bool] = [] + + attempt = {"count": 0} + + def fake_create_engine(*args, **kwargs): # noqa: ANN001 + attempt["count"] += 1 + if attempt["count"] < 3: + ctx = _ConnectCtx(None, SQLAlchemyError("not ready")) + else: + ctx = _ConnectCtx(_FakeConnection(exists=True, executed=executed), None) + return _FakeEngine(ctx, disposed) + + def fake_sleep(delay: float) -> None: + sleep_calls.append(delay) + + def fake_connect(**kwargs): # noqa: ANN001 + if kwargs.get("dbname") == "postgres": + return _FakePsycopgConn(grants) + return _WarmupConn(warmup_closed) + + monkeypatch.setattr(db_utils, "create_engine", fake_create_engine) + monkeypatch.setattr(db_utils.time, "sleep", fake_sleep) + monkeypatch.setattr(db_utils.psycopg, "connect", fake_connect) + + db_utils.ensure_database_exists(_settings(), retries=5, delay=0.25) + + assert sleep_calls == [0.25, 0.25] + assert disposed == [True, True, True] + assert warmup_closed == [True] + + +def test_ensure_database_exists_raises_after_retries(monkeypatch: pytest.MonkeyPatch) -> None: + disposed: list[bool] = [] + sleep_calls: list[float] = [] + psycopg_calls: list[bool] = [] + + def fake_create_engine(*args, **kwargs): # noqa: ANN001 + ctx = _ConnectCtx(None, SQLAlchemyError("still down")) + return _FakeEngine(ctx, disposed) + + def fake_sleep(delay: float) -> None: + sleep_calls.append(delay) + + def fake_connect(**kwargs): # noqa: ANN001 + psycopg_calls.append(True) + raise AssertionError("psycopg.connect should not be called when ensure_database_exists raises") + + monkeypatch.setattr(db_utils, "create_engine", fake_create_engine) + monkeypatch.setattr(db_utils.time, "sleep", fake_sleep) + monkeypatch.setattr(db_utils.psycopg, "connect", fake_connect) + + with pytest.raises(SQLAlchemyError): + db_utils.ensure_database_exists(_settings(), retries=1, delay=0.1) + + assert sleep_calls == [0.1] + assert disposed == [True, True] + assert psycopg_calls == [] diff --git a/backend/tests/test_security.py b/backend/tests/test_security.py new file mode 100644 index 0000000..7582d52 --- /dev/null +++ b/backend/tests/test_security.py @@ -0,0 +1,45 @@ +"""Security regression tests (auth/JWT).""" +from __future__ import annotations + +import base64 +import json +from datetime import timedelta +from time import time + +import pytest +from httpx import AsyncClient + +from app.core.auth import create_access_token + +pytestmark = pytest.mark.asyncio + + +def _b64url(data: bytes) -> str: + return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") + + +def _none_alg_token(subject: str, expires_in_seconds: int = 600) -> str: + header = {"alg": "none", "typ": "JWT"} + payload = {"sub": subject, "exp": int(time()) + expires_in_seconds} + return f"{_b64url(json.dumps(header).encode())}.{_b64url(json.dumps(payload).encode())}." + + +async def test_rejects_expired_access_token(async_client: AsyncClient) -> None: + token = create_access_token("expired@example.com", expires_delta=timedelta(minutes=-10)) + response = await async_client.get("/api/v1/users/me", headers={"Authorization": f"Bearer {token}"}) + assert response.status_code == 401 + assert response.json()["detail"] == "无法验证凭据" + + +async def test_rejects_none_alg_token(async_client: AsyncClient) -> None: + token = _none_alg_token("forged@example.com") + response = await async_client.get("/api/v1/users/me", headers={"Authorization": f"Bearer {token}"}) + assert response.status_code == 401 + assert response.json()["detail"] == "无法验证凭据" + + +async def test_rejects_token_for_missing_user(async_client: AsyncClient) -> None: + token = create_access_token("missing-user@example.com") + response = await async_client.get("/api/v1/users/me", headers={"Authorization": f"Bearer {token}"}) + assert response.status_code == 401 + assert response.json()["detail"] == "无法验证凭据" diff --git a/doc/testing.md b/doc/testing.md new file mode 100644 index 0000000..42e2bd9 --- /dev/null +++ b/doc/testing.md @@ -0,0 +1,115 @@ +# 测试执行指南(单元 / 安全 / 压力) + +本文档独立说明如何在本仓库执行三类测试:**单元测试**、**安全性测试**、**压力测试**(压测)。 + +## 0. 前置条件 + +- 进入后端目录:`cd backend` +- 使用可用的 Python 环境(推荐复用项目的 `backend/.venv`) + +如果你使用 `backend/.venv`: + +```bash +cd backend +source .venv/bin/activate +python -V +``` + +## 1. 单元测试(Unit Tests) + +用途:快速验证业务逻辑、接口行为是否正确(会跑 `backend/tests/` 下全部测试)。 + +执行: + +```bash +cd backend +python -m pytest -q +``` + +只跑某个文件/用例: + +```bash +cd backend +python -m pytest -q tests/test_config.py +python -m pytest -q -k token +``` + +## 2. 安全性测试(Security Tests) + +用途:做安全回归,确保关键安全边界不会被破坏(例如 JWT 过期/伪造 token 必须被拒绝)。 + +执行全部测试(包含安全性测试): + +```bash +cd backend +python -m pytest -q +``` + +仅执行安全性测试文件: + +```bash +cd backend +python -m pytest -q tests/test_security.py +``` + +说明: + +- 安全性测试目前主要覆盖鉴权链路(JWT)。 +- 如果后续新增其它安全回归点(如权限校验、文件上传校验、注入防护等),建议继续往 `backend/tests/test_security*.py` 扩展。 + +## 3. 压力测试(Stress / Load Tests) + +用途:模拟并发访问,观察吞吐(RPS/QPS)和延迟分布(p50/p95/p99),用于容量评估与性能回归。 + +本项目压测脚本:`backend/scripts/stress_test.py` +特点: + +- 默认使用 **ASGI in-process**(不需要启动 `uvicorn`,不走真实网络端口) +- 也支持 `--url` 对已启动服务做真实 HTTP 压测 +- 压测不纳入默认 `pytest`(避免 CI/本地跑测试时耗时、不稳定) + +### 3.1 不启动服务(默认 ASGI 模式) + +压测健康检查接口: + +```bash +cd backend +python scripts/stress_test.py --scenario health --requests 5000 --concurrency 100 +``` + +压测鉴权接口(会预先创建 `--users` 个用户并登录获取 token,压测阶段只打 `/api/v1/users/me`): + +```bash +cd backend +python scripts/stress_test.py --scenario me --users 50 --requests 5000 --concurrency 100 +``` + +输出包含: + +- `rps`:整体吞吐 +- `latency_ms`:平均与 p50/p95/p99/max +- `status_codes`:状态码分布 + +### 3.2 压测已启动服务(真实 HTTP) + +先确保后端已启动在某个地址(例如 `http://127.0.0.1:8000`),再执行: + +```bash +cd backend +python scripts/stress_test.py --url http://127.0.0.1:8000 --scenario health --requests 10000 --concurrency 200 +``` + +### 3.3 常用参数 + +- `--requests`:总请求数(默认 `2000`) +- `--concurrency`:并发数(默认 `50`) +- `--warmup`:预热请求数(默认 `50`) +- `--json`:输出 JSON 汇总(便于存档/画图) + +示例(JSON 输出): + +```bash +cd backend +python scripts/stress_test.py --scenario health --requests 2000 --concurrency 100 --json +``` +