@@ -346,17 +346,173 @@ class XTransport(Transport):
346
346
Channel = XChannel
347
347
348
348
conn = Connection (transport = XTransport )
349
+ conn .transport .cycle = Mock (name = 'cycle' )
349
350
client .ping .side_effect = RuntimeError ()
350
351
with pytest .raises (RuntimeError ):
351
352
conn .channel ()
352
353
pool .disconnect .assert_called_with ()
353
354
pool .disconnect .reset_mock ()
355
+ # Ensure that the channel without ensured connection to Redis
356
+ # won't be added to the cycle.
357
+ conn .transport .cycle .add .assert_not_called ()
358
+ assert len (conn .transport .channels ) == 0
354
359
355
360
pool_at_init = [None ]
356
361
with pytest .raises (RuntimeError ):
357
362
conn .channel ()
358
363
pool .disconnect .assert_not_called ()
359
364
365
+ def test_redis_connection_added_to_cycle_if_ping_succeeds (self ):
366
+ """Test should check the connection is added to the cycle only
367
+ if the ping to Redis was finished successfully."""
368
+ # given: mock pool and client
369
+ pool = Mock (name = 'pool' )
370
+ client = Mock (name = 'client' )
371
+
372
+ # override channel class with given mocks
373
+ class XChannel (Channel ):
374
+ def __init__ (self , * args , ** kwargs ):
375
+ self ._pool = pool
376
+ super ().__init__ (* args , ** kwargs )
377
+
378
+ def _get_client (self ):
379
+ return lambda * _ , ** __ : client
380
+
381
+ # override Channel in Transport with given channel
382
+ class XTransport (Transport ):
383
+ Channel = XChannel
384
+
385
+ # when: create connection with overridden transport
386
+ conn = Connection (transport = XTransport )
387
+ conn .transport .cycle = Mock (name = 'cycle' )
388
+ # create the channel
389
+ chan = conn .channel ()
390
+ # then: check if ping was called
391
+ client .ping .assert_called_once ()
392
+ # the connection was added to the cycle
393
+ conn .transport .cycle .add .assert_called_once ()
394
+ assert len (conn .transport .channels ) == 1
395
+ # the channel was flaged as registered into poller
396
+ assert chan ._registered
397
+
398
+ def test_redis_on_disconnect_channel_only_if_was_registered (self ):
399
+ """Test shoud check if the _on_disconnect method is called only
400
+ if the channel was registered into the poller."""
401
+ # given: mock pool and client
402
+ pool = Mock (name = 'pool' )
403
+ client = Mock (
404
+ name = 'client' ,
405
+ ping = Mock (return_value = True )
406
+ )
407
+
408
+ # create RedisConnectionMock class
409
+ # for the possibility to run disconnect method
410
+ class RedisConnectionMock :
411
+ def disconnect (self , * args ):
412
+ pass
413
+
414
+ # override Channel method with given mocks
415
+ class XChannel (Channel ):
416
+ connection_class = RedisConnectionMock
417
+
418
+ def __init__ (self , * args , ** kwargs ):
419
+ self ._pool = pool
420
+ # counter to check if the method was called
421
+ self .on_disconect_count = 0
422
+ super ().__init__ (* args , ** kwargs )
423
+
424
+ def _get_client (self ):
425
+ return lambda * _ , ** __ : client
426
+
427
+ def _on_connection_disconnect (self , connection ):
428
+ # increment the counter when the method is called
429
+ self .on_disconect_count += 1
430
+
431
+ # create the channel
432
+ chan = XChannel (Mock (
433
+ _used_channel_ids = [],
434
+ channel_max = 1 ,
435
+ channels = [],
436
+ client = Mock (
437
+ transport_options = {},
438
+ hostname = "127.0.0.1" ,
439
+ virtual_host = None )))
440
+ # create the _connparams with overriden connection_class
441
+ connparams = chan ._connparams (asynchronous = True )
442
+ # create redis.Connection
443
+ redis_connection = connparams ['connection_class' ]()
444
+ # the connection was added to the cycle
445
+ chan .connection .cycle .add .assert_called_once ()
446
+ # and the ping was called
447
+ client .ping .assert_called_once ()
448
+ # the channel was registered
449
+ assert chan ._registered
450
+ # than disconnect the Redis connection
451
+ redis_connection .disconnect ()
452
+ # the on_disconnect counter should be incremented
453
+ assert chan .on_disconect_count == 1
454
+
455
+ def test_redis__on_disconnect_should_not_be_called_if_not_registered (self ):
456
+ """Test should check if the _on_disconnect method is not called because
457
+ the connection to Redis isn't established properly."""
458
+ # given: mock pool
459
+ pool = Mock (name = 'pool' )
460
+ # client mock with ping method which return ConnectionError
461
+ from redis .exceptions import ConnectionError
462
+ client = Mock (
463
+ name = 'client' ,
464
+ ping = Mock (side_effect = ConnectionError ())
465
+ )
466
+
467
+ # create RedisConnectionMock
468
+ # for the possibility to run disconnect method
469
+ class RedisConnectionMock :
470
+ def disconnect (self , * args ):
471
+ pass
472
+
473
+ # override Channel method with given mocks
474
+ class XChannel (Channel ):
475
+ connection_class = RedisConnectionMock
476
+
477
+ def __init__ (self , * args , ** kwargs ):
478
+ self ._pool = pool
479
+ # counter to check if the method was called
480
+ self .on_disconect_count = 0
481
+ super ().__init__ (* args , ** kwargs )
482
+
483
+ def _get_client (self ):
484
+ return lambda * _ , ** __ : client
485
+
486
+ def _on_connection_disconnect (self , connection ):
487
+ # increment the counter when the method is called
488
+ self .on_disconect_count += 1
489
+
490
+ # then: exception was risen
491
+ with pytest .raises (ConnectionError ):
492
+ # when: create the channel
493
+ chan = XChannel (Mock (
494
+ _used_channel_ids = [],
495
+ channel_max = 1 ,
496
+ channels = [],
497
+ client = Mock (
498
+ transport_options = {},
499
+ hostname = "127.0.0.1" ,
500
+ virtual_host = None )))
501
+ # create the _connparams with overriden connection_class
502
+ connparams = chan ._connparams (asynchronous = True )
503
+ # create redis.Connection
504
+ redis_connection = connparams ['connection_class' ]()
505
+ # the connection wasn't added to the cycle
506
+ chan .connection .cycle .add .assert_not_called ()
507
+ # the ping was called once with the exception
508
+ client .ping .assert_called_once ()
509
+ # the channel was not registered
510
+ assert not chan ._registered
511
+ # then: disconnect the Redis connection
512
+ redis_connection .disconnect ()
513
+ # the on_disconnect counter shouldn't be incremented
514
+ assert chan .on_disconect_count == 0
515
+
360
516
def test_get_redis_ConnectionError (self ):
361
517
from redis .exceptions import ConnectionError
362
518
0 commit comments