Skip to content

Commit

Permalink
🐛 minor fixes in DoH resolver and WS extension (#155)
Browse files Browse the repository at this point in the history
DoH resolver: no multiplex to multiplex (in single request; upgrade)

WS extension: remote send close but extension show still open
  • Loading branch information
Ousret authored Oct 8, 2024
1 parent f60b80a commit 3863dfc
Show file tree
Hide file tree
Showing 12 changed files with 128 additions and 27 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ jobs:

- name: "Traefik: Prerequisites - Colima (MacOS)"
if: ${{ matrix.traefik-server && contains(matrix.os, 'mac') }}
run: ./traefik/macos.sh
uses: douglascamata/setup-docker-macos-action@8d5fa43892aed7eee4effcdea113fd53e4d4bf83
with:
colima-network-address: true

- name: "Setup Python ${{ matrix.python-version }}"
uses: "actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3"
Expand Down
9 changes: 9 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
2.10.901 (2024-10-08)
=====================

- Fixed closed state on a WebSocketExtensionFromHTTP when the remote send a CloseConnection event.
- Fixed an edge case where a DNS-over-HTTPS would start of a non-multiplexed connection but immediately upgrade to a
multiplexed capable connection would induce an error.
- Allow to disable HTTP/1.1 in a DNS-over-HTTPS resolver.
- Extra "qh3" lower bound aligned with the main constraint ``>=1.2,<2``.

2.10.900 (2024-10-06)
=====================

Expand Down
9 changes: 8 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import time
import typing
from http.client import RemoteDisconnected
from socket import timeout as SocketTimeout
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen

Expand Down Expand Up @@ -171,7 +172,13 @@ def traefik_boot(session: nox.Session) -> typing.Generator[None, None, None]:
),
timeout=1.0,
)
except (HTTPError, URLError, RemoteDisconnected, TimeoutError) as e:
except (
HTTPError,
URLError,
RemoteDisconnected,
TimeoutError,
SocketTimeout,
) as e:
i += 1
time.sleep(1)
session.log(f"Waiting for the Traefik server: {e}...")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ socks = [
"python-socks>=2.0,<3.0",
]
qh3 = [
"qh3>=1.0.3,<2.0.0",
"qh3>=1.2.0,<2.0.0",
]
ws = [
"wsproto>=1.2,<2",
Expand Down
2 changes: 1 addition & 1 deletion src/urllib3/_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# This file is protected via CODEOWNERS
from __future__ import annotations

__version__ = "2.10.900"
__version__ = "2.10.901"
21 changes: 20 additions & 1 deletion src/urllib3/contrib/resolver/_async/doh/_urllib3.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def __init__(
for svn in kwargs["disabled_svn"]:
svn = svn.lower()

if svn == "h2":
if svn == "h11":
disabled_svn.add(HttpVersion.h11)
elif svn == "h2":
disabled_svn.add(HttpVersion.h2)
elif svn == "h3":
disabled_svn.add(HttpVersion.h3)
Expand Down Expand Up @@ -308,6 +310,23 @@ async def getaddrinfo( # type: ignore[override]

no_multiplexing: bool = isinstance(promises[0], AsyncHTTPResponse)

# This edge case can happen when the initial request is emitted through HTTP/1.1
# and a connection upgrade happen just after that (no multiplexing to multiplexing...)
if (
no_multiplexing
and len(promises) > 1
and isinstance(promises[1], AsyncHTTPResponse) is False
):
force_resolve = []

for promise in promises:
if isinstance(promise, AsyncHTTPResponse):
force_resolve.append(promise)
continue
force_resolve.append(await self._pool.get_response(promise=promise)) # type: ignore[arg-type]

promises = force_resolve # type: ignore[assignment]

results: list[
tuple[
socket.AddressFamily,
Expand Down
29 changes: 24 additions & 5 deletions src/urllib3/contrib/resolver/doh/_urllib3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections import deque

from ...._collections import HTTPHeaderDict
from ....backend import ConnectionInfo, HttpVersion
from ....backend import ConnectionInfo, HttpVersion, ResponsePromise
from ....connectionpool import HTTPSConnectionPool
from ....response import HTTPResponse
from ....util.url import parse_url
Expand Down Expand Up @@ -117,7 +117,9 @@ def __init__(
for svn in kwargs["disabled_svn"]:
svn = svn.lower()

if svn == "h2":
if svn == "h11":
disabled_svn.add(HttpVersion.h11)
elif svn == "h2":
disabled_svn.add(HttpVersion.h2)
elif svn == "h3":
disabled_svn.add(HttpVersion.h3)
Expand Down Expand Up @@ -208,7 +210,7 @@ def getaddrinfo(

validate_length_of(host)

promises = []
promises: list[HTTPResponse | ResponsePromise] = []
remote_preemptive_quic_rr = False

if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
Expand Down Expand Up @@ -308,6 +310,23 @@ def getaddrinfo(

no_multiplexing: bool = isinstance(promises[0], HTTPResponse)

# This edge case can happen when the initial request is emitted through HTTP/1.1
# and a connection upgrade happen just after that (no multiplexing to multiplexing...)
if (
no_multiplexing
and len(promises) > 1
and isinstance(promises[1], HTTPResponse) is False
):
force_resolve: list[HTTPResponse] = []

for promise in promises:
if isinstance(promise, HTTPResponse):
force_resolve.append(promise)
continue
force_resolve.append(self._pool.get_response(promise=promise)) # type: ignore[arg-type]

promises = force_resolve # type: ignore[assignment]

results: list[
tuple[
socket.AddressFamily,
Expand All @@ -326,7 +345,7 @@ def getaddrinfo(
if self._unconsumed:
for unconsumed in self._unconsumed:
for pending_promise in promises:
if unconsumed.is_from_promise(pending_promise):
if unconsumed.is_from_promise(pending_promise): # type: ignore[arg-type]
response = unconsumed
break
if response:
Expand All @@ -345,7 +364,7 @@ def getaddrinfo(
p = None

for p in promises:
if response.is_from_promise(p):
if response.is_from_promise(p): # type: ignore[arg-type]
break

if p is None:
Expand Down
10 changes: 8 additions & 2 deletions src/urllib3/contrib/webextensions/_async/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self) -> None:
super().__init__()
self._protocol = WSConnection(ConnectionType.CLIENT)
self._request_headers: dict[str, str] | None = None
self._remote_shutdown: bool = False

@staticmethod
def supported_svn() -> set[HttpVersion]:
Expand Down Expand Up @@ -89,8 +90,9 @@ def headers(self, http_version: HttpVersion) -> dict[str, str]:
async def close(self) -> None:
"""End/Notify close for sub protocol."""
if self._dsa is not None:
data_to_send: bytes = self._protocol.send(CloseConnection(0))
await self._dsa.sendall(data_to_send)
if self._remote_shutdown is False:
data_to_send: bytes = self._protocol.send(CloseConnection(0))
await self._dsa.sendall(data_to_send)
await self._dsa.close()
self._dsa = None
if self._response is not None:
Expand All @@ -115,6 +117,8 @@ async def next_payload(self) -> str | bytes | None:
elif isinstance(event, BytesMessage):
return event.data
elif isinstance(event, CloseConnection):
self._remote_shutdown = True
await self.close()
return None

while True:
Expand All @@ -129,6 +133,8 @@ async def next_payload(self) -> str | bytes | None:
elif isinstance(event, BytesMessage):
return event.data
elif isinstance(event, CloseConnection):
self._remote_shutdown = True
await self.close()
return None
elif isinstance(event, Ping):
data_to_send: bytes = self._protocol.send(Pong())
Expand Down
11 changes: 9 additions & 2 deletions src/urllib3/contrib/webextensions/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self) -> None:
super().__init__()
self._protocol = WSConnection(ConnectionType.CLIENT)
self._request_headers: dict[str, str] | None = None
self._remote_shutdown: bool = False

@staticmethod
def supported_svn() -> set[HttpVersion]:
Expand Down Expand Up @@ -89,8 +90,9 @@ def headers(self, http_version: HttpVersion) -> dict[str, str]:
def close(self) -> None:
"""End/Notify close for sub protocol."""
if self._dsa is not None:
data_to_send: bytes = self._protocol.send(CloseConnection(0))
self._dsa.sendall(data_to_send)
if self._remote_shutdown is False:
data_to_send: bytes = self._protocol.send(CloseConnection(0))
self._dsa.sendall(data_to_send)
self._dsa.close()
self._dsa = None
if self._response is not None:
Expand All @@ -109,12 +111,15 @@ def next_payload(self) -> str | bytes | None:
if self._dsa is None or self._response is None or self._police_officer is None:
raise OSError("The HTTP extension is closed or uninitialized")

# we may have pending event to unpack!
for event in self._protocol.events():
if isinstance(event, TextMessage):
return event.data
elif isinstance(event, BytesMessage):
return event.data
elif isinstance(event, CloseConnection):
self._remote_shutdown = True
self.close()
return None

while True:
Expand All @@ -129,6 +134,8 @@ def next_payload(self) -> str | bytes | None:
elif isinstance(event, BytesMessage):
return event.data
elif isinstance(event, CloseConnection):
self._remote_shutdown = True
self.close()
return None
elif isinstance(event, Ping):
data_to_send: bytes = self._protocol.send(Pong())
Expand Down
23 changes: 23 additions & 0 deletions test/contrib/asynchronous/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,29 @@ async def test_doh_http11() -> None:
assert len(res)


@requires_network()
@pytest.mark.asyncio
@pytest.mark.xfail(
os.environ.get("CI", None) is not None and platform.system() != "Darwin",
reason="Github Action CI: Network Unreachable UDP/QUIC",
strict=False,
)
async def test_doh_http11_upgradable() -> None:
"""Ensure we can do DoH over HTTP/1.1 that can upgrade to HTTP/3"""
resolver = AsyncResolverDescription.from_url(
"doh+google://default/?disabled_svn=h2"
).new()

res = await resolver.getaddrinfo(
"www.cloudflare.com",
80,
socket.AF_UNSPEC,
socket.SOCK_STREAM,
)

assert len(res)


@requires_network()
@pytest.mark.asyncio
async def test_doh_on_connection_callback() -> None:
Expand Down
22 changes: 22 additions & 0 deletions test/contrib/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,28 @@ def test_doh_http11() -> None:
assert len(res)


@requires_network()
@pytest.mark.xfail(
os.environ.get("CI", None) is not None and platform.system() != "Darwin",
reason="Github Action CI: Network Unreachable UDP/QUIC",
strict=False,
)
def test_doh_http11_upgradable() -> None:
"""Ensure we can do DoH over HTTP/1.1 that can upgrade to HTTP/3"""
resolver = ResolverDescription.from_url(
"doh+google://default/?disabled_svn=h2"
).new()

res = resolver.getaddrinfo(
"www.cloudflare.com",
80,
socket.AF_UNSPEC,
socket.SOCK_STREAM,
)

assert len(res)


@requires_network()
def test_doh_on_connection_callback() -> None:
"""Ensure we can inspect the resolver connection with a callback."""
Expand Down
13 changes: 0 additions & 13 deletions traefik/macos.sh

This file was deleted.

0 comments on commit 3863dfc

Please sign in to comment.