Skip to content

Commit 2174cd0

Browse files
authored
Merge pull request #21 from community-of-python/images
Add image support for OpenAI client
2 parents dcf98a3 + 4e0cb3d commit 2174cd0

16 files changed

+417
-86
lines changed

README.md

+32-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import any_llm_client
2424

2525

2626
config = any_llm_client.OpenAIConfig(
27-
url="http://127.0.0.1:11434/v1/chat/completions",
27+
url="http://127.0.0.1:11434/v1/chat/completions",
2828
model_name="qwen2.5-coder:1.5b",
2929
request_extra={"best_of": 3}
3030
)
@@ -57,7 +57,7 @@ import any_llm_client
5757

5858

5959
config = any_llm_client.OpenAIConfig(
60-
url="http://127.0.0.1:11434/v1/chat/completions",
60+
url="http://127.0.0.1:11434/v1/chat/completions",
6161
model_name="qwen2.5-coder:1.5b",
6262
request_extra={"best_of": 3}
6363
)
@@ -164,7 +164,9 @@ async with any_llm_client.OpenAIClient(config, ...) as client:
164164

165165
#### Errors
166166

167-
`any_llm_client.LLMClient.request_llm_message()` and `any_llm_client.LLMClient.stream_llm_message_chunks()` will raise `any_llm_client.LLMError` or `any_llm_client.OutOfTokensOrSymbolsError` when the LLM API responds with a failed HTTP status.
167+
`any_llm_client.LLMClient.request_llm_message()` and `any_llm_client.LLMClient.stream_llm_message_chunks()` will raise:
168+
- `any_llm_client.LLMError` or `any_llm_client.OutOfTokensOrSymbolsError` when the LLM API responds with a failed HTTP status,
169+
- `any_llm_client.LLMRequestValidationError` when images are passed to YandexGPT client.
168170

169171
#### Timeouts, proxy & other HTTP settings
170172

@@ -203,3 +205,30 @@ await client.request_llm_message("Кек, чо как вообще на нара
203205
```
204206

205207
The `extra` parameter is united with `request_extra` in OpenAIConfig
208+
209+
210+
#### Passing images
211+
212+
You can pass images to OpenAI client (YandexGPT doesn't support images yet):
213+
214+
```python
215+
await client.request_llm_message(
216+
messages=[
217+
any_llm_client.TextContentItem("What's on the image?"),
218+
any_llm_client.ImageContentItem("https://upload.wikimedia.org/wikipedia/commons/a/a9/Example.jpg"),
219+
]
220+
)
221+
```
222+
223+
You can also pass a data url with base64-encoded image:
224+
225+
```python
226+
await client.request_llm_message(
227+
messages=[
228+
any_llm_client.TextContentItem("What's on the image?"),
229+
any_llm_client.ImageContentItem(
230+
f"data:image/jpeg;base64,{base64.b64encode(image_content_bytes).decode('utf-8')}"
231+
),
232+
]
233+
)
234+
```

any_llm_client/__init__.py

+10
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,35 @@
22
from any_llm_client.clients.openai import OpenAIClient, OpenAIConfig
33
from any_llm_client.clients.yandexgpt import YandexGPTClient, YandexGPTConfig
44
from any_llm_client.core import (
5+
AnyContentItem,
56
AssistantMessage,
7+
ContentItemList,
8+
ImageContentItem,
69
LLMClient,
710
LLMConfig,
811
LLMError,
12+
LLMRequestValidationError,
913
Message,
1014
MessageRole,
1115
OutOfTokensOrSymbolsError,
1216
SystemMessage,
17+
TextContentItem,
1318
UserMessage,
1419
)
1520
from any_llm_client.main import AnyLLMConfig, get_client
1621
from any_llm_client.retry import RequestRetryConfig
1722

1823

1924
__all__ = [
25+
"AnyContentItem",
2026
"AnyLLMConfig",
2127
"AssistantMessage",
28+
"ContentItemList",
29+
"ImageContentItem",
2230
"LLMClient",
2331
"LLMConfig",
2432
"LLMError",
33+
"LLMRequestValidationError",
2534
"Message",
2635
"MessageRole",
2736
"MockLLMClient",
@@ -31,6 +40,7 @@
3140
"OutOfTokensOrSymbolsError",
3241
"RequestRetryConfig",
3342
"SystemMessage",
43+
"TextContentItem",
3444
"UserMessage",
3545
"YandexGPTClient",
3646
"YandexGPTConfig",

any_llm_client/clients/openai.py

+65-14
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Message,
2020
MessageRole,
2121
OutOfTokensOrSymbolsError,
22+
TextContentItem,
2223
UserMessage,
2324
)
2425
from any_llm_client.http import get_http_client_from_kwargs, make_http_request, make_streaming_http_request
@@ -41,16 +42,34 @@ class OpenAIConfig(LLMConfig):
4142
api_type: typing.Literal["openai"] = "openai"
4243

4344

44-
class ChatCompletionsMessage(pydantic.BaseModel):
45+
class ChatCompletionsTextContentItem(pydantic.BaseModel):
46+
type: typing.Literal["text"] = "text"
47+
text: str
48+
49+
50+
class ChatCompletionsContentUrl(pydantic.BaseModel):
51+
url: str
52+
53+
54+
class ChatCompletionsImageContentItem(pydantic.BaseModel):
55+
type: typing.Literal["image_url"] = "image_url"
56+
image_url: ChatCompletionsContentUrl
57+
58+
59+
ChatCompletionsAnyContentItem = ChatCompletionsImageContentItem | ChatCompletionsTextContentItem
60+
ChatCompletionsContentItemList = typing.Annotated[list[ChatCompletionsAnyContentItem], annotated_types.MinLen(1)]
61+
62+
63+
class ChatCompletionsInputMessage(pydantic.BaseModel):
4564
role: MessageRole
46-
content: str
65+
content: str | ChatCompletionsContentItemList
4766

4867

4968
class ChatCompletionsRequest(pydantic.BaseModel):
5069
model_config = pydantic.ConfigDict(extra="allow")
5170
stream: bool
5271
model: str
53-
messages: list[ChatCompletionsMessage]
72+
messages: list[ChatCompletionsInputMessage]
5473
temperature: float
5574

5675

@@ -67,22 +86,54 @@ class ChatCompletionsStreamingEvent(pydantic.BaseModel):
6786
choices: typing.Annotated[list[OneStreamingChoice], annotated_types.MinLen(1)]
6887

6988

89+
class OneNotStreamingChoiceMessage(pydantic.BaseModel):
90+
role: MessageRole
91+
content: str
92+
93+
7094
class OneNotStreamingChoice(pydantic.BaseModel):
71-
message: ChatCompletionsMessage
95+
message: OneNotStreamingChoiceMessage
7296

7397

7498
class ChatCompletionsNotStreamingResponse(pydantic.BaseModel):
7599
choices: typing.Annotated[list[OneNotStreamingChoice], annotated_types.MinLen(1)]
76100

77101

102+
def _prepare_one_message(one_message: Message) -> ChatCompletionsInputMessage:
103+
if isinstance(one_message.content, str):
104+
return ChatCompletionsInputMessage(role=one_message.role, content=one_message.content)
105+
content_items: typing.Final = [
106+
ChatCompletionsTextContentItem(text=one_content_item.text)
107+
if isinstance(one_content_item, TextContentItem)
108+
else ChatCompletionsImageContentItem(image_url=ChatCompletionsContentUrl(url=one_content_item.image_url))
109+
for one_content_item in one_message.content
110+
]
111+
return ChatCompletionsInputMessage(role=one_message.role, content=content_items)
112+
113+
114+
def _merge_content_chunks(
115+
content_chunks: list[str | ChatCompletionsContentItemList],
116+
) -> str | ChatCompletionsContentItemList:
117+
if all(isinstance(one_content_chunk, str) for one_content_chunk in content_chunks):
118+
return "\n\n".join(typing.cast("list[str]", content_chunks))
119+
120+
new_content_items: ChatCompletionsContentItemList = []
121+
for one_content_chunk in content_chunks:
122+
if isinstance(one_content_chunk, str):
123+
new_content_items.append(ChatCompletionsTextContentItem(text=one_content_chunk))
124+
else:
125+
new_content_items += one_content_chunk
126+
return new_content_items
127+
128+
78129
def _make_user_assistant_alternate_messages(
79-
messages: typing.Iterable[ChatCompletionsMessage],
80-
) -> typing.Iterable[ChatCompletionsMessage]:
130+
messages: typing.Iterable[ChatCompletionsInputMessage],
131+
) -> typing.Iterable[ChatCompletionsInputMessage]:
81132
current_message_role = MessageRole.user
82133
current_message_content_chunks = []
83134

84135
for one_message in messages:
85-
if not one_message.content.strip():
136+
if isinstance(one_message.content, str) and not one_message.content.strip():
86137
continue
87138

88139
if (
@@ -91,14 +142,16 @@ def _make_user_assistant_alternate_messages(
91142
current_message_content_chunks.append(one_message.content)
92143
else:
93144
if current_message_content_chunks:
94-
yield ChatCompletionsMessage(
95-
role=current_message_role, content="\n\n".join(current_message_content_chunks)
145+
yield ChatCompletionsInputMessage(
146+
role=current_message_role, content=_merge_content_chunks(current_message_content_chunks)
96147
)
97148
current_message_content_chunks = [one_message.content]
98149
current_message_role = one_message.role
99150

100151
if current_message_content_chunks:
101-
yield ChatCompletionsMessage(role=current_message_role, content="\n\n".join(current_message_content_chunks))
152+
yield ChatCompletionsInputMessage(
153+
role=current_message_role, content=_merge_content_chunks(current_message_content_chunks)
154+
)
102155

103156

104157
def _handle_status_error(*, status_code: int, content: bytes) -> typing.NoReturn:
@@ -132,11 +185,9 @@ def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request:
132185
headers={"Authorization": f"Bearer {self.config.auth_token}"} if self.config.auth_token else None,
133186
)
134187

135-
def _prepare_messages(self, messages: str | list[Message]) -> list[ChatCompletionsMessage]:
188+
def _prepare_messages(self, messages: str | list[Message]) -> list[ChatCompletionsInputMessage]:
136189
messages = [UserMessage(messages)] if isinstance(messages, str) else messages
137-
initial_messages: typing.Final = (
138-
ChatCompletionsMessage(role=one_message.role, content=one_message.text) for one_message in messages
139-
)
190+
initial_messages: typing.Final = (_prepare_one_message(one_message) for one_message in messages)
140191
return (
141192
list(_make_user_assistant_alternate_messages(initial_messages))
142193
if self.config.force_user_assistant_message_alternation

any_llm_client/clients/yandexgpt.py

+29-5
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
import typing_extensions
1212

1313
from any_llm_client.core import (
14+
ImageContentItem,
1415
LLMClient,
1516
LLMConfig,
1617
LLMConfigValue,
1718
LLMError,
19+
LLMRequestValidationError,
1820
Message,
21+
MessageRole,
1922
OutOfTokensOrSymbolsError,
20-
UserMessage,
2123
)
2224
from any_llm_client.http import get_http_client_from_kwargs, make_http_request, make_streaming_http_request
2325
from any_llm_client.retry import RequestRetryConfig
@@ -50,15 +52,20 @@ class YandexGPTCompletionOptions(pydantic.BaseModel):
5052
max_tokens: int = pydantic.Field(gt=0, alias="maxTokens")
5153

5254

55+
class YandexGPTMessage(pydantic.BaseModel):
56+
role: MessageRole
57+
text: str
58+
59+
5360
class YandexGPTRequest(pydantic.BaseModel):
5461
model_config = pydantic.ConfigDict(protected_namespaces=(), extra="allow")
5562
model_uri: str = pydantic.Field(alias="modelUri")
5663
completion_options: YandexGPTCompletionOptions = pydantic.Field(alias="completionOptions")
57-
messages: list[Message]
64+
messages: list[YandexGPTMessage]
5865

5966

6067
class YandexGPTAlternative(pydantic.BaseModel):
61-
message: Message
68+
message: YandexGPTMessage
6269

6370

6471
class YandexGPTResult(pydantic.BaseModel):
@@ -111,15 +118,32 @@ def _prepare_payload(
111118
stream: bool,
112119
extra: dict[str, typing.Any] | None,
113120
) -> dict[str, typing.Any]:
114-
messages = [UserMessage(messages)] if isinstance(messages, str) else messages
121+
if isinstance(messages, str):
122+
prepared_messages = [YandexGPTMessage(role=MessageRole.user, text=messages)]
123+
else:
124+
prepared_messages = []
125+
for one_message in messages:
126+
if isinstance(one_message.content, list):
127+
if len(one_message.content) != 1:
128+
raise LLMRequestValidationError(
129+
"YandexGPTClient does not support multiple content items per message"
130+
)
131+
message_content = one_message.content[0]
132+
if isinstance(message_content, ImageContentItem):
133+
raise LLMRequestValidationError("YandexGPTClient does not support image content items")
134+
message_text = message_content.text
135+
else:
136+
message_text = one_message.content
137+
prepared_messages.append(YandexGPTMessage(role=one_message.role, text=message_text))
138+
115139
return YandexGPTRequest(
116140
modelUri=f"gpt://{self.config.folder_id}/{self.config.model_name}/{self.config.model_version}",
117141
completionOptions=YandexGPTCompletionOptions(
118142
stream=stream,
119143
temperature=self.config._resolve_request_temperature(temperature), # noqa: SLF001
120144
maxTokens=self.config.max_tokens,
121145
),
122-
messages=messages,
146+
messages=prepared_messages,
123147
**self.config.request_extra | (extra or {}),
124148
).model_dump(mode="json", by_alias=True)
125149

0 commit comments

Comments
 (0)