Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- name: Install package and test deps
run: |
pip install -e .
pip install pytest pytest-cov
pip install pytest pytest-cov pydantic-ai

- name: Run unit tests with coverage
run: |
Expand Down
7 changes: 5 additions & 2 deletions gradient_adk/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@

logger = get_logger(__name__)

from gradient_adk.runtime.langgraph.helpers import capture_graph, get_tracker
# Initialize framework instrumentation using the centralized registry
# This is idempotent and will only install instrumentation once
# Each instrumentor checks for its own environment variable to allow disabling
from gradient_adk.runtime.helpers import capture_all, get_tracker

capture_graph()
capture_all()


class _StreamingIteratorWithTracking:
Expand Down
12 changes: 11 additions & 1 deletion gradient_adk/digital_ocean_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from .models import (
TraceSpanType,
SpanCommon,
LLMSpanDetails,
ToolSpanDetails,
RetrieverSpanDetails,
WorkflowSpanDetails,
Span,
Trace,
CreateTracesInput,
Expand Down Expand Up @@ -38,6 +43,11 @@

__all__ = [
"TraceSpanType",
"SpanCommon",
"LLMSpanDetails",
"ToolSpanDetails",
"RetrieverSpanDetails",
"WorkflowSpanDetails",
"Span",
"Trace",
"CreateTracesInput",
Expand Down Expand Up @@ -70,4 +80,4 @@
"DOAPINetworkError",
"DOAPIValidationError",
"AsyncDigitalOceanGenAI",
]
]
81 changes: 78 additions & 3 deletions gradient_adk/digital_ocean_api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,71 @@ class TraceSpanType(str, Enum):
TRACE_SPAN_TYPE_LLM = "TRACE_SPAN_TYPE_LLM"
TRACE_SPAN_TYPE_RETRIEVER = "TRACE_SPAN_TYPE_RETRIEVER"
TRACE_SPAN_TYPE_TOOL = "TRACE_SPAN_TYPE_TOOL"
TRACE_SPAN_TYPE_WORKFLOW = "TRACE_SPAN_TYPE_WORKFLOW"
TRACE_SPAN_TYPE_AGENT = "TRACE_SPAN_TYPE_AGENT"


class SpanCommon(BaseModel):
"""Common fields for all span types."""

model_config = ConfigDict(populate_by_name=True, extra="allow")

duration_ns: Optional[int] = Field(None, description="Duration in nanoseconds")
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
tags: Optional[List[str]] = Field(None, description="Tags for the span")
status_code: Optional[int] = Field(None, description="HTTP status code if applicable")


class LLMSpanDetails(BaseModel):
"""LLM-specific span details."""

model_config = ConfigDict(populate_by_name=True, extra="allow")

common: Optional[SpanCommon] = None
model: Optional[str] = Field(None, description="Model name")
tools: Optional[List[Dict[str, Any]]] = Field(
None, description="Tool definitions passed to the model"
)
num_input_tokens: Optional[int] = Field(None, description="Number of input tokens")
num_output_tokens: Optional[int] = Field(None, description="Number of output tokens")
total_tokens: Optional[int] = Field(None, description="Total tokens")
temperature: Optional[float] = Field(None, description="Temperature setting")
time_to_first_token_ns: Optional[int] = Field(
None, description="Time to first token in nanoseconds"
)


class ToolSpanDetails(BaseModel):
"""Tool-specific span details."""

model_config = ConfigDict(populate_by_name=True, extra="allow")

common: Optional[SpanCommon] = None
tool_call_id: Optional[str] = Field(None, description="Tool call identifier")


class RetrieverSpanDetails(BaseModel):
"""Retriever-specific span details."""

model_config = ConfigDict(populate_by_name=True, extra="allow")

common: Optional[SpanCommon] = None


class WorkflowSpanDetails(BaseModel):
"""Workflow span containing nested sub-spans."""

model_config = ConfigDict(populate_by_name=True, extra="allow")

spans: List["Span"] = Field(default_factory=list, description="Nested sub-spans")


class Span(BaseModel):
"""
Represents a span within a trace (e.g., LLM call, retriever, tool).
Represents a span within a trace (e.g., LLM call, retriever, tool, workflow).
- created_at: RFC3339 timestamp (protobuf Timestamp)
- input/output: json
- input/output: json (must be dict for protobuf Struct compatibility)
- For workflow spans, contains nested sub-spans in the 'workflow' field
"""

model_config = ConfigDict(populate_by_name=True, extra="allow")
Expand All @@ -29,6 +87,23 @@ class Span(BaseModel):
output: Dict[str, Any]
type: TraceSpanType = Field(default=TraceSpanType.TRACE_SPAN_TYPE_UNKNOWN)

# Common fields for all span types
common: Optional[SpanCommon] = Field(None, description="Common span metadata")

# Type-specific fields
llm: Optional[LLMSpanDetails] = Field(None, description="LLM-specific details")
tool: Optional[ToolSpanDetails] = Field(None, description="Tool-specific details")
retriever: Optional[RetrieverSpanDetails] = Field(
None, description="Retriever-specific details"
)
workflow: Optional[WorkflowSpanDetails] = Field(
None, description="Workflow span with nested sub-spans"
)


# Update forward reference for WorkflowSpanDetails
WorkflowSpanDetails.model_rebuild()


class Trace(BaseModel):
"""
Expand Down Expand Up @@ -919,4 +994,4 @@ class ListEvaluationMetricsOutput(BaseModel):

metrics: List[EvaluationMetric] = Field(
default_factory=list, description="List of evaluation metrics"
)
)
175 changes: 165 additions & 10 deletions gradient_adk/runtime/digitalocean_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
Trace,
Span,
TraceSpanType,
SpanCommon,
LLMSpanDetails,
ToolSpanDetails,
RetrieverSpanDetails,
WorkflowSpanDetails,
)
from .interfaces import NodeExecution

Expand Down Expand Up @@ -324,22 +329,172 @@ def _to_span(self, ex: NodeExecution) -> Span:
out = dict(out)
out["_llm_endpoints"] = list(ex.metadata["llm_endpoints"])

# classify LLM/tool/retriever via metadata set by the instrumentor
# classify span type via metadata set by the instrumentor
metadata = ex.metadata or {}
if metadata.get("is_llm_call"):

# Check if this is a workflow span
if metadata.get("is_workflow"):
span_type = TraceSpanType.TRACE_SPAN_TYPE_WORKFLOW

# Build sub-spans from the workflow's collected spans
sub_spans_list = metadata.get("sub_spans", [])
sub_spans = [self._to_span(sub) for sub in sub_spans_list]

# Calculate duration from start to end
duration_ns = None
if ex.start_time and ex.end_time:
duration_ns = int(
(ex.end_time - ex.start_time).total_seconds() * 1_000_000_000
)

# Build common fields
common = SpanCommon(
duration_ns=duration_ns,
metadata={"agent_name": metadata.get("agent_name")},
status_code=200 if ex.error is None else 500,
)

# Build workflow details with nested sub-spans
workflow_details = WorkflowSpanDetails(spans=sub_spans)

return Span(
created_at=_utc(ex.start_time),
name=ex.node_name,
input=inp,
output=out,
type=span_type,
common=common,
workflow=workflow_details,
)
elif metadata.get("is_llm_call"):
span_type = TraceSpanType.TRACE_SPAN_TYPE_LLM

# Calculate duration
duration_ns = None
if ex.start_time and ex.end_time:
duration_ns = int(
(ex.end_time - ex.start_time).total_seconds() * 1_000_000_000
)

# Build LLM-specific details
llm_common = SpanCommon(
duration_ns=duration_ns,
status_code=200 if ex.error is None else 500,
)

# Extract LLM-specific fields from captured API payloads
llm_request = metadata.get("llm_request_payload", {}) or {}
llm_response = metadata.get("llm_response_payload", {}) or {}

# For LLM spans, use just the messages as input (not the full request payload)
# Must be a dict (not array) because protobuf Struct requires key-value pairs
if isinstance(llm_request, dict) and "messages" in llm_request:
inp = {"messages": llm_request.get("messages")}

# For LLM spans, use just the choices as output (not the full response payload)
# Must be a dict (not array) because protobuf Struct requires key-value pairs
if isinstance(llm_response, dict) and "choices" in llm_response:
out = {"choices": llm_response.get("choices")}

# Extract model from request payload, fallback to metadata or node name
model = (
llm_request.get("model")
or metadata.get("model_name")
or ex.node_name.replace("llm:", "")
)

# Extract tools from request payload
tools = llm_request.get("tools") if isinstance(llm_request, dict) else None

# Extract temperature from request payload
temperature = llm_request.get("temperature") if isinstance(llm_request, dict) else None

# Extract token counts from response payload
num_input_tokens = None
num_output_tokens = None
total_tokens = None
if isinstance(llm_response, dict):
usage = llm_response.get("usage", {})
if isinstance(usage, dict):
num_input_tokens = usage.get("prompt_tokens")
num_output_tokens = usage.get("completion_tokens")
total_tokens = usage.get("total_tokens")

# Get time-to-first-token for streaming calls
time_to_first_token_ns = metadata.get("time_to_first_token_ns")

llm_details = LLMSpanDetails(
common=llm_common,
model=model,
tools=tools,
temperature=temperature,
num_input_tokens=num_input_tokens,
num_output_tokens=num_output_tokens,
total_tokens=total_tokens,
time_to_first_token_ns=time_to_first_token_ns,
)

return Span(
created_at=_utc(ex.start_time),
name=ex.node_name,
input=inp,
output=out,
type=span_type,
llm=llm_details,
)
elif metadata.get("is_retriever_call"):
span_type = TraceSpanType.TRACE_SPAN_TYPE_RETRIEVER

# Calculate duration
duration_ns = None
if ex.start_time and ex.end_time:
duration_ns = int(
(ex.end_time - ex.start_time).total_seconds() * 1_000_000_000
)

# Build retriever-specific details
retriever_common = SpanCommon(
duration_ns=duration_ns,
status_code=200 if ex.error is None else 500,
)

retriever_details = RetrieverSpanDetails(common=retriever_common)

return Span(
created_at=_utc(ex.start_time),
name=ex.node_name,
input=inp,
output=out,
type=span_type,
retriever=retriever_details,
)
else:
# Default to tool span
span_type = TraceSpanType.TRACE_SPAN_TYPE_TOOL

return Span(
created_at=_utc(ex.start_time),
name=ex.node_name,
input=inp,
output=out,
type=span_type,
)
# Calculate duration
duration_ns = None
if ex.start_time and ex.end_time:
duration_ns = int(
(ex.end_time - ex.start_time).total_seconds() * 1_000_000_000
)

# Build tool-specific details
tool_common = SpanCommon(
duration_ns=duration_ns,
status_code=200 if ex.error is None else 500,
)

tool_details = ToolSpanDetails(common=tool_common)

return Span(
created_at=_utc(ex.start_time),
name=ex.node_name,
input=inp,
output=out,
type=span_type,
tool=tool_details,
)

def _coerce_top(self, val: Any, kind: str) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -374,4 +529,4 @@ def _build_trace(self) -> Trace:
output=outputs,
spans=spans,
)
return trace
return trace
Loading