diff --git a/mesa/visualization/__init__.py b/mesa/visualization/__init__.py index 3208642d562..d6e50c37e36 100644 --- a/mesa/visualization/__init__.py +++ b/mesa/visualization/__init__.py @@ -2,13 +2,12 @@ from .components.altair import make_space_altair from .components.matplotlib import make_plot_measure, make_space_matplotlib -from .solara_viz import JupyterViz, SolaraViz, make_text +from .solara_viz import JupyterViz, SolaraViz from .UserParam import Slider __all__ = [ "JupyterViz", "SolaraViz", - "make_text", "Slider", "make_space_altair", "make_space_matplotlib", diff --git a/mesa/visualization/solara_viz.py b/mesa/visualization/solara_viz.py index 7b1a0464f65..129a94863d3 100644 --- a/mesa/visualization/solara_viz.py +++ b/mesa/visualization/solara_viz.py @@ -7,7 +7,6 @@ - SolaraViz: Main component for creating visualizations, supporting grid displays and plots - ModelController: Handles model execution controls (step, play, pause, reset) - UserInputs: Generates UI elements for adjusting model parameters - - Card: Renders individual visualization elements (space, measures) The module uses Solara for rendering in Jupyter notebooks or as standalone web applications. It supports various types of visualizations including matplotlib plots, agent grids, and @@ -22,10 +21,14 @@ See the Visualization Tutorial and example models for more details. """ +from __future__ import annotations + import copy import time +from collections.abc import Callable from typing import TYPE_CHECKING, Literal +import reacton.core import solara from solara.alias import rv @@ -89,31 +92,57 @@ def Card( @solara.component def SolaraViz( - model: "Model" | solara.Reactive["Model"], - components: list[solara.component] | Literal["default"] = "default", - play_interval=100, + model: Model | solara.Reactive[Model], + components: list[reacton.core.Component] + | list[Callable[[Model], reacton.core.Component]] + | Literal["default"] = "default", + play_interval: int = 100, model_params=None, - seed=0, + seed: float = 0, name: str | None = None, ): """Solara visualization component. + This component provides a visualization interface for a given model using Solara. + It supports various visualization components and allows for interactive model + stepping and parameter adjustments. + Args: - model: a Model instance - components: list of solara components - play_interval: int - model_params: parameters for instantiating a model - seed: the seed for the rng - name: str + model (Model | solara.Reactive[Model]): A Model instance or a reactive Model. + This is the main model to be visualized. If a non-reactive model is provided, + it will be converted to a reactive model. + components (list[solara.component] | Literal["default"], optional): List of solara + components or functions that return a solara component. + These components are used to render different parts of the model visualization. + Defaults to "default", which uses the default Altair space visualization. + play_interval (int, optional): Interval for playing the model steps in milliseconds. + This controls the speed of the model's automatic stepping. Defaults to 100 ms. + model_params (dict, optional): Parameters for (re-)instantiating a model. + Can include user-adjustable parameters and fixed parameters. Defaults to None. + seed (int, optional): Seed for the random number generator. This ensures reproducibility + of the model's behavior. Defaults to 0. + name (str | None, optional): Name of the visualization. Defaults to the models class name. + Returns: + solara.component: A Solara component that renders the visualization interface for the model. + + Example: + >>> model = MyModel() + >>> page = SolaraViz(model) + >>> page + + Notes: + - The `model` argument can be either a direct model instance or a reactive model. If a direct + model instance is provided, it will be converted to a reactive model using `solara.use_reactive`. + - The `play_interval` argument controls the speed of the model's automatic stepping. A lower + value results in faster stepping, while a higher value results in slower stepping. """ - update_counter.get() if components == "default": components = [components_altair.make_space_altair()] # Convert model to reactive if not isinstance(model, solara.Reactive): - model = solara.use_reactive(model) + model = solara.use_reactive(model) # noqa: SH102, RUF100 def connect_to_model(): # Patch the step function to force updates @@ -133,39 +162,68 @@ def step(): with solara.AppBar(): solara.AppBarTitle(name if name else model.value.__class__.__name__) - with solara.Sidebar(): - with solara.Card("Controls", margin=1, elevation=2): - if model_params is not None: + with solara.Sidebar(), solara.Column(): + with solara.Card("Controls"): + ModelController(model, play_interval) + + if model_params is not None: + with solara.Card("Model Parameters"): ModelCreator( model, model_params, seed=seed, ) - ModelController(model, play_interval) - with solara.Card("Information", margin=1, elevation=2): + with solara.Card("Information"): ShowSteps(model.value) - solara.Column( - [ - *(component(model.value) for component in components), - ] - ) + ComponentsView(components, model.value) + + +def _wrap_component( + component: reacton.core.Component | Callable[[Model], reacton.core.Component], +) -> reacton.core.Component: + """Wrap a component in an auto-updated Solara component if needed.""" + if isinstance(component, reacton.core.Component): + return component + + @solara.component + def WrappedComponent(model): + update_counter.get() + return component(model) + + return WrappedComponent + + +@solara.component +def ComponentsView( + components: list[reacton.core.Component] + | list[Callable[[Model], reacton.core.Component]], + model: Model, +): + """Display a list of components. + + Args: + components: List of components to display + model: Model instance to pass to each component + """ + wrapped_components = [_wrap_component(component) for component in components] + + with solara.Column(): + for component in wrapped_components: + component(model) JupyterViz = SolaraViz @solara.component -def ModelController(model: solara.Reactive["Model"], play_interval=100): +def ModelController(model: solara.Reactive[Model], play_interval=100): """Create controls for model execution (step, play, pause, reset). Args: - model: The reactive model being visualized - play_interval: Interval between steps during play + model (solara.Reactive[Model]): Reactive model instance + play_interval (int, optional): Interval for playing the model steps in milliseconds. """ - if not isinstance(model, solara.Reactive): - model = solara.use_reactive(model) - playing = solara.use_reactive(False) original_model = solara.use_reactive(None) @@ -188,24 +246,25 @@ def do_step(): """Advance the model by one step.""" model.value.step() - def do_play(): - """Run the model continuously.""" - playing.value = True - - def do_pause(): - """Pause the model execution.""" - playing.value = False - def do_reset(): """Reset the model to its initial state.""" playing.value = False model.value = copy.deepcopy(original_model.value) + def do_play_pause(): + """Toggle play/pause.""" + playing.value = not playing.value + with solara.Row(justify="space-between"): solara.Button(label="Reset", color="primary", on_click=do_reset) - solara.Button(label="Step", color="primary", on_click=do_step) - solara.Button(label="▶", color="primary", on_click=do_play) - solara.Button(label="⏸︎", color="primary", on_click=do_pause) + solara.Button( + label="▶" if not playing.value else "❚❚", + color="primary", + on_click=do_play_pause, + ) + solara.Button( + label="Step", color="primary", on_click=do_step, disabled=playing.value + ) def split_model_params(model_params): @@ -246,13 +305,34 @@ def check_param_is_fixed(param): @solara.component def ModelCreator(model, model_params, seed=1): - """Helper class to create a new Model instance. + """Solara component for creating and managing a model instance with user-defined parameters. + + This component allows users to create a model instance with specified parameters and seed. + It provides an interface for adjusting model parameters and reseeding the model's random + number generator. Args: - model: model instance - model_params: model parameters - seed: the seed to use for the random number generator + model (solara.Reactive[Model]): A reactive model instance. This is the main model to be created and managed. + model_params (dict): Dictionary of model parameters. This includes both user-adjustable parameters and fixed parameters. + seed (int, optional): Initial seed for the random number generator. Defaults to 1. + Returns: + solara.component: A Solara component that renders the model creation and management interface. + + Example: + >>> model = solara.reactive(MyModel()) + >>> model_params = { + >>> "param1": {"type": "slider", "value": 10, "min": 0, "max": 100}, + >>> "param2": {"type": "slider", "value": 5, "min": 1, "max": 10}, + >>> } + >>> creator = ModelCreator(model, model_params) + >>> creator + + Notes: + - The `model_params` argument should be a dictionary where keys are parameter names and values either fixed values + or are dictionaries containing parameter details such as type, value, min, and max. + - The `seed` argument ensures reproducibility by setting the initial seed for the model's random number generator. + - The component provides an interface for adjusting user-defined parameters and reseeding the model. """ user_params, fixed_params = split_model_params(model_params) @@ -279,13 +359,14 @@ def create_model(): solara.use_effect(create_model, [model_parameters, reactive_seed.value]) - solara.InputText( - label="Seed", - value=reactive_seed, - continuous_update=True, - ) + with solara.Row(justify="space-between"): + solara.InputText( + label="Seed", + value=reactive_seed, + continuous_update=True, + ) - solara.Button(label="Reseed", color="primary", on_click=do_reseed) + solara.Button(label="Reseed", color="primary", on_click=do_reseed) UserInputs(user_params, on_change=on_change) @@ -358,22 +439,6 @@ def change_handler(value, name=name): raise ValueError(f"{input_type} is not a supported input type") -def make_text(renderer): - """Create a function that renders text using Markdown. - - Args: - renderer: Function that takes a model and returns a string - - Returns: - function: A function that renders the text as Markdown - """ - - def function(model): - solara.Markdown(renderer(model)) - - return function - - def make_initial_grid_layout(layout_types): """Create an initial grid layout for visualization components. @@ -397,6 +462,7 @@ def make_initial_grid_layout(layout_types): @solara.component -def ShowSteps(model): # noqa: D103 +def ShowSteps(model): + """Display the current step of the model.""" update_counter.get() return solara.Text(f"Step: {model.steps}") diff --git a/tests/test_solara_viz.py b/tests/test_solara_viz.py index 798777bb5eb..301294f25ba 100644 --- a/tests/test_solara_viz.py +++ b/tests/test_solara_viz.py @@ -1,12 +1,13 @@ """Test Solara visualizations.""" import unittest -from unittest.mock import Mock import ipyvuetify as vw import solara import mesa +import mesa.visualization.components.altair +import mesa.visualization.components.matplotlib from mesa.visualization.components.matplotlib import make_space_matplotlib from mesa.visualization.solara_viz import Slider, SolaraViz, UserInputs @@ -86,10 +87,12 @@ def Test(user_params): def test_call_space_drawer(mocker): # noqa: D103 - mock_space_matplotlib = mocker.patch( - "mesa.visualization.components.matplotlib.SpaceMatplotlib" + mock_space_matplotlib = mocker.spy( + mesa.visualization.components.matplotlib, "SpaceMatplotlib" ) + mock_space_altair = mocker.spy(mesa.visualization.components.altair, "SpaceAltair") + model = mesa.Model() mocker.patch.object(mesa.Model, "__init__", return_value=None) @@ -105,13 +108,19 @@ def test_call_space_drawer(mocker): # noqa: D103 # specify no space should be drawn mock_space_matplotlib.reset_mock() - solara.render(SolaraViz(model, components=[])) + solara.render(SolaraViz(model)) # should call default method with class instance and agent portrayal assert mock_space_matplotlib.call_count == 0 + assert mock_space_altair.call_count > 0 # specify a custom space method - altspace_drawer = Mock() - solara.render(SolaraViz(model, components=[altspace_drawer])) + class AltSpace: + @staticmethod + def drawer(model): + return + + altspace_drawer = mocker.spy(AltSpace, "drawer") + solara.render(SolaraViz(model, components=[AltSpace.drawer])) altspace_drawer.assert_called_with(model) # check voronoi space drawer