Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 47 additions & 5 deletions packages/dojozero/src/dojozero/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,16 +1547,58 @@ async def _resolve_event_files(
if not cache_path.exists():
LOGGER.info("Downloading events from %s", url)
try:
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.get(url)
resp.raise_for_status()
cache_path.parent.mkdir(parents=True, exist_ok=True)
cache_path.write_bytes(resp.content)
async with httpx.AsyncClient(timeout=60.0) as client:
delay = 3.0
last_spans = -1
stall_count = 0
max_stalls = 20 # give up after ~5 min of no progress
attempt = 0
while True:
resp = await client.get(url)
if resp.status_code == 200:
cache_path.parent.mkdir(parents=True, exist_ok=True)
cache_path.write_bytes(resp.content)
break
elif resp.status_code == 202:
attempt += 1
body = resp.json()
spans = body.get("spans_fetched", 0)
total = body.get("spans_total", 0)
elapsed = body.get("elapsed_seconds", "?")
pct = f" {100 * spans // total}%" if total > 0 else ""
LOGGER.info(
"Server is materializing events "
"(%s/%s spans%s, %ss elapsed). "
"Retrying in %.0fs... [attempt %d]",
spans,
total or "?",
pct,
elapsed,
delay,
attempt,
)
if spans > last_spans:
last_spans = spans
stall_count = 0
else:
stall_count += 1
if stall_count >= max_stalls:
raise DojoZeroCLIError(
f"Server materialization stalled for "
f"{url} (no progress after "
f"{max_stalls} retries)"
)
await asyncio.sleep(delay)
delay = min(delay * 1.5, 15.0)
else:
resp.raise_for_status()
except httpx.HTTPStatusError as e:
raise DojoZeroCLIError(
f"Failed to download events from {url}: "
f"HTTP {e.response.status_code}"
) from e
except DojoZeroCLIError:
raise
except Exception as e:
raise DojoZeroCLIError(
f"Failed to download events from {url}: {e}"
Expand Down
45 changes: 45 additions & 0 deletions packages/dojozero/src/dojozero/core/_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from collections.abc import Callable
from typing import Any, Protocol
from uuid import uuid4
import asyncio
Expand Down Expand Up @@ -764,11 +765,45 @@ async def list_trials(
LOGGER.error("Failed to parse SLS response for trials: %s", e)
return []

async def _get_count(
self,
resource: str,
url: str,
query: str,
from_time: datetime,
to_time: datetime,
) -> int:
"""Get total log count via SLS SQL count query.

Returns 0 if the query fails (non-fatal).
"""
count_query = f"{query} | select count(1) as cnt"
params = {
"type": "log",
"from": str(int(from_time.timestamp())),
"to": str(int(to_time.timestamp())),
"query": count_query,
"line": "1",
"offset": "0",
}
headers = self._sign_request("GET", resource, params)
try:
response = await self._client.get(url, params=params, headers=headers)
response.raise_for_status()
data = response.json()
rows = data if isinstance(data, list) else data.get("data", [])
if rows and isinstance(rows[0], dict):
return int(rows[0].get("cnt", 0))
except Exception as e:
LOGGER.debug("SLS count query failed (non-fatal): %s", e)
return 0

async def get_spans(
self,
trial_id: str,
start_time: datetime | None = None,
operation_names: list[str] | None = None,
progress_callback: Callable[[int, int], None] | None = None,
) -> list[SpanData]:
"""Get spans for a trial from SLS.

Expand All @@ -777,6 +812,9 @@ async def get_spans(
start_time: If provided, only return spans after this time.
operation_names: If provided, only return spans with operation_name in this list.
Exact match with OR logic. None means no filtering.
progress_callback: If provided, called with ``(fetched, total)``
after each page. ``total`` is obtained via a count query
before pagination begins.

Returns:
List of SpanData sorted by start time.
Expand Down Expand Up @@ -805,6 +843,11 @@ async def get_spans(
resource = f"/logstores/{self._logstore}"
url = f"{self._get_base_url()}{resource}"

# If a progress callback is provided, get total count first via SQL
total_count = 0
if progress_callback is not None:
total_count = await self._get_count(resource, url, query, from_time, now)

# Pagination: SLS GetLogs API limits to 100 rows per request in search mode
# We paginate using offset parameter to get all data
page_size = self._page_size
Expand Down Expand Up @@ -858,6 +901,8 @@ async def get_spans(
break # No more data

all_rows.extend(rows)
if progress_callback is not None:
progress_callback(len(all_rows), total_count)

if len(rows) < page_size:
break # Last page (less than full page means no more data)
Expand Down
143 changes: 128 additions & 15 deletions packages/dojozero/src/dojozero/dashboard_server/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import logging
import os
import platform
import time
from contextlib import asynccontextmanager
from dataclasses import asdict, dataclass, field
from pathlib import Path
Expand Down Expand Up @@ -143,6 +144,85 @@ class TrialSourceRequest(BaseModel):
config: TrialSourceConfigRequest


@dataclass
class _MaterializationEntry:
"""Tracks a single in-flight SLS materialization task."""

task: asyncio.Task[Any]
trial_id: str
run_id: str | None
started_at: float # time.monotonic()
spans_fetched: int = 0 # updated by progress callback
spans_total: int = 0 # total from SLS count query


class MaterializationTracker:
"""Track in-flight SLS materialization tasks.

Keyed by ``(trial_id, run_id)`` so concurrent requests for the same
trial share a single background task.
"""

def __init__(self) -> None:
self._tasks: dict[tuple[str, str | None], _MaterializationEntry] = {}

def get(self, trial_id: str, run_id: str | None) -> _MaterializationEntry | None:
return self._tasks.get((trial_id, run_id))

def start(
self,
trial_id: str,
run_id: str | None,
cached_path: Path,
) -> _MaterializationEntry:
"""Start a background SLS materialization, or return an existing one."""
key = (trial_id, run_id)
existing = self._tasks.get(key)
if existing is not None and not existing.task.done():
return existing # already in flight

entry = _MaterializationEntry(
task=asyncio.ensure_future(asyncio.sleep(0)), # replaced below
trial_id=trial_id,
run_id=run_id,
started_at=time.monotonic(),
)

def _on_progress(rows_fetched: int, total: int) -> None:
entry.spans_fetched = rows_fetched
entry.spans_total = total

async def _do_materialize() -> Path:
from dojozero.data import SLSEventSource

await SLSEventSource().materialize_jsonl(
trial_id,
cached_path,
run_id=run_id or None,
progress_callback=_on_progress,
)
upload_trial_to_oss(trial_id, cached_path)
return cached_path

task = asyncio.create_task(_do_materialize(), name=f"materialize-{trial_id}")

def _on_done(t: asyncio.Task[Any]) -> None:
if not t.cancelled() and t.exception():
LOGGER.error(
"Background materialization failed for '%s': %s",
trial_id,
t.exception(),
)

task.add_done_callback(_on_done)
entry.task = task
self._tasks[key] = entry
return entry

def remove(self, trial_id: str, run_id: str | None) -> None:
self._tasks.pop((trial_id, run_id), None)


@dataclass
class DashboardServerState:
"""Shared state for the Dashboard Server."""
Expand All @@ -163,6 +243,9 @@ class DashboardServerState:
scheduler_store: SchedulerStore | None = None
http_session: Any = None # aiohttp.ClientSession for cross-server forwarding
http_client: Any = None # httpx.AsyncClient for cross-server forwarding
materialization_tracker: MaterializationTracker = field(
default_factory=MaterializationTracker
)


def get_server_state(request: Request) -> DashboardServerState:
Expand Down Expand Up @@ -1343,36 +1426,66 @@ async def download_trial_events(
except Exception:
pass # fall through to SLS fallback

# 5. SLS materialization (last resort)
# 5. SLS materialization (async, non-blocking via 202 polling)
if event_file is None:
try:
from dojozero.data import SLSEventSource
tracker = state.materialization_tracker
entry = tracker.get(trial_id, run_id)

await SLSEventSource().materialize_jsonl(
trial_id,
cached_path,
run_id=run_id or None,
if entry is None:
# No in-flight task — kick one off and return 202
entry = tracker.start(trial_id, run_id, cached_path)
return JSONResponse(
content={
"status": "materializing",
"trial_id": trial_id,
"spans_fetched": 0,
"spans_total": 0,
"message": (
"Materializing events from SLS. "
"Poll this endpoint to check progress."
),
},
status_code=202,
)
event_file = cached_path
# Upload to OSS for cross-host caching
upload_trial_to_oss(trial_id, cached_path)
except Exception as e:

if not entry.task.done():
# Still running — return 202 with progress
elapsed = time.monotonic() - entry.started_at
return JSONResponse(
content={
"status": "materializing",
"trial_id": trial_id,
"elapsed_seconds": round(elapsed, 1),
"spans_fetched": entry.spans_fetched,
"spans_total": entry.spans_total,
"message": "Materialization in progress.",
},
status_code=202,
)

# Task finished — check result
tracker.remove(trial_id, run_id)
exc = entry.task.exception() if not entry.task.cancelled() else None
if exc is not None:
LOGGER.error(
"SLS materialization failed for trial '%s': %s",
trial_id,
e,
exc_info=True,
exc,
exc_info=exc,
)
return JSONResponse(
content={
"status": "failed",
"error": (
f"Event file for trial '{trial_id}' not available "
f"locally and SLS fallback failed: {e}"
)
f"locally and SLS fallback failed: {exc}"
),
},
status_code=424,
)

event_file = cached_path

# Guard against serving empty files (e.g. from a previous failed
# materialization). Remove the empty file so future requests retry.
if event_file.stat().st_size == 0:
Expand Down
12 changes: 11 additions & 1 deletion packages/dojozero/src/dojozero/data/_sls_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
)

if TYPE_CHECKING:
from collections.abc import Callable

from dojozero.core._tracing import TraceReader
from dojozero.data._models import DataEvent

Expand Down Expand Up @@ -108,6 +110,7 @@ async def materialize_jsonl(
*,
run_id: str | None = None,
overwrite: bool = True,
progress_callback: "Callable[[int, int], None] | None" = None,
) -> MaterializeResult:
"""Write events for ``trial_id`` to ``dest`` as JSONL.

Expand Down Expand Up @@ -137,7 +140,14 @@ async def materialize_jsonl(
reader = self._reader or _make_reader()
owns_reader = self._reader is None
try:
spans = await reader.get_spans(trial_id)
if isinstance(reader, SLSTraceReader) and progress_callback is not None:
# SLSTraceReader extends the protocol with progress_callback
sls_reader: SLSTraceReader = reader
spans = await sls_reader.get_spans(
trial_id, progress_callback=progress_callback
)
else:
spans = await reader.get_spans(trial_id)
finally:
if owns_reader:
await reader.close()
Expand Down