diff --git a/.gitignore b/.gitignore index e0358a7..ff9d58e 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ dist uv.lock .mypy_cache .ruff_cache +__pycache__ \ No newline at end of file diff --git a/README.md b/README.md index e58fb62..20bf667 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,20 @@ async with ( ... ``` +### Reasoning models + +Today you can access openapi-like reasoning models and retrieve their reasoning content: + +```python +async def main() -> None: + async with any_llm_client.get_client(config) as client: + llm_response = await client.request_llm_message("Кек, чо как вообще на нарах?") + print(f"Just a regular LLM response content: {llm_response.content}") + print(f"LLM reasoning response content: {llm_response.reasoning_content}") + + ... +``` + ### Other #### Mock client @@ -165,12 +179,12 @@ async with any_llm_client.OpenAIClient(config, ...) as client: #### Errors `any_llm_client.LLMClient.request_llm_message()` and `any_llm_client.LLMClient.stream_llm_message_chunks()` will raise: + - `any_llm_client.LLMError` or `any_llm_client.OutOfTokensOrSymbolsError` when the LLM API responds with a failed HTTP status, - `any_llm_client.LLMRequestValidationError` when images are passed to YandexGPT client. #### Timeouts, proxy & other HTTP settings - Pass custom [HTTPX](https://www.python-httpx.org) kwargs to `any_llm_client.get_client()`: ```python @@ -206,7 +220,6 @@ await client.request_llm_message("Кек, чо как вообще на нара The `extra` parameter is united with `request_extra` in OpenAIConfig - #### Passing images You can pass images to OpenAI client (YandexGPT doesn't support images yet): diff --git a/any_llm_client/__init__.py b/any_llm_client/__init__.py index 076f2f5..ad50259 100644 --- a/any_llm_client/__init__.py +++ b/any_llm_client/__init__.py @@ -10,6 +10,7 @@ LLMConfig, LLMError, LLMRequestValidationError, + LLMResponse, Message, MessageRole, OutOfTokensOrSymbolsError, @@ -31,6 +32,7 @@ "LLMConfig", "LLMError", "LLMRequestValidationError", + "LLMResponse", "Message", "MessageRole", "MockLLMClient", diff --git a/any_llm_client/clients/mock.py b/any_llm_client/clients/mock.py index a465eaf..7ee8a02 100644 --- a/any_llm_client/clients/mock.py +++ b/any_llm_client/clients/mock.py @@ -6,12 +6,12 @@ import pydantic import typing_extensions -from any_llm_client.core import LLMClient, LLMConfig, LLMConfigValue, Message +from any_llm_client.core import LLMClient, LLMConfig, LLMConfigValue, LLMResponse, Message class MockLLMConfig(LLMConfig): - response_message: str = "" - stream_messages: list[str] = pydantic.Field([]) + response_message: LLMResponse = LLMResponse(content="") + stream_messages: list[LLMResponse] = pydantic.Field([]) api_type: typing.Literal["mock"] = "mock" @@ -25,10 +25,10 @@ async def request_llm_message( *, temperature: float = LLMConfigValue(attr="temperature"), # noqa: ARG002 extra: dict[str, typing.Any] | None = None, # noqa: ARG002 - ) -> str: + ) -> LLMResponse: return self.config.response_message - async def _iter_config_stream_messages(self) -> typing.AsyncIterable[str]: + async def _iter_config_stream_messages(self) -> typing.AsyncIterable[LLMResponse]: for one_message in self.config.stream_messages: yield one_message @@ -39,7 +39,7 @@ async def stream_llm_message_chunks( *, temperature: float = LLMConfigValue(attr="temperature"), # noqa: ARG002 extra: dict[str, typing.Any] | None = None, # noqa: ARG002 - ) -> typing.AsyncIterator[typing.AsyncIterable[str]]: + ) -> typing.AsyncIterator[typing.AsyncIterable[LLMResponse]]: yield self._iter_config_stream_messages() async def __aenter__(self) -> typing_extensions.Self: diff --git a/any_llm_client/clients/openai.py b/any_llm_client/clients/openai.py index 01bbf1b..f295e33 100644 --- a/any_llm_client/clients/openai.py +++ b/any_llm_client/clients/openai.py @@ -16,6 +16,7 @@ LLMConfig, LLMConfigValue, LLMError, + LLMResponse, Message, MessageRole, OutOfTokensOrSymbolsError, @@ -76,6 +77,7 @@ class ChatCompletionsRequest(pydantic.BaseModel): class OneStreamingChoiceDelta(pydantic.BaseModel): role: typing.Literal[MessageRole.assistant] | None = None content: str | None = None + reasoning_content: str | None = None class OneStreamingChoice(pydantic.BaseModel): @@ -89,6 +91,7 @@ class ChatCompletionsStreamingEvent(pydantic.BaseModel): class OneNotStreamingChoiceMessage(pydantic.BaseModel): role: MessageRole content: str + reasoning_content: str | None = None class OneNotStreamingChoice(pydantic.BaseModel): @@ -143,14 +146,16 @@ def _make_user_assistant_alternate_messages( else: if current_message_content_chunks: yield ChatCompletionsInputMessage( - role=current_message_role, content=_merge_content_chunks(current_message_content_chunks) + role=current_message_role, + content=_merge_content_chunks(current_message_content_chunks), ) current_message_content_chunks = [one_message.content] current_message_role = one_message.role if current_message_content_chunks: yield ChatCompletionsInputMessage( - role=current_message_role, content=_merge_content_chunks(current_message_content_chunks) + role=current_message_role, + content=_merge_content_chunks(current_message_content_chunks), ) @@ -195,7 +200,12 @@ def _prepare_messages(self, messages: str | list[Message]) -> list[ChatCompletio ) def _prepare_payload( - self, *, messages: str | list[Message], temperature: float, stream: bool, extra: dict[str, typing.Any] | None + self, + *, + messages: str | list[Message], + temperature: float, + stream: bool, + extra: dict[str, typing.Any] | None, ) -> dict[str, typing.Any]: return ChatCompletionsRequest( stream=stream, @@ -211,9 +221,12 @@ async def request_llm_message( *, temperature: float = LLMConfigValue(attr="temperature"), extra: dict[str, typing.Any] | None = None, - ) -> str: + ) -> LLMResponse: payload: typing.Final = self._prepare_payload( - messages=messages, temperature=temperature, stream=False, extra=extra + messages=messages, + temperature=temperature, + stream=False, + extra=extra, ) try: response: typing.Final = await make_http_request( @@ -224,18 +237,27 @@ async def request_llm_message( except httpx.HTTPStatusError as exception: _handle_status_error(status_code=exception.response.status_code, content=exception.response.content) try: - return ChatCompletionsNotStreamingResponse.model_validate_json(response.content).choices[0].message.content + validated_message_model: typing.Final = ( + ChatCompletionsNotStreamingResponse.model_validate_json(response.content).choices[0].message + ) + return LLMResponse( + content=validated_message_model.content, + reasoning_content=validated_message_model.reasoning_content, + ) finally: await response.aclose() - async def _iter_response_chunks(self, response: httpx.Response) -> typing.AsyncIterable[str]: + async def _iter_response_chunks(self, response: httpx.Response) -> typing.AsyncIterable[LLMResponse]: async for event in httpx_sse.EventSource(response).aiter_sse(): if event.data == "[DONE]": break validated_response = ChatCompletionsStreamingEvent.model_validate_json(event.data) - if not (one_chunk := validated_response.choices[0].delta.content): + if not ( + (validated_delta := validated_response.choices[0].delta) + and (validated_delta.content or validated_delta.reasoning_content) + ): continue - yield one_chunk + yield LLMResponse(content=validated_delta.content, reasoning_content=validated_delta.reasoning_content) @contextlib.asynccontextmanager async def stream_llm_message_chunks( @@ -244,9 +266,12 @@ async def stream_llm_message_chunks( *, temperature: float = LLMConfigValue(attr="temperature"), extra: dict[str, typing.Any] | None = None, - ) -> typing.AsyncIterator[typing.AsyncIterable[str]]: + ) -> typing.AsyncIterator[typing.AsyncIterable[LLMResponse]]: payload: typing.Final = self._prepare_payload( - messages=messages, temperature=temperature, stream=True, extra=extra + messages=messages, + temperature=temperature, + stream=True, + extra=extra, ) try: async with make_streaming_http_request( diff --git a/any_llm_client/clients/yandexgpt.py b/any_llm_client/clients/yandexgpt.py index 5e2bae2..72afcf4 100644 --- a/any_llm_client/clients/yandexgpt.py +++ b/any_llm_client/clients/yandexgpt.py @@ -17,6 +17,7 @@ LLMConfigValue, LLMError, LLMRequestValidationError, + LLMResponse, Message, MessageRole, OutOfTokensOrSymbolsError, @@ -35,10 +36,12 @@ class YandexGPTConfig(LLMConfig): else: url: pydantic.HttpUrl = "https://llm.api.cloud.yandex.net/foundationModels/v1/completion" auth_header: str = pydantic.Field( # type: ignore[assignment] - default_factory=lambda: os.environ.get(YANDEXGPT_AUTH_HEADER_ENV_NAME), validate_default=True + default_factory=lambda: os.environ.get(YANDEXGPT_AUTH_HEADER_ENV_NAME), + validate_default=True, ) folder_id: str = pydantic.Field( # type: ignore[assignment] - default_factory=lambda: os.environ.get(YANDEXGPT_FOLDER_ID_ENV_NAME), validate_default=True + default_factory=lambda: os.environ.get(YANDEXGPT_FOLDER_ID_ENV_NAME), + validate_default=True, ) model_name: str model_version: str = "latest" @@ -126,7 +129,7 @@ def _prepare_payload( if isinstance(one_message.content, list): if len(one_message.content) != 1: raise LLMRequestValidationError( - "YandexGPTClient does not support multiple content items per message" + "YandexGPTClient does not support multiple content items per message", ) message_content = one_message.content[0] if isinstance(message_content, ImageContentItem): @@ -153,9 +156,12 @@ async def request_llm_message( *, temperature: float = LLMConfigValue(attr="temperature"), extra: dict[str, typing.Any] | None = None, - ) -> str: + ) -> LLMResponse: payload: typing.Final = self._prepare_payload( - messages=messages, temperature=temperature, stream=False, extra=extra + messages=messages, + temperature=temperature, + stream=False, + extra=extra, ) try: @@ -167,14 +173,16 @@ async def request_llm_message( except httpx.HTTPStatusError as exception: _handle_status_error(status_code=exception.response.status_code, content=exception.response.content) - return YandexGPTResponse.model_validate_json(response.content).result.alternatives[0].message.text + return LLMResponse( + content=YandexGPTResponse.model_validate_json(response.content).result.alternatives[0].message.text, + ) - async def _iter_response_chunks(self, response: httpx.Response) -> typing.AsyncIterable[str]: + async def _iter_response_chunks(self, response: httpx.Response) -> typing.AsyncIterable[LLMResponse]: previous_cursor = 0 async for one_line in response.aiter_lines(): validated_response = YandexGPTResponse.model_validate_json(one_line) response_text = validated_response.result.alternatives[0].message.text - yield response_text[previous_cursor:] + yield LLMResponse(content=response_text[previous_cursor:]) previous_cursor = len(response_text) @contextlib.asynccontextmanager @@ -184,9 +192,12 @@ async def stream_llm_message_chunks( *, temperature: float = LLMConfigValue(attr="temperature"), extra: dict[str, typing.Any] | None = None, - ) -> typing.AsyncIterator[typing.AsyncIterable[str]]: + ) -> typing.AsyncIterator[typing.AsyncIterable[LLMResponse]]: payload: typing.Final = self._prepare_payload( - messages=messages, temperature=temperature, stream=True, extra=extra + messages=messages, + temperature=temperature, + stream=True, + extra=extra, ) try: diff --git a/any_llm_client/core.py b/any_llm_client/core.py index 4a48c7c..c0cf1d9 100644 --- a/any_llm_client/core.py +++ b/any_llm_client/core.py @@ -39,6 +39,12 @@ class Message: content: str | ContentItemList +@pydantic.dataclasses.dataclass +class LLMResponse: + content: str | None = None + reasoning_content: str | None = None + + if typing.TYPE_CHECKING: @pydantic.dataclasses.dataclass @@ -55,6 +61,7 @@ class UserMessage(Message): class AssistantMessage(Message): role: typing.Literal[MessageRole.assistant] = pydantic.Field(MessageRole.assistant, init=False) content: str | ContentItemList + else: def SystemMessage(content: str | ContentItemList) -> Message: # noqa: N802 @@ -102,7 +109,7 @@ async def request_llm_message( *, temperature: float = LLMConfigValue(attr="temperature"), extra: dict[str, typing.Any] | None = None, - ) -> str: ... # raises LLMError, LLMRequestValidationError + ) -> LLMResponse: ... # raises LLMError, LLMRequestValidationError @contextlib.asynccontextmanager def stream_llm_message_chunks( @@ -111,7 +118,7 @@ def stream_llm_message_chunks( *, temperature: float = LLMConfigValue(attr="temperature"), extra: dict[str, typing.Any] | None = None, - ) -> typing.AsyncIterator[typing.AsyncIterable[str]]: ... # raises LLMError, LLMRequestValidationError + ) -> typing.AsyncIterator[typing.AsyncIterable[LLMResponse]]: ... # raises LLMError, LLMRequestValidationError async def __aenter__(self) -> typing_extensions.Self: ... async def __aexit__( diff --git a/examples/openai-image-input.py b/examples/openai-image-input.py index cde02a9..f0ff29a 100644 --- a/examples/openai-image-input.py +++ b/examples/openai-image-input.py @@ -17,7 +17,7 @@ content=[ any_llm_client.TextContentItem("What's on the image?"), any_llm_client.ImageContentItem(f"data:image/jpeg;base64,{base64.b64encode(image_content).decode('utf-8')}"), - ] + ], ) @@ -27,7 +27,7 @@ async def main() -> None: client.stream_llm_message_chunks(messages=[message]) as message_chunks, ): async for chunk in message_chunks: - print(chunk, end="", flush=True) + print(chunk.content, end="", flush=True) asyncio.run(main()) diff --git a/examples/openai-reasoning-response.py b/examples/openai-reasoning-response.py new file mode 100644 index 0000000..401afb6 --- /dev/null +++ b/examples/openai-reasoning-response.py @@ -0,0 +1,21 @@ +"""Install ollama and pull the model to run this script: `ollama pull qwen2.5-coder:1.5b`.""" + +import asyncio +import typing + +import any_llm_client + + +config = any_llm_client.OpenAIConfig(url="http://127.0.0.1:11434/v1/chat/completions", model_name="qwen2.5-coder:1.5b") + + +async def main() -> None: + async with any_llm_client.get_client(config) as client: + llm_response: typing.Final = await client.request_llm_message( + "Кек, чо как вообще на нарах? Порассуждай как философ.", + ) + print(llm_response.reasoning_content) + print(llm_response.content) + + +asyncio.run(main()) diff --git a/examples/openai-stream-advanced.py b/examples/openai-stream-advanced.py index 892f1b2..7fe0783 100644 --- a/examples/openai-stream-advanced.py +++ b/examples/openai-stream-advanced.py @@ -20,7 +20,7 @@ async def main() -> None: ) as message_chunks, ): async for chunk in message_chunks: - print(chunk, end="", flush=True) + print(chunk.content, end="", flush=True) asyncio.run(main()) diff --git a/examples/openai-stream.py b/examples/openai-stream.py index f1beafa..e16ac91 100644 --- a/examples/openai-stream.py +++ b/examples/openai-stream.py @@ -14,7 +14,7 @@ async def main() -> None: client.stream_llm_message_chunks("Кек, чо как вообще на нарах?") as message_chunks, ): async for chunk in message_chunks: - print(chunk, end="", flush=True) + print(chunk.content, end="", flush=True) asyncio.run(main()) diff --git a/tests/conftest.py b/tests/conftest.py index 1c9da09..34395cd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ from functools import reduce from itertools import combinations +import faker import pytest import stamina import typing_extensions @@ -10,6 +11,7 @@ from polyfactory.factories.typed_dict_factory import TypedDictFactory import any_llm_client +from any_llm_client.core import LLMResponse @pytest.fixture(scope="session", autouse=True) @@ -37,6 +39,14 @@ class ImageContentItemFactory(DataclassFactory[any_llm_client.ImageContentItem]) class TextContentItemFactory(DataclassFactory[any_llm_client.TextContentItem]): ... +class LLMResponseFactory(DataclassFactory[any_llm_client.LLMResponse]): ... + + +@pytest.fixture +def random_llm_response(faker: faker.Faker) -> any_llm_client.LLMResponse: + return LLMResponseFactory.build(content=faker.pystr()) + + def set_no_temperature(llm_func_request: LLMFuncRequest) -> LLMFuncRequest: llm_func_request.pop("temperature") return llm_func_request @@ -49,7 +59,7 @@ def set_no_extra(llm_func_request: LLMFuncRequest) -> LLMFuncRequest: def set_message_content_as_image_with_description(llm_func_request: LLMFuncRequest) -> LLMFuncRequest: llm_func_request["messages"] = [ - MessageFactory.build(content=[TextContentItemFactory.build(), ImageContentItemFactory.build()]) + MessageFactory.build(content=[TextContentItemFactory.build(), ImageContentItemFactory.build()]), ] return llm_func_request @@ -89,8 +99,10 @@ def coverage(cls, **kwargs: typing.Any) -> typing.Iterator[LLMFuncRequest]: # n async def consume_llm_message_chunks( - stream_llm_message_chunks_context_manager: contextlib._AsyncGeneratorContextManager[typing.AsyncIterable[str]], + stream_llm_message_chunks_context_manager: contextlib._AsyncGeneratorContextManager[ + typing.AsyncIterable[LLMResponse] + ], /, -) -> list[str]: +) -> list[LLMResponse]: async with stream_llm_message_chunks_context_manager as response_iterable: return [one_item async for one_item in response_iterable] diff --git a/tests/test_mock_client.py b/tests/test_mock_client.py index 9cd88c8..c6bdd75 100644 --- a/tests/test_mock_client.py +++ b/tests/test_mock_client.py @@ -21,6 +21,6 @@ async def test_mock_client_request_llm_message_returns_config_value(func_request async def test_mock_client_stream_llm_message_chunks_returns_config_value(func_request: LLMFuncRequest) -> None: config: typing.Final = MockLLMConfigFactory.build() response: typing.Final = await consume_llm_message_chunks( - any_llm_client.get_client(config).stream_llm_message_chunks(**func_request) + any_llm_client.get_client(config).stream_llm_message_chunks(**func_request), ) assert response == config.stream_messages diff --git a/tests/test_openai_client.py b/tests/test_openai_client.py index 4a90f7d..674a94a 100644 --- a/tests/test_openai_client.py +++ b/tests/test_openai_client.py @@ -25,26 +25,28 @@ class OpenAIConfigFactory(ModelFactory[any_llm_client.OpenAIConfig]): ... class TestOpenAIRequestLLMResponse: @pytest.mark.parametrize("func_request", LLMFuncRequestFactory.coverage()) - async def test_ok(self, faker: faker.Faker, func_request: LLMFuncRequest) -> None: - expected_result: typing.Final = faker.pystr() + async def test_ok(self, func_request: LLMFuncRequest, random_llm_response: any_llm_client.LLMResponse) -> None: response: typing.Final = httpx.Response( 200, json=ChatCompletionsNotStreamingResponse( choices=[ OneNotStreamingChoice( message=OneNotStreamingChoiceMessage( - role=any_llm_client.MessageRole.assistant, content=expected_result - ) - ) - ] + role=any_llm_client.MessageRole.assistant, + content=random_llm_response.content, + reasoning_content=random_llm_response.reasoning_content, + ), + ), + ], ).model_dump(mode="json"), ) result: typing.Final = await any_llm_client.get_client( - OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) + OpenAIConfigFactory.build(), + transport=httpx.MockTransport(lambda _: response), ).request_llm_message(**func_request) - assert result == expected_result + assert result == random_llm_response async def test_fails_without_alternatives(self) -> None: response: typing.Final = httpx.Response( @@ -52,7 +54,8 @@ async def test_fails_without_alternatives(self) -> None: json=ChatCompletionsNotStreamingResponse.model_construct(choices=[]).model_dump(mode="json"), ) client: typing.Final = any_llm_client.get_client( - OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) + OpenAIConfigFactory.build(), + transport=httpx.MockTransport(lambda _: response), ) with pytest.raises(pydantic.ValidationError): @@ -74,12 +77,12 @@ async def test_ok(self, faker: faker.Faker, func_request: LLMFuncRequest) -> Non OneStreamingChoiceDelta(), ] expected_result: typing.Final = [ - "H", - "i", - " t", - "here", - ". How is you", - "r day?", + any_llm_client.LLMResponse("H"), + any_llm_client.LLMResponse("i"), + any_llm_client.LLMResponse(" t"), + any_llm_client.LLMResponse("here"), + any_llm_client.LLMResponse(". How is you"), + any_llm_client.LLMResponse("r day?"), ] config: typing.Final = OpenAIConfigFactory.build() response_content: typing.Final = ( @@ -91,7 +94,9 @@ async def test_ok(self, faker: faker.Faker, func_request: LLMFuncRequest) -> Non + f"\n\ndata: [DONE]\n\ndata: {faker.pystr()}\n\n" ) response: typing.Final = httpx.Response( - 200, headers={"Content-Type": "text/event-stream"}, content=response_content + 200, + headers={"Content-Type": "text/event-stream"}, + content=response_content, ) client: typing.Final = any_llm_client.get_client(config, transport=httpx.MockTransport(lambda _: response)) @@ -104,10 +109,13 @@ async def test_fails_without_alternatives(self) -> None: f"data: {ChatCompletionsStreamingEvent.model_construct(choices=[]).model_dump_json()}\n\n" ) response: typing.Final = httpx.Response( - 200, headers={"Content-Type": "text/event-stream"}, content=response_content + 200, + headers={"Content-Type": "text/event-stream"}, + content=response_content, ) client: typing.Final = any_llm_client.get_client( - OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) + OpenAIConfigFactory.build(), + transport=httpx.MockTransport(lambda _: response), ) with pytest.raises(pydantic.ValidationError): @@ -119,7 +127,8 @@ class TestOpenAILLMErrors: @pytest.mark.parametrize("status_code", [400, 500]) async def test_fails_with_unknown_error(self, stream: bool, status_code: int) -> None: client: typing.Final = any_llm_client.get_client( - OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: httpx.Response(status_code)) + OpenAIConfigFactory.build(), + transport=httpx.MockTransport(lambda _: httpx.Response(status_code)), ) coroutine: typing.Final = ( @@ -143,7 +152,8 @@ async def test_fails_with_unknown_error(self, stream: bool, status_code: int) -> async def test_fails_with_out_of_tokens_error(self, stream: bool, content: bytes | None) -> None: response: typing.Final = httpx.Response(400, content=content) client: typing.Final = any_llm_client.get_client( - OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) + OpenAIConfigFactory.build(), + transport=httpx.MockTransport(lambda _: response), ) coroutine: typing.Final = ( @@ -187,7 +197,8 @@ class TestOpenAIMessageAlternation: [ ChatCompletionsInputMessage(role=any_llm_client.MessageRole.user, content="Hi there"), ChatCompletionsInputMessage( - role=any_llm_client.MessageRole.assistant, content="Hi! How can I help you?" + role=any_llm_client.MessageRole.assistant, + content="Hi! How can I help you?", ), ], ), @@ -200,7 +211,8 @@ class TestOpenAIMessageAlternation: [ ChatCompletionsInputMessage(role=any_llm_client.MessageRole.user, content="Hi there"), ChatCompletionsInputMessage( - role=any_llm_client.MessageRole.assistant, content="Hi! How can I help you?" + role=any_llm_client.MessageRole.assistant, + content="Hi! How can I help you?", ), ], ), @@ -230,7 +242,8 @@ class TestOpenAIMessageAlternation: content="Hi!\n\nI'm your answer to everything.\n\nHow can I help you?", ), ChatCompletionsInputMessage( - role=any_llm_client.MessageRole.user, content="Hi there\n\nWhy is the sky blue?" + role=any_llm_client.MessageRole.user, + content="Hi there\n\nWhy is the sky blue?", ), ChatCompletionsInputMessage(role=any_llm_client.MessageRole.assistant, content="Well..."), ChatCompletionsInputMessage(role=any_llm_client.MessageRole.user, content="Hmmm..."), @@ -243,7 +256,7 @@ class TestOpenAIMessageAlternation: [ any_llm_client.TextContentItem("Hi there"), any_llm_client.TextContentItem("Why is the sky blue?"), - ] + ], ), ], [ @@ -254,7 +267,7 @@ class TestOpenAIMessageAlternation: ChatCompletionsTextContentItem(text="Hi there"), ChatCompletionsTextContentItem(text="Why is the sky blue?"), ], - ) + ), ], ), ( @@ -280,19 +293,21 @@ class TestOpenAIMessageAlternation: ], ) def test_with_alternation( - self, messages: list[any_llm_client.Message], expected_result: list[ChatCompletionsInputMessage] + self, + messages: list[any_llm_client.Message], + expected_result: list[ChatCompletionsInputMessage], ) -> None: client: typing.Final = any_llm_client.OpenAIClient( - OpenAIConfigFactory.build(force_user_assistant_message_alternation=True) + OpenAIConfigFactory.build(force_user_assistant_message_alternation=True), ) assert client._prepare_messages(messages) == expected_result # noqa: SLF001 def test_without_alternation(self) -> None: client: typing.Final = any_llm_client.OpenAIClient( - OpenAIConfigFactory.build(force_user_assistant_message_alternation=False) + OpenAIConfigFactory.build(force_user_assistant_message_alternation=False), ) assert client._prepare_messages( # noqa: SLF001 - [any_llm_client.SystemMessage("Be nice"), any_llm_client.UserMessage("Hi there")] + [any_llm_client.SystemMessage("Be nice"), any_llm_client.UserMessage("Hi there")], ) == [ ChatCompletionsInputMessage(role=any_llm_client.MessageRole.system, content="Be nice"), ChatCompletionsInputMessage(role=any_llm_client.MessageRole.user, content="Hi there"), diff --git a/tests/test_yandexgpt_client.py b/tests/test_yandexgpt_client.py index 4ddeb7f..d662a42 100644 --- a/tests/test_yandexgpt_client.py +++ b/tests/test_yandexgpt_client.py @@ -1,6 +1,5 @@ import typing -import faker import httpx import pydantic import pytest @@ -56,24 +55,27 @@ def func_request_has_image_content_or_list_of_not_one_items(func_request: LLMFun class TestYandexGPTRequestLLMResponse: @pytest.mark.parametrize("func_request", LLMFuncRequestFactory.coverage()) - async def test_ok(self, faker: faker.Faker, func_request: LLMFuncRequest) -> None: - expected_result: typing.Final = faker.pystr() + async def test_ok(self, func_request: LLMFuncRequest, random_llm_response: any_llm_client.LLMResponse) -> None: response: typing.Final = httpx.Response( 200, json=YandexGPTResponse( result=YandexGPTResult( alternatives=[ YandexGPTAlternative( - message=YandexGPTMessage(role=any_llm_client.MessageRole.assistant, text=expected_result) - ) - ] - ) + message=YandexGPTMessage( + role=any_llm_client.MessageRole.assistant, + text=random_llm_response.content, + ), + ), + ], + ), ).model_dump(mode="json"), ) - async def make_request() -> str: + async def make_request() -> any_llm_client.LLMResponse: return await any_llm_client.get_client( - YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) + YandexGPTConfigFactory.build(), + transport=httpx.MockTransport(lambda _: response), ).request_llm_message(**func_request) if func_request_has_image_content_or_list_of_not_one_items(func_request): @@ -81,14 +83,16 @@ async def make_request() -> str: await make_request() else: result: typing.Final = await make_request() - assert result == expected_result + assert result.content == random_llm_response.content async def test_fails_without_alternatives(self) -> None: response: typing.Final = httpx.Response( - 200, json=YandexGPTResponse(result=YandexGPTResult.model_construct(alternatives=[])).model_dump(mode="json") + 200, + json=YandexGPTResponse(result=YandexGPTResult.model_construct(alternatives=[])).model_dump(mode="json"), ) client: typing.Final = any_llm_client.get_client( - YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) + YandexGPTConfigFactory.build(), + transport=httpx.MockTransport(lambda _: response), ) with pytest.raises(pydantic.ValidationError): @@ -97,8 +101,12 @@ async def test_fails_without_alternatives(self) -> None: class TestYandexGPTRequestLLMMessageChunks: @pytest.mark.parametrize("func_request", LLMFuncRequestFactory.coverage()) - async def test_ok(self, faker: faker.Faker, func_request: LLMFuncRequest) -> None: - expected_result: typing.Final = faker.pylist(value_types=[str]) + async def test_ok(self, func_request: LLMFuncRequest, random_llm_response: any_llm_client.LLMResponse) -> None: + assert random_llm_response.content + expected_result: typing.Final = [ + any_llm_client.LLMResponse(content="".join(random_llm_response.content[one_index : one_index + 1])) + for one_index in range(len(random_llm_response.content)) + ] config: typing.Final = YandexGPTConfigFactory.build() response_content: typing.Final = ( "\n".join( @@ -108,11 +116,11 @@ async def test_ok(self, faker: faker.Faker, func_request: LLMFuncRequest) -> Non YandexGPTAlternative( message=YandexGPTMessage( role=any_llm_client.MessageRole.assistant, - text="".join(expected_result[: one_index + 1]), - ) - ) - ] - ) + text="".join(random_llm_response.content[: one_index + 1]), + ), + ), + ], + ), ).model_dump_json() for one_index in range(len(expected_result)) ) @@ -120,11 +128,12 @@ async def test_ok(self, faker: faker.Faker, func_request: LLMFuncRequest) -> Non ) response: typing.Final = httpx.Response(200, content=response_content) - async def make_request() -> list[str]: + async def make_request() -> list[any_llm_client.LLMResponse]: return await consume_llm_message_chunks( any_llm_client.get_client( - config, transport=httpx.MockTransport(lambda _: response) - ).stream_llm_message_chunks(**func_request) + config, + transport=httpx.MockTransport(lambda _: response), + ).stream_llm_message_chunks(**func_request), ) if func_request_has_image_content_or_list_of_not_one_items(func_request): @@ -140,12 +149,13 @@ async def test_fails_without_alternatives(self) -> None: response: typing.Final = httpx.Response(200, content=response_content) client: typing.Final = any_llm_client.get_client( - YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) + YandexGPTConfigFactory.build(), + transport=httpx.MockTransport(lambda _: response), ) with pytest.raises(pydantic.ValidationError): await consume_llm_message_chunks( - client.stream_llm_message_chunks(**LLMFuncRequestWithTextContentMessagesFactory.build()) + client.stream_llm_message_chunks(**LLMFuncRequestWithTextContentMessagesFactory.build()), ) @@ -154,12 +164,13 @@ class TestYandexGPTLLMErrors: @pytest.mark.parametrize("status_code", [400, 500]) async def test_fails_with_unknown_error(self, stream: bool, status_code: int) -> None: client: typing.Final = any_llm_client.get_client( - YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: httpx.Response(status_code)) + YandexGPTConfigFactory.build(), + transport=httpx.MockTransport(lambda _: httpx.Response(status_code)), ) coroutine: typing.Final = ( consume_llm_message_chunks( - client.stream_llm_message_chunks(**LLMFuncRequestWithTextContentMessagesFactory.build()) + client.stream_llm_message_chunks(**LLMFuncRequestWithTextContentMessagesFactory.build()), ) if stream else client.request_llm_message(**LLMFuncRequestWithTextContentMessagesFactory.build()) @@ -180,12 +191,13 @@ async def test_fails_with_unknown_error(self, stream: bool, status_code: int) -> async def test_fails_with_out_of_tokens_error(self, stream: bool, response_content: bytes | None) -> None: response: typing.Final = httpx.Response(400, content=response_content) client: typing.Final = any_llm_client.get_client( - YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) + YandexGPTConfigFactory.build(), + transport=httpx.MockTransport(lambda _: response), ) coroutine: typing.Final = ( consume_llm_message_chunks( - client.stream_llm_message_chunks(**LLMFuncRequestWithTextContentMessagesFactory.build()) + client.stream_llm_message_chunks(**LLMFuncRequestWithTextContentMessagesFactory.build()), ) if stream else client.request_llm_message(**LLMFuncRequestWithTextContentMessagesFactory.build())