diff --git a/biomni/agent/a1.py b/biomni/agent/a1.py
index b6c10d4c5..293830d3b 100644
--- a/biomni/agent/a1.py
+++ b/biomni/agent/a1.py
@@ -2,6 +2,8 @@
import inspect
import os
import re
+import sys
+import time
from collections.abc import Generator
from pathlib import Path
from typing import Any, Literal, TypedDict
@@ -51,6 +53,8 @@ def __init__(
base_url: str | None = None,
api_key: str | None = None,
expected_data_lake_files: list | None = None,
+ trace_tracking: bool = False,
+ trace_output_dir: str = "evaluation_results/reasoning_traces",
):
"""Initialize the biomni agent.
@@ -62,6 +66,9 @@ def __init__(
timeout_seconds: Timeout for code execution in seconds
base_url: Base URL for custom model serving (e.g., "http://localhost:8000/v1")
api_key: API key for the custom LLM
+ expected_data_lake_files: List of expected data lake files
+ trace_tracking: Whether to enable trace tracking for detailed reasoning reports
+ trace_output_dir: Directory to save trace reports
"""
# Use default_config values for unspecified parameters
@@ -171,6 +178,30 @@ def __init__(
# Add timeout parameter
self.timeout_seconds = timeout_seconds # 10 minutes default timeout
+
+ # Initialize trace tracking
+ self.trace_tracking = trace_tracking
+ if self.trace_tracking:
+ # Set matplotlib backend to avoid GUI issues
+ os.environ["MPLBACKEND"] = "Agg"
+
+ # Import ReasoningTraceReporter here to avoid circular imports
+ from biomni.evaluation.reasoning_trace_reporter import ReasoningTraceReporter
+
+ self.trace_reporter = ReasoningTraceReporter(trace_output_dir)
+ else:
+ self.trace_reporter = None
+
+ # Enhanced logging for trace analysis
+ self.enhanced_log = []
+ self.current_query = None
+ self.execution_start_time = None
+
+ # Terminal output capture
+ self.terminal_output_buffer = []
+ self.original_stdout = sys.stdout
+ self.original_stderr = sys.stderr
+
self.configure()
def add_tool(self, api):
@@ -1528,24 +1559,48 @@ def go(self, prompt):
Args:
prompt: The user's query
+ Returns:
+ Tuple of (log, final_message_content)
"""
- self.critic_count = 0
- self.user_task = prompt
+ # Start trace reporting
+ if self.trace_tracking:
+ self.trace_reporter.start_trace(prompt)
+ self.current_query = prompt
+ self.execution_start_time = time.time()
+ self.enhanced_log = []
- if self.use_tool_retriever:
- selected_resources_names = self._prepare_resources_for_retrieval(prompt)
- self.update_system_prompt_with_selected_resources(selected_resources_names)
+ # Add initial query to terminal output
+ self.trace_reporter.add_terminal_output(f"Query: {prompt}", "query")
- inputs = {"messages": [HumanMessage(content=prompt)], "next_step": None}
- config = {"recursion_limit": 500, "configurable": {"thread_id": 42}}
- self.log = []
+ # Start terminal capture
+ self._start_terminal_capture()
- for s in self.app.stream(inputs, stream_mode="values", config=config):
- message = s["messages"][-1]
- out = pretty_print(message)
- self.log.append(out)
+ try:
+ self.critic_count = 0
+ self.user_task = prompt
+
+ if self.use_tool_retriever:
+ selected_resources_names = self._prepare_resources_for_retrieval(prompt)
+ self.update_system_prompt_with_selected_resources(selected_resources_names)
+
+ inputs = {"messages": [HumanMessage(content=prompt)], "next_step": None}
+ config = {"recursion_limit": 500, "configurable": {"thread_id": 42}}
+ self.log = []
+
+ for s in self.app.stream(inputs, stream_mode="values", config=config):
+ message = s["messages"][-1]
+ out = pretty_print(message)
+ self.log.append(out)
+
+ # Process trace reporting
+ if self.trace_tracking:
+ self._process_trace_reporting(self.log, message.content)
- return self.log, message.content
+ return self.log, message.content
+
+ finally:
+ # Stop terminal capture
+ self._stop_terminal_capture()
def go_stream(self, prompt) -> Generator[dict, None, None]:
"""Execute the agent with the given prompt and return a generator that yields each step.
@@ -1559,6 +1614,13 @@ def go_stream(self, prompt) -> Generator[dict, None, None]:
Yields:
dict: Each step of the agent's execution containing the current message and state
"""
+ # Start trace reporting
+ if self.trace_tracking:
+ self.trace_reporter.start_trace(prompt)
+ self.current_query = prompt
+ self.execution_start_time = time.time()
+ self.enhanced_log = []
+
self.critic_count = 0
self.user_task = prompt
@@ -1575,9 +1637,17 @@ def go_stream(self, prompt) -> Generator[dict, None, None]:
out = pretty_print(message)
self.log.append(out)
+ # Add to enhanced log for trace analysis
+ if self.trace_tracking:
+ self.enhanced_log.append(out)
+
# Yield the current step
yield {"output": out}
+ # Process trace reporting after completion
+ if self.trace_tracking:
+ self._process_trace_reporting(self.enhanced_log, message.content)
+
def update_system_prompt_with_selected_resources(self, selected_resources):
"""Update the system prompt with the selected resources."""
# Extract tool descriptions for the selected tools
@@ -1726,6 +1796,87 @@ def _inject_custom_functions_to_repl(self):
builtins._biomni_custom_functions = {}
builtins._biomni_custom_functions.update(self._custom_functions)
+ # Add custom plot saving function if trace tracking is enabled
+ if self.trace_tracking and self.trace_reporter:
+ import builtins
+ from datetime import datetime
+
+ import matplotlib.pyplot as plt
+
+ from biomni.tool.support_tools import _persistent_namespace
+
+ # Store original savefig if not already stored
+ if not hasattr(plt, "_original_savefig"):
+ plt._original_savefig = plt.savefig
+
+ # Override plt.savefig to save to query folder
+ def custom_savefig(*args, **kwargs):
+ """
+ Override plt.savefig to save to query folder when trace tracking is enabled.
+ This ensures all plots are saved in the query-specific directory.
+ """
+ if self.trace_tracking and self.trace_reporter:
+ # Determine filename
+ if args and isinstance(args[0], str):
+ filename = args[0]
+ other_args = args[1:]
+ else:
+ # Generate filename based on number of existing plots
+ plot_count = len(self.trace_reporter.trace_data["generated_plots"]) + 1
+ filename = f"plot_{plot_count}.png"
+ other_args = args
+
+ # Ensure filename has .png extension
+ if not filename.endswith(".png"):
+ filename += ".png"
+
+ # Save to query folder
+ plot_path = self.trace_reporter.query_folder / filename
+ result = plt._original_savefig(plot_path, *other_args, **kwargs)
+
+ # Add to generated plots list for final report
+ plot_info = {
+ "name": filename.replace(".png", ""),
+ "path": str(plot_path),
+ "timestamp": datetime.now().isoformat(),
+ }
+ self.trace_reporter.trace_data["generated_plots"].append(plot_info)
+
+ print(f"Plot saved to query folder: {plot_path}")
+ return result
+ else:
+ # Fall back to original savefig
+ return plt._original_savefig(*args, **kwargs)
+
+ # Replace plt.savefig with our custom version
+ plt.savefig = custom_savefig
+
+ # Inject the modified plt into the namespace
+ _persistent_namespace["plt"] = plt
+
+ # Also provide a convenience function
+ def save_plot_to_query_folder(filename=None, **kwargs):
+ """
+ Convenience function to save plot to query folder.
+ This is equivalent to plt.savefig() when trace tracking is enabled.
+
+ Args:
+ filename: Optional filename for the plot
+ **kwargs: Additional arguments to pass to plt.savefig()
+ """
+ if filename:
+ return plt.savefig(filename, **kwargs)
+ else:
+ return plt.savefig(**kwargs)
+
+ # Inject the convenience function
+ _persistent_namespace["save_plot_to_query_folder"] = save_plot_to_query_folder
+
+ # Also make it available in builtins
+ if not hasattr(builtins, "_biomni_custom_functions"):
+ builtins._biomni_custom_functions = {}
+ builtins._biomni_custom_functions["save_plot_to_query_folder"] = save_plot_to_query_folder
+
def create_mcp_server(self, tool_modules=None):
"""
Create an MCP server object that exposes internal Biomni tools.
@@ -1877,3 +2028,271 @@ def wrapper(**kwargs) -> dict:
wrapper.__signature__ = inspect.Signature(new_params, return_annotation=dict)
return wrapper
+
+ def _start_terminal_capture(self):
+ """Start capturing terminal output."""
+ if self.trace_tracking:
+ self.terminal_output_buffer = []
+
+ # Create a custom stdout that captures output
+ class CapturingStdout:
+ def __init__(self, original_stdout, buffer, reporter):
+ self.original_stdout = original_stdout
+ self.buffer = buffer
+ self.reporter = reporter
+
+ def write(self, text):
+ self.original_stdout.write(text)
+ self.buffer.append(text)
+ if self.reporter and hasattr(self.reporter, "add_terminal_output"):
+ self.reporter.add_terminal_output(text, "stdout")
+
+ def flush(self):
+ self.original_stdout.flush()
+
+ sys.stdout = CapturingStdout(self.original_stdout, self.terminal_output_buffer, self.trace_reporter)
+
+ def _stop_terminal_capture(self):
+ """Stop capturing terminal output."""
+ if self.trace_tracking:
+ sys.stdout = self.original_stdout
+ sys.stderr = self.original_stderr
+
+ def _process_trace_reporting(self, log: list[Any], final_content: str):
+ """
+ Process the execution log to generate trace report.
+
+ Args:
+ log: The execution log from the agent
+ final_content: The final content returned by the agent
+ """
+ if not self.trace_tracking or not self.trace_reporter:
+ return
+
+ # Parse the log to extract trace information
+ self.trace_reporter.parse_agent_log(log)
+
+ # Add performance metrics
+ if self.execution_start_time:
+ execution_time = time.time() - self.execution_start_time
+ self.trace_reporter.trace_data["performance_metrics"]["total_execution_time"] = execution_time
+
+ # End the trace
+ self.trace_reporter.end_trace(final_content)
+
+ def generate_trace_report(self, filename: str | None = None) -> str:
+ """
+ Generate an HTML trace report for the last execution.
+
+ Args:
+ filename: Optional filename for the report
+
+ Returns:
+ Path to the generated HTML file
+ """
+ if not self.trace_tracking or not self.trace_reporter:
+ raise RuntimeError("Trace tracking is not enabled")
+
+ return self.trace_reporter.generate_html_report(filename)
+
+ def generate_final_user_report(self, filename: str | None = None) -> str:
+ """
+ Generate a clean, final user report with plots and evidence.
+
+ Args:
+ filename: Optional filename for the report
+
+ Returns:
+ Path to the generated HTML file
+ """
+ if not self.trace_tracking or not self.trace_reporter:
+ raise RuntimeError("Trace tracking is not enabled")
+
+ return self.trace_reporter.generate_final_user_report(filename)
+
+ def capture_plot(self, plot_name: str = None) -> str:
+ """
+ Capture the current matplotlib plot and save it.
+
+ Args:
+ plot_name: Optional name for the plot file
+
+ Returns:
+ Path to the saved plot file
+ """
+ if not self.trace_tracking or not self.trace_reporter:
+ raise RuntimeError("Trace tracking is not enabled")
+
+ return self.trace_reporter.capture_plot(plot_name)
+
+ def set_final_result(self, result: str):
+ """
+ Set the final result for the query.
+
+ Args:
+ result: The final result text
+ """
+ if not self.trace_tracking or not self.trace_reporter:
+ raise RuntimeError("Trace tracking is not enabled")
+
+ self.trace_reporter.set_final_result(result)
+
+ def capture_current_plot(self, plot_name: str = None) -> str:
+ """
+ Capture the current plot and save it to the query folder.
+ This method can be called from within the agent's execution.
+
+ Args:
+ plot_name: Optional name for the plot file
+
+ Returns:
+ Path to the saved plot file
+ """
+ if not self.trace_tracking or not self.trace_reporter:
+ return None
+
+ try:
+ import matplotlib.pyplot as plt
+
+ # Check if there's a current figure
+ if not plt.get_fignums():
+ print("Warning: No active plot to capture")
+ return None
+
+ # Capture the current figure
+ if not plot_name:
+ plot_name = f"plot_{len(self.trace_reporter.trace_data['generated_plots']) + 1}"
+
+ return self.capture_plot(plot_name)
+
+ except Exception as e:
+ print(f"Warning: Could not capture current plot: {e}")
+ return None
+
+ def save_complete_terminal_output(self, filename: str | None = None) -> str:
+ """
+ Save the complete terminal output to a text file.
+
+ Args:
+ filename: Optional filename for the output file
+
+ Returns:
+ Path to the saved text file
+ """
+ if not self.trace_tracking or not self.trace_reporter:
+ raise RuntimeError("Trace tracking is not enabled")
+
+ return self.trace_reporter.save_complete_terminal_output(filename)
+
+ def get_trace_data(self) -> dict[str, Any]:
+ """
+ Get the current trace data for analysis.
+
+ Returns:
+ Dictionary containing trace data
+ """
+ if not self.trace_tracking or not self.trace_reporter:
+ return {}
+
+ return self.trace_reporter.trace_data
+
+ def add_custom_trace_step(self, step_type: str, content: Any, metadata: dict | None = None):
+ """
+ Add a custom step to the trace for additional analysis.
+
+ Args:
+ step_type: Type of the step
+ content: Content of the step
+ metadata: Additional metadata
+ """
+ if self.trace_tracking and self.trace_reporter:
+ self.trace_reporter.add_step(step_type, content, metadata)
+
+ def analyze_tool_usage_patterns(self) -> dict[str, Any]:
+ """
+ Analyze tool usage patterns from the trace data.
+
+ Returns:
+ Dictionary containing analysis results
+ """
+ if not self.trace_tracking or not self.trace_reporter:
+ return {}
+
+ trace_data = self.trace_reporter.trace_data
+ analysis = {
+ "total_tool_calls": len(trace_data.get("tool_calls", [])),
+ "tool_usage_frequency": {},
+ "code_execution_frequency": {},
+ "reasoning_steps_breakdown": {},
+ "performance_metrics": trace_data["performance_metrics"],
+ }
+
+ # Analyze tool usage frequency
+ for tool_call in trace_data.get("tool_calls", []):
+ tool_name = tool_call["tool_name"]
+ analysis["tool_usage_frequency"][tool_name] = analysis["tool_usage_frequency"].get(tool_name, 0) + 1
+
+ # Analyze code execution patterns
+ for code_exec in trace_data["code_executions"]:
+ code_type = "generated" if code_exec["is_generated"] else "pre_written"
+ analysis["code_execution_frequency"][code_type] = analysis["code_execution_frequency"].get(code_type, 0) + 1
+
+ # Analyze reasoning steps
+ for step in trace_data["steps"]:
+ step_type = step["type"]
+ analysis["reasoning_steps_breakdown"][step_type] = (
+ analysis["reasoning_steps_breakdown"].get(step_type, 0) + 1
+ )
+
+ # Count code generation steps separately
+ if step_type == "code_generation":
+ analysis["code_execution_frequency"]["generated"] = (
+ analysis["code_execution_frequency"].get("generated", 0) + 1
+ )
+
+ return analysis
+
+ def export_trace_data(self, format: str = "json", filename: str | None = None) -> str:
+ """
+ Export trace data in various formats for further analysis.
+
+ Args:
+ format: Export format ('json', 'csv', 'pickle')
+ filename: Optional filename for export
+
+ Returns:
+ Path to the exported file
+ """
+ if not self.trace_tracking or not self.trace_reporter:
+ raise RuntimeError("Trace tracking is not enabled")
+
+ import json
+ import pickle
+ from datetime import datetime
+
+ import pandas as pd
+
+ if not filename:
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ filename = f"trace_data_{timestamp}.{format}"
+
+ # Use query_folder if available, otherwise fall back to output_dir
+ if hasattr(self.trace_reporter, "query_folder") and self.trace_reporter.query_folder:
+ filepath = self.trace_reporter.query_folder / filename
+ else:
+ filepath = self.trace_reporter.output_dir / filename
+
+ if format == "json":
+ with open(filepath, "w", encoding="utf-8") as f:
+ json.dump(self.trace_reporter.trace_data, f, indent=2, default=str)
+ elif format == "csv":
+ # Convert steps to DataFrame
+ steps_df = pd.DataFrame(self.trace_reporter.trace_data["steps"])
+ steps_df.to_csv(filepath, index=False)
+ elif format == "pickle":
+ with open(filepath, "wb") as f:
+ pickle.dump(self.trace_reporter.trace_data, f)
+ else:
+ raise ValueError(f"Unsupported format: {format}")
+
+ return str(filepath)
diff --git a/biomni/evaluation/__init__.py b/biomni/evaluation/__init__.py
new file mode 100644
index 000000000..72d8728b0
--- /dev/null
+++ b/biomni/evaluation/__init__.py
@@ -0,0 +1,10 @@
+"""
+Evaluation module for Biomni
+
+This module provides tools and utilities for evaluating biomni's performance,
+reasoning trace analysis, and detailed reporting capabilities.
+"""
+
+from .reasoning_trace_reporter import ReasoningTraceReporter
+
+__all__ = ["ReasoningTraceReporter"]
diff --git a/biomni/evaluation/reasoning_trace_reporter.py b/biomni/evaluation/reasoning_trace_reporter.py
new file mode 100644
index 000000000..79a88fa78
--- /dev/null
+++ b/biomni/evaluation/reasoning_trace_reporter.py
@@ -0,0 +1,1306 @@
+"""
+Reasoning Trace Reporter for Biomni
+
+This module provides functionality to generate detailed HTML reports of biomni's reasoning trace,
+including all tool calls, code execution, and reasoning steps for evaluation purposes.
+"""
+
+import os
+import re
+from datetime import datetime
+from pathlib import Path
+from typing import Any
+
+from jinja2 import Template
+
+
+class ReasoningTraceReporter:
+ """
+ A class to generate detailed HTML reports of biomni's reasoning trace.
+
+ This reporter captures:
+ - User queries and system responses
+ - Tool calls and their parameters
+ - Code execution (both generated and called)
+ - Reasoning steps and thought processes
+ - Execution timing and performance metrics
+ """
+
+ def __init__(self, output_dir: str = "evaluation_results/reasoning_traces"):
+ """
+ Initialize the reasoning trace reporter.
+
+ Args:
+ output_dir: Directory to save HTML reports
+ """
+ self.output_dir = Path(output_dir)
+ self.output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Trace data structure
+ self.trace_data = {
+ "query": "",
+ "start_time": None,
+ "end_time": None,
+ "steps": [],
+ "code_executions": [],
+ "performance_metrics": {},
+ "complete_terminal_output": [], # Store complete terminal output
+ "generated_plots": [], # Store generated plots
+ "final_result": "", # Store final result
+ }
+
+ # HTML template for the report
+ self.html_template = self._get_html_template()
+
+ def start_trace(self, query: str):
+ """Start tracing a new query execution."""
+ # Create query-specific subfolder
+ query_slug = re.sub(r"[^a-zA-Z0-9]", "_", query[:50])
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ self.query_folder = self.output_dir / f"query_{query_slug}_{timestamp}"
+ self.query_folder.mkdir(parents=True, exist_ok=True)
+
+ self.trace_data = {
+ "query": query,
+ "start_time": datetime.now(),
+ "end_time": None,
+ "steps": [],
+ "code_executions": [],
+ "performance_metrics": {},
+ "complete_terminal_output": [], # Initialize terminal output storage
+ "generated_plots": [], # Store generated plots
+ "final_result": "", # Store final result
+ }
+
+ def add_step(self, step_type: str, content: Any, metadata: dict | None = None):
+ """
+ Add a step to the reasoning trace.
+
+ Args:
+ step_type: Type of step (e.g., 'thinking', 'tool_call', 'code_execution', 'observation')
+ content: Content of the step
+ metadata: Additional metadata about the step
+ """
+ step = {
+ "type": step_type,
+ "content": content,
+ "metadata": metadata or {},
+ "timestamp": datetime.now().isoformat(),
+ }
+ self.trace_data["steps"].append(step)
+
+ def add_code_execution(self, code: str, result: Any, execution_time: float = None, is_generated: bool = True):
+ """
+ Add a code execution to the trace.
+
+ Args:
+ code: Code that was executed
+ result: Result of the execution
+ execution_time: Time taken for execution
+ is_generated: Whether the code was generated on-the-fly
+ """
+ code_execution = {
+ "code": code,
+ "result": result,
+ "execution_time": execution_time,
+ "is_generated": is_generated,
+ "timestamp": datetime.now().isoformat(),
+ }
+ self.trace_data["code_executions"].append(code_execution)
+ self.trace_data["steps"].append(
+ {
+ "type": "code_execution",
+ "content": code_execution,
+ "metadata": {},
+ "timestamp": datetime.now().isoformat(),
+ }
+ )
+
+ def end_trace(self, final_result: Any = None):
+ """End the current trace and calculate performance metrics."""
+ self.trace_data["end_time"] = datetime.now()
+
+ if self.trace_data["start_time"] and self.trace_data["end_time"]:
+ total_time = (self.trace_data["end_time"] - self.trace_data["start_time"]).total_seconds()
+ self.trace_data["performance_metrics"]["total_execution_time"] = total_time
+ self.trace_data["performance_metrics"]["total_steps"] = len(self.trace_data["steps"])
+ self.trace_data["performance_metrics"]["total_code_executions"] = len(self.trace_data["code_executions"])
+
+ if final_result:
+ self.trace_data["final_result"] = final_result
+
+ def parse_agent_log(self, log: list[Any]):
+ """
+ Parse an agent log to extract reasoning trace information.
+
+ Args:
+ log: List of log entries from the agent
+ """
+ for i, log_entry in enumerate(log):
+ if isinstance(log_entry, str):
+ # Parse different types of log entries
+ self._parse_log_entry(log_entry, i)
+
+ def _parse_log_entry(self, log_entry: str, step_index: int):
+ """Parse a single log entry to extract trace information."""
+
+ # Clean up the log entry - handle escaped characters and formatting
+ cleaned_entry = self._clean_log_entry(log_entry)
+
+ # First, check for structured content (think, execute, solution blocks)
+ if " Comprehensive analysis and findings from biomni reasoning system No plots were generated during this analysis. Generated at: {plot["timestamp"]}𧬠Final Analysis Report
+ š Query
+ š Analysis Summary
+ šÆ Final Results
+ š Generated Visualizations
+ {self._generate_plots_html()}
+
+
\1
", html, flags=re.MULTILINE)
+ html = re.sub(r"^## (.*?)$", r"\1
", html, flags=re.MULTILINE)
+ html = re.sub(r"^# (.*?)$", r"\1
", html, flags=re.MULTILINE)
+
+ # Bold and italic
+ html = re.sub(r"\*\*(.*?)\*\*", r"\1", html)
+ html = re.sub(r"\*(.*?)\*", r"\1", html)
+
+ # Code blocks
+ html = re.sub(r"```(.*?)```", r"
", html, flags=re.DOTALL)
+ html = re.sub(r"`(.*?)`", r"\1\1", html)
+
+ # Lists
+ html = re.sub(r"^\d+\. (.*?)$", r" or
+ html = re.sub(r"(
{m.group(0)}
", html, flags=re.DOTALL)
+
+ # Line breaks
+ html = re.sub(r"\n\n", r"
", html)
+ html = re.sub(r"\n", r"
", html)
+
+ # Wrap in paragraphs if not already wrapped
+ if not html.startswith("<"):
+ html = f"
{html}
" + + # Clean up empty paragraphs + html = re.sub(r"\s*
", "", html) + html = re.sub(r"{{ query }}
+Tool: {{ step.content.tool }}
+ {% endif %} +Context: {{ step.content.context[:200] }}{% if step.content.context|length > 200 %}...{% endif %}
+ {% endif %} +{{ step.content.content }}
+ {% else %} +{{ step.content }}
+ {% endif %} + {% if step.content.full_context and step.content.full_context != step.content.content %} +{{ step.content.content }}
+ {% if step.content.full_context and step.content.full_context != step.content.content %} +{{ step.content }}
+ {% endif %} + {% if step.content.full_context and step.content.full_context != step.content.content %} +{{ step.content.content }}
+ {% else %} +{{ step.content }}
+ {% endif %} + {% if step.content.full_context and step.content.full_context != step.content.content %} +Error: {{ step.content.error_message }}
+ {% if step.content.full_log %} +{{ step.content }}
+