Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 74 additions & 13 deletions gradient_adk/runtime/langgraph/langgraph_instrumentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@

from ..interfaces import NodeExecution
from ..digitalocean_tracker import DigitalOceanTracesTracker
from ..network_interceptor import get_network_interceptor
from ..network_interceptor import (
get_network_interceptor,
is_inference_url,
is_kbaas_url,
)


WRAPPED_FLAG = "__do_wrapped__"
Expand Down Expand Up @@ -232,17 +236,50 @@ def _had_hits_since(intr, token) -> bool:
return False


def _get_captured_payloads(intr, token) -> tuple:
"""Get captured API request/response payloads if available (e.g., for LLM calls)."""
def _get_captured_payloads_with_type(intr, token) -> tuple:
"""Get captured API request/response payloads and classify the call type.

Returns:
(request_payload, response_payload, is_llm, is_retriever)
"""
try:
captured = intr.get_captured_requests_since(token)
if captured:
# Use the first captured request (most common case)
call = captured[0]
return call.request_payload, call.response_payload
url = call.url
is_llm = is_inference_url(url)
is_retriever = is_kbaas_url(url)
return call.request_payload, call.response_payload, is_llm, is_retriever
except Exception:
pass
return None, None
return None, None, False, False


def _transform_kbaas_response(response: Optional[Dict[str, Any]]) -> Optional[list]:
"""Transform KBaaS response to standard retriever format.

Extracts results and converts 'text_content' to 'page_content'.
Returns a list of dicts as expected for retriever spans.
"""
if not isinstance(response, dict):
return response

results = response.get("results", [])
if not isinstance(results, list):
return response

transformed_results = []
for item in results:
if isinstance(item, dict) and "text_content" in item:
new_item = dict(item)
new_item["page_content"] = new_item.pop("text_content")
transformed_results.append(new_item)
else:
transformed_results.append(item)

# Return just the array of results
return transformed_results


class LangGraphInstrumentor:
Expand Down Expand Up @@ -280,20 +317,33 @@ def _finish_ok(
# (_wrap_async_func, _wrap_sync_func, etc.) BEFORE calling _finish_ok.
# The wrappers collect streamed content and pass {"content": "..."} here.

# Check if this node made any tracked API calls (e.g., LLM inference)
# Check if this node made any tracked API calls (e.g., LLM inference or KBaaS retrieval)
if _had_hits_since(intr, tok):
_ensure_meta(rec)["is_llm_call"] = True
# Get captured payloads and classify the call type
api_request, api_response, is_llm, is_retriever = (
_get_captured_payloads_with_type(intr, tok)
)

# Try to get actual API request/response payloads (for LLM calls)
api_request, api_response = _get_captured_payloads(intr, tok)
# Set metadata based on call type
meta = _ensure_meta(rec)
if is_llm:
meta["is_llm_call"] = True
elif is_retriever:
meta["is_retriever_call"] = True
else:
# Fallback: assume LLM call for backward compatibility
meta["is_llm_call"] = True

if api_request or api_response:
# Use actual API payloads instead of function args
if api_request:
rec.inputs = _freeze(api_request)

# Use actual API response as output (e.g., LLM completion)
# Use actual API response as output
if api_response:
# Transform KBaaS response to standard retriever format
if is_retriever:
api_response = _transform_kbaas_response(api_response)
out_payload = _freeze(api_response)
else:
out_payload = _canonical_output(inputs_snapshot, a, kw, ret)
Expand All @@ -306,10 +356,21 @@ def _finish_ok(

def _finish_err(rec: NodeExecution, intr, tok, e: BaseException):
if _had_hits_since(intr, tok):
_ensure_meta(rec)["is_llm_call"] = True
# Get captured payloads and classify the call type
api_request, _, is_llm, is_retriever = _get_captured_payloads_with_type(
intr, tok
)

# Set metadata based on call type
meta = _ensure_meta(rec)
if is_llm:
meta["is_llm_call"] = True
elif is_retriever:
meta["is_retriever_call"] = True
else:
# Fallback: assume LLM call for backward compatibility
meta["is_llm_call"] = True

# Try to get actual API request payload even on error
api_request, _ = _get_captured_payloads(intr, tok)
if api_request:
rec.inputs = _freeze(api_request)

Expand Down
42 changes: 35 additions & 7 deletions gradient_adk/runtime/network_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ class CapturedRequest:

def __init__(
self,
url: Optional[str] = None,
request_payload: Optional[Dict[str, Any]] = None,
response_payload: Optional[Dict[str, Any]] = None,
):
self.url = url
self.request_payload = request_payload
self.response_payload = response_payload

Expand Down Expand Up @@ -287,9 +289,9 @@ def _record_request(
with self._lock:
if self._is_tracked_url(url):
self._hit_count += 1
# Create a new captured request record
# Create a new captured request record with URL
self._captured_requests.append(
CapturedRequest(request_payload=request_payload)
CapturedRequest(url=url, request_payload=request_payload)
)

def _record_response(
Expand Down Expand Up @@ -401,6 +403,25 @@ def hook(url: str, headers: Dict[str, str]) -> Dict[str, str]:
return hook


# URL classification helpers for different DigitalOcean services
INFERENCE_URL_PATTERNS = ["inference.do-ai.run", "inference.do-ai-test.run"]
KBAAS_URL_PATTERNS = ["kbaas.do-ai.run", "kbaas.do-ai-test.run"]


def is_inference_url(url: Optional[str]) -> bool:
"""Check if URL matches DigitalOcean inference (LLM) endpoints."""
if not url:
return False
return any(pattern in url for pattern in INFERENCE_URL_PATTERNS)


def is_kbaas_url(url: Optional[str]) -> bool:
"""Check if URL matches DigitalOcean KBaaS (Knowledge Base) endpoints."""
if not url:
return False
return any(pattern in url for pattern in KBAAS_URL_PATTERNS)


# Global instance
_global_interceptor = NetworkInterceptor()

Expand All @@ -411,14 +432,21 @@ def get_network_interceptor() -> NetworkInterceptor:

def setup_digitalocean_interception() -> None:
intr = get_network_interceptor()
intr.add_endpoint_pattern("inference.do-ai.run")
intr.add_endpoint_pattern("inference.do-ai-test.run")

# Register User-Agent hook for ADK identification
# Add inference (LLM) endpoint patterns
for pattern in INFERENCE_URL_PATTERNS:
intr.add_endpoint_pattern(pattern)

# Add KBaaS (Knowledge Base) endpoint patterns
for pattern in KBAAS_URL_PATTERNS:
intr.add_endpoint_pattern(pattern)

# Register User-Agent hook for ADK identification (all DO endpoints)
all_patterns = INFERENCE_URL_PATTERNS + KBAAS_URL_PATTERNS
ua_hook = create_adk_user_agent_hook(
version=_get_adk_version(),
url_patterns=["inference.do-ai.run", "inference.do-ai-test.run"],
url_patterns=all_patterns,
)
intr.add_request_hook(ua_hook)

intr.start_intercepting()
intr.start_intercepting()
Loading