Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REF-1988] API to Get instance of Arbitrary State class #2678

Merged
merged 31 commits into from
Feb 27, 2024
Merged

Conversation

masenf
Copy link
Collaborator

@masenf masenf commented Feb 21, 2024

Implements a new api on rx.State:

other = await self.get_state(MyOtherState)

This is called from an event handler to get access to the latest instance of MyOtherState associated with the same token as the currently running event handler.

How it Works

With the recent Redis state sharding PR #2574, instead of fetching the entire state tree to process an event, only the "branch" of the state tree pertaining to the event handler is fetched from redis and deserialized.

For the given state containing the event handler the following are considered a "branch":

  • all parent states (up to rx.State)
    • any substates of any parent with computed vars (so that the values can be recomputed after the event handler runs)
  • all substates (recursively)

To access other "sibling" / "cousin" states, a new async API State.get_state is provided which can dynamically fetch the requested state from redis and patch it into the state tree as another "branch".

Practically, this means that state can scale horizontally without significant performance overhead (aside from hydration which occurs when the websocket comes up and on_load which occurs on page navigation). It also minimizes the changes needed to the existing State API, retaining compatibility with existing apps (that don't use parent_state chasing).

Delta calculations are performed in exactly the same way as before: by finding the top-level rx.State and recursively calling get_delta on all its substates... the difference is that not all substates are included anymore, so the calculation can be performed more quickly.

Similarly during serialization, the top-level rx.State and its substates are persisted to redis, but only the substates that were previously fetched are serialized. If an event only targets a single substate in a mostly flat state hierachy, this has significant performance benefits for rapidly occuring events, like controlled sliders.

Limitations and Caveats

Computed vars and cached computed vars cannot use this API, because properties do not work with async / await. This is probably fine: just define the computed var within the state containing the vars it will use.

Avoid computed vars in a deeply nested substate, because all of the parents will have to be fetched and instantiated for every event 😱. Any substate containing a computed var will be fetched for every event, to ensure it always gets recalculated.

Using an rx.cached_var is nice here, because those only get fetched when the "branch" containing it is fetched, since we already know what vars and substates the cached var depends on, and can be assured that it can only depend on vars in itself or one of its parent states. Breaking Change: When calculating the deps of a computed var, we no longer allow access to parent_state, substates, get_substate, or get_state (and it wouldn't work anyway). This ensures that all vars can be properly accounted for, and avoids the user shooting themselves in the foot with a cached_var that won't update properly.

Background events must first call async with self before accessing this API. However after the block exits, there is no protection against subsequent modifications to the instance of the state, although these changes cannot be persisted and will not be considered for delta calculation (i.e. it just wont work). Down the road, we might implement some guardrails that will raise an exception when attempting to modify a stale state instance.

All of this only applies to redis use. The in-memory state manager used in dev mode works the same way it always has: keeping everything in memory all the time.

Bonus

  • [REF-2009] Instantiating a State class directly now raises an exception pointing to the state docs
  • [CR Feedback] New helper methods for constructing and deconstructing "substate_key" (the combination of client_token and state full name)
  • State sharding is now using create_task to achieve better concurrency. In reflex-web, this results in a 2x speedup on initial state creation (for a new user) and a 10x speedup deserializing existing app state.
  • All the goodness from Move on_load_internal and update_vars_internal to substates #2725
    • [REF-2038] Do not send update_vars_internal as on_load when no cookies are set
    • More efficient on_load_internal handling.
    • More efficient state class initialization when the state does not exist in redis for the given token.

Some example code i was playing with

Don't ask me what it does 😹

import asyncio
import datetime

import reflex as rx


class State(rx.State):
    foo: str = ""


class Sub(rx.State):
    bar: str = ""

    def do_reset(self):
        self.reset()

    async def do_mod_other(self):
        other = await self.get_state(Other)
        other.quuc = str(len(self.bar))


class Child(Sub):
    baz: str = ""

    async def get_in_other_tree(self):
        sub = await self.get_state(Sub)
        print(sub.bar)
        gc = await self.get_state(GChild)
        print(gc.bar_cap)
        state = await self.get_state(rx.State)
        print(state.is_hydrated)


class GChild(Child):
    @rx.cached_var
    def bar_cap(self) -> str:
        return self.bar.upper()

    @rx.cached_var
    def baz_cap(self) -> str:
        return self.baz.upper()


class Other(rx.State):
    quuc: str = ""

    async def set_quuc(self, value: str):
        self.quuc = value
        child = await self.get_state(Child)
        child.baz = value


class DTState(rx.State):
    @rx.var
    def last_ts(self) -> str:
        return datetime.datetime.now().isoformat()


class DTChild(Other):
    @rx.cached_var
    def last_quuc(self) -> str:
        return self.quuc + " at " + datetime.datetime.now().isoformat()


class BackgroundState(rx.State):
    @rx.background
    async def background(self):
        # The following line raises an exception
        # print(self.substates)
        await asyncio.sleep(1)
        async with self:
            other = await self.get_state(Other)
            other.quuc = "From background"
            self.get_substate(tuple(BGSub.get_full_name().split(".")[1:])).bg_foo = "From background2"
        await asyncio.sleep(5)
        async with self:
            bg_sub = await self.get_state(BGSub)
            bg_sub.bg_foo = "Task Complete"


class BGSub(BackgroundState):
    bg_foo: str


class RouterDepState(BGSub):
    @rx.cached_var
    def router_dep(self) -> str:
        return self.router.page.full_raw_path


def index() -> rx.Component:
    return rx.fragment(
        rx.color_mode.button(rx.color_mode.icon(), float="right"),
        rx.vstack(
            rx.input(value=State.foo, on_change=State.set_foo),
            rx.input(value=Sub.bar, on_change=Sub.set_bar),
            rx.input(value=Child.baz, on_change=Child.set_baz),
            rx.input(value=Other.quuc, on_change=Other.set_quuc),
            rx.button("Reset", on_click=Sub.do_reset),
            rx.button("Get in other tree", on_click=Child.get_in_other_tree),
            rx.button("Measure bar", on_click=Sub.do_mod_other),
            rx.button(f"Background {BGSub.bg_foo}", on_click=BackgroundState.background),
            rx.heading("Quuc: ", Other.quuc),
            rx.heading("Last TS: ", DTState.last_ts),
            rx.heading("Last Quuc: ", DTChild.last_quuc),
            rx.text("GChild.bar_cap: ", GChild.bar_cap),
            rx.text("GChild.baz_cap: ", GChild.baz_cap),
            rx.text("Full path raw: ", RouterDepState.router_dep),
            rx.link("Go to another page", href="/foo"),
            gap="1.5em",
            padding_top="10%",
            width="50%",
        ),
    )


# Create app instance and add index page.
app = rx.App()
app.add_page(index)
app.add_page(index, route="/foo")

@masenf masenf marked this pull request as draft February 21, 2024 02:17
Base automatically changed from masenf/state-shard-redis2 to main February 21, 2024 09:50
Rudimentary protection for state instance access from a background task
(StateProxy)
Fix StateProxy for substates and parent_state attributes (have to handle in
__getattr__, not property)

Fix type annotation for `get_state`
Reset the substate tracking only when the class is instantiated.
Ensure that `get_state` returns the proper "branch" of the state tree depending
on what substate is requested.
…True

Avoid user errors trying to directly instantiate State classes
Unify the implementation of generating and decoding the token + state name
format used for redis state sharding.
read and write substates concurrently (allow redis to shine)
@masenf masenf changed the title WiP Get arbitrary state [REF-1988] API to Get instance of Arbitrary State class Feb 21, 2024
@masenf masenf marked this pull request as ready for review February 21, 2024 11:03
Copy link
Contributor

@picklelo picklelo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The syntax for accessing the state looks good.

On a broader note, I'm scared the code is getting too complicated, particularly now the state class has so many methods and that file is huge. It

The methods are also a bit intimidating - for a newcomer I'm not sure they'd be able to reason around our codebase very easily now. We should have a design meeting to look over how we're doing anything to see if there's some things we can simplify.

reflex/state.py Outdated
"""
if not _reflex_internal_init and not _is_testing_env():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if instead of checking _is_testing_env we should check the inverse, is_reflex_run. So that when external users work with it, they don't have to set this flag (similar to the hack I had to do in flexdown)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets talk about this on monday, i'm not really grokking the suggestion. how does is_reflex_run get determined?

@@ -218,24 +262,39 @@ def __init__(
*args,
parent_state: BaseState | None = None,
init_substates: bool = True,
_reflex_internal_init: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we handle this with only the flag to avoid adding another argument here.

Copy link
Collaborator Author

@masenf masenf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a broader note, I'm scared the code is getting too complicated, particularly now the state class has so many methods and that file is huge.

I agree, there's too much going on inside the BaseState. I have some ideas for simplifying it, particularly via the use of descriptors for Var and EventHandler access, could clean up a lot of the __init_subclass__ logic where we treat the same attribute differently on the class vs the instance.

The methods are also a bit intimidating

I split up get_state into 6 smaller methods with hopefully more descriptive readable names to help paint a picture of what's going on. I don't see an easy way to get away from the intimidation of the BaseState generally though... I was thinking we could split up some of the distinct functionality into a few mixin classes (i.e. one for var/dirty/delta management, one for substate/parent state management, and one for event processing), but i'm not sure that would be more enlightening or worth the time at this point.

I think a bigger state refactoring is on the horizon, but I don't think we can justify further investment on the BaseState after getting this API in.

Copy link
Collaborator Author

@masenf masenf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a problem here if a substate has a cached_var that depends on something in a parent state, but that substate isn't loaded because the event is being processed on a sibling state, then the cached var doesn't update (and it throws an exception).

This has been fixed with the merge of #2725, now cached_var are fetched, similarly to var. This might be inefficient, but it's necessary to ensure consistency of updates.

@picklelo picklelo mentioned this pull request Feb 27, 2024
27 tasks
Avoid loading the entire state tree to process these common internal events. If
the state tree is very large, this allow page navigation to occur more quickly.

Pre-fetch substates that contain cached vars, as they may need to be recomputed
if certain vars change.
This is a waste of time and memory, and can be handled via a special case in
__getattribute__
Avoid wasting time serializing states that have no modifications
Wait until the state is actually modified, and then persist it as part of `set_state`.

Factor out common logic into helper methods for readability and to reduce
duplication of common logic.

To avoid having to recursively call `get_state`, which would require persisting
the instance and then getting it again, some of the initialization logic
regarding parent_state and substates is duplicated when creating a new
instance. This is for performance reasons.
Copy link

linear bot commented Feb 27, 2024

substate = substates[substate_name]
substate.dirty_vars.add(var)
substate._mark_dirty()

def _get_was_touched(self) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be called set_was_touched? I got confused below on line 1566 when you called it, as I didn't realize this also sets the value. The get is just the attribute itself right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah i went back and forth on what to call this one... The flag is called _was_touched, and this function _get_was_touched checks both the flag and other conditions (and then updates the flag when those conditions indicate touching has occurred). The reason you would call this function is to decide whether the state has been modified since it was instantiated. I'll split the getting and updating logic out to clarify.

@masenf masenf merged commit deae662 into main Feb 27, 2024
45 checks passed
@masenf masenf deleted the masenf/get-state branch February 27, 2024 21:02
benedikt-bartscher pushed a commit to benedikt-bartscher/reflex that referenced this pull request Feb 27, 2024
)

* WiP get_state

* Refactor get_state fast path

Rudimentary protection for state instance access from a background task
(StateProxy)

* retain dirty substate marking per `_mark_dirty` call to avoid test changes

* Find common ancestor by part instead of by character

Fix StateProxy for substates and parent_state attributes (have to handle in
__getattr__, not property)

Fix type annotation for `get_state`

* test_state: workflow test for `get_state` functionality

* Do not reset _always_dirty_substates when adding vars

Reset the substate tracking only when the class is instantiated.

* test_state_tree: test substate access in a larger state tree

Ensure that `get_state` returns the proper "branch" of the state tree depending
on what substate is requested.

* test_format: fixup broken tests from adding substates of TestState

* Fix flaky integration tests with more polling

* AppHarness: reset _always_dirty_substates on rx.State

* RuntimeError unless State is instantiated with _reflex_internal_init=True

Avoid user errors trying to directly instantiate State classes

* Helper functions for _substate_key and _split_substate_key

Unify the implementation of generating and decoding the token + state name
format used for redis state sharding.

* StateManagerRedis: use create_task in get_state and set_state

read and write substates concurrently (allow redis to shine)

* test_state_inheritance: use polling cuz life too short for flaky tests

kthnxbai ❤️

* Move _is_testing_env to reflex.utils.exec.is_testing_env

Reuse the code in app.py

* Break up `BaseState.get_state` and friends into separate methods

* Add test case for pre-fetching cached var dependency

* Move on_load_internal and update_vars_internal to substates

Avoid loading the entire state tree to process these common internal events. If
the state tree is very large, this allow page navigation to occur more quickly.

Pre-fetch substates that contain cached vars, as they may need to be recomputed
if certain vars change.

* Do not copy ROUTER_DATA into all substates.

This is a waste of time and memory, and can be handled via a special case in
__getattribute__

* Track whether State instance _was_touched

Avoid wasting time serializing states that have no modifications

* Do not persist states in `StateManagerRedis.get_state`

Wait until the state is actually modified, and then persist it as part of `set_state`.

Factor out common logic into helper methods for readability and to reduce
duplication of common logic.

To avoid having to recursively call `get_state`, which would require persisting
the instance and then getting it again, some of the initialization logic
regarding parent_state and substates is duplicated when creating a new
instance. This is for performance reasons.

* Remove stray print()

* context.js.jinja2: fix check for empty local storage / cookie vars

* Add comments for onLoadInternalEvent and initialEvents

* nit: typo

* split _get_was_touched into _update_was_touched

Improve clarity in cases where _get_was_touched was being called for its side
effects only.

* Remove extraneous information from incorrect State instantiation error

* Update missing redis exception message
Yummy-Yums pushed a commit to Yummy-Yums/reflex-codebase that referenced this pull request Feb 28, 2024
)

* WiP get_state

* Refactor get_state fast path

Rudimentary protection for state instance access from a background task
(StateProxy)

* retain dirty substate marking per `_mark_dirty` call to avoid test changes

* Find common ancestor by part instead of by character

Fix StateProxy for substates and parent_state attributes (have to handle in
__getattr__, not property)

Fix type annotation for `get_state`

* test_state: workflow test for `get_state` functionality

* Do not reset _always_dirty_substates when adding vars

Reset the substate tracking only when the class is instantiated.

* test_state_tree: test substate access in a larger state tree

Ensure that `get_state` returns the proper "branch" of the state tree depending
on what substate is requested.

* test_format: fixup broken tests from adding substates of TestState

* Fix flaky integration tests with more polling

* AppHarness: reset _always_dirty_substates on rx.State

* RuntimeError unless State is instantiated with _reflex_internal_init=True

Avoid user errors trying to directly instantiate State classes

* Helper functions for _substate_key and _split_substate_key

Unify the implementation of generating and decoding the token + state name
format used for redis state sharding.

* StateManagerRedis: use create_task in get_state and set_state

read and write substates concurrently (allow redis to shine)

* test_state_inheritance: use polling cuz life too short for flaky tests

kthnxbai ❤️

* Move _is_testing_env to reflex.utils.exec.is_testing_env

Reuse the code in app.py

* Break up `BaseState.get_state` and friends into separate methods

* Add test case for pre-fetching cached var dependency

* Move on_load_internal and update_vars_internal to substates

Avoid loading the entire state tree to process these common internal events. If
the state tree is very large, this allow page navigation to occur more quickly.

Pre-fetch substates that contain cached vars, as they may need to be recomputed
if certain vars change.

* Do not copy ROUTER_DATA into all substates.

This is a waste of time and memory, and can be handled via a special case in
__getattribute__

* Track whether State instance _was_touched

Avoid wasting time serializing states that have no modifications

* Do not persist states in `StateManagerRedis.get_state`

Wait until the state is actually modified, and then persist it as part of `set_state`.

Factor out common logic into helper methods for readability and to reduce
duplication of common logic.

To avoid having to recursively call `get_state`, which would require persisting
the instance and then getting it again, some of the initialization logic
regarding parent_state and substates is duplicated when creating a new
instance. This is for performance reasons.

* Remove stray print()

* context.js.jinja2: fix check for empty local storage / cookie vars

* Add comments for onLoadInternalEvent and initialEvents

* nit: typo

* split _get_was_touched into _update_was_touched

Improve clarity in cases where _get_was_touched was being called for its side
effects only.

* Remove extraneous information from incorrect State instantiation error

* Update missing redis exception message
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants