Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions gradient_adk/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,49 @@ async def run(req: Request):
logger.error("Error creating generator", error=str(e), exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")

# Wrap in tracking iterator
# If evaluation mode, collect all chunks and return as single response with trace ID
if is_evaluation:
from fastapi.responses import JSONResponse

collected_chunks: List[str] = []
try:
async for chunk in user_gen:
if isinstance(chunk, bytes):
chunk_str = chunk.decode("utf-8", errors="replace")
elif isinstance(chunk, dict):
chunk_str = json.dumps(chunk)
elif chunk is None:
continue
else:
chunk_str = str(chunk)
collected_chunks.append(chunk_str)

result = "".join(collected_chunks)

# Submit tracking and get trace ID
trace_id = None
if tr:
try:
tr._req["outputs"] = result
trace_id = await tr.submit_and_get_trace_id()
except Exception:
pass

headers = {"X-Gradient-Trace-Id": trace_id} if trace_id else {}
return JSONResponse(content=result, headers=headers)

except Exception as e:
if tr:
try:
tr._req["outputs"] = "".join(collected_chunks)
tr._req["error"] = str(e)
await tr._submit()
except Exception:
pass
logger.error("Error in streaming evaluation", error=str(e), exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")

# Normal streaming case - wrap in tracking iterator
streaming_iter = _StreamingIteratorWithTracking(user_gen, tr, func.__name__)

return FastAPIStreamingResponse(
Expand Down Expand Up @@ -234,4 +276,4 @@ async def health():

def run_server(fastapi_app: FastAPI, host: str = "0.0.0.0", port: int = 8080, **kwargs):
"""Run the FastAPI server with uvicorn."""
uvicorn.run(fastapi_app, host=host, port=port, **kwargs)
uvicorn.run(fastapi_app, host=host, port=port, **kwargs)
17 changes: 17 additions & 0 deletions integration_tests/example_agents/streaming_echo_agent/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
Streaming echo agent for integration testing.
Does not make any external API calls - just echoes back the input in chunks.
Used to test streaming vs non-streaming behavior with evaluation-id header.
"""

from gradient_adk import entrypoint


@entrypoint
async def main(query, context):
"""Streaming echo agent - yields the response in chunks."""
prompt = query.get("prompt", "no prompt provided")
# Stream the response in multiple chunks
yield "Echo: "
yield prompt
yield " [DONE]"
162 changes: 162 additions & 0 deletions integration_tests/run/test_adk_agents_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ def echo_agent_dir(self):
"""Get the path to the echo agent directory."""
return Path(__file__).parent.parent / "example_agents" / "echo_agent"

@pytest.fixture
def streaming_echo_agent_dir(self):
"""Get the path to the streaming echo agent directory."""
return Path(__file__).parent.parent / "example_agents" / "streaming_echo_agent"

@pytest.fixture
def setup_agent_in_temp(self, echo_agent_dir):
"""
Expand Down Expand Up @@ -86,6 +91,33 @@ def setup_agent_in_temp(self, echo_agent_dir):

yield temp_path

@pytest.fixture
def setup_streaming_agent_in_temp(self, streaming_echo_agent_dir):
"""
Setup a temporary directory with the streaming echo agent and proper configuration.
Yields the temp directory path and cleans up after.
"""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)

# Copy the streaming echo agent main.py
shutil.copy(streaming_echo_agent_dir / "main.py", temp_path / "main.py")

# Create .gradient directory and config
gradient_dir = temp_path / ".gradient"
gradient_dir.mkdir()

config = {
"agent_name": "test-streaming-echo-agent",
"agent_environment": "main",
"entrypoint_file": "main.py",
}

with open(gradient_dir / "agent.yml", "w") as f:
yaml.safe_dump(config, f)

yield temp_path

@pytest.mark.cli
def test_agent_run_happy_path(self, setup_agent_in_temp):
"""
Expand Down Expand Up @@ -386,5 +418,135 @@ def test_agent_run_run_endpoint_with_various_inputs(self, setup_agent_in_temp):
assert data["echo"] == "Hello `} E1-('"
logger.info("Unicode test passed")

finally:
cleanup_process(process)

@pytest.mark.cli
def test_streaming_agent_without_evaluation_id_streams_response(
self, setup_streaming_agent_in_temp
):
"""
Test that a streaming agent returns a streamed response when no evaluation-id header is sent.
Verifies:
- Response is streamed (text/event-stream content type)
- Response contains the expected content
"""
logger = logging.getLogger(__name__)
temp_dir = setup_streaming_agent_in_temp
port = find_free_port()
process = None

try:
logger.info(f"Starting streaming agent on port {port} in {temp_dir}")

# Start the agent server
process = subprocess.Popen(
[
"gradient",
"agent",
"run",
"--port",
str(port),
"--no-dev",
],
cwd=temp_dir,
start_new_session=True,
)

# Wait for server to be ready
server_ready = wait_for_server(port, timeout=30)
assert server_ready, "Server did not start within timeout"

# Make a streaming request WITHOUT evaluation-id header
with requests.post(
f"http://localhost:{port}/run",
json={"prompt": "Hello, World!"},
stream=True,
timeout=30,
) as response:
assert response.status_code == 200

# Verify it's a streaming response (text/event-stream)
content_type = response.headers.get("content-type", "")
assert "text/event-stream" in content_type, (
f"Expected text/event-stream content type for streaming, got: {content_type}"
)

# Collect chunks to verify content
chunks = list(response.iter_content(decode_unicode=True))
full_content = "".join(c for c in chunks if c)

# Verify the content contains the expected streamed output
assert "Echo:" in full_content or "Hello, World!" in full_content, (
f"Expected streamed content to contain prompt, got: {full_content}"
)

logger.info(f"Streaming response received with {len(chunks)} chunks")
logger.info(f"Full content: {full_content}")

finally:
cleanup_process(process)

@pytest.mark.cli
def test_streaming_agent_with_evaluation_id_returns_single_response(
self, setup_streaming_agent_in_temp
):
"""
Test that a streaming agent returns a single JSON response (not streamed)
when the evaluation-id header is present.
Verifies:
- Response is NOT streamed (application/json content type)
- Response contains the complete collected content
"""
logger = logging.getLogger(__name__)
temp_dir = setup_streaming_agent_in_temp
port = find_free_port()
process = None

try:
logger.info(f"Starting streaming agent on port {port} in {temp_dir}")

# Start the agent server
process = subprocess.Popen(
[
"gradient",
"agent",
"run",
"--port",
str(port),
"--no-dev",
],
cwd=temp_dir,
start_new_session=True,
)

# Wait for server to be ready
server_ready = wait_for_server(port, timeout=30)
assert server_ready, "Server did not start within timeout"

# Make a request WITH evaluation-id header
response = requests.post(
f"http://localhost:{port}/run",
json={"prompt": "Hello, World!"},
headers={"evaluation-id": "test-eval-123"},
timeout=30,
)
assert response.status_code == 200

# Verify it's NOT a streaming response (should be application/json)
content_type = response.headers.get("content-type", "")
assert "application/json" in content_type, (
f"Expected application/json content type for evaluation mode, got: {content_type}"
)

# Verify the response contains the complete content
result = response.json()
expected_content = "Echo: Hello, World! [DONE]"
assert result == expected_content, (
f"Expected complete collected content '{expected_content}', got: {result}"
)

logger.info(f"Single JSON response received: {result}")

finally:
cleanup_process(process)
113 changes: 112 additions & 1 deletion tests/decorator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self):
self.ended = []
self.closed = False
self._req = {}
self._submitted_trace_id = None

def on_request_start(self, name, inputs, is_evaluation=False):
self.started.append((name, inputs, is_evaluation))
Expand All @@ -45,6 +46,12 @@ async def _submit(self):
"""Simulate async submission."""
await asyncio.sleep(0)

async def submit_and_get_trace_id(self):
"""Simulate async submission and return trace ID."""
await asyncio.sleep(0)
self._submitted_trace_id = "test-trace-id-12345"
return self._submitted_trace_id

async def aclose(self):
"""Simulate async close."""
await asyncio.sleep(0)
Expand Down Expand Up @@ -372,6 +379,110 @@ def handler(data, context):
assert tracker.started[-1][2] is True # is_evaluation flag


def test_streaming_with_evaluation_id_collects_and_returns_complete_response(
patch_helpers,
):
"""Test that streaming with evaluation-id collects all chunks and returns complete response."""
tracker = patch_helpers

@entrypoint
async def handler(data):
yield "hello"
yield " "
yield "world"

fastapi_app = globals()["fastapi_app"]
with TestClient(fastapi_app) as client:
# With evaluation-id header, response should NOT be streamed
r = client.post(
"/run", json={"test": 1}, headers={"evaluation-id": "eval-123"}
)
assert r.status_code == 200
# Response should be the complete collected output
assert r.json() == "hello world"
# Trace ID should be in response headers
assert r.headers.get("X-Gradient-Trace-Id") == "test-trace-id-12345"

# Check that is_evaluation was passed correctly
assert tracker.started
assert tracker.started[-1][2] is True # is_evaluation flag
# Tracker should have the collected output
assert tracker._req.get("outputs") == "hello world"
# submit_and_get_trace_id should have been called
assert tracker._submitted_trace_id == "test-trace-id-12345"


def test_streaming_without_evaluation_id_streams_normally(patch_helpers):
"""Test that streaming without evaluation-id continues to stream normally."""
tracker = patch_helpers

@entrypoint
async def handler(data):
yield "hello"
yield " "
yield "world"

fastapi_app = globals()["fastapi_app"]
with TestClient(fastapi_app) as client:
# Without evaluation-id header, response should be streamed
with client.stream("POST", "/run", json={"test": 1}) as resp:
assert resp.status_code == 200
# Read the full stream by iterating
body = "".join(chunk for chunk in resp.iter_text())
assert body == "hello world"

# Check that is_evaluation was passed as False
assert tracker.started
assert tracker.started[-1][2] is False # is_evaluation flag
# submit_and_get_trace_id should NOT have been called (normal streaming)
assert tracker._submitted_trace_id is None


def test_streaming_with_evaluation_id_handles_dict_chunks(patch_helpers):
"""Test that streaming with evaluation-id properly handles dict chunks."""
tracker = patch_helpers

@entrypoint
async def handler(data):
yield {"type": "start"}
yield {"type": "data", "value": 42}
yield {"type": "end"}

fastapi_app = globals()["fastapi_app"]
with TestClient(fastapi_app) as client:
r = client.post(
"/run", json={"test": 1}, headers={"evaluation-id": "eval-123"}
)
assert r.status_code == 200
# Dict chunks should be JSON serialized and concatenated
result = r.json()
assert '{"type": "start"}' in result
assert '{"type": "data", "value": 42}' in result
assert '{"type": "end"}' in result
# Trace ID should be in response headers
assert r.headers.get("X-Gradient-Trace-Id") == "test-trace-id-12345"


def test_streaming_with_evaluation_id_skips_none_chunks(patch_helpers):
"""Test that streaming with evaluation-id properly skips None chunks."""
tracker = patch_helpers

@entrypoint
async def handler(data):
yield "a"
yield None # Should be skipped
yield "b"

fastapi_app = globals()["fastapi_app"]
with TestClient(fastapi_app) as client:
r = client.post(
"/run", json={"test": 1}, headers={"evaluation-id": "eval-123"}
)
assert r.status_code == 200
assert r.json() == "ab" # None skipped
assert r.headers.get("X-Gradient-Trace-Id") == "test-trace-id-12345"


def test_shutdown_event_calls_tracker_aclose(patch_helpers):
"""Test that shutdown event calls tracker aclose."""
tracker = patch_helpers
Expand Down Expand Up @@ -414,4 +525,4 @@ def handler(data, context):
assert calls["host"] == "127.0.0.1"
assert calls["port"] == 9999
assert calls["kwargs"]["reload"] is True
assert calls["kwargs"]["log_level"] == "debug"
assert calls["kwargs"]["log_level"] == "debug"