diff --git a/gradient_adk/decorator.py b/gradient_adk/decorator.py index da70631..8031eda 100644 --- a/gradient_adk/decorator.py +++ b/gradient_adk/decorator.py @@ -166,7 +166,49 @@ async def run(req: Request): logger.error("Error creating generator", error=str(e), exc_info=True) raise HTTPException(status_code=500, detail="Internal server error") - # Wrap in tracking iterator + # If evaluation mode, collect all chunks and return as single response with trace ID + if is_evaluation: + from fastapi.responses import JSONResponse + + collected_chunks: List[str] = [] + try: + async for chunk in user_gen: + if isinstance(chunk, bytes): + chunk_str = chunk.decode("utf-8", errors="replace") + elif isinstance(chunk, dict): + chunk_str = json.dumps(chunk) + elif chunk is None: + continue + else: + chunk_str = str(chunk) + collected_chunks.append(chunk_str) + + result = "".join(collected_chunks) + + # Submit tracking and get trace ID + trace_id = None + if tr: + try: + tr._req["outputs"] = result + trace_id = await tr.submit_and_get_trace_id() + except Exception: + pass + + headers = {"X-Gradient-Trace-Id": trace_id} if trace_id else {} + return JSONResponse(content=result, headers=headers) + + except Exception as e: + if tr: + try: + tr._req["outputs"] = "".join(collected_chunks) + tr._req["error"] = str(e) + await tr._submit() + except Exception: + pass + logger.error("Error in streaming evaluation", error=str(e), exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error") + + # Normal streaming case - wrap in tracking iterator streaming_iter = _StreamingIteratorWithTracking(user_gen, tr, func.__name__) return FastAPIStreamingResponse( @@ -234,4 +276,4 @@ async def health(): def run_server(fastapi_app: FastAPI, host: str = "0.0.0.0", port: int = 8080, **kwargs): """Run the FastAPI server with uvicorn.""" - uvicorn.run(fastapi_app, host=host, port=port, **kwargs) + uvicorn.run(fastapi_app, host=host, port=port, **kwargs) \ No newline at end of file diff --git a/integration_tests/example_agents/streaming_echo_agent/main.py b/integration_tests/example_agents/streaming_echo_agent/main.py new file mode 100644 index 0000000..4750d09 --- /dev/null +++ b/integration_tests/example_agents/streaming_echo_agent/main.py @@ -0,0 +1,17 @@ +""" +Streaming echo agent for integration testing. +Does not make any external API calls - just echoes back the input in chunks. +Used to test streaming vs non-streaming behavior with evaluation-id header. +""" + +from gradient_adk import entrypoint + + +@entrypoint +async def main(query, context): + """Streaming echo agent - yields the response in chunks.""" + prompt = query.get("prompt", "no prompt provided") + # Stream the response in multiple chunks + yield "Echo: " + yield prompt + yield " [DONE]" \ No newline at end of file diff --git a/integration_tests/run/test_adk_agents_run.py b/integration_tests/run/test_adk_agents_run.py index 156cb08..fc48232 100644 --- a/integration_tests/run/test_adk_agents_run.py +++ b/integration_tests/run/test_adk_agents_run.py @@ -59,6 +59,11 @@ def echo_agent_dir(self): """Get the path to the echo agent directory.""" return Path(__file__).parent.parent / "example_agents" / "echo_agent" + @pytest.fixture + def streaming_echo_agent_dir(self): + """Get the path to the streaming echo agent directory.""" + return Path(__file__).parent.parent / "example_agents" / "streaming_echo_agent" + @pytest.fixture def setup_agent_in_temp(self, echo_agent_dir): """ @@ -86,6 +91,33 @@ def setup_agent_in_temp(self, echo_agent_dir): yield temp_path + @pytest.fixture + def setup_streaming_agent_in_temp(self, streaming_echo_agent_dir): + """ + Setup a temporary directory with the streaming echo agent and proper configuration. + Yields the temp directory path and cleans up after. + """ + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Copy the streaming echo agent main.py + shutil.copy(streaming_echo_agent_dir / "main.py", temp_path / "main.py") + + # Create .gradient directory and config + gradient_dir = temp_path / ".gradient" + gradient_dir.mkdir() + + config = { + "agent_name": "test-streaming-echo-agent", + "agent_environment": "main", + "entrypoint_file": "main.py", + } + + with open(gradient_dir / "agent.yml", "w") as f: + yaml.safe_dump(config, f) + + yield temp_path + @pytest.mark.cli def test_agent_run_happy_path(self, setup_agent_in_temp): """ @@ -386,5 +418,135 @@ def test_agent_run_run_endpoint_with_various_inputs(self, setup_agent_in_temp): assert data["echo"] == "Hello `} E1-('" logger.info("Unicode test passed") + finally: + cleanup_process(process) + + @pytest.mark.cli + def test_streaming_agent_without_evaluation_id_streams_response( + self, setup_streaming_agent_in_temp + ): + """ + Test that a streaming agent returns a streamed response when no evaluation-id header is sent. + Verifies: + - Response is streamed (text/event-stream content type) + - Response contains the expected content + """ + logger = logging.getLogger(__name__) + temp_dir = setup_streaming_agent_in_temp + port = find_free_port() + process = None + + try: + logger.info(f"Starting streaming agent on port {port} in {temp_dir}") + + # Start the agent server + process = subprocess.Popen( + [ + "gradient", + "agent", + "run", + "--port", + str(port), + "--no-dev", + ], + cwd=temp_dir, + start_new_session=True, + ) + + # Wait for server to be ready + server_ready = wait_for_server(port, timeout=30) + assert server_ready, "Server did not start within timeout" + + # Make a streaming request WITHOUT evaluation-id header + with requests.post( + f"http://localhost:{port}/run", + json={"prompt": "Hello, World!"}, + stream=True, + timeout=30, + ) as response: + assert response.status_code == 200 + + # Verify it's a streaming response (text/event-stream) + content_type = response.headers.get("content-type", "") + assert "text/event-stream" in content_type, ( + f"Expected text/event-stream content type for streaming, got: {content_type}" + ) + + # Collect chunks to verify content + chunks = list(response.iter_content(decode_unicode=True)) + full_content = "".join(c for c in chunks if c) + + # Verify the content contains the expected streamed output + assert "Echo:" in full_content or "Hello, World!" in full_content, ( + f"Expected streamed content to contain prompt, got: {full_content}" + ) + + logger.info(f"Streaming response received with {len(chunks)} chunks") + logger.info(f"Full content: {full_content}") + + finally: + cleanup_process(process) + + @pytest.mark.cli + def test_streaming_agent_with_evaluation_id_returns_single_response( + self, setup_streaming_agent_in_temp + ): + """ + Test that a streaming agent returns a single JSON response (not streamed) + when the evaluation-id header is present. + Verifies: + - Response is NOT streamed (application/json content type) + - Response contains the complete collected content + """ + logger = logging.getLogger(__name__) + temp_dir = setup_streaming_agent_in_temp + port = find_free_port() + process = None + + try: + logger.info(f"Starting streaming agent on port {port} in {temp_dir}") + + # Start the agent server + process = subprocess.Popen( + [ + "gradient", + "agent", + "run", + "--port", + str(port), + "--no-dev", + ], + cwd=temp_dir, + start_new_session=True, + ) + + # Wait for server to be ready + server_ready = wait_for_server(port, timeout=30) + assert server_ready, "Server did not start within timeout" + + # Make a request WITH evaluation-id header + response = requests.post( + f"http://localhost:{port}/run", + json={"prompt": "Hello, World!"}, + headers={"evaluation-id": "test-eval-123"}, + timeout=30, + ) + assert response.status_code == 200 + + # Verify it's NOT a streaming response (should be application/json) + content_type = response.headers.get("content-type", "") + assert "application/json" in content_type, ( + f"Expected application/json content type for evaluation mode, got: {content_type}" + ) + + # Verify the response contains the complete content + result = response.json() + expected_content = "Echo: Hello, World! [DONE]" + assert result == expected_content, ( + f"Expected complete collected content '{expected_content}', got: {result}" + ) + + logger.info(f"Single JSON response received: {result}") + finally: cleanup_process(process) \ No newline at end of file diff --git a/tests/decorator_test.py b/tests/decorator_test.py index 9d993cc..6a16d0e 100644 --- a/tests/decorator_test.py +++ b/tests/decorator_test.py @@ -31,6 +31,7 @@ def __init__(self): self.ended = [] self.closed = False self._req = {} + self._submitted_trace_id = None def on_request_start(self, name, inputs, is_evaluation=False): self.started.append((name, inputs, is_evaluation)) @@ -45,6 +46,12 @@ async def _submit(self): """Simulate async submission.""" await asyncio.sleep(0) + async def submit_and_get_trace_id(self): + """Simulate async submission and return trace ID.""" + await asyncio.sleep(0) + self._submitted_trace_id = "test-trace-id-12345" + return self._submitted_trace_id + async def aclose(self): """Simulate async close.""" await asyncio.sleep(0) @@ -372,6 +379,110 @@ def handler(data, context): assert tracker.started[-1][2] is True # is_evaluation flag +def test_streaming_with_evaluation_id_collects_and_returns_complete_response( + patch_helpers, +): + """Test that streaming with evaluation-id collects all chunks and returns complete response.""" + tracker = patch_helpers + + @entrypoint + async def handler(data): + yield "hello" + yield " " + yield "world" + + fastapi_app = globals()["fastapi_app"] + with TestClient(fastapi_app) as client: + # With evaluation-id header, response should NOT be streamed + r = client.post( + "/run", json={"test": 1}, headers={"evaluation-id": "eval-123"} + ) + assert r.status_code == 200 + # Response should be the complete collected output + assert r.json() == "hello world" + # Trace ID should be in response headers + assert r.headers.get("X-Gradient-Trace-Id") == "test-trace-id-12345" + + # Check that is_evaluation was passed correctly + assert tracker.started + assert tracker.started[-1][2] is True # is_evaluation flag + # Tracker should have the collected output + assert tracker._req.get("outputs") == "hello world" + # submit_and_get_trace_id should have been called + assert tracker._submitted_trace_id == "test-trace-id-12345" + + +def test_streaming_without_evaluation_id_streams_normally(patch_helpers): + """Test that streaming without evaluation-id continues to stream normally.""" + tracker = patch_helpers + + @entrypoint + async def handler(data): + yield "hello" + yield " " + yield "world" + + fastapi_app = globals()["fastapi_app"] + with TestClient(fastapi_app) as client: + # Without evaluation-id header, response should be streamed + with client.stream("POST", "/run", json={"test": 1}) as resp: + assert resp.status_code == 200 + # Read the full stream by iterating + body = "".join(chunk for chunk in resp.iter_text()) + assert body == "hello world" + + # Check that is_evaluation was passed as False + assert tracker.started + assert tracker.started[-1][2] is False # is_evaluation flag + # submit_and_get_trace_id should NOT have been called (normal streaming) + assert tracker._submitted_trace_id is None + + +def test_streaming_with_evaluation_id_handles_dict_chunks(patch_helpers): + """Test that streaming with evaluation-id properly handles dict chunks.""" + tracker = patch_helpers + + @entrypoint + async def handler(data): + yield {"type": "start"} + yield {"type": "data", "value": 42} + yield {"type": "end"} + + fastapi_app = globals()["fastapi_app"] + with TestClient(fastapi_app) as client: + r = client.post( + "/run", json={"test": 1}, headers={"evaluation-id": "eval-123"} + ) + assert r.status_code == 200 + # Dict chunks should be JSON serialized and concatenated + result = r.json() + assert '{"type": "start"}' in result + assert '{"type": "data", "value": 42}' in result + assert '{"type": "end"}' in result + # Trace ID should be in response headers + assert r.headers.get("X-Gradient-Trace-Id") == "test-trace-id-12345" + + +def test_streaming_with_evaluation_id_skips_none_chunks(patch_helpers): + """Test that streaming with evaluation-id properly skips None chunks.""" + tracker = patch_helpers + + @entrypoint + async def handler(data): + yield "a" + yield None # Should be skipped + yield "b" + + fastapi_app = globals()["fastapi_app"] + with TestClient(fastapi_app) as client: + r = client.post( + "/run", json={"test": 1}, headers={"evaluation-id": "eval-123"} + ) + assert r.status_code == 200 + assert r.json() == "ab" # None skipped + assert r.headers.get("X-Gradient-Trace-Id") == "test-trace-id-12345" + + def test_shutdown_event_calls_tracker_aclose(patch_helpers): """Test that shutdown event calls tracker aclose.""" tracker = patch_helpers @@ -414,4 +525,4 @@ def handler(data, context): assert calls["host"] == "127.0.0.1" assert calls["port"] == 9999 assert calls["kwargs"]["reload"] is True - assert calls["kwargs"]["log_level"] == "debug" + assert calls["kwargs"]["log_level"] == "debug" \ No newline at end of file