Skip to content

Commit dcb43be

Browse files
authored
Fix Redis connections after reconnect - consumer starts consuming the tasks after crash. (#2007)
* Add more logs * Launch _on_connection_disconnect in Conection only if channel was added properly to the poller * Prepare test which check the flow of the channel removal from poller * Change the comment
1 parent 1217865 commit dcb43be

File tree

2 files changed

+164
-2
lines changed

2 files changed

+164
-2
lines changed

kombu/transport/redis.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ def __init__(self, *args, **kwargs):
722722

723723
if not self.ack_emulation: # disable visibility timeout
724724
self.QoS = virtual.QoS
725-
725+
self._registered = False
726726
self._queue_cycle = cycle_by_name(self.queue_order_strategy)()
727727
self.Client = self._get_client()
728728
self.ResponseError = self._get_response_error()
@@ -747,6 +747,9 @@ def __init__(self, *args, **kwargs):
747747
raise
748748

749749
self.connection.cycle.add(self) # add to channel poller.
750+
# and set to true after sucessfuly added channel to the poll.
751+
self._registered = True
752+
750753
# copy errors, in case channel closed but threads still
751754
# are still waiting for data.
752755
self.connection_errors = self.connection.connection_errors
@@ -1201,7 +1204,10 @@ def _connparams(self, asynchronous=False):
12011204
class Connection(connection_cls):
12021205
def disconnect(self, *args):
12031206
super().disconnect(*args)
1204-
channel._on_connection_disconnect(self)
1207+
# We remove the connection from the poller
1208+
# only if it has been added properly.
1209+
if channel._registered:
1210+
channel._on_connection_disconnect(self)
12051211
connection_cls = Connection
12061212

12071213
connparams['connection_class'] = connection_cls

t/unit/transport/test_redis.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,17 +346,173 @@ class XTransport(Transport):
346346
Channel = XChannel
347347

348348
conn = Connection(transport=XTransport)
349+
conn.transport.cycle = Mock(name='cycle')
349350
client.ping.side_effect = RuntimeError()
350351
with pytest.raises(RuntimeError):
351352
conn.channel()
352353
pool.disconnect.assert_called_with()
353354
pool.disconnect.reset_mock()
355+
# Ensure that the channel without ensured connection to Redis
356+
# won't be added to the cycle.
357+
conn.transport.cycle.add.assert_not_called()
358+
assert len(conn.transport.channels) == 0
354359

355360
pool_at_init = [None]
356361
with pytest.raises(RuntimeError):
357362
conn.channel()
358363
pool.disconnect.assert_not_called()
359364

365+
def test_redis_connection_added_to_cycle_if_ping_succeeds(self):
366+
"""Test should check the connection is added to the cycle only
367+
if the ping to Redis was finished successfully."""
368+
# given: mock pool and client
369+
pool = Mock(name='pool')
370+
client = Mock(name='client')
371+
372+
# override channel class with given mocks
373+
class XChannel(Channel):
374+
def __init__(self, *args, **kwargs):
375+
self._pool = pool
376+
super().__init__(*args, **kwargs)
377+
378+
def _get_client(self):
379+
return lambda *_, **__: client
380+
381+
# override Channel in Transport with given channel
382+
class XTransport(Transport):
383+
Channel = XChannel
384+
385+
# when: create connection with overridden transport
386+
conn = Connection(transport=XTransport)
387+
conn.transport.cycle = Mock(name='cycle')
388+
# create the channel
389+
chan = conn.channel()
390+
# then: check if ping was called
391+
client.ping.assert_called_once()
392+
# the connection was added to the cycle
393+
conn.transport.cycle.add.assert_called_once()
394+
assert len(conn.transport.channels) == 1
395+
# the channel was flaged as registered into poller
396+
assert chan._registered
397+
398+
def test_redis_on_disconnect_channel_only_if_was_registered(self):
399+
"""Test shoud check if the _on_disconnect method is called only
400+
if the channel was registered into the poller."""
401+
# given: mock pool and client
402+
pool = Mock(name='pool')
403+
client = Mock(
404+
name='client',
405+
ping=Mock(return_value=True)
406+
)
407+
408+
# create RedisConnectionMock class
409+
# for the possibility to run disconnect method
410+
class RedisConnectionMock:
411+
def disconnect(self, *args):
412+
pass
413+
414+
# override Channel method with given mocks
415+
class XChannel(Channel):
416+
connection_class = RedisConnectionMock
417+
418+
def __init__(self, *args, **kwargs):
419+
self._pool = pool
420+
# counter to check if the method was called
421+
self.on_disconect_count = 0
422+
super().__init__(*args, **kwargs)
423+
424+
def _get_client(self):
425+
return lambda *_, **__: client
426+
427+
def _on_connection_disconnect(self, connection):
428+
# increment the counter when the method is called
429+
self.on_disconect_count += 1
430+
431+
# create the channel
432+
chan = XChannel(Mock(
433+
_used_channel_ids=[],
434+
channel_max=1,
435+
channels=[],
436+
client=Mock(
437+
transport_options={},
438+
hostname="127.0.0.1",
439+
virtual_host=None)))
440+
# create the _connparams with overriden connection_class
441+
connparams = chan._connparams(asynchronous=True)
442+
# create redis.Connection
443+
redis_connection = connparams['connection_class']()
444+
# the connection was added to the cycle
445+
chan.connection.cycle.add.assert_called_once()
446+
# and the ping was called
447+
client.ping.assert_called_once()
448+
# the channel was registered
449+
assert chan._registered
450+
# than disconnect the Redis connection
451+
redis_connection.disconnect()
452+
# the on_disconnect counter should be incremented
453+
assert chan.on_disconect_count == 1
454+
455+
def test_redis__on_disconnect_should_not_be_called_if_not_registered(self):
456+
"""Test should check if the _on_disconnect method is not called because
457+
the connection to Redis isn't established properly."""
458+
# given: mock pool
459+
pool = Mock(name='pool')
460+
# client mock with ping method which return ConnectionError
461+
from redis.exceptions import ConnectionError
462+
client = Mock(
463+
name='client',
464+
ping=Mock(side_effect=ConnectionError())
465+
)
466+
467+
# create RedisConnectionMock
468+
# for the possibility to run disconnect method
469+
class RedisConnectionMock:
470+
def disconnect(self, *args):
471+
pass
472+
473+
# override Channel method with given mocks
474+
class XChannel(Channel):
475+
connection_class = RedisConnectionMock
476+
477+
def __init__(self, *args, **kwargs):
478+
self._pool = pool
479+
# counter to check if the method was called
480+
self.on_disconect_count = 0
481+
super().__init__(*args, **kwargs)
482+
483+
def _get_client(self):
484+
return lambda *_, **__: client
485+
486+
def _on_connection_disconnect(self, connection):
487+
# increment the counter when the method is called
488+
self.on_disconect_count += 1
489+
490+
# then: exception was risen
491+
with pytest.raises(ConnectionError):
492+
# when: create the channel
493+
chan = XChannel(Mock(
494+
_used_channel_ids=[],
495+
channel_max=1,
496+
channels=[],
497+
client=Mock(
498+
transport_options={},
499+
hostname="127.0.0.1",
500+
virtual_host=None)))
501+
# create the _connparams with overriden connection_class
502+
connparams = chan._connparams(asynchronous=True)
503+
# create redis.Connection
504+
redis_connection = connparams['connection_class']()
505+
# the connection wasn't added to the cycle
506+
chan.connection.cycle.add.assert_not_called()
507+
# the ping was called once with the exception
508+
client.ping.assert_called_once()
509+
# the channel was not registered
510+
assert not chan._registered
511+
# then: disconnect the Redis connection
512+
redis_connection.disconnect()
513+
# the on_disconnect counter shouldn't be incremented
514+
assert chan.on_disconect_count == 0
515+
360516
def test_get_redis_ConnectionError(self):
361517
from redis.exceptions import ConnectionError
362518

0 commit comments

Comments
 (0)