From 95a191370140d9ba8b039d32caa2e83be099ca7a Mon Sep 17 00:00:00 2001 From: Lev Vereshchagin Date: Thu, 21 Nov 2024 18:49:18 +0300 Subject: [PATCH] Revert "Start moving to niquests" It is quiet raw yet: https://github.com/jawah/niquests/pull/182 This reverts commit d51f5aee281ed23af45b280e38ad511930208b33. --- README.md | 2 +- any_llm_client/clients/openai.py | 52 +++++++++++------------------ any_llm_client/clients/yandexgpt.py | 48 ++++++++++++-------------- any_llm_client/http.py | 24 ++++++------- any_llm_client/main.py | 12 +++---- pyproject.toml | 1 - tests/test_openai_client.py | 13 ++++---- tests/test_static.py | 4 +-- tests/test_yandexgpt_client.py | 13 ++++---- 9 files changed, 73 insertions(+), 96 deletions(-) diff --git a/README.md b/README.md index 25eac61..926af80 100644 --- a/README.md +++ b/README.md @@ -152,7 +152,7 @@ import any_llm_client async with any_llm_client.get_client( ..., - httpx_client=niquests.AsyncSession( + httpx_client=httpx.AsyncClient( mounts={"https://api.openai.com": httpx.AsyncHTTPTransport(proxy="http://localhost:8030")}, timeout=httpx.Timeout(None, connect=5.0), ), diff --git a/any_llm_client/clients/openai.py b/any_llm_client/clients/openai.py index e2bfc68..438fa4b 100644 --- a/any_llm_client/clients/openai.py +++ b/any_llm_client/clients/openai.py @@ -5,8 +5,8 @@ from http import HTTPStatus import annotated_types +import httpx import httpx_sse -import niquests import pydantic import typing_extensions @@ -92,27 +92,25 @@ def _handle_status_error(*, status_code: int, content: bytes) -> typing.NoReturn @dataclasses.dataclass(slots=True, init=False) class OpenAIClient(LLMClient): config: OpenAIConfig - httpx_client: niquests.AsyncSession + httpx_client: httpx.AsyncClient request_retry: RequestRetryConfig def __init__( self, config: OpenAIConfig, - httpx_client: niquests.AsyncSession | None = None, + httpx_client: httpx.AsyncClient | None = None, request_retry: RequestRetryConfig | None = None, ) -> None: self.config = config - self.httpx_client = httpx_client or niquests.AsyncSession() + self.httpx_client = httpx_client or httpx.AsyncClient() self.request_retry = request_retry or RequestRetryConfig() - def _build_request(self, payload: dict[str, typing.Any]) -> niquests.PreparedRequest: - return self.httpx_client.prepare_request( - niquests.Request( - method="POST", - url=str(self.config.url), - json=payload, - headers={"Authorization": f"Bearer {self.config.auth_token}"} if self.config.auth_token else None, - ) + def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request: + return self.httpx_client.build_request( + method="POST", + url=str(self.config.url), + json=payload, + headers={"Authorization": f"Bearer {self.config.auth_token}"} if self.config.auth_token else None, ) def _prepare_messages(self, messages: str | list[Message]) -> list[ChatCompletionsMessage]: @@ -139,21 +137,14 @@ async def request_llm_message(self, messages: str | list[Message], temperature: request_retry=self.request_retry, build_request=lambda: self._build_request(payload), ) - except niquests.HTTPError as exception: - if exception.response and exception.response.status_code and exception.response.content: - _handle_status_error(status_code=exception.response.status_code, content=exception.response.content) - else: - raise + except httpx.HTTPStatusError as exception: + _handle_status_error(status_code=exception.response.status_code, content=exception.response.content) try: - return ( - ChatCompletionsNotStreamingResponse.model_validate_json(response.content) # type: ignore[arg-type] - .choices[0] - .message.content - ) + return ChatCompletionsNotStreamingResponse.model_validate_json(response.content).choices[0].message.content finally: await response.aclose() - async def _iter_partial_responses(self, response: niquests.AsyncResponse) -> typing.AsyncIterable[str]: + async def _iter_partial_responses(self, response: httpx.Response) -> typing.AsyncIterable[str]: text_chunks: typing.Final = [] async for event in httpx_sse.EventSource(response).aiter_sse(): if event.data == "[DONE]": @@ -181,16 +172,13 @@ async def stream_llm_partial_messages( build_request=lambda: self._build_request(payload), ) as response: yield self._iter_partial_responses(response) - except niquests.HTTPError as exception: - if exception.response and exception.response.status_code and exception.response.content: - content: typing.Final = exception.response.content - exception.response.close() - _handle_status_error(status_code=exception.response.status_code, content=content) - else: - raise + except httpx.HTTPStatusError as exception: + content: typing.Final = await exception.response.aread() + await exception.response.aclose() + _handle_status_error(status_code=exception.response.status_code, content=content) async def __aenter__(self) -> typing_extensions.Self: - await self.httpx_client.__aenter__() # type: ignore[no-untyped-call] + await self.httpx_client.__aenter__() return self async def __aexit__( @@ -199,4 +187,4 @@ async def __aexit__( exc_value: BaseException | None, traceback: types.TracebackType | None, ) -> None: - await self.httpx_client.__aexit__(exc_type, exc_value, traceback) # type: ignore[no-untyped-call] + await self.httpx_client.__aexit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback) diff --git a/any_llm_client/clients/yandexgpt.py b/any_llm_client/clients/yandexgpt.py index 327e53f..8463d9b 100644 --- a/any_llm_client/clients/yandexgpt.py +++ b/any_llm_client/clients/yandexgpt.py @@ -5,7 +5,7 @@ from http import HTTPStatus import annotated_types -import niquests +import httpx import pydantic import typing_extensions @@ -64,27 +64,25 @@ def _handle_status_error(*, status_code: int, content: bytes) -> typing.NoReturn @dataclasses.dataclass(slots=True, init=False) class YandexGPTClient(LLMClient): config: YandexGPTConfig - httpx_client: niquests.AsyncSession + httpx_client: httpx.AsyncClient request_retry: RequestRetryConfig def __init__( self, config: YandexGPTConfig, - httpx_client: niquests.AsyncSession | None = None, + httpx_client: httpx.AsyncClient | None = None, request_retry: RequestRetryConfig | None = None, ) -> None: self.config = config - self.httpx_client = httpx_client or niquests.AsyncSession() + self.httpx_client = httpx_client or httpx.AsyncClient() self.request_retry = request_retry or RequestRetryConfig() - def _build_request(self, payload: dict[str, typing.Any]) -> niquests.PreparedRequest: - return self.httpx_client.prepare_request( - niquests.Request( - method="POST", - url=str(self.config.url), - json=payload, - headers={"Authorization": self.config.auth_header, "x-data-logging-enabled": "false"}, - ) + def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request: + return self.httpx_client.build_request( + method="POST", + url=str(self.config.url), + json=payload, + headers={"Authorization": self.config.auth_header, "x-data-logging-enabled": "false"}, ) def _prepare_payload( @@ -108,16 +106,13 @@ async def request_llm_message(self, messages: str | list[Message], temperature: request_retry=self.request_retry, build_request=lambda: self._build_request(payload), ) - except niquests.HTTPError as exception: - if exception.response and exception.response.status_code and exception.response.content: - _handle_status_error(status_code=exception.response.status_code, content=exception.response.content) - else: - raise + except httpx.HTTPStatusError as exception: + _handle_status_error(status_code=exception.response.status_code, content=exception.response.content) - return YandexGPTResponse.model_validate_json(response.content).result.alternatives[0].message.text # type: ignore[arg-type] + return YandexGPTResponse.model_validate_json(response.content).result.alternatives[0].message.text - async def _iter_completion_messages(self, response: niquests.AsyncResponse) -> typing.AsyncIterable[str]: - async for one_line in response.iter_lines(): + async def _iter_completion_messages(self, response: httpx.Response) -> typing.AsyncIterable[str]: + async for one_line in response.aiter_lines(): validated_response = YandexGPTResponse.model_validate_json(one_line) yield validated_response.result.alternatives[0].message.text @@ -134,14 +129,13 @@ async def stream_llm_partial_messages( build_request=lambda: self._build_request(payload), ) as response: yield self._iter_completion_messages(response) - except niquests.HTTPError as exception: - if exception.response and exception.response.status_code and exception.response.content: - content: typing.Final = exception.response.content - exception.response.close() - _handle_status_error(status_code=exception.response.status_code, content=content) + except httpx.HTTPStatusError as exception: + content: typing.Final = await exception.response.aread() + await exception.response.aclose() + _handle_status_error(status_code=exception.response.status_code, content=content) async def __aenter__(self) -> typing_extensions.Self: - await self.httpx_client.__aenter__() # type: ignore[no-untyped-call] + await self.httpx_client.__aenter__() return self async def __aexit__( @@ -150,4 +144,4 @@ async def __aexit__( exc_value: BaseException | None, traceback: types.TracebackType | None, ) -> None: - await self.httpx_client.__aexit__(exc_type, exc_value, traceback) # type: ignore[no-untyped-call] + await self.httpx_client.__aexit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback) diff --git a/any_llm_client/http.py b/any_llm_client/http.py index 4964210..1bc5d34 100644 --- a/any_llm_client/http.py +++ b/any_llm_client/http.py @@ -3,7 +3,6 @@ import typing import httpx -import niquests import stamina from any_llm_client.retry import RequestRetryConfig @@ -11,12 +10,12 @@ async def make_http_request( *, - httpx_client: niquests.AsyncSession, + httpx_client: httpx.AsyncClient, request_retry: RequestRetryConfig, - build_request: typing.Callable[[], niquests.PreparedRequest], -) -> niquests.Response: - @stamina.retry(on=niquests.HTTPError, **dataclasses.asdict(request_retry)) - async def make_request_with_retries() -> niquests.Response: + build_request: typing.Callable[[], httpx.Request], +) -> httpx.Response: + @stamina.retry(on=httpx.HTTPError, **dataclasses.asdict(request_retry)) + async def make_request_with_retries() -> httpx.Response: response: typing.Final = await httpx_client.send(build_request()) response.raise_for_status() return response @@ -27,19 +26,18 @@ async def make_request_with_retries() -> niquests.Response: @contextlib.asynccontextmanager async def make_streaming_http_request( *, - httpx_client: niquests.AsyncSession, + httpx_client: httpx.AsyncClient, request_retry: RequestRetryConfig, - build_request: typing.Callable[[], niquests.PreparedRequest], -) -> typing.AsyncIterator[niquests.AsyncResponse]: + build_request: typing.Callable[[], httpx.Request], +) -> typing.AsyncIterator[httpx.Response]: @stamina.retry(on=httpx.HTTPError, **dataclasses.asdict(request_retry)) - async def make_request_with_retries() -> niquests.AsyncResponse: + async def make_request_with_retries() -> httpx.Response: response: typing.Final = await httpx_client.send(build_request(), stream=True) response.raise_for_status() - return response # type: ignore[return-value] + return response response: typing.Final = await make_request_with_retries() try: - response.__aenter__() yield response finally: - await response.close() + await response.aclose() diff --git a/any_llm_client/main.py b/any_llm_client/main.py index 000322c..ce8c3b5 100644 --- a/any_llm_client/main.py +++ b/any_llm_client/main.py @@ -1,7 +1,7 @@ import functools import typing -import niquests +import httpx from any_llm_client.clients.mock import MockLLMClient, MockLLMConfig from any_llm_client.clients.openai import OpenAIClient, OpenAIConfig @@ -18,7 +18,7 @@ def get_client( config: AnyLLMConfig, *, - httpx_client: niquests.AsyncSession | None = None, + httpx_client: httpx.AsyncClient | None = None, request_retry: RequestRetryConfig | None = None, ) -> LLMClient: ... # pragma: no cover else: @@ -27,7 +27,7 @@ def get_client( def get_client( config: typing.Any, # noqa: ANN401, ARG001 *, - httpx_client: niquests.AsyncSession | None = None, # noqa: ARG001 + httpx_client: httpx.AsyncClient | None = None, # noqa: ARG001 request_retry: RequestRetryConfig | None = None, # noqa: ARG001 ) -> LLMClient: raise AssertionError("unknown LLM config type") @@ -36,7 +36,7 @@ def get_client( def _( config: YandexGPTConfig, *, - httpx_client: niquests.AsyncSession | None = None, + httpx_client: httpx.AsyncClient | None = None, request_retry: RequestRetryConfig | None = None, ) -> LLMClient: return YandexGPTClient(config=config, httpx_client=httpx_client, request_retry=request_retry) @@ -45,7 +45,7 @@ def _( def _( config: OpenAIConfig, *, - httpx_client: niquests.AsyncSession | None = None, + httpx_client: httpx.AsyncClient | None = None, request_retry: RequestRetryConfig | None = None, ) -> LLMClient: return OpenAIClient(config=config, httpx_client=httpx_client, request_retry=request_retry) @@ -54,7 +54,7 @@ def _( def _( config: MockLLMConfig, *, - httpx_client: niquests.AsyncSession | None = None, # noqa: ARG001 + httpx_client: httpx.AsyncClient | None = None, # noqa: ARG001 request_retry: RequestRetryConfig | None = None, # noqa: ARG001 ) -> LLMClient: return MockLLMClient(config=config) diff --git a/pyproject.toml b/pyproject.toml index 89411c8..7fe9959 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,6 @@ dependencies = [ "httpx>=0.27.2", "pydantic>=2.9.2", "stamina>=24.3.0", - "niquests>=3.11.0", ] dynamic = ["version"] diff --git a/tests/test_openai_client.py b/tests/test_openai_client.py index 5fc1b53..b23c851 100644 --- a/tests/test_openai_client.py +++ b/tests/test_openai_client.py @@ -3,7 +3,6 @@ import faker import httpx -import niquests import pydantic import pytest from polyfactory.factories.pydantic_factory import ModelFactory @@ -37,7 +36,7 @@ async def test_ok(self, faker: faker.Faker) -> None: result: typing.Final = await any_llm_client.get_client( OpenAIConfigFactory.build(), - httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: response)), + httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)), ).request_llm_message(**LLMFuncRequestFactory.build()) assert result == expected_result @@ -49,7 +48,7 @@ async def test_fails_without_alternatives(self) -> None: ) client: typing.Final = any_llm_client.get_client( OpenAIConfigFactory.build(), - httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: response)), + httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)), ) with pytest.raises(pydantic.ValidationError): @@ -92,7 +91,7 @@ async def test_ok(self, faker: faker.Faker) -> None: ) client: typing.Final = any_llm_client.get_client( config, - httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: response)), + httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)), ) result: typing.Final = await consume_llm_partial_responses(client.stream_llm_partial_messages(**func_request)) @@ -108,7 +107,7 @@ async def test_fails_without_alternatives(self) -> None: ) client: typing.Final = any_llm_client.get_client( OpenAIConfigFactory.build(), - httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: response)), + httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)), ) with pytest.raises(pydantic.ValidationError): @@ -121,7 +120,7 @@ class TestOpenAILLMErrors: async def test_fails_with_unknown_error(self, stream: bool, status_code: int) -> None: client: typing.Final = any_llm_client.get_client( OpenAIConfigFactory.build(), - httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: httpx.Response(status_code))), + httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: httpx.Response(status_code))), ) coroutine: typing.Final = ( @@ -146,7 +145,7 @@ async def test_fails_with_out_of_tokens_error(self, stream: bool, content: bytes response: typing.Final = httpx.Response(400, content=content) client: typing.Final = any_llm_client.get_client( OpenAIConfigFactory.build(), - httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: response)), + httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)), ) coroutine: typing.Final = ( diff --git a/tests/test_static.py b/tests/test_static.py index d5f2fed..78beaf6 100644 --- a/tests/test_static.py +++ b/tests/test_static.py @@ -27,12 +27,12 @@ def test_llm_error_str(faker: faker.Faker) -> None: def test_llm_func_request_has_same_annotations_as_llm_client_methods() -> None: - all_objects: typing.Final = ( + all_objects = ( any_llm_client.LLMClient.request_llm_message, any_llm_client.LLMClient.stream_llm_partial_messages, LLMFuncRequest, ) - all_annotations: typing.Final = [typing.get_type_hints(one_object) for one_object in all_objects] + all_annotations = [typing.get_type_hints(one_object) for one_object in all_objects] for one_ignored_prop in ("return",): for annotations in all_annotations: diff --git a/tests/test_yandexgpt_client.py b/tests/test_yandexgpt_client.py index 03b769b..a4c23d4 100644 --- a/tests/test_yandexgpt_client.py +++ b/tests/test_yandexgpt_client.py @@ -2,7 +2,6 @@ import faker import httpx -import niquests import pydantic import pytest from polyfactory.factories.pydantic_factory import ModelFactory @@ -31,7 +30,7 @@ async def test_ok(self, faker: faker.Faker) -> None: result: typing.Final = await any_llm_client.get_client( YandexGPTConfigFactory.build(), - httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: response)), + httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)), ).request_llm_message(**LLMFuncRequestFactory.build()) assert result == expected_result @@ -42,7 +41,7 @@ async def test_fails_without_alternatives(self) -> None: ) client: typing.Final = any_llm_client.get_client( YandexGPTConfigFactory.build(), - httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: response)), + httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)), ) with pytest.raises(pydantic.ValidationError): @@ -71,7 +70,7 @@ async def test_ok(self, faker: faker.Faker) -> None: result: typing.Final = await consume_llm_partial_responses( any_llm_client.get_client( - config, httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: response)) + config, httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)) ).stream_llm_partial_messages(**func_request) ) @@ -85,7 +84,7 @@ async def test_fails_without_alternatives(self) -> None: client: typing.Final = any_llm_client.get_client( YandexGPTConfigFactory.build(), - httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: response)), + httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)), ) with pytest.raises(pydantic.ValidationError): @@ -98,7 +97,7 @@ class TestYandexGPTLLMErrors: async def test_fails_with_unknown_error(self, stream: bool, status_code: int) -> None: client: typing.Final = any_llm_client.get_client( YandexGPTConfigFactory.build(), - httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: httpx.Response(status_code))), + httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: httpx.Response(status_code))), ) coroutine: typing.Final = ( @@ -123,7 +122,7 @@ async def test_fails_with_out_of_tokens_error(self, stream: bool, response_conte response: typing.Final = httpx.Response(400, content=response_content) client: typing.Final = any_llm_client.get_client( YandexGPTConfigFactory.build(), - httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: response)), + httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)), ) coroutine: typing.Final = (