Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding setting title functionality to AnimationPlotter #919

Merged
merged 6 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions stonesoup/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1676,7 +1676,7 @@ class _AnimationPlotterDataClass(Base):
class AnimationPlotter(_Plotter):

def __init__(self, dimension=Dimension.TWO, x_label: str = "$x$", y_label: str = "$y$",
legend_kwargs: dict = {}, **kwargs):
title: str = None, legend_kwargs: dict = {}, **kwargs):

self.figure_kwargs = {"figsize": (10, 6)}
self.figure_kwargs.update(kwargs)
Expand All @@ -1689,6 +1689,10 @@ def __init__(self, dimension=Dimension.TWO, x_label: str = "$x$", y_label: str =
self.x_label: str = x_label
self.y_label: str = y_label

if title:
title += "\n"
self.title: str = title

self.plotting_data: List[_AnimationPlotterDataClass] = []

self.animation_output: animation.FuncAnimation = None
Expand All @@ -1710,7 +1714,6 @@ def run(self,
\\*\\*kwargs: dict
Additional arguments to be passed to the animation.FuncAnimation function
"""

if times_to_plot is None:
times_to_plot = sorted({
state.timestamp
Expand All @@ -1725,7 +1728,8 @@ def run(self,
y_label=self.y_label,
figure_kwargs=self.figure_kwargs,
legend_kwargs=self.legend_kwargs,
animation_input_kwargs=kwargs
animation_input_kwargs=kwargs,
plot_title=self.title
)
return self.animation_output

Expand Down Expand Up @@ -1774,7 +1778,7 @@ def plot_ground_truths(self, truths, mapping: List[int], truths_label: str = "Gr
self.plot_state_mutable_sequence(truths, mapping, truths_label, **truths_kwargs)

def plot_tracks(self, tracks, mapping: List[int], uncertainty=False, particle=False,
track_label="Tracks", **kwargs):
track_label="Tracks", **kwargs):
"""Plots track(s)

Plots each track generated, generating a legend automatically. Tracks are plotted as solid
Expand Down Expand Up @@ -1926,7 +1930,8 @@ def run_animation(cls,
animation_input_kwargs: dict = {},
legend_kwargs: dict = {},
x_label: str = "$x$",
y_label: str = "$y$"
y_label: str = "$y$",
plot_title: str = None
) -> animation.FuncAnimation:
"""
Parameters
Expand All @@ -1953,6 +1958,8 @@ def run_animation(cls,
Label for the x axis
y_label: str
Label for the y axis
plot_title: str
Title for the plot

Returns
-------
Expand Down Expand Up @@ -2019,7 +2026,7 @@ def run_animation(cls,
line_ani = animation.FuncAnimation(fig1, cls.update_animation,
frames=len(times_to_plot),
fargs=(the_lines, plotting_data, min_plot_times,
times_to_plot),
times_to_plot, plot_title),
**animation_kwargs)

plt.draw()
Expand All @@ -2028,7 +2035,7 @@ def run_animation(cls,

@staticmethod
def update_animation(index: int, lines: List[Line2D], data_list: List[List[State]],
start_times: List[datetime], end_times: List[datetime]):
start_times: List[datetime], end_times: List[datetime], title: str):
"""
Parameters
----------
Expand All @@ -2042,6 +2049,8 @@ def update_animation(index: int, lines: List[Line2D], data_list: List[List[State
lowest (earliest) time for an item to be plotted
end_times : List[datetime]
highest (latest) time for an item to be plotted
title: str
Title for the plot

Returns
-------
Expand All @@ -2052,7 +2061,9 @@ def update_animation(index: int, lines: List[Line2D], data_list: List[List[State
min_time = start_times[index]
max_time = end_times[index]

plt.title(max_time)
if title is None:
title = ""
plt.title(title + str(max_time))
for i, data_source in enumerate(data_list):

if data_source is not None:
Expand Down
14 changes: 13 additions & 1 deletion stonesoup/tests/test_plotter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from stonesoup.plotter import Plotter, Dimension, AnimatedPlotterly
from stonesoup.plotter import Plotter, Dimension, AnimatedPlotterly, AnimationPlotter
import pytest
import matplotlib.pyplot as plt

Expand Down Expand Up @@ -211,6 +211,18 @@ def test_plot_complex_uncertainty():
plotter.plot_tracks(track, mapping=[0, 1], uncertainty=True)


def test_animation_plotter():
animation_plotter = AnimationPlotter()
animation_plotter.plot_ground_truths(truth, [0, 2])
animation_plotter.plot_measurements(all_measurements, [0, 2])
animation_plotter.run()

animation_plotter_with_title = AnimationPlotter(title="Plot title")
animation_plotter_with_title.plot_ground_truths(truth, [0, 2])
animation_plotter_with_title.plot_tracks(track, [0, 2])
animation_plotter_with_title.run()


def test_animated_plotterly():
plotter = AnimatedPlotterly(timesteps)
plotter.plot_ground_truths(truth, [0, 2])
Expand Down