Skip to content

Commit b102c3b

Browse files
authored
Adding 'auto' option to MaintNotificationsConfig.enabled (#3779)
1 parent 89dbc2f commit b102c3b

File tree

4 files changed

+105
-49
lines changed

4 files changed

+105
-49
lines changed

redis/connection.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -688,8 +688,12 @@ def on_connect_check_health(self, check_health: bool = True):
688688
):
689689
raise ConnectionError("Invalid RESP version")
690690

691-
# Send maintenance notifications handshake if RESP3 is active and maintenance notifications are enabled
691+
# Send maintenance notifications handshake if RESP3 is active
692+
# and maintenance notifications are enabled
692693
# and we have a host to determine the endpoint type from
694+
# When the maint_notifications_config enabled mode is "auto",
695+
# we just log a warning if the handshake fails
696+
# When the mode is enabled=True, we raise an exception in case of failure
693697
if (
694698
self.protocol not in [2, "2"]
695699
and self.maint_notifications_config
@@ -711,15 +715,21 @@ def on_connect_check_health(self, check_health: bool = True):
711715
)
712716
response = self.read_response()
713717
if str_if_bytes(response) != "OK":
714-
raise ConnectionError(
718+
raise ResponseError(
715719
"The server doesn't support maintenance notifications"
716720
)
717721
except Exception as e:
718-
# Log warning but don't fail the connection
719-
import logging
722+
if (
723+
isinstance(e, ResponseError)
724+
and self.maint_notifications_config.enabled == "auto"
725+
):
726+
# Log warning but don't fail the connection
727+
import logging
720728

721-
logger = logging.getLogger(__name__)
722-
logger.warning(f"Failed to enable maintenance notifications: {e}")
729+
logger = logging.getLogger(__name__)
730+
logger.warning(f"Failed to enable maintenance notifications: {e}")
731+
else:
732+
raise
723733

724734
# if a client_name is given, set it
725735
if self.client_name:

redis/maint_notifications.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import threading
66
import time
77
from abc import ABC, abstractmethod
8-
from typing import TYPE_CHECKING, Optional, Union
8+
from typing import TYPE_CHECKING, Literal, Optional, Union
99

1010
from redis.typing import Number
1111

@@ -447,7 +447,7 @@ class MaintNotificationsConfig:
447447

448448
def __init__(
449449
self,
450-
enabled: bool = True,
450+
enabled: Union[bool, Literal["auto"]] = "auto",
451451
proactive_reconnect: bool = True,
452452
relaxed_timeout: Optional[Number] = 10,
453453
endpoint_type: Optional[EndpointType] = None,
@@ -456,8 +456,13 @@ def __init__(
456456
Initialize a new MaintNotificationsConfig.
457457
458458
Args:
459-
enabled (bool): Whether to enable maintenance notifications handling.
460-
Defaults to False.
459+
enabled (bool | "auto"): Controls maintenance notifications handling behavior.
460+
- True: The CLIENT MAINT_NOTIFICATIONS command must succeed during connection setup,
461+
otherwise a ResponseError is raised.
462+
- "auto": The CLIENT MAINT_NOTIFICATIONS command is attempted but failures are
463+
gracefully handled - a warning is logged and normal operation continues.
464+
- False: Maintenance notifications are completely disabled.
465+
Defaults to "auto".
461466
proactive_reconnect (bool): Whether to proactively reconnect when a node is replaced.
462467
Defaults to True.
463468
relaxed_timeout (Number): The relaxed timeout to use for the connection during maintenance.

tests/test_maint_notifications.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ class TestMaintNotificationsConfig:
387387
def test_init_defaults(self):
388388
"""Test MaintNotificationsConfig initialization with defaults."""
389389
config = MaintNotificationsConfig()
390-
assert config.enabled is True
390+
assert config.enabled == "auto"
391391
assert config.proactive_reconnect is True
392392
assert config.relaxed_timeout == 10
393393

tests/test_maint_notifications_handling.py

Lines changed: 79 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
BlockingConnectionPool,
1414
MaintenanceState,
1515
)
16+
from redis.exceptions import ResponseError
1617
from redis.maint_notifications import (
18+
EndpointType,
1719
MaintNotificationsConfig,
1820
NodeMigratingNotification,
1921
NodeMigratedNotification,
@@ -201,6 +203,10 @@ def send(self, data):
201203
if b"HELLO" in data:
202204
response = b"%7\r\n$6\r\nserver\r\n$5\r\nredis\r\n$7\r\nversion\r\n$5\r\n7.0.0\r\n$5\r\nproto\r\n:3\r\n$2\r\nid\r\n:1\r\n$4\r\nmode\r\n$10\r\nstandalone\r\n$4\r\nrole\r\n$6\r\nmaster\r\n$7\r\nmodules\r\n*0\r\n"
203205
self.pending_responses.append(response)
206+
elif b"MAINT_NOTIFICATIONS" in data and b"internal-ip" in data:
207+
# Simulate error response - activate it only for internal-ip tests
208+
response = b"+ERROR\r\n"
209+
self.pending_responses.append(response)
204210
elif b"SET" in data:
205211
response = b"+OK\r\n"
206212

@@ -337,8 +343,8 @@ def shutdown(self, how):
337343
pass
338344

339345

340-
class TestMaintenanceNotificationsHandlingSingleProxy:
341-
"""Integration tests for maintenance notifications handling with real connection pool."""
346+
class TestMaintenanceNotificationsBase:
347+
"""Base class for maintenance notifications handling tests."""
342348

343349
def setup_method(self):
344350
"""Set up test fixtures with mocked sockets."""
@@ -393,7 +399,7 @@ def _get_client(
393399
pool_class: The connection pool class (ConnectionPool or BlockingConnectionPool)
394400
max_connections: Maximum number of connections in the pool (default: 10)
395401
maint_notifications_config: Optional MaintNotificationsConfig to use. If not provided,
396-
uses self.config from setup_method (default: None)
402+
uses self.config from setup_method (default: None)
397403
setup_pool_handler: Whether to set up pool handler for moving notifications (default: False)
398404
399405
Returns:
@@ -425,6 +431,71 @@ def _get_client(
425431

426432
return test_redis_client
427433

434+
435+
class TestMaintenanceNotificationsHandshake(TestMaintenanceNotificationsBase):
436+
"""Integration tests for maintenance notifications handling with real connection pool."""
437+
438+
def test_handshake_success_when_enabled(self):
439+
"""Test that handshake is performed correctly."""
440+
maint_notifications_config = MaintNotificationsConfig(
441+
enabled=True, endpoint_type=EndpointType.EXTERNAL_IP
442+
)
443+
test_redis_client = self._get_client(
444+
ConnectionPool, maint_notifications_config=maint_notifications_config
445+
)
446+
447+
try:
448+
# Perform Redis operations that should work with our improved mock responses
449+
result_set = test_redis_client.set("hello", "world")
450+
result_get = test_redis_client.get("hello")
451+
452+
# Verify operations completed successfully
453+
assert result_set is True
454+
assert result_get == b"world"
455+
456+
finally:
457+
test_redis_client.close()
458+
459+
def test_handshake_success_when_auto_and_command_not_supported(self):
460+
"""Test that when maintenance notifications are set to 'auto', the client gracefully handles unsupported MAINT_NOTIFICATIONS commands and normal Redis operations succeed."""
461+
maint_notifications_config = MaintNotificationsConfig(
462+
enabled="auto", endpoint_type=EndpointType.INTERNAL_IP
463+
)
464+
test_redis_client = self._get_client(
465+
ConnectionPool, maint_notifications_config=maint_notifications_config
466+
)
467+
468+
try:
469+
# Perform Redis operations that should work with our improved mock responses
470+
result_set = test_redis_client.set("hello", "world")
471+
result_get = test_redis_client.get("hello")
472+
473+
# Verify operations completed successfully
474+
assert result_set is True
475+
assert result_get == b"world"
476+
477+
finally:
478+
test_redis_client.close()
479+
480+
def test_handshake_failure_when_enabled(self):
481+
"""Test that handshake is performed correctly."""
482+
maint_notifications_config = MaintNotificationsConfig(
483+
enabled=True, endpoint_type=EndpointType.INTERNAL_IP
484+
)
485+
test_redis_client = self._get_client(
486+
ConnectionPool, maint_notifications_config=maint_notifications_config
487+
)
488+
try:
489+
with pytest.raises(ResponseError):
490+
test_redis_client.set("hello", "world")
491+
492+
finally:
493+
test_redis_client.close()
494+
495+
496+
class TestMaintenanceNotificationsHandlingSingleProxy(TestMaintenanceNotificationsBase):
497+
"""Integration tests for maintenance notifications handling with real connection pool."""
498+
428499
def _validate_connection_handlers(self, conn, pool_handler, config):
429500
"""Helper method to validate connection handlers are properly set."""
430501
# Test that the node moving handler function is correctly set
@@ -1891,40 +1962,16 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class):
18911962
pool.disconnect()
18921963

18931964

1894-
class TestMaintenanceNotificationsHandlingMultipleProxies:
1965+
class TestMaintenanceNotificationsHandlingMultipleProxies(
1966+
TestMaintenanceNotificationsBase
1967+
):
18951968
"""Integration tests for maintenance notifications handling with real connection pool."""
18961969

18971970
def setup_method(self):
18981971
"""Set up test fixtures with mocked sockets."""
1899-
self.mock_sockets = []
1900-
self.original_socket = socket.socket
1972+
super().setup_method()
19011973
self.orig_host = "test.address.com"
19021974

1903-
# Mock socket creation to return our mock sockets
1904-
def mock_socket_factory(*args, **kwargs):
1905-
mock_sock = MockSocket()
1906-
self.mock_sockets.append(mock_sock)
1907-
return mock_sock
1908-
1909-
self.socket_patcher = patch("socket.socket", side_effect=mock_socket_factory)
1910-
self.socket_patcher.start()
1911-
1912-
# Mock select.select to simulate data availability for reading
1913-
def mock_select(rlist, wlist, xlist, timeout=0):
1914-
# Check if any of the sockets in rlist have data available
1915-
ready_sockets = []
1916-
for sock in rlist:
1917-
if hasattr(sock, "connected") and sock.connected and not sock.closed:
1918-
# Only return socket as ready if it actually has data to read
1919-
if hasattr(sock, "pending_responses") and sock.pending_responses:
1920-
ready_sockets.append(sock)
1921-
# Don't return socket as ready just because it received commands
1922-
# Only when there are actual responses available
1923-
return (ready_sockets, [], [])
1924-
1925-
self.select_patcher = patch("select.select", side_effect=mock_select)
1926-
self.select_patcher.start()
1927-
19281975
ips = ["1.2.3.4", "5.6.7.8", "9.10.11.12"]
19291976
ips = ips * 3
19301977

@@ -1952,15 +1999,9 @@ def mock_socket_getaddrinfo(host, port, family=0, type=0, proto=0, flags=0):
19521999
)
19532000
self.getaddrinfo_patcher.start()
19542001

1955-
# Create maintenance notifications config
1956-
self.config = MaintNotificationsConfig(
1957-
enabled=True, proactive_reconnect=True, relaxed_timeout=30
1958-
)
1959-
19602002
def teardown_method(self):
19612003
"""Clean up test fixtures."""
1962-
self.socket_patcher.stop()
1963-
self.select_patcher.stop()
2004+
super().teardown_method()
19642005
self.getaddrinfo_patcher.stop()
19652006

19662007
@pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool])

0 commit comments

Comments
 (0)