diff --git a/README.rst b/README.rst index 01c337e7..f90d59bd 100644 --- a/README.rst +++ b/README.rst @@ -62,6 +62,7 @@ Features * connection migration and NAT rebinding * logging TLS traffic secrets * logging QUIC events in QLOG format + * version negotiation conforming with `RFC 9368`_ - HTTP/3 stack conforming with `RFC 9114`_ * server push support * WebSocket bootstrapping conforming with `RFC 9220`_ @@ -156,4 +157,5 @@ License .. _RFC 9114: https://datatracker.ietf.org/doc/html/rfc9114 .. _RFC 9220: https://datatracker.ietf.org/doc/html/rfc9220 .. _RFC 9297: https://datatracker.ietf.org/doc/html/rfc9297 +.. _RFC 9368: https://datatracker.ietf.org/doc/html/rfc9368 .. _RFC 9369: https://datatracker.ietf.org/doc/html/rfc9369 diff --git a/src/aioquic/quic/configuration.py b/src/aioquic/quic/configuration.py index 31ec779d..5dc9612d 100644 --- a/src/aioquic/quic/configuration.py +++ b/src/aioquic/quic/configuration.py @@ -100,6 +100,7 @@ class QuicConfiguration: .. note:: This is only used by clients. """ + # For internal purposes, not guaranteed to be stable. cadata: Optional[bytes] = None cafile: Optional[str] = None capath: Optional[str] = None @@ -108,11 +109,13 @@ class QuicConfiguration: cipher_suites: Optional[List[CipherSuite]] = None initial_rtt: float = 0.1 max_datagram_frame_size: Optional[int] = None + original_version: Optional[int] = None private_key: Any = None quantum_readiness_test: bool = False supported_versions: List[int] = field( default_factory=lambda: [ QuicProtocolVersion.VERSION_1, + QuicProtocolVersion.VERSION_2, ] ) verify_mode: Optional[int] = None diff --git a/src/aioquic/quic/connection.py b/src/aioquic/quic/connection.py index b3356aac..92aa0376 100644 --- a/src/aioquic/quic/connection.py +++ b/src/aioquic/quic/connection.py @@ -29,7 +29,7 @@ from . import events from .configuration import SMALLEST_MAX_DATAGRAM_SIZE, QuicConfiguration from .congestion.base import K_GRANULARITY -from .crypto import CryptoError, CryptoPair, KeyUnavailableError +from .crypto import CryptoError, CryptoPair, KeyUnavailableError, NoCallback from .logger import QuicLoggerTrace from .packet import ( CONNECTION_ID_MAX_SIZE, @@ -41,8 +41,10 @@ QuicFrameType, QuicHeader, QuicPacketType, + QuicProtocolVersion, QuicStreamFrame, QuicTransportParameters, + QuicVersionInformation, get_retry_integrity_tag, get_spin_bit, pretty_protocol_version, @@ -113,6 +115,18 @@ def EPOCHS(shortcut: str) -> FrozenSet[tls.Epoch]: return frozenset(EPOCH_SHORTCUTS[i] for i in shortcut) +def is_version_compatible(from_version: int, to_version: int) -> bool: + """ + Return whether it is possible to perform compatible version negotiation + from `from_version` to `to_version`. + """ + # Version 1 is compatible with version 2 and vice versa. These are the + # only compatible versions so far. + return set([from_version, to_version]) == set( + [QuicProtocolVersion.VERSION_1, QuicProtocolVersion.VERSION_2] + ) + + def dump_cid(cid: bytes) -> str: return binascii.hexlify(cid).decode("ascii") @@ -205,6 +219,7 @@ class QuicReceiveContext: network_path: QuicNetworkPath quic_logger_frames: Optional[List[Any]] time: float + version: Optional[int] QuicTokenHandler = Callable[[bytes], None] @@ -278,8 +293,10 @@ def __init__( self._close_event: Optional[events.ConnectionTerminated] = None self._connect_called = False self._cryptos: Dict[tls.Epoch, CryptoPair] = {} + self._cryptos_initial: Dict[int, CryptoPair] = {} self._crypto_buffers: Dict[tls.Epoch, Buffer] = {} self._crypto_frame_type: Optional[int] = None + self._crypto_packet_version: Optional[int] = None self._crypto_retransmitted = False self._crypto_streams: Dict[tls.Epoch, QuicStream] = {} self._events: Deque[events.QuicEvent] = deque() @@ -342,6 +359,7 @@ def __init__( self._remote_max_stream_data_uni = 0 self._remote_max_streams_bidi = 0 self._remote_max_streams_uni = 0 + self._remote_version_information: Optional[QuicVersionInformation] = None self._retry_count = 0 self._retry_source_connection_id = retry_source_connection_id self._spaces: Dict[tls.Epoch, QuicPacketSpace] = {} @@ -354,7 +372,8 @@ def __init__( self._streams_blocked_uni: List[QuicStream] = [] self._streams_finished: Set[int] = set() self._version: Optional[int] = None - self._version_negotiation_count = 0 + self._version_negotiated_compatible = False + self._version_negotiated_incompatible = False if self._is_client: self._original_destination_connection_id = self._peer_cid.cid @@ -498,7 +517,10 @@ def connect(self, addr: NetworkAddress, now: float) -> None: self._connect_called = True self._network_paths = [QuicNetworkPath(addr, is_validated=True)] - self._version = self._configuration.supported_versions[0] + if self._configuration.original_version is not None: + self._version = self._configuration.original_version + else: + self._version = self._configuration.supported_versions[0] self._connect(now=now) def datagrams_to_send(self, now: float) -> List[Tuple[bytes, NetworkAddress]]: @@ -860,7 +882,10 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non # determine crypto and packet space epoch = get_epoch(header.packet_type) - crypto = self._cryptos[epoch] + if epoch == tls.Epoch.INITIAL: + crypto = self._cryptos_initial[header.version] + else: + crypto = self._cryptos[epoch] if epoch == tls.Epoch.ZERO_RTT: space = self._spaces[tls.Epoch.ONE_RTT] else: @@ -987,6 +1012,7 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non network_path=network_path, quic_logger_frames=quic_logger_frames, time=now, + version=header.version, ) try: is_ack_eliciting, is_probing = self._payload_received( @@ -1123,7 +1149,8 @@ def stop_stream(self, stream_id: int, error_code: int) -> None: def _alpn_handler(self, alpn_protocol: str) -> None: """ - Callback which is invoked by the TLS engine when ALPN negotiation completes. + Callback which is invoked by the TLS engine at most once, when the + ALPN negotiation completes. At this point, TLS extensions have been received so we can parse the transport parameters. @@ -1141,6 +1168,30 @@ def _alpn_handler(self, alpn_protocol: str) -> None: reason_phrase="No QUIC transport parameters received", ) + # For servers, determine the Negotiated Version. + if not self._is_client: + if self._remote_version_information is not None: + # Pick the first version we support in the client's available versions, + # which is compatible with the current version. + for version in self._remote_version_information.available_versions: + if version == self._version: + # Stay with the current version. + break + elif ( + version in self._configuration.supported_versions + and is_version_compatible(self._version, version) + ): + # Change version. + self._version = version + self._cryptos[tls.Epoch.INITIAL] = self._cryptos_initial[ + version + ] + break + self._version_negotiated_compatible = True + self._logger.info( + "Negotiated protocol version %s", pretty_protocol_version(self._version) + ) + # Notify the application. self._events.append(events.ProtocolNegotiated(alpn_protocol=alpn_protocol)) @@ -1235,6 +1286,13 @@ def _discard_epoch(self, epoch: tls.Epoch) -> None: if not self._spaces[epoch].discarded: self._logger.debug("Discarding epoch %s", epoch) self._cryptos[epoch].teardown() + if epoch == tls.Epoch.INITIAL: + # Tear the crypto pairs, but do not log the event, + # to avoid duplicate log entries. + for crypto in self._cryptos_initial.values(): + crypto.recv._teardown_cb = NoCallback + crypto.send._teardown_cb = NoCallback + crypto.teardown() self._loss.discard_space(self._spaces[epoch]) self._spaces[epoch].discarded = True @@ -1427,15 +1485,24 @@ def create_crypto_pair(epoch: tls.Epoch) -> CryptoPair: send_teardown_cb=partial(self._log_key_retired, send_secret_name), ) + # To enable version negotiation, setup encryption keys for all + # our supported versions. + self._cryptos_initial = {} + for version in self._configuration.supported_versions: + pair = CryptoPair() + pair.setup_initial(cid=peer_cid, is_client=self._is_client, version=version) + self._cryptos_initial[version] = pair + self._cryptos = dict( (epoch, create_crypto_pair(epoch)) for epoch in ( - tls.Epoch.INITIAL, tls.Epoch.ZERO_RTT, tls.Epoch.HANDSHAKE, tls.Epoch.ONE_RTT, ) ) + self._cryptos[tls.Epoch.INITIAL] = self._cryptos_initial[self._version] + self._crypto_buffers = { tls.Epoch.INITIAL: Buffer(capacity=CRYPTO_BUFFER_SIZE), tls.Epoch.HANDSHAKE: Buffer(capacity=CRYPTO_BUFFER_SIZE), @@ -1451,11 +1518,6 @@ def create_crypto_pair(epoch: tls.Epoch) -> CryptoPair: tls.Epoch.HANDSHAKE: QuicPacketSpace(), tls.Epoch.ONE_RTT: QuicPacketSpace(), } - - self._cryptos[tls.Epoch.INITIAL].setup_initial( - cid=peer_cid, is_client=self._is_client, version=self._version - ) - self._loss.spaces = list(self._spaces.values()) def _handle_ack_frame( @@ -1567,6 +1629,7 @@ def _handle_crypto_frame( # - _alpn_handler # - _update_traffic_key self._crypto_frame_type = frame_type + self._crypto_packet_version = context.version try: self.tls.handle_message(event.data, self._crypto_buffers) self._push_crypto_data() @@ -2476,7 +2539,7 @@ def _receive_version_negotiation_packet( if ( self._is_client and self._state == QuicConnectionState.FIRSTFLIGHT - and not self._version_negotiation_count + and not self._version_negotiated_incompatible ): if self._quic_logger is not None: self._quic_logger.log_event( @@ -2536,7 +2599,7 @@ def _receive_version_negotiation_packet( return self._packet_number = 0 self._version = chosen_version - self._version_negotiation_count += 1 + self._version_negotiated_incompatible = True self._logger.info( "Retrying with protocol version %s", pretty_protocol_version(self._version), @@ -2736,6 +2799,9 @@ def _parse_transport_parameters( self._peer_cid.stateless_reset_token = ( quic_transport_parameters.stateless_reset_token ) + self._remote_version_information = ( + quic_transport_parameters.version_information + ) if quic_transport_parameters.active_connection_id_limit is not None: self._remote_active_connection_id_limit = ( @@ -2780,6 +2846,10 @@ def _serialize_transport_parameters(self) -> bytes: else None ), stateless_reset_token=self._host_cids[0].stateless_reset_token, + version_information=QuicVersionInformation( + chosen_version=self._version, + available_versions=self._configuration.supported_versions, + ), ) if not self._is_client: quic_transport_parameters.original_destination_connection_id = ( @@ -2846,6 +2916,18 @@ def _update_traffic_key( Callback which is invoked by the TLS engine when new traffic keys are available. """ + # For clients, determine the negotiated protocol version. + if ( + self._is_client + and self._crypto_packet_version is not None + and not self._version_negotiated_compatible + ): + self._version = self._crypto_packet_version + self._version_negotiated_compatible = True + self._logger.info( + "Negotiated protocol version %s", pretty_protocol_version(self._version) + ) + secrets_log_file = self._configuration.secrets_log_file if secrets_log_file is not None: label_row = self._is_client == (direction == tls.Direction.DECRYPT) diff --git a/tests/test_connection.py b/tests/test_connection.py index b469f9bd..4b073c17 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -46,6 +46,7 @@ CLIENT_ADDR = ("1.2.3.4", 1234) SERVER_ADDR = ("2.3.4.5", 4433) +SERVER_INITIAL_DATAGRAM_SIZES = [1200, 1200, 986] TICK = 0.05 # seconds @@ -68,6 +69,7 @@ def client_receive_context(client, epoch=tls.Epoch.ONE_RTT): network_path=client._network_paths[0], quic_logger_frames=[], time=time.time(), + version=None, ) @@ -440,7 +442,7 @@ def test_connect_with_loss_1(self): self.assertEqual(datagram_sizes(items), [1200]) self.assertEqual(client.get_timer(), 0.2) - # INITIAL is lost + # INITIAL is lost and retransmitted now = client.get_timer() client.handle_timer(now=now) items = client.datagrams_to_send(now=now) @@ -451,7 +453,7 @@ def test_connect_with_loss_1(self): now += TICK server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) items = server.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1200, 1200, 972]) + self.assertEqual(datagram_sizes(items), SERVER_INITIAL_DATAGRAM_SIZES) self.assertAlmostEqual(server.get_timer(), 0.45) self.assertEqual(len(server._loss.spaces[0].sent_packets), 1) self.assertEqual(len(server._loss.spaces[1].sent_packets), 2) @@ -523,7 +525,7 @@ def test_connect_with_loss_2(self): now += TICK server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) items = server.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1200, 1200, 972]) + self.assertEqual(datagram_sizes(items), SERVER_INITIAL_DATAGRAM_SIZES) self.assertEqual(server.get_timer(), 0.25) self.assertEqual(len(server._loss.spaces[0].sent_packets), 1) self.assertEqual(len(server._loss.spaces[1].sent_packets), 2) @@ -544,7 +546,7 @@ def test_connect_with_loss_2(self): now += TICK server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) items = server.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1200, 1200, 972]) + self.assertEqual(datagram_sizes(items), SERVER_INITIAL_DATAGRAM_SIZES) self.assertAlmostEqual(server.get_timer(), 0.35) self.assertEqual(len(server._loss.spaces[0].sent_packets), 1) self.assertEqual(len(server._loss.spaces[1].sent_packets), 2) @@ -613,7 +615,7 @@ def test_connect_with_loss_3(self): now += TICK server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) items = server.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1200, 1200, 972]) + self.assertEqual(datagram_sizes(items), SERVER_INITIAL_DATAGRAM_SIZES) self.assertEqual(server.get_timer(), 0.25) self.assertEqual(len(server._loss.spaces[0].sent_packets), 1) self.assertEqual(len(server._loss.spaces[1].sent_packets), 2) @@ -632,7 +634,7 @@ def test_connect_with_loss_3(self): now += TICK server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) items = server.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1200, 1200, 972]) + self.assertEqual(datagram_sizes(items), SERVER_INITIAL_DATAGRAM_SIZES) self.assertEqual(server.get_timer(), 0.45) self.assertEqual(len(server._loss.spaces[0].sent_packets), 1) self.assertEqual(len(server._loss.spaces[1].sent_packets), 2) @@ -697,7 +699,7 @@ def test_connect_with_loss_4(self): now += TICK server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) items = server.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1200, 1200, 972]) + self.assertEqual(datagram_sizes(items), SERVER_INITIAL_DATAGRAM_SIZES) self.assertEqual(server.get_timer(), 0.25) self.assertEqual(len(server._loss.spaces[0].sent_packets), 1) self.assertEqual(len(server._loss.spaces[1].sent_packets), 2) @@ -735,7 +737,7 @@ def test_connect_with_loss_4(self): now = server.get_timer() server.handle_timer(now=now) items = server.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1200, 972]) + self.assertEqual(datagram_sizes(items), [1200, 986]) self.assertAlmostEqual(server.get_timer(), 0.65) self.assertEqual(len(server._loss.spaces[0].sent_packets), 0) self.assertEqual(len(server._loss.spaces[1].sent_packets), 3) @@ -795,7 +797,7 @@ def test_connect_with_loss_5(self): now += TICK server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) items = server.datagrams_to_send(now=now) - self.assertEqual(datagram_sizes(items), [1200, 1200, 972]) + self.assertEqual(datagram_sizes(items), SERVER_INITIAL_DATAGRAM_SIZES) self.assertEqual(server.get_timer(), 0.25) self.assertEqual(len(server._loss.spaces[0].sent_packets), 1) self.assertEqual(len(server._loss.spaces[1].sent_packets), 2) @@ -913,6 +915,78 @@ def patched_initialize(peer_cid: bytes): "No QUIC transport parameters received", ) + def test_connect_with_compatible_version_negotiation_1(self): + """ + The client only supports version 1. + + The server sets the Negotiated Version to version 1. + """ + with client_and_server( + client_options={ + "supported_versions": [QuicProtocolVersion.VERSION_1], + }, + ) as (client, server): + # check handshake completed + self.check_handshake(client=client, server=server) + self.assertEqual(client._version, QuicProtocolVersion.VERSION_1) + self.assertEqual(server._version, QuicProtocolVersion.VERSION_1) + + def test_connect_with_compatible_version_negotiation_1_to_2(self): + """ + The client originally connects using version 1 but prefers version 2. + + The server sets the Negotiated Version to version 2. + """ + with client_and_server( + client_options={ + "original_version": QuicProtocolVersion.VERSION_1, + "supported_versions": [ + QuicProtocolVersion.VERSION_2, + QuicProtocolVersion.VERSION_1, + ], + }, + ) as (client, server): + # check handshake completed + self.check_handshake(client=client, server=server) + self.assertEqual(client._version, QuicProtocolVersion.VERSION_2) + self.assertEqual(server._version, QuicProtocolVersion.VERSION_2) + + def test_connect_with_compatible_version_negotiation_2(self): + """ + The client only supports version 2. + + The server sets the Negotiated Version to version 2. + """ + with client_and_server( + client_options={ + "supported_versions": [QuicProtocolVersion.VERSION_2], + }, + ) as (client, server): + # check handshake completed + self.check_handshake(client=client, server=server) + self.assertEqual(client._version, QuicProtocolVersion.VERSION_2) + self.assertEqual(server._version, QuicProtocolVersion.VERSION_2) + + def test_connect_with_compatible_version_negotiation_2_to_1(self): + """ + The client originally connects using version 2 but prefers version 1. + + The server sets the Negotiated Version to version 1. + """ + with client_and_server( + client_options={ + "original_version": QuicProtocolVersion.VERSION_2, + "supported_versions": [ + QuicProtocolVersion.VERSION_1, + QuicProtocolVersion.VERSION_2, + ], + }, + ) as (client, server): + # check handshake completed + self.check_handshake(client=client, server=server) + self.assertEqual(client._version, QuicProtocolVersion.VERSION_1) + self.assertEqual(server._version, QuicProtocolVersion.VERSION_1) + def test_connect_with_quantum_readiness(self): with client_and_server(client_options={"quantum_readiness_test": True}) as ( client,