Skip to content

Commit

Permalink
[REF-1988] API to Get instance of Arbitrary State class (#2678)
Browse files Browse the repository at this point in the history
* WiP get_state

* Refactor get_state fast path

Rudimentary protection for state instance access from a background task
(StateProxy)

* retain dirty substate marking per `_mark_dirty` call to avoid test changes

* Find common ancestor by part instead of by character

Fix StateProxy for substates and parent_state attributes (have to handle in
__getattr__, not property)

Fix type annotation for `get_state`

* test_state: workflow test for `get_state` functionality

* Do not reset _always_dirty_substates when adding vars

Reset the substate tracking only when the class is instantiated.

* test_state_tree: test substate access in a larger state tree

Ensure that `get_state` returns the proper "branch" of the state tree depending
on what substate is requested.

* test_format: fixup broken tests from adding substates of TestState

* Fix flaky integration tests with more polling

* AppHarness: reset _always_dirty_substates on rx.State

* RuntimeError unless State is instantiated with _reflex_internal_init=True

Avoid user errors trying to directly instantiate State classes

* Helper functions for _substate_key and _split_substate_key

Unify the implementation of generating and decoding the token + state name
format used for redis state sharding.

* StateManagerRedis: use create_task in get_state and set_state

read and write substates concurrently (allow redis to shine)

* test_state_inheritance: use polling cuz life too short for flaky tests

kthnxbai ❤️

* Move _is_testing_env to reflex.utils.exec.is_testing_env

Reuse the code in app.py

* Break up `BaseState.get_state` and friends into separate methods

* Add test case for pre-fetching cached var dependency

* Move on_load_internal and update_vars_internal to substates

Avoid loading the entire state tree to process these common internal events. If
the state tree is very large, this allow page navigation to occur more quickly.

Pre-fetch substates that contain cached vars, as they may need to be recomputed
if certain vars change.

* Do not copy ROUTER_DATA into all substates.

This is a waste of time and memory, and can be handled via a special case in
__getattribute__

* Track whether State instance _was_touched

Avoid wasting time serializing states that have no modifications

* Do not persist states in `StateManagerRedis.get_state`

Wait until the state is actually modified, and then persist it as part of `set_state`.

Factor out common logic into helper methods for readability and to reduce
duplication of common logic.

To avoid having to recursively call `get_state`, which would require persisting
the instance and then getting it again, some of the initialization logic
regarding parent_state and substates is duplicated when creating a new
instance. This is for performance reasons.

* Remove stray print()

* context.js.jinja2: fix check for empty local storage / cookie vars

* Add comments for onLoadInternalEvent and initialEvents

* nit: typo

* split _get_was_touched into _update_was_touched

Improve clarity in cases where _get_was_touched was being called for its side
effects only.

* Remove extraneous information from incorrect State instantiation error

* Update missing redis exception message
  • Loading branch information
masenf authored Feb 27, 2024
1 parent bf07315 commit deae662
Show file tree
Hide file tree
Showing 15 changed files with 1,208 additions and 153 deletions.
4 changes: 2 additions & 2 deletions integration/test_client_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,17 +518,17 @@ def set_sub_sub(var: str, value: str):
set_sub("l6", "l6 value")
l5 = driver.find_element(By.ID, "l5")
l6 = driver.find_element(By.ID, "l6")
assert AppHarness._poll_for(lambda: l6.text == "l6 value")
assert l5.text == "l5 value"
assert l6.text == "l6 value"

# Switch back to main window.
driver.switch_to.window(main_tab)

# The values should have updated automatically.
l5 = driver.find_element(By.ID, "l5")
l6 = driver.find_element(By.ID, "l6")
assert AppHarness._poll_for(lambda: l6.text == "l6 value")
assert l5.text == "l5 value"
assert l6.text == "l6 value"

# clear the cookie jar and local storage, ensure state reset to default
driver.delete_all_cookies()
Expand Down
23 changes: 19 additions & 4 deletions integration/test_state_inheritance.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,29 @@
"""Test state inheritance."""

import time
from contextlib import suppress
from typing import Generator

import pytest
from selenium.common.exceptions import NoAlertPresentException
from selenium.webdriver.common.alert import Alert
from selenium.webdriver.common.by import By

from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver


def get_alert_or_none(driver: WebDriver) -> Alert | None:
"""Switch to an alert if present.
Args:
driver: WebDriver instance.
Returns:
The alert if present, otherwise None.
"""
with suppress(NoAlertPresentException):
return driver.switch_to.alert


def raises_alert(driver: WebDriver, element: str) -> None:
"""Click an element and check that an alert is raised.
Expand All @@ -18,8 +33,8 @@ def raises_alert(driver: WebDriver, element: str) -> None:
"""
btn = driver.find_element(By.ID, element)
btn.click()
time.sleep(0.2) # wait for the alert to appear
alert = driver.switch_to.alert
alert = AppHarness._poll_for(lambda: get_alert_or_none(driver))
assert isinstance(alert, Alert)
assert alert.text == "clicked"
alert.accept()

Expand Down Expand Up @@ -355,7 +370,7 @@ def test_state_inheritance(
child3_other_mixin_btn = driver.find_element(By.ID, "child3-other-mixin-btn")
child3_other_mixin_btn.click()
child2_other_mixin_value = state_inheritance.poll_for_content(
child2_other_mixin, exp_not_equal="other_mixin"
child2_other_mixin, exp_not_equal="Child2.clicked.1"
)
child2_computed_mixin_value = state_inheritance.poll_for_content(
child2_computed_other_mixin, exp_not_equal="other_mixin"
Expand Down
28 changes: 24 additions & 4 deletions reflex/.templates/jinja/web/utils/context.js.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,31 @@ export const clientStorage = {}

{% if state_name %}
export const state_name = "{{state_name}}"
export const onLoadInternalEvent = () => [
Event('{{state_name}}.{{const.update_vars_internal}}', {vars: hydrateClientStorage(clientStorage)}),
Event('{{state_name}}.{{const.on_load_internal}}')
]

// Theses events are triggered on initial load and each page navigation.
export const onLoadInternalEvent = () => {
const internal_events = [];

// Get tracked cookie and local storage vars to send to the backend.
const client_storage_vars = hydrateClientStorage(clientStorage);
// But only send the vars if any are actually set in the browser.
if (client_storage_vars && Object.keys(client_storage_vars).length !== 0) {
internal_events.push(
Event(
'{{state_name}}.{{const.update_vars_internal}}',
{vars: client_storage_vars},
),
);
}

// `on_load_internal` triggers the correct on_load event(s) for the current page.
// If the page does not define any on_load event, this will just set `is_hydrated = true`.
internal_events.push(Event('{{state_name}}.{{const.on_load_internal}}'));

return internal_events;
}

// The following events are sent when the websocket connects or reconnects.
export const initialEvents = () => [
Event('{{state_name}}.{{const.hydrate}}'),
...onLoadInternalEvent()
Expand Down
2 changes: 1 addition & 1 deletion reflex/.templates/web/utils/state.js
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ export const useEventLoop = (
if (storage_to_state_map[e.key]) {
const vars = {}
vars[storage_to_state_map[e.key]] = e.newValue
const event = Event(`${state_name}.update_vars_internal`, {vars: vars})
const event = Event(`${state_name}.update_vars_internal_state.update_vars_internal`, {vars: vars})
addEvents([event], e);
}
};
Expand Down
10 changes: 6 additions & 4 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@
State,
StateManager,
StateUpdate,
_substate_key,
code_uses_state_contexts,
)
from reflex.utils import console, exceptions, format, prerequisites, types
from reflex.utils.exec import is_testing_env
from reflex.utils.imports import ImportVar

# Define custom types.
Expand Down Expand Up @@ -159,10 +161,9 @@ def __init__(self, *args, **kwargs):
)
super().__init__(*args, **kwargs)
state_subclasses = BaseState.__subclasses__()
is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ

# Special case to allow test cases have multiple subclasses of rx.BaseState.
if not is_testing_env:
if not is_testing_env():
# Only one Base State class is allowed.
if len(state_subclasses) > 1:
raise ValueError(
Expand All @@ -176,7 +177,8 @@ def __init__(self, *args, **kwargs):
deprecation_version="0.3.5",
removal_version="0.5.0",
)
if len(State.class_subclasses) > 0:
# 2 substates are built-in and not considered when determining if app is stateless.
if len(State.class_subclasses) > 2:
self.state = State
# Get the config
config = get_config()
Expand Down Expand Up @@ -1002,7 +1004,7 @@ async def upload_file(request: Request, files: List[UploadFile]):
)

# Get the state for the session.
substate_token = token + "_" + handler.rpartition(".")[0]
substate_token = _substate_key(token, handler.rpartition(".")[0])
state = await app.state_manager.get_state(substate_token)

# get the current session ID
Expand Down
4 changes: 2 additions & 2 deletions reflex/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,12 @@ def compile_state(state: Type[BaseState]) -> dict:
A dictionary of the compiled state.
"""
try:
initial_state = state().dict(initial=True)
initial_state = state(_reflex_internal_init=True).dict(initial=True)
except Exception as e:
console.warn(
f"Failed to compile initial state with computed vars, excluding them: {e}"
)
initial_state = state().dict(include_computed=False)
initial_state = state(_reflex_internal_init=True).dict(include_computed=False)
return format.format_state(initial_state)


Expand Down
4 changes: 2 additions & 2 deletions reflex/constants/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ class CompileVars(SimpleNamespace):
# The name of the function for converting a dict to an event.
TO_EVENT = "Event"
# The name of the internal on_load event.
ON_LOAD_INTERNAL = "on_load_internal"
ON_LOAD_INTERNAL = "on_load_internal_state.on_load_internal"
# The name of the internal event to update generic state vars.
UPDATE_VARS_INTERNAL = "update_vars_internal"
UPDATE_VARS_INTERNAL = "update_vars_internal_state.update_vars_internal"


class PageNames(SimpleNamespace):
Expand Down
Loading

0 comments on commit deae662

Please sign in to comment.