diff --git a/reflex/vars/object.py b/reflex/vars/object.py index d86c65e6b4a..5ee57059ec6 100644 --- a/reflex/vars/object.py +++ b/reflex/vars/object.py @@ -221,6 +221,29 @@ def __getitem__(self, key: Var | Any) -> Var: return self.__getattr__(key) return ObjectItemOperation.create(self, key).guess_type() + def get(self, key: Var | Any, default: Var | Any | None = None) -> Var: + """Get an item from the object. + + Args: + key: The key to get from the object. + default: The default value if the key is not found. + + Returns: + The item from the object. + """ + from reflex.components.core.cond import cond + + if default is None: + default = Var.create(None) + + value = self.__getitem__(key) # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue,reportUnknownMemberType] + + return cond( # pyright: ignore[reportUnknownVariableType] + value, + value, + default, + ) + # NoReturn is used here to catch when key value is Any @overload def __getattr__( # pyright: ignore [reportOverlappingOverload] diff --git a/tests/integration/test_var_operations.py b/tests/integration/test_var_operations.py index ba9c018e6de..14a4781d4ba 100644 --- a/tests/integration/test_var_operations.py +++ b/tests/integration/test_var_operations.py @@ -18,6 +18,8 @@ def VarOperations(): class Object(rx.Base): name: str = "hello" + optional_none: str | None = None + optional_str: str | None = "hello" class Person(TypedDict): name: str @@ -47,6 +49,7 @@ class VarOperationState(rx.State): people: rx.Field[list[Person]] = rx.field( [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}] ) + obj: rx.Field[Object] = rx.field(Object()) app = rx.App(_state=rx.State) @@ -713,6 +716,27 @@ def index(): rx.text.span(f"{rx.Var.create(13212312312.1231231):_.2f}"), id="float_format_underscore_2f", ), + # ObjectVar + rx.box( + rx.text(VarOperationState.obj.name), + id="obj_name", + ), + rx.box( + rx.text(VarOperationState.obj.optional_none), + id="obj_optional_none", + ), + rx.box( + rx.text(VarOperationState.obj.optional_str), + id="obj_optional_str", + ), + rx.box( + rx.text(VarOperationState.obj.get("optional_none")), + id="obj_optional_none_get_none", + ), + rx.box( + rx.text(VarOperationState.obj.get("optional_none", "foo")), + id="obj_optional_none_get_foo", + ), ) @@ -936,6 +960,11 @@ def test_var_operations(driver, var_operations: AppHarness): ("float_format_underscore_0f", "13_212_312_312"), ("float_format_underscore_1f", "13_212_312_312.1"), ("float_format_underscore_2f", "13_212_312_312.12"), + ("obj_name", "hello"), + ("obj_optional_none", ""), + ("obj_optional_str", "hello"), + ("obj_optional_none_get_none", ""), + ("obj_optional_none_get_foo", "foo"), ] for tag, expected in tests: