diff --git a/docs/index.rst b/docs/index.rst index ae3933c..a6d7471 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,6 +32,7 @@ Autobahn Test Suite `__. getting_started clients servers + timeouts api recipes contributing diff --git a/docs/recipes.rst b/docs/recipes.rst index 19ad0dd..416ecd4 100644 --- a/docs/recipes.rst +++ b/docs/recipes.rst @@ -42,7 +42,7 @@ feature. await trio.sleep(interval) async def main(): - async with open_websocket_url('ws://localhost/foo') as ws: + async with open_websocket_url('ws://my.example/') as ws: async with trio.open_nursery() as nursery: nursery.start_soon(heartbeat, ws, 5, 1) # Your application code goes here: diff --git a/docs/servers.rst b/docs/servers.rst index b01d078..f662036 100644 --- a/docs/servers.rst +++ b/docs/servers.rst @@ -40,7 +40,8 @@ As explained in the tutorial, a WebSocket server needs a handler function and a host/port to bind to. The handler function receives a :class:`WebSocketRequest` object, and it calls the request's :func:`~WebSocketRequest.accept` method to finish the handshake and obtain a -:class:`WebSocketConnection` object. +:class:`WebSocketConnection` object. When the handler function exits, the +connection is automatically closed. .. autofunction:: serve_websocket diff --git a/docs/timeouts.rst b/docs/timeouts.rst new file mode 100644 index 0000000..39e42af --- /dev/null +++ b/docs/timeouts.rst @@ -0,0 +1,197 @@ +Timeouts +======== + +.. currentmodule:: trio_websocket + +Networking code is inherently complex due to the unpredictable nature of network +failures and the possibility of a remote peer that is coded incorrectly—or even +maliciously! Therefore, your code needs to deal with unexpected circumstances. +One common failure mode that you should guard against is a slow or unresponsive +peer. + +This page describes the timeout behavior in ``trio-websocket`` and shows various +examples for implementing timeouts in your own code. Before reading this, you +might find it helpful to read `"Timeouts and cancellation for humans" +`__, an article +written by Trio's author that describes an overall philosophy regarding +timeouts. The short version is that Trio discourages libraries from using +internal timeouts. Instead, it encourages the caller to enforce timeouts, which +makes timeout code easier to compose and reason about. + +On the other hand, this library is intended to be safe to use, and omitting +timeouts could be a dangerous flaw. Therefore, this library takes a balanced +approach to timeouts, where high-level APIs have internal timeouts, but you may +disable them or use lower-level APIs if you want more control over the behavior. + +Message Timeouts +---------------- + +As a motivating example, let's write a client that sends one message and then +expects to receive one message. To guard against a misbehaving server or +network, we want to place a 15 second timeout on this combined send/receive +operation. In other libraries, you might find that the APIs have ``timeout`` +arguments, but that style of timeout is very tedious when composing multiple +operations. In Trio, we have helpful abstractions like cancel scopes, allowing +us to implement our example like this: + +.. code-block:: python + + async with open_websocket_url('ws://my.example/') as ws: + with trio.fail_after(15): + await ws.send_message('test') + msg = await ws.get_message() + print('Received message: {}'.format(msg)) + +The 15 second timeout covers the cumulative time to send one message and to wait +for one response. It raises ``TooSlowError`` if the runtime exceeds 15 seconds. + +Connection Timeouts +------------------- + +The example in the previous section ignores one obvious problem: what if +connecting to the server or closing the connection takes a long time? How do we +apply a timeout to those operations? One option is to put the entire connection +inside a cancel scope: + +.. code-block:: python + + with trio.fail_after(15): + async with open_websocket_url('ws://my.example/') as ws: + await ws.send_message('test') + msg = await ws.get_message() + print('Received message: {}'.format(msg)) + +The approach suffices if we want to compose all four operations into one +timeout: connect, send message, get message, and disconnect. But this approach +will not work if want to separate the timeouts for connecting/disconnecting from +the timeouts for sending and receiving. Let's write a new client that sends +messages periodically, waiting up to 15 seconds for a response to each message +before sending the next message. + +.. code-block:: python + + async with open_websocket_url('ws://my.example/') as ws: + for _ in range(10): + await trio.sleep(30) + with trio.fail_after(15): + await ws.send_message('test') + msg = await ws.get_message() + print('Received message: {}'.format(msg)) + +In this scenario, the ``for`` loop will take at least 300 seconds to run, so we +would like to specify timeouts that apply to connecting and disconnecting but do +not apply to the contents of the context manager block. This is tricky because +the connecting and disconnecting are handled automatically inside the context +manager :func:`open_websocket_url`. Here's one possible approach: + +.. code-block:: python + + with trio.fail_after(10) as cancel_scope: + async with open_websocket_url('ws://my.example'): + cancel_scope.deadline = math.inf + for _ in range(10): + await trio.sleep(30) + with trio.fail_after(15): + await ws.send_message('test') + msg = await ws.get_message() + print('Received message: {}'.format(msg)) + cancel_scope.deadline = trio.current_time() + 5 + +This example places a 10 second timeout on connecting and a separate 5 second +timeout on disconnecting. This is accomplished by wrapping the entire operation +in a cancel scope and then modifying the cancel scope's deadline when entering +and exiting the context manager block. + +This approach works but it is a bit complicated, and we don't want our safety +mechanisms to be complicated! Therefore, the high-level client APIs +:func:`open_websocket` and :func:`open_websocket_url` contain internal timeouts +that apply only to connecting and disconnecting. Let's rewrite the previous +example to use the library's internal timeouts: + +.. code-block:: python + + async with open_websocket_url('ws://my.example/', connect_timeout=10, + disconnect_timeout=5) as ws: + for _ in range(10): + await trio.sleep(30) + with trio.fail_after(15): + await ws.send_message('test') + msg = await ws.get_message() + print('Received message: {}'.format(msg)) + +Just like the previous example, this puts a 10 second timeout on connecting, a +separate 5 second timeout on disconnecting. These internal timeouts violate the +Trio philosophy of composable timeouts, but hopefully the examples in this +section have convinced you that breaking the rules a bit is justified by the +improved safety and ergonomics of this version. + +In fact, these timeouts have actually been present in all of our examples so +far! We just didn't see them because those arguments have default values. If you +really don't like the internal timeouts, you can disable them by passing +``math.inf``, or you can use the low-level APIs instead. + +Timeouts on Low-level APIs +-------------------------- + +In the previous section, we saw how the library's high-level APIs have internal +timeouts. The low-level APIs, like :func:`connect_websocket` and +:func:`connect_websocket_url` do not have internal timeouts, nor are they +context managers. These characteristics make the low-level APIs suitable for +situations where you want very fine-grained control over timeout behavior. + +.. code-block:: python + + async with trio.open_nursery(): + with trio.fail_after(10): + connection = await connect_websocket_url(nursery, 'ws://my.example/') + try: + for _ in range(10): + await trio.sleep(30) + with trio.fail_after(15): + await ws.send_message('test') + msg = await ws.get_message() + print('Received message: {}'.format(msg)) + finally: + with trio.fail_after(5): + await connection.aclose() + +This example applies the same 10 second timeout for connecting and 5 second +timeout for disconnecting as seen in the previous section, but it uses the +lower-level APIs. This approach gives you more control but the low-level APIs +also require more boilerplate, such as creating a nursery and using try/finally +to ensure that the connection is always closed. + +Server Timeouts +--------------- + +The server API also has internal timeouts. These timeouts are configured when +the server is created, and they are enforced on each connection. + +.. code-block:: python + + async def handler(request): + ws = await request.accept() + msg = await ws.get_message() + print('Received message: {}'.format(msg)) + + await serve_websocket(handler, 'localhost', 8080, ssl_context=None, + connect_timeout=10, disconnect_timeout=5) + +The server timeouts work slightly differently from the client timeouts. The +server's connect timeout measures the time between receiving a new TCP +connection and calling the user's handler. The connect timeout +includes waiting for the client's side of the handshake (which is represented by +the ``request`` object), *but it does not include the server's side of the +handshake.* The server handshake needs to be performed inside the user's +handler, e.g. ``await request.accept()``. The disconnect timeout applies to the +time between the handler exiting and the connection being closed. + +Each handler is spawned inside of a nursery, so there is no way for connect and +disconnect timeouts to raise exceptions to your code. (If they did raise +exceptions, they would cancel your nursery and crash your server!) Instead, +connect timeouts cause the connection to be silently closed, and the handler is +never called. For disconnect timeouts, your handler has already exited, so a +timeout will cause the connection to be silently closed. + +As with the client APIs, you can disable the internal timeouts by passing +``math.inf`` or you can use low-level APIs like :func:`wrap_server_stream`. diff --git a/tests/test_connection.py b/tests/test_connection.py index 9a7e498..a4ce171 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,4 +1,35 @@ -from functools import partial +''' +Unit tests for trio_websocket. + +Many of these tests involve networking, i.e. real TCP sockets. To maximize +reliability, all networking tests should follow the following rules: + +- Use localhost only. This is stored in the ``HOST`` global variable. +- Servers use dynamic ports: by passing zero as a port, the system selects a + port that is guaranteed to be available. +- The sequence of events between servers and clients should be controlled as + much as possible to make tests as deterministic. More on determinism below. +- If a test involves timing, e.g. a task needs to ``trio.sleep(…)`` for a bit, + then the ``autojump_clock`` fixture should be used. +- Most tests that involve I/O should have an absolute timeout placed on it to + prevent a hung test from blocking the entire test suite. If a hung test is + cancelled with ctrl+C, then PyTest discards its log messages, which makes + debugging really difficult! The ``fail_after(…)`` decorator places an absolute + timeout on test execution that as measured by Trio's clock. + +`Read more about writing tests with pytest-trio. +`__ + +Determinism is an important property of tests, but it can be tricky to +accomplish with network tests. For example, if a test has a client and a server, +then they may race each other to close the connection first. The test author +should select one side to always initiate the closing handshake. For example, if +a test needs to ensure that the client closes first, then it can have the server +call ``ws.get_message()`` without actually sending it a message. This will cause +the server to block until the client has sent the closing handshake. In other +circumstances +''' +from functools import partial, wraps import attr import pytest @@ -23,6 +54,15 @@ HOST = '127.0.0.1' RESOURCE = '/resource' +# Timeout tests follow a general pattern: one side waits TIMEOUT seconds for an +# event. The other side delays for FORCE_TIMEOUT seconds to force the timeout +# to trigger. Each test also has maximum runtime (measure by Trio's clock) to +# prevent a faulty test from hanging the entire suite. +TIMEOUT = 1 +FORCE_TIMEOUT = 2 +MAX_TIMEOUT_TEST_DURATION = 3 + + @pytest.fixture @async_generator async def echo_server(nursery): @@ -62,6 +102,23 @@ async def echo_conn_handler(conn): pass +class fail_after: + ''' This decorator fails if the runtime of the decorated function (as + measured by the Trio clock) exceeds the specified value. ''' + def __init__(self, seconds): + self._seconds = seconds + + def __call__(self, fn): + @wraps(fn) + async def wrapper(*args, **kwargs): + with trio.move_on_after(self._seconds) as cancel_scope: + await fn(*args, **kwargs) + if cancel_scope.cancelled_caught: + pytest.fail('Test runtime exceeded the maximum {} seconds' + .format(self._seconds)) + return wrapper + + @attr.s(hash=False, cmp=False) class MemoryListener(trio.abc.Listener): closed = attr.ib(default=False) @@ -334,6 +391,132 @@ async def handler(stream): await client.send_message('Hello from client!') +@fail_after(MAX_TIMEOUT_TEST_DURATION) +async def test_client_open_timeout(nursery, autojump_clock): + ''' + The client times out waiting for the server to complete the opening + handshake. + ''' + async def handler(request): + await trio.sleep(FORCE_TIMEOUT) + server_ws = await request.accept() + pytest.fail('Should not reach this line.') + + server = await nursery.start( + partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + + with pytest.raises(trio.TooSlowError): + async with open_websocket(HOST, server.port, '/', use_ssl=False, + connect_timeout=TIMEOUT) as client_ws: + pass + + +@fail_after(MAX_TIMEOUT_TEST_DURATION) +async def test_client_close_timeout(nursery, autojump_clock): + ''' + This client times out waiting for the server to complete the closing + handshake. + + To slow down the server's closing handshake, we make sure that its message + queue size is 0, and the client sends it exactly 1 message. This blocks the + server's reader so it won't do the closing handshake for at least + ``FORCE_TIMEOUT`` seconds. + ''' + async def handler(request): + server_ws = await request.accept() + await trio.sleep(FORCE_TIMEOUT) + # The next line should raise ConnectionClosed. + await server_ws.get_message() + pytest.fail('Should not reach this line.') + + server = await nursery.start( + partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + + with pytest.raises(trio.TooSlowError): + async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, + disconnect_timeout=TIMEOUT) as client_ws: + await client_ws.send_message('test') + + +@fail_after(MAX_TIMEOUT_TEST_DURATION) +async def test_server_open_timeout(autojump_clock): + ''' + The server times out waiting for the client to complete the opening + handshake. + + Server timeouts don't raise exceptions, because handler tasks are launched + in an internal nursery and sending exceptions wouldn't be helpful. Instead, + timed out tasks silently end. + ''' + async def handler(request): + pytest.fail('This handler should not be called.') + + async with trio.open_nursery() as nursery: + server = await nursery.start(partial(serve_websocket, handler, HOST, 0, + ssl_context=None, handler_nursery=nursery, connect_timeout=TIMEOUT)) + + old_task_count = len(nursery.child_tasks) + # This stream is not a WebSocket, so it won't send a handshake: + stream = await trio.open_tcp_stream(HOST, server.port) + # Checkpoint so the server's handler task can spawn: + await trio.sleep(0) + assert len(nursery.child_tasks) == old_task_count + 1, \ + "Server's reader task did not spawn" + # Sleep long enough to trigger server's connect_timeout: + await trio.sleep(FORCE_TIMEOUT) + assert len(nursery.child_tasks) == old_task_count, \ + "Server's reader task is still running" + # Cancel the server task: + nursery.cancel_scope.cancel() + + +@fail_after(MAX_TIMEOUT_TEST_DURATION) +async def test_server_close_timeout(autojump_clock): + ''' + The server times out waiting for the client to complete the closing + handshake. + + Server timeouts don't raise exceptions, because handler tasks are launched + in an internal nursery and sending exceptions wouldn't be helpful. Instead, + timed out tasks silently end. + + To prevent the client from doing the closing handshake, we make sure that + its message queue size is 0 and the server sends it exactly 1 message. This + blocks the client's reader and prevents it from doing the client handshake. + ''' + async def handler(request): + ws = await request.accept() + # Send one message to block the client's reader task: + await ws.send_message('test') + import logging + async with trio.open_nursery() as outer: + server = await outer.start(partial(serve_websocket, handler, HOST, 0, + ssl_context=None, handler_nursery=outer, + disconnect_timeout=TIMEOUT)) + + old_task_count = len(outer.child_tasks) + # Spawn client inside an inner nursery so that we can cancel it's reader + # so that it won't do a closing handshake. + async with trio.open_nursery() as inner: + ws = await connect_websocket(inner, HOST, server.port, RESOURCE, + use_ssl=False) + # Checkpoint so the server can spawn a handler task: + await trio.sleep(0) + assert len(outer.child_tasks) == old_task_count + 1, \ + "Server's reader task did not spawn" + # The client waits long enough to trigger the server's disconnect + # timeout: + await trio.sleep(FORCE_TIMEOUT) + # The server should have cancelled the handler: + assert len(outer.child_tasks) == old_task_count, \ + "Server's reader task is still running" + # Cancel the client's reader task: + inner.cancel_scope.cancel() + + # Cancel the server task: + outer.cancel_scope.cancel() + + async def test_client_does_not_close_handshake(nursery): async def handler(request): server_ws = await request.accept() diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 16b54de..f79d7bb 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -18,13 +18,16 @@ from .version import __version__ + +CONN_TIMEOUT = 60 # default connect & disconnect timeout, in seconds RECEIVE_BYTES = 4096 logger = logging.getLogger('trio-websocket') @asynccontextmanager @async_generator -async def open_websocket(host, port, resource, *, use_ssl, subprotocols=None): +async def open_websocket(host, port, resource, *, use_ssl, subprotocols=None, + connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT): ''' Open a WebSocket client connection to a host. @@ -41,12 +44,21 @@ async def open_websocket(host, port, resource, *, use_ssl, subprotocols=None): :type use_ssl: bool or ssl.SSLContext :param subprotocols: An iterable of strings representing preferred subprotocols. + :param float connect_timeout: The number of seconds to wait for the + connection before timing out. + :param float disconnect_timeout: The number of seconds to wait when closing + the connection before timing out. + :raises trio.TooSlowError: if connecting or disconnecting times out. ''' async with trio.open_nursery() as new_nursery: - connection = await connect_websocket(new_nursery, host, port, resource, - use_ssl=use_ssl, subprotocols=subprotocols) - async with connection: + with trio.fail_after(connect_timeout): + connection = await connect_websocket(new_nursery, host, port, + resource, use_ssl=use_ssl, subprotocols=subprotocols) + try: await yield_(connection) + finally: + with trio.fail_after(disconnect_timeout): + await connection.aclose() async def connect_websocket(nursery, host, port, resource, *, use_ssl, @@ -97,7 +109,8 @@ async def connect_websocket(nursery, host, port, resource, *, use_ssl, return connection -def open_websocket_url(url, ssl_context=None, *, subprotocols=None): +def open_websocket_url(url, ssl_context=None, *, subprotocols=None, + connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT): ''' Open a WebSocket client connection to a URL. @@ -111,6 +124,11 @@ def open_websocket_url(url, ssl_context=None, *, subprotocols=None): :type ssl_context: ssl.SSLContext or None :param subprotocols: An iterable of strings representing preferred subprotocols. + :param float connect_timeout: The number of seconds to wait for the + connection before timing out. + :param float disconnect_timeout: The number of seconds to wait when closing + the connection before timing out. + :raises trio.TooSlowError: if connecting or disconnecting times out. ''' host, port, resource, ssl_context = _url_to_host(url, ssl_context) return open_websocket(host, port, resource, use_ssl=ssl_context, @@ -209,7 +227,8 @@ async def wrap_server_stream(nursery, stream): async def serve_websocket(handler, host, port, ssl_context, *, - handler_nursery=None, task_status=trio.TASK_STATUS_IGNORED): + handler_nursery=None, connect_timeout=CONN_TIMEOUT, + disconnect_timeout=CONN_TIMEOUT, task_status=trio.TASK_STATUS_IGNORED): ''' Serve a WebSocket over TCP. @@ -233,6 +252,10 @@ async def serve_websocket(handler, host, port, ssl_context, *, :type ssl_context: ssl.SSLContext or None :param handler_nursery: An optional nursery to spawn handlers and background tasks in. If not specified, a new nursery will be created internally. + :param float connect_timeout: The number of seconds to wait for a client + to finish connection handshake before timing out. + :param float disconnect_timeout: The number of seconds to wait for a client + to finish the closing handshake before timing out. :param task_status: Part of Trio nursery start protocol. :returns: This function runs until cancelled. ''' @@ -243,7 +266,8 @@ async def serve_websocket(handler, host, port, ssl_context, *, ssl_context, host=host, https_compatible=True) listeners = await open_tcp_listeners() server = WebSocketServer(handler, listeners, - handler_nursery=handler_nursery) + handler_nursery=handler_nursery, connect_timeout=connect_timeout, + disconnect_timeout=disconnect_timeout) await server.run(task_status=task_status) @@ -498,7 +522,7 @@ async def aclose(self, code=1000, reason=None): Close the WebSocket connection. This sends a closing frame and suspends until the connection is closed. - After calling this method, any futher I/O on this WebSocket (such as + After calling this method, any further I/O on this WebSocket (such as ``get_message()`` or ``send_message()``) will raise ``ConnectionClosed``. @@ -512,7 +536,13 @@ async def aclose(self, code=1000, reason=None): # Per AsyncResource interface, calling aclose() on a closed resource # should succeed. return - self._wsproto.close(code=code, reason=reason) + # Wsproto will throw an AttributeError if you close it during the + # handshake phase. This is an open bug: + # https://github.com/python-hyper/wsproto/issues/59 + try: + self._wsproto.close(code=code, reason=reason) + except AttributeError: + pass try: await self._recv_channel.aclose() await self._write_pending() @@ -605,6 +635,11 @@ async def send_message(self, message): self._wsproto.send_data(message) await self._write_pending() + def __str__(self): + ''' Connection ID and type. ''' + type_ = 'client' if self.is_client else 'server' + return '{}-{}'.format(type_, self._id) + async def _abort_web_socket(self): ''' If a stream is closed outside of this class, e.g. due to network @@ -655,7 +690,7 @@ async def _close_web_socket(self, code, reason=None): ''' self._close_reason = CloseReason(code, reason) exc = ConnectionClosed(self._close_reason) - logger.debug('conn#%d websocket closed %r', self._id, exc) + logger.debug('%s websocket closed %r', self, exc) await self._send_channel.aclose() async def _get_request(self): @@ -747,7 +782,7 @@ async def _handle_ping_received_event(self, event): :param event: ''' - logger.debug('conn#%d ping %r', self._id, event.payload) + logger.debug('%s ping %r', self, event.payload) await self._write_pending() async def _handle_pong_received_event(self, event): @@ -774,7 +809,7 @@ async def _handle_pong_received_event(self, event): while self._pings: key, event = self._pings.popitem(0) skipped = ' [skipped] ' if payload != key else ' ' - logger.debug('conn#%d pong%s%r', self._id, skipped, key) + logger.debug('%s pong%s%r', self, skipped, key) event.set() if payload == key: break @@ -802,11 +837,11 @@ async def _reader_task(self): event_type = type(event).__name__ try: handler = handlers[event_type] - logger.debug('conn#%d received event: %s', self._id, + logger.debug('%s received event: %s', self, event_type) await handler(event) except KeyError: - logger.warning('Received unknown event type: "%s"', + logger.warning('%s received unknown event type: "%s"', self, event_type) # Get network data. @@ -816,18 +851,18 @@ async def _reader_task(self): await self._abort_web_socket() break if len(data) == 0: - logger.debug('conn#%d received zero bytes (connection closed)', - self._id) + logger.debug('%s received zero bytes (connection closed)', + self) # If TCP closed before WebSocket, then record it as an abnormal # closure. if not self._wsproto.closed: await self._abort_web_socket() break else: - logger.debug('conn#%d received %d bytes', self._id, len(data)) + logger.debug('%s received %d bytes', self, len(data)) self._wsproto.receive_bytes(data) - logger.debug('conn#%d reader task finished', self._id) + logger.debug('%s reader task finished', self) async def _write_pending(self): ''' Write any pending protocol data to the network socket. ''' @@ -836,14 +871,14 @@ async def _write_pending(self): # The reader task and one or more writers might try to send messages # at the same time, so we need to synchronize access to this stream. async with self._stream_lock: - logger.debug('conn#%d sending %d bytes', self._id, len(data)) + logger.debug('%s sending %d bytes', self, len(data)) try: await self._stream.send_all(data) except (trio.BrokenResourceError, trio.ClosedResourceError): await self._abort_web_socket() raise ConnectionClosed(self._close_reason) from None else: - logger.debug('conn#%d no pending data to send', self._id) + logger.debug('%s no pending data to send', self) class ListenPort: @@ -871,7 +906,8 @@ class WebSocketServer: instance and starts some background tasks, ''' - def __init__(self, handler, listeners, *, handler_nursery=None): + def __init__(self, handler, listeners, *, handler_nursery=None, + connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT): ''' Constructor. @@ -887,12 +923,18 @@ def __init__(self, handler, listeners, *, handler_nursery=None): :param handler_nursery: An optional nursery to spawn connection tasks inside of. If ``None``, then a new nursery will be created internally. + :param float connect_timeout: The number of seconds to wait for a client + to finish connection handshake before timing out. + :param float disconnect_timeout: The number of seconds to wait for a client + to finish the closing handshake before timing out. ''' if len(listeners) == 0: raise ValueError('Listeners must contain at least one item.') self._handler = handler self._handler_nursery = handler_nursery self._listeners = listeners + self._connect_timeout = connect_timeout + self._disconnect_timeout = disconnect_timeout @property def port(self): @@ -973,6 +1015,16 @@ async def _handle_connection(self, stream): wsproto = wsconnection.WSConnection(wsconnection.SERVER) connection = WebSocketConnection(stream, wsproto) nursery.start_soon(connection._reader_task) - async with connection: + with trio.move_on_after(self._connect_timeout) as connect_scope: request = await connection._get_request() + if connect_scope.cancelled_caught: + nursery.cancel_scope.cancel() + await stream.aclose() + return + try: await self._handler(request) + finally: + with trio.move_on_after(self._disconnect_timeout): + # aclose() will shut down the reader task even if its + # cancelled: + await connection.aclose()