diff --git a/docs/index.rst b/docs/index.rst index ae3933c..a6d7471 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,6 +32,7 @@ Autobahn Test Suite `__. getting_started clients servers + timeouts api recipes contributing diff --git a/docs/recipes.rst b/docs/recipes.rst index 19ad0dd..416ecd4 100644 --- a/docs/recipes.rst +++ b/docs/recipes.rst @@ -42,7 +42,7 @@ feature. await trio.sleep(interval) async def main(): - async with open_websocket_url('ws://localhost/foo') as ws: + async with open_websocket_url('ws://my.example/') as ws: async with trio.open_nursery() as nursery: nursery.start_soon(heartbeat, ws, 5, 1) # Your application code goes here: diff --git a/docs/servers.rst b/docs/servers.rst index b01d078..f662036 100644 --- a/docs/servers.rst +++ b/docs/servers.rst @@ -40,7 +40,8 @@ As explained in the tutorial, a WebSocket server needs a handler function and a host/port to bind to. The handler function receives a :class:`WebSocketRequest` object, and it calls the request's :func:`~WebSocketRequest.accept` method to finish the handshake and obtain a -:class:`WebSocketConnection` object. +:class:`WebSocketConnection` object. When the handler function exits, the +connection is automatically closed. .. autofunction:: serve_websocket diff --git a/docs/timeouts.rst b/docs/timeouts.rst new file mode 100644 index 0000000..a15eec4 --- /dev/null +++ b/docs/timeouts.rst @@ -0,0 +1,180 @@ +Timeouts +======== + +.. currentmodule:: trio_websocket + +Networking code is inherently complex due to the unpredictable nature of network +failures and the possibility of a remote peer that is coded incorrectly—or even +maliciously! Therefore, your code needs to deal with unexpected circumstances. +One common failure mode that you should guard against is a slow or unresponsive +peer. + +This page describes the timeout behavior in ``trio-websocket`` and shows various +examples for implementing timeouts in your own code. Before reading this, you +might find it helpful to read `"Timeouts and cancellation for humans" +`__, an article +written by Trio's author that describes an overall philosophy regarding +timeouts. The short version is that Trio discourages libraries from using +internal timeouts. Instead, it encourages the caller to enforce timeouts, which +makes timeout code easier to compose and reason about. + +On the other hand, this library is intended to be safe to use, and omitting +timeouts could be a dangerous flaw. Therefore, this library contains takes a +balanced approach to timeouts, where high-level APIs have internal timeouts, but +you may disable them or use lower-level APIs if you want more control over the +behavior. + +Built-in Client Timeouts +------------------------ + +The high-level client APIs :func:`open_websocket` and :func:`open_websocket_url` +contain built-in timeouts for connecting to a WebSocket and disconnecting from a +WebSocket. These timeouts are built-in for two reasons: + +1. Omitting timeouts may be dangerous, and this library strives to make safe + code easy to write. +2. These high-level APIs are context managers, and composing timeouts with + context managers is tricky. + +These built-in timeouts make it easy to write a WebSocket client that won't hang +indefinitely if the remote endpoint or network are misbehaving. The following +example shows a connect timeout of 10 seconds. This guarantees that the block +will start executing (reaching the line that prints "Connected") within 10 +seconds. When the context manager exits after the ``print(Received response: +…)``, the disconnect timeout guarantees that it will take no more than 5 seconds +to reach the line that prints "Disconnected". If either timeout is exceeded, +then the entire block raises ``trio.TooSlowError``. + +.. code-block:: python + + async with open_websocket_url('ws://my.example/', connect_timeout=10, + disconnect_timeout=5) as ws: + print("Connected") + await ws.send_message('hello from client!') + response = await ws.get_message() + print('Received response: {}'.format(response)) + print("Disconnected") + +.. note:: + + The built-in timeouts do not affect the contents of the block! In this + example, the client waits to receive a message. If the server never sends a + message, then the client will block indefinitely on ``ws.get_message()``. + Placing timeouts inside blocks is discussed below. + +What if you decided that you really wanted to manage the timeouts yourself? The +following example implements the same timeout behavior explicitly, without +relying on the library's built-in timeouts. + +.. code-block:: python + + with trio.move_on_after(10) as cancel_scope: + async with open_websocket_url('ws://my.example', + connect_timeout=math.inf, disconnect_timeout=math.inf): + print("Connected") + cancel_scope.deadline = math.inf + await ws.send_message('hello from client!') + response = await ws.get_message() + print('Received response: {}'.format(response)) + cancel_scope.deadline = trio.current_time() + 5 + print("Disconnected") + +Notice that the library's internal timeouts are disabled by passing +``math.inf``. This example is less ergonomic than using the built-in timeouts. +If you really want to customize this behavior, you may want to use the low-level +APIs instead, which are discussed below. + +Timeouts Inside Blocks +---------------------- + +The built-in timeouts do not apply to the contents of the block. One of the +examples above would hang on ``ws.get_message()`` if the remote endpoint never +sends a message. If you want to enforce a timeout in this situation, you must to +do it explicitly: + +.. code-block:: python + + async with open_websocket_url('ws://my.example/', connect_timeout=10, + disconnect_timeout=5) as ws: + with trio.fail_after(15): + msg = await ws.get_message() + print('Received message: {}'.format(msg)) + +This example waits up to 15 seconds to get one message from the server, raising +``trio.TooSlowError`` if the timeout is exceeded. Notice in this example that +the message timeout is larger than the connect and disconnect timeouts, +illustrating that the connect and disconnect timeouts do not apply to the +contents of the block. + +Alternatively, you might apply one timeout to the entire operation: connect to +the server, get one message, and disconnect. + +.. code-block:: python + + with trio.fail_after(15): + async with open_websocket_url('ws://my.example/', + connect_timeout=math.inf, disconnect_timeout=math.inf) as ws: + msg = await ws.get_message() + print('Received message: {}'.format(msg)) + +Note that the internal timeouts are disabled in this example. + +Timeouts on Low-level APIs +-------------------------- + +We saw an example above where explicit timeouts were applied to the context +managers. In practice, if you need to customize timeout behavior, the low-level +APIs like :func:`connect_websocket_url` etc. will be clearer and easier to use. +This example implements the same timeouts above using the low-level APIs. + +.. code-block:: python + + with trio.fail_after(10): + connection = await connect_websocket_url('ws://my.example/') + print("Connected") + try: + await ws.send_message('hello from client!') + response = await ws.get_message() + print('Received response: {}'.format(response)) + finally: + with trio.fail_after(5): + await connection.aclose() + print("Disconnected") + +The low-level APIs make the timeout code easier to read, but we also have to add +try/finally blocks if we want the same behavior that the context manager +guarantees. + +Built-in Server Timeouts +------------------------ + +The server API also offer built-in timeouts. These timeouts are configured when +the server is created, and they are enforced on each connection. + +.. code-block:: python + + async def handler(request): + ws = await request.accept() + msg = await ws.get_message() + print('Received message: {}'.format(msg)) + + await serve_websocket(handler, 'localhost', 8080, ssl_context=None, + connect_timeout=10, disconnect_timeout=5) + +The server timeouts work slightly differently from the client timeouts. The +connect timeout measures the time between when a TCP connection is received and +when the user's handler is called. As a consequence, the connect timeout +includes waiting for the client's side of the handshake, which is represented by +the ``request`` object. *It does not include the server's side of the +handshake,* because the server handshake needs to be performed inside the user's +handler, i.e. ``await request.accept()``. The disconnect timeout applies to the +time between the handler exiting and the connection being closed. + +Each handler is spawned inside of a nursery, so there is no way for connect and +disconnect timeouts to raise exceptions to your code. Instead, connect timeouts +result cause the connection to be silently closed, and handler is never called. +For disconnect timeouts, your handler has already exited, so a timeout will +cause the connection to be silently closed. + +As with the client APIs, you can disable the internal timeouts by passing +``math.inf``. diff --git a/tests/test_connection.py b/tests/test_connection.py index 9a7e498..3592fb5 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,4 +1,35 @@ -from functools import partial +''' +Unit tests for trio_websocket. + +Many of these tests involve networking, i.e. real TCP sockets. To maximize +reliability, all networking tests should follow the following rules: + +- Use localhost only. This is stored in the ``HOST`` global variable. +- Servers use dynamic ports: by passing zero as a port, the system selects a + port that is guaranteed to be available. +- The sequence of events between servers and clients should be controlled as + much as possible to make tests as deterministic. More on determinism below. +- If a test involves timing, e.g. a task needs to ``trio.sleep(…)`` for a bit, + then the ``autojump_clock`` fixture should be used. +- Most tests that involve I/O should have an absolute timeout placed on it to + prevent a hung test from blocking the entire test suite. If a hung test is + cancelled with ctrl+C, then PyTest discards its log messages, which makes + debugging really difficult! The ``fail_after(…)`` decorator places an absolute + timeout on test execution that as measured by Trio's clock. + +`Read more about writing tests with pytest-trio. +`__ + +Determinism is an important property of tests, but it can be tricky to +accomplish with network tests. For example, if a test has a client and a server, +then they may race each other to close the connection first. The test author +should select one side to always initiate the closing handshake. For example, if +a test needs to ensure that the client closes first, then it can have the server +call ``ws.get_message()`` without actually sending it a message. This will cause +the server to block until the client has sent the closing handshake. In other +circumstances +''' +from functools import partial, wraps import attr import pytest @@ -23,6 +54,15 @@ HOST = '127.0.0.1' RESOURCE = '/resource' +# Timeout tests follow a general pattern: one side waits TIMEOUT seconds for an +# event. The other side delays for FORCE_TIMEOUT seconds to force the timeout +# to trigger. Each test also has maximum runtime (measure by Trio's clock) to +# prevent a faulty test from hanging the entire suite. +TIMEOUT = 1 +FORCE_TIMEOUT = 2 +MAX_TIMEOUT_TEST_DURATION = 3 + + @pytest.fixture @async_generator async def echo_server(nursery): @@ -62,6 +102,23 @@ async def echo_conn_handler(conn): pass +class fail_after: + ''' This decorator fails a if its runtime (as measured by the Trio clock) + exceeds the specified value. ''' + def __init__(self, seconds): + self._seconds = seconds + + def __call__(self, fn): + @wraps(fn) + async def wrapper(*args, **kwargs): + with trio.move_on_after(self._seconds) as cancel_scope: + await fn(*args, **kwargs) + if cancel_scope.cancelled_caught: + pytest.fail('Test runtime exceeded the maximum {} seconds' + .format(self._seconds)) + return wrapper + + @attr.s(hash=False, cmp=False) class MemoryListener(trio.abc.Listener): closed = attr.ib(default=False) @@ -334,6 +391,132 @@ async def handler(stream): await client.send_message('Hello from client!') +@fail_after(MAX_TIMEOUT_TEST_DURATION) +async def test_client_open_timeout(nursery, autojump_clock): + ''' + The client times out waiting for the server to complete the opening + handshake. + ''' + async def handler(request): + await trio.sleep(FORCE_TIMEOUT) + server_ws = await request.accept() + pytest.fail('Should not reach this line.') + + server = await nursery.start( + partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + + with pytest.raises(trio.TooSlowError): + async with open_websocket(HOST, server.port, '/', use_ssl=False, + connect_timeout=TIMEOUT) as client_ws: + pass + + +@fail_after(MAX_TIMEOUT_TEST_DURATION) +async def test_client_close_timeout(nursery, autojump_clock): + ''' + This client times out waiting for the server to complete the closing + handshake. + + To slow down the server's closing handshake, we make sure that its message + queue size is 0, and the client sends it exactly 1 message. This blocks the + server's reader so it won't do the closing handshake for at least + ``FORCE_TIMEOUT`` seconds. + ''' + async def handler(request): + server_ws = await request.accept() + await trio.sleep(FORCE_TIMEOUT) + # The next line should raise ConnectionClosed. + await server_ws.get_message() + pytest.fail('Should not reach this line.') + + server = await nursery.start( + partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + + with pytest.raises(trio.TooSlowError): + async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, + disconnect_timeout=TIMEOUT) as client_ws: + await client_ws.send_message('test') + + +@fail_after(MAX_TIMEOUT_TEST_DURATION) +async def test_server_open_timeout(autojump_clock): + ''' + The server times out waiting for the client to complete the opening + handshake. + + Server timeouts don't raise exceptions, because handler tasks are launched + in an internal nursery and sending exceptions wouldn't be helpful. Instead, + timed out tasks silently end. + ''' + async def handler(request): + pytest.fail('This handler should not be called.') + + async with trio.open_nursery() as nursery: + server = await nursery.start(partial(serve_websocket, handler, HOST, 0, + ssl_context=None, handler_nursery=nursery, connect_timeout=TIMEOUT)) + + old_task_count = len(nursery.child_tasks) + # This stream is not a WebSocket, so it won't send a handshake: + stream = await trio.open_tcp_stream(HOST, server.port) + # Checkpoint so the server's handler task can spawn: + await trio.sleep(0) + assert len(nursery.child_tasks) == old_task_count + 1, \ + "Server's reader task did not spawn" + # Sleep long enough to trigger server's connect_timeout: + await trio.sleep(FORCE_TIMEOUT) + assert len(nursery.child_tasks) == old_task_count, \ + "Server's reader task is still running" + # Cancel the server task: + nursery.cancel_scope.cancel() + + +@fail_after(MAX_TIMEOUT_TEST_DURATION) +async def test_server_close_timeout(autojump_clock): + ''' + The server times out waiting for the client to complete the closing + handshake. + + Server timeouts don't raise exceptions, because handler tasks are launched + in an internal nursery and sending exceptions wouldn't be helpful. Instead, + timed out tasks silently end. + + To prevent the client from doing the closing handshake, we make sure that + its message queue size is 0 and the server sends it exactly 1 message. This + blocks the client's reader and prevents it from doing the client handshake. + ''' + async def handler(request): + ws = await request.accept() + # Send one message to block the client's reader task: + await ws.send_message('test') + import logging + async with trio.open_nursery() as outer: + server = await outer.start(partial(serve_websocket, handler, HOST, 0, + ssl_context=None, handler_nursery=outer, + disconnect_timeout=TIMEOUT)) + + old_task_count = len(outer.child_tasks) + # Spawn client inside an inner nursery so that we can cancel it's reader + # so that it won't do a closing handshake. + async with trio.open_nursery() as inner: + ws = await connect_websocket(inner, HOST, server.port, RESOURCE, + use_ssl=False) + # Checkpoint so the server can spawn a handler task: + await trio.sleep(0) + assert len(outer.child_tasks) == old_task_count + 1, \ + "Server's reader task did not spawn" + # The client waits long enough to trigger the server's disconnect + # timeout: + await trio.sleep(FORCE_TIMEOUT) + # The server should have cancelled the handler: + assert len(outer.child_tasks) == old_task_count, \ + "Server's reader task is still running" + # Cancel the client's reader task: + inner.cancel_scope.cancel() + + # Cancel the server task: + outer.cancel_scope.cancel() + + async def test_client_does_not_close_handshake(nursery): async def handler(request): server_ws = await request.accept() diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 16b54de..b05c12a 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -18,13 +18,16 @@ from .version import __version__ + +CONN_TIMEOUT = 30 # default connect & disconnect timeout, in seconds RECEIVE_BYTES = 4096 logger = logging.getLogger('trio-websocket') @asynccontextmanager @async_generator -async def open_websocket(host, port, resource, *, use_ssl, subprotocols=None): +async def open_websocket(host, port, resource, *, use_ssl, subprotocols=None, + connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT): ''' Open a WebSocket client connection to a host. @@ -41,12 +44,21 @@ async def open_websocket(host, port, resource, *, use_ssl, subprotocols=None): :type use_ssl: bool or ssl.SSLContext :param subprotocols: An iterable of strings representing preferred subprotocols. + :param float connect_timeout: The number of seconds to wait for the + connection before timing out. + :param float disconnect_timeout: The number of seconds to wait when closing + the connection before timing out. + :raises trio.TooSlowError: if connecting or disconnecting times out. ''' async with trio.open_nursery() as new_nursery: - connection = await connect_websocket(new_nursery, host, port, resource, - use_ssl=use_ssl, subprotocols=subprotocols) - async with connection: + with trio.fail_after(connect_timeout): + connection = await connect_websocket(new_nursery, host, port, + resource, use_ssl=use_ssl, subprotocols=subprotocols) + try: await yield_(connection) + finally: + with trio.fail_after(disconnect_timeout): + await connection.aclose() async def connect_websocket(nursery, host, port, resource, *, use_ssl, @@ -97,7 +109,8 @@ async def connect_websocket(nursery, host, port, resource, *, use_ssl, return connection -def open_websocket_url(url, ssl_context=None, *, subprotocols=None): +def open_websocket_url(url, ssl_context=None, *, subprotocols=None, + connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT): ''' Open a WebSocket client connection to a URL. @@ -111,6 +124,11 @@ def open_websocket_url(url, ssl_context=None, *, subprotocols=None): :type ssl_context: ssl.SSLContext or None :param subprotocols: An iterable of strings representing preferred subprotocols. + :param float connect_timeout: The number of seconds to wait for the + connection before timing out. + :param float disconnect_timeout: The number of seconds to wait when closing + the connection before timing out. + :raises trio.TooSlowError: if connecting or disconnecting times out. ''' host, port, resource, ssl_context = _url_to_host(url, ssl_context) return open_websocket(host, port, resource, use_ssl=ssl_context, @@ -209,7 +227,8 @@ async def wrap_server_stream(nursery, stream): async def serve_websocket(handler, host, port, ssl_context, *, - handler_nursery=None, task_status=trio.TASK_STATUS_IGNORED): + handler_nursery=None, connect_timeout=CONN_TIMEOUT, + disconnect_timeout=CONN_TIMEOUT, task_status=trio.TASK_STATUS_IGNORED): ''' Serve a WebSocket over TCP. @@ -233,6 +252,10 @@ async def serve_websocket(handler, host, port, ssl_context, *, :type ssl_context: ssl.SSLContext or None :param handler_nursery: An optional nursery to spawn handlers and background tasks in. If not specified, a new nursery will be created internally. + :param float connect_timeout: The number of seconds to wait for a client + to finish connection handshake before timing out. + :param float disconnect_timeout: The number of seconds to wait for a client + to finish the closing handshake before timing out. :param task_status: Part of Trio nursery start protocol. :returns: This function runs until cancelled. ''' @@ -243,7 +266,8 @@ async def serve_websocket(handler, host, port, ssl_context, *, ssl_context, host=host, https_compatible=True) listeners = await open_tcp_listeners() server = WebSocketServer(handler, listeners, - handler_nursery=handler_nursery) + handler_nursery=handler_nursery, connect_timeout=connect_timeout, + disconnect_timeout=disconnect_timeout) await server.run(task_status=task_status) @@ -498,7 +522,7 @@ async def aclose(self, code=1000, reason=None): Close the WebSocket connection. This sends a closing frame and suspends until the connection is closed. - After calling this method, any futher I/O on this WebSocket (such as + After calling this method, any further I/O on this WebSocket (such as ``get_message()`` or ``send_message()``) will raise ``ConnectionClosed``. @@ -512,7 +536,13 @@ async def aclose(self, code=1000, reason=None): # Per AsyncResource interface, calling aclose() on a closed resource # should succeed. return - self._wsproto.close(code=code, reason=reason) + # Wsproto will throw an AttributeError if you close it during the + # handshake phase. This is an open bug: + # https://github.com/python-hyper/wsproto/issues/59 + try: + self._wsproto.close(code=code, reason=reason) + except AttributeError: + pass try: await self._recv_channel.aclose() await self._write_pending() @@ -605,6 +635,11 @@ async def send_message(self, message): self._wsproto.send_data(message) await self._write_pending() + def __str__(self): + ''' Connection ID and type. ''' + type_ = 'client' if self.is_client else 'server' + return '{}-{}'.format(type_, self._id) + async def _abort_web_socket(self): ''' If a stream is closed outside of this class, e.g. due to network @@ -655,7 +690,7 @@ async def _close_web_socket(self, code, reason=None): ''' self._close_reason = CloseReason(code, reason) exc = ConnectionClosed(self._close_reason) - logger.debug('conn#%d websocket closed %r', self._id, exc) + logger.debug('%s websocket closed %r', self, exc) await self._send_channel.aclose() async def _get_request(self): @@ -747,7 +782,7 @@ async def _handle_ping_received_event(self, event): :param event: ''' - logger.debug('conn#%d ping %r', self._id, event.payload) + logger.debug('%s ping %r', self, event.payload) await self._write_pending() async def _handle_pong_received_event(self, event): @@ -774,7 +809,7 @@ async def _handle_pong_received_event(self, event): while self._pings: key, event = self._pings.popitem(0) skipped = ' [skipped] ' if payload != key else ' ' - logger.debug('conn#%d pong%s%r', self._id, skipped, key) + logger.debug('%s pong%s%r', self, skipped, key) event.set() if payload == key: break @@ -802,11 +837,11 @@ async def _reader_task(self): event_type = type(event).__name__ try: handler = handlers[event_type] - logger.debug('conn#%d received event: %s', self._id, + logger.debug('%s received event: %s', self, event_type) await handler(event) except KeyError: - logger.warning('Received unknown event type: "%s"', + logger.warning('%s received unknown event type: "%s"', self, event_type) # Get network data. @@ -816,18 +851,18 @@ async def _reader_task(self): await self._abort_web_socket() break if len(data) == 0: - logger.debug('conn#%d received zero bytes (connection closed)', - self._id) + logger.debug('%s received zero bytes (connection closed)', + self) # If TCP closed before WebSocket, then record it as an abnormal # closure. if not self._wsproto.closed: await self._abort_web_socket() break else: - logger.debug('conn#%d received %d bytes', self._id, len(data)) + logger.debug('%s received %d bytes', self, len(data)) self._wsproto.receive_bytes(data) - logger.debug('conn#%d reader task finished', self._id) + logger.debug('%s reader task finished', self) async def _write_pending(self): ''' Write any pending protocol data to the network socket. ''' @@ -836,14 +871,14 @@ async def _write_pending(self): # The reader task and one or more writers might try to send messages # at the same time, so we need to synchronize access to this stream. async with self._stream_lock: - logger.debug('conn#%d sending %d bytes', self._id, len(data)) + logger.debug('%s sending %d bytes', self, len(data)) try: await self._stream.send_all(data) except (trio.BrokenResourceError, trio.ClosedResourceError): await self._abort_web_socket() raise ConnectionClosed(self._close_reason) from None else: - logger.debug('conn#%d no pending data to send', self._id) + logger.debug('%s no pending data to send', self) class ListenPort: @@ -871,7 +906,8 @@ class WebSocketServer: instance and starts some background tasks, ''' - def __init__(self, handler, listeners, *, handler_nursery=None): + def __init__(self, handler, listeners, *, handler_nursery=None, + connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT): ''' Constructor. @@ -887,12 +923,18 @@ def __init__(self, handler, listeners, *, handler_nursery=None): :param handler_nursery: An optional nursery to spawn connection tasks inside of. If ``None``, then a new nursery will be created internally. + :param float connect_timeout: The number of seconds to wait for a client + to finish connection handshake before timing out. + :param float disconnect_timeout: The number of seconds to wait for a client + to finish the closing handshake before timing out. ''' if len(listeners) == 0: raise ValueError('Listeners must contain at least one item.') self._handler = handler self._handler_nursery = handler_nursery self._listeners = listeners + self._connect_timeout = connect_timeout + self._disconnect_timeout = disconnect_timeout @property def port(self): @@ -973,6 +1015,16 @@ async def _handle_connection(self, stream): wsproto = wsconnection.WSConnection(wsconnection.SERVER) connection = WebSocketConnection(stream, wsproto) nursery.start_soon(connection._reader_task) - async with connection: + with trio.move_on_after(self._connect_timeout) as connect_scope: request = await connection._get_request() + if connect_scope.cancelled_caught: + nursery.cancel_scope.cancel() + await stream.aclose() + return + try: await self._handler(request) + finally: + with trio.move_on_after(self._disconnect_timeout): + # aclose() will shut down the reader task even if its + # cancelled: + await connection.aclose()