diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9992b77..d5fdccf 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -27,7 +27,7 @@ jobs: - name: Install package and test deps run: | pip install -e . - pip install pytest pytest-cov + pip install pytest pytest-cov pydantic-ai - name: Run unit tests with coverage run: | diff --git a/gradient_adk/decorator.py b/gradient_adk/decorator.py index 8031eda..50e4aa6 100644 --- a/gradient_adk/decorator.py +++ b/gradient_adk/decorator.py @@ -18,9 +18,12 @@ logger = get_logger(__name__) -from gradient_adk.runtime.langgraph.helpers import capture_graph, get_tracker +# Initialize framework instrumentation using the centralized registry +# This is idempotent and will only install instrumentation once +# Each instrumentor checks for its own environment variable to allow disabling +from gradient_adk.runtime.helpers import capture_all, get_tracker -capture_graph() +capture_all() class _StreamingIteratorWithTracking: diff --git a/gradient_adk/digital_ocean_api/__init__.py b/gradient_adk/digital_ocean_api/__init__.py index ad5d403..43c31b9 100644 --- a/gradient_adk/digital_ocean_api/__init__.py +++ b/gradient_adk/digital_ocean_api/__init__.py @@ -1,5 +1,10 @@ from .models import ( TraceSpanType, + SpanCommon, + LLMSpanDetails, + ToolSpanDetails, + RetrieverSpanDetails, + WorkflowSpanDetails, Span, Trace, CreateTracesInput, @@ -38,6 +43,11 @@ __all__ = [ "TraceSpanType", + "SpanCommon", + "LLMSpanDetails", + "ToolSpanDetails", + "RetrieverSpanDetails", + "WorkflowSpanDetails", "Span", "Trace", "CreateTracesInput", @@ -70,4 +80,4 @@ "DOAPINetworkError", "DOAPIValidationError", "AsyncDigitalOceanGenAI", -] +] \ No newline at end of file diff --git a/gradient_adk/digital_ocean_api/models.py b/gradient_adk/digital_ocean_api/models.py index a0bcad2..3f963a4 100644 --- a/gradient_adk/digital_ocean_api/models.py +++ b/gradient_adk/digital_ocean_api/models.py @@ -12,13 +12,71 @@ class TraceSpanType(str, Enum): TRACE_SPAN_TYPE_LLM = "TRACE_SPAN_TYPE_LLM" TRACE_SPAN_TYPE_RETRIEVER = "TRACE_SPAN_TYPE_RETRIEVER" TRACE_SPAN_TYPE_TOOL = "TRACE_SPAN_TYPE_TOOL" + TRACE_SPAN_TYPE_WORKFLOW = "TRACE_SPAN_TYPE_WORKFLOW" + TRACE_SPAN_TYPE_AGENT = "TRACE_SPAN_TYPE_AGENT" + + +class SpanCommon(BaseModel): + """Common fields for all span types.""" + + model_config = ConfigDict(populate_by_name=True, extra="allow") + + duration_ns: Optional[int] = Field(None, description="Duration in nanoseconds") + metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata") + tags: Optional[List[str]] = Field(None, description="Tags for the span") + status_code: Optional[int] = Field(None, description="HTTP status code if applicable") + + +class LLMSpanDetails(BaseModel): + """LLM-specific span details.""" + + model_config = ConfigDict(populate_by_name=True, extra="allow") + + common: Optional[SpanCommon] = None + model: Optional[str] = Field(None, description="Model name") + tools: Optional[List[Dict[str, Any]]] = Field( + None, description="Tool definitions passed to the model" + ) + num_input_tokens: Optional[int] = Field(None, description="Number of input tokens") + num_output_tokens: Optional[int] = Field(None, description="Number of output tokens") + total_tokens: Optional[int] = Field(None, description="Total tokens") + temperature: Optional[float] = Field(None, description="Temperature setting") + time_to_first_token_ns: Optional[int] = Field( + None, description="Time to first token in nanoseconds" + ) + + +class ToolSpanDetails(BaseModel): + """Tool-specific span details.""" + + model_config = ConfigDict(populate_by_name=True, extra="allow") + + common: Optional[SpanCommon] = None + tool_call_id: Optional[str] = Field(None, description="Tool call identifier") + + +class RetrieverSpanDetails(BaseModel): + """Retriever-specific span details.""" + + model_config = ConfigDict(populate_by_name=True, extra="allow") + + common: Optional[SpanCommon] = None + + +class WorkflowSpanDetails(BaseModel): + """Workflow span containing nested sub-spans.""" + + model_config = ConfigDict(populate_by_name=True, extra="allow") + + spans: List["Span"] = Field(default_factory=list, description="Nested sub-spans") class Span(BaseModel): """ - Represents a span within a trace (e.g., LLM call, retriever, tool). + Represents a span within a trace (e.g., LLM call, retriever, tool, workflow). - created_at: RFC3339 timestamp (protobuf Timestamp) - - input/output: json + - input/output: json (must be dict for protobuf Struct compatibility) + - For workflow spans, contains nested sub-spans in the 'workflow' field """ model_config = ConfigDict(populate_by_name=True, extra="allow") @@ -29,6 +87,23 @@ class Span(BaseModel): output: Dict[str, Any] type: TraceSpanType = Field(default=TraceSpanType.TRACE_SPAN_TYPE_UNKNOWN) + # Common fields for all span types + common: Optional[SpanCommon] = Field(None, description="Common span metadata") + + # Type-specific fields + llm: Optional[LLMSpanDetails] = Field(None, description="LLM-specific details") + tool: Optional[ToolSpanDetails] = Field(None, description="Tool-specific details") + retriever: Optional[RetrieverSpanDetails] = Field( + None, description="Retriever-specific details" + ) + workflow: Optional[WorkflowSpanDetails] = Field( + None, description="Workflow span with nested sub-spans" + ) + + +# Update forward reference for WorkflowSpanDetails +WorkflowSpanDetails.model_rebuild() + class Trace(BaseModel): """ @@ -919,4 +994,4 @@ class ListEvaluationMetricsOutput(BaseModel): metrics: List[EvaluationMetric] = Field( default_factory=list, description="List of evaluation metrics" - ) + ) \ No newline at end of file diff --git a/gradient_adk/runtime/digitalocean_tracker.py b/gradient_adk/runtime/digitalocean_tracker.py index 6411de2..402b33e 100644 --- a/gradient_adk/runtime/digitalocean_tracker.py +++ b/gradient_adk/runtime/digitalocean_tracker.py @@ -12,6 +12,11 @@ Trace, Span, TraceSpanType, + SpanCommon, + LLMSpanDetails, + ToolSpanDetails, + RetrieverSpanDetails, + WorkflowSpanDetails, ) from .interfaces import NodeExecution @@ -324,22 +329,172 @@ def _to_span(self, ex: NodeExecution) -> Span: out = dict(out) out["_llm_endpoints"] = list(ex.metadata["llm_endpoints"]) - # classify LLM/tool/retriever via metadata set by the instrumentor + # classify span type via metadata set by the instrumentor metadata = ex.metadata or {} - if metadata.get("is_llm_call"): + + # Check if this is a workflow span + if metadata.get("is_workflow"): + span_type = TraceSpanType.TRACE_SPAN_TYPE_WORKFLOW + + # Build sub-spans from the workflow's collected spans + sub_spans_list = metadata.get("sub_spans", []) + sub_spans = [self._to_span(sub) for sub in sub_spans_list] + + # Calculate duration from start to end + duration_ns = None + if ex.start_time and ex.end_time: + duration_ns = int( + (ex.end_time - ex.start_time).total_seconds() * 1_000_000_000 + ) + + # Build common fields + common = SpanCommon( + duration_ns=duration_ns, + metadata={"agent_name": metadata.get("agent_name")}, + status_code=200 if ex.error is None else 500, + ) + + # Build workflow details with nested sub-spans + workflow_details = WorkflowSpanDetails(spans=sub_spans) + + return Span( + created_at=_utc(ex.start_time), + name=ex.node_name, + input=inp, + output=out, + type=span_type, + common=common, + workflow=workflow_details, + ) + elif metadata.get("is_llm_call"): span_type = TraceSpanType.TRACE_SPAN_TYPE_LLM + + # Calculate duration + duration_ns = None + if ex.start_time and ex.end_time: + duration_ns = int( + (ex.end_time - ex.start_time).total_seconds() * 1_000_000_000 + ) + + # Build LLM-specific details + llm_common = SpanCommon( + duration_ns=duration_ns, + status_code=200 if ex.error is None else 500, + ) + + # Extract LLM-specific fields from captured API payloads + llm_request = metadata.get("llm_request_payload", {}) or {} + llm_response = metadata.get("llm_response_payload", {}) or {} + + # For LLM spans, use just the messages as input (not the full request payload) + # Must be a dict (not array) because protobuf Struct requires key-value pairs + if isinstance(llm_request, dict) and "messages" in llm_request: + inp = {"messages": llm_request.get("messages")} + + # For LLM spans, use just the choices as output (not the full response payload) + # Must be a dict (not array) because protobuf Struct requires key-value pairs + if isinstance(llm_response, dict) and "choices" in llm_response: + out = {"choices": llm_response.get("choices")} + + # Extract model from request payload, fallback to metadata or node name + model = ( + llm_request.get("model") + or metadata.get("model_name") + or ex.node_name.replace("llm:", "") + ) + + # Extract tools from request payload + tools = llm_request.get("tools") if isinstance(llm_request, dict) else None + + # Extract temperature from request payload + temperature = llm_request.get("temperature") if isinstance(llm_request, dict) else None + + # Extract token counts from response payload + num_input_tokens = None + num_output_tokens = None + total_tokens = None + if isinstance(llm_response, dict): + usage = llm_response.get("usage", {}) + if isinstance(usage, dict): + num_input_tokens = usage.get("prompt_tokens") + num_output_tokens = usage.get("completion_tokens") + total_tokens = usage.get("total_tokens") + + # Get time-to-first-token for streaming calls + time_to_first_token_ns = metadata.get("time_to_first_token_ns") + + llm_details = LLMSpanDetails( + common=llm_common, + model=model, + tools=tools, + temperature=temperature, + num_input_tokens=num_input_tokens, + num_output_tokens=num_output_tokens, + total_tokens=total_tokens, + time_to_first_token_ns=time_to_first_token_ns, + ) + + return Span( + created_at=_utc(ex.start_time), + name=ex.node_name, + input=inp, + output=out, + type=span_type, + llm=llm_details, + ) elif metadata.get("is_retriever_call"): span_type = TraceSpanType.TRACE_SPAN_TYPE_RETRIEVER + + # Calculate duration + duration_ns = None + if ex.start_time and ex.end_time: + duration_ns = int( + (ex.end_time - ex.start_time).total_seconds() * 1_000_000_000 + ) + + # Build retriever-specific details + retriever_common = SpanCommon( + duration_ns=duration_ns, + status_code=200 if ex.error is None else 500, + ) + + retriever_details = RetrieverSpanDetails(common=retriever_common) + + return Span( + created_at=_utc(ex.start_time), + name=ex.node_name, + input=inp, + output=out, + type=span_type, + retriever=retriever_details, + ) else: + # Default to tool span span_type = TraceSpanType.TRACE_SPAN_TYPE_TOOL - return Span( - created_at=_utc(ex.start_time), - name=ex.node_name, - input=inp, - output=out, - type=span_type, - ) + # Calculate duration + duration_ns = None + if ex.start_time and ex.end_time: + duration_ns = int( + (ex.end_time - ex.start_time).total_seconds() * 1_000_000_000 + ) + + # Build tool-specific details + tool_common = SpanCommon( + duration_ns=duration_ns, + status_code=200 if ex.error is None else 500, + ) + + tool_details = ToolSpanDetails(common=tool_common) + + return Span( + created_at=_utc(ex.start_time), + name=ex.node_name, + input=inp, + output=out, + type=span_type, + tool=tool_details, + ) def _coerce_top(self, val: Any, kind: str) -> Dict[str, Any]: """ @@ -374,4 +529,4 @@ def _build_trace(self) -> Trace: output=outputs, spans=spans, ) - return trace + return trace \ No newline at end of file diff --git a/gradient_adk/runtime/helpers.py b/gradient_adk/runtime/helpers.py new file mode 100644 index 0000000..99d036d --- /dev/null +++ b/gradient_adk/runtime/helpers.py @@ -0,0 +1,287 @@ +""" +Centralized instrumentation helpers for Gradient ADK. + +Provides a registry pattern for optional framework instrumentors (PydanticAI, LangGraph, etc.) +that handles: +- Tracker creation and lifecycle management +- Environment variable disable checks +- Framework availability checks +- Automatic installation at import time +""" + +from __future__ import annotations +import os +from typing import Optional, Callable, Dict, Any, Protocol +from gradient_adk.cli.config.yaml_agent_config_manager import YamlAgentConfigManager +from gradient_adk.runtime.digitalocean_tracker import DigitalOceanTracesTracker +from gradient_adk.digital_ocean_api import AsyncDigitalOceanGenAI +from gradient_adk.runtime.network_interceptor import setup_digitalocean_interception + + +class InstrumentorProtocol(Protocol): + """Protocol for instrumentor classes.""" + + def install(self, tracker: DigitalOceanTracesTracker) -> None: + """Install the instrumentor with the given tracker.""" + ... + + def uninstall(self) -> None: + """Uninstall the instrumentor and restore original behavior.""" + ... + + def is_installed(self) -> bool: + """Check if the instrumentor is currently installed.""" + ... + + +class InstrumentorRegistry: + """ + Registry for optional framework instrumentors. + + Provides centralized management for: + - Tracker creation (single shared tracker for all instrumentors) + - Framework availability checks + - Environment variable disable flags + - Installation lifecycle + + Usage: + # In instrumentor module: + registry.register( + name="pydanticai", + env_disable_var="GRADIENT_DISABLE_PYDANTICAI_INSTRUMENTOR", + availability_check=lambda: _has_pydantic_ai(), + instrumentor_factory=lambda: PydanticAIInstrumentor() + ) + + # In decorator or main module: + registry.install_all() + tracker = registry.get_tracker() + """ + + def __init__(self): + self._tracker: Optional[DigitalOceanTracesTracker] = None + self._instrumentors: Dict[str, InstrumentorProtocol] = {} + self._registrations: Dict[str, Dict[str, Any]] = {} + self._config_reader = YamlAgentConfigManager() + self._tracker_initialized = False + + def _is_env_disabled(self, env_var: str) -> bool: + """Check if instrumentation is disabled via environment variable.""" + val = os.environ.get(env_var, "").lower() + return val in ("true", "1", "yes") + + def _ensure_tracker(self) -> Optional[DigitalOceanTracesTracker]: + """ + Create the shared tracker if not already created. + + Returns None if: + - No API token is available + - Tracker already failed to initialize + """ + if self._tracker is not None: + return self._tracker + + if self._tracker_initialized: + # Already tried and failed + return None + + self._tracker_initialized = True + + try: + api_token = os.environ.get("DIGITALOCEAN_API_TOKEN") + if not api_token: + return None + + ws = self._config_reader.get_agent_name() + dep = self._config_reader.get_agent_environment() + + self._tracker = DigitalOceanTracesTracker( + client=AsyncDigitalOceanGenAI(api_token=api_token), + agent_workspace_name=ws, + agent_deployment_name=dep, + ) + setup_digitalocean_interception() + return self._tracker + + except Exception: + return None + + def register( + self, + name: str, + env_disable_var: str, + availability_check: Callable[[], bool], + instrumentor_factory: Callable[[], InstrumentorProtocol], + ) -> None: + """ + Register an instrumentor for later installation. + + Args: + name: Unique name for this instrumentor (e.g., "pydanticai", "langgraph") + env_disable_var: Environment variable name to disable this instrumentor + availability_check: Callable that returns True if the framework is available + instrumentor_factory: Callable that creates the instrumentor instance + """ + self._registrations[name] = { + "env_disable_var": env_disable_var, + "availability_check": availability_check, + "instrumentor_factory": instrumentor_factory, + } + + def install(self, name: str) -> Optional[DigitalOceanTracesTracker]: + """ + Install a specific registered instrumentor. + + Returns the tracker if installation succeeded, None otherwise. + """ + if name in self._instrumentors: + # Already installed + return self._tracker + + if name not in self._registrations: + return None + + reg = self._registrations[name] + + # Check if disabled via env var + if self._is_env_disabled(reg["env_disable_var"]): + return None + + # Check if framework is available + if not reg["availability_check"](): + return None + + # Ensure we have a tracker + tracker = self._ensure_tracker() + if tracker is None: + return None + + # Create and install instrumentor + try: + instrumentor = reg["instrumentor_factory"]() + instrumentor.install(tracker) + self._instrumentors[name] = instrumentor + return tracker + except Exception: + return None + + def install_all(self) -> Optional[DigitalOceanTracesTracker]: + """ + Install all registered instrumentors that are available. + + Returns the tracker if at least one instrumentor was installed. + """ + for name in self._registrations: + self.install(name) + return self._tracker + + def uninstall(self, name: str) -> None: + """Uninstall a specific instrumentor.""" + if name in self._instrumentors: + try: + self._instrumentors[name].uninstall() + except Exception: + pass + del self._instrumentors[name] + + def uninstall_all(self) -> None: + """Uninstall all instrumentors.""" + for name in list(self._instrumentors.keys()): + self.uninstall(name) + + def get_tracker(self) -> Optional[DigitalOceanTracesTracker]: + """Get the shared tracker instance.""" + return self._tracker + + def is_installed(self, name: str) -> bool: + """Check if a specific instrumentor is installed.""" + return name in self._instrumentors + + def get_installed_names(self) -> list[str]: + """Get list of installed instrumentor names.""" + return list(self._instrumentors.keys()) + + +# Global registry instance +registry = InstrumentorRegistry() + + +def get_tracker() -> Optional[DigitalOceanTracesTracker]: + """Get the shared tracker from the global registry.""" + return registry.get_tracker() + + +# ---- Auto-registration functions for known instrumentors ---- +# These are called to register instrumentors without importing their heavy dependencies + + +def _register_langgraph() -> None: + """Register LangGraph instrumentor if available.""" + + def is_available() -> bool: + try: + from langgraph.graph import StateGraph + + return True + except ImportError: + return False + + def factory(): + from gradient_adk.runtime.langgraph.langgraph_instrumentor import ( + LangGraphInstrumentor, + ) + + return LangGraphInstrumentor() + + registry.register( + name="langgraph", + env_disable_var="GRADIENT_DISABLE_LANGGRAPH_INSTRUMENTOR", + availability_check=is_available, + instrumentor_factory=factory, + ) + + +def _register_pydanticai() -> None: + """Register PydanticAI instrumentor if available.""" + + def is_available() -> bool: + try: + from pydantic_ai import Agent + + return True + except ImportError: + return False + + def factory(): + from gradient_adk.runtime.pydanticai.pydanticai_instrumentor import ( + PydanticAIInstrumentor, + ) + + return PydanticAIInstrumentor() + + registry.register( + name="pydanticai", + env_disable_var="GRADIENT_DISABLE_PYDANTICAI_INSTRUMENTOR", + availability_check=is_available, + instrumentor_factory=factory, + ) + + +def register_all_instrumentors() -> None: + """Register all known instrumentors with the registry.""" + _register_langgraph() + _register_pydanticai() + + +def capture_all() -> Optional[DigitalOceanTracesTracker]: + """ + Register and install all available instrumentors. + + This is the main entry point for the decorator module. + Call this once at startup to automatically instrument all available frameworks. + + Returns: + The shared tracker if at least one instrumentor was installed, None otherwise. + """ + register_all_instrumentors() + return registry.install_all() diff --git a/gradient_adk/runtime/langgraph/helpers.py b/gradient_adk/runtime/langgraph/helpers.py deleted file mode 100644 index a5d5614..0000000 --- a/gradient_adk/runtime/langgraph/helpers.py +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import annotations -import os -from typing import Optional -from gradient_adk.cli.config.yaml_agent_config_manager import YamlAgentConfigManager -from gradient_adk.runtime.langgraph.langgraph_instrumentor import LangGraphInstrumentor -from gradient_adk.runtime.digitalocean_tracker import DigitalOceanTracesTracker -from gradient_adk.digital_ocean_api import AsyncDigitalOceanGenAI -from gradient_adk.runtime.network_interceptor import setup_digitalocean_interception - -_TRACKER: Optional[DigitalOceanTracesTracker] = None -_INSTALLED = False - -config_reader = YamlAgentConfigManager() - - -def capture_graph() -> None: - """Install DO tracing for LangGraph exactly once. - Must be called BEFORE graph.add_node/compile to capture spans. - """ - global _TRACKER, _INSTALLED - if _INSTALLED and _TRACKER: - return _TRACKER - - try: - api_token = os.environ["DIGITALOCEAN_API_TOKEN"] - except Exception as e: - # Only enable DO tracing if we have an API token - return - ws = config_reader.get_agent_name() - dep = config_reader.get_agent_environment() - - _TRACKER = DigitalOceanTracesTracker( - client=AsyncDigitalOceanGenAI(api_token=api_token), - agent_workspace_name=ws, - agent_deployment_name=dep, - ) - setup_digitalocean_interception() - LangGraphInstrumentor().install(_TRACKER) - _INSTALLED = True - - -def get_tracker() -> Optional[DigitalOceanTracesTracker]: - return _TRACKER diff --git a/gradient_adk/runtime/langgraph/langgraph_instrumentor.py b/gradient_adk/runtime/langgraph/langgraph_instrumentor.py index 8e5acf8..7a2ba29 100644 --- a/gradient_adk/runtime/langgraph/langgraph_instrumentor.py +++ b/gradient_adk/runtime/langgraph/langgraph_instrumentor.py @@ -2,7 +2,9 @@ import functools import inspect +import json import os +import time import uuid from copy import deepcopy from datetime import datetime, timezone @@ -330,6 +332,7 @@ def _finish_ok( ret: Any, intr, tok, + time_to_first_token_ns: Optional[int] = None, ): # NOTE: Async generators should be handled by the wrapper functions # (_wrap_async_func, _wrap_sync_func, etc.) BEFORE calling _finish_ok. @@ -346,11 +349,23 @@ def _finish_ok( meta = _ensure_meta(rec) if is_llm: meta["is_llm_call"] = True + # Store raw API payloads for LLM field extraction in tracker + if api_request: + meta["llm_request_payload"] = api_request + if api_response: + meta["llm_response_payload"] = api_response + # Store time-to-first-token if this was a streaming call + if time_to_first_token_ns is not None: + meta["time_to_first_token_ns"] = time_to_first_token_ns elif is_retriever: meta["is_retriever_call"] = True else: # Fallback: assume LLM call for backward compatibility meta["is_llm_call"] = True + if api_request: + meta["llm_request_payload"] = api_request + if api_response: + meta["llm_response_payload"] = api_response if api_request or api_response: # Use actual API payloads instead of function args @@ -398,6 +413,8 @@ def _wrap_async_func(node_name: str, func): @functools.wraps(func) async def _wrapped(*a, **kw): rec, snap, intr, tok = _start(node_name, a, kw) + # Record start time for TTFT calculation + start_time_ns = time.perf_counter_ns() try: ret = await func(*a, **kw) @@ -408,11 +425,16 @@ async def _wrapped(*a, **kw): ): async def _streaming_wrapper(gen): - import json - collected: list[str] = [] + ttft_ns: Optional[int] = None + first_chunk_received = False try: async for chunk in gen: + # Record time to first token + if not first_chunk_received: + first_chunk_received = True + ttft_ns = time.perf_counter_ns() - start_time_ns + # Convert chunk to string for collection if isinstance(chunk, bytes): chunk_str = chunk.decode( @@ -428,7 +450,7 @@ async def _streaming_wrapper(gen): collected.append(chunk_str) yield chunk - # Stream complete - finalize with collected content + # Stream complete - finalize with collected content and TTFT _finish_ok( rec, snap, @@ -437,6 +459,7 @@ async def _streaming_wrapper(gen): {"content": "".join(collected)}, intr, tok, + time_to_first_token_ns=ttft_ns, ) except BaseException as e: _finish_err(rec, intr, tok, e) @@ -458,6 +481,8 @@ def _wrap_sync_func(node_name: str, func): @functools.wraps(func) def _wrapped(*a, **kw): rec, snap, intr, tok = _start(node_name, a, kw) + # Record start time for TTFT calculation + start_time_ns = time.perf_counter_ns() try: ret = func(*a, **kw) @@ -468,11 +493,16 @@ def _wrapped(*a, **kw): ): async def _streaming_wrapper(gen): - import json - collected: list[str] = [] + ttft_ns: Optional[int] = None + first_chunk_received = False try: async for chunk in gen: + # Record time to first token + if not first_chunk_received: + first_chunk_received = True + ttft_ns = time.perf_counter_ns() - start_time_ns + # Convert chunk to string for collection if isinstance(chunk, bytes): chunk_str = chunk.decode( @@ -488,7 +518,7 @@ async def _streaming_wrapper(gen): collected.append(chunk_str) yield chunk - # Stream complete - finalize with collected content + # Stream complete - finalize with collected content and TTFT _finish_ok( rec, snap, @@ -497,6 +527,7 @@ async def _streaming_wrapper(gen): {"content": "".join(collected)}, intr, tok, + time_to_first_token_ns=ttft_ns, ) except BaseException as e: _finish_err(rec, intr, tok, e) @@ -518,12 +549,21 @@ def _wrap_async_gen(node_name: str, func): @functools.wraps(func) async def _wrapped(*a, **kw): rec, snap, intr, tok = _start(node_name, a, kw) + # Record start time for TTFT calculation + start_time_ns = time.perf_counter_ns() + ttft_ns: Optional[int] = None + first_chunk_received = False try: # Accumulate a compact, canonical final payload # (string: concatenate; list: extend; else: last write wins) acc: Dict[str, Any] = {} async for chunk in func(*a, **kw): + # Record time to first token + if not first_chunk_received: + first_chunk_received = True + ttft_ns = time.perf_counter_ns() - start_time_ns + # Merge into acc for the final on_node_end payload for k, v in chunk.items(): if isinstance(v, str): @@ -538,8 +578,8 @@ async def _wrapped(*a, **kw): # Pass the live chunk downstream unchanged yield chunk - # Finish the span with the aggregated mapping - _finish_ok(rec, snap, a, kw, acc, intr, tok) + # Finish the span with the aggregated mapping and TTFT + _finish_ok(rec, snap, a, kw, acc, intr, tok, time_to_first_token_ns=ttft_ns) except BaseException as e: _finish_err(rec, intr, tok, e) raise @@ -550,6 +590,8 @@ async def _wrapped(*a, **kw): def _wrap_runnable_ainvoke(node_name: str, runnable): async def _wrapped(*a, **kw): rec, snap, intr, tok = _start(node_name, a, kw) + # Record start time for TTFT calculation + start_time_ns = time.perf_counter_ns() try: ret = await runnable.ainvoke(*a, **kw) @@ -559,11 +601,16 @@ async def _wrapped(*a, **kw): ): async def _streaming_wrapper(gen): - import json - collected: list[str] = [] + ttft_ns: Optional[int] = None + first_chunk_received = False try: async for chunk in gen: + # Record time to first token + if not first_chunk_received: + first_chunk_received = True + ttft_ns = time.perf_counter_ns() - start_time_ns + if isinstance(chunk, bytes): chunk_str = chunk.decode( "utf-8", errors="replace" @@ -586,6 +633,7 @@ async def _streaming_wrapper(gen): {"content": "".join(collected)}, intr, tok, + time_to_first_token_ns=ttft_ns, ) except BaseException as e: _finish_err(rec, intr, tok, e) @@ -605,6 +653,8 @@ async def _streaming_wrapper(gen): def _wrap_runnable_invoke(node_name: str, runnable): def _wrapped(*a, **kw): rec, snap, intr, tok = _start(node_name, a, kw) + # Record start time for TTFT calculation + start_time_ns = time.perf_counter_ns() try: ret = runnable.invoke(*a, **kw) @@ -614,11 +664,16 @@ def _wrapped(*a, **kw): ): async def _streaming_wrapper(gen): - import json - collected: list[str] = [] + ttft_ns: Optional[int] = None + first_chunk_received = False try: async for chunk in gen: + # Record time to first token + if not first_chunk_received: + first_chunk_received = True + ttft_ns = time.perf_counter_ns() - start_time_ns + if isinstance(chunk, bytes): chunk_str = chunk.decode( "utf-8", errors="replace" @@ -641,6 +696,7 @@ async def _streaming_wrapper(gen): {"content": "".join(collected)}, intr, tok, + time_to_first_token_ns=ttft_ns, ) except BaseException as e: _finish_err(rec, intr, tok, e) diff --git a/gradient_adk/runtime/network_interceptor.py b/gradient_adk/runtime/network_interceptor.py index 728252c..3874caa 100644 --- a/gradient_adk/runtime/network_interceptor.py +++ b/gradient_adk/runtime/network_interceptor.py @@ -139,13 +139,10 @@ async def intercepted_httpx_send(self_client, request, **kwargs): ) # Don't read response body for streaming responses - it would buffer the entire stream! - # Check if this is a streaming response by looking at headers or response type - is_streaming = ( - response.headers.get("transfer-encoding") == "chunked" - or "text/event-stream" in response.headers.get("content-type", "") - or hasattr(response, "aiter_bytes") - or hasattr(response, "aiter_lines") - ) + # Check if this is a streaming response by looking at response headers + # Note: Only check content-type for SSE, not transfer-encoding (chunked is common for both) + content_type = response.headers.get("content-type", "") + is_streaming = "text/event-stream" in content_type if not is_streaming: response_payload = await _global_interceptor._extract_response_payload( diff --git a/gradient_adk/runtime/pydanticai/__init__.py b/gradient_adk/runtime/pydanticai/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gradient_adk/runtime/pydanticai/pydanticai_instrumentor.py b/gradient_adk/runtime/pydanticai/pydanticai_instrumentor.py new file mode 100644 index 0000000..ba0049b --- /dev/null +++ b/gradient_adk/runtime/pydanticai/pydanticai_instrumentor.py @@ -0,0 +1,819 @@ +from __future__ import annotations + +import time +import uuid +from contextvars import ContextVar +from contextlib import asynccontextmanager +from copy import deepcopy +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Mapping, Optional, Tuple, Dict, List + +from ..interfaces import NodeExecution +from ..digitalocean_tracker import DigitalOceanTracesTracker +from ..network_interceptor import ( + get_network_interceptor, + is_inference_url, + is_kbaas_url, +) + + +def _utc() -> datetime: + return datetime.now(timezone.utc) + + +def _mk_exec(name: str, inputs: Any, framework: str = "pydanticai") -> NodeExecution: + return NodeExecution( + node_id=str(uuid.uuid4()), + node_name=name, + framework=framework, + start_time=_utc(), + inputs=inputs, + ) + + +def _ensure_meta(rec: NodeExecution) -> dict: + md = getattr(rec, "metadata", None) + if not isinstance(md, dict): + md = {} + try: + rec.metadata = md + except Exception: + pass + return md + + +_MAX_DEPTH = 3 +_MAX_ITEMS = 100 # keep payloads bounded + + +def _freeze(obj: Any, depth: int = _MAX_DEPTH) -> Any: + """Mutation-safe, JSON-ish snapshot for arbitrary Python objects.""" + if obj is None or isinstance(obj, (str, int, float, bool)): + return obj + + # dict-like + if isinstance(obj, Mapping): + out: Dict[str, Any] = {} + for i, (k, v) in enumerate(obj.items()): + if i >= _MAX_ITEMS: + out[""] = True + break + out[str(k)] = _freeze(v, depth - 1) + return out + + # sequences + if isinstance(obj, (list, tuple, set)): + seq = list(obj) + out = [] + for i, v in enumerate(seq): + if i >= _MAX_ITEMS: + out.append("") + break + out.append(_freeze(v, depth - 1)) + return out + + # pydantic + try: + from pydantic import BaseModel + + if isinstance(obj, BaseModel): + return _freeze(obj.model_dump(), depth - 1) + except Exception: + pass + + # dataclass + try: + import dataclasses + + if dataclasses.is_dataclass(obj): + return _freeze(dataclasses.asdict(obj), depth - 1) + except Exception: + pass + + # PydanticAI result types + try: + if hasattr(obj, "output"): + return _freeze(obj.output, depth - 1) + if hasattr(obj, "data"): + return _freeze(obj.data, depth - 1) + except Exception: + pass + + # fallback + return repr(obj) + + +def _snapshot_args_kwargs(a: Tuple[Any, ...], kw: Dict[str, Any]) -> Any: + """Deepcopy then freeze to avoid mutation surprises.""" + try: + a_copy = deepcopy(a) + kw_copy = deepcopy(kw) + except Exception: + a_copy, kw_copy = a, kw # best-effort + + # If there's exactly one arg and no kwargs, return just that arg + if len(a_copy) == 1 and not kw_copy: + return _freeze(a_copy[0]) + + # If there are kwargs but no args, return just the kwargs + if not a_copy and kw_copy: + return _freeze(kw_copy) + + # If there are multiple args or both args and kwargs, return a dict + if a_copy and kw_copy: + return {"args": _freeze(a_copy), "kwargs": _freeze(kw_copy)} + elif len(a_copy) > 1: + return _freeze(a_copy) + + # Fallback + return _freeze(a_copy) + + +def _snap(): + intr = get_network_interceptor() + try: + tok = intr.snapshot_token() + except Exception: + tok = 0 + return intr, tok + + +def _had_hits_since(intr, token) -> bool: + try: + return intr.hits_since(token) > 0 + except Exception: + return False + + +def _get_captured_payloads_with_type(intr, token) -> tuple: + """Get captured API request/response payloads and classify the call type. + + Returns: + (request_payload, response_payload, is_llm, is_retriever) + """ + try: + captured = intr.get_captured_requests_since(token) + if captured: + # Use the first captured request (most common case) + call = captured[0] + url = call.url + is_llm = is_inference_url(url) + is_retriever = is_kbaas_url(url) + return call.request_payload, call.response_payload, is_llm, is_retriever + except Exception: + pass + return None, None, False, False + + +def _transform_kbaas_response(response: Optional[Dict[str, Any]]) -> Optional[list]: + """Transform KBaaS response to standard retriever format.""" + if not isinstance(response, dict): + return response + + results = response.get("results", []) + if not isinstance(results, list): + return response + + transformed_results = [] + for item in results: + if isinstance(item, dict): + new_item = dict(item) + + if "parent_chunk_text" in new_item: + new_item["page_content"] = new_item.pop("parent_chunk_text") + if "text_content" in new_item: + new_item["embedded_content"] = new_item.pop("text_content") + elif "text_content" in new_item: + new_item["page_content"] = new_item.pop("text_content") + + transformed_results.append(new_item) + else: + transformed_results.append(item) + + return transformed_results + + +def _extract_messages_input(messages: List[Any]) -> Any: + """Extract a clean representation of the messages sent to the LLM.""" + try: + result = [] + for msg in messages: + if hasattr(msg, "parts"): + # ModelRequest or ModelResponse + msg_data = {"kind": msg.__class__.__name__, "parts": []} + for part in msg.parts: + part_data = _freeze(part) + msg_data["parts"].append(part_data) + if hasattr(msg, "instructions") and msg.instructions: + msg_data["instructions"] = msg.instructions + result.append(msg_data) + else: + result.append(_freeze(msg)) + return result + except Exception: + return _freeze(messages) + + +def _extract_model_response_output(response: Any) -> Any: + """Extract a clean representation of the model response.""" + try: + if hasattr(response, "parts"): + result = {"parts": []} + for part in response.parts: + part_data = _freeze(part) + result["parts"].append(part_data) + if hasattr(response, "usage") and response.usage: + result["usage"] = _freeze(response.usage) + if hasattr(response, "model_name") and response.model_name: + result["model_name"] = response.model_name + return result + return _freeze(response) + except Exception: + return _freeze(response) + + +# ---- Workflow Context Management ---- + + +@dataclass +class WorkflowContext: + """Context for tracking a workflow (Agent.run) and its sub-spans.""" + + node: NodeExecution + sub_spans: List[NodeExecution] = field(default_factory=list) + agent_name: str = "" + + +# Context variable to track the current workflow +_current_workflow: ContextVar[Optional[WorkflowContext]] = ContextVar( + "pydanticai_workflow", default=None +) + + +def _get_current_workflow() -> Optional[WorkflowContext]: + """Get the current workflow context, if any.""" + return _current_workflow.get() + + +def _set_current_workflow(ctx: Optional[WorkflowContext]) -> None: + """Set the current workflow context.""" + _current_workflow.set(ctx) + + +class PydanticAIInstrumentor: + """Wraps PydanticAI agents with tracing using workflow spans.""" + + def __init__(self) -> None: + self._installed = False + self._tracker: Optional[DigitalOceanTracesTracker] = None + self._original_call_tool = None + self._original_model_requests: Dict[type, Any] = {} + self._original_model_request_streams: Dict[type, Any] = {} + self._original_agent_run: Any = None + self._original_agent_run_sync: Any = None + self._original_agent_run_stream: Any = None + + def install(self, tracker: DigitalOceanTracesTracker) -> None: + if self._installed: + return + self._tracker = tracker + + try: + from pydantic_ai import Agent + from pydantic_ai.models import Model + except ImportError: + # PydanticAI not installed, skip instrumentation + return + + t = tracker # close over + + def _start_sub_span(node_name: str, inputs: Any): + """Start a sub-span that will be nested inside the current workflow.""" + inputs_snapshot = _freeze(inputs) + rec = _mk_exec(node_name, inputs_snapshot) + intr, tok = _snap() + + # Check if we're inside a workflow context + workflow = _get_current_workflow() + if workflow is not None: + # Don't call tracker.on_node_start - we'll batch these with the workflow + pass + else: + # No workflow context - fall back to flat spans + t.on_node_start(rec) + + return rec, inputs_snapshot, intr, tok + + def _finish_sub_span_ok( + rec: NodeExecution, + inputs_snapshot: Any, + ret: Any, + intr, + tok, + time_to_first_token_ns: Optional[int] = None, + ): + """Finish a sub-span successfully.""" + # Check if this node made any tracked API calls + if _had_hits_since(intr, tok): + api_request, api_response, is_llm, is_retriever = ( + _get_captured_payloads_with_type(intr, tok) + ) + + meta = _ensure_meta(rec) + if is_llm: + meta["is_llm_call"] = True + # Store raw API payloads for LLM field extraction in tracker + if api_request: + meta["llm_request_payload"] = api_request + if api_response: + meta["llm_response_payload"] = api_response + # Store time-to-first-token if this was a streaming call + if time_to_first_token_ns is not None: + meta["time_to_first_token_ns"] = time_to_first_token_ns + elif is_retriever: + meta["is_retriever_call"] = True + else: + meta["is_llm_call"] = True + if api_request: + meta["llm_request_payload"] = api_request + if api_response: + meta["llm_response_payload"] = api_response + + if api_request or api_response: + if api_request: + rec.inputs = _freeze(api_request) + + if api_response: + if is_retriever: + api_response = _transform_kbaas_response(api_response) + out_payload = _freeze(api_response) + else: + out_payload = _freeze(ret) + else: + out_payload = _freeze(ret) + else: + out_payload = _freeze(ret) + + rec.end_time = _utc() + rec.outputs = out_payload + + # Check if we're inside a workflow context + workflow = _get_current_workflow() + if workflow is not None: + # Add to workflow's sub-spans + workflow.sub_spans.append(rec) + else: + # No workflow context - call tracker directly + t.on_node_end(rec, out_payload) + + def _finish_sub_span_err(rec: NodeExecution, intr, tok, e: BaseException): + """Finish a sub-span with an error.""" + if _had_hits_since(intr, tok): + api_request, api_response, is_llm, is_retriever = _get_captured_payloads_with_type( + intr, tok + ) + + meta = _ensure_meta(rec) + if is_llm: + meta["is_llm_call"] = True + # Store raw API payloads for LLM field extraction in tracker + if api_request: + meta["llm_request_payload"] = api_request + if api_response: + meta["llm_response_payload"] = api_response + elif is_retriever: + meta["is_retriever_call"] = True + else: + meta["is_llm_call"] = True + if api_request: + meta["llm_request_payload"] = api_request + + if api_request: + rec.inputs = _freeze(api_request) + + rec.end_time = _utc() + rec.error = str(e) + + # Check if we're inside a workflow context + workflow = _get_current_workflow() + if workflow is not None: + # Add to workflow's sub-spans + workflow.sub_spans.append(rec) + else: + # No workflow context - call tracker directly + t.on_node_error(rec, e) + + # Import FunctionToolset for tool call instrumentation + try: + from pydantic_ai.toolsets.function import FunctionToolset + + self._original_call_tool = FunctionToolset.call_tool + except ImportError: + self._original_call_tool = None + + # Get all concrete model classes that need patching + model_classes: List[type] = [] + try: + from pydantic_ai.models.openai import OpenAIChatModel, OpenAIResponsesModel + + model_classes.extend([OpenAIChatModel, OpenAIResponsesModel]) + except ImportError: + pass + try: + from pydantic_ai.models.anthropic import AnthropicModel + + model_classes.append(AnthropicModel) + except ImportError: + pass + try: + from pydantic_ai.models.google import GoogleModel + + model_classes.append(GoogleModel) + except ImportError: + pass + try: + from pydantic_ai.models.gemini import GeminiModel + + model_classes.append(GeminiModel) + except ImportError: + pass + try: + from pydantic_ai.models.groq import GroqModel + + model_classes.append(GroqModel) + except ImportError: + pass + try: + from pydantic_ai.models.mistral import MistralModel + + model_classes.append(MistralModel) + except ImportError: + pass + try: + from pydantic_ai.models.cohere import CohereModel + + model_classes.append(CohereModel) + except ImportError: + pass + try: + from pydantic_ai.models.bedrock import BedrockConverseModel + + model_classes.append(BedrockConverseModel) + except ImportError: + pass + try: + from pydantic_ai.models.huggingface import HuggingFaceModel + + model_classes.append(HuggingFaceModel) + except ImportError: + pass + try: + from pydantic_ai.models.test import TestModel + + model_classes.append(TestModel) + except ImportError: + pass + try: + from pydantic_ai.models.function import FunctionModel + + model_classes.append(FunctionModel) + except ImportError: + pass + + # Create wrapper factory for each model class + def make_wrapped_request(original_request): + async def wrapped_model_request( + model_self, messages, model_settings, model_request_parameters + ): + """Wrapped Model.request that traces each LLM call as a sub-span.""" + model_name = getattr( + model_self, "model_name", model_self.__class__.__name__ + ) + node_name = f"llm:{model_name}" + + # Extract input from messages + inputs = _extract_messages_input(messages) + rec, snap, intr, tok = _start_sub_span(node_name, inputs) + + try: + response = await original_request( + model_self, messages, model_settings, model_request_parameters + ) + + # Extract output from the response + output = _extract_model_response_output(response) + _finish_sub_span_ok(rec, snap, output, intr, tok) + return response + except BaseException as e: + _finish_sub_span_err(rec, intr, tok, e) + raise + + return wrapped_model_request + + def make_wrapped_request_stream(original_request_stream): + @asynccontextmanager + async def wrapped_model_request_stream( + model_self, + messages, + model_settings, + model_request_parameters, + run_context=None, + ): + """Wrapped Model.request_stream that traces each streaming LLM call as a sub-span.""" + model_name = getattr( + model_self, "model_name", model_self.__class__.__name__ + ) + node_name = f"llm:{model_name}" + + # Extract input from messages + inputs = _extract_messages_input(messages) + rec, snap, intr, tok = _start_sub_span(node_name, inputs) + + try: + async with original_request_stream( + model_self, + messages, + model_settings, + model_request_parameters, + run_context, + ) as stream: + yield stream + + # After the stream is consumed, get the response + response = stream.get() + output = _extract_model_response_output(response) + _finish_sub_span_ok(rec, snap, output, intr, tok) + except BaseException as e: + _finish_sub_span_err(rec, intr, tok, e) + raise + + return wrapped_model_request_stream + + # Patch all concrete model classes + for model_cls in model_classes: + # Store originals + self._original_model_requests[model_cls] = model_cls.request + self._original_model_request_streams[model_cls] = model_cls.request_stream + + # Apply patches + model_cls.request = make_wrapped_request(model_cls.request) + model_cls.request_stream = make_wrapped_request_stream( + model_cls.request_stream + ) + + # Wrap FunctionToolset.call_tool to instrument tool calls as sub-spans + if self._original_call_tool is not None: + from pydantic_ai.toolsets.function import FunctionToolset + + original_call_tool = self._original_call_tool + + async def wrapped_call_tool(toolset_self, name, tool_args, ctx, tool): + """Wrapped call_tool that traces tool execution as a sub-span.""" + rec, snap, intr, tok = _start_sub_span(name, _freeze(tool_args)) + try: + result = await original_call_tool( + toolset_self, name, tool_args, ctx, tool + ) + _finish_sub_span_ok(rec, snap, result, intr, tok) + return result + except BaseException as e: + _finish_sub_span_err(rec, intr, tok, e) + raise + + FunctionToolset.call_tool = wrapped_call_tool + + # Wrap Agent.run, Agent.run_sync, Agent.run_stream to create workflow spans + self._original_agent_run = Agent.run + self._original_agent_run_sync = Agent.run_sync + self._original_agent_run_stream = Agent.run_stream + + async def wrapped_agent_run(agent_self, user_prompt, **kwargs): + """Wrapped Agent.run that creates a workflow span containing all sub-spans.""" + agent_name = ( + getattr(agent_self, "name", None) or agent_self.__class__.__name__ + ) + + # Create workflow node + inputs_snapshot = _freeze(user_prompt) + workflow_node = _mk_exec(agent_name, inputs_snapshot) + meta = _ensure_meta(workflow_node) + meta["is_workflow"] = True + meta["agent_name"] = agent_name + + # Create workflow context + workflow_ctx = WorkflowContext( + node=workflow_node, + agent_name=agent_name, + ) + + # Set the workflow context + prev_workflow = _get_current_workflow() + _set_current_workflow(workflow_ctx) + + try: + # Call original run method + result = await self._original_agent_run( + agent_self, user_prompt, **kwargs + ) + + # Finish workflow node + workflow_node.end_time = _utc() + # Extract just the output from the result, not the full state + if hasattr(result, "output"): + workflow_node.outputs = {"output": _freeze(result.output)} + elif hasattr(result, "data"): + workflow_node.outputs = {"output": _freeze(result.data)} + else: + workflow_node.outputs = _freeze(result) + + # Store sub-spans in metadata for the tracker to handle + meta["sub_spans"] = workflow_ctx.sub_spans + + # Report the workflow span to the tracker + t.on_node_start(workflow_node) + t.on_node_end(workflow_node, workflow_node.outputs) + + return result + except BaseException as e: + workflow_node.end_time = _utc() + workflow_node.error = str(e) + meta["sub_spans"] = workflow_ctx.sub_spans + + t.on_node_start(workflow_node) + t.on_node_error(workflow_node, e) + raise + finally: + # Restore previous workflow context + _set_current_workflow(prev_workflow) + + def wrapped_agent_run_sync(agent_self, user_prompt, **kwargs): + """Wrapped Agent.run_sync that creates a workflow span containing all sub-spans.""" + agent_name = ( + getattr(agent_self, "name", None) or agent_self.__class__.__name__ + ) + + # Create workflow node + inputs_snapshot = _freeze(user_prompt) + workflow_node = _mk_exec(agent_name, inputs_snapshot) + meta = _ensure_meta(workflow_node) + meta["is_workflow"] = True + meta["agent_name"] = agent_name + + # Create workflow context + workflow_ctx = WorkflowContext( + node=workflow_node, + agent_name=agent_name, + ) + + # Set the workflow context + prev_workflow = _get_current_workflow() + _set_current_workflow(workflow_ctx) + + try: + # Call original run_sync method + result = self._original_agent_run_sync( + agent_self, user_prompt, **kwargs + ) + + # Finish workflow node + workflow_node.end_time = _utc() + # Extract just the output from the result, not the full state + if hasattr(result, "output"): + workflow_node.outputs = {"output": _freeze(result.output)} + elif hasattr(result, "data"): + workflow_node.outputs = {"output": _freeze(result.data)} + else: + workflow_node.outputs = _freeze(result) + + # Store sub-spans in metadata for the tracker to handle + meta["sub_spans"] = workflow_ctx.sub_spans + + # Report the workflow span to the tracker + t.on_node_start(workflow_node) + t.on_node_end(workflow_node, workflow_node.outputs) + + return result + except BaseException as e: + workflow_node.end_time = _utc() + workflow_node.error = str(e) + meta["sub_spans"] = workflow_ctx.sub_spans + + t.on_node_start(workflow_node) + t.on_node_error(workflow_node, e) + raise + finally: + # Restore previous workflow context + _set_current_workflow(prev_workflow) + + @asynccontextmanager + async def wrapped_agent_run_stream(agent_self, user_prompt, **kwargs): + """Wrapped Agent.run_stream that creates a workflow span containing all sub-spans.""" + agent_name = ( + getattr(agent_self, "name", None) or agent_self.__class__.__name__ + ) + + # Create workflow node + inputs_snapshot = _freeze(user_prompt) + workflow_node = _mk_exec(agent_name, inputs_snapshot) + meta = _ensure_meta(workflow_node) + meta["is_workflow"] = True + meta["agent_name"] = agent_name + + # Create workflow context + workflow_ctx = WorkflowContext( + node=workflow_node, + agent_name=agent_name, + ) + + # Set the workflow context + prev_workflow = _get_current_workflow() + _set_current_workflow(workflow_ctx) + + try: + async with self._original_agent_run_stream( + agent_self, user_prompt, **kwargs + ) as stream: + yield stream + + # Finish workflow node - get the result from the stream + try: + result = stream.result + # Extract just the output from the result, not the full state + if hasattr(result, "output"): + workflow_node.outputs = {"output": _freeze(result.output)} + elif hasattr(result, "data"): + workflow_node.outputs = {"output": _freeze(result.data)} + else: + workflow_node.outputs = _freeze(result) + except Exception: + workflow_node.outputs = {"streaming": True} + + workflow_node.end_time = _utc() + + # Store sub-spans in metadata for the tracker to handle + meta["sub_spans"] = workflow_ctx.sub_spans + + # Report the workflow span to the tracker + t.on_node_start(workflow_node) + t.on_node_end(workflow_node, workflow_node.outputs) + except BaseException as e: + workflow_node.end_time = _utc() + workflow_node.error = str(e) + meta["sub_spans"] = workflow_ctx.sub_spans + + t.on_node_start(workflow_node) + t.on_node_error(workflow_node, e) + raise + finally: + # Restore previous workflow context + _set_current_workflow(prev_workflow) + + Agent.run = wrapped_agent_run + Agent.run_sync = wrapped_agent_run_sync + Agent.run_stream = wrapped_agent_run_stream + + self._installed = True + + def uninstall(self) -> None: + """Remove instrumentation hooks.""" + if not self._installed: + return + + # Restore all patched model classes + for model_cls, original_request in self._original_model_requests.items(): + model_cls.request = original_request + for model_cls, original_stream in self._original_model_request_streams.items(): + model_cls.request_stream = original_stream + + # Clear the stored originals + self._original_model_requests.clear() + self._original_model_request_streams.clear() + + # Restore FunctionToolset.call_tool + if self._original_call_tool is not None: + try: + from pydantic_ai.toolsets.function import FunctionToolset + + FunctionToolset.call_tool = self._original_call_tool + except ImportError: + pass + + # Restore Agent methods + try: + from pydantic_ai import Agent + + if self._original_agent_run is not None: + Agent.run = self._original_agent_run + if self._original_agent_run_sync is not None: + Agent.run_sync = self._original_agent_run_sync + if self._original_agent_run_stream is not None: + Agent.run_stream = self._original_agent_run_stream + except ImportError: + pass + + self._installed = False + + def is_installed(self) -> bool: + """Check if instrumentation is currently installed.""" + return self._installed \ No newline at end of file diff --git a/gradient_adk/tracing.py b/gradient_adk/tracing.py index 6d70b49..189c94a 100644 --- a/gradient_adk/tracing.py +++ b/gradient_adk/tracing.py @@ -41,7 +41,7 @@ async def my_agent(input: dict, context: dict): from typing import Any, Callable, Dict, Optional, Tuple, TypeVar from .runtime.interfaces import NodeExecution -from .runtime.langgraph.helpers import get_tracker +from .runtime.helpers import get_tracker from .runtime.network_interceptor import get_network_interceptor F = TypeVar("F", bound=Callable[..., Any]) @@ -231,13 +231,24 @@ async def async_gen_wrapper(*args, **kwargs): collected.append(chunk_str) yield chunk - # Check for network activity (LLM calls) - only if not already marked - if span_type is None: - try: - if interceptor.hits_since(network_token) > 0: - _ensure_meta(span)["is_llm_call"] = True - except Exception: - pass + # Check for network activity and capture LLM payloads + try: + has_network_hits = interceptor.hits_since(network_token) > 0 + # For explicitly marked LLM spans OR auto-detected network activity + if has_network_hits or span_type == SpanType.LLM: + meta = _ensure_meta(span) + if span_type is None and has_network_hits: + meta["is_llm_call"] = True + # Get captured request/response payloads for LLM metadata extraction + captured = interceptor.get_captured_requests_since(network_token) + if captured: + call = captured[0] + if call.request_payload: + meta["llm_request_payload"] = call.request_payload + if call.response_payload: + meta["llm_response_payload"] = call.response_payload + except Exception: + pass # Stream complete - finalize span with collected content tracker.on_node_end(span, {"content": "".join(collected)}) @@ -282,13 +293,24 @@ async def async_wrapper(*args, **kwargs): try: result = await func(*args, **kwargs) - # Check for network activity (LLM calls) - only if not already marked - if span_type is None: - try: - if interceptor.hits_since(network_token) > 0: - _ensure_meta(span)["is_llm_call"] = True - except Exception: - pass + # Check for network activity and capture LLM payloads + try: + has_network_hits = interceptor.hits_since(network_token) > 0 + # For explicitly marked LLM spans OR auto-detected network activity + if has_network_hits or span_type == SpanType.LLM: + meta = _ensure_meta(span) + if span_type is None and has_network_hits: + meta["is_llm_call"] = True + # Get captured request/response payloads for LLM metadata extraction + captured = interceptor.get_captured_requests_since(network_token) + if captured: + call = captured[0] + if call.request_payload: + meta["llm_request_payload"] = call.request_payload + if call.response_payload: + meta["llm_response_payload"] = call.response_payload + except Exception: + pass # If the result is an async generator, wrap it so we can collect output # without double-iterating. We delay on_node_end until the stream is consumed. @@ -370,13 +392,24 @@ def sync_wrapper(*args, **kwargs): try: result = func(*args, **kwargs) - # Check for network activity (LLM calls) - only if not already marked - if span_type is None: - try: - if interceptor.hits_since(network_token) > 0: - _ensure_meta(span)["is_llm_call"] = True - except Exception: - pass + # Check for network activity and capture LLM payloads + try: + has_network_hits = interceptor.hits_since(network_token) > 0 + # For explicitly marked LLM spans OR auto-detected network activity + if has_network_hits or span_type == SpanType.LLM: + meta = _ensure_meta(span) + if span_type is None and has_network_hits: + meta["is_llm_call"] = True + # Get captured request/response payloads for LLM metadata extraction + captured = interceptor.get_captured_requests_since(network_token) + if captured: + call = captured[0] + if call.request_payload: + meta["llm_request_payload"] = call.request_payload + if call.response_payload: + meta["llm_response_payload"] = call.response_payload + except Exception: + pass # Check if result is an async generator - pass directly without snapshotting if result is not None and ( @@ -442,4 +475,4 @@ async def search(query: str) -> list: results = await db.search(query) return results """ - return _trace_base(name, span_type=SpanType.TOOL) + return _trace_base(name, span_type=SpanType.TOOL) \ No newline at end of file diff --git a/tests/runtime/helpers_test.py b/tests/runtime/helpers_test.py new file mode 100644 index 0000000..a583111 --- /dev/null +++ b/tests/runtime/helpers_test.py @@ -0,0 +1,293 @@ +"""Tests for the centralized InstrumentorRegistry in gradient_adk.runtime.helpers.""" + +import pytest +import os +from unittest.mock import MagicMock, patch + +from gradient_adk.runtime.helpers import ( + InstrumentorRegistry, + registry, + capture_all, + get_tracker, + register_all_instrumentors, + _register_langgraph, + _register_pydanticai, +) + + +# ----------------------------- +# Fixtures +# ----------------------------- + + +@pytest.fixture +def fresh_registry(): + """Create a fresh registry instance for testing.""" + return InstrumentorRegistry() + + +@pytest.fixture +def mock_instrumentor(): + """Create a mock instrumentor that follows the protocol.""" + inst = MagicMock() + inst.install = MagicMock() + inst.uninstall = MagicMock() + inst.is_installed = MagicMock(return_value=False) + return inst + + +# ----------------------------- +# Registry Basic Tests +# ----------------------------- + + +def test_registry_initial_state(fresh_registry): + """Test that a fresh registry starts with no tracker and no instrumentors.""" + assert fresh_registry.get_tracker() is None + assert fresh_registry.get_installed_names() == [] + assert not fresh_registry.is_installed("anything") + + +def test_register_adds_instrumentor(fresh_registry, mock_instrumentor): + """Test that register() adds an instrumentor to the registry.""" + fresh_registry.register( + name="test", + env_disable_var="TEST_DISABLE", + availability_check=lambda: True, + instrumentor_factory=lambda: mock_instrumentor, + ) + + assert "test" in fresh_registry._registrations + + +def test_install_without_api_token_returns_none(fresh_registry, mock_instrumentor): + """Test that install returns None when no API token is available.""" + fresh_registry.register( + name="test", + env_disable_var="TEST_DISABLE", + availability_check=lambda: True, + instrumentor_factory=lambda: mock_instrumentor, + ) + + # Ensure no API token + with patch.dict(os.environ, {}, clear=True): + result = fresh_registry.install("test") + + assert result is None + mock_instrumentor.install.assert_not_called() + + +def test_install_when_disabled_returns_none(fresh_registry, mock_instrumentor): + """Test that install returns None when disabled via env var.""" + fresh_registry.register( + name="test", + env_disable_var="TEST_DISABLE", + availability_check=lambda: True, + instrumentor_factory=lambda: mock_instrumentor, + ) + + with patch.dict(os.environ, {"TEST_DISABLE": "true"}): + result = fresh_registry.install("test") + + assert result is None + mock_instrumentor.install.assert_not_called() + + +def test_install_when_unavailable_returns_none(fresh_registry, mock_instrumentor): + """Test that install returns None when framework is unavailable.""" + fresh_registry.register( + name="test", + env_disable_var="TEST_DISABLE", + availability_check=lambda: False, # Not available + instrumentor_factory=lambda: mock_instrumentor, + ) + + result = fresh_registry.install("test") + + assert result is None + mock_instrumentor.install.assert_not_called() + + +def test_install_nonexistent_returns_none(fresh_registry): + """Test that install returns None for unregistered instrumentor.""" + result = fresh_registry.install("nonexistent") + assert result is None + + +def test_is_env_disabled_variations(fresh_registry): + """Test environment variable disable check with various values.""" + # True values + for val in ["true", "TRUE", "1", "yes", "YES"]: + with patch.dict(os.environ, {"TEST_VAR": val}): + assert fresh_registry._is_env_disabled("TEST_VAR") is True + + # False values + for val in ["false", "FALSE", "0", "no", "", "anything"]: + with patch.dict(os.environ, {"TEST_VAR": val}): + assert fresh_registry._is_env_disabled("TEST_VAR") is False + + # Missing var + with patch.dict(os.environ, {}, clear=True): + assert fresh_registry._is_env_disabled("MISSING_VAR") is False + + +def test_is_installed_returns_correct_state(fresh_registry): + """Test is_installed returns correct boolean state.""" + assert not fresh_registry.is_installed("test") + + # Simulate installation by adding to _instrumentors directly + mock_inst = MagicMock() + fresh_registry._instrumentors["test"] = mock_inst + + assert fresh_registry.is_installed("test") + + +def test_get_installed_names_returns_list(fresh_registry): + """Test get_installed_names returns list of installed instrumentor names.""" + assert fresh_registry.get_installed_names() == [] + + # Simulate installations + fresh_registry._instrumentors["test1"] = MagicMock() + fresh_registry._instrumentors["test2"] = MagicMock() + + names = fresh_registry.get_installed_names() + assert sorted(names) == ["test1", "test2"] + + +def test_uninstall_calls_uninstall_method(fresh_registry): + """Test that uninstall calls the instrumentor's uninstall method.""" + mock_inst = MagicMock() + fresh_registry._instrumentors["test"] = mock_inst + + fresh_registry.uninstall("test") + + mock_inst.uninstall.assert_called_once() + assert "test" not in fresh_registry._instrumentors + + +def test_uninstall_nonexistent_is_safe(fresh_registry): + """Test that uninstalling non-existent instrumentor doesn't raise.""" + # Should not raise + fresh_registry.uninstall("nonexistent") + + +def test_uninstall_all_clears_all_instrumentors(fresh_registry): + """Test that uninstall_all clears all instrumentors.""" + mock1 = MagicMock() + mock2 = MagicMock() + fresh_registry._instrumentors["test1"] = mock1 + fresh_registry._instrumentors["test2"] = mock2 + + fresh_registry.uninstall_all() + + mock1.uninstall.assert_called_once() + mock2.uninstall.assert_called_once() + assert fresh_registry.get_installed_names() == [] + + +def test_install_is_idempotent(fresh_registry, mock_instrumentor): + """Test that installing the same instrumentor twice is a no-op.""" + # Simulate already installed + fresh_registry._instrumentors["test"] = mock_instrumentor + + fresh_registry.register( + name="test", + env_disable_var="TEST_DISABLE", + availability_check=lambda: True, + instrumentor_factory=lambda: mock_instrumentor, + ) + + # Second install should not call install again + result = fresh_registry.install("test") + + # Should return the tracker (which may be None), but not call install + mock_instrumentor.install.assert_not_called() + + +# ----------------------------- +# Registration Function Tests +# ----------------------------- + + +def test_register_langgraph_adds_to_registry(): + """Test that _register_langgraph adds langgraph to registry.""" + test_registry = InstrumentorRegistry() + + # Patch the global registry temporarily + with patch("gradient_adk.runtime.helpers.registry", test_registry): + _register_langgraph() + + assert "langgraph" in test_registry._registrations + assert test_registry._registrations["langgraph"]["env_disable_var"] == "GRADIENT_DISABLE_LANGGRAPH_INSTRUMENTOR" + + +def test_register_pydanticai_adds_to_registry(): + """Test that _register_pydanticai adds pydanticai to registry.""" + test_registry = InstrumentorRegistry() + + # Patch the global registry temporarily + with patch("gradient_adk.runtime.helpers.registry", test_registry): + _register_pydanticai() + + assert "pydanticai" in test_registry._registrations + assert test_registry._registrations["pydanticai"]["env_disable_var"] == "GRADIENT_DISABLE_PYDANTICAI_INSTRUMENTOR" + + +def test_register_all_instrumentors_registers_both(): + """Test that register_all_instrumentors registers both frameworks.""" + test_registry = InstrumentorRegistry() + + with patch("gradient_adk.runtime.helpers.registry", test_registry): + register_all_instrumentors() + + assert "langgraph" in test_registry._registrations + assert "pydanticai" in test_registry._registrations + + +# ----------------------------- +# Global Registry Tests +# ----------------------------- + + +def test_global_registry_exists(): + """Test that the global registry instance exists.""" + assert registry is not None + assert isinstance(registry, InstrumentorRegistry) + + +def test_get_tracker_returns_global_tracker(): + """Test that get_tracker returns the global registry's tracker.""" + # Note: tracker may be None if not initialized, but function should work + result = get_tracker() + assert result == registry.get_tracker() + + +# ----------------------------- +# Availability Check Tests +# ----------------------------- + + +def test_langgraph_availability_check(): + """Test langgraph availability check function.""" + test_registry = InstrumentorRegistry() + + with patch("gradient_adk.runtime.helpers.registry", test_registry): + _register_langgraph() + + check = test_registry._registrations["langgraph"]["availability_check"] + + # Since we're running tests with langgraph installed, it should be available + assert check() is True + + +def test_pydanticai_availability_check(): + """Test pydanticai availability check function.""" + test_registry = InstrumentorRegistry() + + with patch("gradient_adk.runtime.helpers.registry", test_registry): + _register_pydanticai() + + check = test_registry._registrations["pydanticai"]["availability_check"] + + # Since we're running tests with pydantic-ai installed, it should be available + assert check() is True \ No newline at end of file diff --git a/tests/runtime/langgraph/langgraph_instrumentor_test.py b/tests/runtime/langgraph/langgraph_instrumentor_test.py index 9af210c..9a915bd 100644 --- a/tests/runtime/langgraph/langgraph_instrumentor_test.py +++ b/tests/runtime/langgraph/langgraph_instrumentor_test.py @@ -1,4 +1,5 @@ import pytest +import os from unittest.mock import MagicMock, patch from langgraph.graph import StateGraph @@ -9,7 +10,6 @@ _get_captured_payloads_with_type, ) - # ----------------------------- # Fixtures # ----------------------------- @@ -212,14 +212,14 @@ def test_transform_kbaas_response_converts_text_content_to_page_content(): "results": [ { "metadata": {"source": "doc1.pdf", "page": 1}, - "text_content": "This is the document content." + "text_content": "This is the document content.", }, { "metadata": {"source": "doc2.pdf", "page": 2}, - "text_content": "Another document chunk." - } + "text_content": "Another document chunk.", + }, ], - "total_results": 2 + "total_results": 2, } transformed = _transform_kbaas_response(response) @@ -241,10 +241,7 @@ def test_transform_kbaas_response_converts_text_content_to_page_content(): def test_transform_kbaas_response_handles_empty_results(): """Test that empty results list is handled correctly.""" - response = { - "results": [], - "total_results": 0 - } + response = {"results": [], "total_results": 0} transformed = _transform_kbaas_response(response) @@ -257,20 +254,17 @@ def test_transform_kbaas_response_preserves_items_without_text_content(): """Test that items without text_content are preserved unchanged.""" response = { "results": [ - { - "metadata": {"source": "doc1.pdf"}, - "text_content": "Has text content." - }, + {"metadata": {"source": "doc1.pdf"}, "text_content": "Has text content."}, { "metadata": {"source": "doc2.pdf"}, - "page_content": "Already has page_content." + "page_content": "Already has page_content.", }, { "metadata": {"source": "doc3.pdf"} # No text_content or page_content - } + }, ], - "total_results": 3 + "total_results": 3, } transformed = _transform_kbaas_response(response) @@ -318,15 +312,15 @@ def test_transform_kbaas_response_hierarchical_kb_with_parent_chunk(): { "metadata": {"source": "doc1.pdf", "page": 1}, "text_content": "This is the embedded chunk.", - "parent_chunk_text": "This is the full parent context with more information." + "parent_chunk_text": "This is the full parent context with more information.", }, { "metadata": {"source": "doc2.pdf", "page": 2}, "text_content": "Another embedded chunk.", - "parent_chunk_text": "Another parent context." - } + "parent_chunk_text": "Another parent context.", + }, ], - "total_results": 2 + "total_results": 2, } transformed = _transform_kbaas_response(response) @@ -336,7 +330,10 @@ def test_transform_kbaas_response_hierarchical_kb_with_parent_chunk(): assert len(transformed) == 2 # parent_chunk_text should become page_content - assert transformed[0]["page_content"] == "This is the full parent context with more information." + assert ( + transformed[0]["page_content"] + == "This is the full parent context with more information." + ) assert transformed[1]["page_content"] == "Another parent context." # text_content should become embedded_content @@ -359,10 +356,10 @@ def test_transform_kbaas_response_hierarchical_kb_parent_only(): "results": [ { "metadata": {"source": "doc1.pdf"}, - "parent_chunk_text": "Parent context only." + "parent_chunk_text": "Parent context only.", } ], - "total_results": 1 + "total_results": 1, } transformed = _transform_kbaas_response(response) @@ -381,18 +378,18 @@ def test_transform_kbaas_response_mixed_results(): { "metadata": {"source": "hierarchical.pdf"}, "text_content": "Embedded chunk.", - "parent_chunk_text": "Full parent context." + "parent_chunk_text": "Full parent context.", }, { "metadata": {"source": "standard.pdf"}, - "text_content": "Standard KB chunk." + "text_content": "Standard KB chunk.", }, { "metadata": {"source": "empty.pdf"} # No content fields - } + }, ], - "total_results": 3 + "total_results": 3, } transformed = _transform_kbaas_response(response) @@ -426,7 +423,7 @@ def test_retriever_hit_sets_metadata(tracker, interceptor): mock_captured.request_payload = {"query": "test query"} mock_captured.response_payload = { "results": [{"text_content": "doc content", "metadata": {}}], - "total_results": 1 + "total_results": 1, } interceptor.hits_since.return_value = 1 @@ -447,7 +444,10 @@ def node(state: dict): # NodeExecution record is arg0 to on_node_end exec_rec = tracker.on_node_end.call_args[0][0] assert exec_rec.metadata.get("is_retriever_call") is True - assert exec_rec.metadata.get("is_llm_call") is None or exec_rec.metadata.get("is_llm_call") is False + assert ( + exec_rec.metadata.get("is_llm_call") is None + or exec_rec.metadata.get("is_llm_call") is False + ) def test_retriever_response_is_transformed(tracker, interceptor): @@ -458,9 +458,12 @@ def test_retriever_response_is_transformed(tracker, interceptor): mock_captured.request_payload = {"query": "test query"} mock_captured.response_payload = { "results": [ - {"text_content": "Document content here", "metadata": {"source": "test.pdf"}} + { + "text_content": "Document content here", + "metadata": {"source": "test.pdf"}, + } ], - "total_results": 1 + "total_results": 1, } interceptor.hits_since.return_value = 1 @@ -514,7 +517,10 @@ def node(state: dict): # NodeExecution record is arg0 to on_node_end exec_rec = tracker.on_node_end.call_args[0][0] assert exec_rec.metadata.get("is_llm_call") is True - assert exec_rec.metadata.get("is_retriever_call") is None or exec_rec.metadata.get("is_retriever_call") is False + assert ( + exec_rec.metadata.get("is_retriever_call") is None + or exec_rec.metadata.get("is_retriever_call") is False + ) def test_get_captured_payloads_with_type_inference_url(): @@ -563,4 +569,4 @@ def test_get_captured_payloads_with_type_no_captures(): assert req is None assert resp is None assert is_llm is False - assert is_retriever is False \ No newline at end of file + assert is_retriever is False diff --git a/tests/runtime/llm_field_extraction_test.py b/tests/runtime/llm_field_extraction_test.py new file mode 100644 index 0000000..326a572 --- /dev/null +++ b/tests/runtime/llm_field_extraction_test.py @@ -0,0 +1,515 @@ +"""Tests for LLM field extraction in spans.""" + +import pytest +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock + +from gradient_adk.runtime.interfaces import NodeExecution +from gradient_adk.runtime.digitalocean_tracker import DigitalOceanTracesTracker +from gradient_adk.digital_ocean_api.models import TraceSpanType + + +def _utc() -> datetime: + return datetime.now(timezone.utc) + + +def create_mock_client(): + """Create a mock AsyncDigitalOceanGenAI client.""" + client = AsyncMock() + client.create_traces = AsyncMock(return_value=MagicMock(trace_uuids=["test-uuid"])) + client.aclose = AsyncMock() + return client + + +class TestLLMFieldExtraction: + """Tests for extracting LLM-specific fields from captured API payloads.""" + + def test_extract_model_from_request_payload(self): + """Test that model is extracted from the request payload.""" + client = create_mock_client() + tracker = DigitalOceanTracesTracker( + client=client, + agent_workspace_name="test-workspace", + agent_deployment_name="test-deployment", + ) + + # Create a node execution with LLM request payload in metadata + node = NodeExecution( + node_id="test-node-1", + node_name="llm:test-model", + framework="langgraph", + start_time=_utc(), + end_time=_utc(), + inputs={"messages": [{"role": "user", "content": "Hello"}]}, + outputs={"content": "Hi there!"}, + metadata={ + "is_llm_call": True, + "llm_request_payload": { + "model": "gpt-4o-mini", + "messages": [{"role": "user", "content": "Hello"}], + }, + "llm_response_payload": { + "choices": [{"message": {"content": "Hi there!"}}], + }, + }, + ) + + span = tracker._to_span(node) + + assert span.type == TraceSpanType.TRACE_SPAN_TYPE_LLM + assert span.llm is not None + assert span.llm.model == "gpt-4o-mini" + + def test_llm_input_only_contains_messages(self): + """Test that LLM span input only contains messages, not the full request payload.""" + client = create_mock_client() + tracker = DigitalOceanTracesTracker( + client=client, + agent_workspace_name="test-workspace", + agent_deployment_name="test-deployment", + ) + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, + ] + + node = NodeExecution( + node_id="test-node-messages", + node_name="llm:gpt-4", + framework="langgraph", + start_time=_utc(), + end_time=_utc(), + inputs={}, # Original inputs + outputs={"content": "Hi!"}, + metadata={ + "is_llm_call": True, + "llm_request_payload": { + "model": "gpt-4", + "messages": messages, + "temperature": 0.7, + "stream": False, # Should NOT appear in span input + "max_tokens": 100, # Should NOT appear in span input + }, + }, + ) + + span = tracker._to_span(node) + + # Span input should be {"messages": [...]} (wrapped in dict for protobuf Struct) + # Should NOT include model, temperature, stream, max_tokens, etc. + assert span.input == {"messages": messages} + assert "model" not in span.input + assert "temperature" not in span.input + assert "stream" not in span.input + assert "max_tokens" not in span.input + + def test_llm_output_only_contains_choices(self): + """Test that LLM span output only contains choices, not the full response payload.""" + client = create_mock_client() + tracker = DigitalOceanTracesTracker( + client=client, + agent_workspace_name="test-workspace", + agent_deployment_name="test-deployment", + ) + + choices = [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! How can I help you today?", + }, + "finish_reason": "stop", + } + ] + + node = NodeExecution( + node_id="test-node-output", + node_name="llm:gpt-4", + framework="langgraph", + start_time=_utc(), + end_time=_utc(), + inputs={}, + outputs={}, # Original outputs + metadata={ + "is_llm_call": True, + "llm_request_payload": {"model": "gpt-4", "messages": []}, + "llm_response_payload": { + "id": "chatcmpl-123", # Should NOT appear in span output + "object": "chat.completion", # Should NOT appear in span output + "created": 1234567890, # Should NOT appear in span output + "model": "gpt-4", # Should NOT appear in span output + "choices": choices, + "usage": { # Should NOT appear in span output + "prompt_tokens": 10, + "completion_tokens": 15, + "total_tokens": 25, + }, + }, + }, + ) + + span = tracker._to_span(node) + + # Span output should be {"choices": [...]} (wrapped in dict for protobuf Struct) + # Should NOT include id, object, created, model, usage, etc. + assert span.output == {"choices": choices} + assert "id" not in span.output + assert "object" not in span.output + assert "created" not in span.output + assert "model" not in span.output + assert "usage" not in span.output + + def test_extract_tools_from_request_payload(self): + """Test that tools are extracted from the request payload.""" + client = create_mock_client() + tracker = DigitalOceanTracesTracker( + client=client, + agent_workspace_name="test-workspace", + agent_deployment_name="test-deployment", + ) + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"}, + }, + "required": ["location"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "search_web", + "description": "Search the web", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string"}, + }, + "required": ["query"], + }, + }, + }, + ] + + node = NodeExecution( + node_id="test-node-2", + node_name="llm:gpt-4", + framework="langgraph", + start_time=_utc(), + end_time=_utc(), + inputs={"messages": []}, + outputs={"content": ""}, + metadata={ + "is_llm_call": True, + "llm_request_payload": { + "model": "gpt-4", + "messages": [], + "tools": tools, + }, + }, + ) + + span = tracker._to_span(node) + + assert span.llm is not None + assert span.llm.tools is not None + assert len(span.llm.tools) == 2 + assert span.llm.tools[0]["type"] == "function" + assert span.llm.tools[0]["function"]["name"] == "get_weather" + assert span.llm.tools[1]["function"]["name"] == "search_web" + + def test_extract_temperature_from_request_payload(self): + """Test that temperature is extracted from the request payload.""" + client = create_mock_client() + tracker = DigitalOceanTracesTracker( + client=client, + agent_workspace_name="test-workspace", + agent_deployment_name="test-deployment", + ) + + node = NodeExecution( + node_id="test-node-3", + node_name="llm:test", + framework="langgraph", + start_time=_utc(), + end_time=_utc(), + inputs={}, + outputs={}, + metadata={ + "is_llm_call": True, + "llm_request_payload": { + "model": "gpt-4", + "temperature": 0.7, + "messages": [], + }, + }, + ) + + span = tracker._to_span(node) + + assert span.llm is not None + assert span.llm.temperature == 0.7 + + def test_extract_token_counts_from_response_payload(self): + """Test that token counts are extracted from the response payload.""" + client = create_mock_client() + tracker = DigitalOceanTracesTracker( + client=client, + agent_workspace_name="test-workspace", + agent_deployment_name="test-deployment", + ) + + node = NodeExecution( + node_id="test-node-4", + node_name="llm:test", + framework="langgraph", + start_time=_utc(), + end_time=_utc(), + inputs={}, + outputs={}, + metadata={ + "is_llm_call": True, + "llm_request_payload": {"model": "gpt-4", "messages": []}, + "llm_response_payload": { + "choices": [{"message": {"content": "Hello!"}}], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + }, + }, + }, + ) + + span = tracker._to_span(node) + + assert span.llm is not None + assert span.llm.num_input_tokens == 10 + assert span.llm.num_output_tokens == 5 + assert span.llm.total_tokens == 15 + + def test_extract_time_to_first_token(self): + """Test that time_to_first_token_ns is extracted from metadata.""" + client = create_mock_client() + tracker = DigitalOceanTracesTracker( + client=client, + agent_workspace_name="test-workspace", + agent_deployment_name="test-deployment", + ) + + ttft_ns = 123456789 # ~123ms + + node = NodeExecution( + node_id="test-node-5", + node_name="llm:streaming-test", + framework="langgraph", + start_time=_utc(), + end_time=_utc(), + inputs={}, + outputs={"content": "Streamed response"}, + metadata={ + "is_llm_call": True, + "llm_request_payload": {"model": "gpt-4", "messages": []}, + "llm_response_payload": {}, + "time_to_first_token_ns": ttft_ns, + }, + ) + + span = tracker._to_span(node) + + assert span.llm is not None + assert span.llm.time_to_first_token_ns == ttft_ns + + def test_all_fields_together(self): + """Test that all LLM fields are extracted together correctly.""" + client = create_mock_client() + tracker = DigitalOceanTracesTracker( + client=client, + agent_workspace_name="test-workspace", + agent_deployment_name="test-deployment", + ) + + tools = [ + { + "type": "function", + "function": {"name": "test_tool", "description": "A test tool"}, + } + ] + + node = NodeExecution( + node_id="test-node-6", + node_name="llm:complete-test", + framework="langgraph", + start_time=_utc(), + end_time=_utc(), + inputs={}, + outputs={}, + metadata={ + "is_llm_call": True, + "llm_request_payload": { + "model": "claude-3-opus", + "temperature": 0.5, + "tools": tools, + "messages": [{"role": "user", "content": "Test"}], + }, + "llm_response_payload": { + "content": [{"type": "text", "text": "Response"}], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + }, + }, + "time_to_first_token_ns": 500000000, # 500ms + }, + ) + + span = tracker._to_span(node) + + assert span.type == TraceSpanType.TRACE_SPAN_TYPE_LLM + assert span.llm is not None + assert span.llm.model == "claude-3-opus" + assert span.llm.temperature == 0.5 + assert span.llm.tools == tools + assert span.llm.num_input_tokens == 100 + assert span.llm.num_output_tokens == 50 + assert span.llm.total_tokens == 150 + assert span.llm.time_to_first_token_ns == 500000000 + + def test_missing_llm_payloads_graceful_fallback(self): + """Test graceful handling when LLM payloads are missing.""" + client = create_mock_client() + tracker = DigitalOceanTracesTracker( + client=client, + agent_workspace_name="test-workspace", + agent_deployment_name="test-deployment", + ) + + # No llm_request_payload or llm_response_payload + node = NodeExecution( + node_id="test-node-7", + node_name="llm:fallback-test", + framework="langgraph", + start_time=_utc(), + end_time=_utc(), + inputs={}, + outputs={}, + metadata={"is_llm_call": True}, + ) + + span = tracker._to_span(node) + + assert span.type == TraceSpanType.TRACE_SPAN_TYPE_LLM + assert span.llm is not None + # Model should fallback to node name + assert span.llm.model == "fallback-test" + # Optional fields should be None + assert span.llm.tools is None + assert span.llm.temperature is None + assert span.llm.num_input_tokens is None + assert span.llm.num_output_tokens is None + assert span.llm.total_tokens is None + assert span.llm.time_to_first_token_ns is None + + def test_empty_usage_object_in_response(self): + """Test handling of empty usage object in response.""" + client = create_mock_client() + tracker = DigitalOceanTracesTracker( + client=client, + agent_workspace_name="test-workspace", + agent_deployment_name="test-deployment", + ) + + node = NodeExecution( + node_id="test-node-8", + node_name="llm:empty-usage", + framework="langgraph", + start_time=_utc(), + end_time=_utc(), + inputs={}, + outputs={}, + metadata={ + "is_llm_call": True, + "llm_request_payload": {"model": "test"}, + "llm_response_payload": {"usage": {}}, + }, + ) + + span = tracker._to_span(node) + + assert span.llm is not None + assert span.llm.num_input_tokens is None + assert span.llm.num_output_tokens is None + assert span.llm.total_tokens is None + + def test_model_fallback_priority(self): + """Test model extraction fallback priority: request > metadata > node_name.""" + client = create_mock_client() + tracker = DigitalOceanTracesTracker( + client=client, + agent_workspace_name="test-workspace", + agent_deployment_name="test-deployment", + ) + + # Case 1: Request payload has model + node1 = NodeExecution( + node_id="test-1", + node_name="llm:node-model", + framework="langgraph", + start_time=_utc(), + end_time=_utc(), + inputs={}, + outputs={}, + metadata={ + "is_llm_call": True, + "model_name": "metadata-model", + "llm_request_payload": {"model": "request-model"}, + }, + ) + span1 = tracker._to_span(node1) + assert span1.llm.model == "request-model" + + # Case 2: No request model, use metadata + node2 = NodeExecution( + node_id="test-2", + node_name="llm:node-model", + framework="langgraph", + start_time=_utc(), + end_time=_utc(), + inputs={}, + outputs={}, + metadata={ + "is_llm_call": True, + "model_name": "metadata-model", + "llm_request_payload": {}, + }, + ) + span2 = tracker._to_span(node2) + assert span2.llm.model == "metadata-model" + + # Case 3: No request or metadata model, use node_name + node3 = NodeExecution( + node_id="test-3", + node_name="llm:node-model", + framework="langgraph", + start_time=_utc(), + end_time=_utc(), + inputs={}, + outputs={}, + metadata={ + "is_llm_call": True, + "llm_request_payload": {}, + }, + ) + span3 = tracker._to_span(node3) + assert span3.llm.model == "node-model" \ No newline at end of file diff --git a/tests/runtime/pydanticai/pydanticai_instrumentor_test.py b/tests/runtime/pydanticai/pydanticai_instrumentor_test.py new file mode 100644 index 0000000..5e7b988 --- /dev/null +++ b/tests/runtime/pydanticai/pydanticai_instrumentor_test.py @@ -0,0 +1,274 @@ +import pytest +from unittest.mock import MagicMock, patch +import os + +# Skip all tests if pydantic-ai is not installed +pytest.importorskip("pydantic_ai") + +from pydantic_ai import Agent +from pydantic_ai.models.test import TestModel + +from gradient_adk.runtime.pydanticai.pydanticai_instrumentor import ( + PydanticAIInstrumentor, + _freeze, + _snapshot_args_kwargs, + _get_captured_payloads_with_type, + _transform_kbaas_response, + _extract_messages_input, + _extract_model_response_output, +) + + +@pytest.fixture +def tracker(): + t = MagicMock() + t.on_node_start = MagicMock() + t.on_node_end = MagicMock() + t.on_node_error = MagicMock() + return t + + +@pytest.fixture +def interceptor(): + intr = MagicMock() + intr.snapshot_token.return_value = 42 + intr.hits_since.return_value = 0 + return intr + + +@pytest.fixture(autouse=True) +def patch_interceptor(interceptor): + with patch( + "gradient_adk.runtime.pydanticai.pydanticai_instrumentor.get_network_interceptor", + return_value=interceptor, + ): + yield interceptor + + +@pytest.fixture +def instrumentor(tracker): + inst = PydanticAIInstrumentor() + inst.install(tracker) + yield inst + inst.uninstall() + + +@pytest.fixture +def test_model(): + return TestModel() + + +def test_install_monkeypatches_model(tracker): + inst = PydanticAIInstrumentor() + old_request = TestModel.request + old_request_stream = TestModel.request_stream + + inst.install(tracker) + + assert TestModel.request is not old_request + assert TestModel.request_stream is not old_request_stream + assert inst._installed + + inst.uninstall() + + assert TestModel.request is old_request + assert TestModel.request_stream is old_request_stream + assert not inst._installed + + +def test_install_is_idempotent(tracker): + inst = PydanticAIInstrumentor() + + inst.install(tracker) + first_request = TestModel.request + + inst.install(tracker) + assert TestModel.request is first_request + + inst.uninstall() + + +def test_is_installed_property(tracker): + inst = PydanticAIInstrumentor() + + assert not inst.is_installed() + + inst.install(tracker) + assert inst.is_installed() + + inst.uninstall() + assert not inst.is_installed() + + +@pytest.mark.asyncio +async def test_async_run_creates_workflow_span(tracker, instrumentor, test_model): + """Test that agent.run() creates a workflow span containing LLM sub-spans.""" + agent = Agent(test_model, system_prompt="Test agent") + + result = await agent.run("Hello, test!") + + # Should have one workflow span reported + assert tracker.on_node_start.call_count >= 1 + assert tracker.on_node_end.call_count >= 1 + tracker.on_node_error.assert_not_called() + + # Check that we got a workflow span + workflow_span_found = False + llm_sub_span_found = False + for call in tracker.on_node_start.call_args_list: + node_exec = call[0][0] + # Check metadata indicates it's a workflow + if node_exec.metadata.get("is_workflow") is True: + workflow_span_found = True + assert node_exec.framework == "pydanticai" + # Check for sub_spans containing LLM calls + sub_spans = node_exec.metadata.get("sub_spans", []) + for sub in sub_spans: + if "llm:" in sub.node_name: + llm_sub_span_found = True + break + break + assert workflow_span_found, "No workflow span was created" + assert llm_sub_span_found, "No LLM sub-span was found in the workflow" + + +def test_sync_run_creates_workflow_span(tracker, instrumentor, test_model): + """Test that agent.run_sync() creates a workflow span containing LLM sub-spans.""" + agent = Agent(test_model, system_prompt="Test agent") + + result = agent.run_sync("Hello, sync test!") + + # Should have one workflow span reported + assert tracker.on_node_start.call_count >= 1 + assert tracker.on_node_end.call_count >= 1 + tracker.on_node_error.assert_not_called() + + # Check that we got a workflow span + workflow_span_found = False + llm_sub_span_found = False + for call in tracker.on_node_start.call_args_list: + node_exec = call[0][0] + # Check metadata indicates it's a workflow + if node_exec.metadata.get("is_workflow") is True: + workflow_span_found = True + assert node_exec.framework == "pydanticai" + # Check for sub_spans containing LLM calls + sub_spans = node_exec.metadata.get("sub_spans", []) + for sub in sub_spans: + if "llm:" in sub.node_name: + llm_sub_span_found = True + break + break + assert workflow_span_found, "No workflow span was created" + assert llm_sub_span_found, "No LLM sub-span was found in the workflow" + + +def test_freeze_handles_primitives(): + assert _freeze(None) is None + assert _freeze("string") == "string" + assert _freeze(42) == 42 + assert _freeze(3.14) == 3.14 + assert _freeze(True) is True + + +def test_freeze_handles_dict(): + result = _freeze({"key": "value", "nested": {"a": 1}}) + assert result == {"key": "value", "nested": {"a": 1}} + + +def test_freeze_handles_list(): + result = _freeze([1, 2, 3, {"x": "y"}]) + assert result == [1, 2, 3, {"x": "y"}] + + +def test_get_captured_payloads_with_type_inference_url(): + mock_intr = MagicMock() + mock_captured = MagicMock() + mock_captured.url = "https://inference.do-ai.run/v1/chat" + mock_captured.request_payload = {"messages": []} + mock_captured.response_payload = {"choices": []} + + mock_intr.get_captured_requests_since.return_value = [mock_captured] + + req, resp, is_llm, is_retriever = _get_captured_payloads_with_type(mock_intr, 0) + + assert req == {"messages": []} + assert resp == {"choices": []} + assert is_llm is True + assert is_retriever is False + + +def test_transform_kbaas_response_converts_text_content(): + response = { + "results": [ + {"metadata": {"source": "doc1.pdf"}, "text_content": "Document content."} + ], + "total_results": 1, + } + + transformed = _transform_kbaas_response(response) + + assert isinstance(transformed, list) + assert len(transformed) == 1 + assert transformed[0]["page_content"] == "Document content." + assert "text_content" not in transformed[0] + + +def test_extract_messages_input_with_parts(): + class MockPart: + pass + + class MockMessage: + def __init__(self, parts, instructions=None): + self.parts = parts + self.instructions = instructions + + mock_part = MockPart() + msg = MockMessage([mock_part], instructions="Test instructions") + + result = _extract_messages_input([msg]) + + assert len(result) == 1 + assert result[0]["kind"] == "MockMessage" + assert "parts" in result[0] + assert result[0]["instructions"] == "Test instructions" + + +def test_extract_model_response_output_with_parts(): + class MockPart: + pass + + class MockResponse: + def __init__(self, parts, usage=None, model_name=None): + self.parts = parts + self.usage = usage + self.model_name = model_name + + mock_part = MockPart() + response = MockResponse([mock_part], model_name="test-model") + + result = _extract_model_response_output(response) + + assert "parts" in result + assert result["model_name"] == "test-model" + + +def test_uninstall_restores_original_methods(tracker): + original_request = TestModel.request + original_request_stream = TestModel.request_stream + + inst = PydanticAIInstrumentor() + inst.install(tracker) + + assert TestModel.request is not original_request + + inst.uninstall() + + assert TestModel.request is original_request + assert TestModel.request_stream is original_request_stream + + +def test_uninstall_without_install_is_safe(): + inst = PydanticAIInstrumentor() + inst.uninstall() + assert not inst._installed \ No newline at end of file