Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to use Pure Pydantic Models as types for a rx.State class #4250

Open
grizzlybearg opened this issue Oct 27, 2024 · 6 comments
Open

Unable to use Pure Pydantic Models as types for a rx.State class #4250

grizzlybearg opened this issue Oct 27, 2024 · 6 comments
Assignees
Labels
bug Something isn't working

Comments

@grizzlybearg
Copy link

grizzlybearg commented Oct 27, 2024

Bug Description: Unable to use a pure Pydantic model as a type for my runner state. The library throws the following error:

File "...\ui\internal\main\views\runners.py", line 186, in
lambda item, index: _show_item(item, index),
^^^^^^^^^^^^^^^^^^^^^^^
File "...\ui\internal\main\views\runners.py", line 26, in _show_item
epic = runner.epic
^^^^^^^^^^^
File "...\site-packages\reflex\vars\base.py", line 909, in getattr
raise VarAttributeError(
reflex.utils.exceptions.VarAttributeError: The State var has no attribute 'epic' or may have been annotated wrongly.

To Reproduce: Steps to reproduce the behavior:

 class RunnerDetails(BaseModel):
    runner_id: Optional[UUID] = Field(default=uuid4(), required=False)
    epic: str
    status: str
    container: str
    endpoint: str
    lastUpdate: datetime
    local: bool
    initial_state: bool

class RunnersState(rx.State):
    items: List[RunnerDetails] = []
    search_value: str = ""
    sort_value: str = ""
    sort_reverse: bool = False
    total_items: int = 0
    offset: int = 0
    limit: int = 12  # Number of rows per page

    @rx.var(cache=True)
    def filtered_sorted_items(self) -> List[RunnerDetails]:
        items = self.items
        ...

def _show_item(item: RunnerDetails, index: int) -> rx.Component:
    bg_color = rx.cond(
        index % 2 == 0,
        rx.color("gray", 1),
        rx.color("accent", 2),
    )
    hover_color = rx.cond(
        index % 2 == 0,
        rx.color("gray", 3),
        rx.color("accent", 3),
    )
    try:
        epic = item.epic
        resolution = item.resolution
        lastUpdate = item.lastUpdate
        status = item.status
    except rx.utils.exceptions.VarAttributeError:
        print(f"Runner: {item}")
        raise
    return rx.table.row(
        rx.table.row_header_cell(epic),
        rx.table.cell(f"${resolution}"),
        rx.table.cell(lastUpdate),
        rx.table.cell(status_badge(str(status))),
        style={"_hover": {"bg": hover_color}, "bg": bg_color},
        align="center",
    )

Expected behavior: Expected to run the app after initialization.

  • Python Version: 3.12
  • Reflex Version: latest
  • OS: windows/

Additional context: I have also tried the following approach, but the result is the same:

def validate_object_id(v: Any) -> RunnerDetails:
    if isinstance(v, RunnerDetails):
        return v
    elif isinstance(v, Dict):
        return RunnerDetails.model_validate(v)
    raise ValueError(f"Invalid object: {v}")
class RunnersState(rx.State):
    items: List[
        Annotated[
            RunnerDetails,
            AfterValidator(lambda x: validate_object_id(x)),
            PlainSerializer(lambda x: x.model_dump(), return_type=Dict),
        ],
    ] = []

The RunnerDetails model cannot be modified to subclass rx.Model or SQLModel and must remain a pure Pydantic model because it is used in another part of the codebase that doesn't use the reflex package.

Help with solving this issue would be appreciated.

@grizzlybearg grizzlybearg added the bug Something isn't working label Oct 27, 2024
Copy link

linear bot commented Oct 27, 2024

@benedikt-bartscher
Copy link
Contributor

It seems like your example code is not complete, where is _show_item called? The stack trace shows a lambda, maybe try a typed function instead

@grizzlybearg
Copy link
Author

grizzlybearg commented Oct 28, 2024

Hey @benedikt-bartscher
The _show_item function is called in:

def main_table() -> rx.Component:
    return rx.box(
        rx.flex(
            rx.flex(
                rx.cond(
                    RunnersState.sort_reverse,
                    rx.icon(
                        "arrow-down-z-a",
                        size=28,
                        stroke_width=1.5,
                        cursor="pointer",
                        flex_shrink="0",
                        on_click=RunnersState.toggle_sort,
                    ),
                    rx.icon(
                        "arrow-down-a-z",
                        size=28,
                        stroke_width=1.5,
                        cursor="pointer",
                        flex_shrink="0",
                        on_click=RunnersState.toggle_sort,
                    ),
                ),
                rx.select(
                    [
                        "Epic",
                        "Last Update",
                        "Status",
                    ],
                    placeholder="Sort By: Epic",
                    size="3",
                    on_change=RunnersState.set_sort_value,
                ),
                rx.input(
                    rx.input.slot(rx.icon("scan-search")),
                    rx.input.slot(
                        rx.icon("x"),
                        justify="end",
                        cursor="pointer",
                        on_click=RunnersState.setvar("search_value", ""),
                        display=rx.cond(RunnersState.search_value, "flex", "none"),
                    ),
                    value=RunnersState.search_value,
                    placeholder="Search here...",
                    size="3",
                    max_width=["150px", "150px", "200px", "250px"],
                    width="100%",
                    variant="surface",
                    color_scheme="gray",
                    on_change=RunnersState.set_search_value,
                ),
                align="center",
                justify="end",
                spacing="3",
            ),
            rx.button(
                rx.icon("arrow-down-to-line", size=20),
                "Export",
                size="3",
                variant="surface",
                display=["none", "none", "none", "flex"],
                on_click=rx.download(url="/items.csv"),
            ),
            spacing="3",
            justify="between",
            wrap="wrap",
            width="100%",
            padding_bottom="1em",
        ),
        rx.table.root(
            rx.table.header(
                rx.table.row(
                    _header_cell("Epic", "user"),
                    _header_cell("Last Update", "calendar"),
                    _header_cell("Status", "notebook-pen"),
                ),
            ),
            rx.table.body(
                rx.foreach(
                    RunnersState.get_current_page,
                    lambda item, index: _show_item(item, index), # < ---------- CALLED HERE
                )
            ),
            variant="surface",
            size="3",
            width="100%",
        ),
        _pagination_view(),
        width="100%",
    )

as shown in the dashboard example.

The RunnersState.get_current_page method is the same as the TableState.get_current_page in the dashboard example as well.

The route/page is also as shown by the dashboard example's table page

@template(route="/runners", title="Runners", on_load=RunnersState.load_entries)
def runners() -> rx.Component:
    """The table page.

    Returns:
        The UI for the table page.
    """
    return rx.vstack(
        rx.heading("Runners", size="5"),
        main_table(),
        spacing="8",
        width="100%",
    )

@riebecj
Copy link

riebecj commented Oct 29, 2024

I'm also using Pydantic models in another part of my code base, but desire to use those models as MutableProxy's in Reflex so I don't have to recreate existing models.

@grizzlybearg
Copy link
Author

I'm also using Pydantic models in another part of my code base, but desire to use those models as MutableProxy's in Reflex so I don't have to recreate existing models.

Any success on your end? @riebecj .The idea is to avoid recreating dozens of models due to maintenance issues

@riebecj
Copy link

riebecj commented Oct 30, 2024

I'm also using Pydantic models in another part of my code base, but desire to use those models as MutableProxy's in Reflex so I don't have to recreate existing models.

Any success on your end? @riebecj .The idea is to avoid recreating dozens of models due to maintenance issues

Ok, I had to do some hacky things to make it work, but it's possible. First, let me preface by saying you won't be able to use Pydantic models as front-end state attributes. They need to be back-end (so starting with a _). However, you can reference your custom Pydantic model attributes in a @rx.var method in your rx.State class, and it will render.

Note: I've placed all this custom code in the rxconfig.py so it executes before getting to my Reflex app.

To make it recognize a Pydantic BaseModel as a valid State attribute, it needs to have a serializer. There is some weird -isms with the how Reflex uses rxconfig.py, as it seems to be called or imported twice, so I had to add a check to the custom BaseModel serializer, but that code looks like:

if BaseModel not in serializers.SERIALIZERS:
    @serializers.serializer
    def serialize_pydantic(value: BaseModel) -> dict:
        return value.model_dump(by_alias=True)

Not sure if the serializer is actually called, but having one in the SERIALIZERS map prevents Reflex from throwing an error about it. Now, this isn't enough, as Reflex gets serializers by type, which will end up being the name of your BaseModel class instead of the BaseModel type itself. So, I had to module substitute a custom get_serializer to check if issubclass(type_, BaseModel) and get the serializer using the BaseModel type instead of the passed type_. That code looks like this:

@functools.lru_cache
def custom_get_serializer(type_: type) -> Optional[serializers.Serializer]:
    if issubclass(type_, BaseModel):
        serializer = serializers.SERIALIZERS.get(BaseModel)
    else:
        serializer = serializers.SERIALIZERS.get(type_)

    if serializer is not None:
        return serializer

    # If the type is not registered, check if it is a subclass of a registered type.
    for registered_type, serializer in reversed(serializers.SERIALIZERS.items()):
        if types._issubclass(type_, registered_type):  # noqa: SLF001
            return serializer

    # If there is no serializer, return None.
    return None

serializers.get_serializer = custom_get_serializer

The only change being the addition of the if issubclass(...): ... else: .... The rest of the code is copy/paste from the actual get_serializer(). Then I was able to add my Pydantic model to my state in this fashion:

class CustomModel(BaseModel):
    foo: str = "bar"

class MyState(rx.State):
    _custom_model = CustomModel()
    
    @rx.var(cache=True)
    def foo(self) -> str:
        return self._custom_model.foo

Then you can reference your MyState.foo in your front-end code and the value set in the model will render. In this method, the Pydantic model only serves as a back-end container for your data, and caching it in the state with the @rx.var decorator.

Like I said: hacky, but it works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants