Skip to content

Commit

Permalink
Implement compatible version negotiation
Browse files Browse the repository at this point in the history
Enable compatible version negotiation according to RFC 9368. Unlike
incompatible negotiation which uses a Version Negotiation packet, this
allows a switch between compatible versions without an additional
roundtrip.

As an example we support switching between version 1 (RFC 9000) and
version 2 (RFC 9369) and vice versa.

On the server side we honour the client's preferences.
  • Loading branch information
jlaine committed Jun 30, 2024
1 parent 88d2ac2 commit bb5a03d
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 22 deletions.
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ Features
* connection migration and NAT rebinding
* logging TLS traffic secrets
* logging QUIC events in QLOG format
* version negotiation conforming with `RFC 9368`_
- HTTP/3 stack conforming with `RFC 9114`_
* server push support
* WebSocket bootstrapping conforming with `RFC 9220`_
Expand Down Expand Up @@ -156,4 +157,5 @@ License
.. _RFC 9114: https://datatracker.ietf.org/doc/html/rfc9114
.. _RFC 9220: https://datatracker.ietf.org/doc/html/rfc9220
.. _RFC 9297: https://datatracker.ietf.org/doc/html/rfc9297
.. _RFC 9368: https://datatracker.ietf.org/doc/html/rfc9368
.. _RFC 9369: https://datatracker.ietf.org/doc/html/rfc9369
3 changes: 3 additions & 0 deletions src/aioquic/quic/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class QuicConfiguration:
.. note:: This is only used by clients.
"""

# For internal purposes, not guaranteed to be stable.
cadata: Optional[bytes] = None
cafile: Optional[str] = None
capath: Optional[str] = None
Expand All @@ -108,11 +109,13 @@ class QuicConfiguration:
cipher_suites: Optional[List[CipherSuite]] = None
initial_rtt: float = 0.1
max_datagram_frame_size: Optional[int] = None
original_version: Optional[int] = None
private_key: Any = None
quantum_readiness_test: bool = False
supported_versions: List[int] = field(
default_factory=lambda: [
QuicProtocolVersion.VERSION_1,
QuicProtocolVersion.VERSION_2,
]
)
verify_mode: Optional[int] = None
Expand Down
108 changes: 95 additions & 13 deletions src/aioquic/quic/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from . import events
from .configuration import SMALLEST_MAX_DATAGRAM_SIZE, QuicConfiguration
from .congestion.base import K_GRANULARITY
from .crypto import CryptoError, CryptoPair, KeyUnavailableError
from .crypto import CryptoError, CryptoPair, KeyUnavailableError, NoCallback
from .logger import QuicLoggerTrace
from .packet import (
CONNECTION_ID_MAX_SIZE,
Expand All @@ -41,8 +41,10 @@
QuicFrameType,
QuicHeader,
QuicPacketType,
QuicProtocolVersion,
QuicStreamFrame,
QuicTransportParameters,
QuicVersionInformation,
get_retry_integrity_tag,
get_spin_bit,
pretty_protocol_version,
Expand Down Expand Up @@ -113,6 +115,18 @@ def EPOCHS(shortcut: str) -> FrozenSet[tls.Epoch]:
return frozenset(EPOCH_SHORTCUTS[i] for i in shortcut)


def is_version_compatible(from_version: int, to_version: int) -> bool:
"""
Return whether it is possible to perform compatible version negotiation
from `from_version` to `to_version`.
"""
# Version 1 is compatible with version 2 and vice versa. These are the
# only compatible versions so far.
return set([from_version, to_version]) == set(
[QuicProtocolVersion.VERSION_1, QuicProtocolVersion.VERSION_2]
)


def dump_cid(cid: bytes) -> str:
return binascii.hexlify(cid).decode("ascii")

Expand Down Expand Up @@ -205,6 +219,7 @@ class QuicReceiveContext:
network_path: QuicNetworkPath
quic_logger_frames: Optional[List[Any]]
time: float
version: Optional[int]


QuicTokenHandler = Callable[[bytes], None]
Expand Down Expand Up @@ -278,8 +293,10 @@ def __init__(
self._close_event: Optional[events.ConnectionTerminated] = None
self._connect_called = False
self._cryptos: Dict[tls.Epoch, CryptoPair] = {}
self._cryptos_initial: Dict[int, CryptoPair] = {}
self._crypto_buffers: Dict[tls.Epoch, Buffer] = {}
self._crypto_frame_type: Optional[int] = None
self._crypto_packet_version: Optional[int] = None
self._crypto_retransmitted = False
self._crypto_streams: Dict[tls.Epoch, QuicStream] = {}
self._events: Deque[events.QuicEvent] = deque()
Expand Down Expand Up @@ -342,6 +359,7 @@ def __init__(
self._remote_max_stream_data_uni = 0
self._remote_max_streams_bidi = 0
self._remote_max_streams_uni = 0
self._remote_version_information: Optional[QuicVersionInformation] = None
self._retry_count = 0
self._retry_source_connection_id = retry_source_connection_id
self._spaces: Dict[tls.Epoch, QuicPacketSpace] = {}
Expand All @@ -354,7 +372,8 @@ def __init__(
self._streams_blocked_uni: List[QuicStream] = []
self._streams_finished: Set[int] = set()
self._version: Optional[int] = None
self._version_negotiation_count = 0
self._version_negotiated_compatible = False
self._version_negotiated_incompatible = False

if self._is_client:
self._original_destination_connection_id = self._peer_cid.cid
Expand Down Expand Up @@ -498,7 +517,10 @@ def connect(self, addr: NetworkAddress, now: float) -> None:
self._connect_called = True

self._network_paths = [QuicNetworkPath(addr, is_validated=True)]
self._version = self._configuration.supported_versions[0]
if self._configuration.original_version is not None:
self._version = self._configuration.original_version
else:
self._version = self._configuration.supported_versions[0]
self._connect(now=now)

def datagrams_to_send(self, now: float) -> List[Tuple[bytes, NetworkAddress]]:
Expand Down Expand Up @@ -860,7 +882,10 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non

# determine crypto and packet space
epoch = get_epoch(header.packet_type)
crypto = self._cryptos[epoch]
if epoch == tls.Epoch.INITIAL:
crypto = self._cryptos_initial[header.version]
else:
crypto = self._cryptos[epoch]
if epoch == tls.Epoch.ZERO_RTT:
space = self._spaces[tls.Epoch.ONE_RTT]
else:
Expand Down Expand Up @@ -987,6 +1012,7 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
network_path=network_path,
quic_logger_frames=quic_logger_frames,
time=now,
version=header.version,
)
try:
is_ack_eliciting, is_probing = self._payload_received(
Expand Down Expand Up @@ -1123,7 +1149,8 @@ def stop_stream(self, stream_id: int, error_code: int) -> None:

def _alpn_handler(self, alpn_protocol: str) -> None:
"""
Callback which is invoked by the TLS engine when ALPN negotiation completes.
Callback which is invoked by the TLS engine at most once, when the
ALPN negotiation completes.
At this point, TLS extensions have been received so we can parse the
transport parameters.
Expand All @@ -1141,6 +1168,30 @@ def _alpn_handler(self, alpn_protocol: str) -> None:
reason_phrase="No QUIC transport parameters received",
)

# For servers, determine the Negotiated Version.
if not self._is_client:
if self._remote_version_information is not None:
# Pick the first version we support in the client's available versions,
# which is compatible with the current version.
for version in self._remote_version_information.available_versions:
if version == self._version:
# Stay with the current version.
break
elif (
version in self._configuration.supported_versions
and is_version_compatible(self._version, version)
):
# Change version.
self._version = version
self._cryptos[tls.Epoch.INITIAL] = self._cryptos_initial[
version
]
break
self._version_negotiated_compatible = True
self._logger.info(
"Negotiated protocol version %s", pretty_protocol_version(self._version)
)

# Notify the application.
self._events.append(events.ProtocolNegotiated(alpn_protocol=alpn_protocol))

Expand Down Expand Up @@ -1235,6 +1286,13 @@ def _discard_epoch(self, epoch: tls.Epoch) -> None:
if not self._spaces[epoch].discarded:
self._logger.debug("Discarding epoch %s", epoch)
self._cryptos[epoch].teardown()
if epoch == tls.Epoch.INITIAL:
# Tear the crypto pairs, but do not log the event,
# to avoid duplicate log entries.
for crypto in self._cryptos_initial.values():
crypto.recv._teardown_cb = NoCallback
crypto.send._teardown_cb = NoCallback
crypto.teardown()
self._loss.discard_space(self._spaces[epoch])
self._spaces[epoch].discarded = True

Expand Down Expand Up @@ -1427,15 +1485,24 @@ def create_crypto_pair(epoch: tls.Epoch) -> CryptoPair:
send_teardown_cb=partial(self._log_key_retired, send_secret_name),
)

# To enable version negotiation, setup encryption keys for all
# our supported versions.
self._cryptos_initial = {}
for version in self._configuration.supported_versions:
pair = CryptoPair()
pair.setup_initial(cid=peer_cid, is_client=self._is_client, version=version)
self._cryptos_initial[version] = pair

self._cryptos = dict(
(epoch, create_crypto_pair(epoch))
for epoch in (
tls.Epoch.INITIAL,
tls.Epoch.ZERO_RTT,
tls.Epoch.HANDSHAKE,
tls.Epoch.ONE_RTT,
)
)
self._cryptos[tls.Epoch.INITIAL] = self._cryptos_initial[self._version]

self._crypto_buffers = {
tls.Epoch.INITIAL: Buffer(capacity=CRYPTO_BUFFER_SIZE),
tls.Epoch.HANDSHAKE: Buffer(capacity=CRYPTO_BUFFER_SIZE),
Expand All @@ -1451,11 +1518,6 @@ def create_crypto_pair(epoch: tls.Epoch) -> CryptoPair:
tls.Epoch.HANDSHAKE: QuicPacketSpace(),
tls.Epoch.ONE_RTT: QuicPacketSpace(),
}

self._cryptos[tls.Epoch.INITIAL].setup_initial(
cid=peer_cid, is_client=self._is_client, version=self._version
)

self._loss.spaces = list(self._spaces.values())

def _handle_ack_frame(
Expand Down Expand Up @@ -1567,6 +1629,7 @@ def _handle_crypto_frame(
# - _alpn_handler
# - _update_traffic_key
self._crypto_frame_type = frame_type
self._crypto_packet_version = context.version
try:
self.tls.handle_message(event.data, self._crypto_buffers)
self._push_crypto_data()
Expand Down Expand Up @@ -2476,7 +2539,7 @@ def _receive_version_negotiation_packet(
if (
self._is_client
and self._state == QuicConnectionState.FIRSTFLIGHT
and not self._version_negotiation_count
and not self._version_negotiated_incompatible
):
if self._quic_logger is not None:
self._quic_logger.log_event(
Expand Down Expand Up @@ -2536,7 +2599,7 @@ def _receive_version_negotiation_packet(
return
self._packet_number = 0
self._version = chosen_version
self._version_negotiation_count += 1
self._version_negotiated_incompatible = True
self._logger.info(
"Retrying with protocol version %s",
pretty_protocol_version(self._version),
Expand Down Expand Up @@ -2736,6 +2799,9 @@ def _parse_transport_parameters(
self._peer_cid.stateless_reset_token = (
quic_transport_parameters.stateless_reset_token
)
self._remote_version_information = (
quic_transport_parameters.version_information
)

if quic_transport_parameters.active_connection_id_limit is not None:
self._remote_active_connection_id_limit = (
Expand Down Expand Up @@ -2780,6 +2846,10 @@ def _serialize_transport_parameters(self) -> bytes:
else None
),
stateless_reset_token=self._host_cids[0].stateless_reset_token,
version_information=QuicVersionInformation(
chosen_version=self._version,
available_versions=self._configuration.supported_versions,
),
)
if not self._is_client:
quic_transport_parameters.original_destination_connection_id = (
Expand Down Expand Up @@ -2846,6 +2916,18 @@ def _update_traffic_key(
Callback which is invoked by the TLS engine when new traffic keys are
available.
"""
# For clients, determine the negotiated protocol version.
if (
self._is_client
and self._crypto_packet_version is not None
and not self._version_negotiated_compatible
):
self._version = self._crypto_packet_version
self._version_negotiated_compatible = True
self._logger.info(
"Negotiated protocol version %s", pretty_protocol_version(self._version)
)

secrets_log_file = self._configuration.secrets_log_file
if secrets_log_file is not None:
label_row = self._is_client == (direction == tls.Direction.DECRYPT)
Expand Down
Loading

0 comments on commit bb5a03d

Please sign in to comment.