Skip to content

Commit

Permalink
Merge pull request #584 from PACarniglia/main
Browse files Browse the repository at this point in the history
Added 3D Plotting Functionality to plotter.py
  • Loading branch information
sdhiscocks authored Mar 2, 2022
2 parents d648229 + f190be0 commit e3da986
Show file tree
Hide file tree
Showing 2 changed files with 266 additions and 82 deletions.
240 changes: 158 additions & 82 deletions stonesoup/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,37 +10,70 @@
from .types import detection
from .models.base import LinearModel, Model

from enum import Enum


class Dimension(Enum):
"""Dimension Enum class for specifying plotting parameters in the Plotter class.
Used to sanitize inputs for the dimension attribute of Plotter().
Attributes
----------
TWO: str
Specifies 2D plotting for Plotter object
THREE: str
Specifies 3D plotting for Plotter object
"""
TWO = 2 # 2D plotting mode (original plotter.py functionality)
THREE = 3 # 3D plotting mode


class Plotter:
"""Plotting class for building graphs of Stone Soup simulations
A plotting class which is used to simplify the process of plotting ground truths,
measurements, clutter and tracks. Tracks can be plotted with uncertainty ellipses or
particles if required. Legends are automatically generated with each plot.
Three dimensional plots can be created using the optional dimension parameter.
Parameters
----------
dimension: enum \'Dimension\'
Optional parameter to specify 2D or 3D plotting. Default is 2D plotting.
Attributes
----------
fig: matplotlib.figure.Figure
Generated figure for graphs to be plotted on
ax: matplotlib.axes.Axes
Generated axes for graphs to be plotted on
handles_list: list of :class:`matplotlib.legend_handler.HandlerBase`
A list of generated legend handles
labels_list: list of str
A list of generated legend labels
legend_dict: dict
Dictionary of legend handles as :class:`matplotlib.legend_handler.HandlerBase`
and labels as str
"""

def __init__(self):
def __init__(self, dimension=Dimension.TWO):
if isinstance(dimension, type(Dimension.TWO)):
self.dimension = dimension
else:
raise TypeError("""%s is an unsupported type for \'dimension\';
expected type %s""" % (type(dimension), type(Dimension.TWO)))
# Generate plot axes
self.fig = plt.figure(figsize=(10, 6))
self.ax = self.fig.add_subplot(1, 1, 1)
if self.dimension is Dimension.TWO: # 2D axes
self.ax = self.fig.add_subplot(1, 1, 1)
self.ax.axis('equal')
else: # 3D axes
self.ax = self.fig.add_subplot(111, projection='3d')
self.ax.axis('auto')
self.ax.set_zlabel("$z$")
self.ax.set_xlabel("$x$")
self.ax.set_ylabel("$y$")
self.ax.axis('equal')

# Create empty lists for legend handles and labels
self.handles_list = []
self.labels_list = []
# Create empty dictionary for legend handles and labels - dict used to
# prevent multiple entries with the same label from displaying on legend
# This is new compared to plotter.py
self.legend_dict = {} # create an empty dictionary to hold legend entries

def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwargs):
"""Plots ground truth(s)
Expand All @@ -58,7 +91,7 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwa
:class:`~.GroundTruthPath` type, the argument is modified to be a set to allow for
iteration.
mapping: list
List of 2 items specifying the mapping of the x and y components of the state space.
List of items specifying the mapping of the position components of the state space.
\\*\\*kwargs: dict
Additional arguments to be passed to plot function. Default is ``linestyle="--"``.
"""
Expand All @@ -69,17 +102,22 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwa
truths = {truths} # Make a set of length 1

for truth in truths:
self.ax.plot([state.state_vector[mapping[0]] for state in truth],
[state.state_vector[mapping[1]] for state in truth],
**truths_kwargs)

if self.dimension is Dimension.TWO: # plots the ground truths in xy
self.ax.plot([state.state_vector[mapping[0]] for state in truth],
[state.state_vector[mapping[1]] for state in truth],
**truths_kwargs)
elif self.dimension is Dimension.THREE: # plots the ground truths in xyz
self.ax.plot3D([state.state_vector[mapping[0]] for state in truth],
[state.state_vector[mapping[1]] for state in truth],
[state.state_vector[mapping[2]] for state in truth],
**truths_kwargs)
else:
raise NotImplementedError('Unsupported dimension type for truth plotting')
# Generate legend items
truths_handle = Line2D([], [], linestyle=truths_kwargs['linestyle'], color='black')
self.handles_list.append(truths_handle)
self.labels_list.append(truths_label)

self.legend_dict[truths_label] = truths_handle
# Generate legend
self.ax.legend(handles=self.handles_list, labels=self.labels_list)
self.ax.legend(handles=self.legend_dict.values(), labels=self.legend_dict.keys())

def plot_measurements(self, measurements, mapping, measurement_model=None,
measurements_label="Measurements", **kwargs):
Expand All @@ -97,7 +135,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None,
measurements : list of :class:`~.Detection`
Detections which will be plotted. If measurements is a set of lists it is flattened.
mapping: list
List of 2 items specifying the mapping of the x and y components of the state space.
List of items specifying the mapping of the position components of the state space.
measurement_model : :class:`~.Model`, optional
User-defined measurement model to be used in finding measurement state inverses if
they cannot be found from the measurements themselves.
Expand Down Expand Up @@ -151,36 +189,38 @@ def plot_measurements(self, measurements, mapping, measurement_model=None,

if plot_detections:
detection_array = np.array(plot_detections)
self.ax.scatter(detection_array[:, 0], detection_array[:, 1], **measurement_kwargs)
# *detection_array.T unpacks detection_array by coloumns
# (same as passing in detection_array[:,0], detection_array[:,1], etc...)
self.ax.scatter(*detection_array.T, **measurement_kwargs)
measurements_handle = Line2D([], [], linestyle='', **measurement_kwargs)

# Generate legend items for measurements
self.handles_list.append(measurements_handle)
self.labels_list.append(measurements_label)
self.legend_dict[measurements_label] = measurements_handle

if plot_clutter:
clutter_array = np.array(plot_clutter)
self.ax.scatter(clutter_array[:, 0], clutter_array[:, 1], color='y', marker='2')
self.ax.scatter(*clutter_array.T, color='y', marker='2')
clutter_handle = Line2D([], [], linestyle='', marker='2', color='y')
clutter_label = "Clutter"

# Generate legend items for clutter
self.handles_list.append(clutter_handle)
self.labels_list.append(clutter_label)
self.legend_dict[clutter_label] = clutter_handle

# Generate legend
self.ax.legend(handles=self.handles_list, labels=self.labels_list)
self.ax.legend(handles=self.legend_dict.values(), labels=self.legend_dict.keys())

def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_label="Track",
**kwargs):
err_freq=1, **kwargs):
"""Plots track(s)
Plots each track generated, generating a legend automatically. If ``uncertainty=True``,
uncertainty ellipses are plotted. If ``particle=True``, particles are plotted.
Tracks are plotted as solid lines with point markers and default colors.
Uncertainty ellipses are plotted with a default color which is the same for all tracks.
Plots each track generated, generating a legend automatically. If ``uncertainty=True``
and is being plotted in 2D, error elipses are plotted. If being plotted in
3D, uncertainty bars are plotted every :attr:`err_freq` measurement, default
plots unceratinty bars at every track step. Tracks are plotted as solid
lines with point markers and default colors. Uncertainty bars are plotted
with a default color which is the same for all tracks.
Users can change linestyle, color and marker using keyword arguments. Uncertainty ellipses
Users can change linestyle, color and marker using keyword arguments. Uncertainty metrics
will also be plotted with the user defined colour and any changes will apply to all tracks.
Parameters
Expand All @@ -189,13 +229,17 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_
Set of tracks which will be plotted. If not a set, and instead a single
:class:`~.Track` type, the argument is modified to be a set to allow for iteration.
mapping: list
List of 2 items specifying the mapping of the x and y components of the state space.
List of items specifying the mapping of the position
components of the state space.
uncertainty : bool
If True, function plots uncertainty ellipses.
If True, function plots uncertainty ellipses or bars.
particle : bool
If True, function plots particles.
track_label: str
Label to apply to all tracks for legend.
err_freq: int
Frequency of error bar plotting on tracks. Default value is 1, meaning
error bars are plotted at every track step.
\\*\\*kwargs: dict
Additional arguments to be passed to plot function. Defaults are ``linestyle="-"``,
``marker='.'`` and ``color=None``.
Expand All @@ -209,9 +253,15 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_
# Plot tracks
track_colors = {}
for track in tracks:
line = self.ax.plot([state.state_vector[mapping[0]] for state in track],
[state.state_vector[mapping[1]] for state in track],
**tracks_kwargs)
if self.dimension is Dimension.TWO:
line = self.ax.plot([state.state_vector[mapping[0]] for state in track],
[state.state_vector[mapping[1]] for state in track],
**tracks_kwargs)
else:
line = self.ax.plot([state.state_vector[mapping[0]] for state in track],
[state.state_vector[mapping[1]] for state in track],
[state.state_vector[mapping[2]] for state in track],
**tracks_kwargs)
track_colors[track] = plt.getp(line[0], 'color')

# Assuming a single track or all plotted as the same colour then the following will work.
Expand All @@ -221,55 +271,81 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_
# Generate legend items for track
track_handle = Line2D([], [], linestyle=tracks_kwargs['linestyle'],
marker=tracks_kwargs['marker'], color=tracks_kwargs['color'])
self.handles_list.append(track_handle)
self.labels_list.append(track_label)

self.legend_dict[track_label] = track_handle
if uncertainty:
# Plot uncertainty ellipses
for track in tracks:
HH = np.eye(track.ndim)[mapping, :] # Get position mapping matrix
for state in track:
w, v = np.linalg.eig(HH @ state.covar @ HH.T)
max_ind = np.argmax(w)
min_ind = np.argmin(w)
orient = np.arctan2(v[1, max_ind], v[0, max_ind])
ellipse = Ellipse(xy=state.state_vector[mapping[:2], 0],
width=2 * np.sqrt(w[max_ind]),
height=2 * np.sqrt(w[min_ind]),
angle=np.rad2deg(orient), alpha=0.2,
color=track_colors[track])
self.ax.add_artist(ellipse)

# Generate legend items for uncertainty ellipses
ellipse_handle = Ellipse((0.5, 0.5), 0.5, 0.5, alpha=0.2, color=tracks_kwargs['color'])
ellipse_label = "Uncertainty"

self.handles_list.append(ellipse_handle)
self.labels_list.append(ellipse_label)

# Generate legend
self.ax.legend(handles=self.handles_list, labels=self.labels_list,
handler_map={Ellipse: _HandlerEllipse()})
if self.dimension is Dimension.TWO:
# Plot uncertainty ellipses
for track in tracks:
HH = np.eye(track.ndim)[mapping, :] # Get position mapping matrix
for state in track:
w, v = np.linalg.eig(HH @ state.covar @ HH.T)
max_ind = np.argmax(w)
min_ind = np.argmin(w)
orient = np.arctan2(v[1, max_ind], v[0, max_ind])
ellipse = Ellipse(xy=state.state_vector[mapping[:2], 0],
width=2 * np.sqrt(w[max_ind]),
height=2 * np.sqrt(w[min_ind]),
angle=np.rad2deg(orient), alpha=0.2,
color=track_colors[track])
self.ax.add_artist(ellipse)

# Generate legend items for uncertainty ellipses
ellipse_handle = Ellipse((0.5, 0.5), 0.5, 0.5, alpha=0.2,
color=tracks_kwargs['color'])
ellipse_label = "Uncertainty"
self.legend_dict[ellipse_label] = ellipse_handle
# Generate legend
self.ax.legend(handles=self.legend_dict.values(),
labels=self.legend_dict.keys(),
handler_map={Ellipse: _HandlerEllipse()})
else:
# Plot 3D error bars on tracks
for track in tracks:
HH = np.eye(track.ndim)[mapping, :] # Get position mapping matrix
check = err_freq
for state in track:
if not check % err_freq:
w, v = np.linalg.eig(HH @ state.covar @ HH.T)

xl = state.state_vector[mapping[0]]
yl = state.state_vector[mapping[1]]
zl = state.state_vector[mapping[2]]

x_err = w[0]
y_err = w[1]
z_err = w[2]

self.ax.plot3D([xl+x_err, xl-x_err], [yl, yl], [zl, zl],
marker="_", color=tracks_kwargs['color'])
self.ax.plot3D([xl, xl], [yl+y_err, yl-y_err], [zl, zl],
marker="_", color=tracks_kwargs['color'])
self.ax.plot3D([xl, xl], [yl, yl], [zl+z_err, zl-z_err],
marker="_", color=tracks_kwargs['color'])
check += 1

elif particle:
# Plot particles
for track in tracks:
for state in track:
data = state.particles.state_vector[mapping[:2], :]
self.ax.plot(data[0], data[1], linestyle='', marker=".",
markersize=1, alpha=0.5)

# Generate legend items for particles
particle_handle = Line2D([], [], linestyle='', color="black", marker='.', markersize=1)
particle_label = "Particles"
self.handles_list.append(particle_handle)
self.labels_list.append(particle_label)

# Generate legend
self.ax.legend(handles=self.handles_list, labels=self.labels_list)
if self.dimension is Dimension.TWO:
# Plot particles
for track in tracks:
for state in track:
data = state.particles.state_vector[mapping[:2], :]
self.ax.plot(data[0], data[1], linestyle='', marker=".",
markersize=1, alpha=0.5)

# Generate legend items for particles
particle_handle = Line2D([], [], linestyle='', color="black", marker='.',
markersize=1)
particle_label = "Particles"
self.legend_dict[particle_label] = particle_handle
# Generate legend
self.ax.legend(handles=self.legend_dict.values(),
labels=self.legend_dict.keys()) # particle error legend
else:
raise NotImplementedError("""Particle plotting is not currently supported for
3D visualization""")

else:
self.ax.legend(handles=self.handles_list, labels=self.labels_list)
self.ax.legend(handles=self.legend_dict.values(), labels=self.legend_dict.keys())

# Ellipse legend patch (used in Tutorial 3)
@staticmethod
Expand Down
Loading

0 comments on commit e3da986

Please sign in to comment.