From e09d33b1c3cd4c8e1fefffd65d8c3b6720e94875 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 7 Jul 2025 23:53:45 +0000 Subject: [PATCH 1/4] models - llamaapi - async --- src/strands/models/llamaapi.py | 11 ++++----- tests-integ/test_model_llamaapi.py | 29 +++++++++++++++++++---- tests/strands/models/test_llamaapi.py | 33 ++++++++++++++++++++++++++- 3 files changed, 60 insertions(+), 13 deletions(-) diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 2b585439c..ffef21d2f 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -11,7 +11,6 @@ from typing import Any, AsyncGenerator, Optional, Type, TypeVar, Union, cast import llama_api_client -from llama_api_client import LlamaAPIClient from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override @@ -63,10 +62,8 @@ def __init__( self.config = LlamaAPIModel.LlamaConfig(**model_config) logger.debug("config=<%s> | initializing", self.config) - if not client_args: - self.client = LlamaAPIClient() - else: - self.client = LlamaAPIClient(**client_args) + client_args = client_args or {} + self.client = llama_api_client.AsyncLlamaAPIClient(**client_args) @override def update_config(self, **model_config: Unpack[LlamaConfig]) -> None: # type: ignore @@ -337,7 +334,7 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any] ModelThrottledException: When the model service is throttling requests from the client. """ try: - response = self.client.chat.completions.create(**request) + response = await self.client.chat.completions.create(**request) except llama_api_client.RateLimitError as e: raise ModelThrottledException(str(e)) from e @@ -348,7 +345,7 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any] curr_tool_call_id = None metrics_event = None - for chunk in response: + async for chunk in response: if chunk.event.event_type == "start": yield {"chunk_type": "content_start", "data_type": "text"} elif chunk.event.event_type in ["progress", "complete"] and chunk.event.delta.type == "text": diff --git a/tests-integ/test_model_llamaapi.py b/tests-integ/test_model_llamaapi.py index dad6919e2..f177c15a8 100644 --- a/tests-integ/test_model_llamaapi.py +++ b/tests-integ/test_model_llamaapi.py @@ -36,12 +36,31 @@ def agent(model, tools): return Agent(model=model, tools=tools) -@pytest.mark.skipif( - "LLAMA_API_KEY" not in os.environ, - reason="LLAMA_API_KEY environment variable missing", -) -def test_agent(agent): +@pytest.mark.skipif("LLAMA_API_KEY" not in os.environ, reason="LLAMA_API_KEY environment variable missing") +def test_agent_invoke(agent): result = agent("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.skipif("LLAMA_API_KEY" not in os.environ, reason="LLAMA_API_KEY environment variable missing") +@pytest.mark.asyncio +async def test_agent_invoke_async(agent): + result = await agent.invoke_async("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.skipif("LLAMA_API_KEY" not in os.environ, reason="LLAMA_API_KEY environment variable missing") +@pytest.mark.asyncio +async def test_agent_stream_async(agent): + stream = agent.stream_async("What is the time and weather in New York?") + async for event in stream: + _ = event + + result = event["result"] + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) diff --git a/tests/strands/models/test_llamaapi.py b/tests/strands/models/test_llamaapi.py index 309dac2e9..9d77e3ae2 100644 --- a/tests/strands/models/test_llamaapi.py +++ b/tests/strands/models/test_llamaapi.py @@ -9,7 +9,7 @@ @pytest.fixture def llamaapi_client(): - with unittest.mock.patch.object(strands.models.llamaapi, "LlamaAPIClient") as mock_client_cls: + with unittest.mock.patch.object(strands.models.llamaapi.llama_api_client, "AsyncLlamaAPIClient") as mock_client_cls: yield mock_client_cls.return_value @@ -361,3 +361,34 @@ def test_format_chunk_other(model): with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): model.format_chunk(event) + + +@pytest.mark.asyncio +async def test_stream(llamaapi_client, model, agenerator, alist): + mock_event_1 = unittest.mock.Mock(event=unittest.mock.Mock(event_type="start", stop_reason=None)) + mock_event_2 = unittest.mock.Mock( + event=unittest.mock.Mock( + delta=unittest.mock.Mock(text="test stream", type="text"), + event_type="complete", + stop_reason="end_turn", + ), + ) + + llamaapi_client.chat.completions.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_event_1, mock_event_2]) + ) + + request = {"model": "m1"} + response = model.stream(request) + + tru_events = await alist(response) + exp_events = [ + {"chunk_type": "message_start"}, + {"chunk_type": "content_start", "data_type": "text"}, + {"chunk_type": "content_delta", "data_type": "text", "data": "test stream"}, + {"chunk_type": "content_stop", "data_type": "text"}, + {"chunk_type": "message_stop", "data": "end_turn"}, + ] + assert tru_events == exp_events + + llamaapi_client.chat.completions.create.assert_called_once_with(**request) From 151686f032b23e0265bafe0475355d492ad6a489 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Thu, 10 Jul 2025 20:35:59 +0000 Subject: [PATCH 2/4] tests --- tests/strands/models/test_llamaapi.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/tests/strands/models/test_llamaapi.py b/tests/strands/models/test_llamaapi.py index 9d77e3ae2..ef84e43fe 100644 --- a/tests/strands/models/test_llamaapi.py +++ b/tests/strands/models/test_llamaapi.py @@ -378,17 +378,23 @@ async def test_stream(llamaapi_client, model, agenerator, alist): return_value=agenerator([mock_event_1, mock_event_2]) ) - request = {"model": "m1"} - response = model.stream(request) + messages = [{"role": "user", "content": [{"text": "calculate 2+2"}]}] + response = model.stream(messages) tru_events = await alist(response) exp_events = [ - {"chunk_type": "message_start"}, - {"chunk_type": "content_start", "data_type": "text"}, - {"chunk_type": "content_delta", "data_type": "text", "data": "test stream"}, - {"chunk_type": "content_stop", "data_type": "text"}, - {"chunk_type": "message_stop", "data": "end_turn"}, + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "test stream"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, ] assert tru_events == exp_events - llamaapi_client.chat.completions.create.assert_called_once_with(**request) + expected_request = { + "model": "Llama-4-Maverick-17B-128E-Instruct-FP8", + "messages": [{"role": "user", "content": [{"text": "calculate 2+2", "type": "text"}]}], + "stream": True, + "tools": [], + } + llamaapi_client.chat.completions.create.assert_called_once_with(**expected_request) From f87fad13ba7dea5ed1898177fcf63612a45956bd Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 9 Jan 2026 14:56:46 -0500 Subject: [PATCH 3/4] fix: update tests for async --- .gitignore | 1 + tests/strands/models/test_llamaapi.py | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 8b0fd989c..0b1375b50 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ repl_state .kiro uv.lock .audio_cache +CLAUDE.md diff --git a/tests/strands/models/test_llamaapi.py b/tests/strands/models/test_llamaapi.py index 7254aca2e..df824f664 100644 --- a/tests/strands/models/test_llamaapi.py +++ b/tests/strands/models/test_llamaapi.py @@ -417,16 +417,16 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings @pytest.mark.asyncio -async def test_tool_choice_not_supported_warns(model, messages, captured_warnings, alist): +async def test_tool_choice_not_supported_warns(model, messages, captured_warnings, alist, agenerator): """Test that non-None toolChoice emits warning for unsupported providers.""" tool_choice = {"auto": {}} - with unittest.mock.patch.object(model.client.chat.completions, "create") as mock_create: + with unittest.mock.patch.object(model.client.chat.completions, "create", new_callable=unittest.mock.AsyncMock) as mock_create: mock_chunk = unittest.mock.Mock() mock_chunk.event.event_type = "start" mock_chunk.event.stop_reason = "stop" - mock_create.return_value = [mock_chunk] + mock_create.return_value = agenerator([mock_chunk]) response = model.stream(messages, tool_choice=tool_choice) await alist(response) @@ -436,14 +436,14 @@ async def test_tool_choice_not_supported_warns(model, messages, captured_warning @pytest.mark.asyncio -async def test_tool_choice_none_no_warning(model, messages, captured_warnings, alist): +async def test_tool_choice_none_no_warning(model, messages, captured_warnings, alist, agenerator): """Test that None toolChoice doesn't emit warning.""" - with unittest.mock.patch.object(model.client.chat.completions, "create") as mock_create: + with unittest.mock.patch.object(model.client.chat.completions, "create", new_callable=unittest.mock.AsyncMock) as mock_create: mock_chunk = unittest.mock.Mock() mock_chunk.event.event_type = "start" mock_chunk.event.stop_reason = "stop" - mock_create.return_value = [mock_chunk] + mock_create.return_value = agenerator([mock_chunk]) response = model.stream(messages, tool_choice=None) await alist(response) From 3a033b47ab070a697f42d1296810877deff61926 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 9 Jan 2026 15:02:24 -0500 Subject: [PATCH 4/4] formatting --- tests/strands/models/test_llamaapi.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/strands/models/test_llamaapi.py b/tests/strands/models/test_llamaapi.py index df824f664..2d9caeaea 100644 --- a/tests/strands/models/test_llamaapi.py +++ b/tests/strands/models/test_llamaapi.py @@ -398,6 +398,8 @@ async def test_stream(llamaapi_client, model, agenerator, alist): "tools": [], } llamaapi_client.chat.completions.create.assert_called_once_with(**expected_request) + + def test_config_validation_warns_on_unknown_keys(llamaapi_client, captured_warnings): """Test that unknown config keys emit a warning.""" LlamaAPIModel(model_id="test-model", invalid_param="test") @@ -421,7 +423,9 @@ async def test_tool_choice_not_supported_warns(model, messages, captured_warning """Test that non-None toolChoice emits warning for unsupported providers.""" tool_choice = {"auto": {}} - with unittest.mock.patch.object(model.client.chat.completions, "create", new_callable=unittest.mock.AsyncMock) as mock_create: + with unittest.mock.patch.object( + model.client.chat.completions, "create", new_callable=unittest.mock.AsyncMock + ) as mock_create: mock_chunk = unittest.mock.Mock() mock_chunk.event.event_type = "start" mock_chunk.event.stop_reason = "stop" @@ -438,7 +442,9 @@ async def test_tool_choice_not_supported_warns(model, messages, captured_warning @pytest.mark.asyncio async def test_tool_choice_none_no_warning(model, messages, captured_warnings, alist, agenerator): """Test that None toolChoice doesn't emit warning.""" - with unittest.mock.patch.object(model.client.chat.completions, "create", new_callable=unittest.mock.AsyncMock) as mock_create: + with unittest.mock.patch.object( + model.client.chat.completions, "create", new_callable=unittest.mock.AsyncMock + ) as mock_create: mock_chunk = unittest.mock.Mock() mock_chunk.event.event_type = "start" mock_chunk.event.stop_reason = "stop"