diff --git a/packages/dojozero/src/dojozero/cli.py b/packages/dojozero/src/dojozero/cli.py index 299aa7c6..a1d39a6e 100644 --- a/packages/dojozero/src/dojozero/cli.py +++ b/packages/dojozero/src/dojozero/cli.py @@ -1547,16 +1547,58 @@ async def _resolve_event_files( if not cache_path.exists(): LOGGER.info("Downloading events from %s", url) try: - async with httpx.AsyncClient(timeout=120.0) as client: - resp = await client.get(url) - resp.raise_for_status() - cache_path.parent.mkdir(parents=True, exist_ok=True) - cache_path.write_bytes(resp.content) + async with httpx.AsyncClient(timeout=60.0) as client: + delay = 3.0 + last_spans = -1 + stall_count = 0 + max_stalls = 20 # give up after ~5 min of no progress + attempt = 0 + while True: + resp = await client.get(url) + if resp.status_code == 200: + cache_path.parent.mkdir(parents=True, exist_ok=True) + cache_path.write_bytes(resp.content) + break + elif resp.status_code == 202: + attempt += 1 + body = resp.json() + spans = body.get("spans_fetched", 0) + total = body.get("spans_total", 0) + elapsed = body.get("elapsed_seconds", "?") + pct = f" {100 * spans // total}%" if total > 0 else "" + LOGGER.info( + "Server is materializing events " + "(%s/%s spans%s, %ss elapsed). " + "Retrying in %.0fs... [attempt %d]", + spans, + total or "?", + pct, + elapsed, + delay, + attempt, + ) + if spans > last_spans: + last_spans = spans + stall_count = 0 + else: + stall_count += 1 + if stall_count >= max_stalls: + raise DojoZeroCLIError( + f"Server materialization stalled for " + f"{url} (no progress after " + f"{max_stalls} retries)" + ) + await asyncio.sleep(delay) + delay = min(delay * 1.5, 15.0) + else: + resp.raise_for_status() except httpx.HTTPStatusError as e: raise DojoZeroCLIError( f"Failed to download events from {url}: " f"HTTP {e.response.status_code}" ) from e + except DojoZeroCLIError: + raise except Exception as e: raise DojoZeroCLIError( f"Failed to download events from {url}: {e}" diff --git a/packages/dojozero/src/dojozero/core/_tracing.py b/packages/dojozero/src/dojozero/core/_tracing.py index 708ce744..8d051790 100644 --- a/packages/dojozero/src/dojozero/core/_tracing.py +++ b/packages/dojozero/src/dojozero/core/_tracing.py @@ -16,6 +16,7 @@ import os from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone +from collections.abc import Callable from typing import Any, Protocol from uuid import uuid4 import asyncio @@ -764,11 +765,45 @@ async def list_trials( LOGGER.error("Failed to parse SLS response for trials: %s", e) return [] + async def _get_count( + self, + resource: str, + url: str, + query: str, + from_time: datetime, + to_time: datetime, + ) -> int: + """Get total log count via SLS SQL count query. + + Returns 0 if the query fails (non-fatal). + """ + count_query = f"{query} | select count(1) as cnt" + params = { + "type": "log", + "from": str(int(from_time.timestamp())), + "to": str(int(to_time.timestamp())), + "query": count_query, + "line": "1", + "offset": "0", + } + headers = self._sign_request("GET", resource, params) + try: + response = await self._client.get(url, params=params, headers=headers) + response.raise_for_status() + data = response.json() + rows = data if isinstance(data, list) else data.get("data", []) + if rows and isinstance(rows[0], dict): + return int(rows[0].get("cnt", 0)) + except Exception as e: + LOGGER.debug("SLS count query failed (non-fatal): %s", e) + return 0 + async def get_spans( self, trial_id: str, start_time: datetime | None = None, operation_names: list[str] | None = None, + progress_callback: Callable[[int, int], None] | None = None, ) -> list[SpanData]: """Get spans for a trial from SLS. @@ -777,6 +812,9 @@ async def get_spans( start_time: If provided, only return spans after this time. operation_names: If provided, only return spans with operation_name in this list. Exact match with OR logic. None means no filtering. + progress_callback: If provided, called with ``(fetched, total)`` + after each page. ``total`` is obtained via a count query + before pagination begins. Returns: List of SpanData sorted by start time. @@ -805,6 +843,11 @@ async def get_spans( resource = f"/logstores/{self._logstore}" url = f"{self._get_base_url()}{resource}" + # If a progress callback is provided, get total count first via SQL + total_count = 0 + if progress_callback is not None: + total_count = await self._get_count(resource, url, query, from_time, now) + # Pagination: SLS GetLogs API limits to 100 rows per request in search mode # We paginate using offset parameter to get all data page_size = self._page_size @@ -858,6 +901,8 @@ async def get_spans( break # No more data all_rows.extend(rows) + if progress_callback is not None: + progress_callback(len(all_rows), total_count) if len(rows) < page_size: break # Last page (less than full page means no more data) diff --git a/packages/dojozero/src/dojozero/dashboard_server/_server.py b/packages/dojozero/src/dojozero/dashboard_server/_server.py index 0d081ace..90bf24a5 100644 --- a/packages/dojozero/src/dojozero/dashboard_server/_server.py +++ b/packages/dojozero/src/dojozero/dashboard_server/_server.py @@ -14,6 +14,7 @@ import logging import os import platform +import time from contextlib import asynccontextmanager from dataclasses import asdict, dataclass, field from pathlib import Path @@ -143,6 +144,85 @@ class TrialSourceRequest(BaseModel): config: TrialSourceConfigRequest +@dataclass +class _MaterializationEntry: + """Tracks a single in-flight SLS materialization task.""" + + task: asyncio.Task[Any] + trial_id: str + run_id: str | None + started_at: float # time.monotonic() + spans_fetched: int = 0 # updated by progress callback + spans_total: int = 0 # total from SLS count query + + +class MaterializationTracker: + """Track in-flight SLS materialization tasks. + + Keyed by ``(trial_id, run_id)`` so concurrent requests for the same + trial share a single background task. + """ + + def __init__(self) -> None: + self._tasks: dict[tuple[str, str | None], _MaterializationEntry] = {} + + def get(self, trial_id: str, run_id: str | None) -> _MaterializationEntry | None: + return self._tasks.get((trial_id, run_id)) + + def start( + self, + trial_id: str, + run_id: str | None, + cached_path: Path, + ) -> _MaterializationEntry: + """Start a background SLS materialization, or return an existing one.""" + key = (trial_id, run_id) + existing = self._tasks.get(key) + if existing is not None and not existing.task.done(): + return existing # already in flight + + entry = _MaterializationEntry( + task=asyncio.ensure_future(asyncio.sleep(0)), # replaced below + trial_id=trial_id, + run_id=run_id, + started_at=time.monotonic(), + ) + + def _on_progress(rows_fetched: int, total: int) -> None: + entry.spans_fetched = rows_fetched + entry.spans_total = total + + async def _do_materialize() -> Path: + from dojozero.data import SLSEventSource + + await SLSEventSource().materialize_jsonl( + trial_id, + cached_path, + run_id=run_id or None, + progress_callback=_on_progress, + ) + upload_trial_to_oss(trial_id, cached_path) + return cached_path + + task = asyncio.create_task(_do_materialize(), name=f"materialize-{trial_id}") + + def _on_done(t: asyncio.Task[Any]) -> None: + if not t.cancelled() and t.exception(): + LOGGER.error( + "Background materialization failed for '%s': %s", + trial_id, + t.exception(), + ) + + task.add_done_callback(_on_done) + entry.task = task + self._tasks[key] = entry + return entry + + def remove(self, trial_id: str, run_id: str | None) -> None: + self._tasks.pop((trial_id, run_id), None) + + @dataclass class DashboardServerState: """Shared state for the Dashboard Server.""" @@ -163,6 +243,9 @@ class DashboardServerState: scheduler_store: SchedulerStore | None = None http_session: Any = None # aiohttp.ClientSession for cross-server forwarding http_client: Any = None # httpx.AsyncClient for cross-server forwarding + materialization_tracker: MaterializationTracker = field( + default_factory=MaterializationTracker + ) def get_server_state(request: Request) -> DashboardServerState: @@ -1343,36 +1426,66 @@ async def download_trial_events( except Exception: pass # fall through to SLS fallback - # 5. SLS materialization (last resort) + # 5. SLS materialization (async, non-blocking via 202 polling) if event_file is None: - try: - from dojozero.data import SLSEventSource + tracker = state.materialization_tracker + entry = tracker.get(trial_id, run_id) - await SLSEventSource().materialize_jsonl( - trial_id, - cached_path, - run_id=run_id or None, + if entry is None: + # No in-flight task — kick one off and return 202 + entry = tracker.start(trial_id, run_id, cached_path) + return JSONResponse( + content={ + "status": "materializing", + "trial_id": trial_id, + "spans_fetched": 0, + "spans_total": 0, + "message": ( + "Materializing events from SLS. " + "Poll this endpoint to check progress." + ), + }, + status_code=202, ) - event_file = cached_path - # Upload to OSS for cross-host caching - upload_trial_to_oss(trial_id, cached_path) - except Exception as e: + + if not entry.task.done(): + # Still running — return 202 with progress + elapsed = time.monotonic() - entry.started_at + return JSONResponse( + content={ + "status": "materializing", + "trial_id": trial_id, + "elapsed_seconds": round(elapsed, 1), + "spans_fetched": entry.spans_fetched, + "spans_total": entry.spans_total, + "message": "Materialization in progress.", + }, + status_code=202, + ) + + # Task finished — check result + tracker.remove(trial_id, run_id) + exc = entry.task.exception() if not entry.task.cancelled() else None + if exc is not None: LOGGER.error( "SLS materialization failed for trial '%s': %s", trial_id, - e, - exc_info=True, + exc, + exc_info=exc, ) return JSONResponse( content={ + "status": "failed", "error": ( f"Event file for trial '{trial_id}' not available " - f"locally and SLS fallback failed: {e}" - ) + f"locally and SLS fallback failed: {exc}" + ), }, status_code=424, ) + event_file = cached_path + # Guard against serving empty files (e.g. from a previous failed # materialization). Remove the empty file so future requests retry. if event_file.stat().st_size == 0: diff --git a/packages/dojozero/src/dojozero/data/_sls_source.py b/packages/dojozero/src/dojozero/data/_sls_source.py index bb268b6c..da70e18e 100644 --- a/packages/dojozero/src/dojozero/data/_sls_source.py +++ b/packages/dojozero/src/dojozero/data/_sls_source.py @@ -39,6 +39,8 @@ ) if TYPE_CHECKING: + from collections.abc import Callable + from dojozero.core._tracing import TraceReader from dojozero.data._models import DataEvent @@ -108,6 +110,7 @@ async def materialize_jsonl( *, run_id: str | None = None, overwrite: bool = True, + progress_callback: "Callable[[int, int], None] | None" = None, ) -> MaterializeResult: """Write events for ``trial_id`` to ``dest`` as JSONL. @@ -137,7 +140,14 @@ async def materialize_jsonl( reader = self._reader or _make_reader() owns_reader = self._reader is None try: - spans = await reader.get_spans(trial_id) + if isinstance(reader, SLSTraceReader) and progress_callback is not None: + # SLSTraceReader extends the protocol with progress_callback + sls_reader: SLSTraceReader = reader + spans = await sls_reader.get_spans( + trial_id, progress_callback=progress_callback + ) + else: + spans = await reader.get_spans(trial_id) finally: if owns_reader: await reader.close() diff --git a/packages/dojozero/tests/test_cli_sls_polling.py b/packages/dojozero/tests/test_cli_sls_polling.py new file mode 100644 index 00000000..ec801a78 --- /dev/null +++ b/packages/dojozero/tests/test_cli_sls_polling.py @@ -0,0 +1,323 @@ +"""Tests for _resolve_event_files HTTP polling (202 materialization flow). + +Exercises the client side of the async SLS materialization contract: + +- 200 on first attempt: writes cache and returns immediately +- 202 -> 200: polls with exponential backoff, capped at 15s, then writes cache +- 202 with stalled progress: raises DojoZeroCLIError after max_stalls +- Non-2xx, non-202 response: raises DojoZeroCLIError (via raise_for_status) +- Progress log formatting: "?" and no percentage when spans_total is 0 +""" + +from __future__ import annotations + +import asyncio +from pathlib import Path +from typing import Any + +import httpx +import pytest + +from dojozero.cli import DojoZeroCLIError, _resolve_event_files + + +def _run(coro): + return asyncio.run(coro) + + +# --------------------------------------------------------------------------- +# Fake httpx.AsyncClient / Response +# --------------------------------------------------------------------------- + + +class _FakeHttpResponse: + def __init__( + self, + *, + status_code: int, + content: bytes = b"", + json_body: dict[str, Any] | None = None, + ) -> None: + self.status_code = status_code + self.content = content + self._json = json_body or {} + + def json(self) -> dict[str, Any]: + return self._json + + def raise_for_status(self) -> None: + if self.status_code >= 400: + req = httpx.Request("GET", "http://fake") + resp = httpx.Response(self.status_code, request=req) + raise httpx.HTTPStatusError( + f"HTTP {self.status_code}", request=req, response=resp + ) + + +class _FakeHttpClient: + """Both the AsyncClient class and instance — async context-manager enabled.""" + + def __init__(self, responses: list[_FakeHttpResponse]) -> None: + self._responses = responses + self._idx = 0 + self.call_log: list[str] = [] + + def __call__(self, *args: Any, **kwargs: Any) -> "_FakeHttpClient": + del args, kwargs + return self + + async def __aenter__(self) -> "_FakeHttpClient": + return self + + async def __aexit__(self, *args: Any) -> None: + del args + return None + + async def get(self, url: str) -> _FakeHttpResponse: + self.call_log.append(url) + # Saturate on the last response so overly-long tests surface as a + # stalled loop (caught by stall detection) rather than IndexError. + resp = self._responses[min(self._idx, len(self._responses) - 1)] + self._idx += 1 + return resp + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def fast_sleep(monkeypatch: pytest.MonkeyPatch) -> list[float]: + """Replace asyncio.sleep with a no-op that records requested delays.""" + delays: list[float] = [] + + async def _sleep(delay: float) -> None: + delays.append(delay) + + monkeypatch.setattr(asyncio, "sleep", _sleep) + return delays + + +def _install_fake_client( + monkeypatch: pytest.MonkeyPatch, responses: list[_FakeHttpResponse] +) -> _FakeHttpClient: + fake = _FakeHttpClient(responses) + monkeypatch.setattr(httpx, "AsyncClient", fake) + return fake + + +# --------------------------------------------------------------------------- +# 200 on first attempt — no polling +# --------------------------------------------------------------------------- + + +def test_http_url_200_on_first_attempt( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch, fast_sleep: list[float] +) -> None: + payload = b'{"event_type": "event.game_initialize"}\n' + fake = _install_fake_client( + monkeypatch, [_FakeHttpResponse(status_code=200, content=payload)] + ) + + files = _run( + _resolve_event_files( + ["http://server/api/trials/trial-A/events.jsonl"], + sls_cache_dir=tmp_path, + ) + ) + + assert len(files) == 1 + assert files[0].read_bytes() == payload + assert len(fake.call_log) == 1 + # Success on first attempt — no polling, no sleeps. + assert fast_sleep == [] + + +# --------------------------------------------------------------------------- +# 202 -> 200 happy path with backoff +# --------------------------------------------------------------------------- + + +def test_http_url_polls_until_200( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch, fast_sleep: list[float] +) -> None: + payload = b'{"event_type": "event.game_initialize"}\n' + _install_fake_client( + monkeypatch, + [ + _FakeHttpResponse( + status_code=202, + json_body={"spans_fetched": 0, "spans_total": 0}, + ), + _FakeHttpResponse( + status_code=202, + json_body={"spans_fetched": 10, "spans_total": 100}, + ), + _FakeHttpResponse( + status_code=202, + json_body={"spans_fetched": 50, "spans_total": 100}, + ), + _FakeHttpResponse(status_code=200, content=payload), + ], + ) + + files = _run( + _resolve_event_files( + ["http://server/api/trials/trial-A/events.jsonl"], + sls_cache_dir=tmp_path, + ) + ) + + assert files[0].read_bytes() == payload + # Three 202 responses → three sleeps with geometric backoff (3.0, 4.5, 6.75). + assert fast_sleep == pytest.approx([3.0, 4.5, 6.75]) + + +# --------------------------------------------------------------------------- +# Backoff caps at 15s +# --------------------------------------------------------------------------- + + +def test_http_url_backoff_caps_at_15_seconds( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch, fast_sleep: list[float] +) -> None: + # Each 202 reports growing progress so stall detection stays reset. + responses: list[_FakeHttpResponse] = [ + _FakeHttpResponse( + status_code=202, + json_body={"spans_fetched": i, "spans_total": 100}, + ) + for i in range(12) + ] + responses.append(_FakeHttpResponse(status_code=200, content=b"")) + _install_fake_client(monkeypatch, responses) + + _run( + _resolve_event_files( + ["http://server/api/trials/trial-A/events.jsonl"], + sls_cache_dir=tmp_path, + ) + ) + + # Backoff grows 3.0 * 1.5^n until capped at 15.0; verify plateau. + assert 15.0 in fast_sleep + assert max(fast_sleep) == 15.0 + + +# --------------------------------------------------------------------------- +# Stall detection: 20 consecutive 202s with no new spans → DojoZeroCLIError +# --------------------------------------------------------------------------- + + +@pytest.mark.usefixtures("fast_sleep") +def test_http_url_stall_detection_raises( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + # spans_fetched=5 throughout. First 202 counts as "initial progress" since + # last_spans starts at -1; subsequent 20 without growth hit max_stalls=20. + stalled = _FakeHttpResponse( + status_code=202, + json_body={"spans_fetched": 5, "spans_total": 100}, + ) + _install_fake_client(monkeypatch, [stalled] * 25) + + with pytest.raises(DojoZeroCLIError, match="stalled"): + _run( + _resolve_event_files( + ["http://server/api/trials/trial-A/events.jsonl"], + sls_cache_dir=tmp_path, + ) + ) + + +# --------------------------------------------------------------------------- +# Unexpected non-2xx status → DojoZeroCLIError via raise_for_status +# --------------------------------------------------------------------------- + + +@pytest.mark.usefixtures("fast_sleep") +def test_http_url_unexpected_status_raises( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + _install_fake_client(monkeypatch, [_FakeHttpResponse(status_code=500)]) + + with pytest.raises(DojoZeroCLIError, match="HTTP 500"): + _run( + _resolve_event_files( + ["http://server/api/trials/trial-A/events.jsonl"], + sls_cache_dir=tmp_path, + ) + ) + + +# --------------------------------------------------------------------------- +# Progress log formatting — percentage omitted when total == 0 +# --------------------------------------------------------------------------- + + +@pytest.mark.usefixtures("fast_sleep") +def test_http_url_progress_log_handles_zero_total( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +) -> None: + _install_fake_client( + monkeypatch, + [ + _FakeHttpResponse( + status_code=202, + json_body={"spans_fetched": 5, "spans_total": 0}, + ), + _FakeHttpResponse(status_code=200, content=b""), + ], + ) + + with caplog.at_level("INFO", logger="dojozero.cli"): + _run( + _resolve_event_files( + ["http://server/api/trials/trial-A/events.jsonl"], + sls_cache_dir=tmp_path, + ) + ) + + materializing_msgs = [ + r.getMessage() for r in caplog.records if "materializing" in r.getMessage() + ] + assert materializing_msgs, "expected at least one materialization log line" + msg = materializing_msgs[0] + # "?" placeholder for total and no percentage when total == 0. + assert "5/?" in msg + assert "%" not in msg + + +@pytest.mark.usefixtures("fast_sleep") +def test_http_url_progress_log_shows_percent_when_total_known( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +) -> None: + _install_fake_client( + monkeypatch, + [ + _FakeHttpResponse( + status_code=202, + json_body={"spans_fetched": 25, "spans_total": 100}, + ), + _FakeHttpResponse(status_code=200, content=b""), + ], + ) + + with caplog.at_level("INFO", logger="dojozero.cli"): + _run( + _resolve_event_files( + ["http://server/api/trials/trial-A/events.jsonl"], + sls_cache_dir=tmp_path, + ) + ) + + msg = next( + r.getMessage() for r in caplog.records if "materializing" in r.getMessage() + ) + assert "25/100" in msg + assert "25%" in msg diff --git a/packages/dojozero/tests/test_materialization_tracker.py b/packages/dojozero/tests/test_materialization_tracker.py new file mode 100644 index 00000000..80c2a783 --- /dev/null +++ b/packages/dojozero/tests/test_materialization_tracker.py @@ -0,0 +1,306 @@ +"""Tests for MaterializationTracker — background SLS materialization coordinator. + +Covers the new 202-polling machinery on the dashboard server: + +- start() spawns a task, forwards run_id + progress callback, and uploads on success +- concurrent requests for the same (trial_id, run_id) coalesce onto one task +- distinct run_ids do NOT coalesce +- completed entries are replaced on the next start() +- failures are surfaced via task.exception() (no upload on failure) +- remove() drops the entry; remove() on an absent key is a no-op +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Callable +from pathlib import Path + +import pytest + +from dojozero.dashboard_server._server import MaterializationTracker + + +# --------------------------------------------------------------------------- +# Fake SLSEventSource +# --------------------------------------------------------------------------- + + +class _FakeSource: + """Stand-in for SLSEventSource driving deterministic test scenarios. + + Signature matches the call site in ``_do_materialize``: + await SLSEventSource().materialize_jsonl( + trial_id, cached_path, run_id=..., progress_callback=... + ) + """ + + def __init__( + self, + *, + progress_steps: list[tuple[int, int]] | None = None, + raises: Exception | None = None, + started_event: asyncio.Event | None = None, + proceed_event: asyncio.Event | None = None, + ) -> None: + self._progress_steps = progress_steps or [] + self._raises = raises + self._started_event = started_event + self._proceed_event = proceed_event + self.calls: list[tuple[str, Path, str | None]] = [] + self.saw_progress_callback: Callable[[int, int], None] | None = None + + async def materialize_jsonl( + self, + trial_id: str, + dest: Path, + *, + run_id: str | None = None, + progress_callback: Callable[[int, int], None] | None = None, + ) -> None: + self.calls.append((trial_id, dest, run_id)) + self.saw_progress_callback = progress_callback + for fetched, total in self._progress_steps: + if progress_callback is not None: + progress_callback(fetched, total) + # Yield so observers can read intermediate state between steps. + await asyncio.sleep(0) + if self._started_event is not None: + self._started_event.set() + if self._proceed_event is not None: + await self._proceed_event.wait() + if self._raises is not None: + raise self._raises + + +@pytest.fixture +def patch_backends(monkeypatch: pytest.MonkeyPatch): + """Install a fake SLSEventSource + upload_trial_to_oss. + + Returns a callable that installs a fake source for the test body and + returns ``(fake, upload_calls)``. + """ + upload_calls: list[tuple[str, Path]] = [] + + import dojozero.data as data_pkg + import dojozero.dashboard_server._server as server_mod + + def _install(**kwargs) -> tuple[_FakeSource, list[tuple[str, Path]]]: + fake = _FakeSource(**kwargs) + monkeypatch.setattr(data_pkg, "SLSEventSource", lambda: fake) + monkeypatch.setattr( + server_mod, + "upload_trial_to_oss", + lambda trial_id, path: upload_calls.append((trial_id, path)), + ) + return fake, upload_calls + + return _install + + +# --------------------------------------------------------------------------- +# get() basic behavior +# --------------------------------------------------------------------------- + + +def test_get_returns_none_for_unknown_key() -> None: + tracker = MaterializationTracker() + assert tracker.get("nonexistent", None) is None + assert tracker.get("nonexistent", "some-run") is None + + +# --------------------------------------------------------------------------- +# start() — happy path +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_start_runs_materialize_and_uploads( + tmp_path: Path, patch_backends +) -> None: + fake, upload_calls = patch_backends(progress_steps=[(10, 100), (100, 100)]) + tracker = MaterializationTracker() + dest = tmp_path / "trial-A.jsonl" + + entry = tracker.start("trial-A", None, dest) + await entry.task + + assert fake.calls == [("trial-A", dest, None)] + assert upload_calls == [("trial-A", dest)] + # Entry reflects the last progress tuple after the task drains. + assert entry.spans_fetched == 100 + assert entry.spans_total == 100 + assert entry.task.exception() is None + + +@pytest.mark.asyncio +async def test_start_forwards_run_id(tmp_path: Path, patch_backends) -> None: + fake, _ = patch_backends() + tracker = MaterializationTracker() + dest = tmp_path / "trial-A.jsonl" + + entry = tracker.start("trial-A", "rootBig", dest) + await entry.task + + assert fake.calls == [("trial-A", dest, "rootBig")] + + +@pytest.mark.asyncio +async def test_entry_started_at_is_monotonic(tmp_path: Path, patch_backends) -> None: + import time + + patch_backends() + tracker = MaterializationTracker() + dest = tmp_path / "trial-A.jsonl" + + before = time.monotonic() + entry = tracker.start("trial-A", None, dest) + after = time.monotonic() + await entry.task + + assert before <= entry.started_at <= after + + +# --------------------------------------------------------------------------- +# Coalescing: concurrent start() for the same key reuses the in-flight task +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_start_coalesces_concurrent_same_key( + tmp_path: Path, patch_backends +) -> None: + started = asyncio.Event() + proceed = asyncio.Event() + patch_backends(started_event=started, proceed_event=proceed) + tracker = MaterializationTracker() + dest = tmp_path / "trial-A.jsonl" + + entry1 = tracker.start("trial-A", None, dest) + await started.wait() # ensure task is actually in-flight + entry2 = tracker.start("trial-A", None, dest) + + assert entry1 is entry2 + assert entry1.task is entry2.task + + proceed.set() + await entry1.task + + +@pytest.mark.asyncio +async def test_start_does_not_coalesce_different_run_ids( + tmp_path: Path, patch_backends +) -> None: + patch_backends() + tracker = MaterializationTracker() + dest = tmp_path / "trial-A.jsonl" + + entry_a = tracker.start("trial-A", "rootA", dest) + entry_b = tracker.start("trial-A", "rootB", dest) + + assert entry_a is not entry_b + assert entry_a.task is not entry_b.task + + await asyncio.gather(entry_a.task, entry_b.task) + + +# --------------------------------------------------------------------------- +# Replacement: once a task is done, the next start() creates a fresh one +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_start_replaces_completed_entry(tmp_path: Path, patch_backends) -> None: + patch_backends() + tracker = MaterializationTracker() + dest = tmp_path / "trial-A.jsonl" + + first = tracker.start("trial-A", None, dest) + await first.task + assert first.task.done() + + second = tracker.start("trial-A", None, dest) + assert second is not first + assert second.task is not first.task + + await second.task + + +# --------------------------------------------------------------------------- +# Failure: exception surfaces via task.exception(); upload skipped +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_materialize_failure_captured_in_task( + tmp_path: Path, patch_backends +) -> None: + boom = RuntimeError("SLS said no") + _, upload_calls = patch_backends(raises=boom) + tracker = MaterializationTracker() + dest = tmp_path / "trial-A.jsonl" + + entry = tracker.start("trial-A", None, dest) + # Consume the exception so it doesn't surface as an unhandled-task warning. + await asyncio.gather(entry.task, return_exceptions=True) + + assert entry.task.done() + exc = entry.task.exception() + assert isinstance(exc, RuntimeError) + assert str(exc) == "SLS said no" + # Upload happens AFTER materialize; a failed materialize must not upload. + assert upload_calls == [] + + +# --------------------------------------------------------------------------- +# Progress callback wiring — entry fields mutate mid-flight +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_progress_callback_updates_entry_fields_mid_flight( + tmp_path: Path, patch_backends +) -> None: + proceed = asyncio.Event() + fake, _ = patch_backends(progress_steps=[(42, 100)], proceed_event=proceed) + tracker = MaterializationTracker() + dest = tmp_path / "trial-A.jsonl" + + entry = tracker.start("trial-A", None, dest) + # Wait until at least the first progress step has been emitted. + while entry.spans_fetched == 0 and not entry.task.done(): + await asyncio.sleep(0) + + # The entry — which the endpoint reads on every poll — reflects the + # latest progress callback invocation. + assert entry.spans_fetched == 42 + assert entry.spans_total == 100 + assert fake.saw_progress_callback is not None + + proceed.set() + await entry.task + + +# --------------------------------------------------------------------------- +# remove() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_remove_drops_entry(tmp_path: Path, patch_backends) -> None: + patch_backends() + tracker = MaterializationTracker() + dest = tmp_path / "trial-A.jsonl" + + entry = tracker.start("trial-A", None, dest) + await entry.task + assert tracker.get("trial-A", None) is entry + + tracker.remove("trial-A", None) + assert tracker.get("trial-A", None) is None + + +def test_remove_absent_key_is_noop() -> None: + tracker = MaterializationTracker() + tracker.remove("nonexistent", None) + tracker.remove("nonexistent", "some-run") # also fine diff --git a/packages/dojozero/tests/test_sls_source.py b/packages/dojozero/tests/test_sls_source.py index e25b28fd..3b4c0507 100644 --- a/packages/dojozero/tests/test_sls_source.py +++ b/packages/dojozero/tests/test_sls_source.py @@ -11,7 +11,7 @@ import pytest -from dojozero.core._tracing import SpanData +from dojozero.core._tracing import SLSTraceReader, SpanData from dojozero.data import ( GameInitializeEvent, GameResultEvent, @@ -734,6 +734,96 @@ def test_injected_reader_not_closed(tmp_path: Path) -> None: assert reader.close_calls == 0 # Caller owns the reader. +# --------------------------------------------------------------------------- +# progress_callback passthrough +# --------------------------------------------------------------------------- + + +class _FakeSLSReader(SLSTraceReader): + """SLSTraceReader subclass that bypasses network setup for isinstance tests. + + materialize_jsonl checks ``isinstance(reader, SLSTraceReader)`` before + forwarding progress_callback, so a bare _FakeReader won't exercise that + branch. + """ + + def __init__(self, spans: list[SpanData]) -> None: + # Skip super().__init__ — it would spin up httpx/credentials which + # we neither have nor need for this test. + self._spans = spans + self.received_progress_callback: Any = None + + async def get_spans( # type: ignore[override] + self, + trial_id: str, + start_time: datetime | None = None, + operation_names: list[str] | None = None, + progress_callback: Any = None, + ) -> list[SpanData]: + self.received_progress_callback = progress_callback + return list(self._spans) + + async def close(self) -> None: # pragma: no cover - not called (not owned) + pass + + +def test_progress_callback_forwarded_to_sls_reader(tmp_path: Path) -> None: + """When reader is an SLSTraceReader, progress_callback is passed through.""" + event = _make_init() + root = _root_span("trial-A") + reader = _FakeSLSReader( + [root, _event_span(event, trace_id="trial-A", parent_span_id=root.span_id)] + ) + source = SLSEventSource(reader=reader) + + calls: list[tuple[int, int]] = [] + + def _on_progress(fetched: int, total: int) -> None: + calls.append((fetched, total)) + + _run( + source.materialize_jsonl( + "trial-A", + tmp_path / "trial-A.jsonl", + progress_callback=_on_progress, + ) + ) + + # The fake reader records the callback it received — identity check. + assert reader.received_progress_callback is _on_progress + + +def test_progress_callback_dropped_for_non_sls_reader(tmp_path: Path) -> None: + """A bare TraceReader stand-in shouldn't crash on unexpected progress_callback. + + The isinstance guard skips the progress branch and calls + ``reader.get_spans(trial_id)`` without kwargs — so a reader whose + get_spans signature doesn't accept progress_callback still works. + """ + event = _make_init() + root = _root_span("trial-A") + # _FakeReader.get_spans has NO progress_callback parameter — this would + # raise TypeError if the guard were missing. + reader = _FakeReader( + [root, _event_span(event, trace_id="trial-A", parent_span_id=root.span_id)] + ) + source = SLSEventSource(reader=reader) + + dest = tmp_path / "trial-A.jsonl" + _run( + source.materialize_jsonl( + "trial-A", + dest, + progress_callback=lambda fetched, total: None, + ) + ) + + # Materialization still produced output — progress_callback was silently + # ignored, not raised. + assert dest.exists() + assert len(dest.read_text().splitlines()) == 1 + + # --------------------------------------------------------------------------- # Missing env vars → clear error # --------------------------------------------------------------------------- diff --git a/packages/dojozero/tests/test_sls_trace_reader_progress.py b/packages/dojozero/tests/test_sls_trace_reader_progress.py new file mode 100644 index 00000000..9338fc90 --- /dev/null +++ b/packages/dojozero/tests/test_sls_trace_reader_progress.py @@ -0,0 +1,302 @@ +"""Tests for SLSTraceReader._get_count and progress_callback in get_spans. + +These paths were added to support async 202-polling of SLS materialization so +the dashboard server can report progress to the CLI client. Both the count +query and the per-page progress callback are new surface area. +""" + +from __future__ import annotations + +import asyncio +from datetime import datetime, timezone +from typing import Any +from unittest.mock import AsyncMock + +import httpx +import pytest + +from dojozero.core._credentials import Credentials +from dojozero.core._tracing import SLSTraceReader + + +def _run(coro): + return asyncio.run(coro) + + +class _DummyCredentialProvider: + """Returns fixed, non-empty credentials without any network lookup. + + The real CredentialProvider delegates to the alibabacloud SDK, which + probes ECS metadata at 100.100.100.200 when no local credentials are + configured — fine locally, but in CI this hangs for ~60s then fails. + """ + + def get_credentials(self) -> Credentials: + return Credentials(access_key_id="test-ak", access_key_secret="test-secret") + + +@pytest.fixture(autouse=True) +def _stub_credentials(monkeypatch: pytest.MonkeyPatch) -> None: + """Stub credential lookup for every test in this module. + + SLSTraceReader.__init__ calls ``get_credential_provider().get_credentials()``. + Without this fixture, that triggers the alibabacloud SDK's full credential + chain — including an ECS metadata probe that times out in CI. + """ + from dojozero.core import _credentials as cred_mod + + monkeypatch.setattr(cred_mod, "_provider", None) + monkeypatch.setattr( + cred_mod, "get_credential_provider", lambda: _DummyCredentialProvider() + ) + + +def _make_reader(page_size: int = 100) -> SLSTraceReader: + """Construct a reader with a stubbed credential provider. + + The ``_stub_credentials`` autouse fixture must be active — otherwise + SLSTraceReader.__init__ will probe ECS metadata on hosts without local + Alibaba Cloud credentials configured. + """ + return SLSTraceReader( + endpoint="fake.log.aliyuncs.com", + project="fake-project", + logstore="fake-store", + page_size=page_size, + ) + + +class _FakeResponse: + """Minimal httpx.Response stand-in sufficient for SLSTraceReader paths.""" + + def __init__( + self, + *, + status_code: int = 200, + json_body: Any = None, + text: str = "", + headers: dict[str, str] | None = None, + ) -> None: + self.status_code = status_code + self._json = json_body + self.text = text + self.headers = headers or {"x-log-progress": "Complete"} + + def json(self) -> Any: + return self._json + + def raise_for_status(self) -> None: + if self.status_code >= 400: + req = httpx.Request("GET", "http://fake") + resp = httpx.Response(self.status_code, request=req) + raise httpx.HTTPStatusError( + f"HTTP {self.status_code}", request=req, response=resp + ) + + +def _from_time() -> datetime: + return datetime(2026, 4, 1, tzinfo=timezone.utc) + + +def _to_time() -> datetime: + return datetime(2026, 4, 2, tzinfo=timezone.utc) + + +# --------------------------------------------------------------------------- +# _get_count +# --------------------------------------------------------------------------- + + +def test_get_count_parses_list_response() -> None: + reader = _make_reader() + reader._client.get = AsyncMock( # type: ignore[method-assign] + return_value=_FakeResponse(json_body=[{"cnt": "1234"}]) + ) + + count = _run( + reader._get_count( + resource="/logstores/fake", + url="http://fake", + query='_service:"dojozero" AND _trace_id:"trial-A"', + from_time=_from_time(), + to_time=_to_time(), + ) + ) + + assert count == 1234 + # Query is extended with a SQL count aggregation. + call_kwargs = reader._client.get.call_args.kwargs + assert call_kwargs["params"]["query"].endswith("| select count(1) as cnt") + + +def test_get_count_parses_dict_response() -> None: + reader = _make_reader() + reader._client.get = AsyncMock( # type: ignore[method-assign] + return_value=_FakeResponse(json_body={"data": [{"cnt": 42}]}) + ) + + count = _run( + reader._get_count( + resource="/logstores/fake", + url="http://fake", + query="q", + from_time=_from_time(), + to_time=_to_time(), + ) + ) + + assert count == 42 + + +def test_get_count_returns_zero_on_empty_rows() -> None: + reader = _make_reader() + reader._client.get = AsyncMock( # type: ignore[method-assign] + return_value=_FakeResponse(json_body=[]) + ) + + count = _run( + reader._get_count( + resource="/logstores/fake", + url="http://fake", + query="q", + from_time=_from_time(), + to_time=_to_time(), + ) + ) + + assert count == 0 + + +def test_get_count_returns_zero_on_http_error() -> None: + reader = _make_reader() + reader._client.get = AsyncMock( # type: ignore[method-assign] + return_value=_FakeResponse(status_code=500, text="boom") + ) + + count = _run( + reader._get_count( + resource="/logstores/fake", + url="http://fake", + query="q", + from_time=_from_time(), + to_time=_to_time(), + ) + ) + + # Failure is non-fatal — returns 0 so materialization continues. + assert count == 0 + + +def test_get_count_returns_zero_on_transport_error() -> None: + reader = _make_reader() + reader._client.get = AsyncMock( # type: ignore[method-assign] + side_effect=httpx.ConnectError("boom"), + ) + + count = _run( + reader._get_count( + resource="/logstores/fake", + url="http://fake", + query="q", + from_time=_from_time(), + to_time=_to_time(), + ) + ) + + assert count == 0 + + +# --------------------------------------------------------------------------- +# get_spans(progress_callback=...) +# --------------------------------------------------------------------------- + + +def _row(span_id: str) -> dict[str, Any]: + """SLS row shape — only the fields downstream code reads matter here. + + We want rows to be *counted* in `all_rows`; whether they later convert + into SpanData doesn't affect progress-callback firing, which happens + on pagination boundaries before conversion. + """ + return { + "_trace_id": "trial-A", + "_span_id": span_id, + "_operation_name": "event.test", + "_start_time": "1700000000000000", + "_duration": "0", + } + + +def _scripted_get(responses: list[_FakeResponse]) -> AsyncMock: + """AsyncMock that yields responses in order, one per call. + + Tests are expected to exhaust exactly the number of responses provided + — extra calls raise StopIteration, which surfaces as a test failure. + """ + return AsyncMock(side_effect=list(responses)) + + +def test_get_spans_calls_progress_callback_per_page() -> None: + """Each pagination response triggers a progress call with (fetched, total).""" + reader = _make_reader(page_size=2) + + # Count query answers 5 total, then three pages of rows: 2, 2, 1. + reader._client.get = _scripted_get( # type: ignore[method-assign] + [ + _FakeResponse(json_body=[{"cnt": "5"}]), # count query + _FakeResponse(json_body=[_row("a"), _row("b")]), + _FakeResponse(json_body=[_row("c"), _row("d")]), + _FakeResponse(json_body=[_row("e")]), + ] + ) + + progress: list[tuple[int, int]] = [] + _run( + reader.get_spans( + "trial-A", + progress_callback=lambda fetched, total: progress.append((fetched, total)), + ) + ) + + # Cumulative fetched count after each page; total comes from the count query. + assert progress == [(2, 5), (4, 5), (5, 5)] + + +def test_get_spans_skips_count_when_no_progress_callback() -> None: + reader = _make_reader(page_size=10) + mock_get = _scripted_get( + [_FakeResponse(json_body=[_row("a")])] # single short page + ) + reader._client.get = mock_get # type: ignore[method-assign] + + _run(reader.get_spans("trial-A")) + + # Without progress_callback, we issue one page fetch and no count query. + assert mock_get.call_count == 1 + call = mock_get.call_args + assert "select count" not in call.kwargs["params"]["query"] + + +def test_get_spans_count_query_runs_once_before_pagination() -> None: + reader = _make_reader(page_size=10) + mock_get = _scripted_get( + [ + _FakeResponse(json_body=[{"cnt": "1"}]), + _FakeResponse(json_body=[_row("a")]), + ] + ) + reader._client.get = mock_get # type: ignore[method-assign] + + progress: list[tuple[int, int]] = [] + _run( + reader.get_spans( + "trial-A", + progress_callback=lambda f, t: progress.append((f, t)), + ) + ) + + # 1 count query + 1 page. + assert mock_get.call_count == 2 + first_call = mock_get.call_args_list[0] + assert "select count" in first_call.kwargs["params"]["query"] + assert progress == [(1, 1)]