Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- name: Install package and test deps
run: |
pip install -e .
pip install pytest pytest-cov
pip install pytest pytest-cov pydantic-ai

- name: Run unit tests with coverage
run: |
Expand Down
7 changes: 5 additions & 2 deletions gradient_adk/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@

logger = get_logger(__name__)

from gradient_adk.runtime.langgraph.helpers import capture_graph, get_tracker
# Initialize framework instrumentation using the centralized registry
# This is idempotent and will only install instrumentation once
# Each instrumentor checks for its own environment variable to allow disabling
from gradient_adk.runtime.helpers import capture_all, get_tracker

capture_graph()
capture_all()


class _StreamingIteratorWithTracking:
Expand Down
12 changes: 11 additions & 1 deletion gradient_adk/digital_ocean_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from .models import (
TraceSpanType,
SpanCommon,
LLMSpanDetails,
ToolSpanDetails,
RetrieverSpanDetails,
WorkflowSpanDetails,
Span,
Trace,
CreateTracesInput,
Expand Down Expand Up @@ -38,6 +43,11 @@

__all__ = [
"TraceSpanType",
"SpanCommon",
"LLMSpanDetails",
"ToolSpanDetails",
"RetrieverSpanDetails",
"WorkflowSpanDetails",
"Span",
"Trace",
"CreateTracesInput",
Expand Down Expand Up @@ -70,4 +80,4 @@
"DOAPINetworkError",
"DOAPIValidationError",
"AsyncDigitalOceanGenAI",
]
]
76 changes: 74 additions & 2 deletions gradient_adk/digital_ocean_api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,68 @@ class TraceSpanType(str, Enum):
TRACE_SPAN_TYPE_LLM = "TRACE_SPAN_TYPE_LLM"
TRACE_SPAN_TYPE_RETRIEVER = "TRACE_SPAN_TYPE_RETRIEVER"
TRACE_SPAN_TYPE_TOOL = "TRACE_SPAN_TYPE_TOOL"
TRACE_SPAN_TYPE_WORKFLOW = "TRACE_SPAN_TYPE_WORKFLOW"
TRACE_SPAN_TYPE_AGENT = "TRACE_SPAN_TYPE_AGENT"


class SpanCommon(BaseModel):
"""Common fields for all span types."""

model_config = ConfigDict(populate_by_name=True, extra="allow")

duration_ns: Optional[int] = Field(None, description="Duration in nanoseconds")
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
tags: Optional[List[str]] = Field(None, description="Tags for the span")
status_code: Optional[int] = Field(None, description="HTTP status code if applicable")


class LLMSpanDetails(BaseModel):
"""LLM-specific span details."""

model_config = ConfigDict(populate_by_name=True, extra="allow")

common: Optional[SpanCommon] = None
model: Optional[str] = Field(None, description="Model name")
num_input_tokens: Optional[int] = Field(None, description="Number of input tokens")
num_output_tokens: Optional[int] = Field(None, description="Number of output tokens")
total_tokens: Optional[int] = Field(None, description="Total tokens")
temperature: Optional[float] = Field(None, description="Temperature setting")
time_to_first_token_ns: Optional[int] = Field(
None, description="Time to first token in nanoseconds"
)


class ToolSpanDetails(BaseModel):
"""Tool-specific span details."""

model_config = ConfigDict(populate_by_name=True, extra="allow")

common: Optional[SpanCommon] = None
tool_call_id: Optional[str] = Field(None, description="Tool call identifier")


class RetrieverSpanDetails(BaseModel):
"""Retriever-specific span details."""

model_config = ConfigDict(populate_by_name=True, extra="allow")

common: Optional[SpanCommon] = None


class WorkflowSpanDetails(BaseModel):
"""Workflow span containing nested sub-spans."""

model_config = ConfigDict(populate_by_name=True, extra="allow")

spans: List["Span"] = Field(default_factory=list, description="Nested sub-spans")


class Span(BaseModel):
"""
Represents a span within a trace (e.g., LLM call, retriever, tool).
Represents a span within a trace (e.g., LLM call, retriever, tool, workflow).
- created_at: RFC3339 timestamp (protobuf Timestamp)
- input/output: json
- For workflow spans, contains nested sub-spans in the 'workflow' field
"""

model_config = ConfigDict(populate_by_name=True, extra="allow")
Expand All @@ -29,6 +84,23 @@ class Span(BaseModel):
output: Dict[str, Any]
type: TraceSpanType = Field(default=TraceSpanType.TRACE_SPAN_TYPE_UNKNOWN)

# Common fields for all span types
common: Optional[SpanCommon] = Field(None, description="Common span metadata")

# Type-specific fields
llm: Optional[LLMSpanDetails] = Field(None, description="LLM-specific details")
tool: Optional[ToolSpanDetails] = Field(None, description="Tool-specific details")
retriever: Optional[RetrieverSpanDetails] = Field(
None, description="Retriever-specific details"
)
workflow: Optional[WorkflowSpanDetails] = Field(
None, description="Workflow span with nested sub-spans"
)


# Update forward reference for WorkflowSpanDetails
WorkflowSpanDetails.model_rebuild()


class Trace(BaseModel):
"""
Expand Down Expand Up @@ -919,4 +991,4 @@ class ListEvaluationMetricsOutput(BaseModel):

metrics: List[EvaluationMetric] = Field(
default_factory=list, description="List of evaluation metrics"
)
)
127 changes: 118 additions & 9 deletions gradient_adk/runtime/digitalocean_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
Trace,
Span,
TraceSpanType,
SpanCommon,
LLMSpanDetails,
ToolSpanDetails,
RetrieverSpanDetails,
WorkflowSpanDetails,
)
from .interfaces import NodeExecution

Expand Down Expand Up @@ -324,22 +329,126 @@ def _to_span(self, ex: NodeExecution) -> Span:
out = dict(out)
out["_llm_endpoints"] = list(ex.metadata["llm_endpoints"])

# classify LLM/tool/retriever via metadata set by the instrumentor
# classify span type via metadata set by the instrumentor
metadata = ex.metadata or {}
if metadata.get("is_llm_call"):

# Check if this is a workflow span
if metadata.get("is_workflow"):
span_type = TraceSpanType.TRACE_SPAN_TYPE_WORKFLOW

# Build sub-spans from the workflow's collected spans
sub_spans_list = metadata.get("sub_spans", [])
sub_spans = [self._to_span(sub) for sub in sub_spans_list]

# Calculate duration from start to end
duration_ns = None
if ex.start_time and ex.end_time:
duration_ns = int(
(ex.end_time - ex.start_time).total_seconds() * 1_000_000_000
)

# Build common fields
common = SpanCommon(
duration_ns=duration_ns,
metadata={"agent_name": metadata.get("agent_name")},
status_code=200 if ex.error is None else 500,
)

# Build workflow details with nested sub-spans
workflow_details = WorkflowSpanDetails(spans=sub_spans)

return Span(
created_at=_utc(ex.start_time),
name=ex.node_name,
input=inp,
output=out,
type=span_type,
common=common,
workflow=workflow_details,
)
elif metadata.get("is_llm_call"):
span_type = TraceSpanType.TRACE_SPAN_TYPE_LLM

# Calculate duration
duration_ns = None
if ex.start_time and ex.end_time:
duration_ns = int(
(ex.end_time - ex.start_time).total_seconds() * 1_000_000_000
)

# Build LLM-specific details
llm_common = SpanCommon(
duration_ns=duration_ns,
status_code=200 if ex.error is None else 500,
)

# Extract LLM-specific fields from output if available
llm_details = LLMSpanDetails(
common=llm_common,
model=metadata.get("model_name") or ex.node_name.replace("llm:", ""),
)

return Span(
created_at=_utc(ex.start_time),
name=ex.node_name,
input=inp,
output=out,
type=span_type,
llm=llm_details,
)
elif metadata.get("is_retriever_call"):
span_type = TraceSpanType.TRACE_SPAN_TYPE_RETRIEVER

# Calculate duration
duration_ns = None
if ex.start_time and ex.end_time:
duration_ns = int(
(ex.end_time - ex.start_time).total_seconds() * 1_000_000_000
)

# Build retriever-specific details
retriever_common = SpanCommon(
duration_ns=duration_ns,
status_code=200 if ex.error is None else 500,
)

retriever_details = RetrieverSpanDetails(common=retriever_common)

return Span(
created_at=_utc(ex.start_time),
name=ex.node_name,
input=inp,
output=out,
type=span_type,
retriever=retriever_details,
)
else:
# Default to tool span
span_type = TraceSpanType.TRACE_SPAN_TYPE_TOOL

return Span(
created_at=_utc(ex.start_time),
name=ex.node_name,
input=inp,
output=out,
type=span_type,
)
# Calculate duration
duration_ns = None
if ex.start_time and ex.end_time:
duration_ns = int(
(ex.end_time - ex.start_time).total_seconds() * 1_000_000_000
)

# Build tool-specific details
tool_common = SpanCommon(
duration_ns=duration_ns,
status_code=200 if ex.error is None else 500,
)

tool_details = ToolSpanDetails(common=tool_common)

return Span(
created_at=_utc(ex.start_time),
name=ex.node_name,
input=inp,
output=out,
type=span_type,
tool=tool_details,
)

def _coerce_top(self, val: Any, kind: str) -> Dict[str, Any]:
"""
Expand Down
Loading