Skip to content

Commit

Permalink
Creating threads to update visualization asynchronously (#2656)
Browse files Browse the repository at this point in the history
This PR introduces a thread-based mechanism to handle model updates and visualization rendering concurrently. It ensures smooth execution of simulations by separating model step execution and visualization updates into independent threads, improving performance and responsiveness during simulations.
Fixes #2604 

### Motive
Previously, the visualization process could become a bottleneck during rapid simulations, as rendering was tightly coupled with model updates. This caused delays and UI responsiveness issues. By separating these operations into threads, the model execution is no longer hindered by the visualization process.

### Implementation

1. Introduced separate thread for visualisation:
Handles visualization updates triggered by threading.Event (visualization_pause_event) to synchronize rendering with model steps.
2. Thread Synchronization:
Implemented visualization_pause_event to signal the visualization thread after each model step is completed, ensuring rendering happens efficiently without blocking the simulation. Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
HMNS19 authored Feb 27, 2025
1 parent 1575c88 commit 41be443
Showing 1 changed file with 114 additions and 19 deletions.
133 changes: 114 additions & 19 deletions mesa/visualization/solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

import asyncio
import inspect
import threading
import time
from collections.abc import Callable
from typing import TYPE_CHECKING, Literal

Expand Down Expand Up @@ -57,6 +59,7 @@ def SolaraViz(
simulator: Simulator | None = None,
model_params=None,
name: str | None = None,
use_threads: bool = False,
):
"""Solara visualization component.
Expand All @@ -76,6 +79,8 @@ def SolaraViz(
This controls the speed of the model's automatic stepping. Defaults to 100 ms.
render_interval (int, optional): Controls how often plots are updated during a simulation,
allowing users to skip intermediate steps and update graphs less frequently.
use_threads: Flag for indicating whether to utilize multi-threading for model execution.
When checked, the model will utilize multiple threads,adjust based on system capabilities.
simulator: A simulator that controls the model (optional)
model_params (dict, optional): Parameters for (re-)instantiating a model.
Can include user-adjustable parameters and fixed parameters. Defaults to None.
Expand Down Expand Up @@ -114,6 +119,7 @@ def SolaraViz(
reactive_model_parameters = solara.use_reactive({})
reactive_play_interval = solara.use_reactive(play_interval)
reactive_render_interval = solara.use_reactive(render_interval)
reactive_use_threads = solara.use_reactive(use_threads)
with solara.AppBar():
solara.AppBarTitle(name if name else model.value.__class__.__name__)
solara.lab.ThemeToggle()
Expand All @@ -136,12 +142,25 @@ def SolaraViz(
max=100,
step=2,
)
if reactive_use_threads.value:
solara.Text("Increase play interval to avoid skipping plots")

def set_reactive_use_threads(value):
reactive_use_threads.set(value)

solara.Checkbox(
label="Use Threads",
value=reactive_use_threads,
on_value=set_reactive_use_threads,
)

if not isinstance(simulator, Simulator):
ModelController(
model,
model_parameters=reactive_model_parameters,
play_interval=reactive_play_interval,
render_interval=reactive_render_interval,
use_threads=reactive_use_threads,
)
else:
SimulatorController(
Expand All @@ -150,6 +169,7 @@ def SolaraViz(
model_parameters=reactive_model_parameters,
play_interval=reactive_play_interval,
render_interval=reactive_render_interval,
use_threads=reactive_use_threads,
)
with solara.Card("Model Parameters"):
ModelCreator(
Expand Down Expand Up @@ -211,6 +231,7 @@ def ModelController(
model_parameters: dict | solara.Reactive[dict] = None,
play_interval: int | solara.Reactive[int] = 100,
render_interval: int | solara.Reactive[int] = 1,
use_threads: bool | solara.Reactive[bool] = False,
):
"""Create controls for model execution (step, play, pause, reset).
Expand All @@ -219,37 +240,70 @@ def ModelController(
model_parameters: Reactive parameters for (re-)instantiating a model.
play_interval: Interval for playing the model steps in milliseconds.
render_interval: Controls how often the plots are updated during simulation steps.Higher value reduce update frequency.
use_threads: Flag for indicating whether to utilize multi-threading for model execution.
"""
playing = solara.use_reactive(False)
running = solara.use_reactive(True)

if model_parameters is None:
model_parameters = {}
model_parameters = solara.use_reactive(model_parameters)

async def step():
while playing.value and running.value:
await asyncio.sleep(play_interval.value / 1000)
do_step()
visualization_pause_event = solara.use_memo(lambda: threading.Event(), [])

def step():
try:
while running.value and playing.value:
time.sleep(play_interval.value / 1000)
do_step()
if use_threads.value:
visualization_pause_event.set()
except Exception as e:
print(f"Error in step: {e}")
return

def visualization_task():
if use_threads.value:
try:
while playing.value and running.value:
visualization_pause_event.wait()
visualization_pause_event.clear()
force_update()
except Exception as e:
print(f"Error in visualization_task: {e}")

solara.lab.use_task(
step, dependencies=[playing.value, running.value], prefer_threaded=False
step, dependencies=[playing.value, running.value], prefer_threaded=True
)

solara.use_thread(
visualization_task,
dependencies=[playing.value, running.value],
)

@function_logger(__name__)
def do_step():
"""Advance the model by the number of steps specified by the render_interval slider."""
for _ in range(render_interval.value):
model.value.step()
if playing.value:
for _ in range(render_interval.value):
model.value.step()
running.value = model.value.running
if not playing.value:
break
if not use_threads.value:
force_update()

running.value = model.value.running

force_update()
else:
for _ in range(render_interval.value):
model.value.step()
running.value = model.value.running
force_update()

@function_logger(__name__)
def do_reset():
"""Reset the model to its initial state."""
playing.value = False
running.value = True
visualization_pause_event.clear()
_mesa_logger.log(
10,
f"creating new {model.value.__class__} instance with {model_parameters.value}",
Expand Down Expand Up @@ -285,6 +339,7 @@ def SimulatorController(
model_parameters: dict | solara.Reactive[dict] = None,
play_interval: int | solara.Reactive[int] = 100,
render_interval: int | solara.Reactive[int] = 1,
use_threads: bool | solara.Reactive[bool] = False,
):
"""Create controls for model execution (step, play, pause, reset).
Expand All @@ -294,6 +349,7 @@ def SimulatorController(
model_parameters: Reactive parameters for (re-)instantiating a model.
play_interval: Interval for playing the model steps in milliseconds.
render_interval: Controls how often the plots are updated during simulation steps.Higher values reduce update frequency.
use_threads: Flag for indicating whether to utilize multi-threading for model execution.
Notes:
The `step button` increments the step by the value specified in the `render_interval` slider.
Expand All @@ -304,27 +360,66 @@ def SimulatorController(
if model_parameters is None:
model_parameters = {}
model_parameters = solara.use_reactive(model_parameters)

async def step():
while playing.value and running.value:
await asyncio.sleep(play_interval.value / 1000)
do_step()
visualization_pause_event = solara.use_memo(lambda: threading.Event(), [])
pause_step_event = solara.use_memo(lambda: threading.Event(), [])

def step():
try:
while running.value and playing.value:
time.sleep(play_interval.value / 1000)
if use_threads.value:
pause_step_event.wait()
pause_step_event.clear()
do_step()
if use_threads.value:
visualization_pause_event.set()
except Exception as e:
print(f"Error in step: {e}")

def visualization_task():
if use_threads.value:
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
pause_step_event.set()
while playing.value and running.value:
visualization_pause_event.wait()
visualization_pause_event.clear()
force_update()
pause_step_event.set()
except Exception as e:
print(f"Error in visualization_task: {e}")
return

solara.lab.use_task(
step, dependencies=[playing.value, running.value], prefer_threaded=False
)
solara.lab.use_task(visualization_task, dependencies=[playing.value])

def do_step():
"""Advance the model by the number of steps specified by the render_interval slider."""
simulator.run_for(render_interval.value)
running.value = model.value.running
force_update()
if playing.value:
for _ in range(render_interval.value):
simulator.run_for(1)
running.value = model.value.running
if not playing.value:
break
if not use_threads.value:
force_update()

else:
for _ in range(render_interval.value):
simulator.run_for(1)
running.value = model.value.running
force_update()

def do_reset():
"""Reset the model to its initial state."""
playing.value = False
running.value = True
simulator.reset()
visualization_pause_event.clear()
pause_step_event.clear()
model.value = model.value = model.value.__class__(
simulator=simulator, **model_parameters.value
)
Expand Down

0 comments on commit 41be443

Please sign in to comment.