diff --git a/mesa/experimental/jupyter_viz.py b/mesa/experimental/jupyter_viz.py index 63165eb6ca4..f4ac1271054 100644 --- a/mesa/experimental/jupyter_viz.py +++ b/mesa/experimental/jupyter_viz.py @@ -406,8 +406,17 @@ def make_plot(model, measure): fig = Figure() ax = fig.subplots() df = model.datacollector.get_model_vars_dataframe() - ax.plot(df.loc[:, measure]) - ax.set_ylabel(measure) + if isinstance(measure, str): + ax.plot(df.loc[:, measure]) + ax.set_ylabel(measure) + elif isinstance(measure, dict): + for m, color in measure.items(): + ax.plot(df.loc[:, m], label=m, color=color) + fig.legend() + elif isinstance(measure, (list, tuple)): + for m in measure: + ax.plot(df.loc[:, m], label=m) + fig.legend() # Set integer x axis ax.xaxis.set_major_locator(MaxNLocator(integer=True)) solara.FigureMatplotlib(fig)