From 77a10136c217b07325bb49cdb82992fc376ea112 Mon Sep 17 00:00:00 2001 From: "tsalpekar21@gmail.com" Date: Thu, 20 Jul 2023 22:13:56 -0500 Subject: [PATCH] feat: add rate limit headers to base exceptions --- stream_chat/async_chat/client.py | 5 +-- stream_chat/base/exceptions.py | 15 +++++++- stream_chat/client.py | 8 +++- stream_chat/tests/async_chat/test_client.py | 12 ++++++ stream_chat/tests/test_client.py | 12 ++++++ stream_chat/types/header.py | 41 +++++++++++++++++++++ stream_chat/types/stream_response.py | 29 ++------------- 7 files changed, 90 insertions(+), 32 deletions(-) create mode 100644 stream_chat/types/header.py diff --git a/stream_chat/async_chat/client.py b/stream_chat/async_chat/client.py index ba5d9c1..aaf3aff 100644 --- a/stream_chat/async_chat/client.py +++ b/stream_chat/async_chat/client.py @@ -69,10 +69,9 @@ async def _parse_response(self, response: aiohttp.ClientResponse) -> StreamRespo try: parsed_result = await response.json() if text else {} except aiohttp.ClientResponseError: - raise StreamAPIException(text, response.status) + raise StreamAPIException(text, response.status, dict(response.headers)) if response.status >= 399: - raise StreamAPIException(text, response.status) - + raise StreamAPIException(text, response.status, dict(response.headers)) return StreamResponse(parsed_result, dict(response.headers), response.status) async def _make_request( diff --git a/stream_chat/base/exceptions.py b/stream_chat/base/exceptions.py index 9ed295b..7f2ca08 100644 --- a/stream_chat/base/exceptions.py +++ b/stream_chat/base/exceptions.py @@ -1,5 +1,8 @@ import json -from typing import Dict +from typing import Any, Dict, Optional + +from stream_chat.types.header import StreamHeaders +from stream_chat.types.rate_limit import RateLimitInfo class StreamChannelException(Exception): @@ -7,10 +10,14 @@ class StreamChannelException(Exception): class StreamAPIException(Exception): - def __init__(self, text: str, status_code: int) -> None: + def __init__( + self, text: str, status_code: int, headers: Dict[str, Any] = {} + ) -> None: self.response_text = text self.status_code = status_code self.json_response = False + self.headers = StreamHeaders(headers) + self.__rate_limit: Optional[RateLimitInfo] = self.headers.rate_limit() try: parsed_response: Dict = json.loads(text) @@ -25,3 +32,7 @@ def __str__(self) -> str: return f'StreamChat error code {self.error_code}: {self.error_message}"' else: return f"StreamChat error HTTP code: {self.status_code}" + + def rate_limit(self) -> Optional[RateLimitInfo]: + """Returns the ratelimit info of your API operation.""" + return self.__rate_limit diff --git a/stream_chat/client.py b/stream_chat/client.py index 1b36cea..83936aa 100644 --- a/stream_chat/client.py +++ b/stream_chat/client.py @@ -55,9 +55,13 @@ def _parse_response(self, response: requests.Response) -> StreamResponse: try: parsed_result = json.loads(response.text) if response.text else {} except ValueError: - raise StreamAPIException(response.text, response.status_code) + raise StreamAPIException( + response.text, response.status_code, dict(response.headers) + ) if response.status_code >= 399: - raise StreamAPIException(response.text, response.status_code) + raise StreamAPIException( + response.text, response.status_code, dict(response.headers) + ) return StreamResponse( parsed_result, dict(response.headers), response.status_code diff --git a/stream_chat/tests/async_chat/test_client.py b/stream_chat/tests/async_chat/test_client.py index a9df6cc..ab6383f 100644 --- a/stream_chat/tests/async_chat/test_client.py +++ b/stream_chat/tests/async_chat/test_client.py @@ -751,6 +751,18 @@ async def test_stream_response(self, client: StreamChatAsync): assert rate_limit.remaining > 0 assert type(rate_limit.reset) is datetime + async def test_stream_headers_in_exception(self, client: StreamChatAsync): + with pytest.raises(StreamAPIException) as stream_exception: + user = {"id": "bad id"} + response = await client.upsert_users([user]) + + assert stream_exception.value.headers is not None + rate_limit = stream_exception.value.rate_limit() + assert rate_limit is not None + assert rate_limit.limit > 0 + assert rate_limit.remaining > 0 + assert rate_limit.reset is not None + async def test_swap_http_client(self): client = StreamChatAsync( api_key=os.environ["STREAM_KEY"], api_secret=os.environ["STREAM_SECRET"] diff --git a/stream_chat/tests/test_client.py b/stream_chat/tests/test_client.py index 12f7f3f..fe065d8 100644 --- a/stream_chat/tests/test_client.py +++ b/stream_chat/tests/test_client.py @@ -88,6 +88,18 @@ def test_auth_exception(self): with pytest.raises(StreamAPIException): client.get_channel_type("team") + def test_stream_headers_in_exception(self, client: StreamChat): + with pytest.raises(StreamAPIException) as stream_exception: + user = {"id": "bad id"} + client.upsert_users([user]) + + assert stream_exception.value.headers is not None + rate_limit = stream_exception.value.rate_limit() + assert rate_limit is not None + assert rate_limit.limit > 0 + assert rate_limit.remaining > 0 + assert rate_limit.reset is not None + def test_get_channel_types(self, client: StreamChat): response = client.get_channel_type("team") assert "permissions" in response diff --git a/stream_chat/types/header.py b/stream_chat/types/header.py new file mode 100644 index 0000000..3375462 --- /dev/null +++ b/stream_chat/types/header.py @@ -0,0 +1,41 @@ +from datetime import datetime, timezone +from typing import Any, Dict, Optional + +from stream_chat.types.rate_limit import RateLimitInfo + + +class StreamHeaders(dict): + def __init__(self, headers: Dict[str, Any]) -> None: + super().__init__() + self.__headers = headers + self.__rate_limit: Optional[RateLimitInfo] = None + limit, remaining, reset = ( + headers.get("x-ratelimit-limit"), + headers.get("x-ratelimit-remaining"), + headers.get("x-ratelimit-reset"), + ) + if limit and remaining and reset: + self.__rate_limit = RateLimitInfo( + limit=int(self._clean_header(limit)), + remaining=int(self._clean_header(remaining)), + reset=datetime.fromtimestamp( + float(self._clean_header(reset)), timezone.utc + ), + ) + + super(StreamHeaders, self).__init__(headers) + + def _clean_header(self, header: str) -> int: + try: + values = (v.strip() for v in header.split(",")) + return int(next(v for v in values if v)) + except ValueError: + return 0 + + def rate_limit(self) -> Optional[RateLimitInfo]: + """Returns the ratelimit info of your API operation.""" + return self.__rate_limit + + def headers(self) -> Dict[str, Any]: + """Returns the headers of the response.""" + return self.__headers diff --git a/stream_chat/types/stream_response.py b/stream_chat/types/stream_response.py index a52e538..c3cd641 100644 --- a/stream_chat/types/stream_response.py +++ b/stream_chat/types/stream_response.py @@ -1,6 +1,6 @@ -from datetime import datetime, timezone from typing import Any, Dict, Optional +from stream_chat.types.header import StreamHeaders from stream_chat.types.rate_limit import RateLimitInfo @@ -28,39 +28,18 @@ class StreamResponse(dict): def __init__( self, response_dict: Dict[str, Any], headers: Dict[str, Any], status_code: int ) -> None: - self.__headers = headers + self.__headers = StreamHeaders(headers) self.__status_code = status_code - self.__rate_limit: Optional[RateLimitInfo] = None - limit, remaining, reset = ( - headers.get("x-ratelimit-limit"), - headers.get("x-ratelimit-remaining"), - headers.get("x-ratelimit-reset"), - ) - if limit and remaining and reset: - self.__rate_limit = RateLimitInfo( - limit=int(self._clean_header(limit)), - remaining=int(self._clean_header(remaining)), - reset=datetime.fromtimestamp( - float(self._clean_header(reset)), timezone.utc - ), - ) - + self.__rate_limit: Optional[RateLimitInfo] = self.__headers.rate_limit() super(StreamResponse, self).__init__(response_dict) - def _clean_header(self, header: str) -> int: - try: - values = (v.strip() for v in header.split(",")) - return int(next(v for v in values if v)) - except ValueError: - return 0 - def rate_limit(self) -> Optional[RateLimitInfo]: """Returns the ratelimit info of your API operation.""" return self.__rate_limit def headers(self) -> Dict[str, Any]: """Returns the headers of the response.""" - return self.__headers + return self.__headers.headers() def status_code(self) -> int: """Returns the HTTP status code of the response."""