From bc8720d31b8aa5fb285ebf2e8e0ddb598d608b38 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Fri, 22 Nov 2024 09:57:18 +0100 Subject: [PATCH 1/5] test fast yielding in background task --- tests/integration/test_background_task.py | 30 +++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/integration/test_background_task.py b/tests/integration/test_background_task.py index d7fe208247b..2cf4d6a30ec 100644 --- a/tests/integration/test_background_task.py +++ b/tests/integration/test_background_task.py @@ -42,6 +42,11 @@ async def handle_event_yield_only(self): yield State.increment() # type: ignore await asyncio.sleep(0.005) + @rx.event(background=True) + async def fast_yielding(self): + for _ in range(100): + yield State.increment() + @rx.event def increment(self): self.counter += 1 @@ -375,3 +380,28 @@ def test_yield_in_async_with_self( yield_in_async_with_self_button.click() assert background_task._poll_for(lambda: counter.text == "2", timeout=5) + + +def test_fast_yielding( + background_task: AppHarness, + driver: WebDriver, + token: str, +) -> None: + """Test that fast yielding works as expected. + + Args: + background_task: harness for BackgroundTask app. + driver: WebDriver instance. + token: The token for the connected client. + """ + assert background_task.app_instance is not None + + # get a reference to all buttons + fast_yielding_button = driver.find_element(By.ID, "yield-increment") + + # get a reference to the counter + counter = driver.find_element(By.ID, "counter") + assert background_task._poll_for(lambda: counter.text == "0", timeout=5) + + fast_yielding_button.click() + assert background_task._poll_for(lambda: counter.text == "100", timeout=5) From 0bf8ffefeef7b25e1fceb826c72df50ccc873208 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Sat, 23 Nov 2024 11:44:58 +0100 Subject: [PATCH 2/5] accidentally pushed unfinished changes --- tests/integration/test_background_task.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/integration/test_background_task.py b/tests/integration/test_background_task.py index 2cf4d6a30ec..70e2202a68d 100644 --- a/tests/integration/test_background_task.py +++ b/tests/integration/test_background_task.py @@ -44,7 +44,7 @@ async def handle_event_yield_only(self): @rx.event(background=True) async def fast_yielding(self): - for _ in range(100): + for _ in range(1000): yield State.increment() @rx.event @@ -174,6 +174,11 @@ def index() -> rx.Component: on_click=State.yield_in_async_with_self, id="yield-in-async-with-self", ), + rx.button( + "Fast Yielding", + on_click=State.fast_yielding, + id="fast-yielding", + ), rx.button("Reset", on_click=State.reset_counter, id="reset"), ) @@ -397,11 +402,11 @@ def test_fast_yielding( assert background_task.app_instance is not None # get a reference to all buttons - fast_yielding_button = driver.find_element(By.ID, "yield-increment") + fast_yielding_button = driver.find_element(By.ID, "fast-yielding") # get a reference to the counter counter = driver.find_element(By.ID, "counter") assert background_task._poll_for(lambda: counter.text == "0", timeout=5) fast_yielding_button.click() - assert background_task._poll_for(lambda: counter.text == "100", timeout=5) + assert background_task._poll_for(lambda: counter.text == "1000", timeout=50) From f3e393e621c34eb357982d3c5002d0694613be73 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Sat, 23 Nov 2024 20:37:39 +0100 Subject: [PATCH 3/5] fix: only open one connection/sub for each token per worker bonus: properly cleanup StateManager connections on disconnect --- reflex/app.py | 5 +++-- reflex/state.py | 42 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index fc8efb42016..758d57c33ae 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1479,7 +1479,7 @@ def __init__(self, namespace: str, app: App): self.sid_to_token = {} self.app = app - def on_connect(self, sid, environ): + async def on_connect(self, sid, environ): """Event for when the websocket is connected. Args: @@ -1488,7 +1488,7 @@ def on_connect(self, sid, environ): """ pass - def on_disconnect(self, sid): + async def on_disconnect(self, sid): """Event for when the websocket disconnects. Args: @@ -1497,6 +1497,7 @@ def on_disconnect(self, sid): disconnect_token = self.sid_to_token.pop(sid, None) if disconnect_token: self.token_to_sid.pop(disconnect_token, None) + await self.app.state_manager.disconnect(sid) async def emit_update(self, update: StateUpdate, sid: str) -> None: """Emit an update to the client. diff --git a/reflex/state.py b/reflex/state.py index 95f7f64f68c..35d840498f8 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2830,6 +2830,14 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]: """ yield self.state() + async def disconnect(self, token: str) -> None: + """Disconnect the client with the given token. + + Args: + token: The token to disconnect. + """ + pass + class StateManagerMemory(StateManager): """A state manager that stores states in memory.""" @@ -2899,6 +2907,20 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]: yield state await self.set_state(token, state) + @override + async def disconnect(self, token: str) -> None: + """Disconnect the client with the given token. + + Args: + token: The token to disconnect. + """ + if token in self.states: + del self.states[token] + if lock := self._states_locks.get(token): + if lock.locked(): + lock.release() + del self._states_locks[token] + def _default_token_expiration() -> int: """Get the default token expiration time. @@ -3187,6 +3209,9 @@ class StateManagerRedis(StateManager): b"evicted", } + # This lock is used to ensure we only subscribe to keyspace events once per token and worker + _pubsub_locks: Dict[bytes, asyncio.Lock] = pydantic.PrivateAttr({}) + async def _get_parent_state( self, token: str, state: BaseState | None = None ) -> BaseState | None: @@ -3462,7 +3487,9 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None: # Some redis servers only allow out-of-band configuration, so ignore errors here. if not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get(): raise - async with self.redis.pubsub() as pubsub: + if lock_key not in self._pubsub_locks: + self._pubsub_locks[lock_key] = asyncio.Lock() + async with self._pubsub_locks[lock_key], self.redis.pubsub() as pubsub: await pubsub.psubscribe(lock_key_channel) while not state_is_locked: # wait for the lock to be released @@ -3479,6 +3506,19 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None: break state_is_locked = await self._try_get_lock(lock_key, lock_id) + @override + async def disconnect(self, token: str): + """Disconnect the token from the redis client. + + Args: + token: The token to disconnect. + """ + lock_key = self._lock_key(token) + if lock := self._pubsub_locks.get(lock_key): + if lock.locked(): + lock.release() + del self._pubsub_locks[lock_key] + @contextlib.asynccontextmanager async def _lock(self, token: str): """Obtain a redis lock for a token. From 3126dcaef618b2cbd1e272b35c4acc21270122b7 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Sat, 23 Nov 2024 20:50:14 +0100 Subject: [PATCH 4/5] benchmark, revert me --- tests/integration/test_background_task.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/test_background_task.py b/tests/integration/test_background_task.py index 70e2202a68d..2804b1ef819 100644 --- a/tests/integration/test_background_task.py +++ b/tests/integration/test_background_task.py @@ -44,7 +44,7 @@ async def handle_event_yield_only(self): @rx.event(background=True) async def fast_yielding(self): - for _ in range(1000): + for _ in range(100000): yield State.increment() @rx.event @@ -409,4 +409,4 @@ def test_fast_yielding( assert background_task._poll_for(lambda: counter.text == "0", timeout=5) fast_yielding_button.click() - assert background_task._poll_for(lambda: counter.text == "1000", timeout=50) + assert background_task._poll_for(lambda: counter.text == "100000", timeout=1200) From 82b829c662055ca2fe6dae134aa2f84e17662c2e Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Sat, 23 Nov 2024 21:06:58 +0100 Subject: [PATCH 5/5] PoC experimental valkey-glide client --- .gitignore | 1 + poetry.lock | 264 +++++++++++++++++++++- pyproject.toml | 30 +-- reflex/app.py | 2 +- reflex/state.py | 171 ++++++++++---- reflex/utils/exceptions.py | 4 + reflex/utils/prerequisites.py | 52 ++++- tests/integration/test_background_task.py | 4 +- tests/units/test_state.py | 17 +- 9 files changed, 459 insertions(+), 86 deletions(-) diff --git a/.gitignore b/.gitignore index 0f7d9e5ffa0..8384e60f8b5 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ assets/external/* dist/* examples/ +glide-logs/ .web .idea .vscode diff --git a/poetry.lock b/poetry.lock index 3310f106a50..dbde2ae4cb4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "alembic" @@ -144,6 +144,17 @@ typing = ["build[uv]", "importlib-metadata (>=5.1)", "mypy (>=1.9.0,<1.10.0)", " uv = ["uv (>=0.1.18)"] virtualenv = ["virtualenv (>=20.0.35)"] +[[package]] +name = "cachetools" +version = "5.5.0" +description = "Extensible memoizing collections and decorators" +optional = false +python-versions = ">=3.7" +files = [ + {file = "cachetools-5.5.0-py3-none-any.whl", hash = "sha256:02134e8439cdc2ffb62023ce1debca2944c3f289d66bb17ead3ab3dede74b292"}, + {file = "cachetools-5.5.0.tar.gz", hash = "sha256:2cc24fb4cbe39633fb7badd9db9ca6295d766d9c2995f245725a46715d050f2a"}, +] + [[package]] name = "certifi" version = "2024.8.30" @@ -619,6 +630,106 @@ docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2. testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"] typing = ["typing-extensions (>=4.12.2)"] +[[package]] +name = "google-api-core" +version = "2.23.0" +description = "Google API client core library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google_api_core-2.23.0-py3-none-any.whl", hash = "sha256:c20100d4c4c41070cf365f1d8ddf5365915291b5eb11b83829fbd1c999b5122f"}, + {file = "google_api_core-2.23.0.tar.gz", hash = "sha256:2ceb087315e6af43f256704b871d99326b1f12a9d6ce99beaedec99ba26a0ace"}, +] + +[package.dependencies] +google-auth = ">=2.14.1,<3.0.dev0" +googleapis-common-protos = ">=1.56.2,<2.0.dev0" +proto-plus = [ + {version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""}, + {version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""}, +] +protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" +requests = ">=2.18.0,<3.0.0.dev0" + +[package.extras] +async-rest = ["google-auth[aiohttp] (>=2.35.0,<3.0.dev0)"] +grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status (>=1.33.2,<2.0.dev0)", "grpcio-status (>=1.49.1,<2.0.dev0)"] +grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] +grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] + +[[package]] +name = "google-api-python-client" +version = "2.85.0" +description = "Google API Client Library for Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google-api-python-client-2.85.0.tar.gz", hash = "sha256:07b21ef21a542dd69cd7c09817a6079b2769cc2a791981402e8f0fcdb2d47f90"}, + {file = "google_api_python_client-2.85.0-py2.py3-none-any.whl", hash = "sha256:baf3c6f9b1679d89fcb88c29941a8b04b9a815d721880786baecc6a7f5bd376f"}, +] + +[package.dependencies] +google-api-core = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0dev" +google-auth = ">=1.19.0,<3.0.0dev" +google-auth-httplib2 = ">=0.1.0" +httplib2 = ">=0.15.0,<1dev" +uritemplate = ">=3.0.1,<5" + +[[package]] +name = "google-auth" +version = "2.36.0" +description = "Google Authentication Library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google_auth-2.36.0-py2.py3-none-any.whl", hash = "sha256:51a15d47028b66fd36e5c64a82d2d57480075bccc7da37cde257fc94177a61fb"}, + {file = "google_auth-2.36.0.tar.gz", hash = "sha256:545e9618f2df0bcbb7dcbc45a546485b1212624716975a1ea5ae8149ce769ab1"}, +] + +[package.dependencies] +cachetools = ">=2.0.0,<6.0" +pyasn1-modules = ">=0.2.1" +rsa = ">=3.1.4,<5" + +[package.extras] +aiohttp = ["aiohttp (>=3.6.2,<4.0.0.dev0)", "requests (>=2.20.0,<3.0.0.dev0)"] +enterprise-cert = ["cryptography", "pyopenssl"] +pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"] +reauth = ["pyu2f (>=0.1.5)"] +requests = ["requests (>=2.20.0,<3.0.0.dev0)"] + +[[package]] +name = "google-auth-httplib2" +version = "0.2.0" +description = "Google Authentication Library: httplib2 transport" +optional = false +python-versions = "*" +files = [ + {file = "google-auth-httplib2-0.2.0.tar.gz", hash = "sha256:38aa7badf48f974f1eb9861794e9c0cb2a0511a4ec0679b1f886d108f5640e05"}, + {file = "google_auth_httplib2-0.2.0-py2.py3-none-any.whl", hash = "sha256:b65a0a2123300dd71281a7bf6e64d65a0759287df52729bdd1ae2e47dc311a3d"}, +] + +[package.dependencies] +google-auth = "*" +httplib2 = ">=0.19.0" + +[[package]] +name = "googleapis-common-protos" +version = "1.66.0" +description = "Common protobufs used in Google APIs" +optional = false +python-versions = ">=3.7" +files = [ + {file = "googleapis_common_protos-1.66.0-py2.py3-none-any.whl", hash = "sha256:d7abcd75fabb2e0ec9f74466401f6c119a0b498e27370e9be4c94cb7e382b8ed"}, + {file = "googleapis_common_protos-1.66.0.tar.gz", hash = "sha256:c3e7b33d15fdca5374cc0a7346dd92ffa847425cc4ea941d970f13680052ec8c"}, +] + +[package.dependencies] +protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" + +[package.extras] +grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] + [[package]] name = "greenlet" version = "3.1.1" @@ -758,6 +869,20 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] trio = ["trio (>=0.22.0,<1.0)"] +[[package]] +name = "httplib2" +version = "0.22.0" +description = "A comprehensive HTTP client library." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "httplib2-0.22.0-py3-none-any.whl", hash = "sha256:14ae0a53c1ba8f3d37e9e27cf37eabb0fb9980f435ba405d546948b009dd64dc"}, + {file = "httplib2-0.22.0.tar.gz", hash = "sha256:d7a10bc5ef5ab08322488bde8c726eeee5c8618723fdb399597ec58f3d82df81"}, +] + +[package.dependencies] +pyparsing = {version = ">=2.4.2,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.0.2 || >3.0.2,<3.0.3 || >3.0.3,<4", markers = "python_version > \"3.0\""} + [[package]] name = "httpx" version = "0.27.2" @@ -1602,6 +1727,43 @@ nodeenv = ">=0.11.1" pyyaml = ">=5.1" virtualenv = ">=20.10.0" +[[package]] +name = "proto-plus" +version = "1.25.0" +description = "Beautiful, Pythonic protocol buffers." +optional = false +python-versions = ">=3.7" +files = [ + {file = "proto_plus-1.25.0-py3-none-any.whl", hash = "sha256:c91fc4a65074ade8e458e95ef8bac34d4008daa7cce4a12d6707066fca648961"}, + {file = "proto_plus-1.25.0.tar.gz", hash = "sha256:fbb17f57f7bd05a68b7707e745e26528b0b3c34e378db91eef93912c54982d91"}, +] + +[package.dependencies] +protobuf = ">=3.19.0,<6.0.0dev" + +[package.extras] +testing = ["google-api-core (>=1.31.5)"] + +[[package]] +name = "protobuf" +version = "5.28.3" +description = "" +optional = false +python-versions = ">=3.8" +files = [ + {file = "protobuf-5.28.3-cp310-abi3-win32.whl", hash = "sha256:0c4eec6f987338617072592b97943fdbe30d019c56126493111cf24344c1cc24"}, + {file = "protobuf-5.28.3-cp310-abi3-win_amd64.whl", hash = "sha256:91fba8f445723fcf400fdbe9ca796b19d3b1242cd873907979b9ed71e4afe868"}, + {file = "protobuf-5.28.3-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:a3f6857551e53ce35e60b403b8a27b0295f7d6eb63d10484f12bc6879c715687"}, + {file = "protobuf-5.28.3-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:3fa2de6b8b29d12c61911505d893afe7320ce7ccba4df913e2971461fa36d584"}, + {file = "protobuf-5.28.3-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:712319fbdddb46f21abb66cd33cb9e491a5763b2febd8f228251add221981135"}, + {file = "protobuf-5.28.3-cp38-cp38-win32.whl", hash = "sha256:3e6101d095dfd119513cde7259aa703d16c6bbdfae2554dfe5cfdbe94e32d548"}, + {file = "protobuf-5.28.3-cp38-cp38-win_amd64.whl", hash = "sha256:27b246b3723692bf1068d5734ddaf2fccc2cdd6e0c9b47fe099244d80200593b"}, + {file = "protobuf-5.28.3-cp39-cp39-win32.whl", hash = "sha256:135658402f71bbd49500322c0f736145731b16fc79dc8f367ab544a17eab4535"}, + {file = "protobuf-5.28.3-cp39-cp39-win_amd64.whl", hash = "sha256:70585a70fc2dd4818c51287ceef5bdba6387f88a578c86d47bb34669b5552c36"}, + {file = "protobuf-5.28.3-py3-none-any.whl", hash = "sha256:cee1757663fa32a1ee673434fcf3bf24dd54763c79690201208bafec62f19eed"}, + {file = "protobuf-5.28.3.tar.gz", hash = "sha256:64badbc49180a5e401f373f9ce7ab1d18b63f7dd4a9cdc43c92b9f0b481cef7b"}, +] + [[package]] name = "psutil" version = "6.1.0" @@ -1643,6 +1805,31 @@ files = [ {file = "py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5"}, ] +[[package]] +name = "pyasn1" +version = "0.6.1" +description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"}, + {file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"}, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.1" +description = "A collection of ASN.1-based protocols modules" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd"}, + {file = "pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c"}, +] + +[package.dependencies] +pyasn1 = ">=0.4.6,<0.7.0" + [[package]] name = "pycparser" version = "2.22" @@ -1817,6 +2004,20 @@ files = [ [package.extras] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pyparsing" +version = "3.2.0" +description = "pyparsing module - Classes and methods to define and execute parsing grammars" +optional = false +python-versions = ">=3.9" +files = [ + {file = "pyparsing-3.2.0-py3-none-any.whl", hash = "sha256:93d9577b88da0bbea8cc8334ee8b918ed014968fd2ec383e868fb8afb1ccef84"}, + {file = "pyparsing-3.2.0.tar.gz", hash = "sha256:cbf74e27246d595d9a74b186b810f6fbb86726dbf3b9532efb343f6d7294fe9c"}, +] + +[package.extras] +diagrams = ["jinja2", "railroad-diagrams"] + [[package]] name = "pyproject-hooks" version = "1.2.0" @@ -2295,6 +2496,20 @@ typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.1 [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] +[[package]] +name = "rsa" +version = "4.9" +description = "Pure-Python RSA implementation" +optional = false +python-versions = ">=3.6,<4" +files = [ + {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"}, + {file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"}, +] + +[package.dependencies] +pyasn1 = ">=0.1.3" + [[package]] name = "ruff" version = "0.7.4" @@ -2759,6 +2974,17 @@ files = [ {file = "tzdata-2024.2.tar.gz", hash = "sha256:7d85cc416e9382e69095b7bdf4afd9e3880418a2413feec7069d533d6b4e31cc"}, ] +[[package]] +name = "uritemplate" +version = "4.1.1" +description = "Implementation of RFC 6570 URI Templates" +optional = false +python-versions = ">=3.6" +files = [ + {file = "uritemplate-4.1.1-py2.py3-none-any.whl", hash = "sha256:830c08b8d99bdd312ea4ead05994a38e8936266f84b9a7878232db50b044e02e"}, + {file = "uritemplate-4.1.1.tar.gz", hash = "sha256:4346edfc5c3b79f694bccd6d6099a322bbeb628dbf2cd86eea55a456ce5124f0"}, +] + [[package]] name = "urllib3" version = "2.2.3" @@ -2798,6 +3024,40 @@ typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} [package.extras] standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"] +[[package]] +name = "valkey-glide" +version = "1.1.0" +description = "" +optional = false +python-versions = ">=3.8" +files = [ + {file = "valkey_glide-1.1.0-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:fb4e5098d398ad63a96d363abee8b0ce4875aff9b61fbc4dbed441a896c123ad"}, + {file = "valkey_glide-1.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3887236b7ddf29a9910011655f477a159157b022467b3b25837ab5bbf6f8c6ee"}, + {file = "valkey_glide-1.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:025d7cecaf789022af07e454c4717e065634ecbb7496e66877a640d20cd64a44"}, + {file = "valkey_glide-1.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:33b4c78702f38e2ee20dfac18cd3e147df19d5a3b93ce523160ff23f8564f8b8"}, + {file = "valkey_glide-1.1.0-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:5816b201d1c79b1e2e1b1470d66155a1a5a30e371b4262ebfe7239e684bffbef"}, + {file = "valkey_glide-1.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:10f4ff6eccb172472fd552dd8b7dc9272fafdb5a99b3fae5857bf7bd302f7a4a"}, + {file = "valkey_glide-1.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd44755bba6463803f58085c55ae4f9224594ed1e728a70bed3223d6ac357bb8"}, + {file = "valkey_glide-1.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:87066ac5cbac106b430a5f39e684d85e4f31e28fc0e0c8569725c376e7fdb51a"}, + {file = "valkey_glide-1.1.0-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:4ea5d9f58a2aa6171359e8a5ebaccac7fdecf946ce66b77d4e702083f4519a5f"}, + {file = "valkey_glide-1.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:67eb38622c2338782b789a97390cca9a8f8b42dab8f1990f3086dd39f43b73f9"}, + {file = "valkey_glide-1.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46cf4ca974905cbed0e5d27b2d6fe3e18ca8709270bac244e99434b2c57715f8"}, + {file = "valkey_glide-1.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:5ec318cd95acf0b12d0f179b9e11a85ccec61dd0a80e2119d9f7cc7e3ca64651"}, + {file = "valkey_glide-1.1.0-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:41609ff67bcf3a3206ac94eff77cd8fa94a318f7555b937028b943fff248eded"}, + {file = "valkey_glide-1.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8effb1a9fd3fab991bb95b4494c179917a7a92a819fe316832850d5ed3648088"}, + {file = "valkey_glide-1.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5814aac65bcb2855ccdbee9efa6b3176c940af0b1e6a6c5d7fb00f43815c8c5"}, + {file = "valkey_glide-1.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:127d60a7e91296ab85b6d280bf570b6f9aab9eb856a69ae0c02a02c7490a32d9"}, + {file = "valkey_glide-1.1.0-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:5314909101d2362d9a19f1473d69e844a27119c4bb349044924603a3ad5b1f1e"}, + {file = "valkey_glide-1.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2b05d20e4cc10705d61031a4c0547c717ea670eabe5ed21b63133fad6fcff445"}, + {file = "valkey_glide-1.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f4b7aae8e1a81e3eeb8389fae55e5e9802aeb78348afef191700882e4d1a8724"}, + {file = "valkey_glide-1.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:7adcf6df73e0e4e027c0a727dfb61b94d3fe410c13db00ad74eb8405d457915e"}, +] + +[package.dependencies] +async-timeout = ">=4.0.2" +google-api-python-client = "2.85.0" +typing-extensions = ">=4.8.0" + [[package]] name = "virtualenv" version = "20.27.1" @@ -3041,4 +3301,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "8000601d48cfc1b10d0ae18c6046cc59a50cb6c45e6d3ef4775a3203769f2154" +content-hash = "b892ee66e115652a7ea4d44ead2edc0d89baa0811430667a5133736a2f59924a" diff --git a/pyproject.toml b/pyproject.toml index 980c16f9723..13a7b419c39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,26 +4,19 @@ version = "0.6.6dev1" description = "Web apps in pure Python." license = "Apache-2.0" authors = [ - "Nikhil Rao ", - "Alek Petuskey ", - "Masen Furer ", - "Elijah Ahianyo ", - "Thomas Brandého ", + "Nikhil Rao ", + "Alek Petuskey ", + "Masen Furer ", + "Elijah Ahianyo ", + "Thomas Brandého ", ] readme = "README.md" homepage = "https://reflex.dev" repository = "https://github.com/reflex-dev/reflex" documentation = "https://reflex.dev/docs/getting-started/introduction" -keywords = [ - "web", - "framework", -] -classifiers = [ - "Development Status :: 4 - Beta", -] -packages = [ - {include = "reflex"} -] +keywords = ["web", "framework"] +classifiers = ["Development Status :: 4 - Beta"] +packages = [{ include = "reflex" }] [tool.poetry.dependencies] python = "^3.9" @@ -42,11 +35,11 @@ uvicorn = ">=0.20.0" starlette-admin = ">=0.11.0,<1.0" alembic = ">=1.11.1,<2.0" platformdirs = ">=3.10.0,<5.0" -distro = {version = ">=1.8.0,<2.0", platform = "linux"} +distro = { version = ">=1.8.0,<2.0", platform = "linux" } python-engineio = "!=4.6.0" wrapt = [ - {version = ">=1.14.0,<2.0", python = ">=3.11"}, - {version = ">=1.11.0,<2.0", python = "<3.11"}, + { version = ">=1.14.0,<2.0", python = ">=3.11" }, + { version = ">=1.11.0,<2.0", python = "<3.11" }, ] packaging = ">=23.1,<25.0" reflex-hosting-cli = ">=0.1.17,<2.0" @@ -60,6 +53,7 @@ tomlkit = ">=0.12.4,<1.0" lazy_loader = ">=0.4" reflex-chakra = ">=0.6.0" typing_extensions = ">=4.6.0" +valkey-glide = "^1.1.0" [tool.poetry.group.dev.dependencies] pytest = ">=7.1.2,<9.0" diff --git a/reflex/app.py b/reflex/app.py index 758d57c33ae..901a6845f19 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1497,7 +1497,7 @@ async def on_disconnect(self, sid): disconnect_token = self.sid_to_token.pop(sid, None) if disconnect_token: self.token_to_sid.pop(disconnect_token, None) - await self.app.state_manager.disconnect(sid) + await self.app.state_manager.disconnect(disconnect_token) async def emit_update(self, update: StateUpdate, sid: str) -> None: """Emit an update to the client. diff --git a/reflex/state.py b/reflex/state.py index 35d840498f8..d2baa50027e 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -39,6 +39,14 @@ get_type_hints, ) +from glide import ( + OK, + ConditionalChange, + ExpirySet, + ExpiryType, + GlideClient, + GlideClientConfiguration, +) from sqlalchemy.orm import DeclarativeBase from typing_extensions import Self @@ -70,8 +78,6 @@ BaseModelV1 = BaseModelV2 import wrapt -from redis.asyncio import Redis -from redis.exceptions import ResponseError import reflex.istate.dynamic from reflex import constants @@ -94,6 +100,7 @@ ImmutableStateError, InvalidStateManagerMode, LockExpiredError, + RedisConfigError, ReflexRuntimeError, SetUndefinedStateVarError, StateSchemaMismatchError, @@ -2781,16 +2788,17 @@ def create(cls, state: Type[BaseState]): return StateManagerMemory(state=state) if config.state_manager_mode == constants.StateManagerMode.DISK: return StateManagerDisk(state=state) - if config.state_manager_mode == constants.StateManagerMode.REDIS: - redis = prerequisites.get_redis() - if redis is not None: - # make sure expiration values are obtained only from the config object on creation - return StateManagerRedis( - state=state, - redis=redis, - token_expiration=config.redis_token_expiration, - lock_expiration=config.redis_lock_expiration, - ) + if ( + config.state_manager_mode == constants.StateManagerMode.REDIS + and prerequisites.parse_redis_url() is not None + ): + # make sure expiration values are obtained only from the config object on creation + return StateManagerRedis( + state=state, + # redis=redis, + token_expiration=config.redis_token_expiration, + lock_expiration=config.redis_lock_expiration, + ) raise InvalidStateManagerMode( f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}" ) @@ -3185,7 +3193,7 @@ class StateManagerRedis(StateManager): """A state manager that stores states in redis.""" # The redis client to use. - redis: Redis + redis: Optional[GlideClient] = None # The token expiration time (s). token_expiration: int = pydantic.Field(default_factory=_default_token_expiration) @@ -3212,6 +3220,34 @@ class StateManagerRedis(StateManager): # This lock is used to ensure we only subscribe to keyspace events once per token and worker _pubsub_locks: Dict[bytes, asyncio.Lock] = pydantic.PrivateAttr({}) + async def get_redis(self) -> GlideClient: + """Get the redis client. + + Returns: + The redis client. + + Raises: + RedisConfigError: If the redis client could not be configured. + """ + if self.redis is not None: + return self.redis + redis = await prerequisites.get_redis() + assert redis is not None + config_result = await redis.config_set( + {"notify-keyspace-events": self._redis_notify_keyspace_events}, + ) + # Some redis servers only allow out-of-band configuration, so ignore errors here. + if ( + config_result != OK + and not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get() + ): + raise RedisConfigError( + f"Failed to set notify-keyspace-events: {config_result}" + ) + + self.redis = redis + return redis + async def _get_parent_state( self, token: str, state: BaseState | None = None ) -> BaseState | None: @@ -3321,7 +3357,8 @@ async def get_state( state = None # Fetch the serialized substate from redis. - redis_state = await self.redis.get(token) + redis = await self.get_redis() + redis_state = await redis.get(token) if redis_state is not None: # Deserialize the substate. @@ -3374,10 +3411,9 @@ async def set_state( RuntimeError: If the state instance doesn't match the state name in the token. """ # Check that we're holding the lock. - if ( - lock_id is not None - and await self.redis.get(self._lock_key(token)) != lock_id - ): + redis = await self.get_redis() + + if lock_id is not None and await redis.get(self._lock_key(token)) != lock_id: raise LockExpiredError( f"Lock expired for token {token} while processing. Consider increasing " f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) " @@ -3404,13 +3440,21 @@ async def set_state( ) # Persist only the given state (parents or substates are excluded by BaseState.__getstate__). if state._get_was_touched(): + redis = await self.get_redis() pickle_state = state._serialize() if pickle_state: - await self.redis.set( + _ = await redis.set( _substate_key(client_token, state), pickle_state, - ex=self.token_expiration, + expiry=ExpirySet( + expiry_type=ExpiryType.MILLSEC, + value=self.token_expiration, + ), ) + # if str(res) != OK: + # raise RuntimeError( + # f"Failed to set state for token {token}. {res} {OK}" + # ) # Wait for substates to be persisted. for t in tasks: @@ -3456,12 +3500,42 @@ async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None: Returns: True if the lock was obtained. """ - return await self.redis.set( + redis = await self.get_redis() + response = await redis.set( lock_key, lock_id, - px=self.lock_expiration, - nx=True, # only set if it doesn't exist + expiry=ExpirySet( + expiry_type=ExpiryType.MILLSEC, + value=self.lock_expiration, + ), + conditional_set=ConditionalChange.ONLY_IF_DOES_NOT_EXIST, ) + return str(response) == OK + + async def get_pubsub(self, lock_key: bytes) -> GlideClient: + """Get the pubsub client for a lock key channel. + + Args: + lock_key: The redis key for the lock. + + Returns: + The pubsub client. + """ + lock_key_channel = f"__keyspace@0__:{lock_key.decode()}" + pubsub_config = GlideClientConfiguration.PubSubSubscriptions( + channels_and_patterns={ + GlideClientConfiguration.PubSubChannelModes.Pattern: {lock_key_channel}, + # GlideClientConfiguration.PubSubChannelModes.Exact: {lock_key_channel}, + }, + callback=None, + context=None, + ) + config = prerequisites.get_glide_client_configuration( + pubsub_subscriptions=pubsub_config + ) + assert config is not None + pubsub = await GlideClient.create(config) + return pubsub async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None: """Wait for a redis lock to be released via pubsub. @@ -3471,43 +3545,35 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None: Args: lock_key: The redis key for the lock. lock_id: The ID of the lock. - - Raises: - ResponseError: when the keyspace config cannot be set. """ state_is_locked = False - lock_key_channel = f"__keyspace@0__:{lock_key.decode()}" # Enable keyspace notifications for the lock key, so we know when it is available. - try: - await self.redis.config_set( - "notify-keyspace-events", - self._redis_notify_keyspace_events, - ) - except ResponseError: - # Some redis servers only allow out-of-band configuration, so ignore errors here. - if not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get(): - raise + redis = await self.get_redis() if lock_key not in self._pubsub_locks: self._pubsub_locks[lock_key] = asyncio.Lock() - async with self._pubsub_locks[lock_key], self.redis.pubsub() as pubsub: - await pubsub.psubscribe(lock_key_channel) + async with self._pubsub_locks[lock_key]: + pubsub = await self.get_pubsub(lock_key) while not state_is_locked: # wait for the lock to be released while True: - if not await self.redis.exists(lock_key): + # check if we missed lock release events + if await redis.exists([lock_key]) == 0: break # key was removed, try to get the lock again - message = await pubsub.get_message( - ignore_subscribe_messages=True, - timeout=self.lock_expiration / 1000.0, - ) - if message is None: + + try: + # TODO: alternative to ignore_subscribe_messages? + message = await asyncio.wait_for( + pubsub.get_pubsub_message(), + timeout=self.lock_expiration / 1000.0, + ) + except asyncio.TimeoutError: continue - if message["data"] in self._redis_keyspace_lock_release_events: + if message.message in self._redis_keyspace_lock_release_events: break state_is_locked = await self._try_get_lock(lock_key, lock_id) @override - async def disconnect(self, token: str): + async def disconnect(self, token: str) -> None: """Disconnect the token from the redis client. Args: @@ -3520,7 +3586,7 @@ async def disconnect(self, token: str): del self._pubsub_locks[lock_key] @contextlib.asynccontextmanager - async def _lock(self, token: str): + async def _lock(self, token: str) -> AsyncIterator[bytes]: """Obtain a redis lock for a token. Args: @@ -3548,9 +3614,12 @@ async def _lock(self, token: str): finally: if state_is_locked: # only delete our lock - await self.redis.delete(lock_key) + redis = await self.get_redis() + _ = await redis.delete([lock_key]) + # if not res: + # raise RuntimeError(f"Failed to release lock for token {token}") - async def close(self): + async def close(self) -> None: """Explicitly close the redis connection and connection_pool. It is necessary in testing scenarios to close between asyncio test cases @@ -3559,7 +3628,9 @@ async def close(self): Note: Connections will be automatically reopened when needed. """ - await self.redis.aclose(close_connection_pool=True) + if self.redis is not None: + await self.redis.close() + self.redis = None def get_state_manager() -> StateManager: diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index 714dc912c7a..303df606924 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -87,6 +87,10 @@ class LockExpiredError(ReflexError): """Raised when the state lock expires while an event is being processed.""" +class RedisConfigError(ReflexError): + """Raised when the Redis configuration is not applied correctly.""" + + class MatchTypeError(ReflexError, TypeError): """Raised when the return types of match cases are different.""" diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index ec79b3297b2..bad9a159645 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -21,15 +21,15 @@ from datetime import datetime from pathlib import Path from types import ModuleType -from typing import Callable, List, Optional +from typing import TYPE_CHECKING, Any, Callable, List, Optional import httpx import typer from alembic.util.exc import CommandError +from glide import GlideClient, GlideClientConfiguration, NodeAddress from packaging import version from redis import Redis as RedisSync from redis import exceptions -from redis.asyncio import Redis from reflex import constants, model from reflex.compiler import templates @@ -44,6 +44,9 @@ CURRENTLY_INSTALLING_NODE = False +if TYPE_CHECKING: + from reflex.app import App + @dataclasses.dataclass(frozen=True) class Template: @@ -320,7 +323,7 @@ def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType: The compiled app based on the default config. """ app_module = get_app(reload=reload) - app = getattr(app_module, constants.CompileVars.APP) + app: App = getattr(app_module, constants.CompileVars.APP) # For py3.9 compatibility when redis is used, we MUST add any decorator pages # before compiling the app in a thread to avoid event loop error (REF-2172). app._apply_decorated_pages() @@ -328,17 +331,46 @@ def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType: return app_module -def get_redis() -> Redis | None: +def get_node_addresses() -> List[NodeAddress]: + """Get the node addresses from the config. + + Returns: + The node addresses. + """ + redis = get_redis_sync() + if redis is None: + return [] + host = redis.connection_pool.connection_kwargs["host"] + port = redis.connection_pool.connection_kwargs["port"] + return [NodeAddress(host=host, port=port)] + + +def get_glide_client_configuration(**kwargs) -> GlideClientConfiguration | None: + """Get the glide client configuration. + + Args: + kwargs: Additional keyword arguments to pass to the GlideClientConfiguration. + + Returns: + The glide client configuration. + """ + addresses = get_node_addresses() + if not addresses: + return None + return GlideClientConfiguration(addresses=addresses, **kwargs) + + +async def get_redis() -> GlideClient | None: """Get the asynchronous redis client. Returns: The asynchronous redis client. """ - if isinstance((redis_url_or_options := parse_redis_url()), str): - return Redis.from_url(redis_url_or_options) - elif isinstance(redis_url_or_options, dict): - return Redis(**redis_url_or_options) - return None + config = get_glide_client_configuration() + if config is None: + return None + client = await GlideClient.create(config) + return client def get_redis_sync() -> RedisSync | None: @@ -354,7 +386,7 @@ def get_redis_sync() -> RedisSync | None: return None -def parse_redis_url() -> str | dict | None: +def parse_redis_url() -> str | dict[Any, Any] | None: """Parse the REDIS_URL in config if applicable. Returns: diff --git a/tests/integration/test_background_task.py b/tests/integration/test_background_task.py index 2804b1ef819..91ef7349810 100644 --- a/tests/integration/test_background_task.py +++ b/tests/integration/test_background_task.py @@ -44,7 +44,7 @@ async def handle_event_yield_only(self): @rx.event(background=True) async def fast_yielding(self): - for _ in range(100000): + for _ in range(1000): yield State.increment() @rx.event @@ -409,4 +409,4 @@ def test_fast_yielding( assert background_task._poll_for(lambda: counter.text == "0", timeout=5) fast_yielding_button.click() - assert background_task._poll_for(lambda: counter.text == "100000", timeout=1200) + assert background_task._poll_for(lambda: counter.text == "1000", timeout=20) diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 45c021bd82c..7c5c1097696 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -1683,7 +1683,8 @@ async def test_state_manager_modify_state( """ async with state_manager.modify_state(substate_token) as state: if isinstance(state_manager, StateManagerRedis): - assert await state_manager.redis.get(f"{token}_lock") + redis = await state_manager.get_redis() + assert await redis.get(f"{token}_lock") elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): assert token in state_manager._states_locks assert state_manager._states_locks[token].locked() @@ -1693,7 +1694,8 @@ async def test_state_manager_modify_state( state.complex[3] = complex_1 # lock should be dropped after exiting the context if isinstance(state_manager, StateManagerRedis): - assert (await state_manager.redis.get(f"{token}_lock")) is None + redis = await state_manager.get_redis() + assert (await redis.get(f"{token}_lock")) is None elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): assert not state_manager._states_locks[token].locked() @@ -1735,7 +1737,8 @@ async def _coro(): assert (await state_manager.get_state(substate_token)).num1 == exp_num1 if isinstance(state_manager, StateManagerRedis): - assert (await state_manager.redis.get(f"{token}_lock")) is None + redis = await state_manager.get_redis() + assert (await redis.get(f"{token}_lock")) is None elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): assert token in state_manager._states_locks assert not state_manager._states_locks[token].locked() @@ -1925,6 +1928,14 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): # Cannot access substates sp.substates[""] + assert ( + sp.router.session.client_token == grandchild_state.router.session.client_token + ) + assert ( + sp.__wrapped__.router.session.client_token + == grandchild_state.router.session.client_token + ) + assert sp.router.session.client_token is not None async with sp: assert sp._self_actx is not None assert sp._self_mutable # proxy is mutable inside context