From 7c7164d0be29958cd378619f50cf17cbf27c0343 Mon Sep 17 00:00:00 2001 From: Tyler Gillam Date: Wed, 31 Dec 2025 08:38:39 -0600 Subject: [PATCH 1/3] Add functions for adding spans --- gradient_adk/__init__.py | 12 +- gradient_adk/tracing.py | 188 +++++++++- tests/tracing_test.py | 733 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 927 insertions(+), 6 deletions(-) create mode 100644 tests/tracing_test.py diff --git a/gradient_adk/__init__.py b/gradient_adk/__init__.py index 82989dd..adbc342 100644 --- a/gradient_adk/__init__.py +++ b/gradient_adk/__init__.py @@ -4,17 +4,25 @@ """ from .decorator import entrypoint -from .tracing import ( # manual tracing decorators +from .tracing import ( # manual tracing decorators and functions trace_llm, trace_retriever, trace_tool, + add_llm_span, + add_retriever_span, + add_tool_span, ) __all__ = [ "entrypoint", + # Decorators "trace_llm", "trace_retriever", "trace_tool", + # Functions + "add_llm_span", + "add_retriever_span", + "add_tool_span", ] -__version__ = "0.0.5" +__version__ = "0.0.5" \ No newline at end of file diff --git a/gradient_adk/tracing.py b/gradient_adk/tracing.py index 6d70b49..834fcce 100644 --- a/gradient_adk/tracing.py +++ b/gradient_adk/tracing.py @@ -1,9 +1,10 @@ -"""Tracing decorators for manual span tracking. +"""Tracing decorators and functions for manual span tracking. -These decorators allow developers to instrument their custom agent functions -with the same kind of tracing automatically provided for some other frameworks. +These decorators and functions allow developers to instrument their custom agent +functions with the same kind of tracing automatically provided for some other +frameworks. -Example usage: +Decorator-based usage: from gradient_adk import entrypoint, trace_llm, trace_tool, trace_retriever @trace_retriever("fetch_data") @@ -27,6 +28,28 @@ async def my_agent(input: dict, context: dict): result = await calculate(5, 10) response = await call_model(data["prompt"]) return {"response": response} + +Function-based usage (for manual span creation): + from gradient_adk import entrypoint, add_llm_span, add_tool_span, add_retriever_span + + @entrypoint + async def my_agent(input: dict, context: dict): + # Do work and then record spans manually + query = input["query"] + + # Record a retriever span + results = await fetch_documents(query) + add_retriever_span("fetch_docs", inputs={"query": query}, output=results) + + # Record a tool span + calculation = calculate(5, 10) + add_tool_span("calculate", inputs={"x": 5, "y": 10}, output=calculation) + + # Record an LLM span + response = await call_llm(results) + add_llm_span("generate_response", inputs={"context": results}, output=response) + + return {"response": response} """ from __future__ import annotations @@ -443,3 +466,160 @@ async def search(query: str) -> list: return results """ return _trace_base(name, span_type=SpanType.TOOL) + + +def _add_span( + name: str, + inputs: Any, + output: Any, + span_type: SpanType, + extra_metadata: Optional[Dict[str, Any]] = None, +) -> None: + """ + Internal helper to add a completed span to the tracker. + + Args: + name: Name for the span. + inputs: The inputs to record for this span. + output: The output to record for this span. + span_type: Type of span (LLM, TOOL, or RETRIEVER). + extra_metadata: Additional metadata fields to attach to the span. + """ + tracker = get_tracker() + if not tracker: + # No tracker available, silently skip + return + + # Snapshot inputs and output + inputs_snapshot = _freeze(inputs) + output_snapshot = _freeze(output) + + # Create span + span = _create_span(name, inputs_snapshot) + + # Mark span type + meta = _ensure_meta(span) + if span_type == SpanType.LLM: + meta["is_llm_call"] = True + elif span_type == SpanType.TOOL: + meta["is_tool_call"] = True + elif span_type == SpanType.RETRIEVER: + meta["is_retriever_call"] = True + + # Add any extra metadata + if extra_metadata: + for key, value in extra_metadata.items(): + meta[key] = value + + # Record start and end + tracker.on_node_start(span) + tracker.on_node_end(span, output_snapshot) + + +def add_llm_span( + name: str, + inputs: Any, + output: Any, + *, + model_name: Optional[str] = None, + ttft_ms: Optional[float] = None, + **extra_metadata: Any, +) -> None: + """ + Add an LLM call span with the given name, inputs, and output. + + Use this function to manually record an LLM call span. + + Args: + name: Name for the span (e.g., "openai_call", "generate_response"). + inputs: The inputs to the LLM call (e.g., prompt, messages). + output: The output from the LLM call (e.g., response text). + model_name: Optional name of the LLM model used. + ttft_ms: Optional time to first token in milliseconds. + **extra_metadata: Any additional metadata fields to attach to the span. + + Example: + response = await call_llm(prompt) + add_llm_span( + "generate_response", + inputs={"prompt": prompt}, + output=response, + model_name="gpt-4", + ttft_ms=150.5, + temperature=0.7, + ) + """ + metadata: Dict[str, Any] = {} + if model_name is not None: + metadata["model_name"] = model_name + if ttft_ms is not None: + metadata["ttft_ms"] = ttft_ms + metadata.update(extra_metadata) + + _add_span(name, inputs, output, SpanType.LLM, metadata if metadata else None) + + +def add_retriever_span( + name: str, + inputs: Any, + output: Any, + **extra_metadata: Any, +) -> None: + """ + Add a retriever call span with the given name, inputs, and output. + + Use this function to manually record a retriever span without using a decorator. + + Args: + name: Name for the span (e.g., "vector_search", "fetch_docs"). + inputs: The inputs to the retriever (e.g., query, filters). + output: The output from the retriever (e.g., list of documents). + **extra_metadata: Any additional metadata fields to attach to the span. + + Example: + results = await fetch_documents(query) + add_retriever_span( + "fetch_docs", + inputs={"query": query}, + output=results, + num_results=len(results), + ) + """ + _add_span( + name, + inputs, + output, + SpanType.RETRIEVER, + extra_metadata if extra_metadata else None, + ) + + +def add_tool_span( + name: str, + inputs: Any, + output: Any, + **extra_metadata: Any, +) -> None: + """ + Add a tool call span with the given name, inputs, and output. + + Use this function to manually record a tool call span without using a decorator. + + Args: + name: Name for the span (e.g., "calculate", "search_database"). + inputs: The inputs to the tool (e.g., parameters). + output: The output from the tool (e.g., result). + **extra_metadata: Any additional metadata fields to attach to the span. + + Example: + result = calculate(5, 10) + add_tool_span( + "calculate", + inputs={"x": 5, "y": 10}, + output=result, + execution_time_ms=12.5, + ) + """ + _add_span( + name, inputs, output, SpanType.TOOL, extra_metadata if extra_metadata else None + ) diff --git a/tests/tracing_test.py b/tests/tracing_test.py new file mode 100644 index 0000000..8376224 --- /dev/null +++ b/tests/tracing_test.py @@ -0,0 +1,733 @@ +"""Tests for gradient_adk.tracing module.""" + +import asyncio +import dataclasses +import pytest +from unittest.mock import MagicMock, patch +from datetime import datetime, timezone + +from gradient_adk.tracing import ( + SpanType, + _freeze, + _snapshot_args_kwargs, + _snapshot_output, + _ensure_meta, + _create_span, + _add_span, + trace_llm, + trace_retriever, + trace_tool, + add_llm_span, + add_retriever_span, + add_tool_span, +) +from gradient_adk.runtime.interfaces import NodeExecution + + +# --------------------------- +# Test Doubles +# --------------------------- + + +class TrackerDouble: + """Mock tracker for testing span operations.""" + + def __init__(self): + self.node_starts = [] + self.node_ends = [] + self.node_errors = [] + + def on_node_start(self, node: NodeExecution): + self.node_starts.append(node) + + def on_node_end(self, node: NodeExecution, outputs): + self.node_ends.append((node, outputs)) + + def on_node_error(self, node: NodeExecution, error: BaseException): + self.node_errors.append((node, error)) + + +# --------------------------- +# Helper Function Tests +# --------------------------- + + +class TestFreeze: + """Tests for the _freeze helper function.""" + + def test_freeze_primitives(self): + """Test that primitives are returned as-is.""" + assert _freeze(None) is None + assert _freeze("hello") == "hello" + assert _freeze(42) == 42 + assert _freeze(3.14) == 3.14 + assert _freeze(True) is True + assert _freeze(False) is False + + def test_freeze_dict(self): + """Test dict serialization.""" + result = _freeze({"a": 1, "b": "hello"}) + assert result == {"a": 1, "b": "hello"} + + def test_freeze_nested_dict(self): + """Test nested dict serialization.""" + result = _freeze({"outer": {"inner": {"deep": 123}}}) + assert result == {"outer": {"inner": {"deep": 123}}} + + def test_freeze_list(self): + """Test list serialization.""" + result = _freeze([1, 2, 3]) + assert result == [1, 2, 3] + + def test_freeze_tuple(self): + """Test tuple serialization (converts to list).""" + result = _freeze((1, 2, 3)) + assert result == [1, 2, 3] + + def test_freeze_set(self): + """Test set serialization (converts to list).""" + result = _freeze({1, 2, 3}) + assert isinstance(result, list) + assert set(result) == {1, 2, 3} + + def test_freeze_max_depth(self): + """Test that max depth is respected.""" + deep = {"a": {"b": {"c": {"d": {"e": "deep"}}}}} + result = _freeze(deep, max_depth=2) + # max_depth=2 means: depth 2 (a), depth 1 (b), depth 0 (c -> max-depth) + assert result == {"a": {"b": {"c": ""}}} + + def test_freeze_max_items(self): + """Test that max items is respected for dicts.""" + large_dict = {f"key{i}": i for i in range(10)} + result = _freeze(large_dict, max_items=3) + # Should have 3 items + truncated marker + assert len(result) == 4 + assert result.get("") is True + + def test_freeze_max_items_list(self): + """Test that max items is respected for lists.""" + large_list = list(range(10)) + result = _freeze(large_list, max_items=3) + assert len(result) == 4 + assert result[-1] == "" + + def test_freeze_fallback_repr(self): + """Test that unknown types fall back to repr().""" + + class CustomClass: + def __repr__(self): + return "CustomClass()" + + result = _freeze(CustomClass()) + assert result == "CustomClass()" + + def test_freeze_dataclass(self): + """Test dataclass serialization.""" + + @dataclasses.dataclass + class Person: + name: str + age: int + + result = _freeze(Person(name="Alice", age=30)) + assert result == {"name": "Alice", "age": 30} + + def test_freeze_pydantic_model(self): + """Test pydantic model serialization.""" + try: + from pydantic import BaseModel + + class Item(BaseModel): + name: str + value: int + + result = _freeze(Item(name="test", value=42)) + assert result == {"name": "test", "value": 42} + except ImportError: + pytest.skip("pydantic not installed") + + +class TestSnapshotArgsKwargs: + """Tests for the _snapshot_args_kwargs helper function.""" + + def test_single_arg_no_kwargs(self): + """Test with single arg and no kwargs - returns just that arg.""" + result = _snapshot_args_kwargs(("hello",), {}) + assert result == "hello" + + def test_kwargs_only(self): + """Test with no args and kwargs only - returns just the kwargs.""" + result = _snapshot_args_kwargs((), {"a": 1, "b": 2}) + assert result == {"a": 1, "b": 2} + + def test_multiple_args(self): + """Test with multiple args - returns list.""" + result = _snapshot_args_kwargs(("a", "b", "c"), {}) + assert result == ["a", "b", "c"] + + def test_args_and_kwargs(self): + """Test with both args and kwargs - returns dict with both.""" + result = _snapshot_args_kwargs(("a",), {"x": 1}) + assert result == {"args": ["a"], "kwargs": {"x": 1}} + + def test_empty_args_and_kwargs(self): + """Test with empty args and kwargs.""" + result = _snapshot_args_kwargs((), {}) + assert result == [] + + +class TestSnapshotOutput: + """Tests for the _snapshot_output helper function.""" + + def test_snapshot_output_primitive(self): + """Test output snapshotting with primitives.""" + assert _snapshot_output("result") == "result" + assert _snapshot_output(42) == 42 + + def test_snapshot_output_dict(self): + """Test output snapshotting with dict.""" + result = _snapshot_output({"key": "value"}) + assert result == {"key": "value"} + + def test_snapshot_output_list(self): + """Test output snapshotting with list.""" + result = _snapshot_output([1, 2, 3]) + assert result == [1, 2, 3] + + +class TestEnsureMeta: + """Tests for the _ensure_meta helper function.""" + + def test_ensure_meta_creates_dict(self): + """Test that _ensure_meta creates a metadata dict if none exists.""" + node = NodeExecution( + node_id="123", + node_name="test", + framework="custom", + start_time=datetime.now(timezone.utc), + ) + assert node.metadata is None + + meta = _ensure_meta(node) + assert isinstance(meta, dict) + assert node.metadata == meta + + def test_ensure_meta_returns_existing(self): + """Test that _ensure_meta returns existing metadata.""" + node = NodeExecution( + node_id="123", + node_name="test", + framework="custom", + start_time=datetime.now(timezone.utc), + metadata={"existing": "data"}, + ) + + meta = _ensure_meta(node) + assert meta == {"existing": "data"} + + +class TestCreateSpan: + """Tests for the _create_span helper function.""" + + def test_create_span_basic(self): + """Test basic span creation.""" + span = _create_span("my_span", {"input": "data"}) + + assert span.node_name == "my_span" + assert span.framework == "custom" + assert span.inputs == {"input": "data"} + assert span.start_time is not None + assert span.node_id is not None + + def test_create_span_uuid_uniqueness(self): + """Test that each span gets a unique ID.""" + span1 = _create_span("span1", {}) + span2 = _create_span("span2", {}) + + assert span1.node_id != span2.node_id + + +# --------------------------- +# Decorator Tests +# --------------------------- + + +class TestTraceLlmDecorator: + """Tests for the @trace_llm decorator.""" + + def test_trace_llm_async_function(self): + """Test @trace_llm with async function.""" + tracker = TrackerDouble() + + with patch("gradient_adk.tracing.get_tracker", return_value=tracker): + with patch("gradient_adk.tracing.get_network_interceptor") as mock_interceptor: + mock_interceptor.return_value.snapshot_token.return_value = 0 + mock_interceptor.return_value.hits_since.return_value = 0 + + @trace_llm("my_llm_call") + async def call_llm(prompt: str) -> str: + return f"Response to: {prompt}" + + result = asyncio.run(call_llm("Hello")) + + assert result == "Response to: Hello" + assert len(tracker.node_starts) == 1 + assert len(tracker.node_ends) == 1 + + started_node = tracker.node_starts[0] + assert started_node.node_name == "my_llm_call" + assert started_node.metadata.get("is_llm_call") is True + + def test_trace_llm_sync_function(self): + """Test @trace_llm with sync function.""" + tracker = TrackerDouble() + + with patch("gradient_adk.tracing.get_tracker", return_value=tracker): + with patch("gradient_adk.tracing.get_network_interceptor") as mock_interceptor: + mock_interceptor.return_value.snapshot_token.return_value = 0 + mock_interceptor.return_value.hits_since.return_value = 0 + + @trace_llm("sync_llm") + def call_llm(prompt: str) -> str: + return f"Response: {prompt}" + + result = call_llm("Test") + + assert result == "Response: Test" + assert len(tracker.node_starts) == 1 + assert tracker.node_starts[0].metadata.get("is_llm_call") is True + + def test_trace_llm_uses_function_name_when_no_name_provided(self): + """Test that function name is used when no custom name is provided.""" + tracker = TrackerDouble() + + with patch("gradient_adk.tracing.get_tracker", return_value=tracker): + with patch("gradient_adk.tracing.get_network_interceptor") as mock_interceptor: + mock_interceptor.return_value.snapshot_token.return_value = 0 + mock_interceptor.return_value.hits_since.return_value = 0 + + @trace_llm() + async def my_custom_llm_function(prompt: str) -> str: + return "response" + + asyncio.run(my_custom_llm_function("test")) + + assert tracker.node_starts[0].node_name == "my_custom_llm_function" + + def test_trace_llm_async_generator(self): + """Test @trace_llm with async generator (streaming).""" + tracker = TrackerDouble() + + with patch("gradient_adk.tracing.get_tracker", return_value=tracker): + with patch("gradient_adk.tracing.get_network_interceptor") as mock_interceptor: + mock_interceptor.return_value.snapshot_token.return_value = 0 + mock_interceptor.return_value.hits_since.return_value = 0 + + @trace_llm("streaming_llm") + async def stream_llm(prompt: str): + yield "Hello" + yield " " + yield "World" + + async def consume(): + chunks = [] + async for chunk in stream_llm("test"): + chunks.append(chunk) + return chunks + + result = asyncio.run(consume()) + + assert result == ["Hello", " ", "World"] + assert len(tracker.node_starts) == 1 + assert len(tracker.node_ends) == 1 + + # Check collected content + ended_node, outputs = tracker.node_ends[0] + assert outputs == {"content": "Hello World"} + + def test_trace_llm_error_handling(self): + """Test that errors are tracked correctly.""" + tracker = TrackerDouble() + + with patch("gradient_adk.tracing.get_tracker", return_value=tracker): + with patch("gradient_adk.tracing.get_network_interceptor") as mock_interceptor: + mock_interceptor.return_value.snapshot_token.return_value = 0 + + @trace_llm("error_llm") + async def failing_llm(prompt: str) -> str: + raise ValueError("LLM error") + + with pytest.raises(ValueError, match="LLM error"): + asyncio.run(failing_llm("test")) + + assert len(tracker.node_starts) == 1 + assert len(tracker.node_errors) == 1 + assert isinstance(tracker.node_errors[0][1], ValueError) + + def test_trace_llm_no_tracker(self): + """Test that function works when no tracker is available.""" + with patch("gradient_adk.tracing.get_tracker", return_value=None): + + @trace_llm("no_tracker") + async def call_llm(prompt: str) -> str: + return "response" + + result = asyncio.run(call_llm("test")) + + assert result == "response" + + +class TestTraceRetrieverDecorator: + """Tests for the @trace_retriever decorator.""" + + def test_trace_retriever_async_function(self): + """Test @trace_retriever with async function.""" + tracker = TrackerDouble() + + with patch("gradient_adk.tracing.get_tracker", return_value=tracker): + with patch("gradient_adk.tracing.get_network_interceptor") as mock_interceptor: + mock_interceptor.return_value.snapshot_token.return_value = 0 + mock_interceptor.return_value.hits_since.return_value = 0 + + @trace_retriever("vector_search") + async def search(query: str) -> list: + return [{"id": 1, "text": "result"}] + + result = asyncio.run(search("test query")) + + assert result == [{"id": 1, "text": "result"}] + assert len(tracker.node_starts) == 1 + assert tracker.node_starts[0].metadata.get("is_retriever_call") is True + + def test_trace_retriever_sync_function(self): + """Test @trace_retriever with sync function.""" + tracker = TrackerDouble() + + with patch("gradient_adk.tracing.get_tracker", return_value=tracker): + with patch("gradient_adk.tracing.get_network_interceptor") as mock_interceptor: + mock_interceptor.return_value.snapshot_token.return_value = 0 + mock_interceptor.return_value.hits_since.return_value = 0 + + @trace_retriever("db_search") + def search(query: str) -> list: + return [{"id": 1}] + + result = search("test") + + assert result == [{"id": 1}] + assert tracker.node_starts[0].metadata.get("is_retriever_call") is True + + +class TestTraceToolDecorator: + """Tests for the @trace_tool decorator.""" + + def test_trace_tool_async_function(self): + """Test @trace_tool with async function.""" + tracker = TrackerDouble() + + with patch("gradient_adk.tracing.get_tracker", return_value=tracker): + with patch("gradient_adk.tracing.get_network_interceptor") as mock_interceptor: + mock_interceptor.return_value.snapshot_token.return_value = 0 + mock_interceptor.return_value.hits_since.return_value = 0 + + @trace_tool("calculator") + async def add(x: int, y: int) -> int: + return x + y + + result = asyncio.run(add(5, 3)) + + assert result == 8 + assert len(tracker.node_starts) == 1 + assert tracker.node_starts[0].metadata.get("is_tool_call") is True + + def test_trace_tool_sync_function(self): + """Test @trace_tool with sync function.""" + tracker = TrackerDouble() + + with patch("gradient_adk.tracing.get_tracker", return_value=tracker): + with patch("gradient_adk.tracing.get_network_interceptor") as mock_interceptor: + mock_interceptor.return_value.snapshot_token.return_value = 0 + mock_interceptor.return_value.hits_since.return_value = 0 + + @trace_tool("multiply") + def multiply(x: int, y: int) -> int: + return x * y + + result = multiply(4, 5) + + assert result == 20 + assert tracker.node_starts[0].metadata.get("is_tool_call") is True + + +class TestDecoratorInputOutputCapture: + """Tests for input/output capture in decorators.""" + + def test_captures_inputs_correctly(self): + """Test that function inputs are captured correctly.""" + tracker = TrackerDouble() + + with patch("gradient_adk.tracing.get_tracker", return_value=tracker): + with patch("gradient_adk.tracing.get_network_interceptor") as mock_interceptor: + mock_interceptor.return_value.snapshot_token.return_value = 0 + mock_interceptor.return_value.hits_since.return_value = 0 + + @trace_tool("test_tool") + def process(data: dict) -> dict: + return {"processed": True} + + process({"key": "value"}) + + started_node = tracker.node_starts[0] + assert started_node.inputs == {"key": "value"} + + def test_captures_outputs_correctly(self): + """Test that function outputs are captured correctly.""" + tracker = TrackerDouble() + + with patch("gradient_adk.tracing.get_tracker", return_value=tracker): + with patch("gradient_adk.tracing.get_network_interceptor") as mock_interceptor: + mock_interceptor.return_value.snapshot_token.return_value = 0 + mock_interceptor.return_value.hits_since.return_value = 0 + + @trace_tool("test_tool") + def process(data: dict) -> dict: + return {"result": 42} + + process({"input": "data"}) + + ended_node, outputs = tracker.node_ends[0] + assert outputs == {"result": 42} + + +# --------------------------- +# Function-based Span API Tests +# --------------------------- + + +class TestAddLlmSpan: + """Tests for the add_llm_span function.""" + + def test_add_llm_span_basic(self): + """Test basic add_llm_span usage.""" + tracker = TrackerDouble() + + with patch("gradient_adk.tracing.get_tracker", return_value=tracker): + add_llm_span( + "my_llm", + inputs={"prompt": "Hello"}, + output="World", + ) + + assert len(tracker.node_starts) == 1 + assert len(tracker.node_ends) == 1 + + node = tracker.node_starts[0] + assert node.node_name == "my_llm" + assert node.inputs == {"prompt": "Hello"} + assert node.metadata.get("is_llm_call") is True + + ended_node, outputs = tracker.node_ends[0] + assert outputs == "World" + + def test_add_llm_span_with_model_name(self): + """Test add_llm_span with model_name parameter.""" + tracker = TrackerDouble() + + with patch("gradient_adk.tracing.get_tracker", return_value=tracker): + add_llm_span( + "my_llm", + inputs={"prompt": "test"}, + output="response", + model_name="gpt-4", + ) + + node = tracker.node_starts[0] + assert node.metadata.get("model_name") == "gpt-4" + + def test_add_llm_span_with_ttft_ms(self): + """Test add_llm_span with ttft_ms parameter.""" + tracker = TrackerDouble() + + with patch("gradient_adk.tracing.get_tracker", return_value=tracker): + add_llm_span( + "my_llm", + inputs={"prompt": "test"}, + output="response", + ttft_ms=150.5, + ) + + node = tracker.node_starts[0] + assert node.metadata.get("ttft_ms") == 150.5 + + def test_add_llm_span_with_extra_metadata(self): + """Test add_llm_span with extra metadata via kwargs.""" + tracker = TrackerDouble() + + with patch("gradient_adk.tracing.get_tracker", return_value=tracker): + add_llm_span( + "my_llm", + inputs={"prompt": "test"}, + output="response", + model_name="gpt-4", + ttft_ms=100.0, + temperature=0.7, + max_tokens=1000, + ) + + node = tracker.node_starts[0] + assert node.metadata.get("model_name") == "gpt-4" + assert node.metadata.get("ttft_ms") == 100.0 + assert node.metadata.get("temperature") == 0.7 + assert node.metadata.get("max_tokens") == 1000 + + def test_add_llm_span_no_tracker(self): + """Test that add_llm_span silently skips when no tracker.""" + with patch("gradient_adk.tracing.get_tracker", return_value=None): + # Should not raise + add_llm_span("my_llm", inputs={}, output="test") + + +class TestAddRetrieverSpan: + """Tests for the add_retriever_span function.""" + + def test_add_retriever_span_basic(self): + """Test basic add_retriever_span usage.""" + tracker = TrackerDouble() + + with patch("gradient_adk.tracing.get_tracker", return_value=tracker): + add_retriever_span( + "vector_search", + inputs={"query": "test"}, + output=[{"id": 1}, {"id": 2}], + ) + + assert len(tracker.node_starts) == 1 + node = tracker.node_starts[0] + assert node.node_name == "vector_search" + assert node.metadata.get("is_retriever_call") is True + + def test_add_retriever_span_with_extra_metadata(self): + """Test add_retriever_span with extra metadata.""" + tracker = TrackerDouble() + + with patch("gradient_adk.tracing.get_tracker", return_value=tracker): + add_retriever_span( + "vector_search", + inputs={"query": "test"}, + output=[{"id": 1}], + num_results=1, + similarity_threshold=0.8, + ) + + node = tracker.node_starts[0] + assert node.metadata.get("num_results") == 1 + assert node.metadata.get("similarity_threshold") == 0.8 + + def test_add_retriever_span_no_tracker(self): + """Test that add_retriever_span silently skips when no tracker.""" + with patch("gradient_adk.tracing.get_tracker", return_value=None): + add_retriever_span("search", inputs={}, output=[]) + + +class TestAddToolSpan: + """Tests for the add_tool_span function.""" + + def test_add_tool_span_basic(self): + """Test basic add_tool_span usage.""" + tracker = TrackerDouble() + + with patch("gradient_adk.tracing.get_tracker", return_value=tracker): + add_tool_span( + "calculator", + inputs={"x": 5, "y": 3}, + output=8, + ) + + assert len(tracker.node_starts) == 1 + node = tracker.node_starts[0] + assert node.node_name == "calculator" + assert node.metadata.get("is_tool_call") is True + + ended_node, outputs = tracker.node_ends[0] + assert outputs == 8 + + def test_add_tool_span_with_extra_metadata(self): + """Test add_tool_span with extra metadata.""" + tracker = TrackerDouble() + + with patch("gradient_adk.tracing.get_tracker", return_value=tracker): + add_tool_span( + "calculator", + inputs={"x": 5, "y": 3}, + output=8, + execution_time_ms=12.5, + tool_version="1.0.0", + ) + + node = tracker.node_starts[0] + assert node.metadata.get("execution_time_ms") == 12.5 + assert node.metadata.get("tool_version") == "1.0.0" + + def test_add_tool_span_no_tracker(self): + """Test that add_tool_span silently skips when no tracker.""" + with patch("gradient_adk.tracing.get_tracker", return_value=None): + add_tool_span("tool", inputs={}, output=None) + + +class TestInternalAddSpan: + """Tests for the internal _add_span function.""" + + def test_add_span_llm_type(self): + """Test _add_span with LLM span type.""" + tracker = TrackerDouble() + + with patch("gradient_adk.tracing.get_tracker", return_value=tracker): + _add_span("test", {"input": 1}, "output", SpanType.LLM) + + assert tracker.node_starts[0].metadata.get("is_llm_call") is True + + def test_add_span_tool_type(self): + """Test _add_span with TOOL span type.""" + tracker = TrackerDouble() + + with patch("gradient_adk.tracing.get_tracker", return_value=tracker): + _add_span("test", {"input": 1}, "output", SpanType.TOOL) + + assert tracker.node_starts[0].metadata.get("is_tool_call") is True + + def test_add_span_retriever_type(self): + """Test _add_span with RETRIEVER span type.""" + tracker = TrackerDouble() + + with patch("gradient_adk.tracing.get_tracker", return_value=tracker): + _add_span("test", {"input": 1}, "output", SpanType.RETRIEVER) + + assert tracker.node_starts[0].metadata.get("is_retriever_call") is True + + def test_add_span_with_extra_metadata(self): + """Test _add_span with extra metadata.""" + tracker = TrackerDouble() + + with patch("gradient_adk.tracing.get_tracker", return_value=tracker): + _add_span( + "test", + {"input": 1}, + "output", + SpanType.LLM, + extra_metadata={"custom_key": "custom_value"}, + ) + + node = tracker.node_starts[0] + assert node.metadata.get("custom_key") == "custom_value" + assert node.metadata.get("is_llm_call") is True + + +class TestSpanType: + """Tests for the SpanType enum.""" + + def test_span_type_values(self): + """Test SpanType enum values.""" + assert SpanType.LLM.value == "llm" + assert SpanType.TOOL.value == "tool" + assert SpanType.RETRIEVER.value == "retriever" \ No newline at end of file From 07b8a1d9d184530c28007471cae38c82286d120f Mon Sep 17 00:00:00 2001 From: Tyler Gillam Date: Wed, 31 Dec 2025 08:39:02 -0600 Subject: [PATCH 2/3] Add functions --- gradient_adk/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gradient_adk/__init__.py b/gradient_adk/__init__.py index adbc342..0065f9a 100644 --- a/gradient_adk/__init__.py +++ b/gradient_adk/__init__.py @@ -25,4 +25,4 @@ "add_tool_span", ] -__version__ = "0.0.5" \ No newline at end of file +__version__ = "0.0.5" From e2ca3c291866d6d72b9b2dbaa5615a49128c131c Mon Sep 17 00:00:00 2001 From: Tyler Gillam Date: Wed, 31 Dec 2025 08:42:39 -0600 Subject: [PATCH 3/3] Add span functions to README --- README.md | 72 +++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 52 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 99dbff9..1e7ec09 100644 --- a/README.md +++ b/README.md @@ -102,9 +102,13 @@ async def main(input: dict, context: dict): return result["output"] ``` -### Using Custom Decorators (Any Framework) +### Manual Trace and Span Capture -For frameworks beyond LangGraph, use trace decorators to capture custom spans: +For frameworks beyond LangGraph, you can manually capture traces using either **decorators** (wrap functions) or **functions** (record spans after execution). + +#### Using Decorators + +Use `@trace_llm`, `@trace_tool`, and `@trace_retriever` decorators to automatically capture spans when functions are called: ```python from gradient_adk import entrypoint, trace_llm, trace_tool, trace_retriever @@ -117,7 +121,7 @@ async def search_knowledge_base(query: str): @trace_llm("generate_response") async def generate_response(prompt: str): - # LLM spans capture model calls with token usage + # LLM spans capture model calls response = await llm.generate(prompt) return response @@ -134,6 +138,36 @@ async def main(input: dict, context: dict): return response ``` +#### Using Functions + +Use `add_llm_span`, `add_tool_span`, and `add_retriever_span` functions to manually record spans after execution. This is useful when you can't wrap a function with a decorator or need more control over what gets recorded: + +```python +from gradient_adk import entrypoint, add_llm_span, add_tool_span, add_retriever_span + +@entrypoint +async def main(input: dict, context: dict): + # Perform retrieval and record the span + results = await vector_db.search(input["query"]) + add_retriever_span("vector_search", inputs={"query": input["query"]}, output=results) + + # Perform calculation and record the span + result = 5 + 10 + add_tool_span("calculate", inputs={"x": 5, "y": 10}, output=result) + + # Call LLM and record with additional metadata + response = await llm.generate(f"Context: {results}") + add_llm_span( + "generate_response", + inputs={"prompt": f"Context: {results}"}, + output=response, + model_name="gpt-4", # Optional: model name + ttft_ms=150.5, # Optional: time to first token + ) + + return response +``` + ### Streaming Responses The runtime supports streaming responses with automatic trace capture: @@ -204,7 +238,9 @@ The ADK runtime automatically captures detailed traces: - **Errors**: Full exception details and stack traces - **Streaming Responses**: Individual chunks and aggregated outputs -### Available Decorators +### Available Decorators and Functions + +**Decorators** - Wrap functions to automatically capture spans: ```python from gradient_adk import trace_llm, trace_tool, trace_retriever @@ -213,7 +249,17 @@ from gradient_adk import trace_llm, trace_tool, trace_retriever @trace_retriever("db_search") # For retrieval/search operations ``` -These decorators are used to log steps or spans of your agent workflow that are not automatically captured. These will log things like the input, output, and step duration and make them available in your agent's traces and for use in agent evaluations. +**Functions** - Manually record spans after execution: +```python +from gradient_adk import add_llm_span, add_tool_span, add_retriever_span + +# Record spans with name, inputs, and output +add_llm_span("model_call", inputs={"prompt": "..."}, output="response", model_name="gpt-4") +add_tool_span("calculator", inputs={"x": 5, "y": 10}, output=15) +add_retriever_span("db_search", inputs={"query": "..."}, output=[...]) +``` + +These are used to log steps or spans of your agent workflow that are not automatically captured. They will log things like the input, output, and step duration and make them available in your agent's traces and for use in agent evaluations. ### Viewing Traces Traces are: @@ -234,20 +280,6 @@ export GRADIENT_MODEL_ACCESS_KEY=your_gradient_key export GRADIENT_VERBOSE=1 ``` -## Project Structure - -``` -my-agent/ -├── main.py # Agent entrypoint with @entrypoint decorator -├── .gradient/agent.yml # Agent configuration (auto-generated) -├── requirements.txt # Python dependencies -├── .env # Environment variables (not committed) -├── agents/ # Agent implementations -│ └── my_agent.py -└── tools/ # Custom tools - └── my_tool.py -``` - ## Framework Compatibility The Gradient ADK is designed to work with any Python-based AI agent framework: @@ -267,4 +299,4 @@ The Gradient ADK is designed to work with any Python-based AI agent framework: ## License -Licensed under the Apache License 2.0. See [LICENSE](./LICENSE) +Licensed under the Apache License 2.0. See [LICENSE](./LICENSE) \ No newline at end of file