Skip to content

Commit

Permalink
Helper functions for _substate_key and _split_substate_key
Browse files Browse the repository at this point in the history
Unify the implementation of generating and decoding the token + state name
format used for redis state sharding.
  • Loading branch information
masenf committed Feb 21, 2024
1 parent 542ce97 commit 48bd579
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 40 deletions.
3 changes: 2 additions & 1 deletion reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
State,
StateManager,
StateUpdate,
_substate_key,
code_uses_state_contexts,
)
from reflex.utils import console, exceptions, format, prerequisites, types
Expand Down Expand Up @@ -1002,7 +1003,7 @@ async def upload_file(request: Request, files: List[UploadFile]):
)

# Get the state for the session.
substate_token = token + "_" + handler.rpartition(".")[0]
substate_token = _substate_key(token, handler.rpartition(".")[0])
state = await app.state_manager.get_state(substate_token)

# get the current session ID
Expand Down
74 changes: 55 additions & 19 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,41 @@ def __init__(self, router_data: Optional[dict] = None):
}


def _substate_key(
token: str,
state_cls_or_name: BaseState | Type[BaseState] | str | list[str],
) -> str:
"""Get the substate key.
Args:
token: The token of the state.
state_cls_or_name: The state class/instance or name or sequence of name parts.
Returns:
The substate key.
"""
if isinstance(state_cls_or_name, BaseState) or (
isinstance(state_cls_or_name, type) and issubclass(state_cls_or_name, BaseState)
):
state_cls_or_name = state_cls_or_name.get_full_name()
elif isinstance(state_cls_or_name, (list, tuple)):
state_cls_or_name = ".".join(state_cls_or_name)
return f"{token}_{state_cls_or_name}"


def _split_substate_key(substate_key: str) -> tuple[str, str]:
"""Split the substate key into token and state name.
Args:
substate_key: The substate key.
Returns:
Tuple of token and state name.
"""
token, _, state_name = substate_key.partition("_")
return token, state_name


def _is_testing_env() -> bool:
"""Check if the app is running in a testing environment.
Expand Down Expand Up @@ -1117,17 +1152,18 @@ async def get_state(self, state_cls: Type[BaseState]) -> BaseState:
# Fetch all missing parent states and link them up to the common ancestor.
parent_state = parent_states_by_name[common_ancestor_name]
for parent_state_name in fetch_parent_states[1:-1]:
state_token = self.router.session.client_token + "_" + parent_state_name
parent_state = await state_manager.get_state(
state_token,
token=_substate_key(
self.router.session.client_token, parent_state_name
),
top_level=False,
get_substates=False,
parent_state=parent_state,
)

# Then get the target state and all its substates.
return await state_manager.get_state(
self.router.session.client_token + "_" + state_cls.get_full_name(),
token=_substate_key(self.router.session.client_token, state_cls),
top_level=False,
get_substates=True,
parent_state=parent_state,
Expand Down Expand Up @@ -1606,9 +1642,10 @@ async def __aenter__(self) -> StateProxy:
This StateProxy instance in mutable mode.
"""
self._self_actx = self._self_app.modify_state(
self.__wrapped__.router.session.client_token
+ "_"
+ ".".join(self._self_substate_path)
token=_substate_key(
self.__wrapped__.router.session.client_token,
self._self_substate_path,
)
)
mutable_state = await self._self_actx.__aenter__()
super().__setattr__(
Expand Down Expand Up @@ -1868,7 +1905,7 @@ async def get_state(self, token: str) -> BaseState:
The state for the token.
"""
# Memory state manager ignores the substate suffix and always returns the top-level state.
token = token.partition("_")[0]
token = _split_substate_key(token)[0]
if token not in self.states:
self.states[token] = self.state(_reflex_internal_init=True)
return self.states[token]
Expand All @@ -1893,7 +1930,7 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
The state for the token.
"""
# Memory state manager ignores the substate suffix and always returns the top-level state.
token = token.partition("_")[0]
token = _split_substate_key(token)[0]
if token not in self._states_locks:
async with self._state_manager_lock:
if token not in self._states_locks:
Expand Down Expand Up @@ -1955,7 +1992,7 @@ async def get_state(
RuntimeError: when the state_cls is not specified in the token
"""
# Split the actual token from the fully qualified substate name.
client_token, _, state_path = token.partition("_")
client_token, state_path = _split_substate_key(token)
if state_path:
# Get the State class associated with the given path.
state_cls = self.state.get_class_substate(tuple(state_path.split(".")))
Expand Down Expand Up @@ -1998,9 +2035,8 @@ async def get_state(
# Retrieve necessary substates from redis.
for substate_cls in fetch_substates:
substate_name = substate_cls.get_name()
substate_key = token + "." + substate_name
state.substates[substate_name] = await self.get_state(
substate_key,
token=_substate_key(client_token, substate_cls),
top_level=False,
get_substates=get_substates,
parent_state=state,
Expand All @@ -2018,9 +2054,10 @@ async def get_state(
parent_state_name = state_path.rpartition(".")[0]
if parent_state_name:
# Retrieve the parent state to populate event handlers onto this substate.
parent_state_key = client_token + "_" + parent_state_name
parent_state = await self.get_state(
parent_state_key, top_level=False, get_substates=False
token=_substate_key(client_token, parent_state_name),
top_level=False,
get_substates=False,
)
# Persist the new state class to redis.
await self.set_state(
Expand Down Expand Up @@ -2066,23 +2103,22 @@ async def set_state(
f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
"or use `@rx.background` decorator for long-running tasks."
)
client_token, _, substate_name = token.partition("_")
client_token, substate_name = _split_substate_key(token)
# If the substate name on the token doesn't match the instance name, it cannot have a parent.
if state.parent_state is not None and state.get_full_name() != substate_name:
raise RuntimeError(
f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}."
)
# Recursively set_state on all known substates.
for substate in state.substates.values():
substate_key = client_token + "_" + substate.get_full_name()
await self.set_state(
substate_key,
substate,
token=_substate_key(client_token, substate),
state=substate,
lock_id=lock_id,
)
# Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
await self.redis.set(
client_token + "_" + state.get_full_name(),
_substate_key(client_token, state),
cloudpickle.dumps(state),
ex=self.token_expiration,
)
Expand Down Expand Up @@ -2113,7 +2149,7 @@ def _lock_key(token: str) -> bytes:
The redis lock key for the token.
"""
# All substates share the same lock domain, so ignore any substate path suffix.
client_token = token.partition("_")[0]
client_token = _split_substate_key(token)[0]
return f"{client_token}_lock".encode()

async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
Expand Down
18 changes: 12 additions & 6 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@
from reflex.event import Event
from reflex.middleware import HydrateMiddleware
from reflex.model import Model
from reflex.state import BaseState, RouterData, State, StateManagerRedis, StateUpdate
from reflex.state import (
BaseState,
RouterData,
State,
StateManagerRedis,
StateUpdate,
_substate_key,
)
from reflex.style import Style
from reflex.utils import format
from reflex.vars import ComputedVar
Expand Down Expand Up @@ -340,7 +347,7 @@ async def test_initialize_with_state(test_state: Type[ATestState], token: str):
assert app.state == test_state

# Get a state for a given token.
state = await app.state_manager.get_state(f"{token}_{test_state.get_full_name()}")
state = await app.state_manager.get_state(_substate_key(token, test_state))
assert isinstance(state, test_state)
assert state.var == 0 # type: ignore

Expand Down Expand Up @@ -744,8 +751,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker):
# The App state must be the "root" of the state tree
app = App(state=State)
app.event_namespace.emit = AsyncMock() # type: ignore
substate_token = f"{token}_{state.get_full_name()}"
current_state = await app.state_manager.get_state(substate_token)
current_state = await app.state_manager.get_state(_substate_key(token, state))
data = b"This is binary data"

# Create a binary IO object and write data to it
Expand Down Expand Up @@ -774,7 +780,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker):
== StateUpdate(delta=delta, events=[], final=True).json() + "\n"
)

current_state = await app.state_manager.get_state(substate_token)
current_state = await app.state_manager.get_state(_substate_key(token, state))
state_dict = current_state.dict()[state.get_full_name()]
assert state_dict["img_list"] == [
"image1.jpg",
Expand Down Expand Up @@ -928,7 +934,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
}
assert constants.ROUTER in app.state()._computed_var_dependencies

substate_token = f"{token}_{DynamicState.get_full_name()}"
substate_token = _substate_key(token, DynamicState)
sid = "mock_sid"
client_ip = "127.0.0.1"
state = await app.state_manager.get_state(substate_token)
Expand Down
28 changes: 18 additions & 10 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
StateManagerRedis,
StateProxy,
StateUpdate,
_substate_key,
)
from reflex.utils import prerequisites, types
from reflex.utils.format import json_dumps
Expand Down Expand Up @@ -1484,7 +1485,7 @@ def substate_token(state_manager, token):
Returns:
Token concatenated with the state_manager's state full_name.
"""
return f"{token}_{state_manager.state.get_full_name()}"
return _substate_key(token, state_manager.state)


@pytest.mark.asyncio
Expand Down Expand Up @@ -1582,7 +1583,7 @@ def substate_token_redis(state_manager_redis, token):
Returns:
Token concatenated with the state_manager's state full_name.
"""
return f"{token}_{state_manager_redis.state.get_full_name()}"
return _substate_key(token, state_manager_redis.state)


@pytest.mark.asyncio
Expand Down Expand Up @@ -1738,8 +1739,9 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
assert sp.value2 == 42

# Get the state from the state manager directly and check that the value is updated
gc_token = f"{grandchild_state.get_token()}_{grandchild_state.get_full_name()}"
gotten_state = await mock_app.state_manager.get_state(gc_token)
gotten_state = await mock_app.state_manager.get_state(
_substate_key(grandchild_state.router.session.client_token, grandchild_state)
)
if isinstance(mock_app.state_manager, StateManagerMemory):
# For in-process store, only one instance of the state exists
assert gotten_state is parent_state
Expand Down Expand Up @@ -1935,8 +1937,11 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
"private",
]

substate_token = f"{token}_{BackgroundTaskState.get_name()}"
assert (await mock_app.state_manager.get_state(substate_token)).order == exp_order
assert (
await mock_app.state_manager.get_state(
_substate_key(token, BackgroundTaskState)
)
).order == exp_order

assert mock_app.event_namespace is not None
emit_mock = mock_app.event_namespace.emit
Expand Down Expand Up @@ -2013,8 +2018,11 @@ async def test_background_task_reset(mock_app: rx.App, token: str):
await task
assert not mock_app.background_tasks

substate_token = f"{token}_{BackgroundTaskState.get_name()}"
assert (await mock_app.state_manager.get_state(substate_token)).order == [
assert (
await mock_app.state_manager.get_state(
_substate_key(token, BackgroundTaskState)
)
).order == [
"reset",
]

Expand Down Expand Up @@ -2580,7 +2588,7 @@ async def test_get_state(mock_app: rx.App, token: str):

# Get instance of ChildState2.
test_state = await mock_app.state_manager.get_state(
f"{token}_{ChildState2.get_full_name()}"
_substate_key(token, ChildState2)
)
assert isinstance(test_state, TestState)
if isinstance(mock_app.state_manager, StateManagerMemory):
Expand Down Expand Up @@ -2642,7 +2650,7 @@ async def test_get_state(mock_app: rx.App, token: str):

# Get a fresh instance
new_test_state = await mock_app.state_manager.get_state(
f"{token}_{ChildState2.get_full_name()}"
_substate_key(token, ChildState2)
)
assert isinstance(new_test_state, TestState)
if isinstance(mock_app.state_manager, StateManagerMemory):
Expand Down
6 changes: 2 additions & 4 deletions tests/test_state_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

import reflex as rx
from reflex.state import BaseState, StateManager, StateManagerRedis
from reflex.state import BaseState, StateManager, StateManagerRedis, _substate_key


class Root(BaseState):
Expand Down Expand Up @@ -320,9 +320,7 @@ async def test_get_state_tree(
exp_root_substates: The expected substates of the root state.
exp_root_dict_keys: The expected keys of the root state dict.
"""
state = await state_manager_redis.get_state(
f"{token}_{substate_cls.get_full_name()}"
)
state = await state_manager_redis.get_state(_substate_key(token, substate_cls))
assert isinstance(state, Root)
assert sorted(state.substates) == sorted(exp_root_substates)

Expand Down

0 comments on commit 48bd579

Please sign in to comment.