Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

json patch #4890

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"gunicorn >=20.1.0,<24.0",
"httpx >=0.25.1,<1.0",
"jinja2 >=3.1.2,<4.0",
"jsonpatch >=1.33,<2.0",
"lazy_loader >=0.4",
"packaging >=23.1,<25.0",
"platformdirs >=3.10.0,<5.0",
Expand Down
43 changes: 40 additions & 3 deletions reflex/.templates/web/utils/state.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
} from "$/utils/context.js";
import debounce from "$/utils/helpers/debounce";
import throttle from "$/utils/helpers/throttle";
import { applyPatch } from "fast-json-patch/index.mjs";

// Endpoint URLs.
const EVENTURL = env.EVENT;
Expand Down Expand Up @@ -127,7 +128,7 @@ export const isStateful = () => {
if (event_queue.length === 0) {
return false;
}
return event_queue.some((event) => event.name.startsWith("reflex___state"));
return event_queue.some((event) => event.name.startsWith(state_name));
};

/**
Expand Down Expand Up @@ -485,13 +486,48 @@ export const connect = async (
window.removeEventListener("pagehide", pagehideHandler);
});

const last_substate_info = {};
const last_substate_hash = {};

const getSubstateFromUpdate = (update, substate_name) => {
if (update.__patch) {
if (last_substate_hash[substate_name] !== update.__previous_hash) {
return null;
}
last_substate_hash[substate_name] = update.__hash;
return applyPatch(last_substate_info[substate_name], update.__patch)
.newDocument;
} else {
last_substate_hash[substate_name] = update.__hash;
return update.__full;
}
};

// On each received message, queue the updates and events.
socket.current.on("event", async (update) => {
const failed_substates = [];
for (const substate in update.delta) {
dispatch[substate](update.delta[substate]);
const new_substate_info = getSubstateFromUpdate(
update.delta[substate],
substate,
);
if (new_substate_info === null) {
console.error("Received patch out of order", update.delta[substate]);
failed_substates.push(substate);
delete update.delta[substate];
continue;
}
last_substate_info[substate] = new_substate_info;
update.delta[substate] = new_substate_info;
dispatch[substate](new_substate_info);
}
applyClientStorageDelta(client_storage, update.delta);
event_processing = !update.final;
if (failed_substates.length > 0) {
update.events.push(
Event(state_name + ".partial_hydrate", { states: failed_substates }),
);
}
if (update.events) {
queueEvents(update.events, socket);
}
Expand Down Expand Up @@ -872,6 +908,7 @@ export const useEventLoop = (
(async () => {
// Process all outstanding events.
while (event_queue.length > 0 && !event_processing) {
await new Promise((resolve) => setTimeout(resolve, 0));
await processEvent(socket.current);
}
})();
Expand Down Expand Up @@ -911,7 +948,7 @@ export const useEventLoop = (
// Route after the initial page hydration.
useEffect(() => {
const change_start = () => {
const main_state_dispatch = dispatch["reflex___state____state"];
const main_state_dispatch = dispatch[state_name];
if (main_state_dispatch !== undefined) {
main_state_dispatch({ is_hydrated: false });
}
Expand Down
2 changes: 1 addition & 1 deletion reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,7 +1407,7 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
async with self.state_manager.modify_state(token) as state:
# No other event handler can modify the state while in this context.
yield state
delta = state.get_delta()
delta = state.get_delta(token=token)
if delta:
# When the state is modified reset dirty status and emit the delta to the frontend.
state._clean()
Expand Down
2 changes: 2 additions & 0 deletions reflex/app_mixins/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from reflex.event import Event
from reflex.middleware import HydrateMiddleware, Middleware
from reflex.middleware.hydrate_middleware import PartialHyderateMiddleware
from reflex.state import BaseState, StateUpdate

from .mixin import AppMixin
Expand All @@ -21,6 +22,7 @@ class MiddlewareMixin(AppMixin):

def _init_mixin(self):
self._middlewares.append(HydrateMiddleware())
self._middlewares.append(PartialHyderateMiddleware())

def add_middleware(self, middleware: Middleware, index: int | None = None):
"""Add middleware to the app.
Expand Down
8 changes: 4 additions & 4 deletions reflex/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from reflex.components.component import Component, ComponentStyle, CustomComponent
from reflex.istate.storage import Cookie, LocalStorage, SessionStorage
from reflex.state import BaseState, _resolve_delta
from reflex.state import BaseState, StateDelta, _resolve_delta
from reflex.style import Style
from reflex.utils import console, format, imports, path_ops
from reflex.utils.exec import is_in_app_harness
Expand Down Expand Up @@ -187,7 +187,7 @@ def compile_state(state: Type[BaseState]) -> dict:
Returns:
A dictionary of the compiled state.
"""
initial_state = state(_reflex_internal_init=True).dict(initial=True)
initial_state = StateDelta(state(_reflex_internal_init=True).dict(initial=True))
try:
_ = asyncio.get_running_loop()
except RuntimeError:
Expand All @@ -202,10 +202,10 @@ def compile_state(state: Type[BaseState]) -> dict:
console.warn(
f"Had to get initial state in a thread 🤮 {resolved_initial_state}",
)
return resolved_initial_state
return dict(**resolved_initial_state.data)

# Normally the compile runs before any event loop starts, we asyncio.run is available for calling.
return asyncio.run(_resolve_delta(initial_state))
return dict(**asyncio.run(_resolve_delta(initial_state)).data)


def _compile_client_storage_field(
Expand Down
3 changes: 3 additions & 0 deletions reflex/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,9 @@ class EnvironmentVariables:
# Used by flexgen to enumerate the pages.
REFLEX_ADD_ALL_ROUTES_ENDPOINT: EnvVar[bool] = env_var(False)

# Use the JSON patch format for websocket messages.
REFLEX_USE_JSON_PATCH: EnvVar[bool] = env_var(True)


environment = EnvironmentVariables()

Expand Down
2 changes: 2 additions & 0 deletions reflex/constants/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class CompileVars(SimpleNamespace):
EVENTS = "events"
# The name of the initial hydrate event.
HYDRATE = "hydrate"
# The name of the partial hydrate event.
PARTIAL_HYDRATE = "partial_hydrate"
# The name of the is_hydrated variable.
IS_HYDRATED = "is_hydrated"
# The name of the function to add events to the queue.
Expand Down
1 change: 1 addition & 0 deletions reflex/constants/installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class Commands(SimpleNamespace):
"react-dom": "19.0.0",
"react-focus-lock": "2.13.6",
"socket.io-client": "4.8.1",
"fast-json-patch": "3.1.1",
"universal-cookie": "7.2.2",
}
DEV_DEPENDENCIES = {
Expand Down
12 changes: 12 additions & 0 deletions reflex/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,6 +1199,18 @@ def get_hydrate_event(state: BaseState) -> str:
return get_event(state, constants.CompileVars.HYDRATE)


def get_partial_hydrate_event(state: BaseState) -> str:
"""Get the name of the partial hydrate event for the state.

Args:
state: The state.

Returns:
The name of the partial hydrate event.
"""
return get_event(state, constants.CompileVars.PARTIAL_HYDRATE)


def call_event_handler(
event_callback: EventHandler | EventSpec,
event_spec: ArgsSpec | Sequence[ArgsSpec],
Expand Down
5 changes: 3 additions & 2 deletions reflex/middleware/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Reflex middleware."""

from .hydrate_middleware import HydrateMiddleware
from .middleware import Middleware
from .hydrate_middleware import HydrateMiddleware as HydrateMiddleware
from .hydrate_middleware import PartialHyderateMiddleware as PartialHydrateMiddleware
from .middleware import Middleware as Middleware
61 changes: 57 additions & 4 deletions reflex/middleware/hydrate_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from __future__ import annotations

import dataclasses
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, ChainMap

from reflex import constants
from reflex.event import Event, get_hydrate_event
from reflex.event import Event, get_hydrate_event, get_partial_hydrate_event
from reflex.middleware.middleware import Middleware
from reflex.state import BaseState, StateUpdate, _resolve_delta
from reflex.state import BaseState, StateDelta, StateUpdate, _resolve_delta

if TYPE_CHECKING:
from reflex.app import App
Expand Down Expand Up @@ -42,9 +42,62 @@ async def preprocess(
setattr(state, constants.CompileVars.IS_HYDRATED, False)

# Get the initial state.
delta = await _resolve_delta(state.dict())
delta = await _resolve_delta(
StateDelta(
state.dict(),
client_token=state.router.session.client_token,
flush=True,
)
)
# since a full dict was captured, clean any dirtiness
state._clean()

# Return the state update.
return StateUpdate(delta=delta, events=[])


@dataclasses.dataclass(init=True)
class PartialHyderateMiddleware(Middleware):
"""Middleware to handle partial app hydration."""

async def preprocess(
self, app: App, state: BaseState, event: Event
) -> StateUpdate | None:
"""Preprocess the event.

Args:
app: The app to apply the middleware to."
state: The client state.""
event: The event to preprocess.""

Returns:
An optional delta or list of state updates to return.""
"""
# If this is not the partial hydrate event, return None
if event.name != get_partial_hydrate_event(state):
return None

substates_names = event.payload.get("states", [])
if not substates_names:
return None

substates = [
substate
for substate_name in substates_names
if (substate := state.get_substate(substate_name.split("."))) is not None
]

delta = await _resolve_delta(
StateDelta(
ChainMap(*[substate.dict() for substate in substates]),
client_token=state.router.session.client_token,
flush=True,
)
)

# since a full dict was captured, clean any dirtiness
for substate in substates:
substate._clean()

# Return the state update.
return StateUpdate(delta=delta, events=[])
Loading
Loading