Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 125 additions & 6 deletions gradient_adk/runtime/network_interceptor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
from __future__ import annotations
import importlib
import os
import threading
import json
from typing import Set, List, Dict, Any, Optional
from typing import Set, List, Dict, Any, Optional, Callable
import httpx, requests


def _get_adk_version() -> str:
"""Get the version from package metadata."""
try:
return importlib.metadata.version("gradient-adk")
except importlib.metadata.PackageNotFoundError:
return "unknown"


# Type for request hooks: (url, headers) -> modified_headers
RequestHook = Callable[[str, Dict[str, str]], Dict[str, str]]


class CapturedRequest:
"""Represents a captured HTTP request/response."""

Expand Down Expand Up @@ -32,6 +46,7 @@ def __init__(self):
self._captured_requests: List[CapturedRequest] = (
[]
) # Capture request/response pairs
self._request_hooks: List[RequestHook] = [] # Hooks to modify outgoing requests
self._lock = threading.Lock()
self._active = False
# originals
Expand Down Expand Up @@ -73,6 +88,20 @@ def clear_hits(self) -> None:
self._hit_count = 0
self._captured_requests.clear()

def add_request_hook(self, hook: RequestHook) -> None:
"""Register a hook to modify outgoing request headers."""
self._request_hooks.append(hook)

def _apply_request_hooks(self, url: str, headers: Dict[str, str]) -> Dict[str, str]:
"""Apply all registered request hooks to headers."""
headers = dict(headers) if headers else {}
for hook in self._request_hooks:
try:
headers = hook(url, headers)
except Exception:
pass # Never break requests due to hook errors
return headers

def start_intercepting(self) -> None:
if self._active:
return
Expand All @@ -87,6 +116,19 @@ def start_intercepting(self) -> None:
# patch httpx (async)
async def intercepted_httpx_send(self_client, request, **kwargs):
url_str = str(request.url)

# Apply request hooks to modify headers
new_headers = _global_interceptor._apply_request_hooks(
url_str, dict(request.headers)
)
if new_headers != dict(request.headers):
request = httpx.Request(
request.method,
request.url,
headers=new_headers,
content=request.content,
)

request_payload = _global_interceptor._extract_request_payload(request)
_global_interceptor._record_request(url_str, request_payload)

Expand All @@ -97,12 +139,12 @@ async def intercepted_httpx_send(self_client, request, **kwargs):
# Don't read response body for streaming responses - it would buffer the entire stream!
# Check if this is a streaming response by looking at headers or response type
is_streaming = (
response.headers.get("transfer-encoding") == "chunked" or
"text/event-stream" in response.headers.get("content-type", "") or
hasattr(response, "aiter_bytes") or
hasattr(response, "aiter_lines")
response.headers.get("transfer-encoding") == "chunked"
or "text/event-stream" in response.headers.get("content-type", "")
or hasattr(response, "aiter_bytes")
or hasattr(response, "aiter_lines")
)

if not is_streaming:
response_payload = await _global_interceptor._extract_response_payload(
response
Expand All @@ -114,6 +156,12 @@ async def intercepted_httpx_send(self_client, request, **kwargs):

def intercepted_httpx_request(self_client, method, url, **kwargs):
url_str = str(url)

# Apply request hooks to modify headers
kwargs["headers"] = _global_interceptor._apply_request_hooks(
url_str, kwargs.get("headers", {})
)

request_payload = _global_interceptor._extract_request_payload_from_kwargs(
kwargs
)
Expand All @@ -130,6 +178,19 @@ def intercepted_httpx_request(self_client, method, url, **kwargs):
# patch httpx (sync)
def intercepted_httpx_sync_send(self_client, request, **kwargs):
url_str = str(request.url)

# Apply request hooks to modify headers
new_headers = _global_interceptor._apply_request_hooks(
url_str, dict(request.headers)
)
if new_headers != dict(request.headers):
request = httpx.Request(
request.method,
request.url,
headers=new_headers,
content=request.content,
)

request_payload = _global_interceptor._extract_request_payload(request)
_global_interceptor._record_request(url_str, request_payload)

Expand All @@ -146,6 +207,12 @@ def intercepted_httpx_sync_send(self_client, request, **kwargs):

def intercepted_httpx_sync_request(self_client, method, url, **kwargs):
url_str = str(url)

# Apply request hooks to modify headers
kwargs["headers"] = _global_interceptor._apply_request_hooks(
url_str, kwargs.get("headers", {})
)

request_payload = _global_interceptor._extract_request_payload_from_kwargs(
kwargs
)
Expand All @@ -160,6 +227,12 @@ def intercepted_httpx_sync_request(self_client, method, url, **kwargs):
# patch requests
def intercepted_requests_request(self_session, method, url, **kwargs):
url_str = str(url)

# Apply request hooks to modify headers
kwargs["headers"] = _global_interceptor._apply_request_hooks(
url_str, kwargs.get("headers", {})
)

request_payload = _global_interceptor._extract_request_payload_from_kwargs(
kwargs
)
Expand Down Expand Up @@ -290,6 +363,44 @@ def _extract_response_payload_from_requests(
return None


def create_adk_user_agent_hook(version: str, url_patterns: List[str]) -> RequestHook:
"""
Factory to create a User-Agent hook for specific URL patterns.

Completely replaces the User-Agent header with the Gradient ADK identifier
for requests matching the specified URL patterns.

Format: Gradient/adk/{version} or Gradient/adk/{version}/{uuid}

Args:
version: The ADK version string (e.g., "0.0.5")
url_patterns: List of URL substrings to match (e.g., ["inference.do-ai.run"])

Returns:
A request hook function that can be registered with NetworkInterceptor
"""

def hook(url: str, headers: Dict[str, str]) -> Dict[str, str]:
# Check if URL matches any pattern
if not any(pattern in url for pattern in url_patterns):
return headers

# Remove old User-Agent keys (both cases) to avoid duplicates
headers.pop("User-Agent", None)
headers.pop("user-agent", None)

# Build new User-Agent: Gradient/adk/{version} or Gradient/adk/{version}/{uuid}
user_agent = f"Gradient/adk/{version}"
deployment_uuid = os.environ.get("AGENT_WORKSPACE_DEPLOYMENT_UUID")
if deployment_uuid:
user_agent += f"/{deployment_uuid}"

headers["User-Agent"] = user_agent
return headers

return hook


# Global instance
_global_interceptor = NetworkInterceptor()

Expand All @@ -302,4 +413,12 @@ def setup_digitalocean_interception() -> None:
intr = get_network_interceptor()
intr.add_endpoint_pattern("inference.do-ai.run")
intr.add_endpoint_pattern("inference.do-ai-test.run")

# Register User-Agent hook for ADK identification
ua_hook = create_adk_user_agent_hook(
version=_get_adk_version(),
url_patterns=["inference.do-ai.run", "inference.do-ai-test.run"],
)
intr.add_request_hook(ua_hook)

intr.start_intercepting()
Loading