diff --git a/gradient_adk/__init__.py b/gradient_adk/__init__.py index 82989dd..fdfa6c5 100644 --- a/gradient_adk/__init__.py +++ b/gradient_adk/__init__.py @@ -3,7 +3,7 @@ and the CLI (gradient command). """ -from .decorator import entrypoint +from .decorator import entrypoint, RequestContext from .tracing import ( # manual tracing decorators trace_llm, trace_retriever, @@ -12,6 +12,7 @@ __all__ = [ "entrypoint", + "RequestContext", "trace_llm", "trace_retriever", "trace_tool", diff --git a/gradient_adk/decorator.py b/gradient_adk/decorator.py index 50e4aa6..12e5e5a 100644 --- a/gradient_adk/decorator.py +++ b/gradient_adk/decorator.py @@ -8,8 +8,21 @@ from __future__ import annotations import inspect import json +from dataclasses import dataclass from typing import Callable, Optional, Any, Dict, List + +@dataclass +class RequestContext: + """Context passed to entrypoint functions containing request metadata. + + Attributes: + session_id: The session ID for the request, if provided. + """ + + session_id: Optional[str] = None + + from fastapi import FastAPI, HTTPException, Request from fastapi.responses import StreamingResponse as FastAPIStreamingResponse import uvicorn @@ -144,12 +157,21 @@ async def run(req: Request): is_evaluation = "evaluation-id" in req.headers + # Extract session ID from headers + session_id = req.headers.get("session-id") + context = RequestContext(session_id=session_id) + # Initialize tracker tr = None try: tr = get_tracker() if tr: - tr.on_request_start(func.__name__, body, is_evaluation=is_evaluation) + tr.on_request_start( + func.__name__, + body, + is_evaluation=is_evaluation, + session_id=session_id, + ) except Exception: pass @@ -159,7 +181,7 @@ async def run(req: Request): if num_params == 1: user_gen = func(body) else: - user_gen = func(body, None) + user_gen = func(body, context) except Exception as e: if tr: try: @@ -232,9 +254,9 @@ async def run(req: Request): result = func(body) else: if inspect.iscoroutinefunction(func): - result = await func(body, None) + result = await func(body, context) else: - result = func(body, None) + result = func(body, context) except Exception as e: if tr: try: diff --git a/gradient_adk/runtime/digitalocean_tracker.py b/gradient_adk/runtime/digitalocean_tracker.py index 402b33e..9588942 100644 --- a/gradient_adk/runtime/digitalocean_tracker.py +++ b/gradient_adk/runtime/digitalocean_tracker.py @@ -54,12 +54,17 @@ def __init__( self._is_evaluation: bool = False def on_request_start( - self, entrypoint: str, inputs: Dict[str, Any], is_evaluation: bool = False + self, + entrypoint: str, + inputs: Dict[str, Any], + is_evaluation: bool = False, + session_id: Optional[str] = None, ) -> None: # NEW: reset buffers per request self._live.clear() self._done.clear() self._is_evaluation = is_evaluation + self._session_id = session_id self._req = {"entrypoint": entrypoint, "inputs": inputs} def _as_async_iterable_and_setter( @@ -299,13 +304,14 @@ async def _submit(self) -> Optional[str]: agent_workspace_name=self._ws, agent_deployment_name=self._dep, traces=[trace], + session_id=getattr(self, "_session_id", None), ) result = await self._client.create_traces(req) # Return first trace_uuid if available if result.trace_uuids: return result.trace_uuids[0] return None - except Exception as e: + except Exception: # never break user code on export errors return None diff --git a/integration_tests/example_agents/echo_agent/main.py b/integration_tests/example_agents/echo_agent/main.py index aa416ef..364c3ef 100644 --- a/integration_tests/example_agents/echo_agent/main.py +++ b/integration_tests/example_agents/echo_agent/main.py @@ -3,11 +3,15 @@ Does not make any external API calls - just echoes back the input. """ -from gradient_adk import entrypoint +from gradient_adk import entrypoint, RequestContext @entrypoint -async def main(query, context): +async def main(query, context: RequestContext): """Echo the input back to the caller.""" prompt = query.get("prompt", "no prompt provided") - return {"echo": prompt, "received": query} \ No newline at end of file + return { + "echo": prompt, + "received": query, + "session_id": context.session_id if context else None, + } diff --git a/integration_tests/run/test_adk_agents_run.py b/integration_tests/run/test_adk_agents_run.py index fc48232..c0d20a6 100644 --- a/integration_tests/run/test_adk_agents_run.py +++ b/integration_tests/run/test_adk_agents_run.py @@ -243,8 +243,10 @@ def test_agent_run_no_config(self): # Check for helpful error message combined_output = result.stdout + result.stderr - assert "error" in combined_output.lower() or "configuration" in combined_output.lower(), \ - f"Expected error about missing configuration, got: {combined_output}" + assert ( + "error" in combined_output.lower() + or "configuration" in combined_output.lower() + ), f"Expected error about missing configuration, got: {combined_output}" logger.info("Correctly failed without configuration") @pytest.mark.cli @@ -282,12 +284,17 @@ def test_agent_run_missing_entrypoint(self): ) # Should fail - assert result.returncode != 0, "Command should have failed with missing entrypoint" + assert ( + result.returncode != 0 + ), "Command should have failed with missing entrypoint" # Check for helpful error message combined_output = result.stdout + result.stderr - assert "error" in combined_output.lower() or "not exist" in combined_output.lower() or "nonexistent" in combined_output.lower(), \ - f"Expected error about missing entrypoint, got: {combined_output}" + assert ( + "error" in combined_output.lower() + or "not exist" in combined_output.lower() + or "nonexistent" in combined_output.lower() + ), f"Expected error about missing entrypoint, got: {combined_output}" logger.info("Correctly failed with missing entrypoint file") @pytest.mark.cli @@ -302,10 +309,12 @@ def test_agent_run_invalid_entrypoint_no_decorator(self): # Create a Python file without @entrypoint decorator main_py = temp_path / "main.py" - main_py.write_text(""" + main_py.write_text( + """ def main(query, context): return {"result": "no decorator"} -""") +""" + ) # Create .gradient directory and config gradient_dir = temp_path / ".gradient" @@ -320,13 +329,22 @@ def main(query, context): with open(gradient_dir / "agent.yml", "w") as f: yaml.safe_dump(config, f) - logger.info(f"Testing agent run with invalid entrypoint (no decorator) in {temp_dir}") + logger.info( + f"Testing agent run with invalid entrypoint (no decorator) in {temp_dir}" + ) # Run gradient agent run # This might start but fail to find fastapi_app, or fail on validation # Either way it should not succeed process = subprocess.Popen( - ["gradient", "agent", "run", "--no-dev", "--port", str(find_free_port())], + [ + "gradient", + "agent", + "run", + "--no-dev", + "--port", + str(find_free_port()), + ], cwd=temp_dir, stdout=subprocess.PIPE, stderr=subprocess.PIPE, @@ -339,7 +357,7 @@ def main(query, context): # Check if process exited with error return_code = process.poll() - + if return_code is None: # Process is still running - try to connect and see if it works # (It shouldn't work properly without @entrypoint) @@ -347,7 +365,9 @@ def main(query, context): logger.info("Process started but likely not functioning correctly") else: # Process exited - check return code - assert return_code != 0 or return_code is None, "Expected process to fail or not work correctly" + assert ( + return_code != 0 or return_code is None + ), "Expected process to fail or not work correctly" logger.info(f"Process correctly exited with code {return_code}") finally: cleanup_process(process) @@ -397,7 +417,11 @@ def test_agent_run_run_endpoint_with_various_inputs(self, setup_agent_in_temp): # Test with additional fields response = requests.post( f"http://localhost:{port}/run", - json={"prompt": "test", "extra_field": "value", "nested": {"key": "val"}}, + json={ + "prompt": "test", + "extra_field": "value", + "nested": {"key": "val"}, + }, timeout=10, ) assert response.status_code == 200 @@ -421,6 +445,87 @@ def test_agent_run_run_endpoint_with_various_inputs(self, setup_agent_in_temp): finally: cleanup_process(process) + @pytest.mark.cli + def test_agent_run_session_id_header_passthrough(self, setup_agent_in_temp): + """ + Test that the Session-Id header is passed to the agent context. + Verifies: + - Session-Id header is extracted from request + - Session-Id is available in RequestContext + - Agent can return session_id in response + """ + logger = logging.getLogger(__name__) + temp_dir = setup_agent_in_temp + port = find_free_port() + process = None + + try: + logger.info(f"Starting agent on port {port} in {temp_dir}") + + # Start the agent server + process = subprocess.Popen( + [ + "gradient", + "agent", + "run", + "--port", + str(port), + "--no-dev", + ], + cwd=temp_dir, + start_new_session=True, + ) + + # Wait for server to be ready + server_ready = wait_for_server(port, timeout=30) + assert server_ready, "Server did not start within timeout" + + # Test with Session-Id header + test_session_id = "test-session-12345" + response = requests.post( + f"http://localhost:{port}/run", + json={"prompt": "Hello"}, + headers={"Session-Id": test_session_id}, + timeout=10, + ) + assert response.status_code == 200 + data = response.json() + assert ( + data["session_id"] == test_session_id + ), f"Expected session_id '{test_session_id}', got '{data.get('session_id')}'" + logger.info(f"Session-Id header passthrough test passed: {data}") + + # Test without Session-Id header (should be None) + response = requests.post( + f"http://localhost:{port}/run", + json={"prompt": "Hello without session"}, + timeout=10, + ) + assert response.status_code == 200 + data = response.json() + assert ( + data["session_id"] is None + ), f"Expected session_id to be None, got '{data.get('session_id')}'" + logger.info("No Session-Id header test passed (session_id is None)") + + # Test with lowercase session-id header (case-insensitive) + lowercase_session_id = "lowercase-session-abc" + response = requests.post( + f"http://localhost:{port}/run", + json={"prompt": "Hello with lowercase header"}, + headers={"session-id": lowercase_session_id}, + timeout=10, + ) + assert response.status_code == 200 + data = response.json() + assert ( + data["session_id"] == lowercase_session_id + ), f"Expected session_id '{lowercase_session_id}', got '{data.get('session_id')}'" + logger.info("Lowercase session-id header test passed") + + finally: + cleanup_process(process) + @pytest.mark.cli def test_streaming_agent_without_evaluation_id_streams_response( self, setup_streaming_agent_in_temp @@ -468,18 +573,18 @@ def test_streaming_agent_without_evaluation_id_streams_response( # Verify it's a streaming response (text/event-stream) content_type = response.headers.get("content-type", "") - assert "text/event-stream" in content_type, ( - f"Expected text/event-stream content type for streaming, got: {content_type}" - ) + assert ( + "text/event-stream" in content_type + ), f"Expected text/event-stream content type for streaming, got: {content_type}" # Collect chunks to verify content chunks = list(response.iter_content(decode_unicode=True)) full_content = "".join(c for c in chunks if c) # Verify the content contains the expected streamed output - assert "Echo:" in full_content or "Hello, World!" in full_content, ( - f"Expected streamed content to contain prompt, got: {full_content}" - ) + assert ( + "Echo:" in full_content or "Hello, World!" in full_content + ), f"Expected streamed content to contain prompt, got: {full_content}" logger.info(f"Streaming response received with {len(chunks)} chunks") logger.info(f"Full content: {full_content}") @@ -535,18 +640,18 @@ def test_streaming_agent_with_evaluation_id_returns_single_response( # Verify it's NOT a streaming response (should be application/json) content_type = response.headers.get("content-type", "") - assert "application/json" in content_type, ( - f"Expected application/json content type for evaluation mode, got: {content_type}" - ) + assert ( + "application/json" in content_type + ), f"Expected application/json content type for evaluation mode, got: {content_type}" # Verify the response contains the complete content result = response.json() expected_content = "Echo: Hello, World! [DONE]" - assert result == expected_content, ( - f"Expected complete collected content '{expected_content}', got: {result}" - ) + assert ( + result == expected_content + ), f"Expected complete collected content '{expected_content}', got: {result}" logger.info(f"Single JSON response received: {result}") finally: - cleanup_process(process) \ No newline at end of file + cleanup_process(process) diff --git a/tests/decorator_test.py b/tests/decorator_test.py index 6a16d0e..2aa9615 100644 --- a/tests/decorator_test.py +++ b/tests/decorator_test.py @@ -2,7 +2,7 @@ import pytest from fastapi.testclient import TestClient -from gradient_adk.decorator import entrypoint, run_server +from gradient_adk.decorator import entrypoint, run_server, RequestContext import gradient_adk.decorator as entrypoint_mod @@ -31,11 +31,13 @@ def __init__(self): self.ended = [] self.closed = False self._req = {} + self._session_id = None self._submitted_trace_id = None - def on_request_start(self, name, inputs, is_evaluation=False): - self.started.append((name, inputs, is_evaluation)) + def on_request_start(self, name, inputs, is_evaluation=False, session_id=None): + self.started.append((name, inputs, is_evaluation, session_id)) self._req = {"entrypoint": name, "inputs": inputs} + self._session_id = session_id def on_request_end(self, outputs=None, error=None): self.ended.append((outputs, error)) @@ -121,7 +123,8 @@ def test_run_endpoint_non_streaming_sync_two_params(patch_helpers): @entrypoint def handler(data, context): - assert context is None + assert isinstance(context, RequestContext) + assert context.session_id is None # No Session-Id header provided return {"echo": data} fastapi_app = globals()["fastapi_app"] @@ -394,9 +397,7 @@ async def handler(data): fastapi_app = globals()["fastapi_app"] with TestClient(fastapi_app) as client: # With evaluation-id header, response should NOT be streamed - r = client.post( - "/run", json={"test": 1}, headers={"evaluation-id": "eval-123"} - ) + r = client.post("/run", json={"test": 1}, headers={"evaluation-id": "eval-123"}) assert r.status_code == 200 # Response should be the complete collected output assert r.json() == "hello world" @@ -450,9 +451,7 @@ async def handler(data): fastapi_app = globals()["fastapi_app"] with TestClient(fastapi_app) as client: - r = client.post( - "/run", json={"test": 1}, headers={"evaluation-id": "eval-123"} - ) + r = client.post("/run", json={"test": 1}, headers={"evaluation-id": "eval-123"}) assert r.status_code == 200 # Dict chunks should be JSON serialized and concatenated result = r.json() @@ -475,9 +474,7 @@ async def handler(data): fastapi_app = globals()["fastapi_app"] with TestClient(fastapi_app) as client: - r = client.post( - "/run", json={"test": 1}, headers={"evaluation-id": "eval-123"} - ) + r = client.post("/run", json={"test": 1}, headers={"evaluation-id": "eval-123"}) assert r.status_code == 200 assert r.json() == "ab" # None skipped assert r.headers.get("X-Gradient-Trace-Id") == "test-trace-id-12345" @@ -525,4 +522,193 @@ def handler(data, context): assert calls["host"] == "127.0.0.1" assert calls["port"] == 9999 assert calls["kwargs"]["reload"] is True - assert calls["kwargs"]["log_level"] == "debug" \ No newline at end of file + assert calls["kwargs"]["log_level"] == "debug" + + +# ---------------------------------- +# Tests for RequestContext and session_id +# ---------------------------------- + + +def test_request_context_dataclass(): + """Test that RequestContext is a proper dataclass with expected fields.""" + # Default values + ctx = RequestContext() + assert ctx.session_id is None + + # With session_id + ctx = RequestContext(session_id="test-session-123") + assert ctx.session_id == "test-session-123" + + +def test_session_id_header_passed_to_context_sync(patch_helpers): + """Test that Session-Id header is passed to context for sync handler.""" + tracker = patch_helpers + captured_context = {} + + @entrypoint + def handler(data, context): + captured_context["context"] = context + return {"session_id": context.session_id} + + fastapi_app = globals()["fastapi_app"] + with TestClient(fastapi_app) as client: + r = client.post( + "/run", + json={"test": 1}, + headers={"Session-Id": "my-session-abc"}, + ) + assert r.status_code == 200 + assert r.json() == {"session_id": "my-session-abc"} + + # Verify the captured context + assert isinstance(captured_context["context"], RequestContext) + assert captured_context["context"].session_id == "my-session-abc" + + +def test_session_id_header_case_insensitive(patch_helpers): + """Test that Session-Id header is case-insensitive (lowercase works).""" + tracker = patch_helpers + + @entrypoint + def handler(data, context): + return {"session_id": context.session_id} + + fastapi_app = globals()["fastapi_app"] + with TestClient(fastapi_app) as client: + # Test with lowercase header name + r = client.post( + "/run", + json={"test": 1}, + headers={"session-id": "lowercase-session"}, + ) + assert r.status_code == 200 + assert r.json() == {"session_id": "lowercase-session"} + + +@pytest.mark.asyncio +async def test_session_id_header_passed_to_context_async(patch_helpers): + """Test that Session-Id header is passed to context for async handler.""" + tracker = patch_helpers + + @entrypoint + async def handler(data, context): + await asyncio.sleep(0) + return {"session_id": context.session_id} + + fastapi_app = globals()["fastapi_app"] + with TestClient(fastapi_app) as client: + r = client.post( + "/run", + json={"test": 1}, + headers={"Session-Id": "async-session-xyz"}, + ) + assert r.status_code == 200 + assert r.json() == {"session_id": "async-session-xyz"} + + +def test_session_id_header_passed_to_streaming_handler(patch_helpers): + """Test that Session-Id header is passed to context for streaming handler.""" + tracker = patch_helpers + captured_session_id = {} + + @entrypoint + async def handler(data, context): + captured_session_id["value"] = context.session_id + yield f"session:{context.session_id}" + + fastapi_app = globals()["fastapi_app"] + with TestClient(fastapi_app) as client: + with client.stream( + "POST", + "/run", + json={"test": 1}, + headers={"Session-Id": "streaming-session"}, + ) as resp: + assert resp.status_code == 200 + body = "".join(chunk for chunk in resp.iter_text()) + assert body == "session:streaming-session" + + assert captured_session_id["value"] == "streaming-session" + + +def test_session_id_none_when_header_not_provided(patch_helpers): + """Test that session_id is None when header is not provided.""" + tracker = patch_helpers + + @entrypoint + def handler(data, context): + return { + "has_context": context is not None, + "session_id": context.session_id, + } + + fastapi_app = globals()["fastapi_app"] + with TestClient(fastapi_app) as client: + r = client.post("/run", json={"test": 1}) + assert r.status_code == 200 + result = r.json() + assert result["has_context"] is True + assert result["session_id"] is None + + +def test_session_id_empty_string_when_header_is_empty(patch_helpers): + """Test that session_id is empty string when header value is empty.""" + tracker = patch_helpers + + @entrypoint + def handler(data, context): + return {"session_id": context.session_id} + + fastapi_app = globals()["fastapi_app"] + with TestClient(fastapi_app) as client: + r = client.post( + "/run", + json={"test": 1}, + headers={"Session-Id": ""}, + ) + assert r.status_code == 200 + assert r.json() == {"session_id": ""} + + +def test_session_id_passed_to_tracker(patch_helpers): + """Test that session_id is passed to the tracker for tracing.""" + tracker = patch_helpers + + @entrypoint + def handler(data, context): + return {"ok": True} + + fastapi_app = globals()["fastapi_app"] + with TestClient(fastapi_app) as client: + r = client.post( + "/run", + json={"test": 1}, + headers={"Session-Id": "trace-session-456"}, + ) + assert r.status_code == 200 + + # Verify that session_id was passed to the tracker + assert tracker.started + # started tuple is (name, inputs, is_evaluation, session_id) + assert tracker.started[-1][3] == "trace-session-456" + assert tracker._session_id == "trace-session-456" + + +def test_session_id_none_passed_to_tracker_when_no_header(patch_helpers): + """Test that session_id is None in tracker when header is not provided.""" + tracker = patch_helpers + + @entrypoint + def handler(data, context): + return {"ok": True} + + fastapi_app = globals()["fastapi_app"] + with TestClient(fastapi_app) as client: + r = client.post("/run", json={"test": 1}) + assert r.status_code == 200 + + # Verify that session_id is None when header not provided + assert tracker.started + assert tracker.started[-1][3] is None + assert tracker._session_id is None diff --git a/tests/runtime/digitalocean_tracker_test.py b/tests/runtime/digitalocean_tracker_test.py index 4237ea5..aad7d84 100644 --- a/tests/runtime/digitalocean_tracker_test.py +++ b/tests/runtime/digitalocean_tracker_test.py @@ -315,6 +315,60 @@ async def test_evaluation_mode_does_not_fire_and_forget(self, tracker, mock_clie assert mock_client.create_traces.called +class TestSessionId: + """Test session_id handling in traces.""" + + @pytest.mark.asyncio + async def test_session_id_passed_to_create_traces(self, tracker, mock_client): + """Test that session_id is passed to CreateTracesInput.""" + tracker.on_request_start( + "agent", {"input": "test"}, is_evaluation=False, session_id="sess-abc-123" + ) + tracker.on_request_end(outputs={"result": "ok"}, error=None) + await tracker.aclose() + + assert mock_client.create_traces.called + call_args = mock_client.create_traces.call_args[0][0] + assert isinstance(call_args, CreateTracesInput) + assert call_args.session_id == "sess-abc-123" + + @pytest.mark.asyncio + async def test_session_id_none_when_not_provided(self, tracker, mock_client): + """Test that session_id is None when not provided.""" + tracker.on_request_start("agent", {"input": "test"}, is_evaluation=False) + tracker.on_request_end(outputs={"result": "ok"}, error=None) + await tracker.aclose() + + assert mock_client.create_traces.called + call_args = mock_client.create_traces.call_args[0][0] + assert call_args.session_id is None + + @pytest.mark.asyncio + async def test_session_id_preserved_across_request(self, tracker, mock_client): + """Test that session_id is preserved throughout the request lifecycle.""" + tracker.on_request_start( + "agent", {"input": "test"}, is_evaluation=True, session_id="eval-session" + ) + + node = NodeExecution( + node_id="node-1", + node_name="process", + framework="langgraph", + start_time=datetime.now(timezone.utc), + inputs={"test": "data"}, + ) + tracker.on_node_start(node) + tracker.on_node_end(node, {"result": "done"}) + tracker.on_request_end(outputs={"final": "result"}, error=None) + + # Use submit_and_get_trace_id for evaluation mode + await tracker.submit_and_get_trace_id() + + assert mock_client.create_traces.called + call_args = mock_client.create_traces.call_args[0][0] + assert call_args.session_id == "eval-session" + + class TestTopLevelIO: """Test top-level input/output normalization.""" @@ -719,4 +773,4 @@ async def mock_stream(): # After streaming, outputs should be filled assert tracker._req.get("outputs") is not None - assert "test" in str(tracker._req.get("outputs")) + assert "test" in str(tracker._req.get("outputs")) \ No newline at end of file