From 7eb779b727b4e029391db50a5f6e24b9541dece2 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 6 Jun 2024 18:04:25 +0200 Subject: [PATCH] handle strict_exception_groups=True by unwrapping user exceptions from within exceptiongroups revert making close_connection CS shielded, as that would be a behaviour change causing very long stalls with the default timeout of 60s add comment for pylint disable move RaisesGroup import --- tests/test_connection.py | 31 +++++++++++++++++- trio_websocket/_impl.py | 68 ++++++++++++++++++++++++++++++++++------ 2 files changed, 89 insertions(+), 10 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 3d45eb7..ab43725 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -32,7 +32,9 @@ from __future__ import annotations from functools import partial, wraps +import re import ssl +import sys from unittest.mock import patch import attr @@ -48,6 +50,13 @@ except ImportError: from trio.hazmat import current_task # type: ignore # pylint: disable=ungrouped-imports + +# only available on trio>=0.25, we don't use it when testing lower versions +try: + from trio.testing import RaisesGroup +except ImportError: + pass + from trio_websocket import ( connect_websocket, connect_websocket_url, @@ -66,6 +75,9 @@ wrap_server_stream ) +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin + WS_PROTO_VERSION = tuple(map(int, wsproto.__version__.split('.'))) HOST = '127.0.0.1' @@ -427,6 +439,9 @@ async def handler(request): assert header_key == b'x-test-header' assert header_value == b'My test header' +def _trio_default_loose() -> bool: + assert re.match(r'^0\.\d\d\.', trio.__version__), "unexpected trio versioning scheme" + return int(trio.__version__[2:4]) < 25 @fail_after(1) async def test_handshake_exception_before_accept() -> None: @@ -436,7 +451,8 @@ async def test_handshake_exception_before_accept() -> None: async def handler(request): raise ValueError() - with pytest.raises(ValueError): + # pylint fails to resolve that BaseExceptionGroup will always be available + with pytest.raises((BaseExceptionGroup, ValueError)) as exc: # pylint: disable=possibly-used-before-assignment async with trio.open_nursery() as nursery: server = await nursery.start(serve_websocket, handler, HOST, 0, None) @@ -444,6 +460,19 @@ async def handler(request): use_ssl=False): pass + if _trio_default_loose(): + assert isinstance(exc.value, ValueError) + else: + # there's 4 levels of nurseries opened, leading to 4 nested groups: + # 1. this test + # 2. WebSocketServer.run + # 3. trio.serve_listeners + # 4. WebSocketServer._handle_connection + assert RaisesGroup( + RaisesGroup( + RaisesGroup( + RaisesGroup(ValueError)))).matches(exc.value) + @fail_after(1) async def test_reject_handshake(nursery): diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index b153034..f28eb15 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -13,6 +13,7 @@ import urllib.parse from typing import Iterable, List, Optional, Union +import outcome import trio import trio.abc from wsproto import ConnectionType, WSConnection @@ -44,6 +45,10 @@ logger = logging.getLogger('trio-websocket') +class TrioWebsocketInternalError(Exception): + ... + + def _ignore_cancel(exc): return None if isinstance(exc, trio.Cancelled) else exc @@ -125,10 +130,10 @@ async def open_websocket( client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`), or server rejection (:exc:`ConnectionRejected`) during handshakes. ''' - async with trio.open_nursery() as new_nursery: + async def open_connection(nursery: trio.Nursery) -> WebSocketConnection: try: with trio.fail_after(connect_timeout): - connection = await connect_websocket(new_nursery, host, port, + return await connect_websocket(nursery, host, port, resource, use_ssl=use_ssl, subprotocols=subprotocols, extra_headers=extra_headers, message_queue_size=message_queue_size, @@ -137,14 +142,59 @@ async def open_websocket( raise ConnectionTimeout from None except OSError as e: raise HandshakeError from e + + async def close_connection(connection: WebSocketConnection) -> None: try: - yield connection - finally: - try: - with trio.fail_after(disconnect_timeout): - await connection.aclose() - except trio.TooSlowError: - raise DisconnectionTimeout from None + with trio.fail_after(disconnect_timeout): + await connection.aclose() + except trio.TooSlowError: + raise DisconnectionTimeout from None + + connection: WebSocketConnection|None=None + result2: outcome.Maybe[None] | None = None + user_error = None + + try: + async with trio.open_nursery() as new_nursery: + result = await outcome.acapture(open_connection, new_nursery) + + if isinstance(result, outcome.Value): + connection = result.unwrap() + try: + yield connection + except BaseException as e: + user_error = e + raise + finally: + result2 = await outcome.acapture(close_connection, connection) + # This exception handler should only be entered if: + # 1. The _reader_task started in connect_websocket raises + # 2. User code raises an exception + except BaseExceptionGroup as e: + # user_error, or exception bubbling up from _reader_task + if len(e.exceptions) == 1: + raise e.exceptions[0] + # if the group contains two exceptions, one being Cancelled, and the other + # is user_error => drop Cancelled and raise user_error + # This Cancelled should only have been able to come from _reader_task + if ( + len(e.exceptions) == 2 + and user_error is not None + and user_error in e.exceptions + and any(isinstance(exc, trio.Cancelled) for exc in e.exceptions) + ): + raise user_error # pylint: disable=raise-missing-from,raising-bad-type + raise TrioWebsocketInternalError from e # pragma: no cover + ## TODO: handle keyboardinterrupt? + + finally: + if result2 is not None: + result2.unwrap() + + + # error setting up, unwrap that exception + if connection is None: + result.unwrap() async def connect_websocket(nursery, host, port, resource, *, use_ssl,