Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,24 @@ docker compose --profile dev up --build

```bash
cd backend
pytest -q
python -m pytest -q
```

### 压力测试(后端)

无需启动 uvicorn(默认走 ASGI in-process):

```bash
cd backend
python scripts/stress_test.py --scenario health --requests 5000 --concurrency 100
python scripts/stress_test.py --scenario me --users 50 --requests 5000 --concurrency 100
```

如需对已启动的服务做压测,可指定 `--url`:

```bash
cd backend
python scripts/stress_test.py --url http://127.0.0.1:8000 --scenario health --requests 10000 --concurrency 200
```

前端可使用:
Expand Down
24 changes: 21 additions & 3 deletions backend/app/db/session.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,34 @@
"""Database session and engine configuration."""
from __future__ import annotations

from functools import lru_cache

from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine

from app.core.config import get_settings

settings = get_settings()

engine: AsyncEngine = create_async_engine(settings.async_database_uri, echo=False)
SessionLocal = async_sessionmaker(bind=engine, expire_on_commit=False, class_=AsyncSession)

@lru_cache
def get_engine() -> AsyncEngine:
"""Create (and memoize) the async engine.

Lazily initializing the engine keeps import-time side effects minimal and
allows test suites to override DB dependencies without requiring every
production DB driver to be installed.
"""

return create_async_engine(settings.async_database_uri, echo=False)
Comment on lines +13 to +22
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using '@lru_cache' on 'get_engine' and 'get_sessionmaker' may cause issues in multi-threaded or multi-process environments. The SQLAlchemy async engine is not inherently thread-safe across process boundaries (e.g., when using multiprocessing or after fork). Consider using '@lru_cache(maxsize=1)' for clarity, or document that this service should not be used with multiprocessing/forking servers without proper engine disposal between forks.

Copilot uses AI. Check for mistakes.


@lru_cache
def get_sessionmaker() -> async_sessionmaker[AsyncSession]:
return async_sessionmaker(bind=get_engine(), expire_on_commit=False, class_=AsyncSession)


async def get_db() -> AsyncSession:
"""Yield an async database session for request lifecycle management."""

async with SessionLocal() as session:
async with get_sessionmaker()() as session:
yield session
4 changes: 2 additions & 2 deletions backend/app/db/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def ensure_database_exists(settings: Settings, retries: int = 5, delay: float =
)

attempt = 0
while attempt <= retries:
while True:
engine = create_engine(default_uri, isolation_level="AUTOCOMMIT")
try:
with engine.connect() as connection:
Expand All @@ -37,7 +37,7 @@ def ensure_database_exists(settings: Settings, retries: int = 5, delay: float =
logger.info("Created database %s", settings.postgres_db)
else:
logger.debug("Database %s already present", settings.postgres_db)
return
break
except SQLAlchemyError as exc:
attempt += 1
logger.warning("Database availability check failed (attempt %s/%s): %s", attempt, retries, exc)
Expand Down
3 changes: 3 additions & 0 deletions backend/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
markers =
asyncio: async test executed via event loop
259 changes: 259 additions & 0 deletions backend/scripts/stress_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
"""Simple backend stress/load test runner.

Default mode uses ASGITransport so you don't need to start uvicorn.
It is intentionally lightweight (no extra dependencies).
"""

from __future__ import annotations

import argparse
import asyncio
import contextlib
import json
import math
import os
import random
import shutil
import statistics
import tempfile
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, AsyncIterator
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 'Any' import on line 22 is unused. Consider removing it to keep imports clean.

Suggested change
from typing import Any, AsyncIterator
from typing import AsyncIterator

Copilot uses AI. Check for mistakes.
from uuid import uuid4

from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine

from app.db.session import get_db
from app.models.base import Base


@dataclass
class Result:
ok: int
errors: int
status_counts: dict[int, int]
latencies_ms: list[float]
elapsed_s: float


def _percentile(sorted_values: list[float], p: float) -> float:
if not sorted_values:
return 0.0
if p <= 0:
return float(sorted_values[0])
if p >= 100:
return float(sorted_values[-1])
k = (len(sorted_values) - 1) * (p / 100.0)
f = math.floor(k)
c = math.ceil(k)
if f == c:
return float(sorted_values[int(k)])
d0 = sorted_values[f] * (c - k)
d1 = sorted_values[c] * (k - f)
return float(d0 + d1)


def _print_summary(result: Result) -> None:
total = result.ok + result.errors
rps = total / result.elapsed_s if result.elapsed_s > 0 else 0.0
lat_sorted = sorted(result.latencies_ms)
avg = statistics.fmean(lat_sorted) if lat_sorted else 0.0
p50 = _percentile(lat_sorted, 50)
p95 = _percentile(lat_sorted, 95)
p99 = _percentile(lat_sorted, 99)
worst = lat_sorted[-1] if lat_sorted else 0.0

print(f"total={total} ok={result.ok} errors={result.errors} elapsed_s={result.elapsed_s:.3f} rps={rps:.1f}")
print(f"latency_ms avg={avg:.2f} p50={p50:.2f} p95={p95:.2f} p99={p99:.2f} max={worst:.2f}")
codes = " ".join(f"{code}:{count}" for code, count in sorted(result.status_counts.items()))
print(f"status_codes {codes}")


async def _setup_sqlite_db(db_path: Path) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]:
# Ensure models are imported for metadata registration
import app.models # noqa: F401

url = f"sqlite+aiosqlite:///{db_path}"
engine = create_async_engine(url, future=True)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
return engine, async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)


async def _build_asgi_client(session_factory: async_sessionmaker[AsyncSession]) -> AsyncClient:
from app.main import app

async def override_get_db() -> AsyncIterator[AsyncSession]:
async with session_factory() as session:
yield session

app.dependency_overrides[get_db] = override_get_db
return AsyncClient(transport=ASGITransport(app=app), base_url="http://test")


async def _register_and_login(client: AsyncClient, email: str, password: str) -> str:
resp = await client.post("/api/v1/users", json={"email": email, "password": password})
# allow re-runs with same user count / db
if resp.status_code not in (201, 409):
raise RuntimeError(f"register failed: {resp.status_code} {resp.text}")

token_resp = await client.post(
"/api/v1/auth/token",
data={"username": email, "password": password},
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
if token_resp.status_code != 200:
raise RuntimeError(f"login failed: {token_resp.status_code} {token_resp.text}")
return token_resp.json()["access_token"]


async def _run_fixed_requests(
client: AsyncClient,
request_count: int,
concurrency: int,
make_request,
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing type annotation for the 'make_request' parameter. Consider adding a type hint such as 'Callable[[AsyncClient], Awaitable[int]]' to improve code clarity and enable better type checking.

Copilot uses AI. Check for mistakes.
) -> Result:
sem = asyncio.Semaphore(concurrency)
latencies_ms: list[float] = [0.0] * request_count
status_counts: dict[int, int] = {}
ok = 0
errors = 0
lock = asyncio.Lock()

async def one(i: int) -> None:
nonlocal ok, errors
async with sem:
start = time.perf_counter()
try:
status = await make_request(client)
except Exception:
status = 0
elapsed_ms = (time.perf_counter() - start) * 1000
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 'lock' is only used to protect 'status_counts' and the 'ok'/'errors' counters, but 'latencies_ms[i]' is written without lock protection on line 134. While this is actually safe because each task writes to a unique index 'i', the inconsistent locking pattern could be confusing. Consider adding a comment explaining why 'latencies_ms[i]' doesn't need lock protection (unique index per task).

Suggested change
elapsed_ms = (time.perf_counter() - start) * 1000
elapsed_ms = (time.perf_counter() - start) * 1000
# This write is safe without the lock: each task gets a unique index i,
# so no two tasks ever write to the same latencies_ms element.

Copilot uses AI. Check for mistakes.
latencies_ms[i] = elapsed_ms
async with lock:
status_counts[status] = status_counts.get(status, 0) + 1
if 200 <= status < 400:
ok += 1
else:
errors += 1

started = time.perf_counter()
async with asyncio.TaskGroup() as tg:
for i in range(request_count):
tg.create_task(one(i))
elapsed_s = time.perf_counter() - started
return Result(ok=ok, errors=errors, status_counts=status_counts, latencies_ms=latencies_ms, elapsed_s=elapsed_s)


async def main() -> int:
parser = argparse.ArgumentParser(description="Backend stress/load test (ASGI in-process by default).")
parser.add_argument("--scenario", choices=("health", "me"), default="health")
parser.add_argument("--requests", type=int, default=2000)
parser.add_argument("--concurrency", type=int, default=50)
parser.add_argument("--warmup", type=int, default=50)
parser.add_argument("--users", type=int, default=50, help="Only for scenario=me")
parser.add_argument("--password", default="StrongPass1!")
parser.add_argument("--json", dest="json_output", action="store_true", help="Print JSON summary")
parser.add_argument(
"--url",
default="",
help="Live server base URL (e.g. http://127.0.0.1:8000). If set, requests go over HTTP instead of ASGI.",
)
args = parser.parse_args()

if args.requests <= 0 or args.concurrency <= 0:
raise SystemExit("--requests/--concurrency must be > 0")

if args.scenario == "me" and args.users <= 0:
raise SystemExit("--users must be > 0 for scenario=me")

# Build client
tmpdir = Path(tempfile.mkdtemp(prefix="ir-stress-"))
db_path = tmpdir / "stress.sqlite3"
client: AsyncClient | None = None
engine: AsyncEngine | None = None
using_asgi = not bool(args.url)

try:
if args.url:
# Live server mode (requires backend already running)
client = AsyncClient(base_url=args.url.rstrip("/"))
else:
# Redirect common ML config caches to a writable directory to avoid sandbox issues.
(tmpdir / "mpl").mkdir(parents=True, exist_ok=True)
os.environ.setdefault("MPLCONFIGDIR", str(tmpdir / "mpl"))
os.environ.setdefault("YOLO_CONFIG_DIR", str(tmpdir / "ultralytics"))

engine, session_factory = await _setup_sqlite_db(db_path)
client = await _build_asgi_client(session_factory)

# Scenario setup
headers_list: list[dict[str, str]] = []
if args.scenario == "me":
# pre-create tokens so hot path is only authenticated GET
for _ in range(args.users):
email = f"stress-{uuid4().hex[:12]}@example.com"
token = await _register_and_login(client, email, args.password)
headers_list.append({"Authorization": f"Bearer {token}"})

# Warmup
async def warmup_request(c: AsyncClient) -> int:
if args.scenario == "health":
return (await c.get("/health")).status_code
header = random.choice(headers_list)
return (await c.get("/api/v1/users/me", headers=header)).status_code

for _ in range(args.warmup):
await warmup_request(client)

# Main load
async def make_request(c: AsyncClient) -> int:
if args.scenario == "health":
return (await c.get("/health")).status_code
header = random.choice(headers_list)
return (await c.get("/api/v1/users/me", headers=header)).status_code

result = await _run_fixed_requests(client, args.requests, args.concurrency, make_request)

if args.json_output:
payload = {
"scenario": args.scenario,
"requests": args.requests,
"concurrency": args.concurrency,
"ok": result.ok,
"errors": result.errors,
"elapsed_s": result.elapsed_s,
"rps": (args.requests / result.elapsed_s) if result.elapsed_s else 0.0,
"status_counts": result.status_counts,
"latency_ms": {
"avg": statistics.fmean(result.latencies_ms) if result.latencies_ms else 0.0,
"p50": _percentile(sorted(result.latencies_ms), 50),
"p95": _percentile(sorted(result.latencies_ms), 95),
"p99": _percentile(sorted(result.latencies_ms), 99),
"max": max(result.latencies_ms) if result.latencies_ms else 0.0,
},
}
print(json.dumps(payload, ensure_ascii=False, indent=2))
else:
_print_summary(result)

return 0
finally:
if client is not None:
await client.aclose()
if engine is not None:
with contextlib.suppress(Exception):
await engine.dispose()
if using_asgi:
with contextlib.suppress(Exception):
from app.main import app

app.dependency_overrides.clear()
with contextlib.suppress(Exception):
shutil.rmtree(tmpdir, ignore_errors=True)


if __name__ == "__main__":
raise SystemExit(asyncio.run(main()))
Loading
Loading