diff --git a/reflex/state.py b/reflex/state.py index 434ee39217..bd147d3026 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -392,6 +392,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # A special event handler for setting base vars. setvar: ClassVar[EventHandler] + # Track if computed vars have changed since last serialization + _changed_computed_vars: Set[str] = set() + + # Track which computed vars have already been computed + _ready_computed_vars: Set[str] = set() + def __init__( self, parent_state: BaseState | None = None, @@ -1850,11 +1856,12 @@ def _mark_dirty_computed_vars(self) -> None: while dirty_vars: calc_vars, dirty_vars = dirty_vars, set() for cvar in self._dirty_computed_vars(from_vars=calc_vars): - self.dirty_vars.add(cvar) - dirty_vars.add(cvar) actual_var = self.computed_vars.get(cvar) - if actual_var is not None: + assert actual_var is not None + if actual_var.has_changed(instance=self): actual_var.mark_dirty(instance=self) + self.dirty_vars.add(cvar) + dirty_vars.add(cvar) def _expired_computed_vars(self) -> set[str]: """Determine ComputedVars that need to be recalculated based on the expiration time. @@ -2134,6 +2141,10 @@ def __getstate__(self): state["__dict__"].pop("parent_state", None) state["__dict__"].pop("substates", None) state["__dict__"].pop("_was_touched", None) + state["__dict__"].pop("_changed_computed_vars", None) + state["__dict__"].pop("_ready_computed_vars", None) + state["__fields_set__"].discard("_changed_computed_vars") + state["__fields_set__"].discard("_ready_computed_vars") # Remove all inherited vars. for inherited_var_name in self.inherited_vars: state["__dict__"].pop(inherited_var_name, None) @@ -2150,6 +2161,9 @@ def __setstate__(self, state: dict[str, Any]): state["__dict__"]["parent_state"] = None state["__dict__"]["substates"] = {} super().__setstate__(state) + self._was_touched = False + self._changed_computed_vars = set() + self._ready_computed_vars = set() def _check_state_size( self, @@ -3131,6 +3145,8 @@ async def get_state( root_state = self.states.get(client_token) if root_state is not None: # Retrieved state from memory. + root_state._changed_computed_vars = set() + root_state._ready_computed_vars = set() return root_state # Deserialize root state from disk. diff --git a/reflex/utils/types.py b/reflex/utils/types.py index b8bcbf2d69..7db53d813f 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -121,6 +121,8 @@ def override(func: Callable) -> Callable: "_abc_impl", "_backend_vars", "_was_touched", + "_changed_computed_vars", + "_ready_computed_vars", } if sys.version_info >= (3, 11): diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 941a9d81ab..0101833ba8 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -2022,18 +2022,7 @@ def __get__(self, instance: BaseState | None, owner): existing_var=self, ) - if not self._cache: - value = self.fget(instance) - else: - # handle caching - if not hasattr(instance, self._cache_attr) or self.needs_update(instance): - # Set cache attr on state instance. - setattr(instance, self._cache_attr, self.fget(instance)) - # Ensure the computed var gets serialized to redis. - instance._was_touched = True - # Set the last updated timestamp on the state instance. - setattr(instance, self._last_updated_attr, datetime.datetime.now()) - value = getattr(instance, self._cache_attr) + value = self.get_value(instance) if not _isinstance(value, self._var_type): console.deprecate( @@ -2158,14 +2147,71 @@ def _deps( self_is_top_of_stack = False return d - def mark_dirty(self, instance) -> None: + def mark_dirty(self, instance: BaseState) -> None: """Mark this ComputedVar as dirty. Args: instance: the state instance that needs to recompute the value. """ - with contextlib.suppress(AttributeError): - delattr(instance, self._cache_attr) + instance._ready_computed_vars.discard(self._js_expr) + + def already_computed(self, instance: BaseState) -> bool: + """Check if the ComputedVar has already been computed. + + Args: + instance: the state instance that needs to recompute the value. + + Returns: + True if the ComputedVar has already been computed, False otherwise. + """ + if self.needs_update(instance): + return False + return self._js_expr in instance._ready_computed_vars + + def get_value(self, instance: BaseState) -> RETURN_TYPE: + """Get the value of the ComputedVar. + + Args: + instance: the state instance that needs to recompute the value. + + Returns: + The value of the ComputedVar. + """ + if not self._cache: + instance._was_touched = True + new = self.fget(instance) + return new + + has_cache = hasattr(instance, self._cache_attr) + + if self.already_computed(instance) and has_cache: + return getattr(instance, self._cache_attr) + + cache_value = getattr(instance, self._cache_attr, None) + instance._ready_computed_vars.add(self._js_expr) + setattr(instance, self._last_updated_attr, datetime.datetime.now()) + new_value = self.fget(instance) + if cache_value != new_value: + instance._changed_computed_vars.add(self._js_expr) + instance._was_touched = True + setattr(instance, self._cache_attr, new_value) + return new_value + + def has_changed(self, instance: BaseState) -> bool: + """Check if the ComputedVar value has changed. + + Args: + instance: the state instance that needs to recompute the value. + + Returns: + True if the value has changed, False otherwise. + """ + if not self._cache: + return True + if self._js_expr in instance._changed_computed_vars: + return True + # TODO: prime the cache if it's not already? creates side effects and breaks order of computed var execution + return self._js_expr in instance._changed_computed_vars def _determine_var_type(self) -> Type: """Get the type of the var. diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 912d72f4f1..f47487f4aa 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -3563,6 +3563,33 @@ class DillState(BaseState): _ = state3._serialize() +def test_pickle(): + class PickleState(BaseState): + pass + + state = PickleState(_reflex_internal_init=True) # type: ignore + + # test computed var cache is persisted + setattr(state, "__cvcached", 1) + state = PickleState._deserialize(state._serialize()) + assert getattr(state, "__cvcached", None) == 1 + + # test ready computed vars set is not persisted + state._ready_computed_vars = {"foo"} + state = PickleState._deserialize(state._serialize()) + assert not state._ready_computed_vars + + # test that changed computed vars set is not persisted + state._changed_computed_vars = {"foo"} + state = PickleState._deserialize(state._serialize()) + assert not state._changed_computed_vars + + # test was_touched is not persisted + state._was_touched = True + state = PickleState._deserialize(state._serialize()) + assert not state._was_touched + + def test_typed_state() -> None: class TypedState(rx.State): field: rx.Field[str] = rx.field("")