Skip to content

Commit

Permalink
handle strict_exception_groups=True by unwrapping user exceptions fro…
Browse files Browse the repository at this point in the history
…m within exceptiongroups

revert making close_connection CS shielded, as that would be a behaviour change causing very long stalls with the default timeout of 60s

add comment for pylint disable

move RaisesGroup import
  • Loading branch information
jakkdl committed Jun 17, 2024
1 parent 8f04a5c commit 7eb779b
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 10 deletions.
31 changes: 30 additions & 1 deletion tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
from __future__ import annotations

from functools import partial, wraps
import re
import ssl
import sys
from unittest.mock import patch

import attr
Expand All @@ -48,6 +50,13 @@
except ImportError:
from trio.hazmat import current_task # type: ignore # pylint: disable=ungrouped-imports


# only available on trio>=0.25, we don't use it when testing lower versions
try:
from trio.testing import RaisesGroup
except ImportError:
pass

from trio_websocket import (
connect_websocket,
connect_websocket_url,
Expand All @@ -66,6 +75,9 @@
wrap_server_stream
)

if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin

WS_PROTO_VERSION = tuple(map(int, wsproto.__version__.split('.')))

HOST = '127.0.0.1'
Expand Down Expand Up @@ -427,6 +439,9 @@ async def handler(request):
assert header_key == b'x-test-header'
assert header_value == b'My test header'

def _trio_default_loose() -> bool:
assert re.match(r'^0\.\d\d\.', trio.__version__), "unexpected trio versioning scheme"
return int(trio.__version__[2:4]) < 25

@fail_after(1)
async def test_handshake_exception_before_accept() -> None:
Expand All @@ -436,14 +451,28 @@ async def test_handshake_exception_before_accept() -> None:
async def handler(request):
raise ValueError()

with pytest.raises(ValueError):
# pylint fails to resolve that BaseExceptionGroup will always be available
with pytest.raises((BaseExceptionGroup, ValueError)) as exc: # pylint: disable=possibly-used-before-assignment
async with trio.open_nursery() as nursery:
server = await nursery.start(serve_websocket, handler, HOST, 0,
None)
async with open_websocket(HOST, server.port, RESOURCE,
use_ssl=False):
pass

if _trio_default_loose():
assert isinstance(exc.value, ValueError)
else:
# there's 4 levels of nurseries opened, leading to 4 nested groups:
# 1. this test
# 2. WebSocketServer.run
# 3. trio.serve_listeners
# 4. WebSocketServer._handle_connection
assert RaisesGroup(
RaisesGroup(
RaisesGroup(
RaisesGroup(ValueError)))).matches(exc.value)


@fail_after(1)
async def test_reject_handshake(nursery):
Expand Down
68 changes: 59 additions & 9 deletions trio_websocket/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import urllib.parse
from typing import Iterable, List, Optional, Union

import outcome
import trio
import trio.abc
from wsproto import ConnectionType, WSConnection
Expand Down Expand Up @@ -44,6 +45,10 @@
logger = logging.getLogger('trio-websocket')


class TrioWebsocketInternalError(Exception):
...


def _ignore_cancel(exc):
return None if isinstance(exc, trio.Cancelled) else exc

Expand Down Expand Up @@ -125,10 +130,10 @@ async def open_websocket(
client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`),
or server rejection (:exc:`ConnectionRejected`) during handshakes.
'''
async with trio.open_nursery() as new_nursery:
async def open_connection(nursery: trio.Nursery) -> WebSocketConnection:
try:
with trio.fail_after(connect_timeout):
connection = await connect_websocket(new_nursery, host, port,
return await connect_websocket(nursery, host, port,
resource, use_ssl=use_ssl, subprotocols=subprotocols,
extra_headers=extra_headers,
message_queue_size=message_queue_size,
Expand All @@ -137,14 +142,59 @@ async def open_websocket(
raise ConnectionTimeout from None
except OSError as e:
raise HandshakeError from e

async def close_connection(connection: WebSocketConnection) -> None:
try:
yield connection
finally:
try:
with trio.fail_after(disconnect_timeout):
await connection.aclose()
except trio.TooSlowError:
raise DisconnectionTimeout from None
with trio.fail_after(disconnect_timeout):
await connection.aclose()
except trio.TooSlowError:
raise DisconnectionTimeout from None

connection: WebSocketConnection|None=None
result2: outcome.Maybe[None] | None = None
user_error = None

try:
async with trio.open_nursery() as new_nursery:
result = await outcome.acapture(open_connection, new_nursery)

if isinstance(result, outcome.Value):
connection = result.unwrap()
try:
yield connection
except BaseException as e:
user_error = e
raise
finally:
result2 = await outcome.acapture(close_connection, connection)
# This exception handler should only be entered if:
# 1. The _reader_task started in connect_websocket raises
# 2. User code raises an exception
except BaseExceptionGroup as e:
# user_error, or exception bubbling up from _reader_task
if len(e.exceptions) == 1:
raise e.exceptions[0]
# if the group contains two exceptions, one being Cancelled, and the other
# is user_error => drop Cancelled and raise user_error
# This Cancelled should only have been able to come from _reader_task
if (
len(e.exceptions) == 2
and user_error is not None
and user_error in e.exceptions
and any(isinstance(exc, trio.Cancelled) for exc in e.exceptions)
):
raise user_error # pylint: disable=raise-missing-from,raising-bad-type
raise TrioWebsocketInternalError from e # pragma: no cover
## TODO: handle keyboardinterrupt?

finally:
if result2 is not None:
result2.unwrap()


# error setting up, unwrap that exception
if connection is None:
result.unwrap()


async def connect_websocket(nursery, host, port, resource, *, use_ssl,
Expand Down

0 comments on commit 7eb779b

Please sign in to comment.