diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 590b4436c..23f810e39 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -9,12 +9,13 @@ 2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")` """ +import asyncio import json import logging import os import random from concurrent.futures import ThreadPoolExecutor -from typing import Any, AsyncIterator, Callable, Generator, List, Mapping, Optional, Type, TypeVar, Union, cast +from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast from opentelemetry import trace from pydantic import BaseModel @@ -378,33 +379,43 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult: - metrics: Performance metrics from the event loop - state: The final state of the event loop """ - callback_handler = kwargs.get("callback_handler", self.callback_handler) - self._start_agent_trace_span(prompt) + def execute() -> AgentResult: + return asyncio.run(self.invoke_async(prompt, **kwargs)) - try: - events = self._run_loop(prompt, kwargs) - for event in events: - if "callback" in event: - callback_handler(**event["callback"]) + with ThreadPoolExecutor() as executor: + future = executor.submit(execute) + return future.result() - stop_reason, message, metrics, state = event["stop"] - result = AgentResult(stop_reason, message, metrics, state) + async def invoke_async(self, prompt: str, **kwargs: Any) -> AgentResult: + """Process a natural language prompt through the agent's event loop. - self._end_agent_trace_span(response=result) + This method implements the conversational interface (e.g., `agent("hello!")`). It adds the user's prompt to + the conversation history, processes it through the model, executes any tool calls, and returns the final result. - return result + Args: + prompt: The natural language prompt from the user. + **kwargs: Additional parameters to pass through the event loop. - except Exception as e: - self._end_agent_trace_span(error=e) - raise + Returns: + Result object containing: + + - stop_reason: Why the event loop stopped (e.g., "end_turn", "max_tokens") + - message: The final message from the model + - metrics: Performance metrics from the event loop + - state: The final state of the event loop + """ + events = self.stream_async(prompt, **kwargs) + async for event in events: + _ = event + + return cast(AgentResult, event["result"]) def structured_output(self, output_model: Type[T], prompt: Optional[str] = None) -> T: """This method allows you to get structured output from the agent. If you pass in a prompt, it will be added to the conversation history and the agent will respond to it. If you don't pass in a prompt, it will use only the conversation history to respond. - If no conversation history exists and no prompt is provided, an error will be raised. For smaller models, you may want to use the optional prompt string to add additional instructions to explicitly instruct the model to output the structured data. @@ -413,25 +424,52 @@ def structured_output(self, output_model: Type[T], prompt: Optional[str] = None) output_model: The output model (a JSON schema written as a Pydantic BaseModel) that the agent will use when responding. prompt: The prompt to use for the agent. + + Raises: + ValueError: If no conversation history or prompt is provided. + """ + + def execute() -> T: + return asyncio.run(self.structured_output_async(output_model, prompt)) + + with ThreadPoolExecutor() as executor: + future = executor.submit(execute) + return future.result() + + async def structured_output_async(self, output_model: Type[T], prompt: Optional[str] = None) -> T: + """This method allows you to get structured output from the agent. + + If you pass in a prompt, it will be added to the conversation history and the agent will respond to it. + If you don't pass in a prompt, it will use only the conversation history to respond. + + For smaller models, you may want to use the optional prompt string to add additional instructions to explicitly + instruct the model to output the structured data. + + Args: + output_model: The output model (a JSON schema written as a Pydantic BaseModel) + that the agent will use when responding. + prompt: The prompt to use for the agent. + + Raises: + ValueError: If no conversation history or prompt is provided. """ self._hooks.invoke_callbacks(StartRequestEvent(agent=self)) try: - messages = self.messages - if not messages and not prompt: + if not self.messages and not prompt: raise ValueError("No conversation history or prompt provided") # add the prompt as the last message if prompt: - messages.append({"role": "user", "content": [{"text": prompt}]}) + self.messages.append({"role": "user", "content": [{"text": prompt}]}) - # get the structured output from the model - events = self.model.structured_output(output_model, messages) - for event in events: + events = self.model.structured_output(output_model, self.messages) + async for event in events: if "callback" in event: self.callback_handler(**cast(dict, event["callback"])) return event["output"] + finally: self._hooks.invoke_callbacks(EndRequestEvent(agent=self)) @@ -471,13 +509,14 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: try: events = self._run_loop(prompt, kwargs) - for event in events: + async for event in events: if "callback" in event: callback_handler(**event["callback"]) yield event["callback"] - stop_reason, message, metrics, state = event["stop"] - result = AgentResult(stop_reason, message, metrics, state) + result = AgentResult(*event["stop"]) + callback_handler(result=result) + yield {"result": result} self._end_agent_trace_span(response=result) @@ -485,7 +524,7 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: self._end_agent_trace_span(error=e) raise - def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> Generator[dict[str, Any], None, None]: + async def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: """Execute the agent's event loop with the given prompt and parameters.""" self._hooks.invoke_callbacks(StartRequestEvent(agent=self)) @@ -499,13 +538,15 @@ def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> Generator[dict[str, self.messages.append(new_message) # Execute the event loop cycle with retry logic for context limits - yield from self._execute_event_loop_cycle(kwargs) + events = self._execute_event_loop_cycle(kwargs) + async for event in events: + yield event finally: self.conversation_manager.apply_management(self) self._hooks.invoke_callbacks(EndRequestEvent(agent=self)) - def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> Generator[dict[str, Any], None, None]: + async def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: """Execute the event loop cycle with retry logic for context window limits. This internal method handles the execution of the event loop cycle and implements @@ -520,7 +561,7 @@ def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> Generator[dict[st try: # Execute the main event loop cycle - yield from event_loop_cycle( + events = event_loop_cycle( model=self.model, system_prompt=self.system_prompt, messages=self.messages, # will be modified by event_loop_cycle @@ -531,11 +572,15 @@ def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> Generator[dict[st event_loop_parent_span=self.trace_span, kwargs=kwargs, ) + async for event in events: + yield event except ContextWindowOverflowException as e: # Try reducing the context size and retrying self.conversation_manager.reduce_context(self, e=e) - yield from self._execute_event_loop_cycle(kwargs) + events = self._execute_event_loop_cycle(kwargs) + async for event in events: + yield event def _record_tool_execution( self, @@ -560,7 +605,7 @@ def _record_tool_execution( messages: The message history to append to. """ # Create user message describing the tool call - user_msg_content: List[ContentBlock] = [ + user_msg_content: list[ContentBlock] = [ {"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {json.dumps(tool['input'])}\n")} ] diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 61eb780c3..37ef6309a 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -13,7 +13,7 @@ import uuid from concurrent.futures import ThreadPoolExecutor from functools import partial -from typing import Any, Generator, Optional +from typing import Any, AsyncGenerator, Optional from opentelemetry import trace @@ -35,7 +35,7 @@ MAX_DELAY = 240 # 4 minutes -def event_loop_cycle( +async def event_loop_cycle( model: Model, system_prompt: Optional[str], messages: Messages, @@ -45,7 +45,7 @@ def event_loop_cycle( event_loop_metrics: EventLoopMetrics, event_loop_parent_span: Optional[trace.Span], kwargs: dict[str, Any], -) -> Generator[dict[str, Any], None, None]: +) -> AsyncGenerator[dict[str, Any], None]: """Execute a single cycle of the event loop. This core function processes a single conversation turn, handling model inference, tool execution, and error @@ -132,7 +132,7 @@ def event_loop_cycle( try: # TODO: To maintain backwards compatability, we need to combine the stream event with kwargs before yielding # to the callback handler. This will be revisited when migrating to strongly typed events. - for event in stream_messages(model, system_prompt, messages, tool_config): + async for event in stream_messages(model, system_prompt, messages, tool_config): if "callback" in event: yield {"callback": {**event["callback"], **(kwargs if "delta" in event["callback"] else {})}} @@ -202,7 +202,7 @@ def event_loop_cycle( ) # Handle tool execution - yield from _handle_tool_execution( + events = _handle_tool_execution( stop_reason, message, model, @@ -218,6 +218,9 @@ def event_loop_cycle( cycle_start_time, kwargs, ) + async for event in events: + yield event + return # End the cycle and return results @@ -250,7 +253,7 @@ def event_loop_cycle( yield {"stop": (stop_reason, message, event_loop_metrics, kwargs["request_state"])} -def recurse_event_loop( +async def recurse_event_loop( model: Model, system_prompt: Optional[str], messages: Messages, @@ -260,7 +263,7 @@ def recurse_event_loop( event_loop_metrics: EventLoopMetrics, event_loop_parent_span: Optional[trace.Span], kwargs: dict[str, Any], -) -> Generator[dict[str, Any], None, None]: +) -> AsyncGenerator[dict[str, Any], None]: """Make a recursive call to event_loop_cycle with the current state. This function is used when the event loop needs to continue processing after tool execution. @@ -292,7 +295,8 @@ def recurse_event_loop( cycle_trace.add_child(recursive_trace) yield {"callback": {"start": True}} - yield from event_loop_cycle( + + events = event_loop_cycle( model=model, system_prompt=system_prompt, messages=messages, @@ -303,11 +307,13 @@ def recurse_event_loop( event_loop_parent_span=event_loop_parent_span, kwargs=kwargs, ) + async for event in events: + yield event recursive_trace.end() -def _handle_tool_execution( +async def _handle_tool_execution( stop_reason: StopReason, message: Message, model: Model, @@ -322,7 +328,7 @@ def _handle_tool_execution( cycle_span: Any, cycle_start_time: float, kwargs: dict[str, Any], -) -> Generator[dict[str, Any], None, None]: +) -> AsyncGenerator[dict[str, Any], None]: tool_uses: list[ToolUse] = [] tool_results: list[ToolResult] = [] invalid_tool_use_ids: list[str] = [] @@ -369,7 +375,7 @@ def _handle_tool_execution( kwargs=kwargs, ) - yield from run_tools( + tool_events = run_tools( handler=tool_handler_process, tool_uses=tool_uses, event_loop_metrics=event_loop_metrics, @@ -379,6 +385,8 @@ def _handle_tool_execution( parent_span=cycle_span, thread_pool=thread_pool, ) + for tool_event in tool_events: + yield tool_event # Store parent cycle ID for the next cycle kwargs["event_loop_parent_cycle_id"] = kwargs["event_loop_cycle_id"] @@ -400,7 +408,7 @@ def _handle_tool_execution( yield {"stop": (stop_reason, message, event_loop_metrics, kwargs["request_state"])} return - yield from recurse_event_loop( + events = recurse_event_loop( model=model, system_prompt=system_prompt, messages=messages, @@ -411,3 +419,5 @@ def _handle_tool_execution( event_loop_parent_span=event_loop_parent_span, kwargs=kwargs, ) + async for event in events: + yield event diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 0e9d472bd..6ecc3e270 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -2,7 +2,7 @@ import json import logging -from typing import Any, Generator, Iterable, Optional +from typing import Any, AsyncGenerator, AsyncIterable, Optional from ..types.content import ContentBlock, Message, Messages from ..types.models import Model @@ -251,10 +251,10 @@ def extract_usage_metrics(event: MetadataEvent) -> tuple[Usage, Metrics]: return usage, metrics -def process_stream( - chunks: Iterable[StreamEvent], +async def process_stream( + chunks: AsyncIterable[StreamEvent], messages: Messages, -) -> Generator[dict[str, Any], None, None]: +) -> AsyncGenerator[dict[str, Any], None]: """Processes the response stream from the API, constructing the final message and extracting usage metrics. Args: @@ -278,7 +278,7 @@ def process_stream( usage: Usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) metrics: Metrics = Metrics(latencyMs=0) - for chunk in chunks: + async for chunk in chunks: yield {"callback": {"event": chunk}} if "messageStart" in chunk: @@ -300,12 +300,12 @@ def process_stream( yield {"stop": (stop_reason, state["message"], usage, metrics)} -def stream_messages( +async def stream_messages( model: Model, system_prompt: Optional[str], messages: Messages, tool_config: Optional[ToolConfig], -) -> Generator[dict[str, Any], None, None]: +) -> AsyncGenerator[dict[str, Any], None]: """Streams messages to the model and processes the response. Args: @@ -323,4 +323,5 @@ def stream_messages( tool_specs = [tool["toolSpec"] for tool in tool_config.get("tools", [])] or None if tool_config else None chunks = model.converse(messages, tool_specs, system_prompt) - yield from process_stream(chunks, messages) + async for event in process_stream(chunks, messages): + yield event diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index e91cd4422..02c3d9089 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -7,7 +7,7 @@ import json import logging import mimetypes -from typing import Any, Generator, Iterable, Optional, Type, TypedDict, TypeVar, Union, cast +from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast import anthropic from pydantic import BaseModel @@ -344,7 +344,7 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: raise RuntimeError(f"event_type=<{event['type']} | unknown type") @override - def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: + async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: """Send the request to the Anthropic model and get the streaming response. Args: @@ -376,9 +376,9 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: raise error @override - def structured_output( + async def structured_output( self, output_model: Type[T], prompt: Messages - ) -> Generator[dict[str, Union[T, Any]], None, None]: + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. Args: @@ -391,7 +391,7 @@ def structured_output( tool_spec = convert_pydantic_to_tool_spec(output_model) response = self.converse(messages=prompt, tool_specs=[tool_spec]) - for event in process_stream(response, prompt): + async for event in process_stream(response, prompt): yield event stop_reason, messages, _, _ = event["stop"] diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index e1fdfbc36..373dd4fff 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -6,7 +6,7 @@ import json import logging import os -from typing import Any, Generator, Iterable, List, Literal, Optional, Type, TypeVar, Union, cast +from typing import Any, AsyncGenerator, Iterable, List, Literal, Optional, Type, TypeVar, Union, cast import boto3 from botocore.config import Config as BotocoreConfig @@ -315,7 +315,7 @@ def _generate_redaction_events(self) -> list[StreamEvent]: return events @override - def stream(self, request: dict[str, Any]) -> Iterable[StreamEvent]: + async def stream(self, request: dict[str, Any]) -> AsyncGenerator[StreamEvent, None]: """Send the request to the Bedrock model and get the response. This method calls either the Bedrock converse_stream API or the converse API @@ -345,14 +345,16 @@ def stream(self, request: dict[str, Any]) -> Iterable[StreamEvent]: ): guardrail_data = chunk["metadata"]["trace"]["guardrail"] if self._has_blocked_guardrail(guardrail_data): - yield from self._generate_redaction_events() + for event in self._generate_redaction_events(): + yield event yield chunk else: # Non-streaming implementation response = self.client.converse(**request) # Convert and yield from the response - yield from self._convert_non_streaming_to_streaming(response) + for event in self._convert_non_streaming_to_streaming(response): + yield event # Check for guardrail triggers after yielding any events (same as streaming path) if ( @@ -360,7 +362,8 @@ def stream(self, request: dict[str, Any]) -> Iterable[StreamEvent]: and "guardrail" in response["trace"] and self._has_blocked_guardrail(response["trace"]["guardrail"]) ): - yield from self._generate_redaction_events() + for event in self._generate_redaction_events(): + yield event except ClientError as e: error_message = str(e) @@ -514,9 +517,9 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool: return False @override - def structured_output( + async def structured_output( self, output_model: Type[T], prompt: Messages - ) -> Generator[dict[str, Union[T, Any]], None, None]: + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. Args: @@ -529,7 +532,7 @@ def structured_output( tool_spec = convert_pydantic_to_tool_spec(output_model) response = self.converse(messages=prompt, tool_specs=[tool_spec]) - for event in process_stream(response, prompt): + async for event in process_stream(response, prompt): yield event stop_reason, messages, _, _ = event["stop"] diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 691887b54..d894e58e2 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -5,7 +5,7 @@ import json import logging -from typing import Any, Generator, Optional, Type, TypedDict, TypeVar, Union, cast +from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast import litellm from litellm.utils import supports_response_schema @@ -104,9 +104,9 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] return super().format_request_message_content(content) @override - def structured_output( + async def structured_output( self, output_model: Type[T], prompt: Messages - ) -> Generator[dict[str, Union[T, Any]], None, None]: + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 74c098e36..2b585439c 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -8,7 +8,7 @@ import json import logging import mimetypes -from typing import Any, Generator, Iterable, Optional, Type, TypeVar, Union, cast +from typing import Any, AsyncGenerator, Optional, Type, TypeVar, Union, cast import llama_api_client from llama_api_client import LlamaAPIClient @@ -324,7 +324,7 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") @override - def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: + async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: """Send the request to the model and get a streaming response. Args: @@ -391,7 +391,7 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: @override def structured_output( self, output_model: Type[T], prompt: Messages - ) -> Generator[dict[str, Union[T, Any]], None, None]: + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 3d44cbe23..6f8492b79 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -6,7 +6,7 @@ import base64 import json import logging -from typing import Any, Dict, Generator, Iterable, List, Optional, Type, TypeVar, Union +from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypeVar, Union from mistralai import Mistral from pydantic import BaseModel @@ -114,7 +114,7 @@ def get_config(self) -> MistralConfig: """ return self.config - def _format_request_message_content(self, content: ContentBlock) -> Union[str, Dict[str, Any]]: + def _format_request_message_content(self, content: ContentBlock) -> Union[str, dict[str, Any]]: """Format a Mistral content block. Args: @@ -170,7 +170,7 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any Returns: Mistral formatted tool message. """ - content_parts: List[str] = [] + content_parts: list[str] = [] for content in tool_result["content"]: if "json" in content: content_parts.append(json.dumps(content["json"])) @@ -205,9 +205,9 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s role = message["role"] contents = message["content"] - text_contents: List[str] = [] - tool_calls: List[Dict[str, Any]] = [] - tool_messages: List[Dict[str, Any]] = [] + text_contents: list[str] = [] + tool_calls: list[dict[str, Any]] = [] + tool_messages: list[dict[str, Any]] = [] for content in contents: if "text" in content: @@ -220,7 +220,7 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s tool_messages.append(self._format_request_tool_message(content["toolResult"])) if text_contents or tool_calls: - formatted_message: Dict[str, Any] = { + formatted_message: dict[str, Any] = { "role": role, "content": " ".join(text_contents) if text_contents else "", } @@ -252,7 +252,7 @@ def format_request( TypeError: If a message contains a content block type that cannot be converted to a Mistral-compatible format. """ - request: Dict[str, Any] = { + request: dict[str, Any] = { "model": self.config["model_id"], "messages": self._format_request_messages(messages, system_prompt), } @@ -393,7 +393,7 @@ def _handle_non_streaming_response(self, response: Any) -> Iterable[dict[str, An yield {"chunk_type": "metadata", "data": response.usage} @override - def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: + async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: """Send the request to the Mistral model and get the streaming response. Args: @@ -406,10 +406,11 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: ModelThrottledException: When the model service is throttling requests. """ try: - if self.config.get("stream", True) is False: + if not self.config.get("stream", True): # Use non-streaming API response = self.client.chat.complete(**request) - yield from self._handle_non_streaming_response(response) + for event in self._handle_non_streaming_response(response): + yield event return # Use the streaming API @@ -418,7 +419,7 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: yield {"chunk_type": "message_start"} content_started = False - current_tool_calls: Dict[str, Dict[str, str]] = {} + current_tool_calls: dict[str, dict[str, str]] = {} accumulated_text = "" for chunk in stream_response: @@ -470,11 +471,11 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: raise @override - def structured_output( + async def structured_output( self, output_model: Type[T], prompt: Messages, - ) -> Generator[dict[str, Union[T, Any]], None, None]: + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 1c834bf6e..707672498 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -5,7 +5,7 @@ import json import logging -from typing import Any, Generator, Iterable, Optional, Type, TypeVar, Union, cast +from typing import Any, AsyncGenerator, Optional, Type, TypeVar, Union, cast from ollama import Client as OllamaClient from pydantic import BaseModel @@ -283,7 +283,7 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") @override - def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: + async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: """Send the request to the Ollama model and get the streaming response. This method calls the Ollama chat API and returns the stream of response events. @@ -315,9 +315,9 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: yield {"chunk_type": "metadata", "data": event} @override - def structured_output( + async def structured_output( self, output_model: Type[T], prompt: Messages - ) -> Generator[dict[str, Union[T, Any]], None, None]: + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index eb58ae41c..5446cbd3d 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -4,7 +4,7 @@ """ import logging -from typing import Any, Generator, Iterable, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast +from typing import Any, AsyncGenerator, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast import openai from openai.types.chat.parsed_chat_completion import ParsedChatCompletion @@ -82,7 +82,7 @@ def get_config(self) -> OpenAIConfig: return cast(OpenAIModel.OpenAIConfig, self.config) @override - def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: + async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: """Send the request to the OpenAI model and get the streaming response. Args: @@ -139,9 +139,9 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: yield {"chunk_type": "metadata", "data": event.usage} @override - def structured_output( + async def structured_output( self, output_model: Type[T], prompt: Messages - ) -> Generator[dict[str, Union[T, Any]], None, None]: + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. Args: diff --git a/src/strands/types/models/model.py b/src/strands/types/models/model.py index 6d8c5aee7..11abfa598 100644 --- a/src/strands/types/models/model.py +++ b/src/strands/types/models/model.py @@ -2,7 +2,7 @@ import abc import logging -from typing import Any, Generator, Iterable, Optional, Type, TypeVar, Union +from typing import Any, AsyncGenerator, AsyncIterable, Optional, Type, TypeVar, Union from pydantic import BaseModel @@ -46,7 +46,7 @@ def get_config(self) -> Any: # pragma: no cover def structured_output( self, output_model: Type[T], prompt: Messages - ) -> Generator[dict[str, Union[T, Any]], None, None]: + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. Args: @@ -93,7 +93,7 @@ def format_chunk(self, event: Any) -> StreamEvent: @abc.abstractmethod # pragma: no cover - def stream(self, request: Any) -> Iterable[Any]: + def stream(self, request: Any) -> AsyncGenerator[Any, None]: """Send the request to the model and get a streaming response. Args: @@ -107,9 +107,9 @@ def stream(self, request: Any) -> Iterable[Any]: """ pass - def converse( + async def converse( self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None - ) -> Iterable[StreamEvent]: + ) -> AsyncIterable[StreamEvent]: """Converse with the model. This method handles the full lifecycle of conversing with the model: @@ -136,7 +136,7 @@ def converse( response = self.stream(request) logger.debug("got response from model") - for event in response: + async for event in response: yield self.format_chunk(event) logger.debug("finished streaming response from model") diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py index 25830bc39..30971c2ba 100644 --- a/src/strands/types/models/openai.py +++ b/src/strands/types/models/openai.py @@ -11,7 +11,7 @@ import json import logging import mimetypes -from typing import Any, Generator, Optional, Type, TypeVar, Union, cast +from typing import Any, AsyncGenerator, Optional, Type, TypeVar, Union, cast from pydantic import BaseModel from typing_extensions import override @@ -297,9 +297,9 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") @override - def structured_output( + async def structured_output( self, output_model: Type[T], prompt: Messages - ) -> Generator[dict[str, Union[T, Any]], None, None]: + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. Args: diff --git a/tests-integ/conftest.py b/tests-integ/conftest.py new file mode 100644 index 000000000..4b38540c5 --- /dev/null +++ b/tests-integ/conftest.py @@ -0,0 +1,20 @@ +import pytest + +## Async + + +@pytest.fixture(scope="session") +def agenerator(): + async def agenerator(items): + for item in items: + yield item + + return agenerator + + +@pytest.fixture(scope="session") +def alist(): + async def alist(items): + return [item async for item in items] + + return alist diff --git a/tests-integ/test_model_bedrock.py b/tests-integ/test_model_bedrock.py index 5378a9b20..120f4036b 100644 --- a/tests-integ/test_model_bedrock.py +++ b/tests-integ/test_model_bedrock.py @@ -51,12 +51,13 @@ def test_non_streaming_agent(non_streaming_agent): assert len(str(result)) > 0 -def test_streaming_model_events(streaming_model): +@pytest.mark.asyncio +async def test_streaming_model_events(streaming_model, alist): """Test streaming model events.""" messages = [{"role": "user", "content": [{"text": "Hello"}]}] # Call converse and collect events - events = list(streaming_model.converse(messages)) + events = await alist(streaming_model.converse(messages)) # Verify basic structure of events assert any("messageStart" in event for event in events) @@ -64,12 +65,13 @@ def test_streaming_model_events(streaming_model): assert any("messageStop" in event for event in events) -def test_non_streaming_model_events(non_streaming_model): +@pytest.mark.asyncio +async def test_non_streaming_model_events(non_streaming_model, alist): """Test non-streaming model events.""" messages = [{"role": "user", "content": [{"text": "Hello"}]}] # Call converse and collect events - events = list(non_streaming_model.converse(messages)) + events = await alist(non_streaming_model.converse(messages)) # Verify basic structure of events assert any("messageStart" in event for event in events) diff --git a/tests/conftest.py b/tests/conftest.py index f00ae497a..3b82e362c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -70,6 +70,26 @@ def boto3_profile_path(boto3_profile, tmp_path, monkeypatch): return path +## Async + + +@pytest.fixture(scope="session") +def agenerator(): + async def agenerator(items): + for item in items: + yield item + + return agenerator + + +@pytest.fixture(scope="session") +def alist(): + async def alist(items): + return [item async for item in items] + + return alist + + ## Itertools diff --git a/tests/fixtures/mocked_model_provider.py b/tests/fixtures/mocked_model_provider.py index f89d56202..eed5a1b25 100644 --- a/tests/fixtures/mocked_model_provider.py +++ b/tests/fixtures/mocked_model_provider.py @@ -1,5 +1,5 @@ import json -from typing import Any, Callable, Iterable, Optional, Type, TypeVar +from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypeVar from pydantic import BaseModel @@ -38,13 +38,18 @@ def get_config(self) -> Any: def update_config(self, **model_config: Any) -> None: pass - def structured_output( - self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None - ) -> T: + async def structured_output( + self, + output_model: Type[T], + prompt: Messages, + ) -> AsyncGenerator[Any, None]: pass - def stream(self, request: Any) -> Iterable[Any]: - yield from self.map_agent_message_to_events(self.agent_responses[self.index]) + async def stream(self, request: Any) -> AsyncGenerator[Any, None]: + events = self.map_agent_message_to_events(self.agent_responses[self.index]) + for event in events: + yield event + self.index += 1 def map_agent_message_to_events(self, agent_message: Message) -> Iterable[dict[str, Any]]: diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 5a8985fb9..787494597 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -41,28 +41,32 @@ def converse(*args, **kwargs): @pytest.fixture -def mock_hook_messages(mock_model, tool): +def mock_hook_messages(mock_model, tool, agenerator): """Fixture which returns a standard set of events for verifying hooks.""" mock_model.mock_converse.side_effect = [ - [ - { - "contentBlockStart": { - "start": { - "toolUse": { - "toolUseId": "t1", - "name": tool.tool_spec["name"], + agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "t1", + "name": tool.tool_spec["name"], + }, }, }, }, - }, - {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"random_string": "abcdEfghI123"}'}}}}, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "tool_use"}}, - ], - [ - {"contentBlockDelta": {"delta": {"text": "test text"}}}, - {"contentBlockStop": {}}, - ], + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"random_string": "abcdEfghI123"}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + ], + ), + agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "test text"}}}, + {"contentBlockStop": {}}, + ], + ), ] return mock_model.mock_converse @@ -199,6 +203,16 @@ def agent( return agent +@pytest.fixture +def user(): + class User(BaseModel): + name: str + age: int + email: str + + return User(name="Jane Doe", age=30, email="jane@doe.com") + + def test_agent__init__tool_loader_format(tool_decorated, tool_module, tool_imported, tool_registry): _ = tool_registry @@ -260,30 +274,35 @@ def test_agent__call__( callback_handler, agent, tool, + agenerator, ): conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager) agent.conversation_manager = conversation_manager_spy mock_model.mock_converse.side_effect = [ - [ - { - "contentBlockStart": { - "start": { - "toolUse": { - "toolUseId": "t1", - "name": tool.tool_spec["name"], + agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "t1", + "name": tool.tool_spec["name"], + }, }, }, }, - }, - {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"random_string": "abcdEfghI123"}'}}}}, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "tool_use"}}, - ], - [ - {"contentBlockDelta": {"delta": {"text": "test text"}}}, - {"contentBlockStop": {}}, - ], + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"random_string": "abcdEfghI123"}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + ] + ), + agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "test text"}}}, + {"contentBlockStop": {}}, + ] + ), ] result = agent("test message") @@ -358,21 +377,23 @@ def test_agent__call__( conversation_manager_spy.apply_management.assert_called_with(agent) -def test_agent__call__passes_kwargs(mock_model, system_prompt, callback_handler, agent, tool, mock_event_loop_cycle): +def test_agent__call__passes_kwargs(mock_model, agent, tool, mock_event_loop_cycle, agenerator): mock_model.mock_converse.side_effect = [ - [ - { - "contentBlockStart": { - "start": { - "toolUse": { - "toolUseId": "t1", - "name": tool.tool_spec["name"], + agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "t1", + "name": tool.tool_spec["name"], + }, }, }, }, - }, - {"messageStop": {"stopReason": "tool_use"}}, - ], + {"messageStop": {"stopReason": "tool_use"}}, + ] + ), ] override_system_prompt = "Override system prompt" @@ -383,7 +404,7 @@ def test_agent__call__passes_kwargs(mock_model, system_prompt, callback_handler, override_messages = [{"role": "user", "content": [{"text": "override msg"}]}] override_tool_config = {"test": "config"} - def check_kwargs(**kwargs): + async def check_kwargs(**kwargs): kwargs_kwargs = kwargs["kwargs"] assert kwargs_kwargs["some_value"] == "a_value" assert kwargs_kwargs["system_prompt"] == override_system_prompt @@ -415,7 +436,7 @@ def check_kwargs(**kwargs): mock_event_loop_cycle.assert_called_once() -def test_agent__call__retry_with_reduced_context(mock_model, agent, tool): +def test_agent__call__retry_with_reduced_context(mock_model, agent, tool, agenerator): conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager) agent.conversation_manager = conversation_manager_spy @@ -435,14 +456,16 @@ def test_agent__call__retry_with_reduced_context(mock_model, agent, tool): mock_model.mock_converse.side_effect = [ ContextWindowOverflowException(RuntimeError("Input is too long for requested model")), - [ - { - "contentBlockStart": {"start": {}}, - }, - {"contentBlockDelta": {"delta": {"text": "Green!"}}}, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "end_turn"}}, - ], + agenerator( + [ + { + "contentBlockStart": {"start": {}}, + }, + {"contentBlockDelta": {"delta": {"text": "Green!"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ), ] agent("And now?") @@ -542,7 +565,7 @@ def test_agent__call__tool_truncation_doesnt_infinite_loop(mock_model, agent): agent("Test!") -def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool): +def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agenerator): conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager) agent.conversation_manager = conversation_manager_spy @@ -556,26 +579,28 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool): agent.messages = messages mock_model.mock_converse.side_effect = [ - [ - { - "contentBlockStart": { - "start": { - "toolUse": { - "toolUseId": "t1", - "name": tool.tool_spec["name"], + agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "t1", + "name": tool.tool_spec["name"], + }, }, }, }, - }, - {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"random_string": "abcdEfghI123"}'}}}}, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "tool_use"}}, - ], + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"random_string": "abcdEfghI123"}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + ] + ), # Will truncate the tool result ContextWindowOverflowException(RuntimeError("Input is too long for requested model")), # Will reduce the context ContextWindowOverflowException(RuntimeError("Input is too long for requested model")), - [], + agenerator([]), ] agent("test message") @@ -612,22 +637,24 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool): assert conversation_manager_spy.apply_management.call_count == 1 -def test_agent__call__invalid_tool_use_event_loop_exception(mock_model, agent, tool): +def test_agent__call__invalid_tool_use_event_loop_exception(mock_model, agent, tool, agenerator): mock_model.mock_converse.side_effect = [ - [ - { - "contentBlockStart": { - "start": { - "toolUse": { - "toolUseId": "t1", - "name": tool.tool_spec["name"], + agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "t1", + "name": tool.tool_spec["name"], + }, }, }, }, - }, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "tool_use"}}, - ], + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + ] + ), RuntimeError, ] @@ -635,19 +662,21 @@ def test_agent__call__invalid_tool_use_event_loop_exception(mock_model, agent, t agent("test message") -def test_agent__call__callback(mock_model, agent, callback_handler): - mock_model.mock_converse.return_value = [ - {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}}, - {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}, - {"contentBlockStop": {}}, - {"contentBlockStart": {"start": {}}}, - {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}, - {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}, - {"contentBlockStop": {}}, - {"contentBlockStart": {"start": {}}}, - {"contentBlockDelta": {"delta": {"text": "value"}}}, - {"contentBlockStop": {}}, - ] +def test_agent__call__callback(mock_model, agent, callback_handler, agenerator): + mock_model.mock_converse.return_value = agenerator( + [ + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "value"}}}, + {"contentBlockStop": {}}, + ] + ) agent("test") @@ -720,6 +749,46 @@ def test_agent__call__callback(mock_model, agent, callback_handler): ) +@pytest.mark.asyncio +async def test_agent__call__in_async_context(mock_model, agent, agenerator): + mock_model.mock_converse.return_value = agenerator( + [ + { + "contentBlockStart": {"start": {}}, + }, + {"contentBlockDelta": {"delta": {"text": "abc"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ) + + result = agent("test") + + tru_message = result.message + exp_message = {"content": [{"text": "abc"}], "role": "assistant"} + assert tru_message == exp_message + + +@pytest.mark.asyncio +async def test_agent_invoke_async(mock_model, agent, agenerator): + mock_model.mock_converse.return_value = agenerator( + [ + { + "contentBlockStart": {"start": {}}, + }, + {"contentBlockDelta": {"delta": {"text": "abc"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ) + + result = await agent.invoke_async("test") + + tru_message = result.message + exp_message = {"content": [{"text": "abc"}], "role": "assistant"} + assert tru_message == exp_message + + @unittest.mock.patch("strands.experimental.hooks.registry.HookRegistry.invoke_callbacks") def test_agent_hooks__init__(mock_invoke_callbacks): """Verify that the AgentInitializedEvent is emitted on Agent construction.""" @@ -752,16 +821,6 @@ async def test_agent_hooks_stream_async(agent, mock_hook_messages, hook_provider assert hook_provider.events_received == [StartRequestEvent(agent=agent), EndRequestEvent(agent=agent)] -def test_agent_hooks_structured_output(agent, mock_hook_messages, hook_provider): - """Verify that the correct hook events are emitted as part of structured_output.""" - - expected_user = User(name="Jane Doe", age=30, email="jane@doe.com") - agent.model.structured_output = unittest.mock.Mock(return_value=[{"output": expected_user}]) - agent.structured_output(User, "example prompt") - - assert hook_provider.events_received == [StartRequestEvent(agent=agent), EndRequestEvent(agent=agent)] - - def test_agent_tool(mock_randint, agent): conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager) agent.conversation_manager = conversation_manager_spy @@ -956,35 +1015,55 @@ def test_agent_callback_handler_custom_handler_used(): assert agent.callback_handler is custom_handler -# mock the User(name='Jane Doe', age=30, email='jane@doe.com') -class User(BaseModel): - """A user of the system.""" +def test_agent_structured_output(agent, user, agenerator): + agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}])) + + prompt = "Jane Doe is 30 years old and her email is jane@doe.com" + + tru_result = agent.structured_output(type(user), prompt) + exp_result = user + assert tru_result == exp_result + + agent.model.structured_output.assert_called_once_with(type(user), [{"role": "user", "content": [{"text": prompt}]}]) + + +@pytest.mark.asyncio +async def test_agent_structured_output_in_async_context(agent, user, agenerator): + agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}])) - name: str - age: int - email: str + prompt = "Jane Doe is 30 years old and her email is jane@doe.com" + tru_result = await agent.structured_output_async(type(user), prompt) + exp_result = user + assert tru_result == exp_result -def test_agent_method_structured_output(agent): - # Mock the structured_output method on the model - expected_user = User(name="Jane Doe", age=30, email="jane@doe.com") - agent.model.structured_output = unittest.mock.Mock(return_value=[{"output": expected_user}]) + +@pytest.mark.asyncio +async def test_agent_structured_output_async(agent, user, agenerator): + agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}])) prompt = "Jane Doe is 30 years old and her email is jane@doe.com" - result = agent.structured_output(User, prompt) - assert result == expected_user + tru_result = agent.structured_output(type(user), prompt) + exp_result = user + assert tru_result == exp_result + + agent.model.structured_output.assert_called_once_with(type(user), [{"role": "user", "content": [{"text": prompt}]}]) - # Verify the model's structured_output was called with correct arguments - agent.model.structured_output.assert_called_once_with(User, [{"role": "user", "content": [{"text": prompt}]}]) + +def test_agent_hooks_structured_output(agent, user, mock_hook_messages, hook_provider, agenerator): + agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}])) + agent.structured_output(type(user), "example prompt") + + assert hook_provider.events_received == [StartRequestEvent(agent=agent), EndRequestEvent(agent=agent)] @pytest.mark.asyncio -async def test_stream_async_returns_all_events(mock_event_loop_cycle): +async def test_stream_async_returns_all_events(mock_event_loop_cycle, alist): agent = Agent() # Define the side effect to simulate callback handler being called multiple times - def test_event_loop(*args, **kwargs): + async def test_event_loop(*args, **kwargs): yield {"callback": {"data": "First chunk"}} yield {"callback": {"data": "Second chunk"}} yield {"callback": {"data": "Final chunk", "complete": True}} @@ -995,14 +1074,22 @@ def test_event_loop(*args, **kwargs): mock_event_loop_cycle.side_effect = test_event_loop mock_callback = unittest.mock.Mock() - iterator = agent.stream_async("test message", callback_handler=mock_callback) + stream = agent.stream_async("test message", callback_handler=mock_callback) - tru_events = [e async for e in iterator] + tru_events = await alist(stream) exp_events = [ {"init_event_loop": True, "callback_handler": mock_callback}, {"data": "First chunk"}, {"data": "Second chunk"}, {"complete": True, "data": "Final chunk"}, + { + "result": AgentResult( + stop_reason="stop", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics={}, + state={}, + ), + }, ] assert tru_events == exp_events @@ -1011,24 +1098,26 @@ def test_event_loop(*args, **kwargs): @pytest.mark.asyncio -async def test_stream_async_passes_kwargs(agent, mock_model, mock_event_loop_cycle): +async def test_stream_async_passes_kwargs(agent, mock_model, mock_event_loop_cycle, agenerator, alist): mock_model.mock_converse.side_effect = [ - [ - { - "contentBlockStart": { - "start": { - "toolUse": { - "toolUseId": "t1", - "name": "a_tool", + agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "t1", + "name": "a_tool", + }, }, }, }, - }, - {"messageStop": {"stopReason": "tool_use"}}, - ], + {"messageStop": {"stopReason": "tool_use"}}, + ] + ), ] - def check_kwargs(**kwargs): + async def check_kwargs(**kwargs): kwargs_kwargs = kwargs["kwargs"] assert kwargs_kwargs["some_value"] == "a_value" # Return expected values from event_loop_cycle @@ -1036,10 +1125,22 @@ def check_kwargs(**kwargs): mock_event_loop_cycle.side_effect = check_kwargs - iterator = agent.stream_async("test message", some_value="a_value") - actual_events = [e async for e in iterator] + stream = agent.stream_async("test message", some_value="a_value") + + tru_events = await alist(stream) + exp_events = [ + {"init_event_loop": True, "some_value": "a_value"}, + { + "result": AgentResult( + stop_reason="stop", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics={}, + state={}, + ), + }, + ] + assert tru_events == exp_events - assert actual_events == [{"init_event_loop": True, "some_value": "a_value"}] assert mock_event_loop_cycle.call_count == 1 @@ -1048,11 +1149,11 @@ async def test_stream_async_raises_exceptions(mock_event_loop_cycle): mock_event_loop_cycle.side_effect = ValueError("Test exception") agent = Agent() - iterator = agent.stream_async("test message") + stream = agent.stream_async("test message") - await anext(iterator) + await anext(stream) with pytest.raises(ValueError, match="Test exception"): - await anext(iterator) + await anext(stream) def test_agent_init_with_trace_attributes(): @@ -1105,7 +1206,7 @@ def test_agent_init_initializes_tracer(mock_get_tracer): @unittest.mock.patch("strands.agent.agent.get_tracer") -def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model): +def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model, agenerator): """Test that __call__ creates and ends a span when the call succeeds.""" # Setup mock tracer and span mock_tracer = unittest.mock.MagicMock() @@ -1115,10 +1216,12 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model # Setup mock model response mock_model.mock_converse.side_effect = [ - [ - {"contentBlockDelta": {"delta": {"text": "test response"}}}, - {"contentBlockStop": {}}, - ], + agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "test response"}}}, + {"contentBlockStop": {}}, + ] + ), ] # Create agent and make a call @@ -1140,7 +1243,7 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model @pytest.mark.asyncio @unittest.mock.patch("strands.agent.agent.get_tracer") -async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_tracer, mock_event_loop_cycle): +async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_tracer, mock_event_loop_cycle, alist): """Test that stream_async creates and ends a span when the call succeeds.""" # Setup mock tracer and span mock_tracer = unittest.mock.MagicMock() @@ -1148,16 +1251,15 @@ async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_trac mock_tracer.start_agent_span.return_value = mock_span mock_get_tracer.return_value = mock_tracer - def test_event_loop(*args, **kwargs): + async def test_event_loop(*args, **kwargs): yield {"stop": ("stop", {"role": "assistant", "content": [{"text": "Agent Response"}]}, {}, {})} mock_event_loop_cycle.side_effect = test_event_loop # Create agent and make a call agent = Agent(model=mock_model) - iterator = agent.stream_async("test prompt") - async for _event in iterator: - pass # NoOp + stream = agent.stream_async("test prompt") + await alist(stream) # Verify span was created mock_tracer.start_agent_span.assert_called_once_with( @@ -1211,7 +1313,7 @@ def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_mod @pytest.mark.asyncio @unittest.mock.patch("strands.agent.agent.get_tracer") -async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tracer, mock_model): +async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tracer, mock_model, alist): """Test that stream_async creates and ends a span when the call succeeds.""" # Setup mock tracer and span mock_tracer = unittest.mock.MagicMock() @@ -1228,9 +1330,8 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr # Call the agent and catch the exception with pytest.raises(ValueError): - iterator = agent.stream_async("test prompt") - async for _event in iterator: - pass # NoOp + stream = agent.stream_async("test prompt") + await alist(stream) # Verify span was created mock_tracer.start_agent_span.assert_called_once_with( @@ -1246,7 +1347,7 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr @unittest.mock.patch("strands.agent.agent.get_tracer") -def test_event_loop_cycle_includes_parent_span(mock_get_tracer, mock_event_loop_cycle, mock_model): +def test_event_loop_cycle_includes_parent_span(mock_get_tracer, mock_event_loop_cycle, mock_model, agenerator): """Test that event_loop_cycle is called with the parent span.""" # Setup mock tracer and span mock_tracer = unittest.mock.MagicMock() @@ -1255,9 +1356,9 @@ def test_event_loop_cycle_includes_parent_span(mock_get_tracer, mock_event_loop_ mock_get_tracer.return_value = mock_tracer # Setup mock for event_loop_cycle - mock_event_loop_cycle.return_value = [ - {"stop": ("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})} - ] + mock_event_loop_cycle.return_value = agenerator( + [{"stop": ("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})}] + ) # Create agent and make a call agent = Agent(model=mock_model) diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index f07f0d27a..291b7be30 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -99,18 +99,23 @@ def mock_tracer(): return tracer -def test_event_loop_cycle_text_response( +@pytest.mark.asyncio +async def test_event_loop_cycle_text_response( model, system_prompt, messages, tool_config, tool_handler, thread_pool, + agenerator, + alist, ): - model.converse.return_value = [ - {"contentBlockDelta": {"delta": {"text": "test text"}}}, - {"contentBlockStop": {}}, - ] + model.converse.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "test text"}}}, + {"contentBlockStop": {}}, + ] + ) stream = strands.event_loop.event_loop.event_loop_cycle( model=model, @@ -123,8 +128,8 @@ def test_event_loop_cycle_text_response( event_loop_parent_span=None, kwargs={}, ) - event = list(stream)[-1] - tru_stop_reason, tru_message, _, tru_request_state = event["stop"] + events = await alist(stream) + tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}]} @@ -133,7 +138,8 @@ def test_event_loop_cycle_text_response( assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state -def test_event_loop_cycle_text_response_throttling( +@pytest.mark.asyncio +async def test_event_loop_cycle_text_response_throttling( mock_time, model, system_prompt, @@ -141,13 +147,17 @@ def test_event_loop_cycle_text_response_throttling( tool_config, tool_handler, thread_pool, + agenerator, + alist, ): model.converse.side_effect = [ ModelThrottledException("ThrottlingException | ConverseStream"), - [ - {"contentBlockDelta": {"delta": {"text": "test text"}}}, - {"contentBlockStop": {}}, - ], + agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "test text"}}}, + {"contentBlockStop": {}}, + ] + ), ] stream = strands.event_loop.event_loop.event_loop_cycle( @@ -161,8 +171,8 @@ def test_event_loop_cycle_text_response_throttling( event_loop_parent_span=None, kwargs={}, ) - event = list(stream)[-1] - tru_stop_reason, tru_message, _, tru_request_state = event["stop"] + events = await alist(stream) + tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}]} @@ -173,7 +183,8 @@ def test_event_loop_cycle_text_response_throttling( mock_time.sleep.assert_called_once() -def test_event_loop_cycle_exponential_backoff( +@pytest.mark.asyncio +async def test_event_loop_cycle_exponential_backoff( mock_time, model, system_prompt, @@ -181,6 +192,8 @@ def test_event_loop_cycle_exponential_backoff( tool_config, tool_handler, thread_pool, + agenerator, + alist, ): """Test that the exponential backoff works correctly with multiple retries.""" # Set up the model to raise throttling exceptions multiple times before succeeding @@ -188,10 +201,12 @@ def test_event_loop_cycle_exponential_backoff( ModelThrottledException("ThrottlingException | ConverseStream"), ModelThrottledException("ThrottlingException | ConverseStream"), ModelThrottledException("ThrottlingException | ConverseStream"), - [ - {"contentBlockDelta": {"delta": {"text": "test text"}}}, - {"contentBlockStop": {}}, - ], + agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "test text"}}}, + {"contentBlockStop": {}}, + ] + ), ] stream = strands.event_loop.event_loop.event_loop_cycle( @@ -205,8 +220,8 @@ def test_event_loop_cycle_exponential_backoff( event_loop_parent_span=None, kwargs={}, ) - event = list(stream)[-1] - tru_stop_reason, tru_message, _, tru_request_state = event["stop"] + events = await alist(stream) + tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] # Verify the final response assert tru_stop_reason == "end_turn" @@ -219,7 +234,8 @@ def test_event_loop_cycle_exponential_backoff( assert mock_time.sleep.call_args_list == [call(4), call(8), call(16)] -def test_event_loop_cycle_text_response_throttling_exceeded( +@pytest.mark.asyncio +async def test_event_loop_cycle_text_response_throttling_exceeded( mock_time, model, system_prompt, @@ -227,6 +243,7 @@ def test_event_loop_cycle_text_response_throttling_exceeded( tool_config, tool_handler, thread_pool, + alist, ): model.converse.side_effect = [ ModelThrottledException("ThrottlingException | ConverseStream"), @@ -249,7 +266,7 @@ def test_event_loop_cycle_text_response_throttling_exceeded( event_loop_parent_span=None, kwargs={}, ) - list(stream) + await alist(stream) mock_time.sleep.assert_has_calls( [ @@ -262,13 +279,15 @@ def test_event_loop_cycle_text_response_throttling_exceeded( ) -def test_event_loop_cycle_text_response_error( +@pytest.mark.asyncio +async def test_event_loop_cycle_text_response_error( model, system_prompt, messages, tool_config, tool_handler, thread_pool, + alist, ): model.converse.side_effect = RuntimeError("Unhandled error") @@ -284,10 +303,11 @@ def test_event_loop_cycle_text_response_error( event_loop_parent_span=None, kwargs={}, ) - list(stream) + await alist(stream) -def test_event_loop_cycle_tool_result( +@pytest.mark.asyncio +async def test_event_loop_cycle_tool_result( model, system_prompt, messages, @@ -295,13 +315,17 @@ def test_event_loop_cycle_tool_result( tool_handler, thread_pool, tool_stream, + agenerator, + alist, ): model.converse.side_effect = [ - tool_stream, - [ - {"contentBlockDelta": {"delta": {"text": "test text"}}}, - {"contentBlockStop": {}}, - ], + agenerator(tool_stream), + agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "test text"}}}, + {"contentBlockStop": {}}, + ] + ), ] stream = strands.event_loop.event_loop.event_loop_cycle( @@ -315,8 +339,8 @@ def test_event_loop_cycle_tool_result( event_loop_parent_span=None, kwargs={}, ) - event = list(stream)[-1] - tru_stop_reason, tru_message, _, tru_request_state = event["stop"] + events = await alist(stream) + tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}]} @@ -358,7 +382,8 @@ def test_event_loop_cycle_tool_result( ) -def test_event_loop_cycle_tool_result_error( +@pytest.mark.asyncio +async def test_event_loop_cycle_tool_result_error( model, system_prompt, messages, @@ -366,8 +391,10 @@ def test_event_loop_cycle_tool_result_error( tool_handler, thread_pool, tool_stream, + agenerator, + alist, ): - model.converse.side_effect = [tool_stream] + model.converse.side_effect = [agenerator(tool_stream)] with pytest.raises(EventLoopException): stream = strands.event_loop.event_loop.event_loop_cycle( @@ -381,18 +408,21 @@ def test_event_loop_cycle_tool_result_error( event_loop_parent_span=None, kwargs={}, ) - list(stream) + await alist(stream) -def test_event_loop_cycle_tool_result_no_tool_handler( +@pytest.mark.asyncio +async def test_event_loop_cycle_tool_result_no_tool_handler( model, system_prompt, messages, tool_config, thread_pool, tool_stream, + agenerator, + alist, ): - model.converse.side_effect = [tool_stream] + model.converse.side_effect = [agenerator(tool_stream)] with pytest.raises(EventLoopException): stream = strands.event_loop.event_loop.event_loop_cycle( @@ -406,18 +436,21 @@ def test_event_loop_cycle_tool_result_no_tool_handler( event_loop_parent_span=None, kwargs={}, ) - list(stream) + await alist(stream) -def test_event_loop_cycle_tool_result_no_tool_config( +@pytest.mark.asyncio +async def test_event_loop_cycle_tool_result_no_tool_config( model, system_prompt, messages, tool_handler, thread_pool, tool_stream, + agenerator, + alist, ): - model.converse.side_effect = [tool_stream] + model.converse.side_effect = [agenerator(tool_stream)] with pytest.raises(EventLoopException): stream = strands.event_loop.event_loop.event_loop_cycle( @@ -431,10 +464,11 @@ def test_event_loop_cycle_tool_result_no_tool_config( event_loop_parent_span=None, kwargs={}, ) - list(stream) + await alist(stream) -def test_event_loop_cycle_stop( +@pytest.mark.asyncio +async def test_event_loop_cycle_stop( model, system_prompt, messages, @@ -442,22 +476,26 @@ def test_event_loop_cycle_stop( tool_handler, thread_pool, tool, + agenerator, + alist, ): model.converse.side_effect = [ - [ - { - "contentBlockStart": { - "start": { - "toolUse": { - "toolUseId": "t1", - "name": tool.tool_spec["name"], + agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "t1", + "name": tool.tool_spec["name"], + }, }, }, }, - }, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "tool_use"}}, - ], + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + ] + ), ] stream = strands.event_loop.event_loop.event_loop_cycle( @@ -471,8 +509,8 @@ def test_event_loop_cycle_stop( event_loop_parent_span=None, kwargs={"request_state": {"stop_event_loop": True}}, ) - event = list(stream)[-1] - tru_stop_reason, tru_message, _, tru_request_state = event["stop"] + events = await alist(stream) + tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] exp_stop_reason = "tool_use" exp_message = { @@ -492,7 +530,8 @@ def test_event_loop_cycle_stop( assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state -def test_cycle_exception( +@pytest.mark.asyncio +async def test_cycle_exception( model, system_prompt, messages, @@ -500,8 +539,14 @@ def test_cycle_exception( tool_handler, thread_pool, tool_stream, + agenerator, ): - model.converse.side_effect = [tool_stream, tool_stream, tool_stream, ValueError("Invalid error presented")] + model.converse.side_effect = [ + agenerator(tool_stream), + agenerator(tool_stream), + agenerator(tool_stream), + ValueError("Invalid error presented"), + ] tru_stop_event = None exp_stop_event = {"callback": {"force_stop": True, "force_stop_reason": "Invalid error presented"}} @@ -518,14 +563,15 @@ def test_cycle_exception( event_loop_parent_span=None, kwargs={}, ) - for event in stream: + async for event in stream: tru_stop_event = event assert tru_stop_event == exp_stop_event @patch("strands.event_loop.event_loop.get_tracer") -def test_event_loop_cycle_creates_spans( +@pytest.mark.asyncio +async def test_event_loop_cycle_creates_spans( mock_get_tracer, model, system_prompt, @@ -534,6 +580,8 @@ def test_event_loop_cycle_creates_spans( tool_handler, thread_pool, mock_tracer, + agenerator, + alist, ): # Setup mock_get_tracer.return_value = mock_tracer @@ -542,10 +590,12 @@ def test_event_loop_cycle_creates_spans( model_span = MagicMock() mock_tracer.start_model_invoke_span.return_value = model_span - model.converse.return_value = [ - {"contentBlockDelta": {"delta": {"text": "test text"}}}, - {"contentBlockStop": {}}, - ] + model.converse.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "test text"}}}, + {"contentBlockStop": {}}, + ] + ) # Call event_loop_cycle stream = strands.event_loop.event_loop.event_loop_cycle( @@ -559,7 +609,7 @@ def test_event_loop_cycle_creates_spans( event_loop_parent_span=None, kwargs={}, ) - list(stream) + await alist(stream) # Verify tracer methods were called correctly mock_get_tracer.assert_called_once() @@ -570,7 +620,8 @@ def test_event_loop_cycle_creates_spans( @patch("strands.event_loop.event_loop.get_tracer") -def test_event_loop_tracing_with_model_error( +@pytest.mark.asyncio +async def test_event_loop_tracing_with_model_error( mock_get_tracer, model, system_prompt, @@ -579,6 +630,7 @@ def test_event_loop_tracing_with_model_error( tool_handler, thread_pool, mock_tracer, + alist, ): # Setup mock_get_tracer.return_value = mock_tracer @@ -603,14 +655,15 @@ def test_event_loop_tracing_with_model_error( event_loop_parent_span=None, kwargs={}, ) - list(stream) + await alist(stream) # Verify error handling span methods were called mock_tracer.end_span_with_error.assert_called_once_with(model_span, "Input too long", model.converse.side_effect) @patch("strands.event_loop.event_loop.get_tracer") -def test_event_loop_tracing_with_tool_execution( +@pytest.mark.asyncio +async def test_event_loop_tracing_with_tool_execution( mock_get_tracer, model, system_prompt, @@ -620,6 +673,8 @@ def test_event_loop_tracing_with_tool_execution( thread_pool, tool_stream, mock_tracer, + agenerator, + alist, ): # Setup mock_get_tracer.return_value = mock_tracer @@ -630,11 +685,13 @@ def test_event_loop_tracing_with_tool_execution( # Set up model to return tool use and then text response model.converse.side_effect = [ - tool_stream, - [ - {"contentBlockDelta": {"delta": {"text": "test text"}}}, - {"contentBlockStop": {}}, - ], + agenerator(tool_stream), + agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "test text"}}}, + {"contentBlockStop": {}}, + ] + ), ] # Call event_loop_cycle which should execute a tool @@ -649,7 +706,7 @@ def test_event_loop_tracing_with_tool_execution( event_loop_parent_span=None, kwargs={}, ) - list(stream) + await alist(stream) # Verify the parent_span parameter is passed to run_tools # At a minimum, verify both model spans were created (one for each model invocation) @@ -658,7 +715,8 @@ def test_event_loop_tracing_with_tool_execution( @patch("strands.event_loop.event_loop.get_tracer") -def test_event_loop_tracing_with_throttling_exception( +@pytest.mark.asyncio +async def test_event_loop_tracing_with_throttling_exception( mock_get_tracer, model, system_prompt, @@ -667,6 +725,8 @@ def test_event_loop_tracing_with_throttling_exception( tool_handler, thread_pool, mock_tracer, + agenerator, + alist, ): # Setup mock_get_tracer.return_value = mock_tracer @@ -678,10 +738,12 @@ def test_event_loop_tracing_with_throttling_exception( # Set up model to raise a throttling exception and then succeed model.converse.side_effect = [ ModelThrottledException("Throttling Error"), - [ - {"contentBlockDelta": {"delta": {"text": "test text"}}}, - {"contentBlockStop": {}}, - ], + agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "test text"}}}, + {"contentBlockStop": {}}, + ] + ), ] # Mock the time.sleep function to speed up the test @@ -697,7 +759,7 @@ def test_event_loop_tracing_with_throttling_exception( event_loop_parent_span=None, kwargs={}, ) - list(stream) + await alist(stream) # Verify error span was created for the throttling exception assert mock_tracer.end_span_with_error.call_count == 1 @@ -707,7 +769,8 @@ def test_event_loop_tracing_with_throttling_exception( @patch("strands.event_loop.event_loop.get_tracer") -def test_event_loop_cycle_with_parent_span( +@pytest.mark.asyncio +async def test_event_loop_cycle_with_parent_span( mock_get_tracer, model, system_prompt, @@ -716,6 +779,8 @@ def test_event_loop_cycle_with_parent_span( tool_handler, thread_pool, mock_tracer, + agenerator, + alist, ): # Setup mock_get_tracer.return_value = mock_tracer @@ -723,10 +788,12 @@ def test_event_loop_cycle_with_parent_span( cycle_span = MagicMock() mock_tracer.start_event_loop_cycle_span.return_value = cycle_span - model.converse.return_value = [ - {"contentBlockDelta": {"delta": {"text": "test text"}}}, - {"contentBlockStop": {}}, - ] + model.converse.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "test text"}}}, + {"contentBlockStop": {}}, + ] + ) # Call event_loop_cycle with a parent span stream = strands.event_loop.event_loop.event_loop_cycle( @@ -740,7 +807,7 @@ def test_event_loop_cycle_with_parent_span( event_loop_parent_span=parent_span, kwargs={}, ) - list(stream) + await alist(stream) # Verify parent_span was used when creating cycle span mock_tracer.start_event_loop_cycle_span.assert_called_once_with( @@ -748,7 +815,8 @@ def test_event_loop_cycle_with_parent_span( ) -def test_request_state_initialization(): +@pytest.mark.asyncio +async def test_request_state_initialization(alist): # Call without providing request_state stream = strands.event_loop.event_loop.event_loop_cycle( model=MagicMock(), @@ -761,8 +829,8 @@ def test_request_state_initialization(): event_loop_parent_span=None, kwargs={}, ) - event = list(stream)[-1] - _, _, _, tru_request_state = event["stop"] + events = await alist(stream) + _, _, _, tru_request_state = events[-1]["stop"] # Verify request_state was initialized to empty dict assert tru_request_state == {} @@ -780,33 +848,38 @@ def test_request_state_initialization(): event_loop_parent_span=None, kwargs={"request_state": initial_request_state}, ) - event = list(stream)[-1] - _, _, _, tru_request_state = event["stop"] + events = await alist(stream) + _, _, _, tru_request_state = events[-1]["stop"] # Verify existing request_state was preserved assert tru_request_state == initial_request_state -def test_prepare_next_cycle_in_tool_execution(model, tool_stream): +@pytest.mark.asyncio +async def test_prepare_next_cycle_in_tool_execution(model, tool_stream, agenerator, alist): """Test that cycle ID and metrics are properly updated during tool execution.""" model.converse.side_effect = [ - tool_stream, - [ - {"contentBlockStop": {}}, - ], + agenerator(tool_stream), + agenerator( + [ + {"contentBlockStop": {}}, + ] + ), ] # Create a mock for recurse_event_loop to capture the kwargs passed to it with unittest.mock.patch.object(strands.event_loop.event_loop, "recurse_event_loop") as mock_recurse: # Set up mock to return a valid response - mock_recurse.side_effect = [ - ( - "end_turn", - {"role": "assistant", "content": [{"text": "test text"}]}, - strands.telemetry.metrics.EventLoopMetrics(), - {}, - ), - ] + mock_recurse.return_value = agenerator( + [ + ( + "end_turn", + {"role": "assistant", "content": [{"text": "test text"}]}, + strands.telemetry.metrics.EventLoopMetrics(), + {}, + ), + ] + ) # Call event_loop_cycle which should execute a tool and then call recurse_event_loop stream = strands.event_loop.event_loop.event_loop_cycle( @@ -820,7 +893,7 @@ def test_prepare_next_cycle_in_tool_execution(model, tool_stream): event_loop_parent_span=None, kwargs={}, ) - list(stream) + await alist(stream) assert mock_recurse.called diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index e91f49867..7b64264e3 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -526,20 +526,24 @@ def test_extract_usage_metrics(): ), ], ) -def test_process_stream(response, exp_events): +@pytest.mark.asyncio +async def test_process_stream(response, exp_events, agenerator, alist): messages = [{"role": "user", "content": [{"text": "Some input!"}]}] - stream = strands.event_loop.streaming.process_stream(response, messages) + stream = strands.event_loop.streaming.process_stream(agenerator(response), messages) - tru_events = list(stream) + tru_events = await alist(stream) assert tru_events == exp_events -def test_stream_messages(): +@pytest.mark.asyncio +async def test_stream_messages(agenerator, alist): mock_model = unittest.mock.MagicMock() - mock_model.converse.return_value = [ - {"contentBlockDelta": {"delta": {"text": "test"}}}, - {"contentBlockStop": {}}, - ] + mock_model.converse.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + ] + ) stream = strands.event_loop.streaming.stream_messages( mock_model, @@ -548,7 +552,7 @@ def test_stream_messages(): tool_config=None, ) - tru_events = list(stream) + tru_events = await alist(stream) exp_events = [ { "callback": { diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 203352151..66046b7a8 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -624,7 +624,8 @@ def test_format_chunk_unknown(model): model.format_chunk(event) -def test_stream(anthropic_client, model): +@pytest.mark.asyncio +async def test_stream(anthropic_client, model, alist): mock_event_1 = unittest.mock.Mock( type="message_start", dict=lambda: {"type": "message_start"}, @@ -652,7 +653,7 @@ def test_stream(anthropic_client, model): request = {"model": "m1"} response = model.stream(request) - tru_events = list(response) + tru_events = await alist(response) exp_events = [ {"type": "message_start"}, { @@ -665,13 +666,14 @@ def test_stream(anthropic_client, model): anthropic_client.messages.stream.assert_called_once_with(**request) -def test_stream_rate_limit_error(anthropic_client, model): +@pytest.mark.asyncio +async def test_stream_rate_limit_error(anthropic_client, model, alist): anthropic_client.messages.stream.side_effect = anthropic.RateLimitError( "rate limit", response=unittest.mock.Mock(), body=None ) with pytest.raises(ModelThrottledException, match="rate limit"): - next(model.stream({})) + await alist(model.stream({})) @pytest.mark.parametrize( @@ -682,25 +684,28 @@ def test_stream_rate_limit_error(anthropic_client, model): "...input and output tokens exceed your context limit...", ], ) -def test_stream_bad_request_overflow_error(overflow_message, anthropic_client, model): +@pytest.mark.asyncio +async def test_stream_bad_request_overflow_error(overflow_message, anthropic_client, model): anthropic_client.messages.stream.side_effect = anthropic.BadRequestError( overflow_message, response=unittest.mock.Mock(), body=None ) with pytest.raises(ContextWindowOverflowException): - next(model.stream({})) + await anext(model.stream({})) -def test_stream_bad_request_error(anthropic_client, model): +@pytest.mark.asyncio +async def test_stream_bad_request_error(anthropic_client, model): anthropic_client.messages.stream.side_effect = anthropic.BadRequestError( "bad", response=unittest.mock.Mock(), body=None ) with pytest.raises(anthropic.BadRequestError, match="bad"): - next(model.stream({})) + await anext(model.stream({})) -def test_structured_output(anthropic_client, model, test_output_model_cls): +@pytest.mark.asyncio +async def test_structured_output(anthropic_client, model, test_output_model_cls, alist): messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] events = [ @@ -749,7 +754,8 @@ def test_structured_output(anthropic_client, model, test_output_model_cls): anthropic_client.messages.stream.return_value.__enter__.return_value = mock_stream stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) - tru_result = list(stream)[-1] + tru_result = events[-1] exp_result = {"output": test_output_model_cls(name="John", age=30)} assert tru_result == exp_result diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index e6ade0dbb..e9fd9f34a 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -405,20 +405,22 @@ def test_format_chunk(model): assert tru_chunk == exp_chunk -def test_stream(bedrock_client, model): +@pytest.mark.asyncio +async def test_stream(bedrock_client, model, alist): bedrock_client.converse_stream.return_value = {"stream": ["e1", "e2"]} request = {"a": 1} response = model.stream(request) - tru_events = list(response) + tru_events = await alist(response) exp_events = ["e1", "e2"] assert tru_events == exp_events bedrock_client.converse_stream.assert_called_once_with(a=1) -def test_stream_throttling_exception_from_event_stream_error(bedrock_client, model): +@pytest.mark.asyncio +async def test_stream_throttling_exception_from_event_stream_error(bedrock_client, model, alist): error_message = "Rate exceeded" bedrock_client.converse_stream.side_effect = EventStreamError( {"Error": {"Message": error_message, "Code": "ThrottlingException"}}, "ConverseStream" @@ -427,13 +429,14 @@ def test_stream_throttling_exception_from_event_stream_error(bedrock_client, mod request = {"a": 1} with pytest.raises(ModelThrottledException) as excinfo: - list(model.stream(request)) + await alist(model.stream(request)) assert error_message in str(excinfo.value) bedrock_client.converse_stream.assert_called_once_with(a=1) -def test_stream_throttling_exception_from_general_exception(bedrock_client, model): +@pytest.mark.asyncio +async def test_stream_throttling_exception_from_general_exception(bedrock_client, model, alist): error_message = "ThrottlingException: Rate exceeded for ConverseStream" bedrock_client.converse_stream.side_effect = ClientError( {"Error": {"Message": error_message, "Code": "ThrottlingException"}}, "Any" @@ -442,26 +445,28 @@ def test_stream_throttling_exception_from_general_exception(bedrock_client, mode request = {"a": 1} with pytest.raises(ModelThrottledException) as excinfo: - list(model.stream(request)) + await alist(model.stream(request)) assert error_message in str(excinfo.value) bedrock_client.converse_stream.assert_called_once_with(a=1) -def test_general_exception_is_raised(bedrock_client, model): +@pytest.mark.asyncio +async def test_general_exception_is_raised(bedrock_client, model, alist): error_message = "Should be raised up" bedrock_client.converse_stream.side_effect = ValueError(error_message) request = {"a": 1} with pytest.raises(ValueError) as excinfo: - list(model.stream(request)) + await alist(model.stream(request)) assert error_message in str(excinfo.value) bedrock_client.converse_stream.assert_called_once_with(a=1) -def test_converse(bedrock_client, model, messages, tool_spec, model_id, additional_request_fields): +@pytest.mark.asyncio +async def test_converse(bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist): bedrock_client.converse_stream.return_value = {"stream": ["e1", "e2"]} request = { @@ -477,17 +482,18 @@ def test_converse(bedrock_client, model, messages, tool_spec, model_id, addition } model.update_config(additional_request_fields=additional_request_fields) - chunks = model.converse(messages, [tool_spec]) + response = model.converse(messages, [tool_spec]) - tru_chunks = list(chunks) + tru_chunks = await alist(response) exp_chunks = ["e1", "e2"] assert tru_chunks == exp_chunks bedrock_client.converse_stream.assert_called_once_with(**request) -def test_converse_stream_input_guardrails( - bedrock_client, model, messages, tool_spec, model_id, additional_request_fields +@pytest.mark.asyncio +async def test_converse_stream_input_guardrails( + bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist ): metadata_event = { "metadata": { @@ -527,9 +533,9 @@ def test_converse_stream_input_guardrails( } model.update_config(additional_request_fields=additional_request_fields) - chunks = model.converse(messages, [tool_spec]) + response = model.converse(messages, [tool_spec]) - tru_chunks = list(chunks) + tru_chunks = await alist(response) exp_chunks = [ {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, metadata_event, @@ -539,8 +545,9 @@ def test_converse_stream_input_guardrails( bedrock_client.converse_stream.assert_called_once_with(**request) -def test_converse_stream_output_guardrails( - bedrock_client, model, messages, tool_spec, model_id, additional_request_fields +@pytest.mark.asyncio +async def test_converse_stream_output_guardrails( + bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist ): model.update_config(guardrail_redact_input=False, guardrail_redact_output=True) metadata_event = { @@ -583,9 +590,9 @@ def test_converse_stream_output_guardrails( } model.update_config(additional_request_fields=additional_request_fields) - chunks = model.converse(messages, [tool_spec]) + response = model.converse(messages, [tool_spec]) - tru_chunks = list(chunks) + tru_chunks = await alist(response) exp_chunks = [ {"redactContent": {"redactAssistantContentMessage": "[Assistant output redacted.]"}}, metadata_event, @@ -595,8 +602,9 @@ def test_converse_stream_output_guardrails( bedrock_client.converse_stream.assert_called_once_with(**request) -def test_converse_output_guardrails_redacts_input_and_output( - bedrock_client, model, messages, tool_spec, model_id, additional_request_fields +@pytest.mark.asyncio +async def test_converse_output_guardrails_redacts_input_and_output( + bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist ): model.update_config(guardrail_redact_output=True) metadata_event = { @@ -639,9 +647,9 @@ def test_converse_output_guardrails_redacts_input_and_output( } model.update_config(additional_request_fields=additional_request_fields) - chunks = model.converse(messages, [tool_spec]) + response = model.converse(messages, [tool_spec]) - tru_chunks = list(chunks) + tru_chunks = await alist(response) exp_chunks = [ {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, {"redactContent": {"redactAssistantContentMessage": "[Assistant output redacted.]"}}, @@ -652,8 +660,9 @@ def test_converse_output_guardrails_redacts_input_and_output( bedrock_client.converse_stream.assert_called_once_with(**request) -def test_converse_output_no_blocked_guardrails_doesnt_redact( - bedrock_client, model, messages, tool_spec, model_id, additional_request_fields +@pytest.mark.asyncio +async def test_converse_output_no_blocked_guardrails_doesnt_redact( + bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist ): metadata_event = { "metadata": { @@ -695,17 +704,18 @@ def test_converse_output_no_blocked_guardrails_doesnt_redact( } model.update_config(additional_request_fields=additional_request_fields) - chunks = model.converse(messages, [tool_spec]) + response = model.converse(messages, [tool_spec]) - tru_chunks = list(chunks) + tru_chunks = await alist(response) exp_chunks = [metadata_event] assert tru_chunks == exp_chunks bedrock_client.converse_stream.assert_called_once_with(**request) -def test_converse_output_no_guardrail_redact( - bedrock_client, model, messages, tool_spec, model_id, additional_request_fields +@pytest.mark.asyncio +async def test_converse_output_no_guardrail_redact( + bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist ): metadata_event = { "metadata": { @@ -751,40 +761,43 @@ def test_converse_output_no_guardrail_redact( guardrail_redact_output=False, guardrail_redact_input=False, ) - chunks = model.converse(messages, [tool_spec]) + response = model.converse(messages, [tool_spec]) - tru_chunks = list(chunks) + tru_chunks = await alist(response) exp_chunks = [metadata_event] assert tru_chunks == exp_chunks bedrock_client.converse_stream.assert_called_once_with(**request) -def test_stream_with_streaming_false(bedrock_client): +@pytest.mark.asyncio +async def test_stream_with_streaming_false(bedrock_client, alist): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, "stopReason": "end_turn", } - expected_events = [ - {"messageStart": {"role": "assistant"}}, - {"contentBlockDelta": {"delta": {"text": "test"}}}, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, - ] # Create model and call stream model = BedrockModel(model_id="test-model", streaming=False) request = {"modelId": "test-model"} - events = list(model.stream(request)) + response = model.stream(request) - assert expected_events == events + tru_events = await alist(response) + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, + ] + assert tru_events == exp_events bedrock_client.converse.assert_called_once() bedrock_client.converse_stream.assert_not_called() -def test_stream_with_streaming_false_and_tool_use(bedrock_client): +@pytest.mark.asyncio +async def test_stream_with_streaming_false_and_tool_use(bedrock_client, alist): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": { @@ -796,26 +809,27 @@ def test_stream_with_streaming_false_and_tool_use(bedrock_client): "stopReason": "tool_use", } - expected_events = [ + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + response = model.stream(request) + + tru_events = await alist(response) + exp_events = [ {"messageStart": {"role": "assistant"}}, {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "dummyTool"}}}}, {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"hello": "world!"}'}}}}, {"contentBlockStop": {}}, {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, ] - - # Create model and call stream - model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - events = list(model.stream(request)) - - assert expected_events == events + assert tru_events == exp_events bedrock_client.converse.assert_called_once() bedrock_client.converse_stream.assert_not_called() -def test_stream_with_streaming_false_and_reasoning(bedrock_client): +@pytest.mark.asyncio +async def test_stream_with_streaming_false_and_reasoning(bedrock_client, alist): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": { @@ -833,27 +847,28 @@ def test_stream_with_streaming_false_and_reasoning(bedrock_client): "stopReason": "tool_use", } - expected_events = [ + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + response = model.stream(request) + + tru_events = await alist(response) + exp_events = [ {"messageStart": {"role": "assistant"}}, {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Thinking really hard...."}}}}, {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "123"}}}}, {"contentBlockStop": {}}, {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, ] - - # Create model and call stream - model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - events = list(model.stream(request)) - - assert expected_events == events + assert tru_events == exp_events # Verify converse was called bedrock_client.converse.assert_called_once() bedrock_client.converse_stream.assert_not_called() -def test_converse_and_reasoning_no_signature(bedrock_client): +@pytest.mark.asyncio +async def test_converse_and_reasoning_no_signature(bedrock_client, alist): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": { @@ -871,25 +886,26 @@ def test_converse_and_reasoning_no_signature(bedrock_client): "stopReason": "tool_use", } - expected_events = [ + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + response = model.stream(request) + + tru_events = await alist(response) + exp_events = [ {"messageStart": {"role": "assistant"}}, {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Thinking really hard...."}}}}, {"contentBlockStop": {}}, {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, ] - - # Create model and call stream - model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - events = list(model.stream(request)) - - assert expected_events == events + assert tru_events == exp_events bedrock_client.converse.assert_called_once() bedrock_client.converse_stream.assert_not_called() -def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client): +@pytest.mark.asyncio +async def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client, alist): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, @@ -898,7 +914,13 @@ def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client): "stopReason": "tool_use", } - expected_events = [ + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + response = model.stream(request) + + tru_events = await alist(response) + exp_events = [ {"messageStart": {"role": "assistant"}}, {"contentBlockDelta": {"delta": {"text": "test"}}}, {"contentBlockStop": {}}, @@ -910,20 +932,15 @@ def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client): } }, ] - - # Create model and call stream - model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - events = list(model.stream(request)) - - assert expected_events == events + assert tru_events == exp_events # Verify converse was called bedrock_client.converse.assert_called_once() bedrock_client.converse_stream.assert_not_called() -def test_converse_input_guardrails(bedrock_client): +@pytest.mark.asyncio +async def test_converse_input_guardrails(bedrock_client, alist): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, @@ -939,7 +956,13 @@ def test_converse_input_guardrails(bedrock_client): "stopReason": "end_turn", } - expected_events = [ + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + response = model.stream(request) + + tru_events = await alist(response) + exp_events = [ {"messageStart": {"role": "assistant"}}, {"contentBlockDelta": {"delta": {"text": "test"}}}, {"contentBlockStop": {}}, @@ -961,19 +984,14 @@ def test_converse_input_guardrails(bedrock_client): }, {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, ] - - # Create model and call stream - model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - events = list(model.stream(request)) - - assert expected_events == events + assert tru_events == exp_events bedrock_client.converse.assert_called_once() bedrock_client.converse_stream.assert_not_called() -def test_converse_output_guardrails(bedrock_client): +@pytest.mark.asyncio +async def test_converse_output_guardrails(bedrock_client, alist): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, @@ -991,7 +1009,12 @@ def test_converse_output_guardrails(bedrock_client): "stopReason": "end_turn", } - expected_events = [ + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + response = model.stream(request) + + tru_events = await alist(response) + exp_events = [ {"messageStart": {"role": "assistant"}}, {"contentBlockDelta": {"delta": {"text": "test"}}}, {"contentBlockStop": {}}, @@ -1015,18 +1038,14 @@ def test_converse_output_guardrails(bedrock_client): }, {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, ] - - model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - events = list(model.stream(request)) - - assert expected_events == events + assert tru_events == exp_events bedrock_client.converse.assert_called_once() bedrock_client.converse_stream.assert_not_called() -def test_converse_output_guardrails_redacts_output(bedrock_client): +@pytest.mark.asyncio +async def test_converse_output_guardrails_redacts_output(bedrock_client, alist): """Test stream method with streaming=False.""" bedrock_client.converse.return_value = { "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, @@ -1044,7 +1063,12 @@ def test_converse_output_guardrails_redacts_output(bedrock_client): "stopReason": "end_turn", } - expected_events = [ + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + response = model.stream(request) + + tru_events = await alist(response) + exp_events = [ {"messageStart": {"role": "assistant"}}, {"contentBlockDelta": {"delta": {"text": "test"}}}, {"contentBlockStop": {}}, @@ -1068,18 +1092,14 @@ def test_converse_output_guardrails_redacts_output(bedrock_client): }, {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, ] - - model = BedrockModel(model_id="test-model", streaming=False) - request = {"modelId": "test-model"} - events = list(model.stream(request)) - - assert expected_events == events + assert tru_events == exp_events bedrock_client.converse.assert_called_once() bedrock_client.converse_stream.assert_not_called() -def test_structured_output(bedrock_client, model, test_output_model_cls): +@pytest.mark.asyncio +async def test_structured_output(bedrock_client, model, test_output_model_cls, alist): messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] bedrock_client.converse_stream.return_value = { @@ -1093,14 +1113,16 @@ def test_structured_output(bedrock_client, model, test_output_model_cls): } stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) - tru_output = list(stream)[-1] + tru_output = events[-1] exp_output = {"output": test_output_model_cls(name="John", age=30)} assert tru_output == exp_output @pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") -def test_add_note_on_client_error(bedrock_client, model): +@pytest.mark.asyncio +async def test_add_note_on_client_error(bedrock_client, model, alist): """Test that add_note is called on ClientError with region and model ID information.""" # Mock the client error response error_response = {"Error": {"Code": "ValidationException", "Message": "Some error message"}} @@ -1108,12 +1130,13 @@ def test_add_note_on_client_error(bedrock_client, model): # Call the stream method which should catch and add notes to the exception with pytest.raises(ClientError) as err: - list(model.stream({"modelId": "test-model"})) + await alist(model.stream({"modelId": "test-model"})) assert err.value.__notes__ == ["└ Bedrock region: us-west-2", "└ Model id: m1"] -def test_no_add_note_when_not_available(bedrock_client, model): +@pytest.mark.asyncio +async def test_no_add_note_when_not_available(bedrock_client, model, alist): """Verify that on any python version (even < 3.11 where add_note is not available, we get the right exception).""" # Mock the client error response error_response = {"Error": {"Code": "ValidationException", "Message": "Some error message"}} @@ -1121,11 +1144,12 @@ def test_no_add_note_when_not_available(bedrock_client, model): # Call the stream method which should catch and add notes to the exception with pytest.raises(ClientError): - list(model.stream({"modelId": "test-model"})) + await alist(model.stream({"modelId": "test-model"})) @pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") -def test_add_note_on_access_denied_exception(bedrock_client, model): +@pytest.mark.asyncio +async def test_add_note_on_access_denied_exception(bedrock_client, model, alist): """Test that add_note adds documentation link for AccessDeniedException.""" # Mock the client error response for access denied error_response = { @@ -1139,7 +1163,7 @@ def test_add_note_on_access_denied_exception(bedrock_client, model): # Call the stream method which should catch and add notes to the exception with pytest.raises(ClientError) as err: - list(model.stream({"modelId": "test-model"})) + await alist(model.stream({"modelId": "test-model"})) assert err.value.__notes__ == [ "└ Bedrock region: us-west-2", @@ -1150,7 +1174,8 @@ def test_add_note_on_access_denied_exception(bedrock_client, model): @pytest.mark.skipif(sys.version_info < (3, 11), reason="This test requires Python 3.11 or higher (need add_note)") -def test_add_note_on_validation_exception_throughput(bedrock_client, model): +@pytest.mark.asyncio +async def test_add_note_on_validation_exception_throughput(bedrock_client, model, alist): """Test that add_note adds documentation link for ValidationException about on-demand throughput.""" # Mock the client error response for validation exception error_response = { @@ -1166,7 +1191,7 @@ def test_add_note_on_validation_exception_throughput(bedrock_client, model): # Call the stream method which should catch and add notes to the exception with pytest.raises(ClientError) as err: - list(model.stream({"modelId": "test-model"})) + await alist(model.stream({"modelId": "test-model"})) assert err.value.__notes__ == [ "└ Bedrock region: us-west-2", diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 50a073ad3..989b7eae6 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -115,7 +115,8 @@ def test_format_request_message_content(content, exp_result): assert tru_result == exp_result -def test_structured_output(litellm_client, model, test_output_model_cls): +@pytest.mark.asyncio +async def test_structured_output(litellm_client, model, test_output_model_cls, alist): messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] mock_choice = unittest.mock.Mock() @@ -128,7 +129,8 @@ def test_structured_output(litellm_client, model, test_output_model_cls): with unittest.mock.patch.object(strands.models.litellm, "supports_response_schema", return_value=True): stream = model.structured_output(test_output_model_cls, messages) - tru_result = list(stream)[-1] + events = await alist(stream) + tru_result = events[-1] exp_result = {"output": test_output_model_cls(name="John", age=30)} assert tru_result == exp_result diff --git a/tests/strands/models/test_mistral.py b/tests/strands/models/test_mistral.py index 1b1f02764..786ba25b3 100644 --- a/tests/strands/models/test_mistral.py +++ b/tests/strands/models/test_mistral.py @@ -436,21 +436,24 @@ def test_format_chunk_unknown(model): model.format_chunk(event) -def test_stream_rate_limit_error(mistral_client, model): +@pytest.mark.asyncio +async def test_stream_rate_limit_error(mistral_client, model, alist): mistral_client.chat.stream.side_effect = Exception("rate limit exceeded (429)") with pytest.raises(ModelThrottledException, match="rate limit exceeded"): - list(model.stream({})) + await alist(model.stream({})) -def test_stream_other_error(mistral_client, model): +@pytest.mark.asyncio +async def test_stream_other_error(mistral_client, model, alist): mistral_client.chat.stream.side_effect = Exception("some other error") with pytest.raises(Exception, match="some other error"): - list(model.stream({})) + await alist(model.stream({})) -def test_structured_output_success(mistral_client, model, test_output_model_cls): +@pytest.mark.asyncio +async def test_structured_output_success(mistral_client, model, test_output_model_cls, alist): messages = [{"role": "user", "content": [{"text": "Extract data"}]}] mock_response = unittest.mock.Mock() @@ -461,13 +464,15 @@ def test_structured_output_success(mistral_client, model, test_output_model_cls) mistral_client.chat.complete.return_value = mock_response stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) - tru_result = list(stream)[-1] + tru_result = events[-1] exp_result = {"output": test_output_model_cls(name="John", age=30)} assert tru_result == exp_result -def test_structured_output_no_tool_calls(mistral_client, model, test_output_model_cls): +@pytest.mark.asyncio +async def test_structured_output_no_tool_calls(mistral_client, model, test_output_model_cls): mock_response = unittest.mock.Mock() mock_response.choices = [unittest.mock.Mock()] mock_response.choices[0].message.tool_calls = None @@ -478,10 +483,11 @@ def test_structured_output_no_tool_calls(mistral_client, model, test_output_mode with pytest.raises(ValueError, match="No tool calls found in response"): stream = model.structured_output(test_output_model_cls, prompt) - next(stream) + await anext(stream) -def test_structured_output_invalid_json(mistral_client, model, test_output_model_cls): +@pytest.mark.asyncio +async def test_structured_output_invalid_json(mistral_client, model, test_output_model_cls): mock_response = unittest.mock.Mock() mock_response.choices = [unittest.mock.Mock()] mock_response.choices[0].message.tool_calls = [unittest.mock.Mock()] @@ -493,4 +499,4 @@ def test_structured_output_invalid_json(mistral_client, model, test_output_model with pytest.raises(ValueError, match="Failed to parse tool call arguments into model"): stream = model.structured_output(test_output_model_cls, prompt) - next(stream) + await anext(stream) diff --git a/tests/strands/models/test_ollama.py b/tests/strands/models/test_ollama.py index ead4caba0..c718a602c 100644 --- a/tests/strands/models/test_ollama.py +++ b/tests/strands/models/test_ollama.py @@ -415,7 +415,8 @@ def test_format_chunk_other(model): model.format_chunk(event) -def test_stream(ollama_client, model): +@pytest.mark.asyncio +async def test_stream(ollama_client, model, alist): mock_event = unittest.mock.Mock() mock_event.message.tool_calls = None mock_event.message.content = "Hello" @@ -426,7 +427,7 @@ def test_stream(ollama_client, model): request = {"model": "m1", "messages": [{"role": "user", "content": "Hello"}]} response = model.stream(request) - tru_events = list(response) + tru_events = await alist(response) exp_events = [ {"chunk_type": "message_start"}, {"chunk_type": "content_start", "data_type": "text"}, @@ -440,7 +441,8 @@ def test_stream(ollama_client, model): ollama_client.chat.assert_called_once_with(**request) -def test_stream_with_tool_calls(ollama_client, model): +@pytest.mark.asyncio +async def test_stream_with_tool_calls(ollama_client, model, alist): mock_event = unittest.mock.Mock() mock_tool_call = unittest.mock.Mock() mock_event.message.tool_calls = [mock_tool_call] @@ -452,7 +454,7 @@ def test_stream_with_tool_calls(ollama_client, model): request = {"model": "m1", "messages": [{"role": "user", "content": "Calculate 2+2"}]} response = model.stream(request) - tru_events = list(response) + tru_events = await alist(response) exp_events = [ {"chunk_type": "message_start"}, {"chunk_type": "content_start", "data_type": "text"}, @@ -469,7 +471,8 @@ def test_stream_with_tool_calls(ollama_client, model): ollama_client.chat.assert_called_once_with(**request) -def test_structured_output(ollama_client, model, test_output_model_cls): +@pytest.mark.asyncio +async def test_structured_output(ollama_client, model, test_output_model_cls, alist): messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] mock_response = unittest.mock.Mock() @@ -478,7 +481,8 @@ def test_structured_output(ollama_client, model, test_output_model_cls): ollama_client.chat.return_value = mock_response stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) - tru_result = list(stream)[-1] + tru_result = events[-1] exp_result = {"output": test_output_model_cls(name="John", age=30)} assert tru_result == exp_result diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 63226bd2c..7bc16e5c2 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -69,7 +69,8 @@ def test_update_config(model, model_id): assert tru_model_id == exp_model_id -def test_stream(openai_client, model): +@pytest.mark.asyncio +async def test_stream(openai_client, model, alist): mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) mock_tool_call_2_part_1 = unittest.mock.Mock(index=1) mock_delta_1 = unittest.mock.Mock( @@ -107,7 +108,7 @@ def test_stream(openai_client, model): request = {"model": "m1", "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]} response = model.stream(request) - tru_events = list(response) + tru_events = await alist(response) exp_events = [ {"chunk_type": "message_start"}, {"chunk_type": "content_start", "data_type": "text"}, @@ -131,7 +132,8 @@ def test_stream(openai_client, model): openai_client.chat.completions.create.assert_called_once_with(**request) -def test_stream_empty(openai_client, model): +@pytest.mark.asyncio +async def test_stream_empty(openai_client, model, alist): mock_delta = unittest.mock.Mock(content=None, tool_calls=None, reasoning_content=None) mock_usage = unittest.mock.Mock(prompt_tokens=0, completion_tokens=0, total_tokens=0) @@ -145,7 +147,7 @@ def test_stream_empty(openai_client, model): request = {"model": "m1", "messages": [{"role": "user", "content": []}]} response = model.stream(request) - tru_events = list(response) + tru_events = await alist(response) exp_events = [ {"chunk_type": "message_start"}, {"chunk_type": "content_start", "data_type": "text"}, @@ -158,7 +160,8 @@ def test_stream_empty(openai_client, model): openai_client.chat.completions.create.assert_called_once_with(**request) -def test_stream_with_empty_choices(openai_client, model): +@pytest.mark.asyncio +async def test_stream_with_empty_choices(openai_client, model, alist): mock_delta = unittest.mock.Mock(content="content", tool_calls=None, reasoning_content=None) mock_usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30) @@ -184,7 +187,7 @@ def test_stream_with_empty_choices(openai_client, model): request = {"model": "m1", "messages": [{"role": "user", "content": ["test"]}]} response = model.stream(request) - tru_events = list(response) + tru_events = await alist(response) exp_events = [ {"chunk_type": "message_start"}, {"chunk_type": "content_start", "data_type": "text"}, @@ -199,7 +202,8 @@ def test_stream_with_empty_choices(openai_client, model): openai_client.chat.completions.create.assert_called_once_with(**request) -def test_structured_output(openai_client, model, test_output_model_cls): +@pytest.mark.asyncio +async def test_structured_output(openai_client, model, test_output_model_cls, alist): messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] mock_parsed_instance = test_output_model_cls(name="John", age=30) @@ -211,7 +215,8 @@ def test_structured_output(openai_client, model, test_output_model_cls): openai_client.beta.chat.completions.parse.return_value = mock_response stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) - tru_result = list(stream)[-1] + tru_result = events[-1] exp_result = {"output": test_output_model_cls(name="John", age=30)} assert tru_result == exp_result diff --git a/tests/strands/types/models/test_model.py b/tests/strands/types/models/test_model.py index dddb763d0..93635f157 100644 --- a/tests/strands/types/models/test_model.py +++ b/tests/strands/types/models/test_model.py @@ -1,5 +1,3 @@ -from typing import Type - import pytest from pydantic import BaseModel @@ -18,8 +16,8 @@ def update_config(self, **model_config): def get_config(self): return - def structured_output(self, output_model: Type[BaseModel]) -> BaseModel: - return output_model(name="test", age=20) + async def structured_output(self, output_model): + yield output_model(name="test", age=20) def format_request(self, messages, tool_specs, system_prompt): return { @@ -31,7 +29,7 @@ def format_request(self, messages, tool_specs, system_prompt): def format_chunk(self, event): return {"event": event} - def stream(self, request): + async def stream(self, request): yield {"request": request} @@ -74,10 +72,11 @@ def system_prompt(): return "s1" -def test_converse(model, messages, tool_specs, system_prompt): +@pytest.mark.asyncio +async def test_converse(model, messages, tool_specs, system_prompt, alist): response = model.converse(messages, tool_specs, system_prompt) - tru_events = list(response) + tru_events = await alist(response) exp_events = [ { "event": { @@ -92,13 +91,18 @@ def test_converse(model, messages, tool_specs, system_prompt): assert tru_events == exp_events -def test_structured_output(model): +@pytest.mark.asyncio +async def test_structured_output(model, alist): response = model.structured_output(Person) + events = await alist(response) - assert response == Person(name="test", age=20) + tru_output = events[-1] + exp_output = Person(name="test", age=20) + assert tru_output == exp_output -def test_converse_logging(model, messages, tool_specs, system_prompt, caplog): +@pytest.mark.asyncio +async def test_converse_logging(model, messages, tool_specs, system_prompt, caplog, alist): """Test that converse method logs the formatted request at debug level.""" import logging @@ -107,7 +111,7 @@ def test_converse_logging(model, messages, tool_specs, system_prompt, caplog): # Execute the converse method response = model.converse(messages, tool_specs, system_prompt) - list(response) # Consume the generator to trigger all logging + await alist(response) # Check that the expected log messages are present assert "formatting request" in caplog.text diff --git a/tests/strands/types/models/test_openai.py b/tests/strands/types/models/test_openai.py index a17294fa1..dc43b3fcd 100644 --- a/tests/strands/types/models/test_openai.py +++ b/tests/strands/types/models/test_openai.py @@ -16,7 +16,7 @@ def update_config(self, **model_config): def get_config(self): return - def stream(self, request): + async def stream(self, request): yield {"request": request}