diff --git a/.gitignore b/.gitignore index 8b0fd989c..0b1375b50 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ repl_state .kiro uv.lock .audio_cache +CLAUDE.md diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 013cd2c7d..ed511a05c 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -11,7 +11,6 @@ from typing import Any, AsyncGenerator, Optional, Type, TypeVar, Union, cast import llama_api_client -from llama_api_client import LlamaAPIClient from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override @@ -65,10 +64,8 @@ def __init__( self.config = LlamaAPIModel.LlamaConfig(**model_config) logger.debug("config=<%s> | initializing", self.config) - if not client_args: - self.client = LlamaAPIClient() - else: - self.client = LlamaAPIClient(**client_args) + client_args = client_args or {} + self.client = llama_api_client.AsyncLlamaAPIClient(**client_args) @override def update_config(self, **model_config: Unpack[LlamaConfig]) -> None: # type: ignore @@ -358,7 +355,7 @@ async def stream( logger.debug("invoking model") try: - response = self.client.chat.completions.create(**request) + response = await self.client.chat.completions.create(**request) except llama_api_client.RateLimitError as e: raise ModelThrottledException(str(e)) from e @@ -370,7 +367,7 @@ async def stream( curr_tool_call_id = None metrics_event = None - for chunk in response: + async for chunk in response: if chunk.event.event_type == "start": yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) elif chunk.event.event_type in ["progress", "complete"] and chunk.event.delta.type == "text": diff --git a/tests/strands/models/test_llamaapi.py b/tests/strands/models/test_llamaapi.py index a6bbf5673..2d9caeaea 100644 --- a/tests/strands/models/test_llamaapi.py +++ b/tests/strands/models/test_llamaapi.py @@ -9,7 +9,7 @@ @pytest.fixture def llamaapi_client(): - with unittest.mock.patch.object(strands.models.llamaapi, "LlamaAPIClient") as mock_client_cls: + with unittest.mock.patch.object(strands.models.llamaapi.llama_api_client, "AsyncLlamaAPIClient") as mock_client_cls: yield mock_client_cls.return_value @@ -363,6 +363,43 @@ def test_format_chunk_other(model): model.format_chunk(event) +@pytest.mark.asyncio +async def test_stream(llamaapi_client, model, agenerator, alist): + mock_event_1 = unittest.mock.Mock(event=unittest.mock.Mock(event_type="start", stop_reason=None)) + mock_event_2 = unittest.mock.Mock( + event=unittest.mock.Mock( + delta=unittest.mock.Mock(text="test stream", type="text"), + event_type="complete", + stop_reason="end_turn", + ), + ) + + llamaapi_client.chat.completions.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_event_1, mock_event_2]) + ) + + messages = [{"role": "user", "content": [{"text": "calculate 2+2"}]}] + response = model.stream(messages) + + tru_events = await alist(response) + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "test stream"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + assert tru_events == exp_events + + expected_request = { + "model": "Llama-4-Maverick-17B-128E-Instruct-FP8", + "messages": [{"role": "user", "content": [{"text": "calculate 2+2", "type": "text"}]}], + "stream": True, + "tools": [], + } + llamaapi_client.chat.completions.create.assert_called_once_with(**expected_request) + + def test_config_validation_warns_on_unknown_keys(llamaapi_client, captured_warnings): """Test that unknown config keys emit a warning.""" LlamaAPIModel(model_id="test-model", invalid_param="test") @@ -382,16 +419,18 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings @pytest.mark.asyncio -async def test_tool_choice_not_supported_warns(model, messages, captured_warnings, alist): +async def test_tool_choice_not_supported_warns(model, messages, captured_warnings, alist, agenerator): """Test that non-None toolChoice emits warning for unsupported providers.""" tool_choice = {"auto": {}} - with unittest.mock.patch.object(model.client.chat.completions, "create") as mock_create: + with unittest.mock.patch.object( + model.client.chat.completions, "create", new_callable=unittest.mock.AsyncMock + ) as mock_create: mock_chunk = unittest.mock.Mock() mock_chunk.event.event_type = "start" mock_chunk.event.stop_reason = "stop" - mock_create.return_value = [mock_chunk] + mock_create.return_value = agenerator([mock_chunk]) response = model.stream(messages, tool_choice=tool_choice) await alist(response) @@ -401,14 +440,16 @@ async def test_tool_choice_not_supported_warns(model, messages, captured_warning @pytest.mark.asyncio -async def test_tool_choice_none_no_warning(model, messages, captured_warnings, alist): +async def test_tool_choice_none_no_warning(model, messages, captured_warnings, alist, agenerator): """Test that None toolChoice doesn't emit warning.""" - with unittest.mock.patch.object(model.client.chat.completions, "create") as mock_create: + with unittest.mock.patch.object( + model.client.chat.completions, "create", new_callable=unittest.mock.AsyncMock + ) as mock_create: mock_chunk = unittest.mock.Mock() mock_chunk.event.event_type = "start" mock_chunk.event.stop_reason = "stop" - mock_create.return_value = [mock_chunk] + mock_create.return_value = agenerator([mock_chunk]) response = model.stream(messages, tool_choice=None) await alist(response) diff --git a/tests_integ/models/test_model_llamaapi.py b/tests_integ/models/test_model_llamaapi.py index b36a63a28..322f16337 100644 --- a/tests_integ/models/test_model_llamaapi.py +++ b/tests_integ/models/test_model_llamaapi.py @@ -40,8 +40,28 @@ def agent(model, tools): return Agent(model=model, tools=tools) -def test_agent(agent): +def test_agent_invoke(agent): result = agent("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.asyncio +async def test_agent_invoke_async(agent): + result = await agent.invoke_async("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.asyncio +async def test_agent_stream_async(agent): + stream = agent.stream_async("What is the time and weather in New York?") + async for event in stream: + _ = event + + result = event["result"] + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"])