Skip to content

Commit 886a6b8

Browse files
authored
Upgrade lastest websockets and Exceptions overhaul (#543)
1 parent fbe03c4 commit 886a6b8

24 files changed

+336
-279
lines changed

gql/client.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
from .graphql_request import GraphQLRequest
3737
from .transport.async_transport import AsyncTransport
38-
from .transport.exceptions import TransportClosed, TransportQueryError
38+
from .transport.exceptions import TransportConnectionFailed, TransportQueryError
3939
from .transport.local_schema import LocalSchemaTransport
4040
from .transport.transport import Transport
4141
from .utilities import build_client_schema, get_introspection_query_ast
@@ -1730,6 +1730,7 @@ async def _connection_loop(self):
17301730
# Then wait for the reconnect event
17311731
self._reconnect_request_event.clear()
17321732
await self._reconnect_request_event.wait()
1733+
await self.transport.close()
17331734

17341735
async def start_connecting_task(self):
17351736
"""Start the task responsible to restart the connection
@@ -1758,7 +1759,7 @@ async def _execute_once(
17581759
**kwargs: Any,
17591760
) -> ExecutionResult:
17601761
"""Same Coroutine as parent method _execute but requesting a
1761-
reconnection if we receive a TransportClosed exception.
1762+
reconnection if we receive a TransportConnectionFailed exception.
17621763
"""
17631764

17641765
try:
@@ -1770,7 +1771,7 @@ async def _execute_once(
17701771
parse_result=parse_result,
17711772
**kwargs,
17721773
)
1773-
except TransportClosed:
1774+
except TransportConnectionFailed:
17741775
self._reconnect_request_event.set()
17751776
raise
17761777

@@ -1786,7 +1787,8 @@ async def _execute(
17861787
**kwargs: Any,
17871788
) -> ExecutionResult:
17881789
"""Same Coroutine as parent, but with optional retries
1789-
and requesting a reconnection if we receive a TransportClosed exception.
1790+
and requesting a reconnection if we receive a
1791+
TransportConnectionFailed exception.
17901792
"""
17911793

17921794
return await self._execute_with_retries(
@@ -1808,7 +1810,7 @@ async def _subscribe(
18081810
**kwargs: Any,
18091811
) -> AsyncGenerator[ExecutionResult, None]:
18101812
"""Same Async generator as parent method _subscribe but requesting a
1811-
reconnection if we receive a TransportClosed exception.
1813+
reconnection if we receive a TransportConnectionFailed exception.
18121814
"""
18131815

18141816
inner_generator: AsyncGenerator[ExecutionResult, None] = super()._subscribe(
@@ -1824,7 +1826,7 @@ async def _subscribe(
18241826
async for result in inner_generator:
18251827
yield result
18261828

1827-
except TransportClosed:
1829+
except TransportConnectionFailed:
18281830
self._reconnect_request_event.set()
18291831
raise
18301832

gql/transport/common/adapters/aiohttp.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -178,12 +178,14 @@ async def send(self, message: str) -> None:
178178
TransportConnectionFailed: If connection closed
179179
"""
180180
if self.websocket is None:
181-
raise TransportConnectionFailed("Connection is already closed")
181+
raise TransportConnectionFailed("WebSocket connection is already closed")
182182

183183
try:
184184
await self.websocket.send_str(message)
185-
except ConnectionResetError as e:
186-
raise TransportConnectionFailed("Connection was closed") from e
185+
except Exception as e:
186+
raise TransportConnectionFailed(
187+
f"Error trying to send data: {type(e).__name__}"
188+
) from e
187189

188190
async def receive(self) -> str:
189191
"""Receive message from the WebSocket server.
@@ -200,6 +202,9 @@ async def receive(self) -> str:
200202
raise TransportConnectionFailed("Connection is already closed")
201203

202204
while True:
205+
# Should not raise any exception:
206+
# https://docs.aiohttp.org/en/stable/_modules/aiohttp/client_ws.html
207+
# #ClientWebSocketResponse.receive
203208
ws_message = await self.websocket.receive()
204209

205210
# Ignore low-level ping and pong received

gql/transport/common/adapters/websockets.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Dict, Optional, Union
44

55
import websockets
6-
from websockets.client import WebSocketClientProtocol
6+
from websockets import ClientConnection
77
from websockets.datastructures import Headers, HeadersLike
88

99
from ...exceptions import TransportConnectionFailed, TransportProtocolError
@@ -40,7 +40,7 @@ def __init__(
4040
self._headers: Optional[HeadersLike] = headers
4141
self.ssl = ssl
4242

43-
self.websocket: Optional[WebSocketClientProtocol] = None
43+
self.websocket: Optional[ClientConnection] = None
4444
self._response_headers: Optional[Headers] = None
4545

4646
async def connect(self) -> None:
@@ -57,7 +57,7 @@ async def connect(self) -> None:
5757
# Set default arguments used in the websockets.connect call
5858
connect_args: Dict[str, Any] = {
5959
"ssl": ssl,
60-
"extra_headers": self.headers,
60+
"additional_headers": self.headers,
6161
}
6262

6363
if self.subprotocols:
@@ -68,11 +68,13 @@ async def connect(self) -> None:
6868

6969
# Connection to the specified url
7070
try:
71-
self.websocket = await websockets.client.connect(self.url, **connect_args)
71+
self.websocket = await websockets.connect(self.url, **connect_args)
7272
except Exception as e:
7373
raise TransportConnectionFailed("Connect failed") from e
7474

75-
self._response_headers = self.websocket.response_headers
75+
assert self.websocket.response is not None
76+
77+
self._response_headers = self.websocket.response.headers
7678

7779
async def send(self, message: str) -> None:
7880
"""Send message to the WebSocket server.
@@ -84,12 +86,14 @@ async def send(self, message: str) -> None:
8486
TransportConnectionFailed: If connection closed
8587
"""
8688
if self.websocket is None:
87-
raise TransportConnectionFailed("Connection is already closed")
89+
raise TransportConnectionFailed("WebSocket connection is already closed")
8890

8991
try:
9092
await self.websocket.send(message)
9193
except Exception as e:
92-
raise TransportConnectionFailed("Connection was closed") from e
94+
raise TransportConnectionFailed(
95+
f"Error trying to send data: {type(e).__name__}"
96+
) from e
9397

9498
async def receive(self) -> str:
9599
"""Receive message from the WebSocket server.
@@ -109,7 +113,9 @@ async def receive(self) -> str:
109113
try:
110114
data = await self.websocket.recv()
111115
except Exception as e:
112-
raise TransportConnectionFailed("Connection was closed") from e
116+
raise TransportConnectionFailed(
117+
f"Error trying to receive data: {type(e).__name__}"
118+
) from e
113119

114120
# websocket.recv() can return either str or bytes
115121
# In our case, we should receive only str here

gql/transport/common/base.py

+23-12
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,13 @@ async def _send(self, message: str) -> None:
127127
"""Send the provided message to the adapter connection and log the message"""
128128

129129
if not self._connected:
130-
raise TransportClosed(
131-
"Transport is not connected"
132-
) from self.close_exception
130+
if isinstance(self.close_exception, TransportConnectionFailed):
131+
raise self.close_exception
132+
else:
133+
raise TransportConnectionFailed() from self.close_exception
133134

134135
try:
136+
# Can raise TransportConnectionFailed
135137
await self.adapter.send(message)
136138
log.info(">>> %s", message)
137139
except TransportConnectionFailed as e:
@@ -143,7 +145,7 @@ async def _receive(self) -> str:
143145

144146
# It is possible that the connection has been already closed in another task
145147
if not self._connected:
146-
raise TransportClosed("Transport is already closed")
148+
raise TransportConnectionFailed() from self.close_exception
147149

148150
# Wait for the next frame.
149151
# Can raise TransportConnectionFailed or TransportProtocolError
@@ -214,8 +216,6 @@ async def _receive_data_loop(self) -> None:
214216
except (TransportConnectionFailed, TransportProtocolError) as e:
215217
await self._fail(e, clean_close=False)
216218
break
217-
except TransportClosed:
218-
break
219219

220220
# Parse the answer
221221
try:
@@ -482,6 +482,10 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
482482
# We should always have an active websocket connection here
483483
assert self._connected
484484

485+
# Saving exception to raise it later if trying to use the transport
486+
# after it has already closed.
487+
self.close_exception = e
488+
485489
# Properly shut down liveness checker if enabled
486490
if self.check_keep_alive_task is not None:
487491
# More info: https://stackoverflow.com/a/43810272/1113207
@@ -492,18 +496,17 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
492496
# Calling the subclass close hook
493497
await self._close_hook()
494498

495-
# Saving exception to raise it later if trying to use the transport
496-
# after it has already closed.
497-
self.close_exception = e
498-
499499
if clean_close:
500500
log.debug("_close_coro: starting clean_close")
501501
try:
502502
await self._clean_close(e)
503503
except Exception as exc: # pragma: no cover
504504
log.warning("Ignoring exception in _clean_close: " + repr(exc))
505505

506-
log.debug("_close_coro: sending exception to listeners")
506+
if log.isEnabledFor(logging.DEBUG):
507+
log.debug(
508+
f"_close_coro: sending exception to {len(self.listeners)} listeners"
509+
)
507510

508511
# Send an exception to all remaining listeners
509512
for query_id, listener in self.listeners.items():
@@ -530,7 +533,15 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
530533
log.debug("_close_coro: exiting")
531534

532535
async def _fail(self, e: Exception, clean_close: bool = True) -> None:
533-
log.debug("_fail: starting with exception: " + repr(e))
536+
if log.isEnabledFor(logging.DEBUG):
537+
import inspect
538+
539+
current_frame = inspect.currentframe()
540+
assert current_frame is not None
541+
caller_frame = current_frame.f_back
542+
assert caller_frame is not None
543+
caller_name = inspect.getframeinfo(caller_frame).function
544+
log.debug(f"_fail from {caller_name}: " + repr(e))
534545

535546
if self.close_task is None:
536547

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
]
5252

5353
install_websockets_requires = [
54-
"websockets>=10.1,<14",
54+
"websockets>=14.2,<16",
5555
]
5656

5757
install_botocore_requires = [

tests/conftest.py

+13-26
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def __init__(self, with_ssl: bool = False):
197197

198198
async def start(self, handler, extra_serve_args=None):
199199

200-
import websockets.server
200+
import websockets
201201

202202
print("Starting server")
203203

@@ -209,16 +209,21 @@ async def start(self, handler, extra_serve_args=None):
209209
extra_serve_args["ssl"] = ssl_context
210210

211211
# Adding dummy response headers
212-
extra_serve_args["extra_headers"] = {"dummy": "test1234"}
212+
extra_headers = {"dummy": "test1234"}
213+
214+
def process_response(connection, request, response):
215+
response.headers.update(extra_headers)
216+
return response
213217

214218
# Start a server with a random open port
215-
self.start_server = websockets.server.serve(
216-
handler, "127.0.0.1", 0, **extra_serve_args
219+
self.server = await websockets.serve(
220+
handler,
221+
"127.0.0.1",
222+
0,
223+
process_response=process_response,
224+
**extra_serve_args,
217225
)
218226

219-
# Wait that the server is started
220-
self.server = await self.start_server
221-
222227
# Get hostname and port
223228
hostname, port = self.server.sockets[0].getsockname()[:2] # type: ignore
224229
assert hostname == "127.0.0.1"
@@ -603,32 +608,14 @@ async def graphqlws_server(request):
603608

604609
subprotocol = "graphql-transport-ws"
605610

606-
from websockets.server import WebSocketServerProtocol
607-
608-
class CustomSubprotocol(WebSocketServerProtocol):
609-
def select_subprotocol(self, client_subprotocols, server_subprotocols):
610-
print(f"Client subprotocols: {client_subprotocols!r}")
611-
print(f"Server subprotocols: {server_subprotocols!r}")
612-
613-
return subprotocol
614-
615-
def process_subprotocol(self, headers, available_subprotocols):
616-
# Overwriting available subprotocols
617-
available_subprotocols = [subprotocol]
618-
619-
print(f"headers: {headers!r}")
620-
# print (f"Available subprotocols: {available_subprotocols!r}")
621-
622-
return super().process_subprotocol(headers, available_subprotocols)
623-
624611
server_handler = get_server_handler(request)
625612

626613
try:
627614
test_server = WebSocketServer()
628615

629616
# Starting the server with the fixture param as the handler function
630617
await test_server.start(
631-
server_handler, extra_serve_args={"create_protocol": CustomSubprotocol}
618+
server_handler, extra_serve_args={"subprotocols": [subprotocol]}
632619
)
633620

634621
yield test_server

tests/test_aiohttp_online.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ async def test_aiohttp_simple_query():
1919
url = "https://countries.trevorblades.com/graphql"
2020

2121
# Get transport
22-
sample_transport = AIOHTTPTransport(url=url)
22+
transport = AIOHTTPTransport(url=url)
2323

2424
# Instanciate client
25-
async with Client(transport=sample_transport) as session:
25+
async with Client(transport=transport) as session:
2626

2727
query = gql(
2828
"""
@@ -60,11 +60,9 @@ async def test_aiohttp_invalid_query():
6060

6161
from gql.transport.aiohttp import AIOHTTPTransport
6262

63-
sample_transport = AIOHTTPTransport(
64-
url="https://countries.trevorblades.com/graphql"
65-
)
63+
transport = AIOHTTPTransport(url="https://countries.trevorblades.com/graphql")
6664

67-
async with Client(transport=sample_transport) as session:
65+
async with Client(transport=transport) as session:
6866

6967
query = gql(
7068
"""
@@ -89,12 +87,12 @@ async def test_aiohttp_two_queries_in_parallel_using_two_tasks():
8987

9088
from gql.transport.aiohttp import AIOHTTPTransport
9189

92-
sample_transport = AIOHTTPTransport(
90+
transport = AIOHTTPTransport(
9391
url="https://countries.trevorblades.com/graphql",
9492
)
9593

9694
# Instanciate client
97-
async with Client(transport=sample_transport) as session:
95+
async with Client(transport=transport) as session:
9896

9997
query1 = gql(
10098
"""

0 commit comments

Comments
 (0)