Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
01f7d1b
async models
pgrayy Jun 27, 2025
0fd2671
Merge branch 'main' of https://github.com/strands-agents/sdk-python i…
pgrayy Jun 30, 2025
ab2c088
lint
pgrayy Jun 30, 2025
3fd243e
tests
pgrayy Jun 30, 2025
d9deb93
agent - asyncio.run stream_async in call
pgrayy Jun 30, 2025
b864b65
tests - agenerator helper
pgrayy Jun 30, 2025
bd2adff
tests - agent - stream async result
pgrayy Jun 30, 2025
7ea90e9
lint
pgrayy Jun 30, 2025
3462e1d
agent - stream async - result
pgrayy Jun 30, 2025
8063c08
typing
pgrayy Jun 30, 2025
f8e58a1
tests - anext
pgrayy Jun 30, 2025
5bb0620
tests - alist
pgrayy Jun 30, 2025
760fcfb
lint
pgrayy Jun 30, 2025
61bb44d
tests - async utilities - scope session
pgrayy Jun 30, 2025
e47567f
tests integ - conftest
pgrayy Jun 30, 2025
90aaa47
Merge branch 'main' of https://github.com/strands-agents/sdk-python i…
pgrayy Jul 1, 2025
ccc44d9
lint
pgrayy Jul 1, 2025
e8c7bda
tests - async mock model provider
pgrayy Jul 1, 2025
17b24b3
lint
pgrayy Jul 1, 2025
6f46740
Merge branch 'main' of https://github.com/strands-agents/sdk-python i…
pgrayy Jul 2, 2025
1c667d9
async invoke and structured output
pgrayy Jul 2, 2025
21f80cc
thread asyncio run
pgrayy Jul 2, 2025
885f98d
test async threading
pgrayy Jul 2, 2025
773bef1
lint
pgrayy Jul 2, 2025
eb1ffd5
move invoke_async up for clarity
pgrayy Jul 2, 2025
2ce4581
Merge branch 'main' of https://github.com/strands-agents/sdk-python i…
pgrayy Jul 3, 2025
3a58ef4
lint
pgrayy Jul 3, 2025
780b13e
tests
pgrayy Jul 3, 2025
5761545
Merge branch 'main' of https://github.com/strands-agents/sdk-python i…
pgrayy Jul 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 35 additions & 26 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")`
"""

import asyncio
import json
import logging
import os
import random
from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncIterator, Callable, Generator, Mapping, Optional, Type, TypeVar, Union, cast
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast

from opentelemetry import trace
from pydantic import BaseModel
Expand Down Expand Up @@ -381,21 +382,22 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
- metrics: Performance metrics from the event loop
- state: The final state of the event loop
"""
callback_handler = kwargs.get("callback_handler", self.callback_handler)

self._start_agent_trace_span(prompt)
async def acall() -> AgentResult:
callback_handler = kwargs.get("callback_handler", self.callback_handler)

try:
events = self._run_loop(callback_handler, prompt, kwargs)
for event in events:
async for event in events:
if "callback" in event:
callback_handler(**event["callback"])

stop_reason, message, metrics, state = event["stop"]
result = AgentResult(stop_reason, message, metrics, state)
return AgentResult(*event["stop"])

self._end_agent_trace_span(response=result)
self._start_agent_trace_span(prompt)

try:
result = asyncio.run(acall())
self._end_agent_trace_span(response=result)
return result

except Exception as e:
Expand All @@ -417,6 +419,15 @@ def structured_output(self, output_model: Type[T], prompt: Optional[str] = None)
that the agent will use when responding.
prompt: The prompt to use for the agent.
"""

async def acall(messages: Messages) -> T:
events = self.model.structured_output(output_model, messages)
async for event in events:
if "callback" in event:
self.callback_handler(**cast(dict, event["callback"]))

return event["output"]

messages = self.messages
if not messages and not prompt:
raise ValueError("No conversation history or prompt provided")
Expand All @@ -425,13 +436,7 @@ def structured_output(self, output_model: Type[T], prompt: Optional[str] = None)
if prompt:
messages.append({"role": "user", "content": [{"text": prompt}]})

# get the structured output from the model
events = self.model.structured_output(output_model, messages)
for event in events:
if "callback" in event:
self.callback_handler(**cast(dict, event["callback"]))

return event["output"]
return asyncio.run(acall(messages))

async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
"""Process a natural language prompt and yield events as an async iterator.
Expand Down Expand Up @@ -469,23 +474,21 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:

try:
events = self._run_loop(callback_handler, prompt, kwargs)
for event in events:
async for event in events:
if "callback" in event:
callback_handler(**event["callback"])
yield event["callback"]

stop_reason, message, metrics, state = event["stop"]
result = AgentResult(stop_reason, message, metrics, state)

result = AgentResult(*event["stop"])
self._end_agent_trace_span(response=result)

except Exception as e:
self._end_agent_trace_span(error=e)
raise

def _run_loop(
async def _run_loop(
self, callback_handler: Callable[..., Any], prompt: str, kwargs: dict[str, Any]
) -> Generator[dict[str, Any], None, None]:
) -> AsyncGenerator[dict[str, Any], None]:
"""Execute the agent's event loop with the given prompt and parameters."""
try:
# Extract key parameters
Expand All @@ -497,14 +500,16 @@ def _run_loop(
self.messages.append(new_message)

# Execute the event loop cycle with retry logic for context limits
yield from self._execute_event_loop_cycle(callback_handler, kwargs)
events = self._execute_event_loop_cycle(callback_handler, kwargs)
async for event in events:
yield event

finally:
self.conversation_manager.apply_management(self)

def _execute_event_loop_cycle(
async def _execute_event_loop_cycle(
self, callback_handler: Callable[..., Any], kwargs: dict[str, Any]
) -> Generator[dict[str, Any], None, None]:
) -> AsyncGenerator[dict[str, Any], None]:
"""Execute the event loop cycle with retry logic for context window limits.

This internal method handles the execution of the event loop cycle and implements
Expand All @@ -527,7 +532,7 @@ def _execute_event_loop_cycle(

try:
# Execute the main event loop cycle
yield from event_loop_cycle(
events = event_loop_cycle(
model=model,
system_prompt=system_prompt,
messages=messages, # will be modified by event_loop_cycle
Expand All @@ -540,11 +545,15 @@ def _execute_event_loop_cycle(
event_loop_parent_span=self.trace_span,
**kwargs,
)
async for event in events:
yield event

except ContextWindowOverflowException as e:
# Try reducing the context size and retrying
self.conversation_manager.reduce_context(self, e=e)
yield from self._execute_event_loop_cycle(callback_handler_override, kwargs)
events = self._execute_event_loop_cycle(callback_handler_override, kwargs)
async for event in events:
yield event

def _record_tool_execution(
self,
Expand Down
29 changes: 18 additions & 11 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import time
import uuid
from functools import partial
from typing import Any, Callable, Generator, Optional, cast
from typing import Any, AsyncGenerator, Callable, Optional, cast

from ..telemetry.metrics import EventLoopMetrics, Trace
from ..telemetry.tracer import get_tracer
Expand All @@ -33,7 +33,7 @@
MAX_DELAY = 240 # 4 minutes


def event_loop_cycle(
async def event_loop_cycle(
model: Model,
system_prompt: Optional[str],
messages: Messages,
Expand All @@ -42,7 +42,7 @@ def event_loop_cycle(
tool_handler: Optional[ToolHandler],
tool_execution_handler: Optional[ParallelToolExecutorInterface] = None,
**kwargs: Any,
) -> Generator[dict[str, Any], None, None]:
) -> AsyncGenerator[dict[str, Any], None]:
"""Execute a single cycle of the event loop.

This core function processes a single conversation turn, handling model inference, tool execution, and error
Expand Down Expand Up @@ -132,7 +132,7 @@ def event_loop_cycle(
try:
# TODO: To maintain backwards compatability, we need to combine the stream event with kwargs before yielding
# to the callback handler. This will be revisited when migrating to strongly typed events.
for event in stream_messages(model, system_prompt, messages, tool_config):
async for event in stream_messages(model, system_prompt, messages, tool_config):
if "callback" in event:
yield {"callback": {**event["callback"], **(kwargs if "delta" in event["callback"] else {})}}

Expand Down Expand Up @@ -202,7 +202,7 @@ def event_loop_cycle(
)

# Handle tool execution
yield from _handle_tool_execution(
events = _handle_tool_execution(
stop_reason,
message,
model,
Expand All @@ -218,6 +218,9 @@ def event_loop_cycle(
cycle_start_time,
kwargs,
)
async for event in events:
yield event

return

# End the cycle and return results
Expand Down Expand Up @@ -250,9 +253,9 @@ def event_loop_cycle(
yield {"stop": (stop_reason, message, event_loop_metrics, kwargs["request_state"])}


def recurse_event_loop(
async def recurse_event_loop(
**kwargs: Any,
) -> Generator[dict[str, Any], None, None]:
) -> AsyncGenerator[dict[str, Any], None]:
"""Make a recursive call to event_loop_cycle with the current state.

This function is used when the event loop needs to continue processing after tool execution.
Expand Down Expand Up @@ -284,12 +287,14 @@ def recurse_event_loop(
cycle_trace.add_child(recursive_trace)

yield {"callback": {"start": True}}
yield from event_loop_cycle(**kwargs)
events = event_loop_cycle(**kwargs)
async for event in events:
yield event

recursive_trace.end()


def _handle_tool_execution(
async def _handle_tool_execution(
stop_reason: StopReason,
message: Message,
model: Model,
Expand All @@ -304,7 +309,7 @@ def _handle_tool_execution(
cycle_span: Any,
cycle_start_time: float,
kwargs: dict[str, Any],
) -> Generator[dict[str, Any], None, None]:
) -> AsyncGenerator[dict[str, Any], None]:
tool_uses: list[ToolUse] = []
tool_results: list[ToolResult] = []
invalid_tool_use_ids: list[str] = []
Expand Down Expand Up @@ -385,7 +390,7 @@ def _handle_tool_execution(
yield {"stop": (stop_reason, message, event_loop_metrics, kwargs["request_state"])}
return

yield from recurse_event_loop(
events = recurse_event_loop(
model=model,
system_prompt=system_prompt,
messages=messages,
Expand All @@ -394,3 +399,5 @@ def _handle_tool_execution(
tool_handler=tool_handler,
**kwargs,
)
async for event in events:
yield event
17 changes: 9 additions & 8 deletions src/strands/event_loop/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
import logging
from typing import Any, Generator, Iterable, Optional
from typing import Any, AsyncGenerator, AsyncIterable, Optional

from ..types.content import ContentBlock, Message, Messages
from ..types.models import Model
Expand Down Expand Up @@ -251,10 +251,10 @@ def extract_usage_metrics(event: MetadataEvent) -> tuple[Usage, Metrics]:
return usage, metrics


def process_stream(
chunks: Iterable[StreamEvent],
async def process_stream(
chunks: AsyncIterable[StreamEvent],
messages: Messages,
) -> Generator[dict[str, Any], None, None]:
) -> AsyncGenerator[dict[str, Any], None]:
"""Processes the response stream from the API, constructing the final message and extracting usage metrics.

Args:
Expand All @@ -278,7 +278,7 @@ def process_stream(
usage: Usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0)
metrics: Metrics = Metrics(latencyMs=0)

for chunk in chunks:
async for chunk in chunks:
yield {"callback": {"event": chunk}}

if "messageStart" in chunk:
Expand All @@ -300,12 +300,12 @@ def process_stream(
yield {"stop": (stop_reason, state["message"], usage, metrics)}


def stream_messages(
async def stream_messages(
model: Model,
system_prompt: Optional[str],
messages: Messages,
tool_config: Optional[ToolConfig],
) -> Generator[dict[str, Any], None, None]:
) -> AsyncGenerator[dict[str, Any], None]:
"""Streams messages to the model and processes the response.

Args:
Expand All @@ -323,4 +323,5 @@ def stream_messages(
tool_specs = [tool["toolSpec"] for tool in tool_config.get("tools", [])] or None if tool_config else None

chunks = model.converse(messages, tool_specs, system_prompt)
yield from process_stream(chunks, messages)
async for event in process_stream(chunks, messages):
yield event
10 changes: 5 additions & 5 deletions src/strands/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import json
import logging
import mimetypes
from typing import Any, Generator, Iterable, Optional, Type, TypedDict, TypeVar, Union, cast
from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast

import anthropic
from pydantic import BaseModel
Expand Down Expand Up @@ -344,7 +344,7 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
raise RuntimeError(f"event_type=<{event['type']} | unknown type")

@override
def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
"""Send the request to the Anthropic model and get the streaming response.

Args:
Expand Down Expand Up @@ -376,9 +376,9 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
raise error

@override
def structured_output(
async def structured_output(
self, output_model: Type[T], prompt: Messages
) -> Generator[dict[str, Union[T, Any]], None, None]:
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
"""Get structured output from the model.

Args:
Expand All @@ -391,7 +391,7 @@ def structured_output(
tool_spec = convert_pydantic_to_tool_spec(output_model)

response = self.converse(messages=prompt, tool_specs=[tool_spec])
for event in process_stream(response, prompt):
async for event in process_stream(response, prompt):
yield event

stop_reason, messages, _, _ = event["stop"]
Expand Down
Loading