Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Mapping to plot_sensors. Hide plotting elements in Plotterly #836

Merged
merged 1 commit into from
Sep 1, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions stonesoup/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,18 @@

from .models.base import LinearModel, Model

from enum import Enum
from enum import IntEnum


class Dimension(Enum):
class Dimension(IntEnum):
"""Dimension Enum class for specifying plotting parameters in the Plotter class.
Used to sanitize inputs for the dimension attribute of Plotter().

Attributes
----------
TWO: str
TWO: int
Specifies 2D plotting for Plotter object
THREE: str
THREE: int
Specifies 3D plotting for Plotter object
"""
TWO = 2 # 2D plotting mode (original plotter.py functionality)
Expand All @@ -66,7 +66,7 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_
raise NotImplementedError

@abstractmethod
def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs):
def plot_sensors(self, sensors, mapping, sensor_label="Sensors", **kwargs):
raise NotImplementedError

def _conv_measurements(self, measurements, mapping, measurement_model=None,
Expand Down Expand Up @@ -470,7 +470,7 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_

return artists

def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs):
def plot_sensors(self, sensors, mapping=None, sensor_label="Sensors", **kwargs):
"""Plots sensor(s)

Plots sensors. Users can change the color and marker of detections using keyword
Expand All @@ -480,6 +480,9 @@ def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs):
----------
sensors : Collection of :class:`~.Sensor`
Sensors to plot
mapping: list
List of items specifying the mapping of the position components of the
sensor's position. Default is either [0, 1] or [0, 1, 2] depending on `self.dimension`
sensor_label: str
Label to apply to all tracks for legend.
\\*\\*kwargs: dict
Expand All @@ -498,16 +501,19 @@ def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs):
if not isinstance(sensors, Collection):
sensors = {sensors} # Make a set of length 1

if mapping is None:
mapping = list(range(self.dimension))

artists = []
for sensor in sensors:
if self.dimension is Dimension.TWO: # plots the sensors in xy
artists.append(self.ax.scatter(sensor.position[0],
sensor.position[1],
artists.append(self.ax.scatter(sensor.position[mapping[0]],
sensor.position[mapping[1]],
**sensor_kwargs))
elif self.dimension is Dimension.THREE: # plots the sensors in xyz
artists.extend(self.ax.plot3D(sensor.position[0],
sensor.position[1],
sensor.position[2],
artists.extend(self.ax.plot3D(sensor.position[mapping[0]],
sensor.position[mapping[1]],
sensor.position[mapping[2]],
**sensor_kwargs))
else:
raise NotImplementedError('Unsupported dimension type for sensor plotting')
Expand Down Expand Up @@ -991,7 +997,7 @@ def func3(x):
points = rotational_matrix @ points.T
return points + state.mean[mapping[:2], :]

def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs):
def plot_sensors(self, sensors, mapping=[0, 1], sensor_label="Sensors", **kwargs):
"""Plots sensor(s)

Plots sensors. Users can change the color and marker of detections using keyword
Expand All @@ -1001,6 +1007,9 @@ def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs):
----------
sensors : Collection of :class:`~.Sensor`
Sensors to plot
mapping: list
List of items specifying the mapping of the position
components of the sensor's position.
sensor_label: str
Label to apply to all tracks for legend.
\\*\\*kwargs: dict
Expand All @@ -1022,9 +1031,27 @@ def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs):
else:
sensor_kwargs['showlegend'] = True

sensor_xy = np.array([sensor.position[[0, 1], 0] for sensor in sensors])
sensor_xy = np.array([sensor.position[mapping, 0] for sensor in sensors])
self.fig.add_scatter(x=sensor_xy[:, 0], y=sensor_xy[:, 1], **sensor_kwargs)

def hide_plot_traces(self, items_to_hide: set):
"""Hide Plot Traces

This function allows plotting items to be invisible as default. Users can toggle the plot
trace to visible.

Parameters
----------
items_to_hide : set[str]
The legend label (`legendgroups`) for the plot traces that should be invisible as
default
"""
for fig_data in self.fig.data:
if fig_data.legendgroup in items_to_hide:
fig_data.visible = "legendonly"
else:
fig_data.visible = None


class _AnimationPlotterDataClass(Base):
plotting_data = Property(Iterable[State])
Expand Down