Skip to content

Commit

Permalink
RuntimeError unless State is instantiated with _reflex_internal_init=…
Browse files Browse the repository at this point in the history
…True

Avoid user errors trying to directly instantiate State classes
  • Loading branch information
masenf committed Feb 21, 2024
1 parent 5cb6b0f commit 7ecda2d
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 8 deletions.
4 changes: 2 additions & 2 deletions reflex/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,12 @@ def compile_state(state: Type[BaseState]) -> dict:
A dictionary of the compiled state.
"""
try:
initial_state = state().dict()
initial_state = state(_reflex_internal_init=True).dict()
except Exception as e:
console.warn(
f"Failed to compile initial state with computed vars, excluding them: {e}"
)
initial_state = state().dict(include_computed=False)
initial_state = state(_reflex_internal_init=True).dict(include_computed=False)
return format.format_state(initial_state)


Expand Down
32 changes: 28 additions & 4 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,15 @@ def __init__(self, router_data: Optional[dict] = None):
}


def _is_testing_env() -> bool:
"""Check if the app is running in a testing environment.
Returns:
True if the app is running in a testing environment, False otherwise.
"""
return constants.PYTEST_CURRENT_TEST in os.environ


class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"""The state of the app."""

Expand Down Expand Up @@ -218,24 +227,39 @@ def __init__(
*args,
parent_state: BaseState | None = None,
init_substates: bool = True,
_reflex_internal_init: bool = False,
**kwargs,
):
"""Initialize the state.
DO NOT INSTANTIATE STATE CLASSES DIRECTLY! Use StateManager.get_state() instead.
Args:
*args: The args to pass to the Pydantic init method.
parent_state: The parent state.
init_substates: Whether to initialize the substates in this instance.
_reflex_internal_init: A flag to indicate that the state is being initialized by the framework.
**kwargs: The kwargs to pass to the Pydantic init method.
Raises:
RuntimeError: If the state is instantiated directly by end user.
"""
if not _reflex_internal_init and not _is_testing_env():
raise RuntimeError(
"State classes should not be instantiated directly. The rx.App and its StateManager "
"are responsible for creating and tracking instances of State. "
"See https://reflex.dev/docs/state for further information."
)
kwargs["parent_state"] = parent_state
super().__init__(*args, **kwargs)

# Setup the substates (for memory state manager only).
if init_substates:
for substate in self.get_substates():
self.substates[substate.get_name()] = substate(parent_state=self)
self.substates[substate.get_name()] = substate(
parent_state=self,
_reflex_internal_init=True,
)
# Convert the event handlers to functions.
self._init_event_handlers()

Expand Down Expand Up @@ -286,7 +310,6 @@ def __init_subclass__(cls, **kwargs):
Raises:
ValueError: If a substate class shadows another.
"""
is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
super().__init_subclass__(**kwargs)
# Event handlers should not shadow builtin state methods.
cls._check_overridden_methods()
Expand All @@ -305,7 +328,7 @@ def __init_subclass__(cls, **kwargs):

# Check if another substate class with the same name has already been defined.
if cls.__name__ in set(c.__name__ for c in parent_state.class_subclasses):
if is_testing_env:
if _is_testing_env():
# Clear existing subclass with same name when app is reloaded via
# utils.prerequisites.get_app(reload=True)
parent_state.class_subclasses = set(
Expand Down Expand Up @@ -1847,7 +1870,7 @@ async def get_state(self, token: str) -> BaseState:
# Memory state manager ignores the substate suffix and always returns the top-level state.
token = token.partition("_")[0]
if token not in self.states:
self.states[token] = self.state()
self.states[token] = self.state(_reflex_internal_init=True)
return self.states[token]

async def set_state(self, token: str, state: BaseState):
Expand Down Expand Up @@ -2005,6 +2028,7 @@ async def get_state(
state_cls(
parent_state=parent_state,
init_substates=False,
_reflex_internal_init=True,
),
)
# After creating the state key, recursively call `get_state` to populate substates.
Expand Down
4 changes: 2 additions & 2 deletions tests/utils/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def test_format_query_params(input, output):
"input, output",
[
(
TestState().dict(), # type: ignore
TestState(_reflex_internal_init=True).dict(), # type: ignore
{
TestState.get_full_name(): {
"array": [1, 2, 3.14],
Expand Down Expand Up @@ -684,7 +684,7 @@ def test_format_query_params(input, output):
},
),
(
DateTimeState().dict(),
DateTimeState(_reflex_internal_init=True).dict(), # type: ignore
{
DateTimeState.get_full_name(): {
"d": "1989-11-09",
Expand Down

0 comments on commit 7ecda2d

Please sign in to comment.