Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 45 additions & 27 deletions ade_bench/utils/logger.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,36 @@
import logging
import sys
import time
import threading
from datetime import datetime
from typing import Optional, Dict
from rich.console import Console
from rich.console import Console, ConsoleOptions, RenderResult
from rich.table import Table
from rich.live import Live

from ade_bench.config import config


class _TableRenderable:
"""Renderable that builds a fresh table on each Rich refresh cycle."""

def __init__(self, logger: "RichTaskLogger"):
self._logger = logger

def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult:
yield self._logger._create_table()


class RichTaskLogger:
"""Rich-based logger that shows one row per task with live updates."""

def __init__(self):
self._lock = threading.Lock()
self._console = Console()
self._table = None
self._live = None
self._task_data: Dict[str, Dict[str, str]] = {} # task_id -> {stage, message, timestamp}
self._initialized = False
self._last_refresh_time = 0.0
self._original_console_handlers = [] # Store original console handlers

def initialize_tasks(self, task_ids: list[str]) -> None:
Expand All @@ -43,14 +54,15 @@ def initialize_tasks(self, task_ids: list[str]) -> None:
"timestamp": ""
}

# Create initial table
self._table = self._create_table()

# Start live display
self._live = Live(self._table, console=self._console, refresh_per_second=4)
# Start live display with manual refresh to avoid flicker from
# unnecessary redraws when no data has changed.
self._live = Live(_TableRenderable(self), console=self._console, auto_refresh=False)
self._live.start()
self._initialized = True

# Initial render outside the lock (refresh -> _create_table -> acquires lock)
self._live.refresh()

def _disable_console_handlers(self) -> None:
"""Disable console handlers to prevent double printing."""
# Get all existing loggers
Expand All @@ -64,8 +76,11 @@ def _disable_console_handlers(self) -> None:
self._original_console_handlers.append((logger_obj, handler))
logger_obj.removeHandler(handler)

_MIN_REFRESH_INTERVAL = 0.2 # seconds between redraws

def update_task_from_dict(self, log_data: dict) -> None:
"""Update a specific task's row from log data dictionary."""
should_refresh = False
with self._lock:
task_id = log_data["task"]

Expand All @@ -88,12 +103,21 @@ def update_task_from_dict(self, log_data: dict) -> None:
"timestamp": log_data["formatted_timestamp"]
}

# Rebuild table with updated data
self._rebuild_table()
now = time.monotonic()
if now - self._last_refresh_time >= self._MIN_REFRESH_INTERVAL:
self._last_refresh_time = now
should_refresh = True

if should_refresh and self._live:
self._live.refresh()

def _create_table(self) -> Table:
"""Create a new table with current task data."""
with self._lock:
# Shallow copy is safe: update_task_from_dict replaces inner dicts
# atomically rather than mutating them in place.
task_data_snapshot = dict(self._task_data)

# Create table with fixed column widths
table = Table(show_header=True, header_style="bold magenta")
table.add_column("Time", style="dim", width=8, no_wrap=True)
Expand All @@ -102,7 +126,7 @@ def _create_table(self) -> Table:
table.add_column("Message", style="white", width=100, no_wrap=True)

# Add rows for each task (excluding SUMMARY)
for task_id, data in self._task_data.items():
for task_id, data in task_data_snapshot.items():
if task_id != "SUMMARY": # Skip SUMMARY - it gets added separately
table.add_row(
data["timestamp"],
Expand All @@ -115,8 +139,8 @@ def _create_table(self) -> Table:
table.add_row("─" * 8, "─" * 32, "─" * 12, "─" * 100)

# Add summary row (if it exists in task data)
if "SUMMARY" in self._task_data:
summary_data = self._task_data["SUMMARY"]
if "SUMMARY" in task_data_snapshot:
summary_data = task_data_snapshot["SUMMARY"]
table.add_row(
summary_data["timestamp"],
"SUMMARY",
Expand All @@ -126,23 +150,21 @@ def _create_table(self) -> Table:

return table

def _rebuild_table(self) -> None:
"""Rebuild the table with current task data."""
if not self._live:
return

# Create new table and update live display
new_table = self._create_table()
self._live.update(new_table)


def stop(self) -> None:
"""Stop the live display and re-enable console handlers."""
live = None
with self._lock:
if self._live:
self._live.stop()
live = self._live
self._initialized = False

# Final refresh + stop outside the lock to avoid deadlock
if live:
live.refresh()
live.stop()

with self._lock:
self._live = None
# Re-enable console handlers
for logger_obj, handler in self._original_console_handlers:
logger_obj.addHandler(handler)
Expand Down Expand Up @@ -277,9 +299,5 @@ def initialize_dynamic_logging(task_ids: list[str]) -> None:
# Initialize the Rich table with tasks
rich_logger.initialize_tasks(task_ids)

# Give a moment for the table to initialize
import time
time.sleep(0.5)


logger = setup_logger(__name__)