From 165e21e838474fbe63c6892b404487c188678c8f Mon Sep 17 00:00:00 2001 From: Ahmed TAHRI Date: Thu, 31 Oct 2024 07:56:20 +0100 Subject: [PATCH] :bug: fix async connection shutdown in HTTP/1.1 and HTTP/2 leaving a `asyncio.TransportSocket` and `_SelectorSocketTransport` partially closed --- CHANGES.rst | 5 ++ dummyserver/server.py | 22 +++++- pyproject.toml | 5 +- src/urllib3/_async/connection.py | 2 +- src/urllib3/backend/_async/hface.py | 8 +- src/urllib3/backend/hface.py | 5 +- src/urllib3/connection.py | 2 +- .../contrib/resolver/_async/doq/_qh3.py | 1 + .../contrib/resolver/_async/dou/_socket.py | 1 + .../contrib/resolver/_async/protocols.py | 7 +- src/urllib3/contrib/resolver/protocols.py | 5 +- src/urllib3/contrib/ssa/__init__.py | 79 +++++++++++++++++-- src/urllib3/util/ssl_.py | 66 +++++++++++++--- test/__init__.py | 2 +- test/test_util.py | 21 +++++ .../asynchronous/test_connectionpool.py | 2 + .../asynchronous/test_poolmanager.py | 2 + test/with_dummyserver/test_connection.py | 1 + test/with_dummyserver/test_connectionpool.py | 2 + test/with_dummyserver/test_https.py | 52 ++++++------ .../asynchronous/test_conn_info.py | 6 ++ .../asynchronous/test_connection.py | 14 ++++ .../test_connection_multiplexed.py | 2 + test/with_traefik/asynchronous/test_svn.py | 9 +++ test/with_traefik/test_conn_info.py | 8 ++ test/with_traefik/test_connection.py | 8 ++ .../test_connection_multiplexed.py | 2 + test/with_traefik/test_svn.py | 9 +++ 28 files changed, 293 insertions(+), 55 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 48cf0878d0..67e1f082b5 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,8 @@ +2.11.908 (2024-11-03) +===================== + +- Fixed async connection shutdown in HTTP/1.1 and HTTP/2 leaving a ``asyncio.TransportSocket`` and ``_SelectorSocketTransport`` partially closed. + 2.11.907 (2024-10-30) ===================== diff --git a/dummyserver/server.py b/dummyserver/server.py index d2a1e6351f..05ce4fe5ce 100755 --- a/dummyserver/server.py +++ b/dummyserver/server.py @@ -163,14 +163,32 @@ def ssl_options_to_context( # type: ignore[no-untyped-def] alpn_protocols=None, ) -> ssl.SSLContext: """Return an equivalent SSLContext based on ssl.wrap_socket args.""" - ssl_version = resolve_ssl_version(ssl_version) + _major, _minor, _patch = ssl.OPENSSL_VERSION_INFO[:3] + is_broken_old_ssl: bool = ( + "OpenSSL" in ssl.OPENSSL_VERSION and _major < 1 or (_major == 1 and _minor < 1) + ) + + ssl_version = resolve_ssl_version( # type: ignore[call-overload] + ssl_version, mitigate_tls_version=not is_broken_old_ssl + ) cert_none = resolve_cert_reqs("CERT_NONE") if cert_reqs is None: cert_reqs = cert_none else: cert_reqs = resolve_cert_reqs(cert_reqs) - ctx = ssl.SSLContext(ssl_version) + if ( + hasattr(ssl.SSLContext, "minimum_version") + and hasattr(ssl, "PROTOCOL_TLS_SERVER") + and is_broken_old_ssl is False + ): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + if hasattr(ssl, "TLSVersion") and isinstance(ssl_version, ssl.TLSVersion): + ctx.minimum_version = ssl_version + ctx.maximum_version = ssl_version + else: + ctx = ssl.SSLContext(ssl_version) # type: ignore[arg-type] + ctx.load_cert_chain(certfile, keyfile) ctx.verify_mode = cert_reqs if ctx.verify_mode != cert_none: diff --git a/pyproject.toml b/pyproject.toml index 9e561dc51e..22643eee20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,14 +108,12 @@ filterwarnings = [ "error", '''default:No IPv6 support. Falling back to IPv4:urllib3.exceptions.HTTPWarning''', '''default:No IPv6 support. skipping:urllib3.exceptions.HTTPWarning''', - '''default:ssl\.TLSVersion\.TLSv1 is deprecated:DeprecationWarning''', '''default:ssl\.PROTOCOL_TLS is deprecated:DeprecationWarning''', '''default:ssl\.PROTOCOL_TLSv1 is deprecated:DeprecationWarning''', '''default:ssl\.TLSVersion\.TLSv1_1 is deprecated:DeprecationWarning''', '''default:ssl\.PROTOCOL_TLSv1_1 is deprecated:DeprecationWarning''', '''default:ssl\.PROTOCOL_TLSv1_2 is deprecated:DeprecationWarning''', '''default:unclosed .*:ResourceWarning''', - '''default:loop is closed:ResourceWarning''', '''default:ssl NPN is deprecated, use ALPN instead:DeprecationWarning''', # https://github.com/pytest-dev/pytest/issues/10977 '''default:ast\.(Num|NameConstant|Str) is deprecated and will be removed in Python 3\.14; use ast\.Constant instead:DeprecationWarning:_pytest''', @@ -129,6 +127,9 @@ filterwarnings = [ '''ignore:Exception in thread:pytest.PytestUnhandledThreadExceptionWarning''', '''ignore:function _SSLProtocolTransport\.__del__:pytest.PytestUnraisableExceptionWarning''', '''ignore:The `hash` argument is deprecated in favor of `unsafe_hash`:DeprecationWarning''', + '''ignore:ssl\.TLSVersion\.TLSv1 is deprecated:DeprecationWarning''', + '''ignore:ssl\.TLSVersion\.TLSv1_1 is deprecated:DeprecationWarning''', + '''ignore:loop is closed:ResourceWarning''', ] [tool.isort] diff --git a/src/urllib3/_async/connection.py b/src/urllib3/_async/connection.py index 56e8c819d9..8d80ec23e7 100644 --- a/src/urllib3/_async/connection.py +++ b/src/urllib3/_async/connection.py @@ -920,7 +920,7 @@ async def _ssl_wrap_socket_and_match_hostname( """ default_ssl_context = False sharable_ext_options: dict[str, int | str | None] = { - "ssl_version": resolve_ssl_version(ssl_version), + "ssl_version": resolve_ssl_version(ssl_version, mitigate_tls_version=True), "ssl_minimum_version": ssl_minimum_version, "ssl_maximum_version": ssl_maximum_version, "cert_reqs": resolve_cert_reqs(cert_reqs), diff --git a/src/urllib3/backend/_async/hface.py b/src/urllib3/backend/_async/hface.py index ef4b03192d..32ad38cfd1 100644 --- a/src/urllib3/backend/_async/hface.py +++ b/src/urllib3/backend/_async/hface.py @@ -190,7 +190,10 @@ async def _new_conn(self) -> AsyncSocket | None: # type: ignore[override] # if conn target another host. if self._response and self._response.authority != self.host: self._svn = None - await self._new_conn() # restore socket defaults + self._response = None + if self.blocksize == UDP_DEFAULT_BLOCKSIZE: + self.blocksize = DEFAULT_BLOCKSIZE + self.socket_kind = SOCK_STREAM else: if self.blocksize == UDP_DEFAULT_BLOCKSIZE: self.blocksize = DEFAULT_BLOCKSIZE @@ -1644,6 +1647,9 @@ async def close(self) -> None: # type: ignore[override] try: self.sock.close() + # this avoids having SelectorSocketTransport in "closing" state + # pending. Thus avoid a ResourceWarning. + await self.sock.wait_for_close() except OSError: pass diff --git a/src/urllib3/backend/hface.py b/src/urllib3/backend/hface.py index 233bbc1d01..2a5351ef61 100644 --- a/src/urllib3/backend/hface.py +++ b/src/urllib3/backend/hface.py @@ -200,7 +200,10 @@ def _new_conn(self) -> socket.socket | None: # if conn target another host. if self._response and self._response.authority != self.host: self._svn = None - self._new_conn() # restore socket defaults + self._response = None # type: ignore[assignment] + if self.blocksize == UDP_DEFAULT_BLOCKSIZE: + self.blocksize = DEFAULT_BLOCKSIZE + self.socket_kind = SOCK_STREAM else: if self.blocksize == UDP_DEFAULT_BLOCKSIZE: self.blocksize = DEFAULT_BLOCKSIZE diff --git a/src/urllib3/connection.py b/src/urllib3/connection.py index 57fee813b3..a64ca4e0eb 100644 --- a/src/urllib3/connection.py +++ b/src/urllib3/connection.py @@ -893,7 +893,7 @@ def _ssl_wrap_socket_and_match_hostname( """ default_ssl_context = False sharable_ext_options: dict[str, int | str | None] = { - "ssl_version": resolve_ssl_version(ssl_version), + "ssl_version": resolve_ssl_version(ssl_version, mitigate_tls_version=True), "ssl_minimum_version": ssl_minimum_version, "ssl_maximum_version": ssl_maximum_version, "cert_reqs": resolve_cert_reqs(cert_reqs), diff --git a/src/urllib3/contrib/resolver/_async/doq/_qh3.py b/src/urllib3/contrib/resolver/_async/doq/_qh3.py index b11047e235..4ba21d61e0 100644 --- a/src/urllib3/contrib/resolver/_async/doq/_qh3.py +++ b/src/urllib3/contrib/resolver/_async/doq/_qh3.py @@ -104,6 +104,7 @@ async def close(self) -> None: # type: ignore[override] await self._socket.sendall(data) self._socket.close() + await self._socket.wait_for_close() self._terminated = True if self._socket.should_connect(): self._terminated = True diff --git a/src/urllib3/contrib/resolver/_async/dou/_socket.py b/src/urllib3/contrib/resolver/_async/dou/_socket.py index 86e2eb40c1..06e2950c1f 100644 --- a/src/urllib3/contrib/resolver/_async/dou/_socket.py +++ b/src/urllib3/contrib/resolver/_async/dou/_socket.py @@ -70,6 +70,7 @@ async def close(self) -> None: # type: ignore[override] if not self._terminated: with self._lock: self._socket.close() + await self._socket.wait_for_close() self._terminated = True def is_available(self) -> bool: diff --git a/src/urllib3/contrib/resolver/_async/protocols.py b/src/urllib3/contrib/resolver/_async/protocols.py index 9ce7b24cd9..20fe417c38 100644 --- a/src/urllib3/contrib/resolver/_async/protocols.py +++ b/src/urllib3/contrib/resolver/_async/protocols.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import socket import typing from abc import ABCMeta, abstractmethod @@ -131,7 +132,11 @@ async def create_connection( # type: ignore[override] if source_address: sock.bind(source_address) - await sock.connect(sa) + try: + await sock.connect(sa) + except asyncio.CancelledError: + sock.close() + raise # Break explicitly a reference cycle err = None diff --git a/src/urllib3/contrib/resolver/protocols.py b/src/urllib3/contrib/resolver/protocols.py index 64c346772a..3ea4e8d282 100644 --- a/src/urllib3/contrib/resolver/protocols.py +++ b/src/urllib3/contrib/resolver/protocols.py @@ -224,7 +224,10 @@ def create_connection( if source_address is not None: try: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) - except (OSError, AttributeError): # Defensive: very old OS? + except ( + OSError, + AttributeError, + ): # Defensive: Windows or very old OS? try: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) except ( diff --git a/src/urllib3/contrib/ssa/__init__.py b/src/urllib3/contrib/ssa/__init__.py index c44f091b2b..c647351ccf 100644 --- a/src/urllib3/contrib/ssa/__init__.py +++ b/src/urllib3/contrib/ssa/__init__.py @@ -27,6 +27,26 @@ from ..._typing import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT +def _can_shutdown_and_close_selector_loop_bug() -> bool: + import platform + + if platform.system() == "Windows" and platform.python_version_tuple()[:2] == ( + "3", + "7", + ): + return int(platform.python_version_tuple()[-1]) >= 17 + + return True + + +# Windows + asyncio bug where doing our shutdown procedure induce a crash +# in SelectorLoop +# File "C:\hostedtoolcache\windows\Python\3.7.9\x64\lib\selectors.py", line 314, in _select +# r, w, x = select.select(r, w, w, timeout) +# [WinError 10038] An operation was attempted on something that is not a socket +_CPYTHON_SELECTOR_CLOSE_BUG_EXIST = _can_shutdown_and_close_selector_loop_bug() is False + + class AsyncSocket: """ This class is brought to add a level of abstraction to an asyncio transport (reader, or writer) @@ -73,6 +93,23 @@ def __init__( def fileno(self) -> int: return self._fileno if self._fileno is not None else self._sock.fileno() + async def wait_for_close(self) -> None: + if self._connect_called: + return + + if self._writer is None: + return + + is_ssl = self._writer.get_extra_info("ssl_object") is not None + + if is_ssl: + # Give the connection a chance to write any data in the buffer, + # and then forcibly tear down the SSL connection. + await asyncio.sleep(0) + self._writer.transport.abort() + + await self._writer.wait_closed() + def close(self) -> None: if self._writer is not None: self._writer.close() @@ -83,15 +120,21 @@ def close(self) -> None: # probably not just uvloop. uvloop_edge_case_bug = False + # keep track of our clean exit procedure + shutdown_called = False + close_called = False + if hasattr(self._sock, "shutdown"): try: self._sock.shutdown(SHUT_RD) + shutdown_called = True except TypeError: uvloop_edge_case_bug = True # uvloop don't support shutdown! and sometime does not support close()... # see https://github.com/jawah/niquests/issues/166 for ctx. try: self._sock.close() + close_called = True except TypeError: # last chance of releasing properly the underlying fd! try: @@ -101,6 +144,7 @@ def close(self) -> None: else: try: direct_sock.shutdown(SHUT_RD) + shutdown_called = True except OSError: warnings.warn( ( @@ -113,15 +157,34 @@ def close(self) -> None: ) finally: direct_sock.detach() - elif hasattr(self._sock, "close"): - self._sock.close() - # we have to force call close() on our sock object in UDP ctx. (even after shutdown) + # we have to force call close() on our sock object (even after shutdown). # or we'll get a resource warning for sure! - if self.type == socket.SOCK_DGRAM and hasattr(self._sock, "close"): - if not uvloop_edge_case_bug: - self._sock.close() - except OSError: - pass + if isinstance(self._sock, socket.socket) and hasattr(self._sock, "close"): + if not uvloop_edge_case_bug and not _CPYTHON_SELECTOR_CLOSE_BUG_EXIST: + try: + self._sock.close() + close_called = True + except OSError: + pass + + if not close_called or not shutdown_called: + # this branch detect whether we have an asyncio.TransportSocket instead of socket.socket. + if ( + hasattr(self._sock, "_sock") + and not _CPYTHON_SELECTOR_CLOSE_BUG_EXIST + ): + try: + self._sock._sock.detach() + except (AttributeError, OSError): + pass + + except ( + OSError + ): # branch where we failed to connect and still try to release resource + try: + self._sock.close() + except (OSError, TypeError, AttributeError): + pass self._connect_called = False self._established.clear() diff --git a/src/urllib3/util/ssl_.py b/src/urllib3/util/ssl_.py index 477553d9ad..3a1517cc67 100644 --- a/src/urllib3/util/ssl_.py +++ b/src/urllib3/util/ssl_.py @@ -224,16 +224,14 @@ def _is_has_never_check_common_name_reliable( # Python built against (very) restrictive ssl library may ship with a single TLS version # thus, it seems to make attribute "minimum_version" and "maximum_version" unavailable. # note: it raises an exception! maybe a CPython bug. - SUPPORT_MIN_MAX_TLS_VERSION = hasattr( - ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT), "maximum_version" - ) + SUPPORT_MIN_MAX_TLS_VERSION = hasattr(ssl.SSLContext, "maximum_version") except ImportError: OP_NO_COMPRESSION = 0x20000 # type: ignore[assignment] OP_NO_TICKET = 0x4000 # type: ignore[assignment] OP_NO_SSLv2 = 0x1000000 # type: ignore[assignment] OP_NO_SSLv3 = 0x2000000 # type: ignore[assignment] PROTOCOL_SSLv23 = PROTOCOL_TLS = 2 # type: ignore[assignment] - PROTOCOL_TLS_CLIENT = 16 # type: ignore[assignment] + PROTOCOL_TLS_CLIENT = PROTOCOL_TLS OP_NO_RENEGOTIATION = None # type: ignore[assignment] SUPPORT_MIN_MAX_TLS_VERSION = False @@ -298,19 +296,59 @@ def resolve_cert_reqs(candidate: None | int | str) -> VerifyMode: return candidate # type: ignore[return-value] -def resolve_ssl_version(candidate: None | int | str) -> int: +@typing.overload +def resolve_ssl_version( + candidate: None | int | str, + *, + mitigate_tls_version: typing.Literal[True] = True, +) -> ssl.TLSVersion: + ... + + +@typing.overload +def resolve_ssl_version( + candidate: None | int | str, + *, + mitigate_tls_version: typing.Literal[False] = False, +) -> int: + ... + + +def resolve_ssl_version( + candidate: None | int | str, mitigate_tls_version: bool = False +) -> int | ssl.TLSVersion: """ like resolve_cert_reqs """ if candidate is None: + if mitigate_tls_version: + return PROTOCOL_TLS_CLIENT return PROTOCOL_TLS if isinstance(candidate, str): + if mitigate_tls_version and hasattr(ssl, "TLSVersion"): + res = getattr(ssl.TLSVersion, candidate, None) + + if res is not None: + return res # type: ignore[no-any-return] + + res = getattr(ssl.TLSVersion, candidate.replace("PROTOCOL_", ""), None) + + if res is not None: + return res # type: ignore[no-any-return] + res = getattr(ssl, candidate, None) if res is None: res = getattr(ssl, "PROTOCOL_" + candidate) return typing.cast(int, res) + if mitigate_tls_version: + if candidate in _SSL_VERSION_TO_TLS_VERSION: + return _SSL_VERSION_TO_TLS_VERSION[candidate] + if candidate == PROTOCOL_TLS_CLIENT or candidate == PROTOCOL_TLS: + return PROTOCOL_TLS_CLIENT + return ssl.TLSVersion.MAXIMUM_SUPPORTED + return candidate @@ -363,13 +401,17 @@ def create_urllib3_context( ) else: - # Use 'ssl_minimum_version' and 'ssl_maximum_version' instead. - ssl_minimum_version = _SSL_VERSION_TO_TLS_VERSION.get( - ssl_version, TLSVersion.MINIMUM_SUPPORTED - ) - ssl_maximum_version = _SSL_VERSION_TO_TLS_VERSION.get( - ssl_version, TLSVersion.MAXIMUM_SUPPORTED - ) + if hasattr(ssl, "TLSVersion") and isinstance(ssl_version, ssl.TLSVersion): + ssl_minimum_version = ssl_version + ssl_maximum_version = ssl_version + else: + # Use 'ssl_minimum_version' and 'ssl_maximum_version' instead. + ssl_minimum_version = _SSL_VERSION_TO_TLS_VERSION.get( + ssl_version, TLSVersion.MINIMUM_SUPPORTED + ) + ssl_maximum_version = _SSL_VERSION_TO_TLS_VERSION.get( + ssl_version, TLSVersion.MAXIMUM_SUPPORTED + ) # PROTOCOL_TLS is deprecated in Python 3.10 so we always use PROTOCOL_TLS_CLIENT context = SSLContext(PROTOCOL_TLS_CLIENT) diff --git a/test/__init__.py b/test/__init__.py index 2ce5a6f150..1550bcc15b 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -90,7 +90,7 @@ def _can_resolve(host: str, should_match: str | None = None) -> bool: def has_alpn(ctx_cls: type[ssl.SSLContext] | None = None) -> bool: """Detect if ALPN support is enabled.""" ctx_cls = ctx_cls or util.SSLContext - ctx = ctx_cls(protocol=ssl_.PROTOCOL_TLS) # type: ignore[misc, attr-defined] + ctx = ctx_cls(protocol=ssl_.PROTOCOL_TLS_CLIENT) # type: ignore[misc, attr-defined] try: if hasattr(ctx, "set_alpn_protocols"): ctx.set_alpn_protocols(ssl_.ALPN_PROTOCOLS) diff --git a/test/test_util.py b/test/test_util.py index 9232f40001..4409f3de2a 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -982,6 +982,27 @@ def test_resolve_cert_reqs( def test_resolve_ssl_version(self, candidate: int | str, version: int) -> None: assert resolve_ssl_version(candidate) == version + @pytest.mark.parametrize( + "candidate, version", + [ + (ssl.PROTOCOL_TLSv1, "TLSv1"), + ("PROTOCOL_TLSv1", "TLSv1"), + ("TLSv1", "TLSv1"), + ], + ) + @pytest.mark.skipif( + hasattr(ssl, "TLSVersion") is False, reason="test requires ssl.TLSVersion" + ) + def test_resolve_ssl_version_mitigated( + self, candidate: int | str, version: str + ) -> None: + version_ = getattr(ssl.TLSVersion, version, None) + + if version_ is None: + pytest.skip(f"unsupported TLSVersion.{version}") + + assert resolve_ssl_version(candidate, mitigate_tls_version=True) == version_ + def test_ssl_wrap_socket_loads_the_cert_chain(self) -> None: socket = Mock() mock_context = Mock() diff --git a/test/with_dummyserver/asynchronous/test_connectionpool.py b/test/with_dummyserver/asynchronous/test_connectionpool.py index 2b648eec77..bc36091e3f 100644 --- a/test/with_dummyserver/asynchronous/test_connectionpool.py +++ b/test/with_dummyserver/asynchronous/test_connectionpool.py @@ -1091,6 +1091,7 @@ async def test_headers_not_modified_by_request( else: conn = await pool._get_conn() await conn.request("GET", "/headers", chunked=chunked) + await pool._put_conn(conn) assert pool.headers == {"key": "val"} assert isinstance(pool.headers, header_type) @@ -1101,6 +1102,7 @@ async def test_headers_not_modified_by_request( else: conn = await pool._get_conn() await conn.request("GET", "/headers", headers=headers, chunked=chunked) + await pool._put_conn(conn) assert headers == {"key": "val"} diff --git a/test/with_dummyserver/asynchronous/test_poolmanager.py b/test/with_dummyserver/asynchronous/test_poolmanager.py index 90cac1f52d..a46755afc9 100644 --- a/test/with_dummyserver/asynchronous/test_poolmanager.py +++ b/test/with_dummyserver/asynchronous/test_poolmanager.py @@ -286,6 +286,7 @@ async def test_redirect_without_preload_releases_connection(self) -> None: assert r._pool.num_requests == 2 assert r._pool.num_connections == 1 assert len(http.pools) == 1 + await r.json() # consume content, avoid resource warning async def test_303_redirect_makes_request_lose_body(self) -> None: async with AsyncPoolManager() as http: @@ -300,6 +301,7 @@ async def test_303_redirect_makes_request_lose_body(self) -> None: data = await response.json() assert data["params"] == {} assert "Content-Type" not in HTTPHeaderDict(data["headers"]) + await response.json() # consume content, avoid resource warning async def test_unknown_scheme(self) -> None: async with AsyncPoolManager() as http: diff --git a/test/with_dummyserver/test_connection.py b/test/with_dummyserver/test_connection.py index eb814594cc..0e3365fc33 100644 --- a/test/with_dummyserver/test_connection.py +++ b/test/with_dummyserver/test_connection.py @@ -31,6 +31,7 @@ def test_returns_urllib3_HTTPResponse(pool: HTTPConnectionPool) -> None: response = conn.getresponse() assert isinstance(response, HTTPResponse) + pool._put_conn(conn) def test_does_not_release_conn(pool: HTTPConnectionPool) -> None: diff --git a/test/with_dummyserver/test_connectionpool.py b/test/with_dummyserver/test_connectionpool.py index 93056e6a9c..972695d7dd 100644 --- a/test/with_dummyserver/test_connectionpool.py +++ b/test/with_dummyserver/test_connectionpool.py @@ -1014,6 +1014,7 @@ def test_headers_not_modified_by_request( else: conn = pool._get_conn() conn.request("GET", "/headers", chunked=chunked) + pool._put_conn(conn) assert pool.headers == {"key": "val"} assert isinstance(pool.headers, header_type) @@ -1024,6 +1025,7 @@ def test_headers_not_modified_by_request( else: conn = pool._get_conn() conn.request("GET", "/headers", headers=headers, chunked=chunked) + pool._put_conn(conn) assert headers == {"key": "val"} diff --git a/test/with_dummyserver/test_https.py b/test/with_dummyserver/test_https.py index d8c8b85860..dbb9caa762 100644 --- a/test/with_dummyserver/test_https.py +++ b/test/with_dummyserver/test_https.py @@ -173,19 +173,21 @@ def test_client_intermediate(self) -> None: @notWindows() @notMacOS() def test_in_memory_client_intermediate(self) -> None: - with HTTPSConnectionPool( - self.host, - self.port, - key_data=open(os.path.join(self.certs_dir, CLIENT_INTERMEDIATE_KEY)).read(), - cert_data=open( + with open(os.path.join(self.certs_dir, CLIENT_INTERMEDIATE_KEY)) as fp_key_data: + with open( os.path.join(self.certs_dir, CLIENT_INTERMEDIATE_PEM) - ).read(), - ca_certs=DEFAULT_CA, - ssl_minimum_version=self.tls_version(), - ) as https_pool: - r = https_pool.request("GET", "/certificate") - subject = r.json() - assert subject["organizationalUnitName"].startswith("Testing cert") + ) as fp_cert_data: + with HTTPSConnectionPool( + self.host, + self.port, + key_data=fp_key_data.read(), + cert_data=fp_cert_data.read(), + ca_certs=DEFAULT_CA, + ssl_minimum_version=self.tls_version(), + ) as https_pool: + r = https_pool.request("GET", "/certificate") + subject = r.json() + assert subject["organizationalUnitName"].startswith("Testing cert") def test_client_no_intermediate(self) -> None: """Check that missing links in certificate chains indeed break @@ -221,18 +223,20 @@ def test_client_key_password(self) -> None: @notWindows() @notMacOS() def test_in_memory_client_key_password(self) -> None: - with HTTPSConnectionPool( - self.host, - self.port, - ca_certs=DEFAULT_CA, - key_data=open(os.path.join(self.certs_dir, PASSWORD_CLIENT_KEYFILE)).read(), - cert_data=open(os.path.join(self.certs_dir, CLIENT_CERT)).read(), - key_password="letmein", - ssl_minimum_version=self.tls_version(), - ) as https_pool: - r = https_pool.request("GET", "/certificate") - subject = r.json() - assert subject["organizationalUnitName"].startswith("Testing cert") + with open(os.path.join(self.certs_dir, PASSWORD_CLIENT_KEYFILE)) as fp_key_data: + with open(os.path.join(self.certs_dir, CLIENT_CERT)) as fp_cert_data: + with HTTPSConnectionPool( + self.host, + self.port, + ca_certs=DEFAULT_CA, + key_data=fp_key_data.read(), + cert_data=fp_cert_data.read(), + key_password="letmein", + ssl_minimum_version=self.tls_version(), + ) as https_pool: + r = https_pool.request("GET", "/certificate") + subject = r.json() + assert subject["organizationalUnitName"].startswith("Testing cert") def test_client_encrypted_key_requires_password(self) -> None: with HTTPSConnectionPool( diff --git a/test/with_traefik/asynchronous/test_conn_info.py b/test/with_traefik/asynchronous/test_conn_info.py index cd67b3d0d1..a922dfa51a 100644 --- a/test/with_traefik/asynchronous/test_conn_info.py +++ b/test/with_traefik/asynchronous/test_conn_info.py @@ -29,6 +29,8 @@ async def on_post_connection(o: ConnectionInfo) -> None: assert conn_info.http_version == HttpVersion.h11 assert conn_info.certificate_dict is None + await p.clear() + async def test_tls_on_tcp(self) -> None: p = AsyncPoolManager( ca_certs=self.ca_authority, resolver=self.test_async_resolver @@ -50,6 +52,8 @@ async def on_post_connection(o: ConnectionInfo) -> None: assert conn_info.tls_version is not None assert conn_info.cipher is not None + await p.clear() + @pytest.mark.usefixtures("requires_http3") async def test_tls_on_udp(self) -> None: p = AsyncPoolManager( @@ -75,3 +79,5 @@ async def on_post_connection(o: ConnectionInfo) -> None: assert conn_info.tls_version is not None assert conn_info.cipher is not None assert conn_info.http_version == HttpVersion.h3 + + await p.clear() diff --git a/test/with_traefik/asynchronous/test_connection.py b/test/with_traefik/asynchronous/test_connection.py index 38eb16ef7f..a4aa5ded87 100644 --- a/test/with_traefik/asynchronous/test_connection.py +++ b/test/with_traefik/asynchronous/test_connection.py @@ -66,6 +66,8 @@ async def test_h2_svn_conserved(self) -> None: assert resp.version == 20 + await conn.close() + async def test_getresponse_not_ready(self) -> None: conn = AsyncHTTPSConnection( self.host, @@ -99,6 +101,8 @@ async def test_quic_cache_capable(self) -> None: assert resp.status == 200 assert resp.version == 30 + await conn.close() + @pytest.mark.usefixtures("requires_http3") async def test_quic_cache_capable_but_disabled(self) -> None: quic_cache_resumption: dict[tuple[str, int], tuple[str, int] | None] = { @@ -120,6 +124,8 @@ async def test_quic_cache_capable_but_disabled(self) -> None: assert resp.status == 200 assert resp.version == 20 + await conn.close() + @pytest.mark.usefixtures("requires_http3") async def test_quic_cache_explicit_not_capable(self) -> None: quic_cache_resumption: dict[tuple[str, int], tuple[str, int] | None] = { @@ -140,6 +146,8 @@ async def test_quic_cache_explicit_not_capable(self) -> None: assert resp.status == 200 assert resp.version == 20 + await conn.close() + @pytest.mark.usefixtures("requires_http3") async def test_quic_cache_implicit_not_capable(self) -> None: quic_cache_resumption: dict[tuple[str, int], tuple[str, int] | None] = dict() @@ -161,6 +169,8 @@ async def test_quic_cache_implicit_not_capable(self) -> None: assert len(quic_cache_resumption.keys()) == 1 assert (self.host, self.https_port) in quic_cache_resumption + await conn.close() + @pytest.mark.usefixtures("requires_http3") async def test_quic_extract_ssl_ctx_ca_root(self) -> None: quic_cache_resumption: dict[tuple[str, int], tuple[str, int] | None] = { @@ -194,6 +204,8 @@ async def test_quic_extract_ssl_ctx_ca_root(self) -> None: assert resp.status == 200 assert resp.version == 30 + await conn.close() + @pytest.mark.xfail( reason="experimental support for reusable outgoing port", strict=False ) @@ -207,6 +219,8 @@ async def test_fast_reuse_outgoing_port(self) -> None: source_address=("0.0.0.0", 8745), ) + await conn.connect() + await conn.request("GET", "/get") resp = await conn.getresponse() diff --git a/test/with_traefik/asynchronous/test_connection_multiplexed.py b/test/with_traefik/asynchronous/test_connection_multiplexed.py index 19156125af..f5e3fd2ede 100644 --- a/test/with_traefik/asynchronous/test_connection_multiplexed.py +++ b/test/with_traefik/asynchronous/test_connection_multiplexed.py @@ -121,3 +121,5 @@ async def test_multiplexing_upgrade_h3(self) -> None: for i in range(3): r = await conn.getresponse() assert r.version == 30 + + await conn.close() diff --git a/test/with_traefik/asynchronous/test_svn.py b/test/with_traefik/asynchronous/test_svn.py index 73df2440af..44c20888ca 100644 --- a/test/with_traefik/asynchronous/test_svn.py +++ b/test/with_traefik/asynchronous/test_svn.py @@ -276,6 +276,8 @@ async def test_drop_h3_upgrade(self) -> None: assert resp.version == 20 assert resp.status == 200 + await conn.close() + @pytest.mark.usefixtures("requires_http3") async def test_drop_post_established_h3(self) -> None: conn = AsyncHTTPSConnection( @@ -309,6 +311,8 @@ async def test_drop_post_established_h3(self) -> None: assert resp.version == 20 assert resp.status == 200 + await conn.close() + @pytest.mark.usefixtures("requires_http3") async def test_pool_manager_quic_cache(self) -> None: dumb_cache: dict[tuple[str, int], tuple[str, int] | None] = dict() @@ -329,6 +333,8 @@ async def test_pool_manager_quic_cache(self) -> None: await conn.close() + await pm.clear() + pm = AsyncPoolManager( ca_certs=self.ca_authority, preemptive_quic_cache=dumb_cache, @@ -343,6 +349,9 @@ async def test_pool_manager_quic_cache(self) -> None: assert len(dumb_cache.keys()) == 1 + await conn.close() + await pm.clear() + async def test_http2_with_prior_knowledge(self) -> None: async with AsyncHTTPConnectionPool( self.host, diff --git a/test/with_traefik/test_conn_info.py b/test/with_traefik/test_conn_info.py index 2d126c619c..d5afc66b46 100644 --- a/test/with_traefik/test_conn_info.py +++ b/test/with_traefik/test_conn_info.py @@ -29,6 +29,8 @@ def on_post_connection(o: ConnectionInfo) -> None: assert conn_info.http_version == HttpVersion.h11 assert conn_info.certificate_dict is None + p.clear() + def test_tls_on_tcp(self) -> None: p = PoolManager(ca_certs=self.ca_authority, resolver=self.test_resolver) @@ -48,6 +50,8 @@ def on_post_connection(o: ConnectionInfo) -> None: assert conn_info.tls_version is not None assert conn_info.cipher is not None + p.clear() + @pytest.mark.skipif( sys.version_info < (3, 10), reason="unsupported due missing API (3.10+)", @@ -75,6 +79,8 @@ def on_post_connection(o: ConnectionInfo) -> None: assert conn_info.tls_version is not None assert conn_info.cipher is not None + p.clear() + @pytest.mark.usefixtures("requires_http3") def test_tls_on_udp(self) -> None: p = PoolManager( @@ -100,3 +106,5 @@ def on_post_connection(o: ConnectionInfo) -> None: assert conn_info.tls_version is not None assert conn_info.cipher is not None assert conn_info.http_version == HttpVersion.h3 + + p.clear() diff --git a/test/with_traefik/test_connection.py b/test/with_traefik/test_connection.py index 49428ab385..4926a26d31 100644 --- a/test/with_traefik/test_connection.py +++ b/test/with_traefik/test_connection.py @@ -65,6 +65,8 @@ def test_h2_svn_conserved(self) -> None: assert resp.version == 20 + conn.close() + def test_getresponse_not_ready(self) -> None: conn = HTTPSConnection( self.host, @@ -98,6 +100,8 @@ def test_quic_cache_capable(self) -> None: assert resp.status == 200 assert resp.version == 30 + conn.close() + def test_quic_cache_capable_but_disabled(self) -> None: quic_cache_resumption: dict[tuple[str, int], tuple[str, int] | None] = { (self.host, self.https_port): ("", self.https_port) @@ -192,6 +196,8 @@ def test_quic_extract_ssl_ctx_ca_root(self) -> None: assert resp.status == 200 assert resp.version == 30 + conn.close() + @pytest.mark.xfail( reason="experimental support for reusable outgoing port", strict=False ) @@ -205,6 +211,8 @@ def test_fast_reuse_outgoing_port(self) -> None: source_address=("0.0.0.0", 8845), ) + conn.connect() + conn.request("GET", "/get") resp = conn.getresponse() diff --git a/test/with_traefik/test_connection_multiplexed.py b/test/with_traefik/test_connection_multiplexed.py index 470eed704f..30f5178488 100644 --- a/test/with_traefik/test_connection_multiplexed.py +++ b/test/with_traefik/test_connection_multiplexed.py @@ -120,3 +120,5 @@ def test_multiplexing_upgrade_h3(self) -> None: for i in range(3): r = conn.getresponse() assert r.version == 30 + + conn.close() diff --git a/test/with_traefik/test_svn.py b/test/with_traefik/test_svn.py index 68f2ddfdb3..3dd07d3fb5 100644 --- a/test/with_traefik/test_svn.py +++ b/test/with_traefik/test_svn.py @@ -99,6 +99,8 @@ def test_can_disable_h11(self) -> None: assert r.status == 200 assert r.version == 20 + p.close() + def test_cannot_disable_everything(self) -> None: with pytest.raises(RuntimeError): p = HTTPSConnectionPool( @@ -268,6 +270,8 @@ def test_drop_h3_upgrade(self) -> None: assert resp.version == 20 assert resp.status == 200 + conn.close() + @pytest.mark.usefixtures("requires_http3") def test_drop_post_established_h3(self) -> None: conn = HTTPSConnection( @@ -321,6 +325,8 @@ def test_pool_manager_quic_cache(self) -> None: conn.close() + pm.clear() + pm = PoolManager( ca_certs=self.ca_authority, preemptive_quic_cache=dumb_cache, @@ -335,6 +341,9 @@ def test_pool_manager_quic_cache(self) -> None: assert len(dumb_cache.keys()) == 1 + conn.close() + pm.clear() + def test_can_upgrade_h2c_via_altsvc(self) -> None: with HTTPConnectionPool( self.host,