Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion gradient_adk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
and the CLI (gradient command).
"""

from .decorator import entrypoint
from .decorator import entrypoint, RequestContext
from .tracing import ( # manual tracing decorators
trace_llm,
trace_retriever,
Expand All @@ -12,6 +12,7 @@

__all__ = [
"entrypoint",
"RequestContext",
"trace_llm",
"trace_retriever",
"trace_tool",
Expand Down
30 changes: 26 additions & 4 deletions gradient_adk/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,21 @@
from __future__ import annotations
import inspect
import json
from dataclasses import dataclass
from typing import Callable, Optional, Any, Dict, List


@dataclass
class RequestContext:
"""Context passed to entrypoint functions containing request metadata.

Attributes:
session_id: The session ID for the request, if provided.
"""

session_id: Optional[str] = None


from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
import uvicorn
Expand Down Expand Up @@ -144,12 +157,21 @@ async def run(req: Request):

is_evaluation = "evaluation-id" in req.headers

# Extract session ID from headers
session_id = req.headers.get("session-id")
context = RequestContext(session_id=session_id)

# Initialize tracker
tr = None
try:
tr = get_tracker()
if tr:
tr.on_request_start(func.__name__, body, is_evaluation=is_evaluation)
tr.on_request_start(
func.__name__,
body,
is_evaluation=is_evaluation,
session_id=session_id,
)
except Exception:
pass

Expand All @@ -159,7 +181,7 @@ async def run(req: Request):
if num_params == 1:
user_gen = func(body)
else:
user_gen = func(body, None)
user_gen = func(body, context)
except Exception as e:
if tr:
try:
Expand Down Expand Up @@ -232,9 +254,9 @@ async def run(req: Request):
result = func(body)
else:
if inspect.iscoroutinefunction(func):
result = await func(body, None)
result = await func(body, context)
else:
result = func(body, None)
result = func(body, context)
except Exception as e:
if tr:
try:
Expand Down
10 changes: 8 additions & 2 deletions gradient_adk/runtime/digitalocean_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,17 @@ def __init__(
self._is_evaluation: bool = False

def on_request_start(
self, entrypoint: str, inputs: Dict[str, Any], is_evaluation: bool = False
self,
entrypoint: str,
inputs: Dict[str, Any],
is_evaluation: bool = False,
session_id: Optional[str] = None,
) -> None:
# NEW: reset buffers per request
self._live.clear()
self._done.clear()
self._is_evaluation = is_evaluation
self._session_id = session_id
self._req = {"entrypoint": entrypoint, "inputs": inputs}

def _as_async_iterable_and_setter(
Expand Down Expand Up @@ -299,13 +304,14 @@ async def _submit(self) -> Optional[str]:
agent_workspace_name=self._ws,
agent_deployment_name=self._dep,
traces=[trace],
session_id=getattr(self, "_session_id", None),
)
result = await self._client.create_traces(req)
# Return first trace_uuid if available
if result.trace_uuids:
return result.trace_uuids[0]
return None
except Exception as e:
except Exception:
Copy link

Copilot AI Jan 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The exception is being caught but the variable 'e' is removed. While silently handling exceptions may be intentional here, consider logging the exception for debugging purposes even if you don't want to break user code.

Copilot uses AI. Check for mistakes.
# never break user code on export errors
return None

Expand Down
10 changes: 7 additions & 3 deletions integration_tests/example_agents/echo_agent/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
Does not make any external API calls - just echoes back the input.
"""

from gradient_adk import entrypoint
from gradient_adk import entrypoint, RequestContext


@entrypoint
async def main(query, context):
async def main(query, context: RequestContext):
"""Echo the input back to the caller."""
prompt = query.get("prompt", "no prompt provided")
return {"echo": prompt, "received": query}
return {
"echo": prompt,
"received": query,
"session_id": context.session_id if context else None,
}
155 changes: 130 additions & 25 deletions integration_tests/run/test_adk_agents_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,10 @@ def test_agent_run_no_config(self):

# Check for helpful error message
combined_output = result.stdout + result.stderr
assert "error" in combined_output.lower() or "configuration" in combined_output.lower(), \
f"Expected error about missing configuration, got: {combined_output}"
assert (
"error" in combined_output.lower()
or "configuration" in combined_output.lower()
), f"Expected error about missing configuration, got: {combined_output}"
logger.info("Correctly failed without configuration")

@pytest.mark.cli
Expand Down Expand Up @@ -282,12 +284,17 @@ def test_agent_run_missing_entrypoint(self):
)

# Should fail
assert result.returncode != 0, "Command should have failed with missing entrypoint"
assert (
result.returncode != 0
), "Command should have failed with missing entrypoint"

# Check for helpful error message
combined_output = result.stdout + result.stderr
assert "error" in combined_output.lower() or "not exist" in combined_output.lower() or "nonexistent" in combined_output.lower(), \
f"Expected error about missing entrypoint, got: {combined_output}"
assert (
"error" in combined_output.lower()
or "not exist" in combined_output.lower()
or "nonexistent" in combined_output.lower()
), f"Expected error about missing entrypoint, got: {combined_output}"
logger.info("Correctly failed with missing entrypoint file")

@pytest.mark.cli
Expand All @@ -302,10 +309,12 @@ def test_agent_run_invalid_entrypoint_no_decorator(self):

# Create a Python file without @entrypoint decorator
main_py = temp_path / "main.py"
main_py.write_text("""
main_py.write_text(
"""
def main(query, context):
return {"result": "no decorator"}
""")
"""
)

# Create .gradient directory and config
gradient_dir = temp_path / ".gradient"
Expand All @@ -320,13 +329,22 @@ def main(query, context):
with open(gradient_dir / "agent.yml", "w") as f:
yaml.safe_dump(config, f)

logger.info(f"Testing agent run with invalid entrypoint (no decorator) in {temp_dir}")
logger.info(
f"Testing agent run with invalid entrypoint (no decorator) in {temp_dir}"
)

# Run gradient agent run
# This might start but fail to find fastapi_app, or fail on validation
# Either way it should not succeed
process = subprocess.Popen(
["gradient", "agent", "run", "--no-dev", "--port", str(find_free_port())],
[
"gradient",
"agent",
"run",
"--no-dev",
"--port",
str(find_free_port()),
],
cwd=temp_dir,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
Expand All @@ -339,15 +357,17 @@ def main(query, context):

# Check if process exited with error
return_code = process.poll()

if return_code is None:
# Process is still running - try to connect and see if it works
# (It shouldn't work properly without @entrypoint)
# If it's running but not responding properly, that's also a failure mode
logger.info("Process started but likely not functioning correctly")
else:
# Process exited - check return code
assert return_code != 0 or return_code is None, "Expected process to fail or not work correctly"
assert (
return_code != 0 or return_code is None
), "Expected process to fail or not work correctly"
logger.info(f"Process correctly exited with code {return_code}")
finally:
cleanup_process(process)
Expand Down Expand Up @@ -397,7 +417,11 @@ def test_agent_run_run_endpoint_with_various_inputs(self, setup_agent_in_temp):
# Test with additional fields
response = requests.post(
f"http://localhost:{port}/run",
json={"prompt": "test", "extra_field": "value", "nested": {"key": "val"}},
json={
"prompt": "test",
"extra_field": "value",
"nested": {"key": "val"},
},
timeout=10,
)
assert response.status_code == 200
Expand All @@ -421,6 +445,87 @@ def test_agent_run_run_endpoint_with_various_inputs(self, setup_agent_in_temp):
finally:
cleanup_process(process)

@pytest.mark.cli
def test_agent_run_session_id_header_passthrough(self, setup_agent_in_temp):
"""
Test that the Session-Id header is passed to the agent context.
Verifies:
- Session-Id header is extracted from request
- Session-Id is available in RequestContext
- Agent can return session_id in response
"""
logger = logging.getLogger(__name__)
temp_dir = setup_agent_in_temp
port = find_free_port()
process = None

try:
logger.info(f"Starting agent on port {port} in {temp_dir}")

# Start the agent server
process = subprocess.Popen(
[
"gradient",
"agent",
"run",
"--port",
str(port),
"--no-dev",
],
cwd=temp_dir,
start_new_session=True,
)

# Wait for server to be ready
server_ready = wait_for_server(port, timeout=30)
assert server_ready, "Server did not start within timeout"

# Test with Session-Id header
test_session_id = "test-session-12345"
response = requests.post(
f"http://localhost:{port}/run",
json={"prompt": "Hello"},
headers={"Session-Id": test_session_id},
timeout=10,
)
assert response.status_code == 200
data = response.json()
assert (
data["session_id"] == test_session_id
), f"Expected session_id '{test_session_id}', got '{data.get('session_id')}'"
logger.info(f"Session-Id header passthrough test passed: {data}")

# Test without Session-Id header (should be None)
response = requests.post(
f"http://localhost:{port}/run",
json={"prompt": "Hello without session"},
timeout=10,
)
assert response.status_code == 200
data = response.json()
assert (
data["session_id"] is None
), f"Expected session_id to be None, got '{data.get('session_id')}'"
logger.info("No Session-Id header test passed (session_id is None)")

# Test with lowercase session-id header (case-insensitive)
lowercase_session_id = "lowercase-session-abc"
response = requests.post(
f"http://localhost:{port}/run",
json={"prompt": "Hello with lowercase header"},
headers={"session-id": lowercase_session_id},
timeout=10,
)
assert response.status_code == 200
data = response.json()
assert (
data["session_id"] == lowercase_session_id
), f"Expected session_id '{lowercase_session_id}', got '{data.get('session_id')}'"
logger.info("Lowercase session-id header test passed")

finally:
cleanup_process(process)

@pytest.mark.cli
def test_streaming_agent_without_evaluation_id_streams_response(
self, setup_streaming_agent_in_temp
Expand Down Expand Up @@ -468,18 +573,18 @@ def test_streaming_agent_without_evaluation_id_streams_response(

# Verify it's a streaming response (text/event-stream)
content_type = response.headers.get("content-type", "")
assert "text/event-stream" in content_type, (
f"Expected text/event-stream content type for streaming, got: {content_type}"
)
assert (
"text/event-stream" in content_type
), f"Expected text/event-stream content type for streaming, got: {content_type}"

# Collect chunks to verify content
chunks = list(response.iter_content(decode_unicode=True))
full_content = "".join(c for c in chunks if c)

# Verify the content contains the expected streamed output
assert "Echo:" in full_content or "Hello, World!" in full_content, (
f"Expected streamed content to contain prompt, got: {full_content}"
)
assert (
"Echo:" in full_content or "Hello, World!" in full_content
), f"Expected streamed content to contain prompt, got: {full_content}"

logger.info(f"Streaming response received with {len(chunks)} chunks")
logger.info(f"Full content: {full_content}")
Expand Down Expand Up @@ -535,18 +640,18 @@ def test_streaming_agent_with_evaluation_id_returns_single_response(

# Verify it's NOT a streaming response (should be application/json)
content_type = response.headers.get("content-type", "")
assert "application/json" in content_type, (
f"Expected application/json content type for evaluation mode, got: {content_type}"
)
assert (
"application/json" in content_type
), f"Expected application/json content type for evaluation mode, got: {content_type}"

# Verify the response contains the complete content
result = response.json()
expected_content = "Echo: Hello, World! [DONE]"
assert result == expected_content, (
f"Expected complete collected content '{expected_content}', got: {result}"
)
assert (
result == expected_content
), f"Expected complete collected content '{expected_content}', got: {result}"

logger.info(f"Single JSON response received: {result}")

finally:
cleanup_process(process)
cleanup_process(process)
Loading