Skip to content

Commit

Permalink
🐛 fix async connection shutdown in HTTP/1.1 and HTTP/2 leaving a `asy…
Browse files Browse the repository at this point in the history
…ncio.TransportSocket` and `_SelectorSocketTransport` partially closed
  • Loading branch information
Ousret committed Nov 3, 2024
1 parent faaa80f commit 165e21e
Show file tree
Hide file tree
Showing 28 changed files with 293 additions and 55 deletions.
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
2.11.908 (2024-11-03)
=====================

- Fixed async connection shutdown in HTTP/1.1 and HTTP/2 leaving a ``asyncio.TransportSocket`` and ``_SelectorSocketTransport`` partially closed.

2.11.907 (2024-10-30)
=====================

Expand Down
22 changes: 20 additions & 2 deletions dummyserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,32 @@ def ssl_options_to_context( # type: ignore[no-untyped-def]
alpn_protocols=None,
) -> ssl.SSLContext:
"""Return an equivalent SSLContext based on ssl.wrap_socket args."""
ssl_version = resolve_ssl_version(ssl_version)
_major, _minor, _patch = ssl.OPENSSL_VERSION_INFO[:3]
is_broken_old_ssl: bool = (
"OpenSSL" in ssl.OPENSSL_VERSION and _major < 1 or (_major == 1 and _minor < 1)
)

ssl_version = resolve_ssl_version( # type: ignore[call-overload]
ssl_version, mitigate_tls_version=not is_broken_old_ssl
)
cert_none = resolve_cert_reqs("CERT_NONE")
if cert_reqs is None:
cert_reqs = cert_none
else:
cert_reqs = resolve_cert_reqs(cert_reqs)

ctx = ssl.SSLContext(ssl_version)
if (
hasattr(ssl.SSLContext, "minimum_version")
and hasattr(ssl, "PROTOCOL_TLS_SERVER")
and is_broken_old_ssl is False
):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
if hasattr(ssl, "TLSVersion") and isinstance(ssl_version, ssl.TLSVersion):
ctx.minimum_version = ssl_version
ctx.maximum_version = ssl_version
else:
ctx = ssl.SSLContext(ssl_version) # type: ignore[arg-type]

ctx.load_cert_chain(certfile, keyfile)
ctx.verify_mode = cert_reqs
if ctx.verify_mode != cert_none:
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,12 @@ filterwarnings = [
"error",
'''default:No IPv6 support. Falling back to IPv4:urllib3.exceptions.HTTPWarning''',
'''default:No IPv6 support. skipping:urllib3.exceptions.HTTPWarning''',
'''default:ssl\.TLSVersion\.TLSv1 is deprecated:DeprecationWarning''',
'''default:ssl\.PROTOCOL_TLS is deprecated:DeprecationWarning''',
'''default:ssl\.PROTOCOL_TLSv1 is deprecated:DeprecationWarning''',
'''default:ssl\.TLSVersion\.TLSv1_1 is deprecated:DeprecationWarning''',
'''default:ssl\.PROTOCOL_TLSv1_1 is deprecated:DeprecationWarning''',
'''default:ssl\.PROTOCOL_TLSv1_2 is deprecated:DeprecationWarning''',
'''default:unclosed .*:ResourceWarning''',
'''default:loop is closed:ResourceWarning''',
'''default:ssl NPN is deprecated, use ALPN instead:DeprecationWarning''',
# https://github.com/pytest-dev/pytest/issues/10977
'''default:ast\.(Num|NameConstant|Str) is deprecated and will be removed in Python 3\.14; use ast\.Constant instead:DeprecationWarning:_pytest''',
Expand All @@ -129,6 +127,9 @@ filterwarnings = [
'''ignore:Exception in thread:pytest.PytestUnhandledThreadExceptionWarning''',
'''ignore:function _SSLProtocolTransport\.__del__:pytest.PytestUnraisableExceptionWarning''',
'''ignore:The `hash` argument is deprecated in favor of `unsafe_hash`:DeprecationWarning''',
'''ignore:ssl\.TLSVersion\.TLSv1 is deprecated:DeprecationWarning''',
'''ignore:ssl\.TLSVersion\.TLSv1_1 is deprecated:DeprecationWarning''',
'''ignore:loop is closed:ResourceWarning''',
]

[tool.isort]
Expand Down
2 changes: 1 addition & 1 deletion src/urllib3/_async/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ async def _ssl_wrap_socket_and_match_hostname(
"""
default_ssl_context = False
sharable_ext_options: dict[str, int | str | None] = {
"ssl_version": resolve_ssl_version(ssl_version),
"ssl_version": resolve_ssl_version(ssl_version, mitigate_tls_version=True),
"ssl_minimum_version": ssl_minimum_version,
"ssl_maximum_version": ssl_maximum_version,
"cert_reqs": resolve_cert_reqs(cert_reqs),
Expand Down
8 changes: 7 additions & 1 deletion src/urllib3/backend/_async/hface.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,10 @@ async def _new_conn(self) -> AsyncSocket | None: # type: ignore[override]
# if conn target another host.
if self._response and self._response.authority != self.host:
self._svn = None
await self._new_conn() # restore socket defaults
self._response = None
if self.blocksize == UDP_DEFAULT_BLOCKSIZE:
self.blocksize = DEFAULT_BLOCKSIZE
self.socket_kind = SOCK_STREAM
else:
if self.blocksize == UDP_DEFAULT_BLOCKSIZE:
self.blocksize = DEFAULT_BLOCKSIZE
Expand Down Expand Up @@ -1644,6 +1647,9 @@ async def close(self) -> None: # type: ignore[override]

try:
self.sock.close()
# this avoids having SelectorSocketTransport in "closing" state
# pending. Thus avoid a ResourceWarning.
await self.sock.wait_for_close()
except OSError:
pass

Expand Down
5 changes: 4 additions & 1 deletion src/urllib3/backend/hface.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,10 @@ def _new_conn(self) -> socket.socket | None:
# if conn target another host.
if self._response and self._response.authority != self.host:
self._svn = None
self._new_conn() # restore socket defaults
self._response = None # type: ignore[assignment]
if self.blocksize == UDP_DEFAULT_BLOCKSIZE:
self.blocksize = DEFAULT_BLOCKSIZE
self.socket_kind = SOCK_STREAM
else:
if self.blocksize == UDP_DEFAULT_BLOCKSIZE:
self.blocksize = DEFAULT_BLOCKSIZE
Expand Down
2 changes: 1 addition & 1 deletion src/urllib3/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,7 @@ def _ssl_wrap_socket_and_match_hostname(
"""
default_ssl_context = False
sharable_ext_options: dict[str, int | str | None] = {
"ssl_version": resolve_ssl_version(ssl_version),
"ssl_version": resolve_ssl_version(ssl_version, mitigate_tls_version=True),
"ssl_minimum_version": ssl_minimum_version,
"ssl_maximum_version": ssl_maximum_version,
"cert_reqs": resolve_cert_reqs(cert_reqs),
Expand Down
1 change: 1 addition & 0 deletions src/urllib3/contrib/resolver/_async/doq/_qh3.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ async def close(self) -> None: # type: ignore[override]
await self._socket.sendall(data)

self._socket.close()
await self._socket.wait_for_close()
self._terminated = True
if self._socket.should_connect():
self._terminated = True
Expand Down
1 change: 1 addition & 0 deletions src/urllib3/contrib/resolver/_async/dou/_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ async def close(self) -> None: # type: ignore[override]
if not self._terminated:
with self._lock:
self._socket.close()
await self._socket.wait_for_close()
self._terminated = True

def is_available(self) -> bool:
Expand Down
7 changes: 6 additions & 1 deletion src/urllib3/contrib/resolver/_async/protocols.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import socket
import typing
from abc import ABCMeta, abstractmethod
Expand Down Expand Up @@ -131,7 +132,11 @@ async def create_connection( # type: ignore[override]
if source_address:
sock.bind(source_address)

await sock.connect(sa)
try:
await sock.connect(sa)
except asyncio.CancelledError:
sock.close()
raise

# Break explicitly a reference cycle
err = None
Expand Down
5 changes: 4 additions & 1 deletion src/urllib3/contrib/resolver/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,10 @@ def create_connection(
if source_address is not None:
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except (OSError, AttributeError): # Defensive: very old OS?
except (
OSError,
AttributeError,
): # Defensive: Windows or very old OS?
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
except (
Expand Down
79 changes: 71 additions & 8 deletions src/urllib3/contrib/ssa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,26 @@
from ..._typing import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT


def _can_shutdown_and_close_selector_loop_bug() -> bool:
import platform

if platform.system() == "Windows" and platform.python_version_tuple()[:2] == (
"3",
"7",
):
return int(platform.python_version_tuple()[-1]) >= 17

return True


# Windows + asyncio bug where doing our shutdown procedure induce a crash
# in SelectorLoop
# File "C:\hostedtoolcache\windows\Python\3.7.9\x64\lib\selectors.py", line 314, in _select
# r, w, x = select.select(r, w, w, timeout)
# [WinError 10038] An operation was attempted on something that is not a socket
_CPYTHON_SELECTOR_CLOSE_BUG_EXIST = _can_shutdown_and_close_selector_loop_bug() is False


class AsyncSocket:
"""
This class is brought to add a level of abstraction to an asyncio transport (reader, or writer)
Expand Down Expand Up @@ -73,6 +93,23 @@ def __init__(
def fileno(self) -> int:
return self._fileno if self._fileno is not None else self._sock.fileno()

async def wait_for_close(self) -> None:
if self._connect_called:
return

if self._writer is None:
return

is_ssl = self._writer.get_extra_info("ssl_object") is not None

if is_ssl:
# Give the connection a chance to write any data in the buffer,
# and then forcibly tear down the SSL connection.
await asyncio.sleep(0)
self._writer.transport.abort()

await self._writer.wait_closed()

def close(self) -> None:
if self._writer is not None:
self._writer.close()
Expand All @@ -83,15 +120,21 @@ def close(self) -> None:
# probably not just uvloop.
uvloop_edge_case_bug = False

# keep track of our clean exit procedure
shutdown_called = False
close_called = False

if hasattr(self._sock, "shutdown"):
try:
self._sock.shutdown(SHUT_RD)
shutdown_called = True
except TypeError:
uvloop_edge_case_bug = True
# uvloop don't support shutdown! and sometime does not support close()...
# see https://github.com/jawah/niquests/issues/166 for ctx.
try:
self._sock.close()
close_called = True
except TypeError:
# last chance of releasing properly the underlying fd!
try:
Expand All @@ -101,6 +144,7 @@ def close(self) -> None:
else:
try:
direct_sock.shutdown(SHUT_RD)
shutdown_called = True
except OSError:
warnings.warn(
(
Expand All @@ -113,15 +157,34 @@ def close(self) -> None:
)
finally:
direct_sock.detach()
elif hasattr(self._sock, "close"):
self._sock.close()
# we have to force call close() on our sock object in UDP ctx. (even after shutdown)
# we have to force call close() on our sock object (even after shutdown).
# or we'll get a resource warning for sure!
if self.type == socket.SOCK_DGRAM and hasattr(self._sock, "close"):
if not uvloop_edge_case_bug:
self._sock.close()
except OSError:
pass
if isinstance(self._sock, socket.socket) and hasattr(self._sock, "close"):
if not uvloop_edge_case_bug and not _CPYTHON_SELECTOR_CLOSE_BUG_EXIST:
try:
self._sock.close()
close_called = True
except OSError:
pass

if not close_called or not shutdown_called:
# this branch detect whether we have an asyncio.TransportSocket instead of socket.socket.
if (
hasattr(self._sock, "_sock")
and not _CPYTHON_SELECTOR_CLOSE_BUG_EXIST
):
try:
self._sock._sock.detach()
except (AttributeError, OSError):
pass

except (
OSError
): # branch where we failed to connect and still try to release resource
try:
self._sock.close()
except (OSError, TypeError, AttributeError):
pass

self._connect_called = False
self._established.clear()
Expand Down
66 changes: 54 additions & 12 deletions src/urllib3/util/ssl_.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,16 +224,14 @@ def _is_has_never_check_common_name_reliable(
# Python built against (very) restrictive ssl library may ship with a single TLS version
# thus, it seems to make attribute "minimum_version" and "maximum_version" unavailable.
# note: it raises an exception! maybe a CPython bug.
SUPPORT_MIN_MAX_TLS_VERSION = hasattr(
ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT), "maximum_version"
)
SUPPORT_MIN_MAX_TLS_VERSION = hasattr(ssl.SSLContext, "maximum_version")
except ImportError:
OP_NO_COMPRESSION = 0x20000 # type: ignore[assignment]
OP_NO_TICKET = 0x4000 # type: ignore[assignment]
OP_NO_SSLv2 = 0x1000000 # type: ignore[assignment]
OP_NO_SSLv3 = 0x2000000 # type: ignore[assignment]
PROTOCOL_SSLv23 = PROTOCOL_TLS = 2 # type: ignore[assignment]
PROTOCOL_TLS_CLIENT = 16 # type: ignore[assignment]
PROTOCOL_TLS_CLIENT = PROTOCOL_TLS
OP_NO_RENEGOTIATION = None # type: ignore[assignment]
SUPPORT_MIN_MAX_TLS_VERSION = False

Expand Down Expand Up @@ -298,19 +296,59 @@ def resolve_cert_reqs(candidate: None | int | str) -> VerifyMode:
return candidate # type: ignore[return-value]


def resolve_ssl_version(candidate: None | int | str) -> int:
@typing.overload
def resolve_ssl_version(
candidate: None | int | str,
*,
mitigate_tls_version: typing.Literal[True] = True,
) -> ssl.TLSVersion:
...


@typing.overload
def resolve_ssl_version(
candidate: None | int | str,
*,
mitigate_tls_version: typing.Literal[False] = False,
) -> int:
...


def resolve_ssl_version(
candidate: None | int | str, mitigate_tls_version: bool = False
) -> int | ssl.TLSVersion:
"""
like resolve_cert_reqs
"""
if candidate is None:
if mitigate_tls_version:
return PROTOCOL_TLS_CLIENT
return PROTOCOL_TLS

if isinstance(candidate, str):
if mitigate_tls_version and hasattr(ssl, "TLSVersion"):
res = getattr(ssl.TLSVersion, candidate, None)

if res is not None:
return res # type: ignore[no-any-return]

res = getattr(ssl.TLSVersion, candidate.replace("PROTOCOL_", ""), None)

if res is not None:
return res # type: ignore[no-any-return]

res = getattr(ssl, candidate, None)
if res is None:
res = getattr(ssl, "PROTOCOL_" + candidate)
return typing.cast(int, res)

if mitigate_tls_version:
if candidate in _SSL_VERSION_TO_TLS_VERSION:
return _SSL_VERSION_TO_TLS_VERSION[candidate]
if candidate == PROTOCOL_TLS_CLIENT or candidate == PROTOCOL_TLS:
return PROTOCOL_TLS_CLIENT
return ssl.TLSVersion.MAXIMUM_SUPPORTED

return candidate


Expand Down Expand Up @@ -363,13 +401,17 @@ def create_urllib3_context(
)

else:
# Use 'ssl_minimum_version' and 'ssl_maximum_version' instead.
ssl_minimum_version = _SSL_VERSION_TO_TLS_VERSION.get(
ssl_version, TLSVersion.MINIMUM_SUPPORTED
)
ssl_maximum_version = _SSL_VERSION_TO_TLS_VERSION.get(
ssl_version, TLSVersion.MAXIMUM_SUPPORTED
)
if hasattr(ssl, "TLSVersion") and isinstance(ssl_version, ssl.TLSVersion):
ssl_minimum_version = ssl_version
ssl_maximum_version = ssl_version
else:
# Use 'ssl_minimum_version' and 'ssl_maximum_version' instead.
ssl_minimum_version = _SSL_VERSION_TO_TLS_VERSION.get(
ssl_version, TLSVersion.MINIMUM_SUPPORTED
)
ssl_maximum_version = _SSL_VERSION_TO_TLS_VERSION.get(
ssl_version, TLSVersion.MAXIMUM_SUPPORTED
)

# PROTOCOL_TLS is deprecated in Python 3.10 so we always use PROTOCOL_TLS_CLIENT
context = SSLContext(PROTOCOL_TLS_CLIENT)
Expand Down
Loading

0 comments on commit 165e21e

Please sign in to comment.