Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reasoning content #22

Merged
merged 3 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ dist
uv.lock
.mypy_cache
.ruff_cache
__pycache__
17 changes: 15 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,20 @@ async with (
...
```

### Reasoning models

Today you can access openapi-like reasoning models and retrieve their reasoning content:

```python
async def main() -> None:
async with any_llm_client.get_client(config) as client:
llm_response = await client.request_llm_message("Кек, чо как вообще на нарах?")
print(f"Just a regular LLM response content: {llm_response.content}")
print(f"LLM reasoning response content: {llm_response.reasoning_content}")

...
```

### Other

#### Mock client
Expand Down Expand Up @@ -165,12 +179,12 @@ async with any_llm_client.OpenAIClient(config, ...) as client:
#### Errors

`any_llm_client.LLMClient.request_llm_message()` and `any_llm_client.LLMClient.stream_llm_message_chunks()` will raise:

- `any_llm_client.LLMError` or `any_llm_client.OutOfTokensOrSymbolsError` when the LLM API responds with a failed HTTP status,
- `any_llm_client.LLMRequestValidationError` when images are passed to YandexGPT client.

#### Timeouts, proxy & other HTTP settings


Pass custom [HTTPX](https://www.python-httpx.org) kwargs to `any_llm_client.get_client()`:

```python
Expand Down Expand Up @@ -206,7 +220,6 @@ await client.request_llm_message("Кек, чо как вообще на нара

The `extra` parameter is united with `request_extra` in OpenAIConfig


#### Passing images

You can pass images to OpenAI client (YandexGPT doesn't support images yet):
Expand Down
2 changes: 2 additions & 0 deletions any_llm_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
LLMConfig,
LLMError,
LLMRequestValidationError,
LLMResponse,
Message,
MessageRole,
OutOfTokensOrSymbolsError,
Expand All @@ -31,6 +32,7 @@
"LLMConfig",
"LLMError",
"LLMRequestValidationError",
"LLMResponse",
"Message",
"MessageRole",
"MockLLMClient",
Expand Down
12 changes: 6 additions & 6 deletions any_llm_client/clients/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import pydantic
import typing_extensions

from any_llm_client.core import LLMClient, LLMConfig, LLMConfigValue, Message
from any_llm_client.core import LLMClient, LLMConfig, LLMConfigValue, LLMResponse, Message


class MockLLMConfig(LLMConfig):
response_message: str = ""
stream_messages: list[str] = pydantic.Field([])
response_message: LLMResponse = LLMResponse(content="")
stream_messages: list[LLMResponse] = pydantic.Field([])
api_type: typing.Literal["mock"] = "mock"


Expand All @@ -25,10 +25,10 @@ async def request_llm_message(
*,
temperature: float = LLMConfigValue(attr="temperature"), # noqa: ARG002
extra: dict[str, typing.Any] | None = None, # noqa: ARG002
) -> str:
) -> LLMResponse:
return self.config.response_message

async def _iter_config_stream_messages(self) -> typing.AsyncIterable[str]:
async def _iter_config_stream_messages(self) -> typing.AsyncIterable[LLMResponse]:
for one_message in self.config.stream_messages:
yield one_message

Expand All @@ -39,7 +39,7 @@ async def stream_llm_message_chunks(
*,
temperature: float = LLMConfigValue(attr="temperature"), # noqa: ARG002
extra: dict[str, typing.Any] | None = None, # noqa: ARG002
) -> typing.AsyncIterator[typing.AsyncIterable[str]]:
) -> typing.AsyncIterator[typing.AsyncIterable[LLMResponse]]:
yield self._iter_config_stream_messages()

async def __aenter__(self) -> typing_extensions.Self:
Expand Down
47 changes: 36 additions & 11 deletions any_llm_client/clients/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
LLMConfig,
LLMConfigValue,
LLMError,
LLMResponse,
Message,
MessageRole,
OutOfTokensOrSymbolsError,
Expand Down Expand Up @@ -76,6 +77,7 @@ class ChatCompletionsRequest(pydantic.BaseModel):
class OneStreamingChoiceDelta(pydantic.BaseModel):
role: typing.Literal[MessageRole.assistant] | None = None
content: str | None = None
reasoning_content: str | None = None


class OneStreamingChoice(pydantic.BaseModel):
Expand All @@ -89,6 +91,7 @@ class ChatCompletionsStreamingEvent(pydantic.BaseModel):
class OneNotStreamingChoiceMessage(pydantic.BaseModel):
role: MessageRole
content: str
reasoning_content: str | None = None


class OneNotStreamingChoice(pydantic.BaseModel):
Expand Down Expand Up @@ -143,14 +146,16 @@ def _make_user_assistant_alternate_messages(
else:
if current_message_content_chunks:
yield ChatCompletionsInputMessage(
role=current_message_role, content=_merge_content_chunks(current_message_content_chunks)
role=current_message_role,
content=_merge_content_chunks(current_message_content_chunks),
)
current_message_content_chunks = [one_message.content]
current_message_role = one_message.role

if current_message_content_chunks:
yield ChatCompletionsInputMessage(
role=current_message_role, content=_merge_content_chunks(current_message_content_chunks)
role=current_message_role,
content=_merge_content_chunks(current_message_content_chunks),
)


Expand Down Expand Up @@ -195,7 +200,12 @@ def _prepare_messages(self, messages: str | list[Message]) -> list[ChatCompletio
)

def _prepare_payload(
self, *, messages: str | list[Message], temperature: float, stream: bool, extra: dict[str, typing.Any] | None
self,
*,
messages: str | list[Message],
temperature: float,
stream: bool,
extra: dict[str, typing.Any] | None,
) -> dict[str, typing.Any]:
return ChatCompletionsRequest(
stream=stream,
Expand All @@ -211,9 +221,12 @@ async def request_llm_message(
*,
temperature: float = LLMConfigValue(attr="temperature"),
extra: dict[str, typing.Any] | None = None,
) -> str:
) -> LLMResponse:
payload: typing.Final = self._prepare_payload(
messages=messages, temperature=temperature, stream=False, extra=extra
messages=messages,
temperature=temperature,
stream=False,
extra=extra,
)
try:
response: typing.Final = await make_http_request(
Expand All @@ -224,18 +237,27 @@ async def request_llm_message(
except httpx.HTTPStatusError as exception:
_handle_status_error(status_code=exception.response.status_code, content=exception.response.content)
try:
return ChatCompletionsNotStreamingResponse.model_validate_json(response.content).choices[0].message.content
validated_message_model: typing.Final = (
ChatCompletionsNotStreamingResponse.model_validate_json(response.content).choices[0].message
)
return LLMResponse(
content=validated_message_model.content,
reasoning_content=validated_message_model.reasoning_content,
)
finally:
await response.aclose()

async def _iter_response_chunks(self, response: httpx.Response) -> typing.AsyncIterable[str]:
async def _iter_response_chunks(self, response: httpx.Response) -> typing.AsyncIterable[LLMResponse]:
async for event in httpx_sse.EventSource(response).aiter_sse():
if event.data == "[DONE]":
break
validated_response = ChatCompletionsStreamingEvent.model_validate_json(event.data)
if not (one_chunk := validated_response.choices[0].delta.content):
if not (
(validated_delta := validated_response.choices[0].delta)
and (validated_delta.content or validated_delta.reasoning_content)
):
continue
yield one_chunk
yield LLMResponse(content=validated_delta.content, reasoning_content=validated_delta.reasoning_content)

@contextlib.asynccontextmanager
async def stream_llm_message_chunks(
Expand All @@ -244,9 +266,12 @@ async def stream_llm_message_chunks(
*,
temperature: float = LLMConfigValue(attr="temperature"),
extra: dict[str, typing.Any] | None = None,
) -> typing.AsyncIterator[typing.AsyncIterable[str]]:
) -> typing.AsyncIterator[typing.AsyncIterable[LLMResponse]]:
payload: typing.Final = self._prepare_payload(
messages=messages, temperature=temperature, stream=True, extra=extra
messages=messages,
temperature=temperature,
stream=True,
extra=extra,
)
try:
async with make_streaming_http_request(
Expand Down
31 changes: 21 additions & 10 deletions any_llm_client/clients/yandexgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
LLMConfigValue,
LLMError,
LLMRequestValidationError,
LLMResponse,
Message,
MessageRole,
OutOfTokensOrSymbolsError,
Expand All @@ -35,10 +36,12 @@ class YandexGPTConfig(LLMConfig):
else:
url: pydantic.HttpUrl = "https://llm.api.cloud.yandex.net/foundationModels/v1/completion"
auth_header: str = pydantic.Field( # type: ignore[assignment]
default_factory=lambda: os.environ.get(YANDEXGPT_AUTH_HEADER_ENV_NAME), validate_default=True
default_factory=lambda: os.environ.get(YANDEXGPT_AUTH_HEADER_ENV_NAME),
validate_default=True,
)
folder_id: str = pydantic.Field( # type: ignore[assignment]
default_factory=lambda: os.environ.get(YANDEXGPT_FOLDER_ID_ENV_NAME), validate_default=True
default_factory=lambda: os.environ.get(YANDEXGPT_FOLDER_ID_ENV_NAME),
validate_default=True,
)
model_name: str
model_version: str = "latest"
Expand Down Expand Up @@ -126,7 +129,7 @@ def _prepare_payload(
if isinstance(one_message.content, list):
if len(one_message.content) != 1:
raise LLMRequestValidationError(
"YandexGPTClient does not support multiple content items per message"
"YandexGPTClient does not support multiple content items per message",
)
message_content = one_message.content[0]
if isinstance(message_content, ImageContentItem):
Expand All @@ -153,9 +156,12 @@ async def request_llm_message(
*,
temperature: float = LLMConfigValue(attr="temperature"),
extra: dict[str, typing.Any] | None = None,
) -> str:
) -> LLMResponse:
payload: typing.Final = self._prepare_payload(
messages=messages, temperature=temperature, stream=False, extra=extra
messages=messages,
temperature=temperature,
stream=False,
extra=extra,
)

try:
Expand All @@ -167,14 +173,16 @@ async def request_llm_message(
except httpx.HTTPStatusError as exception:
_handle_status_error(status_code=exception.response.status_code, content=exception.response.content)

return YandexGPTResponse.model_validate_json(response.content).result.alternatives[0].message.text
return LLMResponse(
content=YandexGPTResponse.model_validate_json(response.content).result.alternatives[0].message.text,
)

async def _iter_response_chunks(self, response: httpx.Response) -> typing.AsyncIterable[str]:
async def _iter_response_chunks(self, response: httpx.Response) -> typing.AsyncIterable[LLMResponse]:
previous_cursor = 0
async for one_line in response.aiter_lines():
validated_response = YandexGPTResponse.model_validate_json(one_line)
response_text = validated_response.result.alternatives[0].message.text
yield response_text[previous_cursor:]
yield LLMResponse(content=response_text[previous_cursor:])
previous_cursor = len(response_text)

@contextlib.asynccontextmanager
Expand All @@ -184,9 +192,12 @@ async def stream_llm_message_chunks(
*,
temperature: float = LLMConfigValue(attr="temperature"),
extra: dict[str, typing.Any] | None = None,
) -> typing.AsyncIterator[typing.AsyncIterable[str]]:
) -> typing.AsyncIterator[typing.AsyncIterable[LLMResponse]]:
payload: typing.Final = self._prepare_payload(
messages=messages, temperature=temperature, stream=True, extra=extra
messages=messages,
temperature=temperature,
stream=True,
extra=extra,
)

try:
Expand Down
11 changes: 9 additions & 2 deletions any_llm_client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ class Message:
content: str | ContentItemList


@pydantic.dataclasses.dataclass
class LLMResponse:
content: str | None = None
reasoning_content: str | None = None


if typing.TYPE_CHECKING:

@pydantic.dataclasses.dataclass
Expand All @@ -55,6 +61,7 @@ class UserMessage(Message):
class AssistantMessage(Message):
role: typing.Literal[MessageRole.assistant] = pydantic.Field(MessageRole.assistant, init=False)
content: str | ContentItemList

else:

def SystemMessage(content: str | ContentItemList) -> Message: # noqa: N802
Expand Down Expand Up @@ -102,7 +109,7 @@ async def request_llm_message(
*,
temperature: float = LLMConfigValue(attr="temperature"),
extra: dict[str, typing.Any] | None = None,
) -> str: ... # raises LLMError, LLMRequestValidationError
) -> LLMResponse: ... # raises LLMError, LLMRequestValidationError

@contextlib.asynccontextmanager
def stream_llm_message_chunks(
Expand All @@ -111,7 +118,7 @@ def stream_llm_message_chunks(
*,
temperature: float = LLMConfigValue(attr="temperature"),
extra: dict[str, typing.Any] | None = None,
) -> typing.AsyncIterator[typing.AsyncIterable[str]]: ... # raises LLMError, LLMRequestValidationError
) -> typing.AsyncIterator[typing.AsyncIterable[LLMResponse]]: ... # raises LLMError, LLMRequestValidationError

async def __aenter__(self) -> typing_extensions.Self: ...
async def __aexit__(
Expand Down
4 changes: 2 additions & 2 deletions examples/openai-image-input.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
content=[
any_llm_client.TextContentItem("What's on the image?"),
any_llm_client.ImageContentItem(f"data:image/jpeg;base64,{base64.b64encode(image_content).decode('utf-8')}"),
]
],
)


Expand All @@ -27,7 +27,7 @@ async def main() -> None:
client.stream_llm_message_chunks(messages=[message]) as message_chunks,
):
async for chunk in message_chunks:
print(chunk, end="", flush=True)
print(chunk.content, end="", flush=True)


asyncio.run(main())
21 changes: 21 additions & 0 deletions examples/openai-reasoning-response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Install ollama and pull the model to run this script: `ollama pull qwen2.5-coder:1.5b`."""

import asyncio
import typing

import any_llm_client


config = any_llm_client.OpenAIConfig(url="http://127.0.0.1:11434/v1/chat/completions", model_name="qwen2.5-coder:1.5b")


async def main() -> None:
async with any_llm_client.get_client(config) as client:
llm_response: typing.Final = await client.request_llm_message(
"Кек, чо как вообще на нарах? Порассуждай как философ.",
)
print(llm_response.reasoning_content)
print(llm_response.content)


asyncio.run(main())
2 changes: 1 addition & 1 deletion examples/openai-stream-advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def main() -> None:
) as message_chunks,
):
async for chunk in message_chunks:
print(chunk, end="", flush=True)
print(chunk.content, end="", flush=True)


asyncio.run(main())
Loading