diff --git a/README.md b/README.md index 99dbff9..0aa3d7e 100644 --- a/README.md +++ b/README.md @@ -191,31 +191,104 @@ gradient agent evaluate \ ``` -## Trace Capture +## Tracing -The ADK runtime automatically captures detailed traces: +The ADK provides comprehensive tracing capabilities to capture and analyze your agent's execution. You can use **decorators** for wrapping functions or **programmatic functions** for manual span creation. -### What Gets Traced +### What Gets Traced Automatically - **LangGraph Nodes**: All node executions, state transitions, and edges (including LLM calls, tool calls, and DigitalOcean Knowledge Base calls) -- **LLM Calls**: Function decorated with `@trace_llm` -- **Tool Calls**: Functions decorated with `@trace_tool` -- **Retriever Calls**: Functions decorated with `@trace_retriever` - **HTTP Requests**: Request/response payloads for LLM API calls - **Errors**: Full exception details and stack traces - **Streaming Responses**: Individual chunks and aggregated outputs -### Available Decorators +### Tracing Decorators + +Use decorators to automatically trace function executions: + +```python +from gradient_adk import entrypoint, trace_llm, trace_tool, trace_retriever + +@trace_llm("model_call") +async def call_model(prompt: str): + """LLM spans capture model calls with token usage.""" + response = await llm.generate(prompt) + return response + +@trace_tool("calculator") +async def calculate(x: int, y: int): + """Tool spans capture function/tool execution.""" + return x + y + +@trace_retriever("vector_search") +async def search_docs(query: str): + """Retriever spans capture search/lookup operations.""" + results = await vector_db.search(query) + return results + +@entrypoint +async def main(input: dict, context: dict): + docs = await search_docs(input["query"]) + result = await calculate(5, 10) + response = await call_model(f"Context: {docs}") + return response +``` + +### Programmatic Span Functions + +For more control over span creation, use the programmatic functions. These are useful when you can't use decorators or need to add spans for code you don't control: + ```python -from gradient_adk import trace_llm, trace_tool, trace_retriever +from gradient_adk import entrypoint, add_llm_span, add_tool_span, add_agent_span -@trace_llm("model_call") # For LLM/model invocations -@trace_tool("calculator") # For tool/function calls -@trace_retriever("db_search") # For retrieval/search operations +@entrypoint +async def main(input: dict, context: dict): + # Add an LLM span with detailed metadata + response = await external_llm_call(input["query"]) + add_llm_span( + name="external_llm_call", + input={"messages": [{"role": "user", "content": input["query"]}]}, + output={"response": response}, + model="gpt-4", + num_input_tokens=100, + num_output_tokens=50, + temperature=0.7, + ) + + # Add a tool span + tool_result = await run_tool(input["data"]) + add_tool_span( + name="data_processor", + input={"data": input["data"]}, + output={"result": tool_result}, + tool_call_id="call_abc123", + metadata={"tool_version": "1.0"}, + ) + + # Add an agent span for sub-agent calls + agent_result = await call_sub_agent(input["task"]) + add_agent_span( + name="research_agent", + input={"task": input["task"]}, + output={"result": agent_result}, + metadata={"agent_type": "research"}, + tags=["sub-agent", "research"], + ) + + return {"response": response, "tool_result": tool_result, "agent_result": agent_result} ``` -These decorators are used to log steps or spans of your agent workflow that are not automatically captured. These will log things like the input, output, and step duration and make them available in your agent's traces and for use in agent evaluations. +#### Available Span Functions + +| Function | Description | Key Optional Fields | +|----------|-------------|---------------------| +| `add_llm_span()` | Record LLM/model calls | `model`, `temperature`, `num_input_tokens`, `num_output_tokens`, `total_tokens`, `tools`, `time_to_first_token_ns` | +| `add_tool_span()` | Record tool/function executions | `tool_call_id` | +| `add_agent_span()` | Record agent/sub-agent executions | — | + +**Common optional fields for all span functions:** `duration_ns`, `metadata`, `tags`, `status_code` ### Viewing Traces + Traces are: - Automatically sent to DigitalOcean's Gradient Platform - Available in real-time through the web console diff --git a/gradient_adk/__init__.py b/gradient_adk/__init__.py index fdfa6c5..70ba0bf 100644 --- a/gradient_adk/__init__.py +++ b/gradient_adk/__init__.py @@ -4,18 +4,28 @@ """ from .decorator import entrypoint, RequestContext -from .tracing import ( # manual tracing decorators +from .tracing import ( + # Decorators trace_llm, trace_retriever, trace_tool, + # Programmatic span functions + add_llm_span, + add_tool_span, + add_agent_span, ) __all__ = [ "entrypoint", "RequestContext", + # Decorators "trace_llm", "trace_retriever", "trace_tool", + # Programmatic span functions + "add_llm_span", + "add_tool_span", + "add_agent_span", ] __version__ = "0.0.5" diff --git a/gradient_adk/runtime/digitalocean_tracker.py b/gradient_adk/runtime/digitalocean_tracker.py index 9588942..0f91ea6 100644 --- a/gradient_adk/runtime/digitalocean_tracker.py +++ b/gradient_adk/runtime/digitalocean_tracker.py @@ -375,17 +375,21 @@ def _to_span(self, ex: NodeExecution) -> Span: elif metadata.get("is_llm_call"): span_type = TraceSpanType.TRACE_SPAN_TYPE_LLM - # Calculate duration - duration_ns = None - if ex.start_time and ex.end_time: - duration_ns = int( - (ex.end_time - ex.start_time).total_seconds() * 1_000_000_000 - ) + # For programmatic API, only use user-provided duration_ns + # For decorators/automatic instrumentation, auto-calculate from start/end times + duration_ns = metadata.get("duration_ns") + if duration_ns is None and not metadata.get("is_programmatic"): + if ex.start_time and ex.end_time: + duration_ns = int( + (ex.end_time - ex.start_time).total_seconds() * 1_000_000_000 + ) # Build LLM-specific details llm_common = SpanCommon( duration_ns=duration_ns, - status_code=200 if ex.error is None else 500, + metadata=metadata.get("custom_metadata"), + tags=metadata.get("tags"), + status_code=metadata.get("status_code", 200 if ex.error is None else 500), ) # Extract LLM-specific fields from captured API payloads @@ -474,8 +478,69 @@ def _to_span(self, ex: NodeExecution) -> Span: type=span_type, retriever=retriever_details, ) + elif metadata.get("is_agent_call"): + span_type = TraceSpanType.TRACE_SPAN_TYPE_AGENT + + # For programmatic API, only use user-provided duration_ns + # For decorators/automatic instrumentation, auto-calculate from start/end times + duration_ns = metadata.get("duration_ns") + if duration_ns is None and not metadata.get("is_programmatic"): + if ex.start_time and ex.end_time: + duration_ns = int( + (ex.end_time - ex.start_time).total_seconds() * 1_000_000_000 + ) + + # Build agent-specific details + agent_common = SpanCommon( + duration_ns=duration_ns, + metadata=metadata.get("custom_metadata"), + tags=metadata.get("tags"), + status_code=metadata.get("status_code", 200 if ex.error is None else 500), + ) + + return Span( + created_at=_utc(ex.start_time), + name=ex.node_name, + input=inp, + output=out, + type=span_type, + common=agent_common, + ) + elif metadata.get("is_tool_call"): + span_type = TraceSpanType.TRACE_SPAN_TYPE_TOOL + + # For programmatic API, only use user-provided duration_ns + # For decorators/automatic instrumentation, auto-calculate from start/end times + duration_ns = metadata.get("duration_ns") + if duration_ns is None and not metadata.get("is_programmatic"): + if ex.start_time and ex.end_time: + duration_ns = int( + (ex.end_time - ex.start_time).total_seconds() * 1_000_000_000 + ) + + # Build tool-specific details + tool_common = SpanCommon( + duration_ns=duration_ns, + metadata=metadata.get("custom_metadata"), + tags=metadata.get("tags"), + status_code=metadata.get("status_code", 200 if ex.error is None else 500), + ) + + tool_details = ToolSpanDetails( + common=tool_common, + tool_call_id=metadata.get("tool_call_id"), + ) + + return Span( + created_at=_utc(ex.start_time), + name=ex.node_name, + input=inp, + output=out, + type=span_type, + tool=tool_details, + ) else: - # Default to tool span + # Default to tool span (for backward compatibility) span_type = TraceSpanType.TRACE_SPAN_TYPE_TOOL # Calculate duration diff --git a/gradient_adk/tracing.py b/gradient_adk/tracing.py index 189c94a..494f528 100644 --- a/gradient_adk/tracing.py +++ b/gradient_adk/tracing.py @@ -38,7 +38,7 @@ async def my_agent(input: dict, context: dict): from copy import deepcopy from datetime import datetime, timezone from enum import Enum -from typing import Any, Callable, Dict, Optional, Tuple, TypeVar +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar from .runtime.interfaces import NodeExecution from .runtime.helpers import get_tracker @@ -475,4 +475,203 @@ async def search(query: str) -> list: results = await db.search(query) return results """ - return _trace_base(name, span_type=SpanType.TOOL) \ No newline at end of file + return _trace_base(name, span_type=SpanType.TOOL) + + +# ============================================================================= +# Programmatic Span Functions +# ============================================================================= + + +def add_llm_span( + name: str, + input: Any, + output: Any, + *, + model: Optional[str] = None, + tools: Optional[List[Dict[str, Any]]] = None, + num_input_tokens: Optional[int] = None, + num_output_tokens: Optional[int] = None, + total_tokens: Optional[int] = None, + temperature: Optional[float] = None, + time_to_first_token_ns: Optional[int] = None, + duration_ns: Optional[int] = None, + metadata: Optional[Dict[str, Any]] = None, + tags: Optional[List[str]] = None, + status_code: Optional[int] = None, +) -> None: + """ + Add an LLM span to the current trace. + + Args: + name: Name for the span (e.g., "call_gpt", "embedding_request") + input: The input to the LLM call (e.g., messages, prompt) + output: The output from the LLM call (e.g., response, completion) + model: Model name (e.g., "gpt-4", "claude-3") + tools: Tool definitions passed to the model + num_input_tokens: Number of input/prompt tokens + num_output_tokens: Number of output/completion tokens + total_tokens: Total tokens used + temperature: Temperature setting used + time_to_first_token_ns: Time to first token in nanoseconds (for streaming) + duration_ns: Duration of the call in nanoseconds + metadata: Additional custom metadata + tags: Tags for the span + status_code: HTTP status code if applicable + + Example: + add_llm_span( + name="call_gpt", + input={"messages": [{"role": "user", "content": "Hello"}]}, + output={"response": "Hi there!"}, + model="gpt-4", + num_input_tokens=10, + num_output_tokens=5, + ) + """ + tracker = get_tracker() + if not tracker: + return + + span = _create_span(name, _freeze(input)) + meta = _ensure_meta(span) + meta["is_llm_call"] = True + meta["is_programmatic"] = True # Mark as programmatic to skip auto-duration calculation + + if model is not None: + meta["model_name"] = model + if tools is not None: + meta["llm_request_payload"] = {"tools": tools} + if temperature is not None: + if "llm_request_payload" not in meta: + meta["llm_request_payload"] = {} + meta["llm_request_payload"]["temperature"] = temperature + if time_to_first_token_ns is not None: + meta["time_to_first_token_ns"] = time_to_first_token_ns + if num_input_tokens is not None or num_output_tokens is not None or total_tokens is not None: + if "llm_response_payload" not in meta: + meta["llm_response_payload"] = {} + meta["llm_response_payload"]["usage"] = { + "prompt_tokens": num_input_tokens, + "completion_tokens": num_output_tokens, + "total_tokens": total_tokens, + } + if tags is not None: + meta["tags"] = tags + if status_code is not None: + meta["status_code"] = status_code + if metadata is not None: + meta["custom_metadata"] = metadata + if duration_ns is not None: + meta["duration_ns"] = duration_ns + + tracker.on_node_start(span) + tracker.on_node_end(span, _freeze(output)) + + +def add_tool_span( + name: str, + input: Any, + output: Any, + *, + tool_call_id: Optional[str] = None, + duration_ns: Optional[int] = None, + metadata: Optional[Dict[str, Any]] = None, + tags: Optional[List[str]] = None, + status_code: Optional[int] = None, +) -> None: + """ + Add a tool span to the current trace. + + Args: + name: Name for the span (e.g., "calculator", "web_search") + input: The input to the tool (e.g., function arguments) + output: The output from the tool (e.g., result) + tool_call_id: Tool call identifier (from LLM tool calling) + duration_ns: Duration of the call in nanoseconds + metadata: Additional custom metadata + tags: Tags for the span + status_code: HTTP status code if applicable + + Example: + add_tool_span( + name="calculator", + input={"operation": "add", "x": 5, "y": 3}, + output={"result": 8}, + tool_call_id="call_abc123", + ) + """ + tracker = get_tracker() + if not tracker: + return + + span = _create_span(name, _freeze(input)) + meta = _ensure_meta(span) + meta["is_tool_call"] = True + meta["is_programmatic"] = True # Mark as programmatic to skip auto-duration calculation + + if tool_call_id is not None: + meta["tool_call_id"] = tool_call_id + if tags is not None: + meta["tags"] = tags + if status_code is not None: + meta["status_code"] = status_code + if metadata is not None: + meta["custom_metadata"] = metadata + if duration_ns is not None: + meta["duration_ns"] = duration_ns + + tracker.on_node_start(span) + tracker.on_node_end(span, _freeze(output)) + + +def add_agent_span( + name: str, + input: Any, + output: Any, + *, + duration_ns: Optional[int] = None, + metadata: Optional[Dict[str, Any]] = None, + tags: Optional[List[str]] = None, + status_code: Optional[int] = None, +) -> None: + """ + Add an agent span to the current trace. + + Args: + name: Name for the span (e.g., "research_agent", "planning_agent") + input: The input to the agent (e.g., query, task) + output: The output from the agent (e.g., response, result) + duration_ns: Duration of the agent execution in nanoseconds + metadata: Additional custom metadata + tags: Tags for the span + status_code: HTTP status code if applicable + + Example: + add_agent_span( + name="research_agent", + input={"query": "What is machine learning?"}, + output={"answer": "Machine learning is..."}, + metadata={"model": "gpt-4"}, + ) + """ + tracker = get_tracker() + if not tracker: + return + + span = _create_span(name, _freeze(input)) + meta = _ensure_meta(span) + meta["is_agent_call"] = True + meta["is_programmatic"] = True # Mark as programmatic to skip auto-duration calculation + + if tags is not None: + meta["tags"] = tags + if status_code is not None: + meta["status_code"] = status_code + if metadata is not None: + meta["custom_metadata"] = metadata + if duration_ns is not None: + meta["duration_ns"] = duration_ns + + tracker.on_node_start(span) + tracker.on_node_end(span, _freeze(output)) diff --git a/integration_tests/example_agents/programmatic_spans_agent/main.py b/integration_tests/example_agents/programmatic_spans_agent/main.py new file mode 100644 index 0000000..09a0551 --- /dev/null +++ b/integration_tests/example_agents/programmatic_spans_agent/main.py @@ -0,0 +1,49 @@ +""" +Agent that tests programmatic span functions. +Uses add_llm_span, add_tool_span, and add_agent_span to manually create spans. +""" + +from gradient_adk import entrypoint, RequestContext, add_llm_span, add_tool_span, add_agent_span + + +@entrypoint +async def main(query, context: RequestContext): + """Test all programmatic span functions.""" + prompt = query.get("prompt", "no prompt provided") + + # Add an LLM span + add_llm_span( + name="test_llm_call", + input={"messages": [{"role": "user", "content": prompt}]}, + output={"response": f"Mock response to: {prompt}"}, + model="test-model", + num_input_tokens=10, + num_output_tokens=20, + total_tokens=30, + temperature=0.7, + ) + + # Add a tool span + add_tool_span( + name="test_tool_call", + input={"query": prompt}, + output={"result": "tool result"}, + tool_call_id="test_call_123", + metadata={"tool_type": "search"}, + ) + + # Add an agent span + add_agent_span( + name="test_agent_call", + input={"task": prompt}, + output={"answer": "agent answer"}, + metadata={"agent_version": "1.0"}, + tags=["test", "integration"], + ) + + return { + "success": True, + "message": "All programmatic spans created successfully", + "prompt_received": prompt, + "session_id": context.session_id if context else None, + } \ No newline at end of file diff --git a/integration_tests/programmatic_spans/test_programmatic_spans.py b/integration_tests/programmatic_spans/test_programmatic_spans.py new file mode 100644 index 0000000..541855a --- /dev/null +++ b/integration_tests/programmatic_spans/test_programmatic_spans.py @@ -0,0 +1,243 @@ +""" +Integration tests for the programmatic span functions (add_llm_span, add_tool_span, add_agent_span). +""" + +import logging +import os +import shutil +import signal +import socket +import subprocess +import tempfile +import time +from pathlib import Path + +import pytest +import requests +import yaml + + +def find_free_port(): + """Find an available port on the local machine.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +def wait_for_server(port: int, timeout: int = 30) -> bool: + """Wait for server to be ready on the given port.""" + start_time = time.time() + while time.time() - start_time < timeout: + try: + response = requests.get(f"http://localhost:{port}/health", timeout=2) + if response.status_code == 200: + return True + except (requests.ConnectionError, requests.Timeout): + pass + time.sleep(0.5) + return False + + +def cleanup_process(process): + """Clean up a process and its entire process group.""" + if process and process.poll() is None: + try: + os.killpg(process.pid, signal.SIGTERM) + process.wait(timeout=5) + except subprocess.TimeoutExpired: + os.killpg(process.pid, signal.SIGKILL) + except (ProcessLookupError, OSError): + pass + + +class TestProgrammaticSpans: + """Integration tests for programmatic span functions.""" + + @pytest.fixture + def programmatic_spans_agent_dir(self): + """Get the path to the programmatic spans agent directory.""" + return Path(__file__).parent.parent / "example_agents" / "programmatic_spans_agent" + + @pytest.fixture + def setup_agent_in_temp(self, programmatic_spans_agent_dir): + """ + Setup a temporary directory with the programmatic spans agent and proper configuration. + Yields the temp directory path and cleans up after. + """ + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Copy the agent main.py + shutil.copy(programmatic_spans_agent_dir / "main.py", temp_path / "main.py") + + # Create .gradient directory and config + gradient_dir = temp_path / ".gradient" + gradient_dir.mkdir() + + config = { + "agent_name": "test-programmatic-spans-agent", + "agent_environment": "main", + "entrypoint_file": "main.py", + } + + with open(gradient_dir / "agent.yml", "w") as f: + yaml.safe_dump(config, f) + + yield temp_path + + @pytest.mark.cli + def test_programmatic_spans_agent_runs_successfully(self, setup_agent_in_temp): + """ + Test that an agent using add_llm_span, add_tool_span, and add_agent_span + can start and respond without errors. + + Verifies: + - Server starts successfully + - Health endpoint responds + - /run endpoint works with programmatic span functions + - Server can be cleanly terminated + """ + logger = logging.getLogger(__name__) + temp_dir = setup_agent_in_temp + port = find_free_port() + process = None + + try: + logger.info(f"Starting programmatic spans agent on port {port} in {temp_dir}") + + # Start the agent server + process = subprocess.Popen( + [ + "gradient", + "agent", + "run", + "--port", + str(port), + "--no-dev", + ], + cwd=temp_dir, + start_new_session=True, + ) + + # Wait for server to be ready + server_ready = wait_for_server(port, timeout=30) + assert server_ready, "Server did not start within timeout" + + # Test health endpoint + health_response = requests.get(f"http://localhost:{port}/health", timeout=5) + assert health_response.status_code == 200 + health_data = health_response.json() + assert health_data["status"] == "healthy" + logger.info(f"Health check passed: {health_data}") + + # Test /run endpoint - this will exercise all three programmatic span functions + run_response = requests.post( + f"http://localhost:{port}/run", + json={"prompt": "Test prompt for programmatic spans"}, + timeout=10, + ) + assert run_response.status_code == 200 + run_data = run_response.json() + + # Verify the response + assert run_data["success"] is True + assert run_data["message"] == "All programmatic spans created successfully" + assert run_data["prompt_received"] == "Test prompt for programmatic spans" + logger.info(f"Run endpoint test passed: {run_data}") + + finally: + cleanup_process(process) + + @pytest.mark.cli + def test_programmatic_spans_with_empty_input(self, setup_agent_in_temp): + """ + Test programmatic spans with empty input. + """ + logger = logging.getLogger(__name__) + temp_dir = setup_agent_in_temp + port = find_free_port() + process = None + + try: + logger.info(f"Starting agent on port {port}") + + process = subprocess.Popen( + [ + "gradient", + "agent", + "run", + "--port", + str(port), + "--no-dev", + ], + cwd=temp_dir, + start_new_session=True, + ) + + server_ready = wait_for_server(port, timeout=30) + assert server_ready, "Server did not start within timeout" + + # Test with empty object + run_response = requests.post( + f"http://localhost:{port}/run", + json={}, + timeout=10, + ) + assert run_response.status_code == 200 + run_data = run_response.json() + + assert run_data["success"] is True + assert run_data["prompt_received"] == "no prompt provided" + logger.info(f"Empty input test passed: {run_data}") + + finally: + cleanup_process(process) + + @pytest.mark.cli + def test_programmatic_spans_multiple_requests(self, setup_agent_in_temp): + """ + Test that multiple requests with programmatic spans work correctly. + """ + logger = logging.getLogger(__name__) + temp_dir = setup_agent_in_temp + port = find_free_port() + process = None + + try: + logger.info(f"Starting agent on port {port}") + + process = subprocess.Popen( + [ + "gradient", + "agent", + "run", + "--port", + str(port), + "--no-dev", + ], + cwd=temp_dir, + start_new_session=True, + ) + + server_ready = wait_for_server(port, timeout=30) + assert server_ready, "Server did not start within timeout" + + # Make multiple requests to ensure spans work repeatedly + for i in range(3): + run_response = requests.post( + f"http://localhost:{port}/run", + json={"prompt": f"Request {i + 1}"}, + timeout=10, + ) + assert run_response.status_code == 200 + run_data = run_response.json() + + assert run_data["success"] is True + assert run_data["prompt_received"] == f"Request {i + 1}" + logger.info(f"Request {i + 1} passed: {run_data}") + + logger.info("Multiple requests test passed") + + finally: + cleanup_process(process) \ No newline at end of file diff --git a/tests/tracing_test.py b/tests/tracing_test.py new file mode 100644 index 0000000..c86c9b4 --- /dev/null +++ b/tests/tracing_test.py @@ -0,0 +1,322 @@ +"""Tests for the tracing module's programmatic span functions.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from gradient_adk.tracing import add_llm_span, add_tool_span, add_agent_span + + +class TestAddLlmSpan: + """Tests for add_llm_span function.""" + + def test_no_tracker_does_not_raise(self): + """Should not raise when no tracker is available.""" + with patch("gradient_adk.tracing.get_tracker", return_value=None): + add_llm_span( + name="test_llm", + input={"prompt": "Hello"}, + output={"response": "Hi"}, + ) + + def test_basic_span_creation(self): + """Should create and submit LLM span with basic fields.""" + mock_tracker = MagicMock() + + with patch("gradient_adk.tracing.get_tracker", return_value=mock_tracker): + add_llm_span( + name="test_llm", + input={"prompt": "Hello"}, + output={"response": "Hi"}, + ) + + assert mock_tracker.on_node_start.call_count == 1 + span = mock_tracker.on_node_start.call_args[0][0] + assert span.node_name == "test_llm" + assert span.inputs == {"prompt": "Hello"} + assert span.metadata["is_llm_call"] is True + + assert mock_tracker.on_node_end.call_count == 1 + + def test_with_model(self): + """Should include model name in metadata.""" + mock_tracker = MagicMock() + + with patch("gradient_adk.tracing.get_tracker", return_value=mock_tracker): + add_llm_span( + name="test_llm", + input={"prompt": "Hello"}, + output={"response": "Hi"}, + model="gpt-4", + ) + + span = mock_tracker.on_node_start.call_args[0][0] + assert span.metadata["model_name"] == "gpt-4" + + def test_with_tokens(self): + """Should include token counts in metadata.""" + mock_tracker = MagicMock() + + with patch("gradient_adk.tracing.get_tracker", return_value=mock_tracker): + add_llm_span( + name="test_llm", + input={"prompt": "Hello"}, + output={"response": "Hi"}, + num_input_tokens=10, + num_output_tokens=5, + total_tokens=15, + ) + + span = mock_tracker.on_node_start.call_args[0][0] + usage = span.metadata["llm_response_payload"]["usage"] + assert usage["prompt_tokens"] == 10 + assert usage["completion_tokens"] == 5 + assert usage["total_tokens"] == 15 + + def test_with_temperature(self): + """Should include temperature in request payload.""" + mock_tracker = MagicMock() + + with patch("gradient_adk.tracing.get_tracker", return_value=mock_tracker): + add_llm_span( + name="test_llm", + input={"prompt": "Hello"}, + output={"response": "Hi"}, + temperature=0.7, + ) + + span = mock_tracker.on_node_start.call_args[0][0] + assert span.metadata["llm_request_payload"]["temperature"] == 0.7 + + def test_with_all_optional_fields(self): + """Should handle all optional fields.""" + mock_tracker = MagicMock() + + with patch("gradient_adk.tracing.get_tracker", return_value=mock_tracker): + add_llm_span( + name="test_llm", + input={"prompt": "Hello"}, + output={"response": "Hi"}, + model="gpt-4", + tools=[{"type": "function"}], + num_input_tokens=10, + num_output_tokens=5, + total_tokens=15, + temperature=0.7, + time_to_first_token_ns=100000000, + duration_ns=500000000, + metadata={"custom": "data"}, + tags=["production", "test"], + status_code=200, + ) + + span = mock_tracker.on_node_start.call_args[0][0] + meta = span.metadata + + assert meta["is_llm_call"] is True + assert meta["model_name"] == "gpt-4" + assert meta["time_to_first_token_ns"] == 100000000 + assert meta["duration_ns"] == 500000000 + assert meta["custom_metadata"] == {"custom": "data"} + assert meta["tags"] == ["production", "test"] + assert meta["status_code"] == 200 + + +class TestAddToolSpan: + """Tests for add_tool_span function.""" + + def test_no_tracker_does_not_raise(self): + """Should not raise when no tracker is available.""" + with patch("gradient_adk.tracing.get_tracker", return_value=None): + add_tool_span( + name="calculator", + input={"x": 5, "y": 3}, + output={"result": 8}, + ) + + def test_basic_span_creation(self): + """Should create and submit tool span with basic fields.""" + mock_tracker = MagicMock() + + with patch("gradient_adk.tracing.get_tracker", return_value=mock_tracker): + add_tool_span( + name="calculator", + input={"x": 5, "y": 3}, + output={"result": 8}, + ) + + assert mock_tracker.on_node_start.call_count == 1 + span = mock_tracker.on_node_start.call_args[0][0] + assert span.node_name == "calculator" + assert span.inputs == {"x": 5, "y": 3} + assert span.metadata["is_tool_call"] is True + + assert mock_tracker.on_node_end.call_count == 1 + + def test_with_tool_call_id(self): + """Should include tool_call_id in metadata.""" + mock_tracker = MagicMock() + + with patch("gradient_adk.tracing.get_tracker", return_value=mock_tracker): + add_tool_span( + name="calculator", + input={"x": 5, "y": 3}, + output={"result": 8}, + tool_call_id="call_abc123", + ) + + span = mock_tracker.on_node_start.call_args[0][0] + assert span.metadata["tool_call_id"] == "call_abc123" + + def test_with_all_optional_fields(self): + """Should handle all optional fields.""" + mock_tracker = MagicMock() + + with patch("gradient_adk.tracing.get_tracker", return_value=mock_tracker): + add_tool_span( + name="calculator", + input={"x": 5, "y": 3}, + output={"result": 8}, + tool_call_id="call_abc123", + duration_ns=1000000, + metadata={"function": "add"}, + tags=["math"], + status_code=200, + ) + + span = mock_tracker.on_node_start.call_args[0][0] + meta = span.metadata + + assert meta["is_tool_call"] is True + assert meta["tool_call_id"] == "call_abc123" + assert meta["duration_ns"] == 1000000 + assert meta["custom_metadata"] == {"function": "add"} + assert meta["tags"] == ["math"] + assert meta["status_code"] == 200 + + +class TestAddAgentSpan: + """Tests for add_agent_span function.""" + + def test_no_tracker_does_not_raise(self): + """Should not raise when no tracker is available.""" + with patch("gradient_adk.tracing.get_tracker", return_value=None): + add_agent_span( + name="research_agent", + input={"query": "What is AI?"}, + output={"answer": "AI is..."}, + ) + + def test_basic_span_creation(self): + """Should create and submit agent span with basic fields.""" + mock_tracker = MagicMock() + + with patch("gradient_adk.tracing.get_tracker", return_value=mock_tracker): + add_agent_span( + name="research_agent", + input={"query": "What is AI?"}, + output={"answer": "AI is..."}, + ) + + assert mock_tracker.on_node_start.call_count == 1 + span = mock_tracker.on_node_start.call_args[0][0] + assert span.node_name == "research_agent" + assert span.inputs == {"query": "What is AI?"} + assert span.metadata["is_agent_call"] is True + + assert mock_tracker.on_node_end.call_count == 1 + + def test_with_all_optional_fields(self): + """Should handle all optional fields.""" + mock_tracker = MagicMock() + + with patch("gradient_adk.tracing.get_tracker", return_value=mock_tracker): + add_agent_span( + name="research_agent", + input={"query": "What is AI?"}, + output={"answer": "AI is..."}, + duration_ns=5000000000, + metadata={"model": "gpt-4"}, + tags=["research"], + status_code=200, + ) + + span = mock_tracker.on_node_start.call_args[0][0] + meta = span.metadata + + assert meta["is_agent_call"] is True + assert meta["duration_ns"] == 5000000000 + assert meta["custom_metadata"] == {"model": "gpt-4"} + assert meta["tags"] == ["research"] + assert meta["status_code"] == 200 + + +class TestInputOutputSerialization: + """Tests for input/output serialization handling.""" + + def test_complex_input_is_frozen(self): + """Should properly serialize complex input objects.""" + mock_tracker = MagicMock() + + class CustomObject: + def __repr__(self): + return "CustomObject()" + + with patch("gradient_adk.tracing.get_tracker", return_value=mock_tracker): + add_llm_span( + name="test", + input=CustomObject(), + output="response", + ) + + span = mock_tracker.on_node_start.call_args[0][0] + assert span.inputs == "CustomObject()" + + def test_nested_dict_input(self): + """Should handle nested dictionaries.""" + mock_tracker = MagicMock() + + with patch("gradient_adk.tracing.get_tracker", return_value=mock_tracker): + add_llm_span( + name="test", + input={ + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + ], + "config": {"temperature": 0.7}, + }, + output={"response": "result"}, + ) + + span = mock_tracker.on_node_start.call_args[0][0] + assert span.inputs["messages"][0]["role"] == "user" + assert span.inputs["config"]["temperature"] == 0.7 + + def test_list_input(self): + """Should handle list inputs.""" + mock_tracker = MagicMock() + + with patch("gradient_adk.tracing.get_tracker", return_value=mock_tracker): + add_tool_span( + name="batch_process", + input=[1, 2, 3, 4, 5], + output=[2, 4, 6, 8, 10], + ) + + span = mock_tracker.on_node_start.call_args[0][0] + assert span.inputs == [1, 2, 3, 4, 5] + + def test_none_input_output(self): + """Should handle None inputs and outputs.""" + mock_tracker = MagicMock() + + with patch("gradient_adk.tracing.get_tracker", return_value=mock_tracker): + add_agent_span( + name="test", + input=None, + output=None, + ) + + span = mock_tracker.on_node_start.call_args[0][0] + assert span.inputs is None \ No newline at end of file