diff --git a/gradient_adk/runtime/network_interceptor.py b/gradient_adk/runtime/network_interceptor.py index e1cb0d5..10b09aa 100644 --- a/gradient_adk/runtime/network_interceptor.py +++ b/gradient_adk/runtime/network_interceptor.py @@ -1,10 +1,24 @@ from __future__ import annotations +import importlib +import os import threading import json -from typing import Set, List, Dict, Any, Optional +from typing import Set, List, Dict, Any, Optional, Callable import httpx, requests +def _get_adk_version() -> str: + """Get the version from package metadata.""" + try: + return importlib.metadata.version("gradient-adk") + except importlib.metadata.PackageNotFoundError: + return "unknown" + + +# Type for request hooks: (url, headers) -> modified_headers +RequestHook = Callable[[str, Dict[str, str]], Dict[str, str]] + + class CapturedRequest: """Represents a captured HTTP request/response.""" @@ -32,6 +46,7 @@ def __init__(self): self._captured_requests: List[CapturedRequest] = ( [] ) # Capture request/response pairs + self._request_hooks: List[RequestHook] = [] # Hooks to modify outgoing requests self._lock = threading.Lock() self._active = False # originals @@ -73,6 +88,20 @@ def clear_hits(self) -> None: self._hit_count = 0 self._captured_requests.clear() + def add_request_hook(self, hook: RequestHook) -> None: + """Register a hook to modify outgoing request headers.""" + self._request_hooks.append(hook) + + def _apply_request_hooks(self, url: str, headers: Dict[str, str]) -> Dict[str, str]: + """Apply all registered request hooks to headers.""" + headers = dict(headers) if headers else {} + for hook in self._request_hooks: + try: + headers = hook(url, headers) + except Exception: + pass # Never break requests due to hook errors + return headers + def start_intercepting(self) -> None: if self._active: return @@ -87,6 +116,19 @@ def start_intercepting(self) -> None: # patch httpx (async) async def intercepted_httpx_send(self_client, request, **kwargs): url_str = str(request.url) + + # Apply request hooks to modify headers + new_headers = _global_interceptor._apply_request_hooks( + url_str, dict(request.headers) + ) + if new_headers != dict(request.headers): + request = httpx.Request( + request.method, + request.url, + headers=new_headers, + content=request.content, + ) + request_payload = _global_interceptor._extract_request_payload(request) _global_interceptor._record_request(url_str, request_payload) @@ -97,12 +139,12 @@ async def intercepted_httpx_send(self_client, request, **kwargs): # Don't read response body for streaming responses - it would buffer the entire stream! # Check if this is a streaming response by looking at headers or response type is_streaming = ( - response.headers.get("transfer-encoding") == "chunked" or - "text/event-stream" in response.headers.get("content-type", "") or - hasattr(response, "aiter_bytes") or - hasattr(response, "aiter_lines") + response.headers.get("transfer-encoding") == "chunked" + or "text/event-stream" in response.headers.get("content-type", "") + or hasattr(response, "aiter_bytes") + or hasattr(response, "aiter_lines") ) - + if not is_streaming: response_payload = await _global_interceptor._extract_response_payload( response @@ -114,6 +156,12 @@ async def intercepted_httpx_send(self_client, request, **kwargs): def intercepted_httpx_request(self_client, method, url, **kwargs): url_str = str(url) + + # Apply request hooks to modify headers + kwargs["headers"] = _global_interceptor._apply_request_hooks( + url_str, kwargs.get("headers", {}) + ) + request_payload = _global_interceptor._extract_request_payload_from_kwargs( kwargs ) @@ -130,6 +178,19 @@ def intercepted_httpx_request(self_client, method, url, **kwargs): # patch httpx (sync) def intercepted_httpx_sync_send(self_client, request, **kwargs): url_str = str(request.url) + + # Apply request hooks to modify headers + new_headers = _global_interceptor._apply_request_hooks( + url_str, dict(request.headers) + ) + if new_headers != dict(request.headers): + request = httpx.Request( + request.method, + request.url, + headers=new_headers, + content=request.content, + ) + request_payload = _global_interceptor._extract_request_payload(request) _global_interceptor._record_request(url_str, request_payload) @@ -146,6 +207,12 @@ def intercepted_httpx_sync_send(self_client, request, **kwargs): def intercepted_httpx_sync_request(self_client, method, url, **kwargs): url_str = str(url) + + # Apply request hooks to modify headers + kwargs["headers"] = _global_interceptor._apply_request_hooks( + url_str, kwargs.get("headers", {}) + ) + request_payload = _global_interceptor._extract_request_payload_from_kwargs( kwargs ) @@ -160,6 +227,12 @@ def intercepted_httpx_sync_request(self_client, method, url, **kwargs): # patch requests def intercepted_requests_request(self_session, method, url, **kwargs): url_str = str(url) + + # Apply request hooks to modify headers + kwargs["headers"] = _global_interceptor._apply_request_hooks( + url_str, kwargs.get("headers", {}) + ) + request_payload = _global_interceptor._extract_request_payload_from_kwargs( kwargs ) @@ -290,6 +363,44 @@ def _extract_response_payload_from_requests( return None +def create_adk_user_agent_hook(version: str, url_patterns: List[str]) -> RequestHook: + """ + Factory to create a User-Agent hook for specific URL patterns. + + Completely replaces the User-Agent header with the Gradient ADK identifier + for requests matching the specified URL patterns. + + Format: Gradient/adk/{version} or Gradient/adk/{version}/{uuid} + + Args: + version: The ADK version string (e.g., "0.0.5") + url_patterns: List of URL substrings to match (e.g., ["inference.do-ai.run"]) + + Returns: + A request hook function that can be registered with NetworkInterceptor + """ + + def hook(url: str, headers: Dict[str, str]) -> Dict[str, str]: + # Check if URL matches any pattern + if not any(pattern in url for pattern in url_patterns): + return headers + + # Remove old User-Agent keys (both cases) to avoid duplicates + headers.pop("User-Agent", None) + headers.pop("user-agent", None) + + # Build new User-Agent: Gradient/adk/{version} or Gradient/adk/{version}/{uuid} + user_agent = f"Gradient/adk/{version}" + deployment_uuid = os.environ.get("AGENT_WORKSPACE_DEPLOYMENT_UUID") + if deployment_uuid: + user_agent += f"/{deployment_uuid}" + + headers["User-Agent"] = user_agent + return headers + + return hook + + # Global instance _global_interceptor = NetworkInterceptor() @@ -302,4 +413,12 @@ def setup_digitalocean_interception() -> None: intr = get_network_interceptor() intr.add_endpoint_pattern("inference.do-ai.run") intr.add_endpoint_pattern("inference.do-ai-test.run") + + # Register User-Agent hook for ADK identification + ua_hook = create_adk_user_agent_hook( + version=_get_adk_version(), + url_patterns=["inference.do-ai.run", "inference.do-ai-test.run"], + ) + intr.add_request_hook(ua_hook) + intr.start_intercepting() diff --git a/tests/runtime/network_interceptor_test.py b/tests/runtime/network_interceptor_test.py index 53fc04f..6f79520 100644 --- a/tests/runtime/network_interceptor_test.py +++ b/tests/runtime/network_interceptor_test.py @@ -9,6 +9,8 @@ NetworkInterceptor, get_network_interceptor, setup_digitalocean_interception, + create_adk_user_agent_hook, + RequestHook, ) @@ -17,7 +19,7 @@ def reset_global_interceptor(): """ Ensure a clean singleton between tests: - stop intercepting (restores patched methods) - - clear patterns and hits + - clear patterns, hits, and hooks """ intr = get_network_interceptor() try: @@ -25,9 +27,10 @@ def reset_global_interceptor(): finally: # brute-force cleanup of internal state intr.clear_hits() - # Not public, but safe for tests: nuke patterns set + # Not public, but safe for tests: nuke patterns set and hooks with intr._lock: intr._tracked_endpoints.clear() + intr._request_hooks.clear() intr._original_httpx_request = None intr._original_httpx_send = None intr._original_httpx_sync_request = None @@ -43,6 +46,7 @@ def reset_global_interceptor(): intr.clear_hits() with intr._lock: intr._tracked_endpoints.clear() + intr._request_hooks.clear() @pytest.fixture @@ -228,3 +232,239 @@ def test_setup_digitalocean_interception( # so we just exercise the internal recorder to confirm patterns work: intr._record_request("https://inference.do-ai.run/v1/chat") assert intr.hits_since(0) == 1 + + +# ---- Request Hook Tests ---- + + +def test_add_request_hook(intr): + """Test that hooks can be registered.""" + assert len(intr._request_hooks) == 0 + + def my_hook(url: str, headers: dict) -> dict: + headers["X-Custom"] = "value" + return headers + + intr.add_request_hook(my_hook) + assert len(intr._request_hooks) == 1 + assert intr._request_hooks[0] is my_hook + + +def test_apply_request_hooks_empty(intr): + """Test applying hooks when none are registered.""" + result = intr._apply_request_hooks("http://example.com", {"Existing": "header"}) + assert result == {"Existing": "header"} + + +def test_apply_request_hooks_single(intr): + """Test applying a single hook.""" + + def add_custom_header(url: str, headers: dict) -> dict: + headers["X-Custom"] = "added" + return headers + + intr.add_request_hook(add_custom_header) + result = intr._apply_request_hooks("http://example.com", {"Existing": "header"}) + assert result == {"Existing": "header", "X-Custom": "added"} + + +def test_apply_request_hooks_multiple(intr): + """Test applying multiple hooks in order.""" + + def hook1(url: str, headers: dict) -> dict: + headers["X-First"] = "1" + return headers + + def hook2(url: str, headers: dict) -> dict: + headers["X-Second"] = "2" + return headers + + intr.add_request_hook(hook1) + intr.add_request_hook(hook2) + + result = intr._apply_request_hooks("http://example.com", {}) + assert result == {"X-First": "1", "X-Second": "2"} + + +def test_apply_request_hooks_with_none_headers(intr): + """Test that hooks handle None headers gracefully.""" + + def add_header(url: str, headers: dict) -> dict: + headers["X-Added"] = "value" + return headers + + intr.add_request_hook(add_header) + result = intr._apply_request_hooks("http://example.com", None) + assert result == {"X-Added": "value"} + + +def test_apply_request_hooks_error_handling(intr): + """Test that a failing hook doesn't break other hooks.""" + + def failing_hook(url: str, headers: dict) -> dict: + raise RuntimeError("Hook failed!") + + def working_hook(url: str, headers: dict) -> dict: + headers["X-Works"] = "yes" + return headers + + intr.add_request_hook(failing_hook) + intr.add_request_hook(working_hook) + + # Should not raise, and working_hook should still apply + result = intr._apply_request_hooks("http://example.com", {}) + assert result == {"X-Works": "yes"} + + +# ---- ADK User-Agent Hook Tests ---- + + +def test_create_adk_user_agent_hook_matching_url(): + """Test that the hook completely replaces User-Agent for matching URLs.""" + hook = create_adk_user_agent_hook( + version="1.2.3", url_patterns=["api.example.com", "api.test.com"] + ) + + headers = {"User-Agent": "MyClient/1.0"} + result = hook("https://api.example.com/v1/chat", headers) + + # Should completely replace, not append + assert result["User-Agent"] == "Gradient/adk/1.2.3" + + +def test_create_adk_user_agent_hook_non_matching_url(): + """Test that the hook doesn't modify User-Agent for non-matching URLs.""" + hook = create_adk_user_agent_hook(version="1.2.3", url_patterns=["api.example.com"]) + + headers = {"User-Agent": "MyClient/1.0"} + result = hook("https://other.domain.com/v1/chat", headers) + + # Should be unchanged + assert result["User-Agent"] == "MyClient/1.0" + + +def test_create_adk_user_agent_hook_no_existing_user_agent(): + """Test that the hook works when there's no existing User-Agent.""" + hook = create_adk_user_agent_hook(version="2.0.0", url_patterns=["api.test"]) + + headers = {} + result = hook("https://api.test/endpoint", headers) + + assert result["User-Agent"] == "Gradient/adk/2.0.0" + + +def test_create_adk_user_agent_hook_with_deployment_uuid(monkeypatch): + """Test that the hook includes deployment UUID when available.""" + monkeypatch.setenv("AGENT_WORKSPACE_DEPLOYMENT_UUID", "deploy-abc-123") + + hook = create_adk_user_agent_hook(version="1.0.0", url_patterns=["api.example"]) + + headers = {"User-Agent": "Client/1.0"} + result = hook("https://api.example/v1", headers) + + # Format: Gradient/adk/{version}/{uuid} + assert result["User-Agent"] == "Gradient/adk/1.0.0/deploy-abc-123" + + +def test_create_adk_user_agent_hook_without_deployment_uuid(monkeypatch): + """Test that the hook works without deployment UUID.""" + monkeypatch.delenv("AGENT_WORKSPACE_DEPLOYMENT_UUID", raising=False) + + hook = create_adk_user_agent_hook(version="1.0.0", url_patterns=["api.example"]) + + headers = {"User-Agent": "Client/1.0"} + result = hook("https://api.example/v1", headers) + + # Format: Gradient/adk/{version} - completely replaces original + assert result["User-Agent"] == "Gradient/adk/1.0.0" + + +def test_create_adk_user_agent_hook_lowercase_user_agent(): + """Test that the hook handles lowercase 'user-agent' header and replaces it.""" + hook = create_adk_user_agent_hook(version="1.0.0", url_patterns=["api.test"]) + + # Some HTTP libraries use lowercase headers + headers = {"user-agent": "LowercaseClient/1.0"} + result = hook("https://api.test/endpoint", headers) + + # Should completely replace with Gradient format and remove old lowercase key + assert result["User-Agent"] == "Gradient/adk/1.0.0" + # Old lowercase key should be removed to avoid duplicates + assert "user-agent" not in result + + +def test_create_adk_user_agent_hook_removes_duplicate_keys(): + """Test that the hook removes both cases to avoid duplicate headers.""" + hook = create_adk_user_agent_hook(version="1.0.0", url_patterns=["api.test"]) + + # Test with lowercase key (common in gradient/openai SDK) + headers = {"user-agent": "AsyncGradient/Python/3.10.1", "other-header": "value"} + result = hook("https://api.test/endpoint", headers) + + # Should completely replace with Gradient format + assert result["User-Agent"] == "Gradient/adk/1.0.0" + assert "user-agent" not in result + assert result["other-header"] == "value" + # Count keys to ensure no duplicates + user_agent_keys = [k for k in result.keys() if k.lower() == "user-agent"] + assert len(user_agent_keys) == 1 + + +# ---- Integration Tests for Hooks with Interception ---- + + +def test_hooks_applied_during_httpx_sync_request(intr, stub_httpx_sync): + """Test that hooks are applied when making httpx sync requests.""" + captured_headers = {} + + def capture_hook(url: str, headers: dict) -> dict: + headers["X-Captured"] = "yes" + captured_headers.update(headers) + return headers + + intr.add_endpoint_pattern("api.test") + intr.add_request_hook(capture_hook) + intr.start_intercepting() + + c = httpx.Client() + c.get("http://api.test/resource", headers={"Original": "header"}) + + # Hook should have been called and added the header + assert "X-Captured" in captured_headers + assert captured_headers["X-Captured"] == "yes" + assert captured_headers["Original"] == "header" + + +def test_hooks_applied_during_requests_session(intr, stub_requests): + """Test that hooks are applied when making requests.Session calls.""" + captured_headers = {} + + def capture_hook(url: str, headers: dict) -> dict: + headers["X-From-Hook"] = "hooked" + captured_headers.update(headers) + return headers + + intr.add_endpoint_pattern("billing.svc") + intr.add_request_hook(capture_hook) + intr.start_intercepting() + + s = requests.Session() + s.post("https://billing.svc/process", headers={"Content-Type": "application/json"}) + + assert captured_headers.get("X-From-Hook") == "hooked" + + +def test_setup_digitalocean_interception_registers_ua_hook(): + """Test that setup_digitalocean_interception registers the User-Agent hook.""" + setup_digitalocean_interception() + intr = get_network_interceptor() + + # Should have at least one hook registered + assert len(intr._request_hooks) >= 1 + + # Test that the hook modifies headers for DO inference URLs + headers = {"User-Agent": "TestClient/1.0"} + result = intr._apply_request_hooks("https://inference.do-ai.run/v1/chat", headers) + + # Should completely replace with Gradient/adk/{version} format + assert result["User-Agent"].startswith("Gradient/adk/") \ No newline at end of file