From 88912c68249b110c177a756df946b1c881c5e442 Mon Sep 17 00:00:00 2001 From: Akshay Agrawal Date: Wed, 24 Apr 2024 11:31:04 -0700 Subject: [PATCH 1/2] fix: close matplotlib figures in app.run() Each cell gets a fresh figure, just like in edit mode. Without this using the imperative api makes all plots share the same object. --- marimo/_ast/app.py | 25 +++++++++++++++++++++---- tests/_ast/test_app.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/marimo/_ast/app.py b/marimo/_ast/app.py index e7437ab4d30..bb873b0e80c 100644 --- a/marimo/_ast/app.py +++ b/marimo/_ast/app.py @@ -31,6 +31,7 @@ MultipleDefinitionError, UnparsableError, ) +from marimo._dependencies.dependencies import DependencyManager from marimo._messaging.mimetypes import KnownMimeType from marimo._messaging.types import NoopStream from marimo._output.rich_help import mddoc @@ -249,7 +250,9 @@ def _outputs_and_defs( {name: glbls[name] for name in self._defs if name in glbls}, ) - def _run_sync(self) -> tuple[Sequence[Any], dict[str, Any]]: + def _run_sync( + self, post_execute_hooks: list[Callable[..., Any]] + ) -> tuple[Sequence[Any], dict[str, Any]]: from marimo._runtime.context.types import ExecutionContext # No need to provide `file`, `input_override` here, since this @@ -265,9 +268,13 @@ def _run_sync(self) -> tuple[Sequence[Any], dict[str, Any]]: cell_id=cid, setting_element_value=False ) outputs[cid] = execute_cell(cell._cell, glbls) + for hook in post_execute_hooks: + hook() return self._outputs_and_defs(outputs, glbls) - async def _run_async(self) -> tuple[Sequence[Any], dict[str, Any]]: + async def _run_async( + self, post_execute_hooks: list[Callable[..., Any]] + ) -> tuple[Sequence[Any], dict[str, Any]]: from marimo._runtime.context.types import ExecutionContext # No need to provide `file`, `input_override` here, since this @@ -283,6 +290,8 @@ async def _run_async(self) -> tuple[Sequence[Any], dict[str, Any]]: cell_id=cid, setting_element_value=False ) outputs[cid] = await execute_cell_async(cell._cell, glbls) + for hook in post_execute_hooks: + hook() self._execution_context = None return self._outputs_and_defs(outputs, glbls) @@ -329,10 +338,18 @@ def run(self) -> tuple[Sequence[Any], dict[str, Any]]: if not FORMATTERS: register_formatters() + post_execute_hooks = [] + if DependencyManager.has_matplotlib(): + from marimo._output.mpl import close_figures + + post_execute_hooks.append(close_figures) + if is_async: - return asyncio.run(self._run_async()) + return asyncio.run( + self._run_async(post_execute_hooks=post_execute_hooks) + ) else: - return self._run_sync() + return self._run_sync(post_execute_hooks=post_execute_hooks) finally: if installed_script_context: teardown_context() diff --git a/tests/_ast/test_app.py b/tests/_ast/test_app.py index 1c865b4bfc4..a8a8500b0b4 100644 --- a/tests/_ast/test_app.py +++ b/tests/_ast/test_app.py @@ -14,6 +14,7 @@ MultipleDefinitionError, UnparsableError, ) +from marimo._dependencies.dependencies import DependencyManager if TYPE_CHECKING: import pathlib @@ -385,6 +386,33 @@ def __(x: int) -> tuple[int]: assert defs["x"] == 0 assert defs["y"] == 1 + @pytest.mark.skipif( + condition=not DependencyManager.has_matplotlib(), + reason="requires matplotlib", + ) + @staticmethod + def test_app_run_matplotlib_figures_closed() -> None: + from matplotlib.axes import Axes + + app = App() + + @app.cell + def __() -> None: + import matplotlib.pyplot as plt + + plt.plot([1, 2]) + plt.gca() + + @app.cell + def __(plt: Any) -> None: + plt.plot([1, 1]) + plt.gca() + + outputs, _ = app.run() + assert isinstance(outputs[0], Axes) + assert isinstance(outputs[1], Axes) + assert outputs[0] != outputs[1] + def test_app_config() -> None: config = _AppConfig.from_untrusted_dict({"width": "full"}) From a3a366c5aaa760840c34673144ee866cbb0b6452 Mon Sep 17 00:00:00 2001 From: Akshay Agrawal Date: Wed, 24 Apr 2024 11:48:12 -0700 Subject: [PATCH 2/2] fix test --- tests/_ast/test_app.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/_ast/test_app.py b/tests/_ast/test_app.py index a8a8500b0b4..6e541ba670a 100644 --- a/tests/_ast/test_app.py +++ b/tests/_ast/test_app.py @@ -390,8 +390,7 @@ def __(x: int) -> tuple[int]: condition=not DependencyManager.has_matplotlib(), reason="requires matplotlib", ) - @staticmethod - def test_app_run_matplotlib_figures_closed() -> None: + def test_app_run_matplotlib_figures_closed(self) -> None: from matplotlib.axes import Axes app = App()