diff --git a/src/dcv_benchmark/core/runner.py b/src/dcv_benchmark/core/runner.py index dded9c3..1df66a1 100644 --- a/src/dcv_benchmark/core/runner.py +++ b/src/dcv_benchmark/core/runner.py @@ -13,6 +13,7 @@ from dcv_benchmark.models.responses import TargetResponse from dcv_benchmark.models.traces import TraceItem from dcv_benchmark.utils.logger import ( + ExperimentProgressLogger, get_logger, print_dataset_header, print_experiment_header, @@ -73,7 +74,9 @@ def run( if limit: total_samples = min(total_samples, limit) - log_interval = max(1, total_samples // 10) + # Initialize Progress Logger + progress_logger = ExperimentProgressLogger(total_samples) + progress_logger.start() with open(traces_path, "w", encoding="utf-8") as f: for sample in dataset.samples: @@ -81,12 +84,8 @@ def run( logger.info(f"Limit of {limit} reached.") break - if (count + 1) % log_interval == 0 or (count + 1) == total_samples: - pct = ((count + 1) / total_samples) * 100 - logger.info( - f"Progress: {count + 1}/{total_samples} " - f"({pct:.0f}%) samples processed." - ) + # Update Progress + progress_logger.log_progress(count, success_count) logger.debug( f"Processing Sample {count + 1}/{total_samples} " diff --git a/src/dcv_benchmark/utils/logger.py b/src/dcv_benchmark/utils/logger.py index 0b2ce71..79acff1 100644 --- a/src/dcv_benchmark/utils/logger.py +++ b/src/dcv_benchmark/utils/logger.py @@ -1,3 +1,4 @@ +import datetime import logging import sys from typing import Any @@ -184,3 +185,74 @@ def print_run_summary(metrics: Any, duration: float, artifacts_path: str) -> Non logger.info("=" * 90) logger.info(f"Artifacts: {artifacts_path}") logger.info("=" * 90) + + +class ExperimentProgressLogger: + """ + Handles logging of experiment progress, including start messages, + step updates, and ETA calculations. + """ + + def __init__(self, total_samples: int): + self.total_samples: int = total_samples + self.start_time: datetime.datetime | None = None + self.logger: logging.Logger = get_logger(__name__) + # interval for logging progress (10%) + self.log_interval = max(1, self.total_samples // 10) + + def start(self) -> None: + """ + Logs the start of the experiment. + """ + + self.start_time = datetime.datetime.now() + self.logger.info( + f"🚀 [STARTED] Experiment started with {self.total_samples} samples." + ) + + def log_progress(self, current_count: int, success_count: int) -> None: + """ + Logs progress if the current count hits the 10% interval or is the last sample. + Calculates ETA if the elapsed time is sufficient. + """ + + # Check if we should log (10% interval or last sample) + if (current_count) % self.log_interval == 0 or ( + current_count + ) == self.total_samples: + if self.start_time is None: + self.start_time = datetime.datetime.now() + + pct = (current_count / self.total_samples) * 100 + elapsed = datetime.datetime.now() - self.start_time + + # success rate calculation + if current_count > 0: + success_rate = (success_count / current_count) * 100 + else: + success_rate = 0.0 + + msg = ( + f"🔄 [RUNNING] Progress: {current_count}/{self.total_samples} " + f"({pct:.0f}%) | Success Rate: {success_rate:.1f}%" + ) + + # ETA Calculation + # Only show ETA if we are past the first interval and it's taking some time + # This avoids ETA on super fast runs + seconds_elapsed = elapsed.total_seconds() + if seconds_elapsed > 5 and current_count < self.total_samples: + avg_time_per_sample = seconds_elapsed / current_count + remaining_samples = self.total_samples - current_count + eta_seconds = remaining_samples * avg_time_per_sample + + # Format ETA + if eta_seconds < 60: + eta_str = "< 1 min" + else: + eta_min = int(eta_seconds // 60) + eta_str = f"~{eta_min} min" + + msg += f" | ETA: {eta_str}" + + self.logger.info(msg) diff --git a/tests/unit/utils/test_logger.py b/tests/unit/utils/test_logger.py new file mode 100644 index 0000000..099663e --- /dev/null +++ b/tests/unit/utils/test_logger.py @@ -0,0 +1,99 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from dcv_benchmark.utils.logger import ExperimentProgressLogger + + +@pytest.fixture +def mock_logger(): + with patch("dcv_benchmark.utils.logger.get_logger") as mock_get_logger: + yield mock_get_logger.return_value + + +def test_start_log(mock_logger): + progress_logger = ExperimentProgressLogger(total_samples=100) + progress_logger.start() + + # helper to check if start message was logged + mock_logger.info.assert_any_call( + "🚀 [STARTED] Experiment started with 100 samples." + ) + + +def test_log_progress_logic(mock_logger): + # Total samples 100 -> Interval 10 + progress_logger = ExperimentProgressLogger(total_samples=100) + progress_logger.start() + + # Call with count 5 (should NOT log) + progress_logger.log_progress(current_count=5, success_count=4) + # Verify info was NOT called with progress message + # We need to be careful not to match the start message + # Check that no calls matching "Progress:" were made + for call in mock_logger.info.call_args_list: + assert "Progress:" not in call[0][0] + + # Call with count 10 (should log) + progress_logger.log_progress(current_count=10, success_count=8) + + found_progress = False + for call in mock_logger.info.call_args_list: + if "Progress: 10/100" in call[0][0]: + found_progress = True + assert "Success Rate: 80.0%" in call[0][0] + break + assert found_progress, "Did not find expected progress log at 10%" + + +@patch("dcv_benchmark.utils.logger.datetime") +def test_eta_calculation(mock_datetime_module, mock_logger): + # Setup start time + start_time = MagicMock() + now_time_2 = MagicMock() + + # We need now() to return start_time first, then a later time + mock_datetime_module.datetime.now.side_effect = [start_time, now_time_2] + + # Setup elapsed time + # elapsed = now - start + mock_timedelta = MagicMock() + mock_timedelta.total_seconds.return_value = 60.0 # 60 seconds elapsed + + # When (now - start) is called, return mock_timedelta + now_time_2.__sub__.return_value = mock_timedelta + + progress_logger = ExperimentProgressLogger(total_samples=100) + progress_logger.start() # This calls now() -> start_time + + # Next call to log_progress calls now() -> now_time_2 + # Current count 10. Elapsed 60s. + # Avg time = 6s/sample. Remaining 90. ETA = 540s = 9 mins. + progress_logger.log_progress(current_count=10, success_count=10) + + found_eta = False + for call in mock_logger.info.call_args_list: + msg = call[0][0] + if "ETA:" in msg: + found_eta = True + assert "~9 min" in msg + break + + assert found_eta, "ETA was not logged or incorrect" + + +def test_last_sample_always_logged(mock_logger): + progress_logger = ExperimentProgressLogger(total_samples=23) # Interval 2 + progress_logger.start() + + # 23 % 2 != 0, so it wouldn't log by interval arithmetic usually (1..23) + # But it IS the last sample, so it MUST log. + progress_logger.log_progress(current_count=23, success_count=23) + + found_completion = False + for call in mock_logger.info.call_args_list: + if "Progress: 23/23" in call[0][0]: + found_completion = True + assert "100%" in call[0][0] + + assert found_completion