1
1
import asyncio
2
2
from http .client import responses as http_reasons
3
+ from typing import Callable , Optional
3
4
from unittest import mock
4
5
from urllib .parse import urlencode , urlunparse
5
6
from collections .abc import Mapping
6
7
8
+ import aiohttp
7
9
from aiohttp .helpers import TimerNoop
8
10
from aiohttp .streams import EmptyStreamReader
9
11
@@ -29,13 +31,8 @@ async def read(self, n=-1):
29
31
return self .content
30
32
31
33
32
- def HTTPResponse (* args , ** kw ):
33
- # Dynamically load package
34
- module = __import__ (RESPONSE_PATH , fromlist = (RESPONSE_CLASS ,))
35
- ClientResponse = getattr (module , RESPONSE_CLASS )
36
-
37
- # Return response instance
38
- return ClientResponse (
34
+ def HTTPResponse (session : aiohttp .ClientSession , * args , ** kw ):
35
+ return session ._response_class (
39
36
* args ,
40
37
request_info = mock .Mock (),
41
38
writer = None ,
@@ -53,22 +50,17 @@ class AIOHTTPInterceptor(BaseInterceptor):
53
50
aiohttp HTTP client traffic interceptor.
54
51
"""
55
52
56
- def _url (self , url ):
53
+ def _url (self , url ) -> Optional [ yarl . URL ] :
57
54
return yarl .URL (url ) if yarl else None
58
55
59
- async def _on_request (
60
- self , _request , session , method , url , data = None , headers = None , ** kw
61
- ):
62
- # Create request contract based on incoming params
63
- req = Request (method )
64
-
56
+ def set_headers (self , req , headers ) -> None :
65
57
# aiohttp's interface allows various mappings, as well as an iterable of key/value tuples
66
58
# ``pook.request`` only allows a dict, so we need to map the iterable to the matchable interface
67
59
if headers :
68
60
if isinstance (headers , Mapping ):
69
61
req .headers = headers
70
62
else :
71
- req_headers = {}
63
+ req_headers : dict [ str , str ] = {}
72
64
# If it isn't a mapping, then its an Iterable[Tuple[Union[str, istr], str]]
73
65
for req_header , req_header_value in headers :
74
66
normalised_header = req_header .lower ()
@@ -79,17 +71,37 @@ async def _on_request(
79
71
80
72
req .headers = req_headers
81
73
74
+ async def _on_request (
75
+ self ,
76
+ _request : Callable ,
77
+ session : aiohttp .ClientSession ,
78
+ method : str ,
79
+ url : str ,
80
+ data = None ,
81
+ headers = None ,
82
+ ** kw ,
83
+ ) -> aiohttp .ClientResponse :
84
+ # Create request contract based on incoming params
85
+ req = Request (method )
86
+
87
+ self .set_headers (req , headers )
88
+ self .set_headers (req , session .headers )
89
+
82
90
req .body = data
83
91
84
92
# Expose extra variadic arguments
85
93
req .extra = kw
86
94
95
+ full_url = session ._build_url (url )
96
+
87
97
# Compose URL
88
98
if not kw .get ("params" ):
89
- req .url = str (url )
99
+ req .url = str (full_url )
90
100
else :
91
101
req .url = (
92
- str (url ) + "?" + urlencode ([(x , y ) for x , y in kw ["params" ].items ()])
102
+ str (full_url )
103
+ + "?"
104
+ + urlencode ([(x , y ) for x , y in kw ["params" ].items ()])
93
105
)
94
106
95
107
# If a json payload is provided, serialize it for JSONMatcher support
@@ -122,13 +134,12 @@ async def _on_request(
122
134
headers .append ((key , res ._headers [key ]))
123
135
124
136
# Create mock equivalent HTTP response
125
- _res = HTTPResponse (req .method , self ._url (urlunparse (req .url )))
137
+ _res = HTTPResponse (session , req .method , self ._url (urlunparse (req .url )))
126
138
127
139
# response status
128
- _res .version = (1 , 1 )
140
+ _res .version = aiohttp . HttpVersion (1 , 1 )
129
141
_res .status = res ._status
130
142
_res .reason = http_reasons .get (res ._status )
131
- _res ._should_close = False
132
143
133
144
# Add response headers
134
145
_res ._raw_headers = tuple (headers )
@@ -144,7 +155,7 @@ async def _on_request(
144
155
# Return response based on mock definition
145
156
return _res
146
157
147
- def _patch (self , path ) :
158
+ def _patch (self , path : str ) -> None :
148
159
# If not able to import aiohttp dependencies, skip
149
160
if not yarl or not multidict :
150
161
return None
@@ -170,16 +181,18 @@ async def handler(session, method, url, data=None, headers=None, **kw):
170
181
else :
171
182
self .patchers .append (patcher )
172
183
173
- def activate (self ):
184
+ def activate (self ) -> None :
174
185
"""
175
186
Activates the traffic interceptor.
176
187
This method must be implemented by any interceptor.
177
188
"""
178
- [self ._patch (path ) for path in PATCHES ]
189
+ for path in PATCHES :
190
+ self ._patch (path )
179
191
180
- def disable (self ):
192
+ def disable (self ) -> None :
181
193
"""
182
194
Disables the traffic interceptor.
183
195
This method must be implemented by any interceptor.
184
196
"""
185
- [patch .stop () for patch in self .patchers ]
197
+ for patch in self .patchers :
198
+ patch .stop ()
0 commit comments