diff --git a/kazoo/aio/__init__.py b/kazoo/aio/__init__.py new file mode 100644 index 00000000..6af289a9 --- /dev/null +++ b/kazoo/aio/__init__.py @@ -0,0 +1,3 @@ +""" +Simple asyncio integration of the threaded async executor engine. +""" diff --git a/kazoo/aio/client.py b/kazoo/aio/client.py new file mode 100644 index 00000000..1a1e3b3e --- /dev/null +++ b/kazoo/aio/client.py @@ -0,0 +1,92 @@ +import asyncio + +from kazoo.aio.handler import AioSequentialThreadingHandler +from kazoo.client import KazooClient, TransactionRequest + + +class AioKazooClient(KazooClient): + """ + The asyncio compatibility mostly mimics the behaviour of the base async + one. All calls are wrapped in asyncio.shield() to prevent cancellation + that is not supported in the base async implementation. + + The sync and base-async API are still completely functional. Mixing the + use of any of the 3 should be okay. + """ + + def __init__(self, *args, **kwargs): + if not kwargs.get("handler"): + kwargs["handler"] = AioSequentialThreadingHandler() + KazooClient.__init__(self, *args, **kwargs) + + # asyncio compatible api wrappers + async def start_aio(self, timeout=15): + """ + There is no protection for calling this multiple times in parallel. + The start_async() seems to lack that as well. Maybe it is allowed and + handled internally. + """ + await self.handler.loop.run_in_executor(None, self.start, timeout) + + async def add_auth_aio(self, *args, **kwargs): + return await asyncio.shield( + self.add_auth_async(*args, **kwargs).future + ) + + async def sync_aio(self, *args, **kwargs): + return await asyncio.shield(self.sync_async(*args, **kwargs).future) + + async def create_aio(self, *args, **kwargs): + return await asyncio.shield(self.create_async(*args, **kwargs).future) + + async def ensure_path_aio(self, *args, **kwargs): + return await asyncio.shield( + self.ensure_path_async(*args, **kwargs).future + ) + + async def exists_aio(self, *args, **kwargs): + return await asyncio.shield(self.exists_async(*args, **kwargs).future) + + async def get_aio(self, *args, **kwargs): + return await asyncio.shield(self.get_async(*args, **kwargs).future) + + async def get_children_aio(self, *args, **kwargs): + return await asyncio.shield( + self.get_children_async(*args, **kwargs).future + ) + + async def get_acls_aio(self, *args, **kwargs): + return await asyncio.shield( + self.get_acls_async(*args, **kwargs).future + ) + + async def set_acls_aio(self, *args, **kwargs): + return await asyncio.shield( + self.set_acls_async(*args, **kwargs).future + ) + + async def set_aio(self, *args, **kwargs): + return await asyncio.shield(self.set_async(*args, **kwargs).future) + + def transaction_aio(self): + return AioTransactionRequest(self) + + async def delete_aio(self, *args, **kwargs): + return await asyncio.shield(self.delete_async(*args, **kwargs).future) + + async def reconfig_aio(self, *args, **kwargs): + return await asyncio.shield( + self.reconfig_async(*args, **kwargs).future + ) + + +class AioTransactionRequest(TransactionRequest): + async def commit_aio(self): + return await asyncio.shield(self.commit_async().future) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, exc_tb): + if not exc_type: + await self.commit_aio() diff --git a/kazoo/aio/handler.py b/kazoo/aio/handler.py new file mode 100644 index 00000000..333f5cb3 --- /dev/null +++ b/kazoo/aio/handler.py @@ -0,0 +1,60 @@ +import asyncio +import threading + +from kazoo.handlers.threading import AsyncResult, SequentialThreadingHandler + + +class AioAsyncResult(AsyncResult): + def __init__(self, handler): + self.future = handler.loop.create_future() + AsyncResult.__init__(self, handler) + + def set(self, value=None): + """ + The completion of the future has the same guarantees as the + notification emitting of the condition. + Provided that no callbacks raise it will complete. + """ + AsyncResult.set(self, value) + self._handler.loop.call_soon_threadsafe(self.future.set_result, value) + + def set_exception(self, exception): + """ + The completion of the future has the same guarantees as the + notification emitting of the condition. + Provided that no callbacks raise it will complete. + """ + AsyncResult.set_exception(self, exception) + self._handler.loop.call_soon_threadsafe( + self.future.set_exception, exception + ) + + +class AioSequentialThreadingHandler(SequentialThreadingHandler): + def __init__(self): + """ + Creating the handler must be done on the asyncio-loop's thread. + """ + self.loop = asyncio.get_running_loop() + self._aio_thread = threading.current_thread() + SequentialThreadingHandler.__init__(self) + + def async_result(self, api=False): + """ + Almost all async-result objects are created by a method that is + invoked from the user's thead. The one exception I'm aware of is + in the PatientChildrenWatch utility, that creates an async-result + in its worker thread. Just because of that it is imperative to + only create asyncio compatible results when the invoking code is + from the loop's thread. There is no PEP/API guarantee that + implementing the create_future() has to be thread-safe. The default + is mostly thread-safe. The only thing that may get synchronization + issue is a debug-feature for asyncio development. Quickly looking at + the alternate implementation of uvloop, they use the default Future + implementation, so no change there. + For now, just to be safe, we check the current thread and create an + async-result object based on the invoking thread's identity. + """ + if api and threading.current_thread() is self._aio_thread: + return AioAsyncResult(self) + return AsyncResult(self) diff --git a/kazoo/aio/retry.py b/kazoo/aio/retry.py new file mode 100644 index 00000000..3af5c8b9 --- /dev/null +++ b/kazoo/aio/retry.py @@ -0,0 +1,91 @@ +import asyncio +import random +import time +from functools import partial + +from kazoo.exceptions import ( + ConnectionClosedError, + ConnectionLoss, + OperationTimeoutError, + SessionExpiredError, +) +from kazoo.retry import ForceRetryError, RetryFailedError + + +EXCEPTIONS = ( + ConnectionLoss, + OperationTimeoutError, + ForceRetryError, +) + +EXCEPTIONS_WITH_EXPIRED = EXCEPTIONS + (SessionExpiredError,) + + +def kazoo_retry_aio( + max_tries=1, + delay=0.1, + backoff=2, + max_jitter=0.4, + max_delay=60.0, + ignore_expire=True, + deadline=None, +): + """ + This is similar to KazooRetry, but they do not have compatible + interfaces. The threaded and asyncio constructs are too different + to easily wrap the KazooRetry implementation. Unless, all retries + always get their own thread to work in. This is much more lightweight + compared to the object-copying and resetting implementation. + + There is no equivalent analogue to the interrupt API. + If interrupting the retry is necessary, it must be wrapped in + an asyncio.Task, which can be cancelled. Be aware though that + this will quit waiting on the Zookeeper API call immediately + unlike the threaded API. There is no way to interrupt/cancel an + internal request thread so it will continue and stop eventually + on its own. This means caller can't know if the call is still + in progress and may succeed or the retry was cancelled while it + was waiting for delay. + + Usage example. These are equivalent except that the latter lines + will retry the requests on specific exceptions: + await zk.create_aio("/x") + await zk.create_aio("/x/y") + + aio_retry = kazoo_retry_aio() + await aio_retry(zk.create_aio, "/x") + await aio_retry(zk.create_aio, "/x/y") + """ + retry_exceptions = ( + EXCEPTIONS_WITH_EXPIRED if ignore_expire else EXCEPTIONS + ) + max_jitter = max(min(max_jitter, 1.0), 0.0) + get_jitter = partial(random.uniform, 1.0 - max_jitter, 1.0 + max_jitter) + del max_jitter + + async def _retry(func, *args, **kwargs): + attempts = 0 + cur_delay = delay + stop_time = ( + None if deadline is None else time.perf_counter() + deadline + ) + while True: + try: + return await func(*args, **kwargs) + except ConnectionClosedError: + raise + except retry_exceptions: + # Note: max_tries == -1 means infinite tries. + if attempts == max_tries: + raise RetryFailedError("Too many retry attempts") + attempts += 1 + sleep_time = cur_delay * get_jitter() + if ( + stop_time is not None + and time.perf_counter() + sleep_time >= stop_time + ): + raise RetryFailedError("Exceeded retry deadline") + await asyncio.sleep(sleep_time) + cur_delay = min(sleep_time * backoff, max_delay) + + return _retry diff --git a/kazoo/client.py b/kazoo/client.py index 25baa683..c6681069 100644 --- a/kazoo/client.py +++ b/kazoo/client.py @@ -818,7 +818,7 @@ def add_auth_async(self, scheme, credential): # we need this auth data to re-authenticate on reconnect self.auth_data.add((scheme, credential)) - async_result = self.handler.async_result() + async_result = self.handler.async_result(api=True) self._call(Auth(0, scheme, credential), async_result) return async_result @@ -839,7 +839,7 @@ def sync_async(self, path): :rtype: :class:`~kazoo.interfaces.IAsyncResult` """ - async_result = self.handler.async_result() + async_result = self.handler.async_result(api=True) @wrap(async_result) def _sync_completion(result): @@ -997,7 +997,7 @@ def create_async(self, path, value=b"", acl=None, ephemeral=False, if acl is None: acl = OPEN_ACL_UNSAFE - async_result = self.handler.async_result() + async_result = self.handler.async_result(api=True) @capture_exceptions(async_result) def do_create(): @@ -1071,7 +1071,7 @@ def ensure_path_async(self, path, acl=None): """ acl = acl or self.default_acl - async_result = self.handler.async_result() + async_result = self.handler.async_result(api=True) @wrap(async_result) def create_completion(result): @@ -1134,7 +1134,7 @@ def exists_async(self, path, watch=None): if watch and not callable(watch): raise TypeError("Invalid type for 'watch' (must be a callable)") - async_result = self.handler.async_result() + async_result = self.handler.async_result(api=True) self._call(Exists(_prefix_root(self.chroot, path), watch), async_result) return async_result @@ -1176,7 +1176,7 @@ def get_async(self, path, watch=None): if watch and not callable(watch): raise TypeError("Invalid type for 'watch' (must be a callable)") - async_result = self.handler.async_result() + async_result = self.handler.async_result(api=True) self._call(GetData(_prefix_root(self.chroot, path), watch), async_result) return async_result @@ -1232,7 +1232,7 @@ def get_children_async(self, path, watch=None, include_data=False): if not isinstance(include_data, bool): raise TypeError("Invalid type for 'include_data' (bool expected)") - async_result = self.handler.async_result() + async_result = self.handler.async_result(api=True) if include_data: req = GetChildren2(_prefix_root(self.chroot, path), watch) else: @@ -1270,7 +1270,7 @@ def get_acls_async(self, path): if not isinstance(path, string_types): raise TypeError("Invalid type for 'path' (string expected)") - async_result = self.handler.async_result() + async_result = self.handler.async_result(api=True) self._call(GetACL(_prefix_root(self.chroot, path)), async_result) return async_result @@ -1318,7 +1318,7 @@ def set_acls_async(self, path, acls, version=-1): if not isinstance(version, int): raise TypeError("Invalid type for 'version' (int expected)") - async_result = self.handler.async_result() + async_result = self.handler.async_result(api=True) self._call(SetACL(_prefix_root(self.chroot, path), acls, version), async_result) return async_result @@ -1372,7 +1372,7 @@ def set_async(self, path, value, version=-1): if not isinstance(version, int): raise TypeError("Invalid type for 'version' (int expected)") - async_result = self.handler.async_result() + async_result = self.handler.async_result(api=True) self._call(SetData(_prefix_root(self.chroot, path), value, version), async_result) return async_result @@ -1443,7 +1443,7 @@ def delete_async(self, path, version=-1): raise TypeError("Invalid type for 'path' (string expected)") if not isinstance(version, int): raise TypeError("Invalid type for 'version' (int expected)") - async_result = self.handler.async_result() + async_result = self.handler.async_result(api=True) self._call(Delete(_prefix_root(self.chroot, path), version), async_result) return async_result @@ -1556,7 +1556,7 @@ def reconfig_async(self, joining, leaving, new_members, from_config): if not isinstance(from_config, int): raise TypeError("Invalid type for 'from_config' (int expected)") - async_result = self.handler.async_result() + async_result = self.handler.async_result(api=True) reconfig = Reconfig(joining, leaving, new_members, from_config) self._call(reconfig, async_result) @@ -1672,7 +1672,7 @@ def commit_async(self): """ self._check_tx_state() self.committed = True - async_object = self.client.handler.async_result() + async_object = self.client.handler.async_result(api=True) self.client._call(Transaction(self.operations), async_object) return async_object diff --git a/kazoo/handlers/threading.py b/kazoo/handlers/threading.py index 21925237..7797f881 100644 --- a/kazoo/handlers/threading.py +++ b/kazoo/handlers/threading.py @@ -271,8 +271,10 @@ def rlock_object(self): """Create an appropriate RLock object""" return threading.RLock() - def async_result(self): - """Create a :class:`AsyncResult` instance""" + def async_result(self, api=False): + """Create a :class:`AsyncResult` instance. The api flag will + indicate if this object will be used by a user code or an + internal one. It is necessary for asyncio support.""" return AsyncResult(self) def spawn(self, func, *args, **kwargs): diff --git a/kazoo/testing/__init__.py b/kazoo/testing/__init__.py index c1ae12cc..7dcd51ef 100644 --- a/kazoo/testing/__init__.py +++ b/kazoo/testing/__init__.py @@ -1,4 +1,12 @@ -from kazoo.testing.harness import KazooTestCase, KazooTestHarness +from kazoo.testing.harness import ( + KazooAioTestCase, + KazooTestCase, + KazooTestHarness, +) -__all__ = ('KazooTestHarness', 'KazooTestCase', ) +__all__ = ( + "KazooTestHarness", + "KazooTestCase", + "KazooAioTestCase", +) diff --git a/kazoo/testing/harness.py b/kazoo/testing/harness.py index ce8748aa..bc789641 100644 --- a/kazoo/testing/harness.py +++ b/kazoo/testing/harness.py @@ -1,10 +1,12 @@ """Kazoo testing harnesses""" +import asyncio import logging import os import uuid import unittest from kazoo import python2atexit as atexit +from kazoo.aio.client import AioKazooClient from kazoo.client import KazooClient from kazoo.exceptions import KazooException from kazoo.protocol.connection import _CONNECTION_DROP, _SESSION_EXPIRED @@ -144,6 +146,7 @@ def test_something_else(self): """ DEFAULT_CLIENT_TIMEOUT = 15 + CLIENT_CLS = KazooClient def __init__(self, *args, **kw): super(KazooTestHarness, self).__init__(*args, **kw) @@ -159,14 +162,14 @@ def servers(self): return ",".join([s.address for s in self.cluster]) def _get_nonchroot_client(self): - c = KazooClient(self.servers) + c = self.CLIENT_CLS(self.servers) self._clients.append(c) return c def _get_client(self, **client_options): if 'timeout' not in client_options: client_options['timeout'] = self.DEFAULT_CLIENT_TIMEOUT - c = KazooClient(self.hosts, **client_options) + c = self.CLIENT_CLS(self.hosts, **client_options) self._clients.append(c) return c @@ -245,3 +248,26 @@ def setUp(self): def tearDown(self): self.teardown_zookeeper() + + +class KazooAioTestCase(KazooTestHarness): + CLIENT_CLS = AioKazooClient + + def __init__(self, *args, **kw): + super(KazooAioTestCase, self).__init__(*args, **kw) + self.loop = None + + async def setup_zookeeper_aio(self): + # NOTE: could enhance this to call start_aio() on the client + self.setup_zookeeper() + + async def teardown_zookeeper_aio(self): + self.teardown_zookeeper() + + def setUp(self): + self.loop = asyncio.get_event_loop_policy().new_event_loop() + self.loop.run_until_complete(self.setup_zookeeper_aio()) + + def tearDown(self): + self.loop.run_until_complete(self.teardown_zookeeper_aio()) + self.loop.close() diff --git a/kazoo/tests/test_aio.py b/kazoo/tests/test_aio.py new file mode 100644 index 00000000..30602421 --- /dev/null +++ b/kazoo/tests/test_aio.py @@ -0,0 +1,105 @@ +import pytest + +from kazoo.aio.retry import kazoo_retry_aio +from kazoo.exceptions import NotEmptyError, NoNodeError +from kazoo.protocol.states import ZnodeStat +from kazoo.testing import KazooAioTestCase + + +class KazooAioTests(KazooAioTestCase): + def test_basic_aio_functionality(self): + self.loop.run_until_complete(self._test_basic_aio_functionality()) + + async def _test_basic_aio_functionality(self): + assert await self.client.create_aio("/tmp") == "/tmp" + assert await self.client.get_children_aio("/") == ["tmp"] + assert await self.client.ensure_path_aio("/tmp/x/y") == "/tmp/x/y" + assert await self.client.exists_aio("/tmp/x/y") + assert isinstance( + await self.client.set_aio("/tmp/x/y", b"very aio"), ZnodeStat + ) + data, stat = await self.client.get_aio("/tmp/x/y") + assert data == b"very aio" + assert isinstance(stat, ZnodeStat) + with pytest.raises(NotEmptyError): + await self.client.delete_aio("/tmp/x") + await self.client.delete_aio("/tmp/x/y") + with pytest.raises(NoNodeError): + await self.client.get_aio("/tmp/x/y") + async with self.client.transaction_aio() as tx: + tx.create("/tmp/z", b"ZZZ") + tx.set_data("/tmp/x", b"XXX") + assert (await self.client.get_aio("/tmp/x"))[0] == b"XXX" + assert (await self.client.get_aio("/tmp/z"))[0] == b"ZZZ" + self.client.stop() + assert self.client.connected is False + await self.client.start_aio() + assert self.client.connected is True + assert set(await self.client.get_children_aio("/tmp")) == set( + ["x", "z"] + ) + + def test_aio_retry_functionality(self): + self.loop.run_until_complete(self._test_aio_retry_functionality()) + + async def _test_aio_retry_functionality(self): + # Just lump them all in here for now, they are short enough that + # it does not matter much. + await self._test_aio_retry() + await self._test_too_many_tries() + await self._test_connection_closed() + await self._test_session_expired() + + async def _pass(self): + pass + + def _fail(self, times=1, scope=None): + from kazoo.retry import ForceRetryError + + if not scope: + scope = dict(times=0) + + async def inner(): + if scope["times"] >= times: + return scope + else: + scope["times"] += 1 + raise ForceRetryError("Failed!") + + return inner + + async def _test_aio_retry(self): + aio_retry = kazoo_retry_aio(delay=0, max_tries=2) + assert await aio_retry(self._fail()) == {"times": 1} + assert await aio_retry(self._fail()) == {"times": 1} + + async def _test_too_many_tries(self): + from kazoo.retry import RetryFailedError + + aio_retry = kazoo_retry_aio(delay=0, max_tries=3) + scope = dict(times=0) + with pytest.raises(RetryFailedError): + await aio_retry(self._fail(times=999, scope=scope)) + assert scope == {"times": 4} + + async def _test_connection_closed(self): + from kazoo.exceptions import ConnectionClosedError + + aio_retry = kazoo_retry_aio() + + async def testit(): + raise ConnectionClosedError() + + with pytest.raises(ConnectionClosedError): + await aio_retry(testit) + + async def _test_session_expired(self): + from kazoo.exceptions import SessionExpiredError + + aio_retry = kazoo_retry_aio() + + async def testit(): + raise SessionExpiredError() + + with pytest.raises(Exception): + await aio_retry(testit)