Skip to content

Commit

Permalink
Allow setting of colour to measurements and clutter in Plotterly
Browse files Browse the repository at this point in the history
Fixes #984
  • Loading branch information
sdhiscocks committed Apr 23, 2024
1 parent 12ccccc commit 92dc07f
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 55 deletions.
4 changes: 2 additions & 2 deletions docs/examples/KalmanFilterOOSMExample.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@

plotter = AnimatedPlotterly(timesteps=time_steps)
plotter.plot_ground_truths(truth, [0, 2])
plotter.plot_measurements(measurements1, [0, 2], marker=dict(color='blue', symbol='0'),
plotter.plot_measurements(measurements1, [0, 2], marker=dict(color='blue'),
measurements_label='Detections with no lag')
plotter.plot_measurements(measurements2, [0, 2], marker=dict(color='orange', symbol='0'),
plotter.plot_measurements(measurements2, [0, 2], marker=dict(color='orange'),
measurements_label='Detections with lag')
plotter.plot_sensors([sensor1_platform, sensor2_platform],
marker=dict(color='black', symbol='129', size=15),
Expand Down
14 changes: 2 additions & 12 deletions docs/examples/track_fusion_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
from stonesoup.types.array import CovarianceMatrix
from stonesoup.simulator.simple import SingleTargetGroundTruthSimulator
from stonesoup.models.clutter.clutter import ClutterModel
from stonesoup.types.detection import Clutter

# Instantiate the radars to collect measurements - Use a :class:`~.RadarBearingRange` radar.#
from stonesoup.sensor.radar.radar import RadarBearingRange
Expand Down Expand Up @@ -185,19 +184,10 @@

# Plot the detections from the two radars
plotter = Plotterly()
plotter.plot_measurements([d for ds in s1_detections for d in ds if not isinstance(d, Clutter)],
[0, 2], marker=dict(color='red'), measurements_label='Sensor 1 measurements')

plotter.plot_measurements([d for ds in s1_detections for d in ds if isinstance(d, Clutter)],
[0, 2], marker=dict(color='red', symbol='star-triangle-up'),
plotter.plot_measurements(s1_detections, [0, 2], marker=dict(color='red'),
measurements_label='Sensor 1 measurements')

plotter.plot_measurements([d for ds in s2_detections for d in ds if not isinstance(d, Clutter)],
[0, 2], marker=dict(color='blue'), measurements_label='Sensor 2 measurements')
plotter.plot_measurements([d for ds in s2_detections for d in ds if isinstance(d, Clutter)],
[0, 2], marker=dict(color='blue', symbol='star-triangle-up'),
plotter.plot_measurements(s2_detections, [0, 2], marker=dict(color='blue'),
measurements_label='Sensor 2 measurements')

plotter.plot_sensors({sensor1_platform, sensor2_platform}, [0, 1],
marker=dict(color='black', symbol='1', size=10))
plotter.plot_ground_truths(truths, [0, 2])
Expand Down
83 changes: 42 additions & 41 deletions stonesoup/plotter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import warnings
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from enum import IntEnum
from itertools import chain
from typing import Collection, Iterable, Union, List, Optional, Tuple, Dict

import numpy as np
from matplotlib import animation as animation
from matplotlib import pyplot as plt
from matplotlib.legend_handler import HandlerPatch
from matplotlib.lines import Line2D
from matplotlib.patches import Ellipse
from mergedeep import merge
from scipy.integrate import quad
from scipy.optimize import brentq
from scipy.stats import kde
Expand All @@ -21,19 +24,15 @@
except ImportError:
go = None

from .base import Base, Property
from .models.base import LinearModel, Model
from .types import detection
from .types.groundtruth import GroundTruthPath
from .types.array import StateVector
from .types.groundtruth import GroundTruthPath
from .types.metric import SingleTimeMetric
from .types.state import State, StateMutableSequence
from .types.update import Update

from .base import Base, Property

from .models.base import LinearModel, Model

from enum import IntEnum


class Dimension(IntEnum):
"""Dimension Enum class for specifying plotting parameters in the Plotter class.
Expand Down Expand Up @@ -285,9 +284,11 @@ def plot_measurements(self, measurements, mapping, measurement_model=None,
self.legend_dict[measurements_label] = measurements_handle

if plot_clutter:
clutter_kwargs = kwargs.copy()
clutter_kwargs.update(dict(marker='2'))
clutter_array = np.array(list(plot_clutter.values()))
artists.append(self.ax.scatter(*clutter_array.T, color='y', marker='2'))
clutter_handle = Line2D([], [], linestyle='', marker='2', color='y')
artists.append(self.ax.scatter(*clutter_array.T, **clutter_kwargs))
clutter_handle = Line2D([], [], linestyle='', **clutter_kwargs)
clutter_label = "Clutter"

# Generate legend items for clutter
Expand Down Expand Up @@ -998,7 +999,7 @@ def __init__(self, dimension=Dimension.TWO, axis_labels=None, **kwargs):
if self.dimension == 3:
layout_kwargs.update(dict(scene_aspectmode='data')) # auto shapes fig to fit data well

layout_kwargs.update(kwargs)
merge(layout_kwargs, kwargs)

# Generate plot axes
self.fig = go.Figure(layout=layout_kwargs)
Expand Down Expand Up @@ -1054,7 +1055,7 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwa
if self.dimension == 3: # make ground truth line thicker so easier to see in 3d plot
truths_kwargs.update(dict(line=dict(width=8, dash="longdashdot")))

truths_kwargs.update(kwargs)
merge(truths_kwargs, kwargs)
add_legend = truths_kwargs['legendgroup'] not in {trace.legendgroup
for trace in self.fig.data}

Expand Down Expand Up @@ -1142,7 +1143,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None,
if self.dimension == 3: # make markers smaller in 3d plot
measurement_kwargs.update(dict(marker=dict(size=4, color='#636EFA')))

measurement_kwargs.update(kwargs)
merge(measurement_kwargs, kwargs)
if measurement_kwargs['legendgroup'] not in {trace.legendgroup
for trace in self.fig.data}:
measurement_kwargs['showlegend'] = True
Expand Down Expand Up @@ -1175,43 +1176,43 @@ def plot_measurements(self, measurements, mapping, measurement_model=None,

if plot_clutter:
name = measurements_label + "<br>(Clutter)"
measurement_kwargs = dict(
clutter_kwargs = dict(
mode='markers', marker=dict(symbol="star-triangle-up", color='#FECB52'),
name=name, legendgroup=name, legendrank=210)

if self.dimension == 3: # update - star-triangle-up not in 3d plotly
measurement_kwargs.update(dict(marker=dict(size=4, symbol="diamond",
color='#FECB52')))

measurement_kwargs.update(kwargs)
if measurement_kwargs['legendgroup'] not in {trace.legendgroup
for trace in self.fig.data}:
measurement_kwargs['showlegend'] = True
merge(clutter_kwargs, kwargs)
if clutter_kwargs['legendgroup'] not in {trace.legendgroup
for trace in self.fig.data}:
clutter_kwargs['showlegend'] = True
else:
measurement_kwargs['showlegend'] = False
clutter_kwargs['showlegend'] = False
clutter_array = np.asarray(list(plot_clutter.values()), dtype=np.float64)

if self.dimension == 1:
self.fig.add_scatter(
x=[state.timestamp for state in plot_clutter.keys()],
y=clutter_array[:, 0],
text=[self._format_state_text(state) for state in plot_clutter.keys()],
**measurement_kwargs,
**clutter_kwargs,
)
elif self.dimension == 2:
self.fig.add_scatter(
x=clutter_array[:, 0],
y=clutter_array[:, 1],
text=[self._format_state_text(state) for state in plot_clutter.keys()],
**measurement_kwargs,
**clutter_kwargs,
)
elif self.dimension == 3:
self.fig.add_scatter3d(
x=clutter_array[:, 0],
y=clutter_array[:, 1],
z=clutter_array[:, 2],
text=[self._format_state_text(state) for state in plot_clutter.keys()],
**measurement_kwargs,
**clutter_kwargs,
)

def get_next_color(self):
Expand Down Expand Up @@ -1283,7 +1284,7 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_

if self.dimension == 3: # change visuals to work well in 3d
track_kwargs.update(dict(line=dict(width=7)), marker=dict(size=4))
track_kwargs.update(kwargs)
merge(track_kwargs, kwargs)
add_legend = track_kwargs['legendgroup'] not in {trace.legendgroup
for trace in self.fig.data}

Expand Down Expand Up @@ -1484,7 +1485,7 @@ def plot_sensors(self, sensors, mapping=[0, 1], sensor_label="Sensors", **kwargs

sensor_kwargs = dict(mode='markers', marker=dict(symbol='x', color='black'),
legendgroup=sensor_label, legendrank=50)
sensor_kwargs.update(kwargs)
merge(sensor_kwargs, kwargs)

sensor_kwargs['name'] = sensor_label
if sensor_kwargs['legendgroup'] not in {trace.legendgroup
Expand Down Expand Up @@ -1573,7 +1574,7 @@ def plot_state_sequence(self, state_sequences, angle_mapping: int, range_mapping
plotting_kwargs = dict(
mode="markers", legendgroup=label, legendrank=200,
name=label, thetaunit="radians")
plotting_kwargs.update(kwargs)
merge(plotting_kwargs, kwargs)
add_legend = plotting_kwargs['legendgroup'] not in {trace.legendgroup
for trace in self.fig.data}

Expand Down Expand Up @@ -1620,7 +1621,7 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwa
``line=dict(dash="dash")``.
"""
truths_kwargs = dict(mode="lines", line=dict(dash="dash"), legendrank=100)
truths_kwargs.update(kwargs)
merge(truths_kwargs, kwargs)
angle_mapping = mapping[0]
if len(mapping) > 1:
range_mapping = mapping[1]
Expand Down Expand Up @@ -1680,7 +1681,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None,
if plot_detections:
name = measurements_label + "<br>(Detections)"
measurement_kwargs = dict(mode='markers', marker=dict(color='#636EFA'), legendrank=200)
measurement_kwargs.update(kwargs)
merge(measurement_kwargs, kwargs)
plotting_data = [State(state_vector=plotting_state_vector,
timestamp=det.timestamp)
for det, plotting_state_vector in plot_detections.items()]
Expand All @@ -1691,16 +1692,16 @@ def plot_measurements(self, measurements, mapping, measurement_model=None,

if plot_clutter:
name = measurements_label + "<br>(Clutter)"
measurement_kwargs = dict(mode='markers', legendrank=210,
marker=dict(symbol="star-triangle-up", color='#FECB52'))
measurement_kwargs.update(kwargs)
clutter_kwargs = dict(mode='markers', legendrank=210,
marker=dict(symbol="star-triangle-up", color='#FECB52'))
merge(clutter_kwargs, kwargs)
plotting_data = [State(state_vector=plotting_state_vector,
timestamp=det.timestamp)
for det, plotting_state_vector in plot_clutter.items()]

self.plot_state_sequence(state_sequences=[plotting_data], angle_mapping=angle_mapping,
range_mapping=range_mapping, label=name,
**measurement_kwargs)
**clutter_kwargs)

def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_label="Tracks",
**kwargs):
Expand Down Expand Up @@ -1734,7 +1735,7 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_
raise NotImplementedError

track_kwargs = dict(mode='markers+lines', legendrank=300)
track_kwargs.update(kwargs)
merge(track_kwargs, kwargs)
angle_mapping = mapping[0]
if len(mapping) > 1:
range_mapping = mapping[1]
Expand Down Expand Up @@ -2460,7 +2461,7 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth",
truth_kwargs = dict(x=[], y=[], mode="lines", hoverinfo='none', legendgroup=truths_label,
line=dict(dash="dash", color=self.colorway[0]), legendrank=100,
name=truths_label, showlegend=True)
truth_kwargs.update(kwargs)
merge(truth_kwargs, kwargs)
# legend dummy trace
self.fig.add_trace(go.Scatter(truth_kwargs))

Expand All @@ -2469,9 +2470,8 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth",

for n, _ in enumerate(truths):
# change the colour of each truth and include n in its name
truth_kwargs.update({
"line": dict(dash="dash", color=self.colorway[n % len(self.colorway)])})
truth_kwargs.update(kwargs)
merge(truth_kwargs, dict(line=dict(color=self.colorway[n % len(self.colorway)])))
merge(truth_kwargs, kwargs)
self.fig.add_trace(go.Scatter(truth_kwargs)) # add to traces

for frame in self.fig.frames:
Expand Down Expand Up @@ -2613,7 +2613,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None,
legendgroup='Detections (Measurements)',
legendrank=200, showlegend=True,
marker=dict(color="#636EFA"), hoverinfo='none')
measurement_kwargs.update(kwargs)
merge(measurement_kwargs, kwargs)

self.fig.add_trace(go.Scatter(measurement_kwargs)) # trace for legend

Expand All @@ -2622,11 +2622,12 @@ def plot_measurements(self, measurements, mapping, measurement_model=None,

# change necessary kwargs to initialise clutter trace
name = measurements_label + "<br>(Clutter)"
measurement_kwargs.update({"legendgroup": 'Clutter', "legendrank": 300,
"marker": dict(symbol="star-triangle-up", color='#FECB52'),
"name": name, 'showlegend': True})
clutter_kwargs = {"legendgroup": 'Clutter', "legendrank": 300,
"marker": dict(symbol="star-triangle-up", color='#FECB52'),
"name": name, 'showlegend': True}
merge(clutter_kwargs, kwargs)

self.fig.add_trace(go.Scatter(measurement_kwargs)) # trace for plotting clutter
self.fig.add_trace(go.Scatter(clutter_kwargs)) # trace for plotting clutter

# add data to frames
for frame in self.fig.frames:
Expand Down Expand Up @@ -2968,7 +2969,7 @@ def plot_sensors(self, sensors, sensor_label="Sensors", resize=True, **kwargs):
sensor_kwargs = dict(mode='markers', marker=dict(symbol='x', color='black'),
legendgroup=sensor_label, legendrank=50,
name=sensor_label, showlegend=True)
sensor_kwargs.update(kwargs)
merge(sensor_kwargs, kwargs)

self.fig.add_trace(go.Scatter(sensor_kwargs)) # initialises trace

Expand Down

0 comments on commit 92dc07f

Please sign in to comment.