30
30
31
31
if SSL_AVAILABLE :
32
32
import ssl
33
- from ssl import SSLContext , TLSVersion
33
+ from ssl import SSLContext , TLSVersion , VerifyFlags
34
34
else :
35
35
ssl = None
36
36
TLSVersion = None
37
37
SSLContext = None
38
+ VerifyFlags = None
38
39
39
40
from ..auth .token import TokenInterface
40
41
from ..event import AsyncAfterConnectionReleasedEvent , EventDispatcher
@@ -793,6 +794,8 @@ def __init__(
793
794
ssl_keyfile : Optional [str ] = None ,
794
795
ssl_certfile : Optional [str ] = None ,
795
796
ssl_cert_reqs : Union [str , ssl .VerifyMode ] = "required" ,
797
+ ssl_include_verify_flags : Optional [List ["ssl.VerifyFlags" ]] = None ,
798
+ ssl_exclude_verify_flags : Optional [List ["ssl.VerifyFlags" ]] = None ,
796
799
ssl_ca_certs : Optional [str ] = None ,
797
800
ssl_ca_data : Optional [str ] = None ,
798
801
ssl_check_hostname : bool = True ,
@@ -807,6 +810,8 @@ def __init__(
807
810
keyfile = ssl_keyfile ,
808
811
certfile = ssl_certfile ,
809
812
cert_reqs = ssl_cert_reqs ,
813
+ include_verify_flags = ssl_include_verify_flags ,
814
+ exclude_verify_flags = ssl_exclude_verify_flags ,
810
815
ca_certs = ssl_ca_certs ,
811
816
ca_data = ssl_ca_data ,
812
817
check_hostname = ssl_check_hostname ,
@@ -832,6 +837,14 @@ def certfile(self):
832
837
def cert_reqs (self ):
833
838
return self .ssl_context .cert_reqs
834
839
840
+ @property
841
+ def include_verify_flags (self ):
842
+ return self .ssl_context .include_verify_flags
843
+
844
+ @property
845
+ def exclude_verify_flags (self ):
846
+ return self .ssl_context .exclude_verify_flags
847
+
835
848
@property
836
849
def ca_certs (self ):
837
850
return self .ssl_context .ca_certs
@@ -854,6 +867,8 @@ class RedisSSLContext:
854
867
"keyfile" ,
855
868
"certfile" ,
856
869
"cert_reqs" ,
870
+ "include_verify_flags" ,
871
+ "exclude_verify_flags" ,
857
872
"ca_certs" ,
858
873
"ca_data" ,
859
874
"context" ,
@@ -867,6 +882,8 @@ def __init__(
867
882
keyfile : Optional [str ] = None ,
868
883
certfile : Optional [str ] = None ,
869
884
cert_reqs : Optional [Union [str , ssl .VerifyMode ]] = None ,
885
+ include_verify_flags : Optional [List ["ssl.VerifyFlags" ]] = None ,
886
+ exclude_verify_flags : Optional [List ["ssl.VerifyFlags" ]] = None ,
870
887
ca_certs : Optional [str ] = None ,
871
888
ca_data : Optional [str ] = None ,
872
889
check_hostname : bool = False ,
@@ -892,6 +909,8 @@ def __init__(
892
909
)
893
910
cert_reqs = CERT_REQS [cert_reqs ]
894
911
self .cert_reqs = cert_reqs
912
+ self .include_verify_flags = include_verify_flags
913
+ self .exclude_verify_flags = exclude_verify_flags
895
914
self .ca_certs = ca_certs
896
915
self .ca_data = ca_data
897
916
self .check_hostname = (
@@ -906,6 +925,12 @@ def get(self) -> SSLContext:
906
925
context = ssl .create_default_context ()
907
926
context .check_hostname = self .check_hostname
908
927
context .verify_mode = self .cert_reqs
928
+ if self .include_verify_flags :
929
+ for flag in self .include_verify_flags :
930
+ context .verify_flags |= flag
931
+ if self .exclude_verify_flags :
932
+ for flag in self .exclude_verify_flags :
933
+ context .verify_flags &= ~ flag
909
934
if self .certfile and self .keyfile :
910
935
context .load_cert_chain (certfile = self .certfile , keyfile = self .keyfile )
911
936
if self .ca_certs or self .ca_data :
@@ -953,6 +978,20 @@ def to_bool(value) -> Optional[bool]:
953
978
return bool (value )
954
979
955
980
981
+ def parse_ssl_verify_flags (value ):
982
+ # flags are passed in as a string representation of a list,
983
+ # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
984
+ verify_flags_str = value .replace ("[" , "" ).replace ("]" , "" )
985
+
986
+ verify_flags = []
987
+ for flag in verify_flags_str .split ("," ):
988
+ flag = flag .strip ()
989
+ if not hasattr (VerifyFlags , flag ):
990
+ raise ValueError (f"Invalid ssl verify flag: { flag } " )
991
+ verify_flags .append (getattr (VerifyFlags , flag ))
992
+ return verify_flags
993
+
994
+
956
995
URL_QUERY_ARGUMENT_PARSERS : Mapping [str , Callable [..., object ]] = MappingProxyType (
957
996
{
958
997
"db" : int ,
@@ -963,6 +1002,8 @@ def to_bool(value) -> Optional[bool]:
963
1002
"max_connections" : int ,
964
1003
"health_check_interval" : int ,
965
1004
"ssl_check_hostname" : to_bool ,
1005
+ "ssl_include_verify_flags" : parse_ssl_verify_flags ,
1006
+ "ssl_exclude_verify_flags" : parse_ssl_verify_flags ,
966
1007
"timeout" : float ,
967
1008
}
968
1009
)
@@ -1021,6 +1062,7 @@ def parse_url(url: str) -> ConnectKwargs:
1021
1062
1022
1063
if parsed .scheme == "rediss" :
1023
1064
kwargs ["connection_class" ] = SSLConnection
1065
+
1024
1066
else :
1025
1067
valid_schemes = "redis://, rediss://, unix://"
1026
1068
raise ValueError (
0 commit comments