From 47637a7edd391678a4102aba573bc2160c8b7979 Mon Sep 17 00:00:00 2001 From: Rebecca Sutton Koeser Date: Fri, 1 Sep 2023 11:16:15 -0400 Subject: [PATCH] Add docstring for jupyterviz make_user_input that documents supported inputs (#1784) * Add docstring for make_user_input that documents supported inputs * Use field name as user input fallback label; error on unsupported type * Preliminary unit tests for jupyter viz make_user_input method * Remove unused variable flagged by ruff lint --- mesa/experimental/jupyter_viz.py | 53 ++++++++++++++++++++------------ tests/test_jupyter_viz.py | 48 +++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 19 deletions(-) create mode 100644 tests/test_jupyter_viz.py diff --git a/mesa/experimental/jupyter_viz.py b/mesa/experimental/jupyter_viz.py index c048be35d14..6007b9a8c8f 100644 --- a/mesa/experimental/jupyter_viz.py +++ b/mesa/experimental/jupyter_viz.py @@ -32,10 +32,10 @@ def JupyterViz( # 1. User inputs user_inputs = {} - for k, v in model_params_input.items(): - user_input = solara.use_reactive(v["value"]) - user_inputs[k] = user_input.value - make_user_input(user_input, k, v) + for name, options in model_params_input.items(): + user_input = solara.use_reactive(options["value"]) + user_inputs[name] = user_input.value + make_user_input(user_input, name, options) # 2. Model def make_model(): @@ -142,29 +142,44 @@ def check_param_is_fixed(param): return True -def make_user_input(user_input, k, v): - if v["type"] == "SliderInt": +def make_user_input(user_input, name, options): + """Initialize a user input for configurable model parameters. + Currently supports :class:`solara.SliderInt`, :class:`solara.SliderFloat`, + and :class:`solara.Select`. + + Args: + user_input: :class:`solara.reactive` object with initial value + name: field name; used as fallback for label if 'label' is not in options + options: dictionary with options for the input, including label, + min and max values, and other fields specific to the input type. + """ + # label for the input is "label" from options or name + label = options.get("label", name) + input_type = options.get("type") + if input_type == "SliderInt": solara.SliderInt( - v.get("label", "label"), + label, value=user_input, - min=v.get("min"), - max=v.get("max"), - step=v.get("step"), + min=options.get("min"), + max=options.get("max"), + step=options.get("step"), ) - elif v["type"] == "SliderFloat": + elif input_type == "SliderFloat": solara.SliderFloat( - v.get("label", "label"), + label, value=user_input, - min=v.get("min"), - max=v.get("max"), - step=v.get("step"), + min=options.get("min"), + max=options.get("max"), + step=options.get("step"), ) - elif v["type"] == "Select": + elif input_type == "Select": solara.Select( - v.get("label", "label"), - value=v.get("value"), - values=v.get("values"), + label, + value=options.get("value"), + values=options.get("values"), ) + else: + raise ValueError(f"{input_type} is not a supported input type") def make_space(model, agent_portrayal): diff --git a/tests/test_jupyter_viz.py b/tests/test_jupyter_viz.py new file mode 100644 index 00000000000..b2e701df7fd --- /dev/null +++ b/tests/test_jupyter_viz.py @@ -0,0 +1,48 @@ +import unittest +from unittest.mock import patch + +from mesa.experimental.jupyter_viz import make_user_input + + +class TestMakeUserInput(unittest.TestCase): + def test_unsupported_type(self): + """unsupported input type should raise ValueError""" + # bogus type + with self.assertRaisesRegex(ValueError, "not a supported input type"): + make_user_input(10, "input", {"type": "bogus"}) + # no type is specified + with self.assertRaisesRegex(ValueError, "not a supported input type"): + make_user_input(10, "input", {}) + + @patch("mesa.experimental.jupyter_viz.solara") + def test_slider_int(self, mock_solara): + value = 10 + name = "num_agents" + options = { + "type": "SliderInt", + "label": "number of agents", + "min": 10, + "max": 20, + "step": 1, + } + make_user_input(value, name, options) + mock_solara.SliderInt.assert_called_with( + options["label"], + value=value, + min=options["min"], + max=options["max"], + step=options["step"], + ) + + @patch("mesa.experimental.jupyter_viz.solara") + def test_label_fallback(self, mock_solara): + """name should be used as fallback label""" + value = 10 + name = "num_agents" + options = { + "type": "SliderInt", + } + make_user_input(value, name, options) + mock_solara.SliderInt.assert_called_with( + name, value=value, min=None, max=None, step=None + )