diff --git a/memcache.py b/memcache.py index 3e6bd67..95591b6 100644 --- a/memcache.py +++ b/memcache.py @@ -1048,7 +1048,7 @@ def _unsafe_set(): except _ConnectionDeadError: # retry once try: - if server._get_socket(): + if server._get_socket(reconnect=True): return _unsafe_set() except (_ConnectionDeadError, socket.error) as msg: server.mark_dead(msg) @@ -1101,7 +1101,7 @@ def _unsafe_get(): except _ConnectionDeadError: # retry once try: - if server.connect(): + if server._get_socket(reconnect=True): return _unsafe_get() return None except (_ConnectionDeadError, socket.error) as msg: @@ -1386,10 +1386,10 @@ def mark_dead(self, reason): self.flush_on_next_connect = 1 self.close_socket() - def _get_socket(self): + def _get_socket(self, reconnect=False): if self._check_dead(): return None - if self.socket: + if self.socket and not reconnect: return self.socket s = socket.socket(self.family, socket.SOCK_STREAM) if hasattr(s, 'settimeout'): diff --git a/tests/test_memcache.py b/tests/test_memcache.py index 0072813..43c715c 100644 --- a/tests/test_memcache.py +++ b/tests/test_memcache.py @@ -1,5 +1,6 @@ from __future__ import print_function +import socket import unittest import six @@ -167,5 +168,52 @@ def test_disconnect_all_delete_multi(self): self.assertEqual(ret, 1) +class TestMemcacheMarkDead(unittest.TestCase): + + def setUp(self): + self.status = locals() + self.address = ("127.0.0.1", 11213) + self._start_stub_server() + self.client = Client(["127.0.0.1:11213"], debug=1) + + def tearDown(self): + self._stop_stub_server() + + def _start_stub_server(self): + # setup stub server + stub_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + stub_socket.bind(self.address) + stub_socket.listen(1) + self.stub_socket = stub_socket + + def _stop_stub_server(self): + self.stub_socket.close() + + def test_mark_server_dead(self): + mc_host = self.client._get_server('foo'.encode('utf8'))[0] + client_socket = mc_host._get_socket() + + # make sure the server is not marked dead + self.assertEqual(0, mc_host._check_dead()) + + # stop the stub server + self._stop_stub_server() + + # host is not yet marked as dead + self.assertEqual(0, mc_host._check_dead()) + + # create a new stub socket again + self._start_stub_server() + + # The client will try to re-use the old socket and if it fails + # then should re-establish a new connection + # so the socket must be a new one + new_client_socket = mc_host._get_socket(reconnect=True) + self.assertNotEqual(new_client_socket, client_socket) + + # server is not marked dead, because connection succeeded + self.assertEqual(0, mc_host._check_dead()) + + if __name__ == '__main__': unittest.main()