Skip to content

Commit 89dbc2f

Browse files
authored
Adding ssl_verify_flags_config argument for ssl connection configuration (#3772)
1 parent e138000 commit 89dbc2f

File tree

9 files changed

+339
-6
lines changed

9 files changed

+339
-6
lines changed

redis/asyncio/client.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,11 @@
8181
)
8282

8383
if TYPE_CHECKING and SSL_AVAILABLE:
84-
from ssl import TLSVersion, VerifyMode
84+
from ssl import TLSVersion, VerifyFlags, VerifyMode
8585
else:
8686
TLSVersion = None
8787
VerifyMode = None
88+
VerifyFlags = None
8889

8990
PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
9091
_KeyT = TypeVar("_KeyT", bound=KeyT)
@@ -238,6 +239,8 @@ def __init__(
238239
ssl_keyfile: Optional[str] = None,
239240
ssl_certfile: Optional[str] = None,
240241
ssl_cert_reqs: Union[str, VerifyMode] = "required",
242+
ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
243+
ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
241244
ssl_ca_certs: Optional[str] = None,
242245
ssl_ca_data: Optional[str] = None,
243246
ssl_check_hostname: bool = True,
@@ -347,6 +350,8 @@ def __init__(
347350
"ssl_keyfile": ssl_keyfile,
348351
"ssl_certfile": ssl_certfile,
349352
"ssl_cert_reqs": ssl_cert_reqs,
353+
"ssl_include_verify_flags": ssl_include_verify_flags,
354+
"ssl_exclude_verify_flags": ssl_exclude_verify_flags,
350355
"ssl_ca_certs": ssl_ca_certs,
351356
"ssl_ca_data": ssl_ca_data,
352357
"ssl_check_hostname": ssl_check_hostname,

redis/asyncio/cluster.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,11 @@
8686
)
8787

8888
if SSL_AVAILABLE:
89-
from ssl import TLSVersion, VerifyMode
89+
from ssl import TLSVersion, VerifyFlags, VerifyMode
9090
else:
9191
TLSVersion = None
9292
VerifyMode = None
93+
VerifyFlags = None
9394

9495
TargetNodesT = TypeVar(
9596
"TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
@@ -299,6 +300,8 @@ def __init__(
299300
ssl_ca_certs: Optional[str] = None,
300301
ssl_ca_data: Optional[str] = None,
301302
ssl_cert_reqs: Union[str, VerifyMode] = "required",
303+
ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
304+
ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
302305
ssl_certfile: Optional[str] = None,
303306
ssl_check_hostname: bool = True,
304307
ssl_keyfile: Optional[str] = None,
@@ -358,6 +361,8 @@ def __init__(
358361
"ssl_ca_certs": ssl_ca_certs,
359362
"ssl_ca_data": ssl_ca_data,
360363
"ssl_cert_reqs": ssl_cert_reqs,
364+
"ssl_include_verify_flags": ssl_include_verify_flags,
365+
"ssl_exclude_verify_flags": ssl_exclude_verify_flags,
361366
"ssl_certfile": ssl_certfile,
362367
"ssl_check_hostname": ssl_check_hostname,
363368
"ssl_keyfile": ssl_keyfile,

redis/asyncio/connection.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,12 @@
3030

3131
if SSL_AVAILABLE:
3232
import ssl
33-
from ssl import SSLContext, TLSVersion
33+
from ssl import SSLContext, TLSVersion, VerifyFlags
3434
else:
3535
ssl = None
3636
TLSVersion = None
3737
SSLContext = None
38+
VerifyFlags = None
3839

3940
from ..auth.token import TokenInterface
4041
from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher
@@ -793,6 +794,8 @@ def __init__(
793794
ssl_keyfile: Optional[str] = None,
794795
ssl_certfile: Optional[str] = None,
795796
ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required",
797+
ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
798+
ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
796799
ssl_ca_certs: Optional[str] = None,
797800
ssl_ca_data: Optional[str] = None,
798801
ssl_check_hostname: bool = True,
@@ -807,6 +810,8 @@ def __init__(
807810
keyfile=ssl_keyfile,
808811
certfile=ssl_certfile,
809812
cert_reqs=ssl_cert_reqs,
813+
include_verify_flags=ssl_include_verify_flags,
814+
exclude_verify_flags=ssl_exclude_verify_flags,
810815
ca_certs=ssl_ca_certs,
811816
ca_data=ssl_ca_data,
812817
check_hostname=ssl_check_hostname,
@@ -832,6 +837,14 @@ def certfile(self):
832837
def cert_reqs(self):
833838
return self.ssl_context.cert_reqs
834839

840+
@property
841+
def include_verify_flags(self):
842+
return self.ssl_context.include_verify_flags
843+
844+
@property
845+
def exclude_verify_flags(self):
846+
return self.ssl_context.exclude_verify_flags
847+
835848
@property
836849
def ca_certs(self):
837850
return self.ssl_context.ca_certs
@@ -854,6 +867,8 @@ class RedisSSLContext:
854867
"keyfile",
855868
"certfile",
856869
"cert_reqs",
870+
"include_verify_flags",
871+
"exclude_verify_flags",
857872
"ca_certs",
858873
"ca_data",
859874
"context",
@@ -867,6 +882,8 @@ def __init__(
867882
keyfile: Optional[str] = None,
868883
certfile: Optional[str] = None,
869884
cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None,
885+
include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
886+
exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
870887
ca_certs: Optional[str] = None,
871888
ca_data: Optional[str] = None,
872889
check_hostname: bool = False,
@@ -892,6 +909,8 @@ def __init__(
892909
)
893910
cert_reqs = CERT_REQS[cert_reqs]
894911
self.cert_reqs = cert_reqs
912+
self.include_verify_flags = include_verify_flags
913+
self.exclude_verify_flags = exclude_verify_flags
895914
self.ca_certs = ca_certs
896915
self.ca_data = ca_data
897916
self.check_hostname = (
@@ -906,6 +925,12 @@ def get(self) -> SSLContext:
906925
context = ssl.create_default_context()
907926
context.check_hostname = self.check_hostname
908927
context.verify_mode = self.cert_reqs
928+
if self.include_verify_flags:
929+
for flag in self.include_verify_flags:
930+
context.verify_flags |= flag
931+
if self.exclude_verify_flags:
932+
for flag in self.exclude_verify_flags:
933+
context.verify_flags &= ~flag
909934
if self.certfile and self.keyfile:
910935
context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
911936
if self.ca_certs or self.ca_data:
@@ -953,6 +978,20 @@ def to_bool(value) -> Optional[bool]:
953978
return bool(value)
954979

955980

981+
def parse_ssl_verify_flags(value):
982+
# flags are passed in as a string representation of a list,
983+
# e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
984+
verify_flags_str = value.replace("[", "").replace("]", "")
985+
986+
verify_flags = []
987+
for flag in verify_flags_str.split(","):
988+
flag = flag.strip()
989+
if not hasattr(VerifyFlags, flag):
990+
raise ValueError(f"Invalid ssl verify flag: {flag}")
991+
verify_flags.append(getattr(VerifyFlags, flag))
992+
return verify_flags
993+
994+
956995
URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType(
957996
{
958997
"db": int,
@@ -963,6 +1002,8 @@ def to_bool(value) -> Optional[bool]:
9631002
"max_connections": int,
9641003
"health_check_interval": int,
9651004
"ssl_check_hostname": to_bool,
1005+
"ssl_include_verify_flags": parse_ssl_verify_flags,
1006+
"ssl_exclude_verify_flags": parse_ssl_verify_flags,
9661007
"timeout": float,
9671008
}
9681009
)
@@ -1021,6 +1062,7 @@ def parse_url(url: str) -> ConnectKwargs:
10211062

10221063
if parsed.scheme == "rediss":
10231064
kwargs["connection_class"] = SSLConnection
1065+
10241066
else:
10251067
valid_schemes = "redis://, rediss://, unix://"
10261068
raise ValueError(

redis/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ def __init__(
224224
ssl_keyfile: Optional[str] = None,
225225
ssl_certfile: Optional[str] = None,
226226
ssl_cert_reqs: Union[str, "ssl.VerifyMode"] = "required",
227+
ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
228+
ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
227229
ssl_ca_certs: Optional[str] = None,
228230
ssl_ca_path: Optional[str] = None,
229231
ssl_ca_data: Optional[str] = None,
@@ -330,6 +332,8 @@ def __init__(
330332
"ssl_keyfile": ssl_keyfile,
331333
"ssl_certfile": ssl_certfile,
332334
"ssl_cert_reqs": ssl_cert_reqs,
335+
"ssl_include_verify_flags": ssl_include_verify_flags,
336+
"ssl_exclude_verify_flags": ssl_exclude_verify_flags,
333337
"ssl_ca_certs": ssl_ca_certs,
334338
"ssl_ca_data": ssl_ca_data,
335339
"ssl_check_hostname": ssl_check_hostname,

redis/cluster.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ def parse_cluster_myshardid(resp, **options):
184184
"ssl_ca_data",
185185
"ssl_certfile",
186186
"ssl_cert_reqs",
187+
"ssl_include_verify_flags",
188+
"ssl_exclude_verify_flags",
187189
"ssl_keyfile",
188190
"ssl_password",
189191
"ssl_check_hostname",

redis/connection.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,10 @@
6868

6969
if SSL_AVAILABLE:
7070
import ssl
71+
from ssl import VerifyFlags
7172
else:
7273
ssl = None
74+
VerifyFlags = None
7375

7476
if HIREDIS_AVAILABLE:
7577
import hiredis
@@ -1360,6 +1362,8 @@ def __init__(
13601362
ssl_keyfile=None,
13611363
ssl_certfile=None,
13621364
ssl_cert_reqs="required",
1365+
ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None,
1366+
ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None,
13631367
ssl_ca_certs=None,
13641368
ssl_ca_data=None,
13651369
ssl_check_hostname=True,
@@ -1378,7 +1382,10 @@ def __init__(
13781382
Args:
13791383
ssl_keyfile: Path to an ssl private key. Defaults to None.
13801384
ssl_certfile: Path to an ssl certificate. Defaults to None.
1381-
ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), or an ssl.VerifyMode. Defaults to "required".
1385+
ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required),
1386+
or an ssl.VerifyMode. Defaults to "required".
1387+
ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None.
1388+
ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None.
13821389
ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None.
13831390
ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
13841391
ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True.
@@ -1414,6 +1421,8 @@ def __init__(
14141421
)
14151422
ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
14161423
self.cert_reqs = ssl_cert_reqs
1424+
self.ssl_include_verify_flags = ssl_include_verify_flags
1425+
self.ssl_exclude_verify_flags = ssl_exclude_verify_flags
14171426
self.ca_certs = ssl_ca_certs
14181427
self.ca_data = ssl_ca_data
14191428
self.ca_path = ssl_ca_path
@@ -1453,6 +1462,12 @@ def _wrap_socket_with_ssl(self, sock):
14531462
context = ssl.create_default_context()
14541463
context.check_hostname = self.check_hostname
14551464
context.verify_mode = self.cert_reqs
1465+
if self.ssl_include_verify_flags:
1466+
for flag in self.ssl_include_verify_flags:
1467+
context.verify_flags |= flag
1468+
if self.ssl_exclude_verify_flags:
1469+
for flag in self.ssl_exclude_verify_flags:
1470+
context.verify_flags &= ~flag
14561471
if self.certfile or self.keyfile:
14571472
context.load_cert_chain(
14581473
certfile=self.certfile,
@@ -1566,6 +1581,20 @@ def to_bool(value):
15661581
return bool(value)
15671582

15681583

1584+
def parse_ssl_verify_flags(value):
1585+
# flags are passed in as a string representation of a list,
1586+
# e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
1587+
verify_flags_str = value.replace("[", "").replace("]", "")
1588+
1589+
verify_flags = []
1590+
for flag in verify_flags_str.split(","):
1591+
flag = flag.strip()
1592+
if not hasattr(VerifyFlags, flag):
1593+
raise ValueError(f"Invalid ssl verify flag: {flag}")
1594+
verify_flags.append(getattr(VerifyFlags, flag))
1595+
return verify_flags
1596+
1597+
15691598
URL_QUERY_ARGUMENT_PARSERS = {
15701599
"db": int,
15711600
"socket_timeout": float,
@@ -1576,6 +1605,8 @@ def to_bool(value):
15761605
"max_connections": int,
15771606
"health_check_interval": int,
15781607
"ssl_check_hostname": to_bool,
1608+
"ssl_include_verify_flags": parse_ssl_verify_flags,
1609+
"ssl_exclude_verify_flags": parse_ssl_verify_flags,
15791610
"timeout": float,
15801611
}
15811612

tests/test_asyncio/test_ssl.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import ssl
2+
import unittest.mock
13
from urllib.parse import urlparse
24
import pytest
35
import pytest_asyncio
@@ -54,3 +56,88 @@ async def test_cert_reqs_none_with_check_hostname(self, request):
5456
assert conn.check_hostname is False
5557
finally:
5658
await r.aclose()
59+
60+
async def test_ssl_flags_applied_to_context(self, request):
61+
"""
62+
Test that ssl_include_verify_flags and ssl_exclude_verify_flags
63+
are properly applied to the SSL context
64+
"""
65+
ssl_url = request.config.option.redis_ssl_url
66+
parsed_url = urlparse(ssl_url)
67+
68+
# Test with specific SSL verify flags
69+
ssl_include_verify_flags = [
70+
ssl.VerifyFlags.VERIFY_CRL_CHECK_LEAF, # Disable strict verification
71+
ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN, # Enable partial chain
72+
]
73+
74+
ssl_exclude_verify_flags = [
75+
ssl.VerifyFlags.VERIFY_X509_STRICT, # Disable trusted first
76+
]
77+
78+
r = redis.Redis(
79+
host=parsed_url.hostname,
80+
port=parsed_url.port,
81+
ssl=True,
82+
ssl_cert_reqs="none",
83+
ssl_include_verify_flags=ssl_include_verify_flags,
84+
ssl_exclude_verify_flags=ssl_exclude_verify_flags,
85+
)
86+
87+
try:
88+
# Get the connection to trigger SSL context creation
89+
conn = r.connection_pool.make_connection()
90+
assert isinstance(conn, redis.SSLConnection)
91+
92+
# Verify the flags were processed by checking they're stored in connection
93+
assert conn.include_verify_flags is not None
94+
assert len(conn.include_verify_flags) == 2
95+
96+
assert conn.exclude_verify_flags is not None
97+
assert len(conn.exclude_verify_flags) == 1
98+
99+
# Check each flag individually
100+
for flag in ssl_include_verify_flags:
101+
assert flag in conn.include_verify_flags, (
102+
f"Flag {flag} not found in stored ssl_include_verify_flags"
103+
)
104+
for flag in ssl_exclude_verify_flags:
105+
assert flag in conn.exclude_verify_flags, (
106+
f"Flag {flag} not found in stored ssl_exclude_verify_flags"
107+
)
108+
109+
# Test the actual SSL context created by the connection's RedisSSLContext
110+
# We need to mock the ssl.create_default_context to capture the context
111+
captured_context = None
112+
original_create_default_context = ssl.create_default_context
113+
114+
def capture_context_create_default():
115+
nonlocal captured_context
116+
captured_context = original_create_default_context()
117+
return captured_context
118+
119+
with unittest.mock.patch(
120+
"ssl.create_default_context", capture_context_create_default
121+
):
122+
# Trigger SSL context creation by calling get() on the RedisSSLContext
123+
ssl_context = conn.ssl_context.get()
124+
125+
# Validate that we captured a context and it has the correct flags applied
126+
assert captured_context is not None, "SSL context was not captured"
127+
assert ssl_context is captured_context, (
128+
"Returned context should be the captured one"
129+
)
130+
131+
# Verify that VERIFY_X509_STRICT was disabled (bit cleared)
132+
assert not (
133+
captured_context.verify_flags & ssl.VerifyFlags.VERIFY_X509_STRICT
134+
), "VERIFY_X509_STRICT should be disabled but is enabled"
135+
136+
# Verify that VERIFY_CRL_CHECK_CHAIN was enabled (bit set)
137+
assert (
138+
captured_context.verify_flags
139+
& ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN
140+
), "VERIFY_CRL_CHECK_CHAIN should be enabled but is disabled"
141+
142+
finally:
143+
await r.aclose()

0 commit comments

Comments
 (0)