Skip to content

Commit d9f3677

Browse files
DevilXDLilSpazJoekp
andcommitted
Implement delayed session creation and reuse existing session
Co-authored-by: Joel Payne <[email protected]>
1 parent 0b9dd4e commit d9f3677

File tree

9 files changed

+191
-142
lines changed

9 files changed

+191
-142
lines changed

asyncprawcore/auth.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import inspect
66
import time
77
from abc import ABC, abstractmethod
8-
from typing import TYPE_CHECKING, Any, Awaitable, Callable
8+
from contextlib import asynccontextmanager
9+
from typing import TYPE_CHECKING, Any, AsyncContextManager, Awaitable, Callable
910

1011
from aiohttp import ClientRequest
1112
from aiohttp.helpers import BasicAuth
@@ -49,19 +50,20 @@ def __init__(
4950
self.client_id = client_id
5051
self.redirect_uri = redirect_uri
5152

53+
@asynccontextmanager
5254
async def _post(
5355
self, url: str, success_status: int = codes["ok"], **data: Any
54-
) -> ClientResponse:
55-
response = await self._requestor.request(
56+
) -> Callable[..., AsyncContextManager[ClientResponse]]:
57+
async with self._requestor.request(
5658
"POST",
5759
url,
5860
auth=self._auth(),
5961
data=sorted(data.items()),
6062
headers={"Connection": "close"},
61-
)
62-
if response.status != success_status:
63-
raise ResponseException(response)
64-
return response
63+
) as response:
64+
if response.status != success_status:
65+
raise ResponseException(response)
66+
yield response
6567

6668
def authorize_url(
6769
self, duration: str, scopes: list[str], state: str, implicit: bool = False
@@ -126,7 +128,8 @@ async def revoke_token(self, token: str, token_type: str | None = None):
126128
if token_type is not None:
127129
data["token_type_hint"] = token_type
128130
url = self._requestor.reddit_url + const.REVOKE_TOKEN_PATH
129-
await self._post(url, **data)
131+
async with self._post(url, **data) as _:
132+
pass # The response is not used.
130133

131134

132135
class BaseAuthorizer:
@@ -152,8 +155,8 @@ def _clear_access_token(self):
152155
async def _request_token(self, **data: Any):
153156
url = self._authenticator._requestor.reddit_url + const.ACCESS_TOKEN_PATH
154157
pre_request_time = time.time()
155-
response = await self._authenticator._post(url=url, **data)
156-
payload = await response.json()
158+
async with self._authenticator._post(url=url, **data) as response:
159+
payload = await response.json()
157160
if "error" in payload: # Why are these OKAY responses?
158161
raise OAuthException(
159162
response, payload["error"], payload.get("error_description")

asyncprawcore/rate_limit.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import asyncio
66
import logging
77
import time
8-
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Mapping
8+
from contextlib import asynccontextmanager
9+
from typing import TYPE_CHECKING, Any, AsyncContextManager, Awaitable, Callable, Mapping
910

1011
if TYPE_CHECKING:
1112
from aiohttp import ClientResponse
@@ -28,16 +29,15 @@ def __init__(self, *, window_size: int):
2829
self.used: int | None = None
2930
self.window_size: int = window_size
3031

32+
@asynccontextmanager
3133
async def call(
3234
self,
33-
request_function: Callable[
34-
[Any],
35-
Awaitable[ClientResponse],
36-
],
35+
# async context manager
36+
request_function: Callable[..., AsyncContextManager[ClientResponse]],
3737
set_header_callback: Callable[[], Awaitable[dict[str, str]]],
3838
*args: Any,
3939
**kwargs: Any,
40-
) -> ClientResponse:
40+
) -> Callable[..., AsyncContextManager[ClientResponse]]:
4141
"""Rate limit the call to ``request_function``.
4242
4343
:param request_function: A function call that returns an HTTP response object.
@@ -49,9 +49,9 @@ async def call(
4949
"""
5050
await self.delay()
5151
kwargs["headers"] = await set_header_callback()
52-
response = await request_function(*args, **kwargs)
53-
self.update(response.headers)
54-
return response
52+
async with request_function(*args, **kwargs) as response:
53+
self.update(response.headers)
54+
yield response
5555

5656
async def delay(self):
5757
"""Sleep for an amount of time to remain under the rate limit."""

asyncprawcore/requestor.py

+57-19
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
from __future__ import annotations
44

5-
import asyncio
6-
from typing import TYPE_CHECKING, Any
5+
from contextlib import asynccontextmanager
6+
from typing import TYPE_CHECKING, Any, AsyncContextManager, Callable
7+
from warnings import warn
78

89
import aiohttp
910
from aiohttp import ClientSession
1011

1112
from .const import TIMEOUT
12-
from .exceptions import InvalidInvocation, RequestException
13+
from .exceptions import InvalidInvocation, RequestException, ResponseException
1314

1415
if TYPE_CHECKING:
1516
from asyncio import AbstractEventLoop
@@ -23,7 +24,7 @@ class Requestor:
2324
def __getattr__(self, attribute: str) -> Any: # pragma: no cover
2425
"""Pass all undefined attributes to the ``_http`` attribute."""
2526
if attribute.startswith("__"):
26-
raise AttributeError
27+
raise AttributeError(attribute)
2728
return getattr(self._http, attribute)
2829

2930
def __init__(
@@ -32,7 +33,7 @@ def __init__(
3233
oauth_url: str = "https://oauth.reddit.com",
3334
reddit_url: str = "https://www.reddit.com",
3435
session: ClientSession | None = None,
35-
loop: AbstractEventLoop = None,
36+
loop: AbstractEventLoop | None = None,
3637
timeout: float = TIMEOUT,
3738
):
3839
"""Create an instance of the Requestor class.
@@ -45,40 +46,77 @@ def __init__(
4546
``"https://www.reddit.com"``).
4647
:param session: A session instance to handle requests, compatible with
4748
``aiohttp.ClientSession()`` (default: ``None``).
49+
:param loop: The event loop to run the requestor on (default: ``None``).
50+
51+
.. Deprecated:: 2.5.0
52+
53+
The ``loop`` argument is deprecated and will be ignored.
54+
4855
:param timeout: How many seconds to wait for the server to send data before
4956
giving up (default: ``asyncprawcore.const.TIMEOUT``).
5057
5158
"""
5259
# Imported locally to avoid an import cycle, with __init__
5360
from . import __version__
5461

62+
if loop is not None:
63+
msg = "The loop argument is deprecated and will be ignored."
64+
warn(msg, DeprecationWarning, stacklevel=2)
65+
5566
if user_agent is None or len(user_agent) < 7:
5667
msg = "user_agent is not descriptive"
5768
raise InvalidInvocation(msg)
5869

59-
self.loop = loop or asyncio.get_event_loop()
60-
self._http = session or aiohttp.ClientSession(
61-
loop=self.loop, timeout=aiohttp.ClientTimeout(total=None)
62-
)
63-
self._http._default_headers["User-Agent"] = (
64-
f"{user_agent} asyncprawcore/{__version__}"
65-
)
66-
70+
self.headers = {"User-Agent": f"{user_agent} asyncprawcore/{__version__}"}
6771
self.oauth_url = oauth_url
6872
self.reddit_url = reddit_url
6973
self.timeout = timeout
7074

75+
self._http = session
76+
if self._http is not None and "User-Agent" not in self._http.headers:
77+
# ensure user-agent is set
78+
self._http.headers.update(self.headers)
79+
80+
async def _ensure_session(self):
81+
"""Ensure that the session is open."""
82+
if self._http is None or self._http.closed:
83+
self._http = aiohttp.ClientSession(
84+
headers=self.headers,
85+
timeout=aiohttp.ClientTimeout(total=None),
86+
)
87+
7188
async def close(self):
7289
"""Call close on the underlying session."""
73-
await self._http.close()
90+
if self._http is not None and not self._http.closed:
91+
await self._http.close()
7492

93+
@asynccontextmanager
7594
async def request(
7695
self, *args: Any, timeout: float | None = None, **kwargs: Any
77-
) -> ClientResponse:
78-
"""Issue the HTTP request capturing any errors that may occur."""
96+
) -> Callable[..., AsyncContextManager[ClientResponse]]:
97+
"""Issue the HTTP request capturing any errors that may occur.
98+
99+
:param args: Positional arguments to pass to ``aiohttp.ClientSession.request``.
100+
:param timeout: How many seconds to wait for the server to send data before
101+
giving up (default: ``None``).
102+
:param kwargs: Keyword arguments to pass to ``aiohttp.ClientSession.request``.
103+
104+
:returns: The response from the request.
105+
106+
:raises: RequestException: If an error occurs while issuing the request.
107+
108+
"""
79109
try:
80-
return await self._http.request(
81-
*args, timeout=timeout or self.timeout, **kwargs
82-
)
110+
await self._ensure_session()
111+
kwargs_copy = kwargs.copy()
112+
async with self._http.request(
113+
*args,
114+
headers={**self.headers, **kwargs_copy.pop("headers", {})},
115+
timeout=timeout or self.timeout,
116+
**kwargs_copy,
117+
) as request:
118+
yield request
119+
except ResponseException as exc:
120+
raise exc
83121
except Exception as exc: # noqa: BLE001
84122
raise RequestException(exc, args, kwargs) from None

asyncprawcore/sessions.py

+59-55
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
import random
88
import time
99
from abc import ABC, abstractmethod
10+
from contextlib import asynccontextmanager
1011
from copy import deepcopy
1112
from pprint import pformat
12-
from typing import TYPE_CHECKING, Any, BinaryIO, TextIO
13+
from typing import TYPE_CHECKING, Any, AsyncContextManager, BinaryIO, Callable, TextIO
1314
from urllib.parse import urljoin
1415

1516
from aiohttp.web import HTTPRequestTimeout
@@ -179,6 +180,7 @@ async def _do_retry(
179180
# noqa: E501
180181
)
181182

183+
@asynccontextmanager
182184
async def _make_request(
183185
self,
184186
data: list[tuple[str, Any]],
@@ -188,9 +190,11 @@ async def _make_request(
188190
retry_strategy_state: FiniteRetryStrategy,
189191
timeout: float,
190192
url: str,
191-
) -> tuple[ClientResponse, None] | tuple[None, Exception]:
193+
) -> Callable[
194+
..., AsyncContextManager[tuple[ClientResponse | None, Exception | None]]
195+
]:
192196
try:
193-
response = await self._rate_limiter.call(
197+
async with self._rate_limiter.call(
194198
self._requestor.request,
195199
self._set_header_callback,
196200
method,
@@ -200,17 +204,17 @@ async def _make_request(
200204
json=json,
201205
params=params,
202206
timeout=timeout,
203-
)
204-
log.debug(
205-
"Response: %s (%s bytes) (rst-%s:rem-%s:used-%s ratelimit) at %s",
206-
response.status,
207-
response.headers.get("content-length"),
208-
response.headers.get("x-ratelimit-reset"),
209-
response.headers.get("x-ratelimit-remaining"),
210-
response.headers.get("x-ratelimit-used"),
211-
time.time(),
212-
)
213-
return response, None
207+
) as response:
208+
log.debug(
209+
"Response: %s (%s bytes) (rst-%s:rem-%s:used-%s ratelimit) at %s",
210+
response.status,
211+
response.headers.get("content-length"),
212+
response.headers.get("x-ratelimit-reset"),
213+
response.headers.get("x-ratelimit-remaining"),
214+
response.headers.get("x-ratelimit-used"),
215+
time.time(),
216+
)
217+
yield response, None
214218
except RequestException as exception:
215219
if (
216220
not retry_strategy_state.should_retry_on_failure()
@@ -219,7 +223,7 @@ async def _make_request(
219223
)
220224
):
221225
raise
222-
return None, exception.original_exception
226+
yield None, exception.original_exception
223227

224228
def _preprocess_data(
225229
self,
@@ -284,54 +288,54 @@ async def _request_with_retries(
284288

285289
await retry_strategy_state.sleep()
286290
self._log_request(data, method, params, url)
287-
response, saved_exception = await self._make_request(
291+
async with self._make_request(
288292
data,
289293
json,
290294
method,
291295
params,
292296
retry_strategy_state,
293297
timeout,
294298
url,
295-
)
296-
297-
do_retry = False
298-
if response is not None and response.status == codes["unauthorized"]:
299-
self._authorizer._clear_access_token()
300-
if hasattr(self._authorizer, "refresh"):
301-
do_retry = True
302-
303-
if retry_strategy_state.should_retry_on_failure() and (
304-
do_retry or response is None or response.status in self.RETRY_STATUSES
305-
):
306-
return await self._do_retry(
307-
data,
308-
json,
309-
method,
310-
params,
311-
response,
312-
retry_strategy_state,
313-
saved_exception,
314-
timeout,
315-
url,
316-
)
317-
if response.status in self.STATUS_EXCEPTIONS:
318-
if response.status == codes["media_type"]:
319-
# since exception class needs response.json
320-
raise self.STATUS_EXCEPTIONS[response.status](
321-
response, await response.json()
299+
) as (response, saved_exception):
300+
do_retry = False
301+
if response is not None and response.status == codes["unauthorized"]:
302+
# noinspection PyProtectedMember
303+
self._authorizer._clear_access_token()
304+
if hasattr(self._authorizer, "refresh"):
305+
do_retry = True
306+
307+
if retry_strategy_state.should_retry_on_failure() and (
308+
do_retry or response is None or response.status in self.RETRY_STATUSES
309+
):
310+
return await self._do_retry(
311+
data,
312+
json,
313+
method,
314+
params,
315+
response,
316+
retry_strategy_state,
317+
saved_exception,
318+
timeout,
319+
url,
322320
)
323-
raise self.STATUS_EXCEPTIONS[response.status](response)
324-
if response.status == codes["no_content"]:
325-
return None
326-
assert (
327-
response.status in self.SUCCESS_STATUSES
328-
), f"Unexpected status code: {response.status}"
329-
if response.headers.get("content-length") == "0":
330-
return ""
331-
try:
332-
return await response.json()
333-
except ValueError:
334-
raise BadJSON(response) from None
321+
if response.status in self.STATUS_EXCEPTIONS:
322+
if response.status == codes["media_type"]:
323+
# since exception class needs response.json
324+
raise self.STATUS_EXCEPTIONS[response.status](
325+
response, await response.json()
326+
)
327+
raise self.STATUS_EXCEPTIONS[response.status](response)
328+
if response.status == codes["no_content"]:
329+
return None
330+
assert (
331+
response.status in self.SUCCESS_STATUSES
332+
), f"Unexpected status code: {response.status}"
333+
if response.headers.get("content-length") == "0":
334+
return ""
335+
try:
336+
return await response.json()
337+
except ValueError:
338+
raise BadJSON(response) from None
335339

336340
async def _set_header_callback(self) -> dict[str, str]:
337341
if not self._authorizer.is_valid() and hasattr(self._authorizer, "refresh"):

0 commit comments

Comments
 (0)