Skip to content

Commit

Permalink
Start moving to niquests
Browse files Browse the repository at this point in the history
  • Loading branch information
vrslev committed Nov 21, 2024
1 parent f26373a commit d51f5ae
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 73 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ import any_llm_client

async with any_llm_client.get_client(
...,
httpx_client=httpx.AsyncClient(
httpx_client=niquests.AsyncSession(
mounts={"https://api.openai.com": httpx.AsyncHTTPTransport(proxy="http://localhost:8030")},
timeout=httpx.Timeout(None, connect=5.0),
),
Expand Down
52 changes: 32 additions & 20 deletions any_llm_client/clients/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from http import HTTPStatus

import annotated_types
import httpx
import httpx_sse
import niquests
import pydantic
import typing_extensions

Expand Down Expand Up @@ -92,25 +92,27 @@ def _handle_status_error(*, status_code: int, content: bytes) -> typing.NoReturn
@dataclasses.dataclass(slots=True, init=False)
class OpenAIClient(LLMClient):
config: OpenAIConfig
httpx_client: httpx.AsyncClient
httpx_client: niquests.AsyncSession
request_retry: RequestRetryConfig

def __init__(
self,
config: OpenAIConfig,
httpx_client: httpx.AsyncClient | None = None,
httpx_client: niquests.AsyncSession | None = None,
request_retry: RequestRetryConfig | None = None,
) -> None:
self.config = config
self.httpx_client = httpx_client or httpx.AsyncClient()
self.httpx_client = httpx_client or niquests.AsyncSession()
self.request_retry = request_retry or RequestRetryConfig()

def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request:
return self.httpx_client.build_request(
method="POST",
url=str(self.config.url),
json=payload,
headers={"Authorization": f"Bearer {self.config.auth_token}"} if self.config.auth_token else None,
def _build_request(self, payload: dict[str, typing.Any]) -> niquests.PreparedRequest:
return self.httpx_client.prepare_request(
niquests.Request(
method="POST",
url=str(self.config.url),
json=payload,
headers={"Authorization": f"Bearer {self.config.auth_token}"} if self.config.auth_token else None,
)
)

def _prepare_messages(self, messages: str | list[Message]) -> list[ChatCompletionsMessage]:
Expand All @@ -137,14 +139,21 @@ async def request_llm_message(self, messages: str | list[Message], temperature:
request_retry=self.request_retry,
build_request=lambda: self._build_request(payload),
)
except httpx.HTTPStatusError as exception:
_handle_status_error(status_code=exception.response.status_code, content=exception.response.content)
except niquests.HTTPError as exception:
if exception.response and exception.response.status_code and exception.response.content:
_handle_status_error(status_code=exception.response.status_code, content=exception.response.content)
else:
raise
try:
return ChatCompletionsNotStreamingResponse.model_validate_json(response.content).choices[0].message.content
return (
ChatCompletionsNotStreamingResponse.model_validate_json(response.content) # type: ignore[arg-type]
.choices[0]
.message.content
)
finally:
await response.aclose()

async def _iter_partial_responses(self, response: httpx.Response) -> typing.AsyncIterable[str]:
async def _iter_partial_responses(self, response: niquests.AsyncResponse) -> typing.AsyncIterable[str]:
text_chunks: typing.Final = []
async for event in httpx_sse.EventSource(response).aiter_sse():
if event.data == "[DONE]":
Expand Down Expand Up @@ -172,13 +181,16 @@ async def stream_llm_partial_messages(
build_request=lambda: self._build_request(payload),
) as response:
yield self._iter_partial_responses(response)
except httpx.HTTPStatusError as exception:
content: typing.Final = await exception.response.aread()
await exception.response.aclose()
_handle_status_error(status_code=exception.response.status_code, content=content)
except niquests.HTTPError as exception:
if exception.response and exception.response.status_code and exception.response.content:
content: typing.Final = exception.response.content
exception.response.close()
_handle_status_error(status_code=exception.response.status_code, content=content)
else:
raise

async def __aenter__(self) -> typing_extensions.Self:
await self.httpx_client.__aenter__()
await self.httpx_client.__aenter__() # type: ignore[no-untyped-call]
return self

async def __aexit__(
Expand All @@ -187,4 +199,4 @@ async def __aexit__(
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
await self.httpx_client.__aexit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback)
await self.httpx_client.__aexit__(exc_type, exc_value, traceback) # type: ignore[no-untyped-call]
48 changes: 27 additions & 21 deletions any_llm_client/clients/yandexgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from http import HTTPStatus

import annotated_types
import httpx
import niquests
import pydantic
import typing_extensions

Expand Down Expand Up @@ -64,25 +64,27 @@ def _handle_status_error(*, status_code: int, content: bytes) -> typing.NoReturn
@dataclasses.dataclass(slots=True, init=False)
class YandexGPTClient(LLMClient):
config: YandexGPTConfig
httpx_client: httpx.AsyncClient
httpx_client: niquests.AsyncSession
request_retry: RequestRetryConfig

def __init__(
self,
config: YandexGPTConfig,
httpx_client: httpx.AsyncClient | None = None,
httpx_client: niquests.AsyncSession | None = None,
request_retry: RequestRetryConfig | None = None,
) -> None:
self.config = config
self.httpx_client = httpx_client or httpx.AsyncClient()
self.httpx_client = httpx_client or niquests.AsyncSession()
self.request_retry = request_retry or RequestRetryConfig()

def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request:
return self.httpx_client.build_request(
method="POST",
url=str(self.config.url),
json=payload,
headers={"Authorization": self.config.auth_header, "x-data-logging-enabled": "false"},
def _build_request(self, payload: dict[str, typing.Any]) -> niquests.PreparedRequest:
return self.httpx_client.prepare_request(
niquests.Request(
method="POST",
url=str(self.config.url),
json=payload,
headers={"Authorization": self.config.auth_header, "x-data-logging-enabled": "false"},
)
)

def _prepare_payload(
Expand All @@ -106,13 +108,16 @@ async def request_llm_message(self, messages: str | list[Message], temperature:
request_retry=self.request_retry,
build_request=lambda: self._build_request(payload),
)
except httpx.HTTPStatusError as exception:
_handle_status_error(status_code=exception.response.status_code, content=exception.response.content)
except niquests.HTTPError as exception:
if exception.response and exception.response.status_code and exception.response.content:
_handle_status_error(status_code=exception.response.status_code, content=exception.response.content)
else:
raise

return YandexGPTResponse.model_validate_json(response.content).result.alternatives[0].message.text
return YandexGPTResponse.model_validate_json(response.content).result.alternatives[0].message.text # type: ignore[arg-type]

async def _iter_completion_messages(self, response: httpx.Response) -> typing.AsyncIterable[str]:
async for one_line in response.aiter_lines():
async def _iter_completion_messages(self, response: niquests.AsyncResponse) -> typing.AsyncIterable[str]:
async for one_line in response.iter_lines():
validated_response = YandexGPTResponse.model_validate_json(one_line)
yield validated_response.result.alternatives[0].message.text

Expand All @@ -129,13 +134,14 @@ async def stream_llm_partial_messages(
build_request=lambda: self._build_request(payload),
) as response:
yield self._iter_completion_messages(response)
except httpx.HTTPStatusError as exception:
content: typing.Final = await exception.response.aread()
await exception.response.aclose()
_handle_status_error(status_code=exception.response.status_code, content=content)
except niquests.HTTPError as exception:
if exception.response and exception.response.status_code and exception.response.content:
content: typing.Final = exception.response.content
exception.response.close()
_handle_status_error(status_code=exception.response.status_code, content=content)

async def __aenter__(self) -> typing_extensions.Self:
await self.httpx_client.__aenter__()
await self.httpx_client.__aenter__() # type: ignore[no-untyped-call]
return self

async def __aexit__(
Expand All @@ -144,4 +150,4 @@ async def __aexit__(
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
await self.httpx_client.__aexit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback)
await self.httpx_client.__aexit__(exc_type, exc_value, traceback) # type: ignore[no-untyped-call]
24 changes: 13 additions & 11 deletions any_llm_client/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,20 @@
import typing

import httpx
import niquests
import stamina

from any_llm_client.retry import RequestRetryConfig


async def make_http_request(
*,
httpx_client: httpx.AsyncClient,
httpx_client: niquests.AsyncSession,
request_retry: RequestRetryConfig,
build_request: typing.Callable[[], httpx.Request],
) -> httpx.Response:
@stamina.retry(on=httpx.HTTPError, **dataclasses.asdict(request_retry))
async def make_request_with_retries() -> httpx.Response:
build_request: typing.Callable[[], niquests.PreparedRequest],
) -> niquests.Response:
@stamina.retry(on=niquests.HTTPError, **dataclasses.asdict(request_retry))
async def make_request_with_retries() -> niquests.Response:
response: typing.Final = await httpx_client.send(build_request())
response.raise_for_status()
return response
Expand All @@ -26,18 +27,19 @@ async def make_request_with_retries() -> httpx.Response:
@contextlib.asynccontextmanager
async def make_streaming_http_request(
*,
httpx_client: httpx.AsyncClient,
httpx_client: niquests.AsyncSession,
request_retry: RequestRetryConfig,
build_request: typing.Callable[[], httpx.Request],
) -> typing.AsyncIterator[httpx.Response]:
build_request: typing.Callable[[], niquests.PreparedRequest],
) -> typing.AsyncIterator[niquests.AsyncResponse]:
@stamina.retry(on=httpx.HTTPError, **dataclasses.asdict(request_retry))
async def make_request_with_retries() -> httpx.Response:
async def make_request_with_retries() -> niquests.AsyncResponse:
response: typing.Final = await httpx_client.send(build_request(), stream=True)
response.raise_for_status()
return response
return response # type: ignore[return-value]

response: typing.Final = await make_request_with_retries()
try:
response.__aenter__()
yield response
finally:
await response.aclose()
await response.close()
12 changes: 6 additions & 6 deletions any_llm_client/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import typing

import httpx
import niquests

from any_llm_client.clients.mock import MockLLMClient, MockLLMConfig
from any_llm_client.clients.openai import OpenAIClient, OpenAIConfig
Expand All @@ -18,7 +18,7 @@
def get_client(
config: AnyLLMConfig,
*,
httpx_client: httpx.AsyncClient | None = None,
httpx_client: niquests.AsyncSession | None = None,
request_retry: RequestRetryConfig | None = None,
) -> LLMClient: ... # pragma: no cover
else:
Expand All @@ -27,7 +27,7 @@ def get_client(
def get_client(
config: typing.Any, # noqa: ANN401, ARG001
*,
httpx_client: httpx.AsyncClient | None = None, # noqa: ARG001
httpx_client: niquests.AsyncSession | None = None, # noqa: ARG001
request_retry: RequestRetryConfig | None = None, # noqa: ARG001
) -> LLMClient:
raise AssertionError("unknown LLM config type")
Expand All @@ -36,7 +36,7 @@ def get_client(
def _(
config: YandexGPTConfig,
*,
httpx_client: httpx.AsyncClient | None = None,
httpx_client: niquests.AsyncSession | None = None,
request_retry: RequestRetryConfig | None = None,
) -> LLMClient:
return YandexGPTClient(config=config, httpx_client=httpx_client, request_retry=request_retry)
Expand All @@ -45,7 +45,7 @@ def _(
def _(
config: OpenAIConfig,
*,
httpx_client: httpx.AsyncClient | None = None,
httpx_client: niquests.AsyncSession | None = None,
request_retry: RequestRetryConfig | None = None,
) -> LLMClient:
return OpenAIClient(config=config, httpx_client=httpx_client, request_retry=request_retry)
Expand All @@ -54,7 +54,7 @@ def _(
def _(
config: MockLLMConfig,
*,
httpx_client: httpx.AsyncClient | None = None, # noqa: ARG001
httpx_client: niquests.AsyncSession | None = None, # noqa: ARG001
request_retry: RequestRetryConfig | None = None, # noqa: ARG001
) -> LLMClient:
return MockLLMClient(config=config)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies = [
"httpx>=0.27.2",
"pydantic>=2.9.2",
"stamina>=24.3.0",
"niquests>=3.11.0",
]
dynamic = ["version"]

Expand Down
13 changes: 7 additions & 6 deletions tests/test_openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import faker
import httpx
import niquests
import pydantic
import pytest
from polyfactory.factories.pydantic_factory import ModelFactory
Expand Down Expand Up @@ -36,7 +37,7 @@ async def test_ok(self, faker: faker.Faker) -> None:

result: typing.Final = await any_llm_client.get_client(
OpenAIConfigFactory.build(),
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: response)),
).request_llm_message(**LLMFuncRequestFactory.build())

assert result == expected_result
Expand All @@ -48,7 +49,7 @@ async def test_fails_without_alternatives(self) -> None:
)
client: typing.Final = any_llm_client.get_client(
OpenAIConfigFactory.build(),
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: response)),
)

with pytest.raises(pydantic.ValidationError):
Expand Down Expand Up @@ -91,7 +92,7 @@ async def test_ok(self, faker: faker.Faker) -> None:
)
client: typing.Final = any_llm_client.get_client(
config,
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: response)),
)

result: typing.Final = await consume_llm_partial_responses(client.stream_llm_partial_messages(**func_request))
Expand All @@ -107,7 +108,7 @@ async def test_fails_without_alternatives(self) -> None:
)
client: typing.Final = any_llm_client.get_client(
OpenAIConfigFactory.build(),
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: response)),
)

with pytest.raises(pydantic.ValidationError):
Expand All @@ -120,7 +121,7 @@ class TestOpenAILLMErrors:
async def test_fails_with_unknown_error(self, stream: bool, status_code: int) -> None:
client: typing.Final = any_llm_client.get_client(
OpenAIConfigFactory.build(),
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: httpx.Response(status_code))),
httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: httpx.Response(status_code))),
)

coroutine: typing.Final = (
Expand All @@ -145,7 +146,7 @@ async def test_fails_with_out_of_tokens_error(self, stream: bool, content: bytes
response: typing.Final = httpx.Response(400, content=content)
client: typing.Final = any_llm_client.get_client(
OpenAIConfigFactory.build(),
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: response)),
)

coroutine: typing.Final = (
Expand Down
4 changes: 2 additions & 2 deletions tests/test_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ def test_llm_error_str(faker: faker.Faker) -> None:


def test_llm_func_request_has_same_annotations_as_llm_client_methods() -> None:
all_objects = (
all_objects: typing.Final = (
any_llm_client.LLMClient.request_llm_message,
any_llm_client.LLMClient.stream_llm_partial_messages,
LLMFuncRequest,
)
all_annotations = [typing.get_type_hints(one_object) for one_object in all_objects]
all_annotations: typing.Final = [typing.get_type_hints(one_object) for one_object in all_objects]

for one_ignored_prop in ("return",):
for annotations in all_annotations:
Expand Down
Loading

0 comments on commit d51f5ae

Please sign in to comment.