Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix get_parent_state and get_root_state when using mixin=True #4976

Merged
merged 1 commit into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,14 @@ def get_parent_state(cls) -> Type[BaseState] | None:
]
if len(parent_states) >= 2:
raise ValueError(f"Only one parent state is allowed {parent_states}.")
return parent_states[0] if len(parent_states) == 1 else None
# The first non-mixin state in the mro is our parent.
for base in cls.mro()[1:]:
if base._mixin or not issubclass(base, BaseState):
continue
if base is BaseState:
break
return base
return None # No known parent

@classmethod
@functools.lru_cache()
Expand Down
33 changes: 33 additions & 0 deletions tests/units/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3424,6 +3424,18 @@ class ChildUsesMixinState(UsesMixinState):
pass


class ChildMixinState(ChildUsesMixinState, mixin=True):
"""A mixin state that inherits from a concrete state that uses mixins."""

pass


class GrandchildUsesMixinState(ChildMixinState):
"""A grandchild state that uses the mixin state."""

pass


def test_mixin_state() -> None:
"""Test that a mixin state works correctly."""
assert "num" in UsesMixinState.base_vars
Expand All @@ -3438,6 +3450,9 @@ def test_mixin_state() -> None:
is not UsesMixinState.backend_vars["_backend_no_default"]
)

assert UsesMixinState.get_parent_state() == State
assert UsesMixinState.get_root_state() == State


def test_child_mixin_state() -> None:
"""Test that mixin vars are only applied to the highest state in the hierarchy."""
Expand All @@ -3447,6 +3462,24 @@ def test_child_mixin_state() -> None:
assert "computed" in ChildUsesMixinState.inherited_vars
assert "computed" not in ChildUsesMixinState.computed_vars

assert ChildUsesMixinState.get_parent_state() == UsesMixinState
assert ChildUsesMixinState.get_root_state() == State


def test_grandchild_mixin_state() -> None:
"""Test that a mixin can inherit from a concrete state class."""
assert "num" in GrandchildUsesMixinState.inherited_vars
assert "num" not in GrandchildUsesMixinState.base_vars

assert "computed" in GrandchildUsesMixinState.inherited_vars
assert "computed" not in GrandchildUsesMixinState.computed_vars

assert ChildMixinState.get_parent_state() == ChildUsesMixinState
assert ChildMixinState.get_root_state() == State

assert GrandchildUsesMixinState.get_parent_state() == ChildUsesMixinState
assert GrandchildUsesMixinState.get_root_state() == State


def test_assignment_to_undeclared_vars():
"""Test that an attribute error is thrown when undeclared vars are set."""
Expand Down
Loading