Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions src/dcv_benchmark/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dcv_benchmark.models.responses import TargetResponse
from dcv_benchmark.models.traces import TraceItem
from dcv_benchmark.utils.logger import (
ExperimentProgressLogger,
get_logger,
print_dataset_header,
print_experiment_header,
Expand Down Expand Up @@ -73,20 +74,18 @@ def run(
if limit:
total_samples = min(total_samples, limit)

log_interval = max(1, total_samples // 10)
# Initialize Progress Logger
progress_logger = ExperimentProgressLogger(total_samples)
progress_logger.start()

with open(traces_path, "w", encoding="utf-8") as f:
for sample in dataset.samples:
if limit and count >= limit:
logger.info(f"Limit of {limit} reached.")
break

if (count + 1) % log_interval == 0 or (count + 1) == total_samples:
pct = ((count + 1) / total_samples) * 100
logger.info(
f"Progress: {count + 1}/{total_samples} "
f"({pct:.0f}%) samples processed."
)
# Update Progress
progress_logger.log_progress(count, success_count)

logger.debug(
f"Processing Sample {count + 1}/{total_samples} "
Expand Down
72 changes: 72 additions & 0 deletions src/dcv_benchmark/utils/logger.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import logging
import sys
from typing import Any
Expand Down Expand Up @@ -184,3 +185,74 @@ def print_run_summary(metrics: Any, duration: float, artifacts_path: str) -> Non
logger.info("=" * 90)
logger.info(f"Artifacts: {artifacts_path}")
logger.info("=" * 90)


class ExperimentProgressLogger:
"""
Handles logging of experiment progress, including start messages,
step updates, and ETA calculations.
"""

def __init__(self, total_samples: int):
self.total_samples: int = total_samples
self.start_time: datetime.datetime | None = None
self.logger: logging.Logger = get_logger(__name__)
# interval for logging progress (10%)
self.log_interval = max(1, self.total_samples // 10)

def start(self) -> None:
"""
Logs the start of the experiment.
"""

self.start_time = datetime.datetime.now()
self.logger.info(
f"🚀 [STARTED] Experiment started with {self.total_samples} samples."
)

def log_progress(self, current_count: int, success_count: int) -> None:
"""
Logs progress if the current count hits the 10% interval or is the last sample.
Calculates ETA if the elapsed time is sufficient.
"""

# Check if we should log (10% interval or last sample)
if (current_count) % self.log_interval == 0 or (
current_count
) == self.total_samples:
if self.start_time is None:
self.start_time = datetime.datetime.now()

pct = (current_count / self.total_samples) * 100
elapsed = datetime.datetime.now() - self.start_time

# success rate calculation
if current_count > 0:
success_rate = (success_count / current_count) * 100
else:
success_rate = 0.0

msg = (
f"🔄 [RUNNING] Progress: {current_count}/{self.total_samples} "
f"({pct:.0f}%) | Success Rate: {success_rate:.1f}%"
)

# ETA Calculation
# Only show ETA if we are past the first interval and it's taking some time
# This avoids ETA on super fast runs
seconds_elapsed = elapsed.total_seconds()
if seconds_elapsed > 5 and current_count < self.total_samples:
avg_time_per_sample = seconds_elapsed / current_count
remaining_samples = self.total_samples - current_count
eta_seconds = remaining_samples * avg_time_per_sample

# Format ETA
if eta_seconds < 60:
eta_str = "< 1 min"
else:
eta_min = int(eta_seconds // 60)
eta_str = f"~{eta_min} min"

msg += f" | ETA: {eta_str}"

self.logger.info(msg)
99 changes: 99 additions & 0 deletions tests/unit/utils/test_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from unittest.mock import MagicMock, patch

import pytest

from dcv_benchmark.utils.logger import ExperimentProgressLogger


@pytest.fixture
def mock_logger():
with patch("dcv_benchmark.utils.logger.get_logger") as mock_get_logger:
yield mock_get_logger.return_value


def test_start_log(mock_logger):
progress_logger = ExperimentProgressLogger(total_samples=100)
progress_logger.start()

# helper to check if start message was logged
mock_logger.info.assert_any_call(
"🚀 [STARTED] Experiment started with 100 samples."
)


def test_log_progress_logic(mock_logger):
# Total samples 100 -> Interval 10
progress_logger = ExperimentProgressLogger(total_samples=100)
progress_logger.start()

# Call with count 5 (should NOT log)
progress_logger.log_progress(current_count=5, success_count=4)
# Verify info was NOT called with progress message
# We need to be careful not to match the start message
# Check that no calls matching "Progress:" were made
for call in mock_logger.info.call_args_list:
assert "Progress:" not in call[0][0]

# Call with count 10 (should log)
progress_logger.log_progress(current_count=10, success_count=8)

found_progress = False
for call in mock_logger.info.call_args_list:
if "Progress: 10/100" in call[0][0]:
found_progress = True
assert "Success Rate: 80.0%" in call[0][0]
break
assert found_progress, "Did not find expected progress log at 10%"


@patch("dcv_benchmark.utils.logger.datetime")
def test_eta_calculation(mock_datetime_module, mock_logger):
# Setup start time
start_time = MagicMock()
now_time_2 = MagicMock()

# We need now() to return start_time first, then a later time
mock_datetime_module.datetime.now.side_effect = [start_time, now_time_2]

# Setup elapsed time
# elapsed = now - start
mock_timedelta = MagicMock()
mock_timedelta.total_seconds.return_value = 60.0 # 60 seconds elapsed

# When (now - start) is called, return mock_timedelta
now_time_2.__sub__.return_value = mock_timedelta

progress_logger = ExperimentProgressLogger(total_samples=100)
progress_logger.start() # This calls now() -> start_time

# Next call to log_progress calls now() -> now_time_2
# Current count 10. Elapsed 60s.
# Avg time = 6s/sample. Remaining 90. ETA = 540s = 9 mins.
progress_logger.log_progress(current_count=10, success_count=10)

found_eta = False
for call in mock_logger.info.call_args_list:
msg = call[0][0]
if "ETA:" in msg:
found_eta = True
assert "~9 min" in msg
break

assert found_eta, "ETA was not logged or incorrect"


def test_last_sample_always_logged(mock_logger):
progress_logger = ExperimentProgressLogger(total_samples=23) # Interval 2
progress_logger.start()

# 23 % 2 != 0, so it wouldn't log by interval arithmetic usually (1..23)
# But it IS the last sample, so it MUST log.
progress_logger.log_progress(current_count=23, success_count=23)

found_completion = False
for call in mock_logger.info.call_args_list:
if "Progress: 23/23" in call[0][0]:
found_completion = True
assert "100%" in call[0][0]

assert found_completion