From c78f2a4e870ec72abd34b2f344af76d9af1c7bf5 Mon Sep 17 00:00:00 2001 From: pesslovany Date: Tue, 11 Jun 2024 11:22:01 +0200 Subject: [PATCH 01/16] added lagrangian pmf and tan in its raw form --- .DS_Store | Bin 0 -> 6148 bytes 01_terrain_aided_navigation.py | 309 ++++++++++++++++++ 02_ParticleFilter.py | 284 ++++++++++++++++ docs/.DS_Store | Bin 0 -> 6148 bytes docs/examples/.DS_Store | Bin 0 -> 8196 bytes docs/source/.DS_Store | Bin 0 -> 6148 bytes docs/source/_templates/breadcrumbs.html | 3 - docs/tutorials/.DS_Store | Bin 0 -> 6148 bytes functions.py | 199 +++++++++++ pokus1.py | 45 +++ stonesoup/.DS_Store | Bin 0 -> 12292 bytes stonesoup/dataassociator/.DS_Store | Bin 0 -> 6148 bytes stonesoup/functions/.DS_Store | Bin 0 -> 6148 bytes stonesoup/functions/__init__.py | 37 +++ stonesoup/gater/.DS_Store | Bin 0 -> 6148 bytes stonesoup/hypothesiser/.DS_Store | Bin 0 -> 6148 bytes stonesoup/initiator/.DS_Store | Bin 0 -> 6148 bytes stonesoup/measures/.DS_Store | Bin 0 -> 6148 bytes stonesoup/models/.DS_Store | Bin 0 -> 6148 bytes stonesoup/models/measurement/.DS_Store | Bin 0 -> 6148 bytes stonesoup/models/measurement/nonlinear.py | 95 +++++- .../models/measurement/tests/test_models.py | 38 +-- stonesoup/movable/.DS_Store | Bin 0 -> 6148 bytes stonesoup/movable/grid.py | 6 +- stonesoup/platform/.DS_Store | Bin 0 -> 6148 bytes stonesoup/platform/base.py | 28 -- .../platform/tests/test_platform_base.py | 123 ------- stonesoup/plotter.py | 4 +- stonesoup/predictor/.DS_Store | Bin 0 -> 6148 bytes stonesoup/predictor/pointMass.py | 189 +++++++++++ stonesoup/regulariser/.DS_Store | Bin 0 -> 6148 bytes stonesoup/resampler/.DS_Store | Bin 0 -> 6148 bytes stonesoup/sampler/.DS_Store | Bin 0 -> 6148 bytes stonesoup/sensor/.DS_Store | Bin 0 -> 6148 bytes stonesoup/tracker/.DS_Store | Bin 0 -> 6148 bytes stonesoup/types/.DS_Store | Bin 0 -> 6148 bytes stonesoup/types/state.py | 37 +++ stonesoup/updater/.DS_Store | Bin 0 -> 6148 bytes stonesoup/updater/kalman.py | 39 +-- stonesoup/updater/pointMass.py | 122 +++++++ stonesoup/updater/recursive.py | 20 +- stonesoup/updater/tests/test_kalman.py | 81 +++-- terrain_aided_navigation.py | 309 ++++++++++++++++++ terrain_aided_navigation.py.lprof | Bin 0 -> 979 bytes 44 files changed, 1713 insertions(+), 255 deletions(-) create mode 100644 .DS_Store create mode 100644 01_terrain_aided_navigation.py create mode 100644 02_ParticleFilter.py create mode 100644 docs/.DS_Store create mode 100644 docs/examples/.DS_Store create mode 100644 docs/source/.DS_Store create mode 100644 docs/tutorials/.DS_Store create mode 100644 functions.py create mode 100644 pokus1.py create mode 100644 stonesoup/.DS_Store create mode 100644 stonesoup/dataassociator/.DS_Store create mode 100644 stonesoup/functions/.DS_Store create mode 100644 stonesoup/gater/.DS_Store create mode 100644 stonesoup/hypothesiser/.DS_Store create mode 100644 stonesoup/initiator/.DS_Store create mode 100644 stonesoup/measures/.DS_Store create mode 100644 stonesoup/models/.DS_Store create mode 100644 stonesoup/models/measurement/.DS_Store create mode 100644 stonesoup/movable/.DS_Store create mode 100644 stonesoup/platform/.DS_Store create mode 100644 stonesoup/predictor/.DS_Store create mode 100644 stonesoup/predictor/pointMass.py create mode 100644 stonesoup/regulariser/.DS_Store create mode 100644 stonesoup/resampler/.DS_Store create mode 100644 stonesoup/sampler/.DS_Store create mode 100644 stonesoup/sensor/.DS_Store create mode 100644 stonesoup/tracker/.DS_Store create mode 100644 stonesoup/types/.DS_Store create mode 100644 stonesoup/updater/.DS_Store create mode 100644 stonesoup/updater/pointMass.py create mode 100644 terrain_aided_navigation.py create mode 100644 terrain_aided_navigation.py.lprof diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..08a30397bf998b9bd0d3fa6910467d2626341266 GIT binary patch literal 6148 zcmeHKy-or_5T4ZwL9xJ&M2m}smP%`4vf%jyJOBz9FoC-Zh(FyufbXI4J+!veSXxgO{fX?!L1JSrW+vJFj@g-=?>^XFmWV`s+^G_kiKvLC&|5~AVQO5pDO<3fRiMyw zwCRZ2)T3i+PDQhW6<`JOrhrAcAJ26o-$(zB}gDzk9B{z8CJ5Hs85!?W}RPt{{n2TA~r1qH>7$9|6Y{ zrx@(q>}E%~Txx3@7n>`ZOU8UItND>GKa%>#k}=;|V;J_MunS%|DcEwMfc++lrChLK z6086#kS7Jy{@|l2lok_%dh0-^uK+*~Zf$thXCu&84xqG{7=#BVOe)Z%3VX#6CLMm? z#z~8bL6c5GSH?WHvamN4p{v91>u?g%Aor{QE09)T(R?hc^Z#J>{Xd<=pI8A_;9n^q zyr2=(@k;h=oq9QS*7|6NXv$PvVo;}`v&XT1P)BhEO&i91^#GI>6NB);_>X{=fjd^< HPZjtAU;3M8 literal 0 HcmV?d00001 diff --git a/01_terrain_aided_navigation.py b/01_terrain_aided_navigation.py new file mode 100644 index 000000000..d4c06d405 --- /dev/null +++ b/01_terrain_aided_navigation.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python + +# ===================================== +# 4 - Sampling methods: particle filter +# ===================================== +# """ + + + +# %% +# +# Nearly-constant velocity example +# -------------------------------- +# We continue in the same vein as the previous tutorials. +# +# Ground truth +# ^^^^^^^^^^^^ +# Import the necessary libraries + +import numpy as np +import matplotlib.pyplot as plt +import time + +from datetime import datetime +from datetime import timedelta + + +# Initialise Stone Soup ground-truth and transition models. +from stonesoup.models.transition.linear import CombinedLinearGaussianTransitionModel, \ + ConstantVelocity +from stonesoup.models.transition.linear import KnownTurnRate +from stonesoup.types.groundtruth import GroundTruthPath, GroundTruthState +from stonesoup.types.detection import Detection +from stonesoup.models.measurement.nonlinear import TerrainAidedNavigation +from stonesoup.models.measurement.linear import LinearGaussian +from scipy.interpolate import RegularGridInterpolator +from stonesoup.predictor.particle import ParticlePredictor +from stonesoup.resampler.particle import ESSResampler +from stonesoup.resampler.particle import MultinomialResampler +from stonesoup.updater.particle import ParticleUpdater +from stonesoup.functions import gridCreation +from numpy.linalg import inv +from stonesoup.types.state import PointMassState +from stonesoup.types.hypothesis import SingleHypothesis +from stonesoup.types.track import Track +from stonesoup.types.state import GaussianState + +from stonesoup.predictor.pointMass import PointMassPredictor +from stonesoup.updater.pointMass import PointMassUpdater +from scipy.stats import multivariate_normal + +from stonesoup.predictor.kalman import KalmanPredictor +from stonesoup.updater.kalman import KalmanUpdater + +from stonesoup.types.numeric import Probability # Similar to a float type +from stonesoup.types.state import ParticleState +from stonesoup.types.array import StateVectors +import json + + + +# Initialize arrays to store RMSE values +matrixTruePMF = [] +matrixTruePF = [] +matrixTrueKF = [] +MC = 10 +for mc in range(0,MC): + print(mc) + start_time = datetime.now().replace(microsecond=0) + + # %% + + #np.random.seed(1991) + + # %% + + + transition_model = KnownTurnRate(turn_noise_diff_coeffs = [2, 2], turn_rate = np.deg2rad(30)) + + # This needs to be done in other way + time_difference = timedelta(days=0, hours=0, minutes=0, seconds=1) + + + timesteps = [start_time] + truth = GroundTruthPath([GroundTruthState([36569, 50, 55581, 50], timestamp=start_time)]) + + # %% + # Create the truth path + for k in range(1, 20): + timesteps.append(start_time+timedelta(seconds=k)) + truth.append(GroundTruthState( + transition_model.function(truth[k-1], noise=True, time_interval=timedelta(seconds=1)), + timestamp=timesteps[k])) + + + # %% + # Initialise the bearing, range sensor using the appropriate measurement model. + + # Open the JSON file + with open('/Users/matoujak/Desktop/file.json', 'r') as file: + # Load JSON data + data = json.load(file) + + map_x = data['x'] + map_y = data['y'] + map_z = data['z'] + + map_x = np.array(map_x) + map_y = np.array(map_y) + map_z = np.matrix(map_z) + + + interpolator = RegularGridInterpolator((map_x[0,:],map_y[:,0]), map_z) + + + + measurement_model = TerrainAidedNavigation(interpolator,noise_covar = 1, mapping=(0, 2)) + # matrix = np.array([ + # [1, 0], + # [0, 1], + # ]) + # measurement_model = LinearGaussian(ndim_state = 4, mapping = (0, 2), noise_covar = matrix) + + # %% + # Populate the measurement array + measurements = [] + for state in truth: + measurement = measurement_model.function(state, noise=True) + measurements.append(Detection(measurement, timestamp=state.timestamp, + measurement_model=measurement_model)) + + + + predictor = ParticlePredictor(transition_model) + resampler = MultinomialResampler() + updater = ParticleUpdater(measurement_model, resampler) + + + predictorKF = KalmanPredictor(transition_model) + updaterKF = KalmanUpdater(measurement_model) + + + + # %% + # Initialise a prior + # ^^^^^^^^^^^^^^^^^^ + # To start we create a prior estimate. This is a :class:`~.ParticleState` which describes + # the state as a distribution of particles using :class:`~.StateVectors` and weights. + # This is sampled from the Gaussian distribution (using the same parameters we + # had in the previous examples). + + number_particles = 10000 + + # Sample from the prior Gaussian distribution + samples = multivariate_normal.rvs(np.array([36569, 50, 55581, 50]), + np.diag([90, 5, 160, 5]), + size=number_particles) + + # Create prior particle state. + prior = ParticleState(state_vector=StateVectors(samples.T), + weight=np.array([Probability(1/number_particles)]*number_particles), + timestamp=start_time) + + priorKF = GaussianState([[36569], [50], [55581], [50]], np.diag([90, 5, 160, 5]), timestamp=start_time) + + # %% PMF prior + + pmfPredictor = PointMassPredictor(transition_model) + pmfUpdater = PointMassUpdater(measurement_model) + # Initial condition - Gaussian + nx = 4 + meanX0 = np.array([36569, 50, 55581, 50]) # mean value + varX0 = np.diag([90, 5, 160, 5]) # variance + Npa = np.array([31, 31, 27, 27]) # 33 number of points per axis, for FFT must be ODD!!!! + N = np.prod(Npa) # number of points - total + sFactor = 4 # scaling factor (number of sigmas covered by the grid) + + + [predGrid, predGridDelta, gridDimOld, xOld, Ppold] = gridCreation(np.vstack(meanX0),varX0,sFactor,nx,Npa) + + meanX0 = np.vstack(meanX0) + pom = predGrid-np.matlib.repmat(meanX0,1,N) + denominator = np.sqrt((2*np.pi)**nx)*np.linalg.det(varX0) + pompom = np.sum(-0.5*np.multiply(pom.T@inv(varX0),pom.T),1) #elementwise multiplication + pomexp = np.exp(pompom) + predDensityProb = pomexp/denominator # Adding probabilities to points + predDensityProb = predDensityProb/(sum(predDensityProb)*np.prod(predGridDelta)) + + priorPMF = PointMassState(state_vector=StateVectors(predGrid), + weight=predDensityProb, + grid_delta = predGridDelta, + grid_dim = gridDimOld, + center = xOld, + eigVec = Ppold, + Npa = Npa, + timestamp=start_time) + + F = transition_model.matrix(prior=prior, time_interval=time_difference) + Q = transition_model.covar(time_interval=time_difference) + + + + priorPMF = PointMassState(state_vector=StateVectors(predGrid), + weight=predDensityProb, + grid_delta = predGridDelta, + grid_dim = gridDimOld, + center = xOld, + eigVec = Ppold, + Npa = Npa, + timestamp=start_time) + + + matrixPMF = [] + + start_time = time.time() + track = Track() + for measurement in measurements: + prediction = pmfPredictor.predict(priorPMF, timestamp=measurement.timestamp) + hypothesis = SingleHypothesis(prediction, measurement) + post = pmfUpdater.update(hypothesis) + priorPMF = post + matrixPMF.append(post.mean) + # print(post.mean) + + # Record the end time + end_time = time.time() + + # Calculate the elapsed time + # print(end_time - start_time) + + + # matrixKF = [] + + # start_time = time.time() + # track = Track() + # for measurement in measurements: + # prediction = predictorKF.predict(priorKF, timestamp=measurement.timestamp) + # hypothesis = SingleHypothesis(prediction, measurement) + # post = updaterKF.update(hypothesis) + # priorKF = post + # matrixKF.append(post.mean) + # # print(post.mean) + + # # Record the end time + # end_time = time.time() + + # %% + # Run the tracker + # ^^^^^^^^^^^^^^^ + # We now run the predict and update steps, propagating the collection of particles and resampling + # when told to (at every step). + + matrixPF = [] + start_time = time.time() + track = Track() + for measurement in measurements: + prediction = predictor.predict(prior, timestamp=measurement.timestamp) + hypothesis = SingleHypothesis(prediction, measurement) + post = updater.update(hypothesis) + # print(post.mean) + track.append(post) + matrixPF.append(post.mean) + prior = track[-1] + + # Record the end time + end_time = time.time() + + # Calculate the elapsed time + # print(end_time - start_time) + + + for ind in range(0,20): + matrixTruePMF.append(np.ravel(np.vstack(matrixPMF[ind])-truth.states[ind].state_vector)) + matrixTruePF.append(np.ravel(matrixPF[ind]-truth.states[ind].state_vector)) + # matrixTrueKF.append(np.ravel(matrixKF[ind]-truth.states[ind].state_vector)) + + +def rmse(errors): + """ + Calculate the Root Mean Square Error (RMSE) from a list of errors. + + Args: + errors (list): List of errors. + + Returns: + float: RMSE value. + """ + # Convert the list of errors into a numpy array for easier computation + errors_array = np.array(errors) + + # Square the errors + squared_errors = np.square(errors_array) + + # Calculate the mean squared error + mean_squared_error = np.mean(squared_errors,0) + + # Calculate the root mean squared error + rmse_value = np.sqrt(mean_squared_error) + + return rmse_value + + +print(rmse(matrixTruePF)) +print(rmse(matrixTruePMF)) +# print(rmse(matrixTrueKF)) + + + + diff --git a/02_ParticleFilter.py b/02_ParticleFilter.py new file mode 100644 index 000000000..020aa578c --- /dev/null +++ b/02_ParticleFilter.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python + +""" +===================================== +4 - Sampling methods: particle filter +===================================== +""" + + + +# %% +# +# Nearly-constant velocity example +# -------------------------------- +# We continue in the same vein as the previous tutorials. +# +# Ground truth +# ^^^^^^^^^^^^ +# Import the necessary libraries + +import numpy as np +import matplotlib.pyplot as plt +import time + +from datetime import datetime +from datetime import timedelta + + +from stonesoup.functions import gridCreation +from numpy.linalg import inv +from stonesoup.types.state import PointMassState +from stonesoup.types.hypothesis import SingleHypothesis +from stonesoup.types.track import Track + +from stonesoup.predictor.pointMass import PointMassPredictor +from stonesoup.updater.pointMass import PointMassUpdater + +from stonesoup.models.transition.linear import CombinedLinearGaussianTransitionModel, \ + ConstantVelocity +from stonesoup.types.groundtruth import GroundTruthPath, GroundTruthState +from stonesoup.models.measurement.nonlinear import CartesianToBearingRange +from stonesoup.types.detection import Detection + + +from stonesoup.predictor.particle import ParticlePredictor + +from stonesoup.resampler.particle import ESSResampler + +from stonesoup.updater.particle import ParticleUpdater + + +from scipy.stats import multivariate_normal + +from stonesoup.types.numeric import Probability # Similar to a float type +from stonesoup.types.state import ParticleState +from stonesoup.types.array import StateVectors + + + +timePMF = [] +timePF = [] + +# %% + +#np.random.seed(1991) + +# %% +# Initialise Stone Soup ground-truth and transition models. + +kf = 10 +matrixTruePMF = [] +matrixTruePF = [] +MC = 10 +for mc in range(0,MC): + print(mc) + start_time = datetime.now().replace(microsecond=0) + transition_model = CombinedLinearGaussianTransitionModel([ConstantVelocity(1), + ConstantVelocity(1)]) + + # This needs to be done in other way + time_difference = timedelta(days=0, hours=0, minutes=0, seconds=1) + + + timesteps = [start_time] + truth = GroundTruthPath([GroundTruthState([48, 1, -5, 1], timestamp=start_time)]) + + # %% + # Create the truth path + for k in range(1, kf): + timesteps.append(start_time+timedelta(seconds=k)) + truth.append(GroundTruthState( + transition_model.function(truth[k-1], noise=True, time_interval=timedelta(seconds=1)), + timestamp=timesteps[k])) + + + # %% + # Initialise the bearing, range sensor using the appropriate measurement model. + + + sensor_x = 50 + sensor_y = 0 + + measurement_model = CartesianToBearingRange( + ndim_state=4, + mapping=(0, 2), + noise_covar=np.diag([np.radians(0.1), 0.1]), + translation_offset=np.array([[sensor_x], [sensor_y]]) + ) + + # %% + # Populate the measurement array + measurements = [] + for state in truth: + measurement = measurement_model.function(state, noise=True) + measurements.append(Detection(measurement, timestamp=state.timestamp, + measurement_model=measurement_model)) + + + + resampler = ESSResampler() + updater = ParticleUpdater(measurement_model, resampler) + predictor = ParticlePredictor(transition_model) + updater = ParticleUpdater(measurement_model, resampler) + # %% + # Initialise a prior + # ^^^^^^^^^^^^^^^^^^ + # To start we create a prior estimate. This is a :class:`~.ParticleState` which describes + # the state as a distribution of particles using :class:`~.StateVectors` and weights. + # This is sampled from the Gaussian distribution (using the same parameters we + # had in the previous examples). + + + number_particles = 130000 + + # Sample from the prior Gaussian distribution + samples = multivariate_normal.rvs(np.array([48, 1, -5, 1]), + np.diag([1, 0.5, 1, 0.5]), + size=number_particles) + + # Create prior particle state. + prior = ParticleState(state_vector=StateVectors(samples.T), + weight=np.array([Probability(1/number_particles)]*number_particles), + timestamp=start_time) + + # %% PMF prior + + pmfPredictor = PointMassPredictor(transition_model) + pmfUpdater = PointMassUpdater(measurement_model) + # Initial condition - Gaussian + nx = 4 + meanX0 = np.array([48, 1, -5, 1]) # mean value + varX0 = np.diag([1, 0.5, 1, 0.5]) # variance + Npa = np.array([19, 19, 19, 19]) # number of points per axis, for FFT must be ODD!!!! + N = np.prod(Npa) # number of points - total + sFactor = 4 # scaling factor (number of sigmas covered by the grid) + + + [predGrid, predGridDelta, gridDimOld, xOld, Ppold] = gridCreation(np.vstack(meanX0),varX0,sFactor,nx,Npa) + + meanX0 = np.vstack(meanX0) + pom = predGrid-np.matlib.repmat(meanX0,1,N) + denominator = np.sqrt((2*np.pi)**nx)*np.linalg.det(varX0) + pompom = np.sum(-0.5*np.multiply(pom.T@inv(varX0),pom.T),1) #elementwise multiplication + pomexp = np.exp(pompom) + predDensityProb = pomexp/denominator # Adding probabilities to points + predDensityProb = predDensityProb/(sum(predDensityProb)*np.prod(predGridDelta)) + + priorPMF = PointMassState(state_vector=StateVectors(predGrid), + weight=predDensityProb, + grid_delta = predGridDelta, + grid_dim = gridDimOld, + center = xOld, + eigVec = Ppold, + Npa = Npa, + timestamp=start_time) + + F = transition_model.matrix(prior=prior, time_interval=time_difference) + Q = transition_model.covar(time_interval=time_difference) + + FqF = np.linalg.inv(F)@Q@np.linalg.inv(F.T) + + + priorPMF = PointMassState(state_vector=StateVectors(predGrid), + weight=predDensityProb, + grid_delta = predGridDelta, + grid_dim = gridDimOld, + center = xOld, + eigVec = Ppold, + Npa = Npa, + timestamp=start_time) + + + matrixPMF = [] + + start_time = time.time() + track = Track() + for measurement in measurements: + prediction = pmfPredictor.predict(priorPMF, timestamp=measurement.timestamp) + hypothesis = SingleHypothesis(prediction, measurement) + post = pmfUpdater.update(hypothesis) + priorPMF = post + + matrixPMF.append(post.mean) + #print(post.mean) + + # Record the end time + end_time = time.time() + + # Calculate the elapsed time + timePMF.append(end_time - start_time) + + + # %% + # Run the tracker + # ^^^^^^^^^^^^^^^ + # We now run the predict and update steps, propagating the collection of particles and resampling + # when told to (at every step). + + matrixPF = [] + start_time = time.time() + track = Track() + for measurement in measurements: + prediction = predictor.predict(prior, timestamp=measurement.timestamp) + hypothesis = SingleHypothesis(prediction, measurement) + post = updater.update(hypothesis) + #print(post.mean) + track.append(post) + matrixPF.append(post.mean) + prior = track[-1] + + # Record the end time + end_time = time.time() + + # Calculate the elapsed time + timePF.append(end_time - start_time) + + + for ind in range(0,kf): + matrixTruePF.append(np.ravel(matrixPF[ind]-truth.states[ind].state_vector)) + # print(np.vstack(matrixPF[ind])) + + for ind in range(0,kf): + matrixTruePMF.append(np.ravel(np.vstack(matrixPMF[ind])-truth.states[ind].state_vector)) + # print(np.vstack(matrixPMF[ind])) + + + # for ind in range(0,kf): + # print(truth.states[ind].state_vector) + + + +def rmse(errors): + """ + Calculate the Root Mean Square Error (RMSE) from a list of errors. + + Args: + errors (list): List of errors. + + Returns: + float: RMSE value. + """ + # Convert the list of errors into a numpy array for easier computation + errors_array = np.array(errors) + + # Square the errors + squared_errors = np.square(errors_array) + + # Calculate the mean squared error + mean_squared_error = np.mean(squared_errors) + + # Calculate the root mean squared error + rmse_value = np.sqrt(mean_squared_error) + + return rmse_value + + +print('PF rmse', rmse(matrixTruePF)) +print('PMF rmse', rmse(matrixTruePMF)) +print('PF s', np.mean(timePF)) +print('PMF s', np.mean(timePMF)) + + + + diff --git a/docs/.DS_Store b/docs/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..28ed39947384e21cde2abf87863c6a86232d68a9 GIT binary patch literal 6148 zcmeHKy-EW?5T4abjL{-kMX)#?n24+Lzk^%X3|Lacm*H2Mxc zf)6015&UL%HMhA->_lY-cE7#ZnZ55m?sjg8NYzH|Dp84u3TTX85nYDyJhzmsSPv&F zy~dz>UTw6Zu$PJzhbSNl{5}Qv+s#p%Iuuew?)O_++i_>Zs9tY&qb9tCtLyFkhnJ_r z`H1#65xtAV4pBKqIyRRMjp+;wdep_(Ws#qrWW^-8OmF=0-sUrNNv#z|Ge*ne9A1S9 zmZCfjpvj2()S`~FvrTS!?{nspe2RCwpM2(GW?UAu|IFetq!APTq&teVRb*Z-r%=YKazu0#P*;9n`Ayn3Tv!z0<+ntC{{wLaP@ q8XL#O26YNLTaNXBOK}-Z8|Hi-00WD$L3m*DBcNrFP89f21wH^#uaT7i literal 0 HcmV?d00001 diff --git a/docs/examples/.DS_Store b/docs/examples/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..20f8592cc74c91040103ae54139faa141aefc235 GIT binary patch literal 8196 zcmeHMziSjh6n=9a++jith>*_3!YT+U?ALIIrG|8ZkV`I}!Q0*Qet>oh)`DFs!Bzrd zYh`C`qhKM3jZHxCFYtRahchtqj+yW6eD9mtnY;}ViS1c`i)e$0+Nf;T zSI|@x9_Lyr4c)T|E8tHwp(A>LdO*DrI&=gDf&xK-pg>R{DDW>RfIFK@+jH-$(a;73 zf&%}g0(?G1sBB|xb7kexfkvhPuuXK!g5Sskq)()6tZlBWbj34O_8_#Y&?$yc?ii04 z4jXHmE34c|D0dQiW}!0_A!mnQB;h1tWkVYj2nv)H;M{!;`T7tU`^3rLr^w?oN{553 z-GeL{Dc7fY9X`oAo&9094_y1x^Q%v;uYR~y@6$iG&*-7e>jkEuOVmdcQ&bg8hPaMo zPemlJFV}n6v7Ej#vDD0yWJNTKd%5v^0Tgng-9nzR4?n<28DI;vYpBvF1?dP6j z^Yf!zKHn6d1Wcyj!e4mi_}sfK`8WuhPr7FDnL`zN`B5$(xgDF>pBeUS;wtm?2bh>K zo5TA}!@(LaX0IA|=2f>#n?`tFGAP0CNJ;~D{>Rrae+gsrc-J*NDsTm#vv3YCdaKrF zjP>c^EvVQpH(tv1aS%4AXFG;d3oC3r$KEe~l*=b(KC4)pL+CSsR%7htnLJ-gXG-_x zSy-E6>%QUWeYVYZ5xmQLk9n^oH)g)eS}rK!8=U`l{ontWHN3z(C=e7_KmpP0>~{DP)46rYrGcIZ zbq|#b<8o!C3mS%Szf8wr%|8tBJJMsZwz;wxLHX++0_Oe}an5rUVeOq9?tgQqhWlUc I;^k|815Yj>SpWb4 literal 0 HcmV?d00001 diff --git a/docs/source/.DS_Store b/docs/source/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..f7bfcf8ee86195c1907d789a9377b6e378ddb3d0 GIT binary patch literal 6148 zcmeHK%Sr<=6uq&9ma5RD3x&`{UH1bDV;L82y6H*~>=P9_Orb4wIXeYEL~!L#DEbNF z#y{{EJU5B$OlR79a@AG5xVE$ATwCUxLAX%3 ziJQ=A)=TK@&=g8T= {{ _('Edit on GitHub') }} diff --git a/docs/tutorials/.DS_Store b/docs/tutorials/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..29ba630d4a27ffbae1471a19b03bd37e866f458e GIT binary patch literal 6148 zcmeHKu};G<5IsW?5e#HxLYZJhmH2>Ag@LV7f1sf)rBc&WZQ1h&`~?Fm3$QW3KQQnQ zd;;%mE7+u^g#jUSSDl}8xwD_26uTxO)wt}|i9#X@;Ec5`G;@sKvsY}*Mz~n%F;bf1 zms5uZm23%!0;0fQQ-HtS9qLj_Gr0Txa%6ySc{J?TTb(=`8|uxC;gqptwK8p7Jp54i)3EXEcKgLnf|rW9yOmAzspQ;v9K^J0sIK~ql3UOtqaS=k$k z(z9dyNT-vE4N5BthyuO>>*g`!^Z#V={ogN=Cs9BY_*V+3TGC1ySdu+kD~sc^HiEa| pY@AmZ)Fo)_ajZLh6z{^d;a$WXU~I84hzLyn2xu9k5e0r#fp}AC literal 0 HcmV?d00001 diff --git a/functions.py b/functions.py new file mode 100644 index 000000000..d7c2b9030 --- /dev/null +++ b/functions.py @@ -0,0 +1,199 @@ +import numpy as np +from scipy.stats import mvn +import itertools +from numpy import linalg as LA +import scipy.special as sciSpec +from scipy.interpolate import RegularGridInterpolator +from scipy.signal import fftconvolve + + +def boxvertex(n, bound): + bound = np.flipud(bound) + vertices = np.zeros((2**n, n)) + for k in range(2**n): + for d in range(n): + if k & (1 << d): + vertices[k, d] = bound[d] + else: + vertices[k, d] = -bound[d] + return vertices + + + +def measPdfPrepFFT(measPdf, gridDimOld, predMeanEst, predVarEst, F, sFactor, nx, Npa, k): + # Setup the measurement grid + eigVal, eigVect = np.linalg.eig(predVarEst) # eigenvalue and eigenvectors, for setting up the grid + gridBoundWant = np.sqrt(eigVal) * sFactor # Wanted boundaries of pred grid + gridBoundWantCorners = np.dot(boxvertex(nx, gridBoundWant), eigVect.T).T + predMeanEst # Wanted corner of predictive grid + gridBoundWantCorners = np.dot(np.linalg.inv(F), gridBoundWantCorners) # Back to filtering space + maxF = np.max(gridBoundWantCorners, axis=1) # Min/Max meas corners + minF = np.min(gridBoundWantCorners, axis=1) + gridDim = [] + gridStep = np.zeros((nx, 1)) + for ind3 in range(nx): # Creation of filtering grid so that it creates wanted predictive grid + gridDim.append(np.linspace(minF[ind3], maxF[ind3], Npa)) + gridStep[ind3] = abs(gridDim[ind3][0] - gridDim[ind3][1]) + measGridNew = np.array(np.meshgrid(*gridDim)).reshape(nx, -1, order='C') + + + GridDelta = gridStep # Grid step size + GridDelta = np.squeeze(GridDelta) + + # Interpolation + Fint = RegularGridInterpolator(gridDimOld, measPdf.reshape(Npa, Npa, order='F'), method="linear", bounds_error=False, fill_value=0) + if k == 0: + filtGridInterpInvTrsf = measGridNew.T + else: + filtGridInterpInvTrsf = np.dot(np.linalg.inv(F), measGridNew).T + measPdf = Fint(filtGridInterpInvTrsf) + + gridDimOld = gridDim + + # # Unpack x, y coordinates from measGrid + # x_coords, y_coords = measGridNew + + # # Plot the data as a scatter plot + # plt.figure() + # plt.scatter(x_coords, y_coords, c=measPdf, cmap='viridis') + + + return measPdf, gridDimOld, GridDelta, measGridNew + + + +def pmfUpdateFFT(F, measPdf, measGridNew, GridDelta, k, Npa, invQ, predDenDenomW, nx): + # Predictive grid + predGrid = np.dot(F, measGridNew) + + # Grid step size + GridDelta[:, k+1] = np.dot(F, GridDelta[:, k]) + + # ULTRA FAST PMF + filtDenDOTprodDeltas = np.dot(measPdf, np.prod(GridDelta[:, k])) # measurement PDF * measurement PDF step size + filtDenDOTprodDeltasCub = np.reshape(filtDenDOTprodDeltas, (Npa, Npa), order='C') # Into physical space + + halfGrid = (np.ceil(predGrid.shape[1] / 2)-1).astype(int) + + pom = np.transpose(predGrid[:, halfGrid][:, np.newaxis] - predGrid) # Middle row of the TPM matrix + TPMrow = (np.exp(np.sum(-0.5 * pom @ invQ * pom, axis=1)) / predDenDenomW).reshape(1, -1, order='C') # Middle row of the TPM matrix + TPMrowCubPom = np.reshape(TPMrow, (Npa, Npa), order='F') # Into physical space + + # Compute the convolution using scipy.signal.fftconvolve + convolution_result_complex = fftconvolve(filtDenDOTprodDeltasCub, TPMrowCubPom, mode='same') + + # Take the real part of the convolution result to get a real-valued result + convolution_result_real = np.real(convolution_result_complex).T + + + predDensityProb = np.reshape(convolution_result_real, (-1,1), order='F') + predDensityProb = predDensityProb / (np.sum(predDensityProb) * np.prod(GridDelta[:, k+1])) # Normalization (theoretically not needed) + + + return predDensityProb, predGrid, GridDelta + + +def ukfUpdate(measVar,nx,kappa,measMean,ffunct,k,Q): + # UKF prediction for grid placement + S = np.linalg.cholesky(measVar) #lower choleski + decomp = np.sqrt(nx+kappa)*S + rep = np.matlib.repmat(measMean.T,2*nx,1).T + np.c_[decomp,-decomp] #concatenate + chi = np.c_[measMean, rep] + wUKF = np.array(np.c_[kappa,np.matlib.repmat(0.5,1,2*nx)])/(nx+kappa) #weights + + Y = ffunct(chi, np.zeros((nx,1)),k) + xp_aux = Y @ wUKF.T + Ydiff = Y - xp_aux + Pp_aux = np.multiply(Ydiff,np.matlib.repmat(wUKF,nx,1))@Ydiff.T+Q.T # UKF prediction var + return xp_aux,Pp_aux + + +def gridCreationFFT(xp_aux, Pp_aux, sFactor, nx, Npa): + # Boundaries of grid + gridBound = np.sqrt(np.diag(Pp_aux)) * sFactor + + # Creation of propagated grid + gridDim = [] + gridStep = np.zeros((nx, 1)) + for ind3 in range(nx): + gridDim.append(np.linspace(-gridBound[ind3], gridBound[ind3], Npa) + xp_aux[ind3]) + gridStep[ind3] = abs(gridDim[ind3][0] - gridDim[ind3][1]) + + # Grid rotation by eigenvectors and translation to the counted unscented mean + predGrid = np.array(np.meshgrid(*gridDim)).reshape(nx, -1, order='C') + + # Grid step size + predGridDelta = np.squeeze(gridStep) + + return predGrid, gridDim, predGridDelta + + +def gridCreation(xp_aux,Pp_aux,sFactor,nx,Npa): + gridDim = np.zeros((nx,Npa)) + gridStep = np.zeros(nx) + eigVal,eigVect = LA.eig(Pp_aux) # eigenvalue and eigenvectors for setting up the grid + gridBound = np.sqrt(eigVal)*sFactor #Boundaries of grid + + for ind3 in range(0,nx): #Creation of propagated grid + gridDim[ind3] = np.linspace(-gridBound[ind3], gridBound[ind3], Npa) #New grid with middle in 0 + gridStep[ind3] = np.absolute(gridDim[ind3][0] - gridDim[ind3][1]) #Grid step + + combvec_predGrid = np.array(list(itertools.product(*gridDim))) + predGrid_pom = np.dot(combvec_predGrid,eigVect).T + size_pom = np.size(predGrid_pom,1) + predGrid = predGrid_pom + np.matlib.repmat(xp_aux,1,size_pom) #Grid rotation by eigenvectors and traslation to the counted unscented mean + predGridDelta = gridStep # Grid step size + return predGrid,predGridDelta + + +def pmfMeas(predGrid,nz,k,z,invR,predDenDenomV,predDensityProb,predGridDelta,hfunct): + predThrMeasEq = hfunct(predGrid,np.zeros((nz,1)),k+1) #Prediction density grid through measurement EQ + pom = np.matlib.repmat(z,np.size(predThrMeasEq.T,0),1)-predThrMeasEq.T #Measurement - measurementEQ(Grid) + citatel = np.exp(np.sum(-0.5*np.multiply(pom @ invR,pom),1)) + filterDensityNoNorm = np.multiply(citatel / predDenDenomV ,predDensityProb.T) + filterDensityNoNorm = filterDensityNoNorm.T + measPdf = (filterDensityNoNorm / np.sum(np.prod(predGridDelta)*filterDensityNoNorm,0)) + return measPdf + +def pmfUpdateSTD(measGrid,measPdf,predGridDelta,ffunct,predGrid,nx,k,invQ,predDenDenomW,N): + fitDenDOTprodDeltas = measPdf*np.prod(predGridDelta[:,k]) # measurement PDF * measurement PDF step size + gridNext = ffunct(measGrid,np.zeros((nx,1)),k+1) # Old grid through dynamics + + predDensityProb = np.zeros((N,1)) + for ind2 in range(0,N): #Over number of state of prediction grid + pom = (predGrid[:,ind2].T-(gridNext).T) + suma = np.sum(-0.5*np.multiply(pom@invQ,pom),1) + predDensityProb[ind2,0] = (np.exp(suma)/predDenDenomW).T@fitDenDOTprodDeltas + predDensityProb = predDensityProb/(np.sum(predDensityProb)*np.prod(predGridDelta[:,k+1])) # Normalizaton (theoretically not needed) + return predDensityProb + + +def pmfUpdateDWC(invF,predGrid,measGrid,predGridDelta,Qa,cnormHere,measPdf,N,k): + predDensityProb = np.zeros((N,N)) + for i in range(0,N): # Unecessary for cycle, for clearer understanding + ma = invF @ predGrid[:,i] + for n in range(0,N): + lowerBound = np.array([measGrid[:,n]-predGridDelta[:,k]/2]).T # boundary of rectangular region M + upperBound = np.array([measGrid[:,n]+predGridDelta[:,k]/2]).T + cdfAct = mvn.mvnun(lowerBound,upperBound,ma,Qa)[0] #Integral calculation + predDensityProb[i,n] = cnormHere*cdfAct*measPdf[n] # Predictive density + predDensityProb = np.sum(predDensityProb,1) + predDensityProb = predDensityProb/(np.sum(predDensityProb)*np.prod(predGridDelta[:,k+1])) # Normalizaton (theoretically not needed) + return predDensityProb + + +def pmfUpdateDiagDWC(measGrid,N,predGridDelta,F,predGrid,Q,s2,measPdf,normDiagDWC,k): + predDensityProb = np.zeros((N,N)) + bound = np.zeros((np.size(measGrid,0),np.size(measGrid,0))) # boundary of rectangular region M + + for i in range(0,N): + for n in range(0,N): + bound = np.array([measGrid[:,n]-predGridDelta[:,k]/2,measGrid[:,n]+ predGridDelta[:,k]/2]).T + pom = np.array([ np.divide(-F@bound[:,0] + predGrid[:,i],np.sqrt(np.diag(Q))), np.divide(-F@bound[:,1] + predGrid[:,i],np.sqrt(np.diag(Q))) ]).T # NESEDI DIMENZE!!!!!!!!!! + erfAct = np.prod((0.5 - 0.5*sciSpec.erf(pom[:,1]/s2)) - (0.5 - 0.5*sciSpec.erf(pom[:,0]/s2))) + predDensityProb[i,n] = measPdf[n] * normDiagDWC * erfAct # Predictive density + predDensityProb = np.sum(predDensityProb,1) + predDensityProb = predDensityProb/(np.sum(predDensityProb)*np.prod(predGridDelta[:,k+1])) # Normalizaton (theoretically not needed) + return predDensityProb + + + diff --git a/pokus1.py b/pokus1.py new file mode 100644 index 000000000..1f2bca21c --- /dev/null +++ b/pokus1.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Tue Apr 2 08:53:07 2024 + +@author: matoujak +""" + +from stonesoup.types.state import PointMassState +from stonesoup.functions import gridCreationFFT +import numpy as np +from numpy.linalg import inv +from stonesoup.types.array import StateVectors +from datetime import datetime +import numpy.matlib +start_time = datetime.now().replace(microsecond=0) + + +# Initial condition - Gaussian +nx = 2 +meanX0 = np.array([20, 5]) # mean value +varX0 = np.array([[0.1, 0], [0, 0.1]]) # variance +Npa = np.array([21, 21]) # number of points per axis, for FFT must be ODD!!!! +N = np.prod(Npa) # number of points - total +sFactor = 4 # scaling factor (number of sigmas covered by the grid) + + +[predGrid, gridDimOld, predGridDelta] = gridCreationFFT(np.vstack(meanX0),varX0,sFactor,nx,Npa) + +meanX0 = np.vstack(meanX0) +pom = predGrid-np.matlib.repmat(meanX0,1,N) +denominator = np.sqrt((2*np.pi)**nx)*np.linalg.det(varX0) +pompom = np.sum(-0.5*np.multiply(pom.T@inv(varX0),pom.T),1) #elementwise multiplication +pomexp = np.exp(pompom) +predDensityProb = pomexp/denominator # Adding probabilities to points +predDensityProb = predDensityProb/(sum(predDensityProb)*np.prod(predGridDelta)) + +np.hstack(predGrid @ predDensityProb * np.prod(predGridDelta)) + +prior = PointMassState(state_vector=StateVectors(predGrid), + weight=predDensityProb, + grid_delta = predGridDelta, + timestamp=start_time) + +a = prior.covar \ No newline at end of file diff --git a/stonesoup/.DS_Store b/stonesoup/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..90155ce47b69117b5fe1dc7eb911a0d80fdaa9fc GIT binary patch literal 12292 zcmeI2O=u)V6vtmrb}}<0#(2`*L!1P0^`Pvc2p+^qh>Lj8gxy8Z!{%dVlhI5jPBNN3 z84rGdAi0SiBzP1AmGxo}1B+pW1wH7&gCOMM$FeM<2f+pZuezuEy{_t>2;#x59=dC$ z>;I~L^}4#c>h%y&Fx9P}B)W%)%Gfy_9me4`qUsP*v2i}{);54M0rm~-_&!Qhrv|Oi zdaFG>ZK>@O1)O60BiIM2MQv%nu+m<@`Euj_V?X~fxIcMr8NY_;96oDwffn)0Z{7pm z1KtDP1KtDP1OMCu{OrY6PCf6#v;OEk;5~3tdw|D>Dt1n5QMwhTX6V3yr|_K5Q+TuC zJJtc_Ts2B-QMwhT#uPTj6lrK0L`#eyG{^ZW%aPWibSq3la|%Lp3Zj~WXoV7(+F@5U zatdl;>W|(7-UGG=rf@2?_O4H-*WPK(+VaWMi|y8$DX#P*Cnd9ff05>C0TF=4%92pf zr-(G96&xez)pKTeV)O3d{YjT9p~+S$o^XIDBTBFWnhiRSxI%|oiYU1K2I{gjlIgD9 z+pa`R4qZZ(r_)?DR7n+Zbr2<~Qxh?a4Q5khg-@Pxt0dd#@z`IE@>=bLE?=qm2DX`R z(9q4BNQ@>6F7c47@+H$NN5_wU@coy34O5zgCQG9@3Li1Ykl|ROuDOaGcxxSA(?s<2 zscfZWd*rPaF`HGEB+Jt&F5I5jir{I0 zZG+a;_)8Zr>0HTjbu#XI!d_Ci6}tZGTz?5QchJ)o>3Q2L7gubB z6Fx7wJs02__(b`mbe+crkt;c#Sd2N8BI(y|g{r?2`J;>4>+lF)7d_4we>iLWp%jwi z?7J~X33q}hNBSdMrFcp;REF{CC0w7TgNE+Cn$DLjKV6wfo`a$cNs^^e9AQ6Y7QcB9 zcn^3Fcn^3F{M#PjmAQPY&NC98r{O)jng@{e{#_{L^KZ2MZD;AgtOGO3l~2O{YR5Bi zB$_A8W45r4PaDilGjp)k5@U~}eYtwczBeDQbNbjMGn~WPsCE?g_&S!%DjU%f*DRxU z##TtyN5}5$%`;NvNdIQ56c24PPucMB{L2-Ypa(y5X(ZFceKC_&mZZwlX)ZHYA4CkN zfe-(?M9&QA#&2L!`I2Y$!|`5RDWyrS-!F49v(hMzUSF_WV_>7{(7;w#KBUlHq8*Zi4YXzN_bQ{9buguUc&D^&gT4c9-6k%QL| zy|o^%OdV01zh2MNO17o1Vz$h(CFx=GiaRjQp8@ zE?=qnw%M;`te_iatm}qRCe`bLTE(T2Y;PaLY*ty)e|b91wZmMu8ZKVPax9Z$CaZjh zx<}$SzwU1#nP;!X%zf&SEjYZDYR|w4ZqEqX)1*y|#-X3XIBgegv&Nl=-{_~8ZJ9b& z?TyvBb;y<+)<(61=C~ayp0D~!@%Q(*Eo=i_k;VA$AHV;9Ec^d|M&u9P1KtDwPY(n~ zrcY1vi@-1xXg88y{8X{eVQ0-AcHZ#anBc$?{$XCnBQ>w%k(z-u%Bw}`R+u6e+kgH? W0Qv6HKmYMx?PO-+pZ_<{^Z#!}v#|dF literal 0 HcmV?d00001 diff --git a/stonesoup/dataassociator/.DS_Store b/stonesoup/dataassociator/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..dcb9c6d38c50c76b836d6be3249052f9270909c3 GIT binary patch literal 6148 zcmeHKJ5B>J5S<}GT8Snl6>>+N8!TZtK`Mk4K?sB*tBn-seI7av!6CQ+H77v4@eGMA zR+I?AjATE1zPnFadxwa4@~{{ajfkki5M){Oh^Tu_d*&=gmK@{hb=R(YQv6<(Jo|uJ zx}~etGJpQrY-ee^rkSt0IW*P9>(TM!%kxQYU3_C*-{n~oIs0@^H?+aLDQiBRB5P*N z>z==F>lPln+Wj77{WlbFaR!_LXTTZw=M3P^7O4&teRc+%0cT*#fSeBjL$EfCis|UU zlv)7b5b7jYOD`cY!LT-rim*UfLxCE~)?%=RLm$kqHjIiIPHfEwTjkH@g;RCp56PXl zR`l5!a0W^S2D%-|{eOl}rnkv2Lwx59I0Ju-0UkEfW`dWpyYLRr~==d;}|h!SmT_ z6QxCBL4a(@_bxvB&UshGH4*Xfel{TL6Hx>e?CfFKAhIsnl8#wqk;5JtT~kJLx}kB) z+YP^w0lB*#Zfi^>J>u@x_cvbL4n|p#mzwsQ;1&E$Ej};Sxoq}mo5=ySb?yH0yUJa z#b6DGJvP6@uu{};VrxFwPG;*+IGm3CYHyx!zypq+{*2_t)jiD2$i1<~Cs}M{=DTc3<;sa<9*khRh6T?ao7Kr~52sHTM I4E!hqpK8%hQ2+n{ literal 0 HcmV?d00001 diff --git a/stonesoup/functions/__init__.py b/stonesoup/functions/__init__.py index b7441da72..db7a1e1c9 100644 --- a/stonesoup/functions/__init__.py +++ b/stonesoup/functions/__init__.py @@ -7,6 +7,43 @@ from ..types.numeric import Probability from ..types.array import StateVector, StateVectors, CovarianceMatrix from ..types.state import State +import itertools +from numpy import linalg as LA +from numpy import matlib as matlib + + +def gridCreation(xp_aux,Pp_aux,sFactor,nx,Npa): + gridDim = np.zeros((nx,Npa[0])) + gridStep = np.zeros(nx) + eigVal,eigVect = LA.eig(Pp_aux) # eigenvalue and eigenvectors for setting up the grid + gridBound = np.sqrt(eigVal)*sFactor #Boundaries of grid + + + # Ensure the grid steps are in the right order + I = np.argsort(np.diag(Pp_aux)) + I = np.argsort(I) + + pom = np.sort(gridBound) + gridBound = pom[I] + + Ipom = np.argsort(gridBound) + pom2 = eigVect[:, Ipom] + eigVect = pom2[:, I] + gridDim = [] # Reset gridDim for each cycle + gridStep = [] # Reset gridStep for each cycle + for ind3 in range(0,nx): #Creation of propagated grid + # gridDim[ind3] = np.linspace(-gridBound[ind3], gridBound[ind3], Npa[ind3]) #New grid with middle in 0 + # gridStep[ind3] = np.absolute(gridDim[ind3][0] - gridDim[ind3][1]) #Grid step + gridDim.append(np.linspace(-gridBound[ind3], gridBound[ind3], Npa[ind3])) # New grid with middle in 0 + gridStep.append(np.absolute(gridDim[ind3][0] - gridDim[ind3][1])) # Grid step + + + combvec_predGrid = np.array(list(itertools.product(*gridDim))) + predGrid_pom = np.dot(eigVect,combvec_predGrid.T) + size_pom = np.size(predGrid_pom,1) + predGrid = predGrid_pom + matlib.repmat(xp_aux,1,size_pom) #Grid rotation by eigenvectors and traslation to the counted unscented mean + predGridDelta = gridStep # Grid step size + return predGrid,predGridDelta,gridDim,xp_aux,eigVect def tria(matrix): diff --git a/stonesoup/gater/.DS_Store b/stonesoup/gater/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..b0033fdaa6638caa70cb43a004108816e15632c1 GIT binary patch literal 6148 zcmeHKOHKnZ47H(1BVf@b%dlbJ6A-H51ie62ieS-dDizc&b2T=ci}zxKm@*Q(3L#sH zpW{zv=1G&pBBImRsv)h*y{&fl*0H<#~shZ)OXGB(fO8JeR!Bp3(=f`MQl7&s~eII}6u zg<o4WxH{o3xrM>Pml{pjOZd^0#R_;cgY1jg)0x`tFBU~b(qI^>83SVO!O%%0kAPa&*sUKXi@R>K9a!dXy6oN0 zHF2o^~KxlFk|&g#`f_$L$4^OkO4A42FL&zI4T1; zvrU?FM`|)a2FSn{1GYaT48t5)9P8ErRa*d{U8AeOwqC+&QeX}&j)*|iOMzah(PF5V z!=EBA2Np*!7memaFX4;~kb!^3 zfR5TlJI6=G+4|%2?5s@~w-_dtSImIG?p*>fv3=yS20fn4MqUmqj<$;VYdMe?fl3G! I88`w1AB{*VasU7T literal 0 HcmV?d00001 diff --git a/stonesoup/initiator/.DS_Store b/stonesoup/initiator/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..ae5c9b07d20ce2e2d58f1ceb4febc88360ad15ac GIT binary patch literal 6148 zcmeHKOG*Pl5UtV(17?w>%XZ^FH)z9nf?PmiL>CDYh>FW5x9~QuJceiSRaYra9B>yz zUP0HZu20X*gX!)P5uH6Qry>&(X)p}ci~%wCVCX24M?kG>Y*%;F`AxUp3@r0EUH0xL z@*uaeh2DGrm#2Po-L>sv)h*y{cJJrU7vs0f{fy-=8Jqj>485Y9LI%hH86X2>;GhiP z%rwjVFILFIu0D^;1n3&5qP$fDH5*IQHqp&i^a?GJ{Qizl1X~KnDI9 z13GHw?F=6kXX}s8v$Hl~Tw|D6UNHj#yLSn|#P*T18uWNF8+kdfINB=aujN2q1S%m^ JWZ(b{d;sp@D=Pp1 literal 0 HcmV?d00001 diff --git a/stonesoup/measures/.DS_Store b/stonesoup/measures/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..7c425774bb63cfd36857d176fc1bd12a2b5aa329 GIT binary patch literal 6148 zcmeHKu};H44E2>F6?N&z#%+kPe-Nth1ziwJpeag+Mu{j4Z1*ut{0Be4#sULB!t>c` z6Q!-hf&kf)?_GTMo%61WYa()!hxwRjL_`B9WA6yV7U6Z)j`Y;Ri!Sbw(+%acpj*m1 z-e&lX4Dj6zaa$SH^oYCL++VgV#?!niYe~Df-0fLbX*ye0Gk6SMuio#9&zJLk-t%vG z*NbKr0o&3c-QyRw!nkVZm_4;un|YmmoSZ(t#k?BvWwps=yGQ$cC4;jtF$RnQW8m*H zfSS!R7z)~G3>X8(z?uR6K6of&q8J3@rvpQ50f2p&qhQXx1jl&AL@@|r1>z(WD4|YU z3@72Rd-Y2cgP??y)8@nJ$xb^I7f;9jKDv`j1Z^}1jDa=-2l6=L`hT&z|8ED`lQCcn z{3`}rKb@u%ypq<|&dYJF_0Vf53;PAZbqFT06vLNG@i8%# Ifgfey3;b16tN;K2 literal 0 HcmV?d00001 diff --git a/stonesoup/models/.DS_Store b/stonesoup/models/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..f5a947452f8f91f00a2df45ffd298f9322bfc5b8 GIT binary patch literal 6148 zcmeHKKTiTN9Q_qYB;decqT{*9OycNpl5;oV10eqbBnL+hNp!oa|L88g4;PIBBN_|fI^PZ zq6W37M;)ps*tWttU>*2x4)EG7Vzt(3M3-2*i}hRgyOmli7>tzU!?$$iAn?4VA2h+p z9}REA?(6;WPwe(D*rQ%t#VL!~f=-_%bdFz#LQ|(ROETj6{IXNNno503;|zj@u5^~E z$DQj5Xb3eS1>`U3dDnVkjJTe1_hs%E>oHghJ?Sh{kDRO`4d6?I#_%XaiLao}(3SPn ze6D21_3d6y@1ACQK2z8Wo$J|#-s~g$SCGy%pm9unl>6+0p;ml6r}5%itDpN92ip() z?yVd%KEvy{hR)0l>)f=!p3RXjDzv$Ez&c=4)3RGQ%K4J*f9rcdpxf&ydsyhjN_z?QeLZ48CygSBsB%Oq-(B{?w z>wxJ%R$Z2O|367T|C^ocm36>6uu%?(oLBRzcqF~ICLWIWS_kP2iGz8OLfHkG-i~E~ ex8gk{8EA930CY7*3Q>cwe*~lrHnR@=sRLj2G01!X literal 0 HcmV?d00001 diff --git a/stonesoup/models/measurement/.DS_Store b/stonesoup/models/measurement/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5cf8e9813069d86bfb75238fec42626d2b6d6ff8 GIT binary patch literal 6148 zcmeHKu}T9$5S@t;B56`uZXsCb56*CcWgu9eCXuAbae-)Ici$uU1NQ!c;1}2nHmQ8G zGbDFhBG`z^49vd0*_qwBx8Wv3L_D~g_KA8#)PV}Njxg*HSr=_c%Pg|UagPaI(}d=9 zLs`Sy4Zo2Axw|fIE2Aadg>3C6$;n8^~F0o&3c-QpLv!nmsC$R`)|)oNZ(m*Mo`HRe@`&x>s?yFKdX59NFpXTTY7 z2L3t&sM#W&uA-05fHU9>Y#5O5Lx2hom|tR8Drz{fH6LszvvnvOPsjctx)UdgJ~{)=K%IdDee6m7Ki}W~*Mt1Y z8E^*v6$9K(N9h2sWVN;Na#Cvp^a?5>ex>3j1d~{b;VY&17#aokAQNC>SSrE-@jn8Q L1|OV(A7$VZ8k|*_ literal 0 HcmV?d00001 diff --git a/stonesoup/models/measurement/nonlinear.py b/stonesoup/models/measurement/nonlinear.py index be2f63590..f129a297b 100644 --- a/stonesoup/models/measurement/nonlinear.py +++ b/stonesoup/models/measurement/nonlinear.py @@ -17,6 +17,92 @@ from ...types.angle import Bearing, Elevation, Azimuth from ..base import LinearModel, GaussianModel, ReversibleModel from .base import MeasurementModel +from ...types.state import State + + + +class TerrainAidedNavigation(): + + def __init__(self,interpolator,noise_covar,mapping): + self.interpolator = interpolator + self.noise_covar = noise_covar + self.mapping = mapping + + @property + def ndim_meas(self) -> int: + """ndim_meas getter method + + Returns + ------- + :class:`int` + The number of measurement dimensions + """ + return 1 + + def function(self, state, noise=False, **kwargs) -> StateVector: + + out = self.interpolator(state.state_vector[self.mapping,:].T) + if isinstance(noise, bool) or noise is None: + if noise: + noise = np.random.normal([0], self.noise_covar, out.size) + out = out + noise + + + # Return the interpolated measurements with added noise + return out + + def covar(self, **kwargs) -> CovarianceMatrix: + """Returns the measurement model noise covariance matrix. + + Returns + ------- + :class:`~.CovarianceMatrix` of shape\ + (:py:attr:`~ndim_meas`, :py:attr:`~ndim_meas`) + The measurement noise covariance. + """ + + return self.noise_covar + + def logpdf(self, state1: State, state2: State, **kwargs) -> Union[float, np.ndarray]: + r"""Model log pdf/likelihood evaluation function + + Evaluates the pdf/likelihood of ``state1``, given the state + ``state2`` which is passed to :meth:`function()`. + + In mathematical terms, this can be written as: + + .. math:: + + p = p(y_t | x_t) = \mathcal{N}(y_t; x_t, Q) + + where :math:`y_t` = ``state_vector1``, :math:`x_t` = ``state_vector2`` + and :math:`Q` = :attr:`covar`. + + Parameters + ---------- + state1 : State + state2 : State + + Returns + ------- + : float or :class:`~.numpy.ndarray` + The log likelihood of ``state1``, given ``state2`` + """ + covar = self.covar(**kwargs) + + + # Calculate difference before to handle custom types (mean defaults to zero) + # This is required as log pdf coverts arrays to floats + likelihood = np.atleast_1d( + multivariate_normal.logpdf((state1.state_vector - self.function(state2, **kwargs)).T, + cov=covar)) + + if len(likelihood) == 1: + likelihood = likelihood[0] + + return likelihood + + class CombinedReversibleGaussianMeasurementModel(ReversibleModel, GaussianModel, MeasurementModel): @@ -279,6 +365,8 @@ def rvs(self, num_samples=1, **kwargs) -> Union[StateVector, StateVectors]: out = super().rvs(num_samples, **kwargs) out = np.array([[Elevation(0.)], [Bearing(0.)], [0.]]) + out return out + + class CartesianToBearingRange(NonLinearGaussianMeasurement, ReversibleModel): @@ -424,6 +512,7 @@ def rvs(self, num_samples=1, **kwargs) -> Union[StateVector, StateVectors]: return out + class CartesianToElevationBearing(NonLinearGaussianMeasurement): r"""This is a class implementation of a time-invariant measurement model, \ where measurements are assumed to be received in the form of bearing \ @@ -942,7 +1031,9 @@ def inverse_function(self, detection, **kwargs) -> StateVector: x, y, z = sphere2cart(rho, phi, theta) # because only rho_rate is known, only the components in # x,y and z of the range rate can be found. - x_rate, y_rate, z_rate = sphere2cart(rho_rate, phi, theta) + x_rate = np.cos(phi) * np.cos(theta) * rho_rate + y_rate = np.cos(phi) * np.sin(theta) * rho_rate + z_rate = np.sin(phi) * rho_rate inv_rotation_matrix = inv(self.rotation_matrix) @@ -955,8 +1046,6 @@ def inverse_function(self, detection, **kwargs) -> StateVector: inv_rotation_matrix @ out_vector[self.velocity_mapping, :] out_vector[self.mapping, :] = out_vector[self.mapping, :] + self.translation_offset - out_vector[self.velocity_mapping, :] = out_vector[self.velocity_mapping, :] \ - + self.velocity return out_vector diff --git a/stonesoup/models/measurement/tests/test_models.py b/stonesoup/models/measurement/tests/test_models.py index 25c537e13..293a9f922 100644 --- a/stonesoup/models/measurement/tests/test_models.py +++ b/stonesoup/models/measurement/tests/test_models.py @@ -558,7 +558,7 @@ def h3d_rr(state_vector, pos_map, vel_map, translation_offset, rotation_offset, "h, modelclass, state_vec, ndim_state, pos_mapping, vel_mapping,\ noise_covar, position, orientation", [ - ( # rrRB_1. 3D meas, 6D state + ( # 3D meas, 6D state h2d_rr, # h CartesianToBearingRangeRate, # ModelClass StateVector([[200.], [10.], [0.], [0.], [0.], [0.]]), # state_vec @@ -571,7 +571,7 @@ def h3d_rr(state_vector, pos_map, vel_map, translation_offset, rotation_offset, StateVector([[1], [-1], [0]]), # position (translation offset) StateVector([[0], [0], [1]]) # orientation (rotation offset) ), - ( # rrRB_2. 3D meas, 6D state + ( # 3D meas, 6D state h2d_rr, # h CartesianToBearingRangeRate, # ModelClass StateVector([[200.], [10.], [0.], [0.], [0.], [0.]]), # state_vec @@ -584,7 +584,7 @@ def h3d_rr(state_vector, pos_map, vel_map, translation_offset, rotation_offset, None, # position (translation offset) None # orientation (rotation offset) ), - ( # rrRBE_1, 4D meas, 6D state + ( # 4D meas, 6D state h3d_rr, # h CartesianToElevationBearingRangeRate, # ModelClass StateVector([[200.], [10.], [0.], [0.], [0.], [0.]]), # state_vec @@ -598,7 +598,7 @@ def h3d_rr(state_vector, pos_map, vel_map, translation_offset, rotation_offset, StateVector([[100], [0], [0]]), # position (translation offset) StateVector([[0], [0], [0]]) # orientation (rotation offset) ), - ( # rrRBE_2. 4D meas, 6D state + ( # 4D meas, 6D state h3d_rr, # h CartesianToElevationBearingRangeRate, # ModelClass StateVector([[200.], [10.], [0.], [0.], [0.], [0.]]), # state_vec @@ -611,37 +611,9 @@ def h3d_rr(state_vector, pos_map, vel_map, translation_offset, rotation_offset, [0, 0, 0, 10]]), # noise_covar None, # position (translation offset) None # orientation (rotation offset) - ), - ( # rrRBE_3. 4D meas, 6D state. Changed orientation compared to rrRBE_2. - h3d_rr, # h - CartesianToElevationBearingRangeRate, # ModelClass - StateVector([[200.], [10.], [0.], [0.], [0.], [0.]]), # state_vec - 6, # ndim_state - np.array([0, 2, 4]), # pos_mapping - np.array([1, 3, 5]), # vel_mapping - CovarianceMatrix([[0.05, 0, 0, 0], - [0, 0.05, 0, 0], - [0, 0, 0.015, 0], - [0, 0, 0, 10]]), # noise_covar - None, # position (translation offset) - StateVector([[0], [0], [np.pi / 2]]) # orientation (rotation offset) - ), - ( # rrRBE_4. 4D meas, 6D state. Range rate not alligned with any axis - h3d_rr, # h - CartesianToElevationBearingRangeRate, # ModelClass - StateVector([[300.], [30.], [200.], [20.], [10.], [1.]]), # state_vec - 6, # ndim_state - np.array([0, 2, 4]), # pos_mapping - np.array([1, 3, 5]), # vel_mapping - CovarianceMatrix([[0.05, 0, 0, 0], - [0, 0.05, 0, 0], - [0, 0, 0.015, 0], - [0, 0, 0, 10]]), # noise_covar - StateVector([[0], [0], [0]]), # position (translation offset) - StateVector([[0], [0], [np.pi / 2]]) # orientation (rotation offset), Facing North ) ], - ids=["rrRB_1", "rrRB_2", "rrRBE_1", "rrRBE_2", "rrRBE_3", "rrRBE_4"] + ids=["rrRB_1", "rrRB_2", "rrRBE_1", "rrRBE_2"] ) def test_rangeratemodels(h, modelclass, state_vec, ndim_state, pos_mapping, vel_mapping, noise_covar, position, orientation): diff --git a/stonesoup/movable/.DS_Store b/stonesoup/movable/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..70344b36f1ef73e1e28bd626cdb1c6336f8bf9fc GIT binary patch literal 6148 zcmeHKK~BR!475u@L0o!4>TwT9ocn{N3NPpfpe+KVhOJb%=gtGT@c=Hof=3|n7{+TW zB&3KNLSRewY`pgFW)j6Q5%KhHIUyPmQ2`Yk?89sknHTNJz&tX^d5uIjtINr3k*aks z+8w`<0eN_DTD!O$BoPjO_J9;^i`hWU$|KCmWD`&tN_)`q+JS-pxs^ s4bU^Fh{V;3O$r9N6eCtj@fex~{vZQjG^`cjf%uO=rooLf@S_a80}X~s)c^nh literal 0 HcmV?d00001 diff --git a/stonesoup/movable/grid.py b/stonesoup/movable/grid.py index 3a7c3fe4f..9428eac04 100644 --- a/stonesoup/movable/grid.py +++ b/stonesoup/movable/grid.py @@ -34,6 +34,7 @@ class _GridActionableMovable(FixedMovable): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._next_action = None + self._generator_kwargs = _GridActionableMovable._generator_kwargs def actions(self, timestamp, start_timestamp=None): """Method to return a set of grid action generators available up to a provided timestamp. @@ -58,7 +59,7 @@ def actions(self, timestamp, start_timestamp=None): attribute="position", start_time=start_timestamp, end_time=timestamp, - **{name: getattr(self, name) for name in type(self)._generator_kwargs})) + **{name: getattr(self, name) for name in self._generator_kwargs})) return generators @@ -87,7 +88,8 @@ class NStepDirectionalGridMovable(_GridActionableMovable): of each axis. This movable implements the :class:`~.NStepDirectionalGridActionGenerator`""" generator = NStepDirectionalGridActionGenerator - _generator_kwargs = _GridActionableMovable._generator_kwargs | {'n_steps', 'step_size'} + _generator_kwargs = _GridActionableMovable._generator_kwargs.update({'n_steps', + 'step_size'}) n_steps: int = Property( default=1, diff --git a/stonesoup/platform/.DS_Store b/stonesoup/platform/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..9a53a118d848d81354a3dbc853f5701d6c18c5d5 GIT binary patch literal 6148 zcmeHKJ5Iwu5S@V_S%M;>qFhlxof|BXIYBOff)Nyn94R1C?qR6l9B>DE=s5y!K1A5K zj!-ls?Y!CftnDXxcZZ00`p`^>#za(M2(lOhBJRP^o+pohtZS?nx0Bg*yIc({^FLkk z?gw;FH?)TCz5nM&H@a-=dcJ7q@K!I!r;lexZx{WHW|Pr8rSs?a*Ivfz?l(hUC=Uq+ zf`MQl7zhTo$^g!6k?PPebTAMM1OpoeWPeB)g4wYc>ed0JEdWri(JHW|mynv|m>r8D zA`rGxpp~+>7;NS6C(p}{#n8%$z4>5&@=x=^`gE+HG@Upbh7JaTfi464HXY0Pe~Dja zu*vV1#4H#H26l`A9@VpYijVTM_0Q+YS(`AfFhnG;hyj7!xddP$`^bp~EuO?iUUn>o UvWob%92ge?B_vcZumuLb0RvYo82|tP literal 0 HcmV?d00001 diff --git a/stonesoup/platform/base.py b/stonesoup/platform/base.py index 38bfc03f7..d58a7e444 100644 --- a/stonesoup/platform/base.py +++ b/stonesoup/platform/base.py @@ -1,10 +1,8 @@ -import uuid from typing import MutableSequence from stonesoup.base import Property, Base from stonesoup.movable import Movable, FixedMovable, MovingMovable, MultiTransitionMovable from stonesoup.sensor.sensor import Sensor -from stonesoup.types.groundtruth import GroundTruthPath class Platform(Base): @@ -40,10 +38,6 @@ class Platform(Base): default=None, readonly=True, doc="A list of N mounted sensors. Defaults to an empty list.") - id: str = Property( - default=None, - doc="The unique platform ID. Default `None` where random UUID is generated.") - _default_movable_class = None # Will be overridden by subclasses def __getattribute__(self, name): @@ -85,8 +79,6 @@ def __init__(self, *args, **kwargs): self._property_sensors = [] for sensor in self.sensors: sensor.movement_controller = self.movement_controller - if self.id is None: - self.id = str(uuid.uuid4()) @staticmethod def _tuple_or_none(value): @@ -158,26 +150,6 @@ def orientation(self, value): def __getitem__(self, item): return self.movement_controller.__getitem__(item) - @property - def ground_truth_path(self) -> GroundTruthPath: - """ Produce a :class:`.GroundTruthPath` with the same `id` and `states` as the platform. - - The `states` property for the platform and `ground_truth_path` are dynamically linked: - ``self.ground_truth_path.states is self.states`` - - So after `platform.move()` the `ground_truth_path` will contain the new state. However, - replacing the `id`, `states` or `movement_controller` variables in either the platform or - ground truth path will not be reflected in the other object. - ``platform_gtp = self.ground_truth_path`` - ``platform_gtp.states = []`` - ``self.states is not platform_gtp.states`` - - `Platform.ground_truth_path` produces a new :class:`.GroundTruthPath` on every instance. - It is not an object that is updated - ``self.ground_truth_path.states is not self.ground_truth_path.states`` - """ - return GroundTruthPath(id=self.id, states=self.movement_controller.states) - class FixedPlatform(Platform): _default_movable_class = FixedMovable diff --git a/stonesoup/platform/tests/test_platform_base.py b/stonesoup/platform/tests/test_platform_base.py index 9723f1f99..4ac5675e7 100644 --- a/stonesoup/platform/tests/test_platform_base.py +++ b/stonesoup/platform/tests/test_platform_base.py @@ -12,7 +12,6 @@ from stonesoup.types.array import StateVector from stonesoup.platform import MovingPlatform, FixedPlatform, MultiTransitionMovingPlatform from ...types.state import State -from ...types.groundtruth import GroundTruthPath def test_base(): @@ -700,125 +699,3 @@ def test_platform_getitem(): state_after = platform.state assert platform[0] is state_before assert platform[1] is state_after - - -def test_ground_truth_path(): - timestamp = datetime.datetime.now() - state_before = State(np.array([[2], [1], [2], [1], [0], [1]]), timestamp) - cv_model = CombinedLinearGaussianTransitionModel((ConstantVelocity(0), - ConstantVelocity(0), - ConstantVelocity(0))) - platform = MovingPlatform(states=state_before, - transition_model=cv_model, - position_mapping=[0, 2, 4], velocity_mapping=[1, 3, 5]) - - platform_gtp: GroundTruthPath = platform.ground_truth_path - - # Test the id and states match - assert platform.id == platform_gtp.id - assert platform.state is platform_gtp.state - - # Test the platform states are dynamically linked to the ground truth path state - platform.move(timestamp + datetime.timedelta(seconds=1)) - assert platform.state is platform_gtp.state - assert platform.states is platform_gtp.states - - -def test_default_platform_id(): - fixed_state = State(np.array([[2], [2], [0]]), datetime.datetime.now()) - platform1 = FixedPlatform(states=fixed_state, position_mapping=(0, 1, 2)) - - # Test `id` is string - assert isinstance(platform1.id, str) - - # Test string is suitable long to avoid conflict - assert len(platform1.id) >= 32 - - # Test id isn't replicated - platform2 = FixedPlatform(states=fixed_state, position_mapping=(0, 1, 2)) - assert platform1.id != platform2 - - -test_id_str = "hello" - - -def test_platform_id_assignment(): - fixed_state = State(np.array([[2], [2], [0]]), datetime.datetime.now()) - fixed_platform = FixedPlatform(states=fixed_state, position_mapping=(0, 1, 2)) - fixed_platform.id = test_id_str - assert fixed_platform.id == test_id_str - - -def test_platform_initialised_with_id(): - fixed_state = State(np.array([[2], [2], [0]]), datetime.datetime.now()) - platform = FixedPlatform(id=test_id_str, states=fixed_state, position_mapping=(0, 1, 2)) - assert platform.id == test_id_str - - -def test_ground_truth_path_story(): - # Set Up - timestamp = datetime.datetime.now() - state_before = State(np.array([[2], [1], [2], [1], [0], [1]]), timestamp) - cv_model = CombinedLinearGaussianTransitionModel((ConstantVelocity(0), - ConstantVelocity(0), - ConstantVelocity(0))) - platform = MovingPlatform(states=state_before, - transition_model=cv_model, - position_mapping=[0, 2, 4], velocity_mapping=[1, 3, 5]) - - def are_equal(gtp1: GroundTruthPath, gtp2: GroundTruthPath): - return gtp1.id == gtp2.id and gtp1.states == gtp2.states - - # `Platform.ground_truth_path` produces a GroundTruthPath with `id` and `states` matching the - # platform. - platform_gtp = platform.ground_truth_path - assert isinstance(platform_gtp, GroundTruthPath) - - # Generally `Platform.ground_truth_path` produces equal objects - assert are_equal(platform.ground_truth_path, platform_gtp) - - # This is still true even if the platform moves. This is because they share the same `states` - # object - platform.move(timestamp + datetime.timedelta(seconds=1)) - assert are_equal(platform.ground_truth_path, platform_gtp) - - # However they are not the same object - assert platform.ground_truth_path is not platform_gtp - - # Therefore changing the `id` will result in a different GroundTruthPath - platform.id = test_id_str - assert not are_equal(platform.ground_truth_path, platform_gtp) - - # Reset the id. They should now be equal again - platform.id = platform_gtp.id - assert are_equal(platform.ground_truth_path, platform_gtp) - - # They will remain equal if the `states` property is appended/altered - platform_gtp.states.append(State(np.array([[10], [9], [8], [7], [6], [5]]), - timestamp + datetime.timedelta(seconds=2))) - platform.move(timestamp + datetime.timedelta(seconds=3)) - assert are_equal(platform.ground_truth_path, platform_gtp) - - # However if the `states` property is replaced, they are no longer equal - platform_gtp.states = [] - assert not are_equal(platform.ground_truth_path, platform_gtp) - - -@pytest.mark.xfail -def test_setting_movement_controller_sensors(): - timestamp = datetime.datetime.now() - fixed_state = State(np.array([[2], [2], [0]]), - timestamp) - fixed = FixedMovable(states=fixed_state, position_mapping=(0, 1, 2)) - platform = MovingPlatform(movement_controller=fixed) - - sensor = DummySensor() - platform.add_sensor(sensor) - assert platform.movement_controller is sensor.movement_controller - - moving_state = State(np.array([[2], [1], [2], [-1], [2], [0]]), timestamp) - moving = MovingMovable(states=moving_state, position_mapping=(0, 2, 4), transition_model=None) - - platform.movement_controller = moving - - assert platform.movement_controller is sensor.movement_controller diff --git a/stonesoup/plotter.py b/stonesoup/plotter.py index 26ae3b07e..528597cce 100644 --- a/stonesoup/plotter.py +++ b/stonesoup/plotter.py @@ -1192,8 +1192,8 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, name=name, legendgroup=name, legendrank=210) if self.dimension == 3: # update - star-triangle-up not in 3d plotly - clutter_kwargs.update(dict(marker=dict(size=4, symbol="diamond", - color='#FECB52'))) + measurement_kwargs.update(dict(marker=dict(size=4, symbol="diamond", + color='#FECB52'))) merge(clutter_kwargs, kwargs) if clutter_kwargs['legendgroup'] not in {trace.legendgroup diff --git a/stonesoup/predictor/.DS_Store b/stonesoup/predictor/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..bf511d63fcf003677aa416fe974c96266d529912 GIT binary patch literal 6148 zcmeHKF;2r!4D~NX3MxZK#xDl;+#pop0H_B@X^PsRQ6d$jvSs8F%p8Fez&*G@XC|J% ztu|3wBo+k7mhAWAZzs!UVIgiJASyWQ?%g5o!qD<52vYa9!dcU4L%{Q-?Uq#Qp z5nU~sRcyGJW4gmHYz4b&*Z90{?>6gtjL+8>r*B+WBfltjwd{}B^L(zs#27FJjDf$* z0D3k{G!Qgv3>X8(z?K32K6of&qF4)tPX|mG0f2qDgJ8|G1m}3gM6njc4#Y_+P*RYn21FD?8KV;nJA~I+I})jd7hTvRr(4QtL3fmO zyv^_%8Q{C?*1yno)$*Ke1Hyl3C= zt`^NKJhr7{dcZGig>lu+@%*~I+RSS`JR81#guEK@MY+vow@3T@3qIe(7%&EmfxpfG zYBo#U7qrnBFb0f)4Fmjr@KDA?u@(%U4h*3M0QO)Gf;smR9OD%e#aa+65GSEP33b|H zI0=W{t6!p63raXSZ9bfy?6gC1;dJcpgFCrI&_-jx7-%zaB##5G|5y9_|8|f)83V?^ zzhc1k(n&hTD`{=*yd2kB550l1uwN~>3Be?kV)$|?K7|H>-SY&PDAt0oK;M1& literal 0 HcmV?d00001 diff --git a/stonesoup/resampler/.DS_Store b/stonesoup/resampler/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..a059a4de56d06a9fa7a14b0a2294b323f3cf4b9f GIT binary patch literal 6148 zcmeHKu};H44E2>F1$F7jcudS4kdRnH6~3S!kkS;jL!(3l2DW?;AHWyDMaA~(LDkBEju6hRq#M;NvUud{ZfrxspxVULWiDWfIb zP}=b}!*67O?{0wGN~xp=+}-B>($!)#$%?#`w9l*Eo_UcZvsE#JNA&UZcDs0gx%|$1 zzK3_YY-TaVwsc5$_=Qy%t9FhT&+XM_UY}kOb~vqS?y8;t>Dz!+FFz~2WCWsDVT!SLz85Ly6WALbyKb1%U$UNKgz1+fBg5(<=1 zr!9t)aM->2#fr6{gp~Qv zA9$CGdKMnr(gEG$7q-HSEUCYS9V&GS$Bd=q297%&F@ zI|HcMEYU#FMq|JjFb38P@b|$(8DqstFnl^NgcbnUgE9clZCzAbT|ns_+1)V;5-4pVTx}TO`Je3=BL655OD1JMcDq zK3n-AMQjy9w&eRPpMAc3iQ-&{$n_q^U7`*V1yIJuHii|#^Q<*#sHY8devOo-liO~8 zlqR!Uv^oAF1AOl`>54|Qqlm7kLc zOHhmE*vS|=uRKL=YgC;r?_n^__fX1D(qadzBP#Ws)?d;!Ex^R0QlD;mzJ5MiuK#>` z_;}R*U=CWy)ckV&y{FO+^N31)dO^C0F<=ZB18ZXdJ)0%i5wy`5Fb0f)oB@76giyvv zF%$H!4h;SX0BpkT1vZ~0I44w$6f;50K%ArkCDrK>!$~^)p~gjunV_VT^OdoVR(5(q zalShIp${kdvJ=}zW55_F7}yZA!~OsK=ktG&V3sjp4E#F=Tr=*+{EI_-YwZHL*M`tb pC=17Bf_VxI--;2-t#}{m1^$o?V5FD{!UM4n0Z)Ss#=u$`_yX+iWuX87 literal 0 HcmV?d00001 diff --git a/stonesoup/tracker/.DS_Store b/stonesoup/tracker/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..7647a2d276261c7e59843107c654da772a39e269 GIT binary patch literal 6148 zcmeHKOG*Pl5UtV(1G>o4W#hvA+@KRia32p4Vni1S6NrM#zL3}OT)yfm#fby%g2*fA zde!ymnRzhXT_U3Mmvt*L6_EzRP|X+-bB~5jB6$MTy2gI<*e>q-?QUe5zv;4fKa*#9 zkUjLl``^9|quajgR-1kWZ`0mgyKiz~Wf94yf7!0PPxG1-A7PR+9pAU~xnQqFxI0QjHcvy&V1& zc{#8+dbwyc9~w{oY+kfJ9qXr>E}9*w$p9G`GH~kCsh$5f_+>_${9y@aWPl9(GX`|h zExI{AD$dp)pJ!)n!dPOMSY9y$0()=?z{K{E%Nq1}G8=h0usGT(=C9>IUIZ#3RAk@? G4155JwjcM(UwQk4q2s+&_pVenC%OOl?~1p&@}P=q-5k68R`4P!GcfZuvopJswBzk*>!MQ2^DmWvs<5EvN#`%(mX3%*}?tI_Q))cIBOW0+-#*==1wp!QSY47TA|B36Wr9I~#vPF>ZJGtqUg;393~Flz}<} zEqUy6|3BV-|F0+MlQN(T{3`}bJ08VDypr#&jhEwIYoTK(3&)j$s}u}eD@H80;sdA? Y_$?2Bk;hUH9*F)Ca5Sh<27Z)*Pr|!XcK`qY literal 0 HcmV?d00001 diff --git a/stonesoup/types/state.py b/stonesoup/types/state.py index 9374d1ddb..26466d76a 100644 --- a/stonesoup/types/state.py +++ b/stonesoup/types/state.py @@ -165,7 +165,44 @@ def from_state( target_type = CreatableFromState.class_mapping[cls][state_type] return target_type.from_state(state, *args, **kwargs, target_type=target_type) + + +class PointMassState(State): + + state_vector: StateVectors = Property(doc='State vectors.') + weight: MutableSequence[Probability] = Property(default=None, doc='Masses of grid points') + grid_delta: np.ndarray = Property(default=None, doc='Grid step per dim') + grid_dim: np.ndarray = Property(default=None, doc='Grid coordinates per dimension before rotation and translation') + center: np.ndarray = Property(default=None, doc='Center of the grid') + eigVec: np.ndarray = Property(default=None, doc='Eigenvectors of the grid') + Npa: np.ndarray = Property(default=None, doc='Points per dim') + + + def __len__(self): + return self.state_vector.shape[1] + @property + def ndim(self): + """The number of dimensions represented by the state.""" + return self.state_vector.shape[0] + + @clearable_cached_property('state_vector') + def mean(self): + """Sample mean for particles""" + return np.hstack(self.state_vector @ self.weight * np.prod(self.grid_delta)) + + + #@profile + def covar(self): + # Measurement update covariance + chip_ = self.state_vector - self.mean[:, np.newaxis] + chip_w = chip_ * self.weight.reshape(1, -1, order='C') + measVar = (chip_w @ chip_.T) * np.prod(self.grid_delta) + return measVar + +State.register(PointMassState) # noqa: E305 + + class ASDState(Type): """ASD State type diff --git a/stonesoup/updater/.DS_Store b/stonesoup/updater/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..d64b7b319ab02a5252d73c776ebfa1d1b5e21e5c GIT binary patch literal 6148 zcmeHKF;2r!47EcML1pO3_+enCHwabW0zE*WEovniB~n3xEvI1P4jcfn!OF@pI0DcA zTWzAWNGu4DE!i*exAUF%tGFg29^6d&L_H#kp@Pva%o>q((UwH&kwx~;7}EufX-1cn zwW7`O8yS$Z>*BOBD(Mbqw>rOUKJ5?3MPB-{dFO0%WL~7{XkLsE5kI|MJxo7dPQHqs zd?UJ?)vGv`Vz%iTzpw@DqETaXaNk_5*Y*BDs|QJ^I}!Y`iRE<(WC5dVhk7q z#=u`^06kkI?glh#3>X8(z>)#^J~*ggqF4p=pAHy50suR3d%?BL5|ZN-6U8bZRv;Ww zfgzQ3#Ndz)ztgxxu?iT{iFNp3-I;YlVZS?`?|e9M640zMU<@=F*b28j>Hnvj=l^Ds z-5CSMz`tUEJLxbT;FYwuwq8zpt%IIJMI^2YxJ6_BV4_$Bga=}O N1UwC9jDa6z;1d+}QcwT@ literal 0 HcmV?d00001 diff --git a/stonesoup/updater/kalman.py b/stonesoup/updater/kalman.py index aa9f129e5..c2d6abb5f 100644 --- a/stonesoup/updater/kalman.py +++ b/stonesoup/updater/kalman.py @@ -70,10 +70,6 @@ class KalmanUpdater(Updater): default=False, doc="A flag to force the output covariance matrix to be symmetric by way of a simple " "geometric combination of the matrix and transpose. Default is False.") - use_joseph_cov: bool = Property( - default=False, - doc="Bool dictating the method of covariance calculation. If use_joseph_cov is True then " - "the Joseph form of the covariance equation is used.") def _measurement_matrix(self, predicted_state=None, measurement_model=None, **kwargs): @@ -189,38 +185,11 @@ def _posterior_covariance(self, hypothesis): The Kalman gain, :math:`K = P_{k|k-1} H_k^T S^{-1}` """ - if self.use_joseph_cov: - # Identity matrix - id_matrix = np.identity(hypothesis.prediction.ndim) - - # Calculate Kalman gain - kalman_gain = hypothesis.measurement_prediction.cross_covar @ \ - np.linalg.inv(hypothesis.measurement_prediction.covar) - - measurement_model = self._check_measurement_model( - hypothesis.measurement.measurement_model) - - # Calculate measurement matrix/jacobian matrix - meas_matrix = self._measurement_matrix(hypothesis.prediction, - measurement_model) - - # Calculate Prior covariance - prior_covar = hypothesis.prediction.covar - - # Calculate measurement covariance - meas_covar = measurement_model.covar() - - # Compute posterior covariance matrix - I_KH = id_matrix - kalman_gain @ meas_matrix - post_cov = I_KH @ prior_covar @ I_KH.T \ - + kalman_gain @ meas_covar @ kalman_gain.T - - else: - kalman_gain = hypothesis.measurement_prediction.cross_covar @ \ - np.linalg.inv(hypothesis.measurement_prediction.covar) + kalman_gain = hypothesis.measurement_prediction.cross_covar @ \ + np.linalg.inv(hypothesis.measurement_prediction.covar) - post_cov = hypothesis.prediction.covar - kalman_gain @ \ - hypothesis.measurement_prediction.covar @ kalman_gain.T + post_cov = hypothesis.prediction.covar - kalman_gain @ \ + hypothesis.measurement_prediction.covar @ kalman_gain.T return post_cov.view(CovarianceMatrix), kalman_gain diff --git a/stonesoup/updater/pointMass.py b/stonesoup/updater/pointMass.py new file mode 100644 index 000000000..c218ba392 --- /dev/null +++ b/stonesoup/updater/pointMass.py @@ -0,0 +1,122 @@ +import copy +from functools import lru_cache +from typing import Callable +import warnings + +import numpy as np +from scipy.linalg import inv +from scipy.special import logsumexp + +from .base import Updater +from .kalman import KalmanUpdater, ExtendedKalmanUpdater +from ..base import Property +from ..functions import cholesky_eps, sde_euler_maruyama_integration +from ..predictor.particle import MultiModelPredictor, RaoBlackwellisedMultiModelPredictor +from ..resampler import Resampler +from ..regulariser import Regulariser +from ..types.prediction import ( + Prediction, ParticleMeasurementPrediction, GaussianStatePrediction, MeasurementPrediction) +from ..types.update import ParticleStateUpdate, Update +from scipy.stats import multivariate_normal +from stonesoup.types.state import PointMassState + +import matplotlib.pyplot as plt + + +class PointMassUpdater(Updater): + """Particle Updater + + Perform an update by multiplying particle weights by PDF of measurement + model (either :attr:`~.Detection.measurement_model` or + :attr:`measurement_model`), and normalising the weights. If provided, a + :attr:`resampler` will be used to take a new sample of particles (this is + called every time, and it's up to the resampler to decide if resampling is + required). + """ + sFactor: float = Property( + default=3, + doc="How many sigma to cover by the grid") + resampler: Resampler = Property(default=None, doc='Resampler to prevent particle degeneracy') + regulariser: Regulariser = Property( + default=None, + doc='Regulariser to prevent particle impoverishment. The regulariser ' + 'is normally used after resampling. If a :class:`~.Resampler` is defined, ' + 'then regularisation will only take place if the particles have been ' + 'resampled. If the :class:`~.Resampler` is not defined but a ' + ':class:`~.Regulariser` is, then regularisation will be conducted under the ' + 'assumption that the user intends for this to occur.') + + constraint_func: Callable = Property( + default=None, + doc="Callable, user defined function for applying " + "constraints to the states. This is done by setting the weights " + "of particles to 0 for particles that are not correctly constrained. " + "This function provides indices of the unconstrained particles and " + "should accept a :class:`~.ParticleState` object and return an array-like " + "object of logical indices. " + ) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + #@profile + def update(self, hypothesis, **kwargs): + """Particle Filter update step + + Parameters + ---------- + hypothesis : :class:`~.Hypothesis` + Hypothesis with predicted state and associated detection used for + updating. + + Returns + ------- + : :class:`~.ParticleState` + The state posterior + """ + + predicted_state = Update.from_state( + state=hypothesis.prediction, + hypothesis=hypothesis, + timestamp=hypothesis.prediction.timestamp + ) + + if hypothesis.measurement.measurement_model is None: + measurement_model = self.measurement_model + else: + measurement_model = hypothesis.measurement.measurement_model + + + R = measurement_model.covar() + + x = measurement_model.function(predicted_state) # SLOW LINE 92% guess it is just StoneSoup + pdf_value = multivariate_normal.pdf(x.T,np.ravel(hypothesis.measurement.state_vector), R) + new_weight = np.ravel(hypothesis.prediction.weight) * np.ravel(pdf_value) + + new_weight = new_weight/(np.prod(hypothesis.prediction.grid_delta)*sum(new_weight)) + + predicted_state = PointMassState(state_vector=hypothesis.prediction.state_vector, + weight=new_weight, + grid_delta = hypothesis.prediction.grid_delta, + grid_dim = hypothesis.prediction.grid_dim, + center = hypothesis.prediction.center, + eigVec = hypothesis.prediction.eigVec, + Npa = hypothesis.prediction.Npa, + timestamp = hypothesis.prediction.timestamp) + # plt.plot(new_weight) + + return predicted_state + + @lru_cache() + def predict_measurement(self, state_prediction, measurement_model=None, + **kwargs): + + if measurement_model is None: + measurement_model = self.measurement_model + + new_state_vector = measurement_model.function(state_prediction, **kwargs) + + return MeasurementPrediction.from_state( + state_prediction, state_vector=new_state_vector, timestamp=state_prediction.timestamp) + + diff --git a/stonesoup/updater/recursive.py b/stonesoup/updater/recursive.py index 4ba66d6ed..e0d84e204 100644 --- a/stonesoup/updater/recursive.py +++ b/stonesoup/updater/recursive.py @@ -19,6 +19,10 @@ class BayesianRecursiveUpdater(ExtendedKalmanUpdater): """ number_steps: int = Property(doc="Number of recursive steps", default=1) + use_joseph_cov: bool = Property(doc="Bool dictating the method of covariance calculation. If " + "use_joseph_cov is True then the Joseph form of the " + "covariance equation is used.", + default=False) @classmethod def _get_meas_cov_scale_factor(cls, n=1, step_no=None): @@ -80,24 +84,22 @@ def _posterior_covariance(self, hypothesis, scale_factor=1): kalman_gain = hypothesis.measurement_prediction.cross_covar @ \ np.linalg.inv(hypothesis.measurement_prediction.covar) - measurement_model = self._check_measurement_model( - hypothesis.measurement.measurement_model) - # Calculate measurement matrix/jacobian matrix - meas_matrix = self._measurement_matrix(hypothesis.prediction, - measurement_model) + meas_matrix = self._measurement_matrix(hypothesis.prediction) # Calculate Prior covariance prior_covar = hypothesis.prediction.covar # Calculate measurement covariance - meas_covar = measurement_model.covar() + meas_covar = hypothesis.measurement.measurement_model.covar() # Compute posterior covariance matrix I_KH = id_matrix - kalman_gain @ meas_matrix post_cov = I_KH @ prior_covar @ I_KH.T \ + kalman_gain @ (scale_factor * meas_covar) @ kalman_gain.T + return post_cov.view(CovarianceMatrix), kalman_gain + else: kalman_gain = hypothesis.measurement_prediction.cross_covar @ \ np.linalg.inv(hypothesis.measurement_prediction.covar) @@ -105,7 +107,7 @@ def _posterior_covariance(self, hypothesis, scale_factor=1): post_cov = hypothesis.prediction.covar - kalman_gain @ \ hypothesis.measurement_prediction.covar @ kalman_gain.T - return post_cov.view(CovarianceMatrix), kalman_gain + return post_cov.view(CovarianceMatrix), kalman_gain def update(self, hypothesis, **kwargs): r"""The Kalman update method. Given a hypothesised association between @@ -385,6 +387,10 @@ class VariableStepBayesianRecursiveUpdater(BayesianRecursiveUpdater): """ number_steps: int = Property(doc="Number of recursive steps", default=1) + use_joseph_cov: bool = Property(doc="Bool dictating the method of covariance calculation. If " + "use_joseph_cov is True then the Joseph form of the " + "covariance equation is used.", + default=False) @classmethod def _get_meas_cov_scale_factor(cls, n=1, step_no=None): diff --git a/stonesoup/updater/tests/test_kalman.py b/stonesoup/updater/tests/test_kalman.py index 6e0a848df..332ee7d87 100644 --- a/stonesoup/updater/tests/test_kalman.py +++ b/stonesoup/updater/tests/test_kalman.py @@ -18,24 +18,67 @@ CubatureKalmanUpdater) -@pytest.fixture(params=[KalmanUpdater, ExtendedKalmanUpdater, UnscentedKalmanUpdater, - IteratedKalmanUpdater, SchmidtKalmanUpdater, CubatureKalmanUpdater]) -def updater_class(request): - return request.param - - -@pytest.fixture(params=[True, False]) -def use_joseph_cov(request): - return request.param - - -def test_kalman(updater_class, use_joseph_cov): - measurement_model = LinearGaussian(ndim_state=2, mapping=[0], - noise_covar=np.array([[0.04]])) - prediction = GaussianStatePrediction(np.array([[-6.45], [0.7]]), - np.array([[4.1123, 0.0013], - [0.0013, 0.0365]])) - measurement = Detection(np.array([[-6.23]])) +@pytest.mark.parametrize( + "UpdaterClass, measurement_model, prediction, measurement", + [ + ( # Standard Kalman + KalmanUpdater, + LinearGaussian(ndim_state=2, mapping=[0], + noise_covar=np.array([[0.04]])), + GaussianStatePrediction(np.array([[-6.45], [0.7]]), + np.array([[4.1123, 0.0013], + [0.0013, 0.0365]])), + Detection(np.array([[-6.23]])) + ), + ( # Extended Kalman + ExtendedKalmanUpdater, + LinearGaussian(ndim_state=2, mapping=[0], + noise_covar=np.array([[0.04]])), + GaussianStatePrediction(np.array([[-6.45], [0.7]]), + np.array([[4.1123, 0.0013], + [0.0013, 0.0365]])), + Detection(np.array([[-6.23]])) + ), + ( # Unscented Kalman + UnscentedKalmanUpdater, + LinearGaussian(ndim_state=2, mapping=[0], + noise_covar=np.array([[0.04]])), + GaussianStatePrediction(np.array([[-6.45], [0.7]]), + np.array([[4.1123, 0.0013], + [0.0013, 0.0365]])), + Detection(np.array([[-6.23]])) + ), + ( # Iterated Kalman + IteratedKalmanUpdater, + LinearGaussian(ndim_state=2, mapping=[0], + noise_covar=np.array([[0.04]])), + GaussianStatePrediction(np.array([[-6.45], [0.7]]), + np.array([[4.1123, 0.0013], + [0.0013, 0.0365]])), + Detection(np.array([[-6.23]])) + ), + ( # Schmidt Kalman + SchmidtKalmanUpdater, + LinearGaussian(ndim_state=2, mapping=[0], + noise_covar=np.array([[0.04]])), + GaussianStatePrediction(np.array([[-6.45], [0.7]]), + np.array([[4.1123, 0.0013], + [0.0013, 0.0365]])), + Detection(np.array([[-6.23]])) + ), + ( # Cubature Kalman + CubatureKalmanUpdater, + LinearGaussian(ndim_state=2, mapping=[0], + noise_covar=np.array([[0.04]])), + GaussianStatePrediction(np.array([[-6.45], [0.7]]), + np.array([[4.1123, 0.0013], + [0.0013, 0.0365]])), + Detection(np.array([[-6.23]])) + ), + ], + ids=["standard", "extended", "unscented", "iterated", "schmidtkalman", "cubaturekalman"] +) +def test_kalman(UpdaterClass, measurement_model, prediction, measurement): # Calculate evaluation variables eval_measurement_prediction = GaussianMeasurementPrediction( @@ -54,7 +97,7 @@ def test_kalman(updater_class, use_joseph_cov): - kalman_gain@eval_measurement_prediction.covar @ kalman_gain.T) # Initialise a kalman updater - updater = updater_class(measurement_model=measurement_model, use_joseph_cov=use_joseph_cov) + updater = UpdaterClass(measurement_model=measurement_model) # Get and assert measurement prediction without measurement noise measurement_prediction = updater.predict_measurement(prediction, measurement_noise=False) diff --git a/terrain_aided_navigation.py b/terrain_aided_navigation.py new file mode 100644 index 000000000..d4c06d405 --- /dev/null +++ b/terrain_aided_navigation.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python + +# ===================================== +# 4 - Sampling methods: particle filter +# ===================================== +# """ + + + +# %% +# +# Nearly-constant velocity example +# -------------------------------- +# We continue in the same vein as the previous tutorials. +# +# Ground truth +# ^^^^^^^^^^^^ +# Import the necessary libraries + +import numpy as np +import matplotlib.pyplot as plt +import time + +from datetime import datetime +from datetime import timedelta + + +# Initialise Stone Soup ground-truth and transition models. +from stonesoup.models.transition.linear import CombinedLinearGaussianTransitionModel, \ + ConstantVelocity +from stonesoup.models.transition.linear import KnownTurnRate +from stonesoup.types.groundtruth import GroundTruthPath, GroundTruthState +from stonesoup.types.detection import Detection +from stonesoup.models.measurement.nonlinear import TerrainAidedNavigation +from stonesoup.models.measurement.linear import LinearGaussian +from scipy.interpolate import RegularGridInterpolator +from stonesoup.predictor.particle import ParticlePredictor +from stonesoup.resampler.particle import ESSResampler +from stonesoup.resampler.particle import MultinomialResampler +from stonesoup.updater.particle import ParticleUpdater +from stonesoup.functions import gridCreation +from numpy.linalg import inv +from stonesoup.types.state import PointMassState +from stonesoup.types.hypothesis import SingleHypothesis +from stonesoup.types.track import Track +from stonesoup.types.state import GaussianState + +from stonesoup.predictor.pointMass import PointMassPredictor +from stonesoup.updater.pointMass import PointMassUpdater +from scipy.stats import multivariate_normal + +from stonesoup.predictor.kalman import KalmanPredictor +from stonesoup.updater.kalman import KalmanUpdater + +from stonesoup.types.numeric import Probability # Similar to a float type +from stonesoup.types.state import ParticleState +from stonesoup.types.array import StateVectors +import json + + + +# Initialize arrays to store RMSE values +matrixTruePMF = [] +matrixTruePF = [] +matrixTrueKF = [] +MC = 10 +for mc in range(0,MC): + print(mc) + start_time = datetime.now().replace(microsecond=0) + + # %% + + #np.random.seed(1991) + + # %% + + + transition_model = KnownTurnRate(turn_noise_diff_coeffs = [2, 2], turn_rate = np.deg2rad(30)) + + # This needs to be done in other way + time_difference = timedelta(days=0, hours=0, minutes=0, seconds=1) + + + timesteps = [start_time] + truth = GroundTruthPath([GroundTruthState([36569, 50, 55581, 50], timestamp=start_time)]) + + # %% + # Create the truth path + for k in range(1, 20): + timesteps.append(start_time+timedelta(seconds=k)) + truth.append(GroundTruthState( + transition_model.function(truth[k-1], noise=True, time_interval=timedelta(seconds=1)), + timestamp=timesteps[k])) + + + # %% + # Initialise the bearing, range sensor using the appropriate measurement model. + + # Open the JSON file + with open('/Users/matoujak/Desktop/file.json', 'r') as file: + # Load JSON data + data = json.load(file) + + map_x = data['x'] + map_y = data['y'] + map_z = data['z'] + + map_x = np.array(map_x) + map_y = np.array(map_y) + map_z = np.matrix(map_z) + + + interpolator = RegularGridInterpolator((map_x[0,:],map_y[:,0]), map_z) + + + + measurement_model = TerrainAidedNavigation(interpolator,noise_covar = 1, mapping=(0, 2)) + # matrix = np.array([ + # [1, 0], + # [0, 1], + # ]) + # measurement_model = LinearGaussian(ndim_state = 4, mapping = (0, 2), noise_covar = matrix) + + # %% + # Populate the measurement array + measurements = [] + for state in truth: + measurement = measurement_model.function(state, noise=True) + measurements.append(Detection(measurement, timestamp=state.timestamp, + measurement_model=measurement_model)) + + + + predictor = ParticlePredictor(transition_model) + resampler = MultinomialResampler() + updater = ParticleUpdater(measurement_model, resampler) + + + predictorKF = KalmanPredictor(transition_model) + updaterKF = KalmanUpdater(measurement_model) + + + + # %% + # Initialise a prior + # ^^^^^^^^^^^^^^^^^^ + # To start we create a prior estimate. This is a :class:`~.ParticleState` which describes + # the state as a distribution of particles using :class:`~.StateVectors` and weights. + # This is sampled from the Gaussian distribution (using the same parameters we + # had in the previous examples). + + number_particles = 10000 + + # Sample from the prior Gaussian distribution + samples = multivariate_normal.rvs(np.array([36569, 50, 55581, 50]), + np.diag([90, 5, 160, 5]), + size=number_particles) + + # Create prior particle state. + prior = ParticleState(state_vector=StateVectors(samples.T), + weight=np.array([Probability(1/number_particles)]*number_particles), + timestamp=start_time) + + priorKF = GaussianState([[36569], [50], [55581], [50]], np.diag([90, 5, 160, 5]), timestamp=start_time) + + # %% PMF prior + + pmfPredictor = PointMassPredictor(transition_model) + pmfUpdater = PointMassUpdater(measurement_model) + # Initial condition - Gaussian + nx = 4 + meanX0 = np.array([36569, 50, 55581, 50]) # mean value + varX0 = np.diag([90, 5, 160, 5]) # variance + Npa = np.array([31, 31, 27, 27]) # 33 number of points per axis, for FFT must be ODD!!!! + N = np.prod(Npa) # number of points - total + sFactor = 4 # scaling factor (number of sigmas covered by the grid) + + + [predGrid, predGridDelta, gridDimOld, xOld, Ppold] = gridCreation(np.vstack(meanX0),varX0,sFactor,nx,Npa) + + meanX0 = np.vstack(meanX0) + pom = predGrid-np.matlib.repmat(meanX0,1,N) + denominator = np.sqrt((2*np.pi)**nx)*np.linalg.det(varX0) + pompom = np.sum(-0.5*np.multiply(pom.T@inv(varX0),pom.T),1) #elementwise multiplication + pomexp = np.exp(pompom) + predDensityProb = pomexp/denominator # Adding probabilities to points + predDensityProb = predDensityProb/(sum(predDensityProb)*np.prod(predGridDelta)) + + priorPMF = PointMassState(state_vector=StateVectors(predGrid), + weight=predDensityProb, + grid_delta = predGridDelta, + grid_dim = gridDimOld, + center = xOld, + eigVec = Ppold, + Npa = Npa, + timestamp=start_time) + + F = transition_model.matrix(prior=prior, time_interval=time_difference) + Q = transition_model.covar(time_interval=time_difference) + + + + priorPMF = PointMassState(state_vector=StateVectors(predGrid), + weight=predDensityProb, + grid_delta = predGridDelta, + grid_dim = gridDimOld, + center = xOld, + eigVec = Ppold, + Npa = Npa, + timestamp=start_time) + + + matrixPMF = [] + + start_time = time.time() + track = Track() + for measurement in measurements: + prediction = pmfPredictor.predict(priorPMF, timestamp=measurement.timestamp) + hypothesis = SingleHypothesis(prediction, measurement) + post = pmfUpdater.update(hypothesis) + priorPMF = post + matrixPMF.append(post.mean) + # print(post.mean) + + # Record the end time + end_time = time.time() + + # Calculate the elapsed time + # print(end_time - start_time) + + + # matrixKF = [] + + # start_time = time.time() + # track = Track() + # for measurement in measurements: + # prediction = predictorKF.predict(priorKF, timestamp=measurement.timestamp) + # hypothesis = SingleHypothesis(prediction, measurement) + # post = updaterKF.update(hypothesis) + # priorKF = post + # matrixKF.append(post.mean) + # # print(post.mean) + + # # Record the end time + # end_time = time.time() + + # %% + # Run the tracker + # ^^^^^^^^^^^^^^^ + # We now run the predict and update steps, propagating the collection of particles and resampling + # when told to (at every step). + + matrixPF = [] + start_time = time.time() + track = Track() + for measurement in measurements: + prediction = predictor.predict(prior, timestamp=measurement.timestamp) + hypothesis = SingleHypothesis(prediction, measurement) + post = updater.update(hypothesis) + # print(post.mean) + track.append(post) + matrixPF.append(post.mean) + prior = track[-1] + + # Record the end time + end_time = time.time() + + # Calculate the elapsed time + # print(end_time - start_time) + + + for ind in range(0,20): + matrixTruePMF.append(np.ravel(np.vstack(matrixPMF[ind])-truth.states[ind].state_vector)) + matrixTruePF.append(np.ravel(matrixPF[ind]-truth.states[ind].state_vector)) + # matrixTrueKF.append(np.ravel(matrixKF[ind]-truth.states[ind].state_vector)) + + +def rmse(errors): + """ + Calculate the Root Mean Square Error (RMSE) from a list of errors. + + Args: + errors (list): List of errors. + + Returns: + float: RMSE value. + """ + # Convert the list of errors into a numpy array for easier computation + errors_array = np.array(errors) + + # Square the errors + squared_errors = np.square(errors_array) + + # Calculate the mean squared error + mean_squared_error = np.mean(squared_errors,0) + + # Calculate the root mean squared error + rmse_value = np.sqrt(mean_squared_error) + + return rmse_value + + +print(rmse(matrixTruePF)) +print(rmse(matrixTruePMF)) +# print(rmse(matrixTrueKF)) + + + + diff --git a/terrain_aided_navigation.py.lprof b/terrain_aided_navigation.py.lprof new file mode 100644 index 0000000000000000000000000000000000000000..0152616d5d2d5da5e2f4c9066a958d29042d04b1 GIT binary patch literal 979 zcmb8t&ui0Q7zglW{keH?GQq(_g~_M`Tk+x#aJK7UBkwx9#$r2}ZMJ1w);0-Ax~VW( z=b=@j7EEDB5rq8{f}n>TJ&24Q^rBEah{qoNKJS`?-t48He4qDu-@I?$?Lcea*Ta9# z>9VRR1;f-=)Uskm3x7SjPTvg5k}cbo`^vpgbGO`x6SCEcs;yf5AsNeBifP3vvaMIw zv~m Date: Wed, 26 Jun 2024 13:58:11 +0200 Subject: [PATCH 02/16] refact --- stonesoup/functions/__init__.py | 184 +++++++----- stonesoup/predictor/pointMass.py | 219 ++++++-------- stonesoup/types/prediction.py | 14 +- stonesoup/types/state.py | 483 ++++++++++++++++++------------- stonesoup/updater/pointMass.py | 110 ++++--- 5 files changed, 550 insertions(+), 460 deletions(-) diff --git a/stonesoup/functions/__init__.py b/stonesoup/functions/__init__.py index db7a1e1c9..1060d4cd2 100644 --- a/stonesoup/functions/__init__.py +++ b/stonesoup/functions/__init__.py @@ -1,49 +1,50 @@ """Mathematical functions used within Stone Soup""" + import copy +import itertools import warnings import numpy as np +from numpy import linalg as LA +from numpy import matlib as matlib +from ..types.array import CovarianceMatrix, StateVector, StateVectors from ..types.numeric import Probability -from ..types.array import StateVector, StateVectors, CovarianceMatrix from ..types.state import State -import itertools -from numpy import linalg as LA -from numpy import matlib as matlib -def gridCreation(xp_aux,Pp_aux,sFactor,nx,Npa): - gridDim = np.zeros((nx,Npa[0])) +def gridCreation(xp_aux, Pp_aux, sFactor, nx, Npa): + gridDim = np.zeros((nx, Npa[0])) gridStep = np.zeros(nx) - eigVal,eigVect = LA.eig(Pp_aux) # eigenvalue and eigenvectors for setting up the grid - gridBound = np.sqrt(eigVal)*sFactor #Boundaries of grid - + eigVal, eigVect = LA.eig( + Pp_aux + ) # eigenvalue and eigenvectors for setting up the grid + gridBound = np.sqrt(eigVal) * sFactor # Boundaries of grid # Ensure the grid steps are in the right order - I = np.argsort(np.diag(Pp_aux)) - I = np.argsort(I) - + sortInd = np.argsort(np.diag(Pp_aux)) + sortInd = np.argsort(sortInd) + pom = np.sort(gridBound) - gridBound = pom[I] - + gridBound = pom[sortInd] + Ipom = np.argsort(gridBound) pom2 = eigVect[:, Ipom] - eigVect = pom2[:, I] + eigVect = pom2[:, sortInd] gridDim = [] # Reset gridDim for each cycle gridStep = [] # Reset gridStep for each cycle - for ind3 in range(0,nx): #Creation of propagated grid - # gridDim[ind3] = np.linspace(-gridBound[ind3], gridBound[ind3], Npa[ind3]) #New grid with middle in 0 - # gridStep[ind3] = np.absolute(gridDim[ind3][0] - gridDim[ind3][1]) #Grid step - gridDim.append(np.linspace(-gridBound[ind3], gridBound[ind3], Npa[ind3])) # New grid with middle in 0 + for ind3 in range(0, nx): # Creation of propagated grid + # New grid with middle in 0 + gridDim.append(np.linspace(-gridBound[ind3], gridBound[ind3], Npa[ind3])) gridStep.append(np.absolute(gridDim[ind3][0] - gridDim[ind3][1])) # Grid step - combvec_predGrid = np.array(list(itertools.product(*gridDim))) - predGrid_pom = np.dot(eigVect,combvec_predGrid.T) - size_pom = np.size(predGrid_pom,1) - predGrid = predGrid_pom + matlib.repmat(xp_aux,1,size_pom) #Grid rotation by eigenvectors and traslation to the counted unscented mean - predGridDelta = gridStep # Grid step size - return predGrid,predGridDelta,gridDim,xp_aux,eigVect + predGrid_pom = np.dot(eigVect, combvec_predGrid.T) + size_pom = np.size(predGrid_pom, 1) + # Grid rotation by eigenvectors and traslation to the counted unscented mean + predGrid = predGrid_pom + matlib.repmat(xp_aux, 1, size_pom) + predGridDelta = gridStep # Grid step size + return predGrid, predGridDelta, gridDim, xp_aux, eigVect def tria(matrix): @@ -65,9 +66,7 @@ def tria(matrix): _, upper_triangular = np.linalg.qr(matrix.T) lower_triangular = upper_triangular.T - index = [col - for col, val in enumerate(np.diag(lower_triangular)) - if val < 0] + index = [col for col, val in enumerate(np.diag(lower_triangular)) if val < 0] lower_triangular[:, index] *= -1 @@ -98,8 +97,8 @@ def cholesky_eps(A, lower=False): L = np.zeros(A.shape) for i in range(A.shape[0]): for j in range(i): - L[i, j] = (A[i, j] - L[i, :]@L[j, :].T) / L[j, j] - val = A[i, i] - L[i, :]@L[i, :].T + L[i, j] = (A[i, j] - L[i, :] @ L[j, :].T) / L[j, j] + val = A[i, i] - L[i, :] @ L[i, :].T L[i, i] = np.sqrt(val) if val > eps else np.sqrt(eps) if lower: @@ -130,13 +129,16 @@ def jacobian(fun, x, **kwargs): # For numerical reasons the step size needs to large enough. Aim for 1e-8 # relative to spacing between floating point numbers for each dimension - delta = 1e8*np.spacing(x.state_vector.astype(np.float64).ravel()) + delta = 1e8 * np.spacing(x.state_vector.astype(np.float64).ravel()) # But at least 1e-8 # TODO: Is this needed? If not, note special case at zero. delta[delta < 1e-8] = 1e-8 x2 = copy.copy(x) # Create a clone of the input - x2.state_vector = np.tile(x.state_vector, ndim+1) + np.eye(ndim, ndim+1)*delta[:, np.newaxis] + x2.state_vector = ( + np.tile(x.state_vector, ndim + 1) + + np.eye(ndim, ndim + 1) * delta[:, np.newaxis] + ) x2.state_vector = x2.state_vector.view(StateVectors) F = fun(x2, **kwargs) @@ -207,10 +209,12 @@ def gauss2sigma(state, alpha=1.0, beta=2.0, kappa=None): sigma_points = sigma_points.astype(float) # Can't use in place addition/subtraction as casting issues may arise when mixing float/int - sigma_points[:, 1:(ndim_state + 1)] = \ - sigma_points[:, 1:(ndim_state + 1)] + sqrt_sigma*np.sqrt(c) - sigma_points[:, (ndim_state + 1):] = \ - sigma_points[:, (ndim_state + 1):] - sqrt_sigma*np.sqrt(c) + sigma_points[:, 1 : (ndim_state + 1)] = sigma_points[ + :, 1 : (ndim_state + 1) + ] + sqrt_sigma * np.sqrt(c) + sigma_points[:, (ndim_state + 1) :] = sigma_points[ + :, (ndim_state + 1) : + ] - sqrt_sigma * np.sqrt(c) # Put these sigma points into s State object list sigma_points_states = [] @@ -262,8 +266,14 @@ def sigma2gauss(sigma_points, mean_weights, covar_weights, covar_noise=None): return mean.view(StateVector), covar.view(CovarianceMatrix) -def unscented_transform(sigma_points_states, mean_weights, covar_weights, - fun, points_noise=None, covar_noise=None): +def unscented_transform( + sigma_points_states, + mean_weights, + covar_weights, + fun, + points_noise=None, + covar_noise=None, +): """ Apply the Unscented Transform to a set of sigma points @@ -306,24 +316,33 @@ def unscented_transform(sigma_points_states, mean_weights, covar_weights, An array containing the transformed sigma point covariance weights """ # Reconstruct the sigma_points matrix - sigma_points = StateVectors([ - sigma_points_state.state_vector for sigma_points_state in sigma_points_states]) + sigma_points = StateVectors( + [sigma_points_state.state_vector for sigma_points_state in sigma_points_states] + ) # Transform points through f if points_noise is None: - sigma_points_t = StateVectors([ - fun(sigma_points_state) for sigma_points_state in sigma_points_states]) + sigma_points_t = StateVectors( + [fun(sigma_points_state) for sigma_points_state in sigma_points_states] + ) else: - sigma_points_t = StateVectors([ - fun(sigma_points_state, points_noise) - for sigma_points_state, point_noise in zip(sigma_points_states, points_noise.T)]) + sigma_points_t = StateVectors( + [ + fun(sigma_points_state, points_noise) + for sigma_points_state, point_noise in zip( + sigma_points_states, points_noise.T + ) + ] + ) # Calculate mean and covariance approximation mean, covar = sigma2gauss(sigma_points_t, mean_weights, covar_weights, covar_noise) # Calculate cross-covariance cross_covar = ( - (sigma_points-sigma_points[:, 0:1]) @ np.diag(covar_weights) @ (sigma_points_t-mean).T + (sigma_points - sigma_points[:, 0:1]) + @ np.diag(covar_weights) + @ (sigma_points_t - mean).T ).view(CovarianceMatrix) return mean, covar, cross_covar, sigma_points_t, mean_weights, covar_weights @@ -487,7 +506,7 @@ def az_el_rg2cart(phi, theta, rho): """ x = rho * np.sin(phi) y = rho * np.sin(theta) - z = rho * np.sqrt(1.0 - np.sin(theta)**2 - np.sin(phi)**2) + z = rho * np.sqrt(1.0 - np.sin(theta) ** 2 - np.sin(phi) ** 2) return x, y, z @@ -521,9 +540,7 @@ def rotx(theta): c, s = np.cos(theta), np.sin(theta) zero = np.zeros_like(theta) one = np.ones_like(theta) - return np.array([[one, zero, zero], - [zero, c, -s], - [zero, s, c]]) + return np.array([[one, zero, zero], [zero, c, -s], [zero, s, c]]) def roty(theta): @@ -557,9 +574,7 @@ def roty(theta): c, s = np.cos(theta), np.sin(theta) zero = np.zeros_like(theta) one = np.ones_like(theta) - return np.array([[c, zero, s], - [zero, one, zero], - [-s, zero, c]]) + return np.array([[c, zero, s], [zero, one, zero], [-s, zero, c]]) def rotz(theta): @@ -593,9 +608,7 @@ def rotz(theta): c, s = np.cos(theta), np.sin(theta) zero = np.zeros_like(theta) one = np.ones_like(theta) - return np.array([[c, -s, zero], - [s, c, zero], - [zero, zero, one]]) + return np.array([[c, -s, zero], [s, c, zero], [zero, zero, one]]) def gm_sample(means, covars, size, weights=None): @@ -633,8 +646,12 @@ def gm_sample(means, covars, size, weights=None): weights = np.array([1 / len(means)] * len(means)) n_samples = np.random.multinomial(size, weights) - samples = np.vstack([np.random.multivariate_normal(mean.ravel(), covar, sample) - for (mean, covar, sample) in zip(means, covars, n_samples)]).T + samples = np.vstack( + [ + np.random.multivariate_normal(mean.ravel(), covar, sample) + for (mean, covar, sample) in zip(means, covars, n_samples) + ] + ).T return StateVectors(samples) @@ -669,7 +686,10 @@ def gm_reduce_single(means, covars, weights): # Calculate covar delta_means = means - mean - covar = np.sum(covars*weights, axis=2, dtype=np.float64) + weights*delta_means@delta_means.T + covar = ( + np.sum(covars * weights, axis=2, dtype=np.float64) + + weights * delta_means @ delta_means.T + ) return mean.view(StateVector), covar.view(CovarianceMatrix) @@ -689,7 +709,7 @@ def mod_bearing(x): Angle in radians in the range math: :math:`-\pi` to :math:`+\pi` """ - x = (x+np.pi) % (2.0*np.pi)-np.pi + x = (x + np.pi) % (2.0 * np.pi) - np.pi return x @@ -708,7 +728,7 @@ def mod_elevation(x): float Angle in radians in the range math: :math:`-\pi/2` to :math:`+\pi/2` """ - x = x % (2*np.pi) # limit to 2*pi + x = x % (2 * np.pi) # limit to 2*pi N = x // (np.pi / 2) # Count # of 90 deg multiples if N == 1: x = np.pi - x @@ -796,15 +816,19 @@ def dotproduct(a, b): """ if np.shape(a) != np.shape(b): - raise ValueError("Inputs must be (a collection of) column vectors of the same dimension") + raise ValueError( + "Inputs must be (a collection of) column vectors of the same dimension" + ) # Decide whether this is a StateVector or a StateVectors - if type(a) is StateVector and type(b) is StateVector: + if isinstance(a, StateVector) and isinstance(b, StateVector): return np.sum(a * b) - elif type(a) is StateVectors and type(b) is StateVectors: + elif isinstance(a, StateVectors) and isinstance(b, StateVectors): return np.atleast_2d(np.asarray(np.sum(a * b, axis=0))) else: - raise ValueError("Inputs must be `StateVector` or `StateVectors` and of the same type") + raise ValueError( + "Inputs must be `StateVector` or `StateVectors` and of the same type" + ) def sde_euler_maruyama_integration(fun, t_values, state_x0): @@ -832,7 +856,7 @@ def sde_euler_maruyama_integration(fun, t_values, state_x0): delta_t = next_t - t delta_w = np.random.normal(scale=np.sqrt(delta_t), size=(state_x.ndim, 1)) a, b = fun(state_x, t) - state_x.state_vector = state_x.state_vector + a*delta_t + b@delta_w + state_x.state_vector = state_x.state_vector + a * delta_t + b @ delta_w return state_x.state_vector @@ -870,13 +894,14 @@ def gauss2cubature(state, alpha=1.0): ndim_state = np.shape(state.state_vector)[0] sqrt_covar = np.linalg.cholesky(state.covar) - cuba_points = np.sqrt(alpha*ndim_state) * np.hstack((np.identity(ndim_state), - -np.identity(ndim_state))) + cuba_points = np.sqrt(alpha * ndim_state) * np.hstack( + (np.identity(ndim_state), -np.identity(ndim_state)) + ) if np.issubdtype(cuba_points.dtype, np.integer): cuba_points = cuba_points.astype(float) - cuba_points = sqrt_covar@cuba_points + state.mean + cuba_points = sqrt_covar @ cuba_points + state.mean return StateVectors(cuba_points) @@ -921,7 +946,7 @@ def cubature2gauss(cubature_points, covar_noise=None, alpha=1.0): mean = np.average(cubature_points, axis=1) sigma_mult = cubature_points @ cubature_points.T mean_mult = mean @ mean.T - covar = (1/alpha)*((1/m)*sigma_mult - mean_mult) + covar = (1 / alpha) * ((1 / m) * sigma_mult - mean_mult) if covar_noise is not None: covar = covar + covar_noise @@ -979,16 +1004,23 @@ def cubature_transform(state, fun, points_noise=None, covar_noise=None, alpha=1. cubature_points = gauss2cubature(state) if points_noise is None: - cubature_points_t = StateVectors([fun(State(cub_point)) for cub_point in cubature_points]) + cubature_points_t = StateVectors( + [fun(State(cub_point)) for cub_point in cubature_points] + ) else: - cubature_points_t = StateVectors([ - fun(State(cub_point), points_noise) - for cub_point, point_noise in zip(cubature_points, points_noise)]) + cubature_points_t = StateVectors( + [ + fun(State(cub_point), points_noise) + for cub_point, point_noise in zip(cubature_points, points_noise) + ] + ) mean, covar = cubature2gauss(cubature_points_t, covar_noise) - cross_covar = (1/alpha)*((1./(2*ndim_state))*cubature_points@cubature_points_t.T - - np.average(cubature_points, axis=1)@mean.T) + cross_covar = (1 / alpha) * ( + (1.0 / (2 * ndim_state)) * cubature_points @ cubature_points_t.T + - np.average(cubature_points, axis=1) @ mean.T + ) cross_covar = cross_covar.view(CovarianceMatrix) return mean, covar, cross_covar, cubature_points_t diff --git a/stonesoup/predictor/pointMass.py b/stonesoup/predictor/pointMass.py index 0c39032ea..d0014371f 100644 --- a/stonesoup/predictor/pointMass.py +++ b/stonesoup/predictor/pointMass.py @@ -1,35 +1,17 @@ -import copy -from typing import Sequence +from datetime import timedelta import numpy as np -from scipy.special import logsumexp -from ordered_set import OrderedSet +import plotly.io as pio +from scipy.interpolate import RegularGridInterpolator +from scipy.signal import fftconvolve +from stonesoup.functions import gridCreation +from stonesoup.types.state import PointMassState -from .base import Predictor -from ._utils import predict_lru_cache -from .kalman import KalmanPredictor, ExtendedKalmanPredictor from ..base import Property -from ..models.transition import TransitionModel -from ..types.prediction import Prediction -from ..types.state import GaussianState -from ..sampler import Sampler - from ..types.array import StateVectors +from .base import Predictor -from numpy.linalg import inv -from numpy import linalg as LA -import itertools -from scipy.interpolate import RegularGridInterpolator -from scipy.signal import fftconvolve -from scipy.linalg import inv -from stonesoup.types.state import PointMassState -from datetime import timedelta -import matplotlib.pyplot as plt -from stonesoup.functions import gridCreation -from mpl_toolkits.mplot3d import Axes3D -import plotly.graph_objects as go -import plotly.io as pio -pio.renderers.default='browser' +pio.renderers.default = "browser" class PointMassPredictor(Predictor): @@ -37,11 +19,10 @@ class PointMassPredictor(Predictor): An implementation of a Particle Filter predictor. """ - sFactor: float = Property( - default=4, - doc="How many sigma to cover by the grid") - #@profile + sFactor: float = Property(default=4, doc="How many sigma to cover by the grid") + + # @profile def predict(self, prior, timestamp=None, **kwargs): """Particle Filter prediction step @@ -63,127 +44,103 @@ def predict(self, prior, timestamp=None, **kwargs): time_interval = timestamp - prior.timestamp except TypeError: time_interval = None - + time_difference = timedelta(days=0, hours=0, minutes=0, seconds=0) if time_interval == time_difference: - predGrid = prior.state_vector, - predDensityProb = prior.weight # SLOW LINE + predGrid = (prior.state_vector,) + predDensityProb = prior.weight # SLOW LINE GridDelta = prior.grid_delta gridDimOld = prior.grid_dim xOld = prior.center Ppold = prior.eigVec else: - - F = self.transition_model.matrix(prior=prior, time_interval=time_interval,**kwargs) + + F = self.transition_model.matrix( + prior=prior, time_interval=time_interval, **kwargs + ) Q = self.transition_model.covar(time_interval=time_interval, **kwargs) - + invF = np.linalg.inv(F) invFT = np.linalg.inv(F.T) - FqF = invF@Q@invFT - matrixForEig = prior.covar() + FqF - - measGridNew, GridDeltaOld, gridDimOld, nothing, eigVect = gridCreation(prior.mean.reshape(-1,1),matrixForEig,self.sFactor,len(invF),prior.Npa); - + FqF = invF @ Q @ invFT + matrixForEig = prior.covar() + FqF + + measGridNew, GridDeltaOld, gridDimOld, nothing, eigVect = gridCreation( + prior.mean.reshape(-1, 1), + matrixForEig, + self.sFactor, + len(invF), + prior.Npa, + ) + # Interpolation - Fint = RegularGridInterpolator(prior.grid_dim, prior.weight.reshape(prior.Npa, order='C'), method="linear", bounds_error=False, fill_value=0) + Fint = RegularGridInterpolator( + prior.grid_dim, + prior.weight.reshape(prior.Npa, order="C"), + method="linear", + bounds_error=False, + fill_value=0, + ) inerpOn = np.dot(np.linalg.inv(prior.eigVec), (measGridNew - prior.center)) measPdfNew = Fint(inerpOn.T).T - -# ============================================================================= - # # Data for the first plot - # toPlot1 = prior.state_vector - # vals1 = prior.weight - - # # Data for the second plot - # toPlot2 = measGridNew - # vals2 = measPdfNew - - # # Create a figure - # fig = go.Figure() - - # # Plot the first set of data - # fig.add_trace(go.Scatter3d( - # x=toPlot1[0, :], y=toPlot1[1, :], z=vals1, - # mode='markers', - # marker=dict(size=5, color='blue', opacity=0.5), - # name='Measurement' - # )) - - # # Plot the second set of data - # fig.add_trace(go.Scatter3d( - # x=toPlot2[0, :], y=toPlot2[1, :], z=vals2, - # mode='markers', - # marker=dict(size=5, color='red', opacity=0.5), - # name='Interpolated' - # )) - - # # Update layout for better visibility and interactivity - # fig.update_layout( - # scene=dict( - # xaxis_title='X-axis', - # yaxis_title='Y-axis', - # zaxis_title='Values' - # ), - # title='Comparison of Prior and Measurement Data', - # margin=dict(l=0, r=0, t=40, b=0) # Adjust margins for better layout - # ) - - # # Show the plot - # fig.show() - -# ============================================================================= - - + # Predictive grid predGrid = np.dot(F, measGridNew) - + # Grid step size GridDelta = np.dot(F, GridDeltaOld) - + # ULTRA FAST PMF - filtDenDOTprodDeltas = np.dot(measPdfNew, np.prod(GridDeltaOld)) # measurement PDF * measurement PDF step size - filtDenDOTprodDeltasCub = np.reshape(filtDenDOTprodDeltas, prior.Npa, order='C') # Into physical space - - halfGrid = (np.ceil(predGrid.shape[1] / 2)-1).astype(int) - - predDenDenomW = np.sqrt((2*np.pi)**prior.ndim*np.linalg.det(Q)) #Denominator for convolution in predictive step - - pom = np.transpose(predGrid[:, halfGrid][:, np.newaxis] - predGrid) # Middle row of the TPM matrix - TPMrow = (np.exp(np.sum(-0.5 * pom @ np.linalg.inv(Q) * pom, axis=1)) / predDenDenomW).reshape(1, -1, order='C') # Middle row of the TPM matrix - TPMrowCubPom = np.reshape(TPMrow, prior.Npa, order='C') # Into physical space - + # measurement PDF * measurement PDF step size + filtDenDOTprodDeltas = np.dot(measPdfNew, np.prod(GridDeltaOld)) + filtDenDOTprodDeltasCub = np.reshape( + filtDenDOTprodDeltas, prior.Npa, order="C" + ) # Into physical space + + halfGrid = (np.ceil(predGrid.shape[1] / 2) - 1).astype(int) + + # Denominator for convolution in predictive step + predDenDenomW = np.sqrt((2 * np.pi) ** prior.ndim * np.linalg.det(Q)) + + pom = np.transpose( + predGrid[:, halfGrid][:, np.newaxis] - predGrid + ) # Middle row of the TPM matrix + TPMrow = ( + np.exp(np.sum(-0.5 * pom @ np.linalg.inv(Q) * pom, axis=1)) + / predDenDenomW + ).reshape( + 1, -1, order="C" + ) # Middle row of the TPM matrix + TPMrowCubPom = np.reshape( + TPMrow, prior.Npa, order="C" + ) # Into physical space + # Compute the convolution using scipy.signal.fftconvolve - convolution_result_complex = fftconvolve(filtDenDOTprodDeltasCub, TPMrowCubPom, mode='same') - + convolution_result_complex = fftconvolve( + filtDenDOTprodDeltasCub, TPMrowCubPom, mode="same" + ) + # Take the real part of the convolution result to get a real-valued result convolution_result_real = np.real(convolution_result_complex).T - - - predDensityProb = np.reshape(convolution_result_real, (-1,1), order='F') - predDensityProb = predDensityProb / (np.sum(predDensityProb) * np.prod(GridDelta)) # Normalization (theoretically not needed) - -# ============================================================================= -# toPlot = predGrid -# vals = predDensityProb -# fig = plt.figure() -# ax = fig.add_subplot(111, projection='3d') -# ax.scatter(toPlot[0,:],toPlot[1,:],vals) -# ============================================================================= - - - xOld = F@np.vstack(prior.mean); - Ppold = F@eigVect; - - # plt.plot(abs(predDensityProb)) - - - return PointMassState(state_vector=StateVectors(np.squeeze(predGrid)), - weight=abs(predDensityProb), # SLOW LINE - grid_delta = GridDelta, - grid_dim = gridDimOld, - center = xOld, - eigVec = Ppold, - Npa = prior.Npa, - timestamp=timestamp) + + predDensityProb = np.reshape(convolution_result_real, (-1, 1), order="F") + # Normalization (theoretically not needed) + predDensityProb = predDensityProb / ( + np.sum(predDensityProb) * np.prod(GridDelta) + ) + + + xOld = F @ np.vstack(prior.mean) + Ppold = F @ eigVect + return PointMassState( + state_vector=StateVectors(np.squeeze(predGrid)), + weight=abs(predDensityProb), + grid_delta=GridDelta, + grid_dim=gridDimOld, + center=xOld, + eigVec=Ppold, + Npa=prior.Npa, + timestamp=timestamp, + ) diff --git a/stonesoup/types/prediction.py b/stonesoup/types/prediction.py index 02bd7b8fc..f2319ef98 100644 --- a/stonesoup/types/prediction.py +++ b/stonesoup/types/prediction.py @@ -11,7 +11,7 @@ BernoulliParticleState) from ..base import Property from ..models.transition.base import TransitionModel -from ..types.state import CreatableFromState, CompositeState +from ..types.state import CreatableFromState, CompositeState, PointMassState class Prediction(Type, CreatableFromState): @@ -141,12 +141,24 @@ class ParticleStatePrediction(Prediction, ParticleState): This is a simple Particle state prediction object. """ +class PointMassStatePrediction(Prediction, PointMassState): + """ParticleStatePrediction type + + This is a simple Particle state prediction object. + """ class ParticleMeasurementPrediction(MeasurementPrediction, ParticleState): """MeasurementStatePrediction type This is a simple Particle measurement prediction object. """ + + +class PointMassMeasurementPrediction(MeasurementPrediction, PointMassState): + """MeasurementStatePrediction type + + This is a simple Particle measurement prediction object. + """ class MultiModelParticleStatePrediction(Prediction, MultiModelParticleState): diff --git a/stonesoup/types/state.py b/stonesoup/types/state.py index 26466d76a..9e22705d7 100644 --- a/stonesoup/types/state.py +++ b/stonesoup/types/state.py @@ -1,33 +1,36 @@ import copy import datetime +import typing import uuid from collections import abc from numbers import Integral -from typing import MutableSequence, Any, Optional, Sequence, MutableMapping -import typing +from typing import Any, MutableMapping, MutableSequence, Optional, Sequence import numpy as np from scipy.stats import multivariate_normal from ..base import Property, clearable_cached_property -from .array import StateVector, CovarianceMatrix, PrecisionMatrix, StateVectors +from .array import CovarianceMatrix, PrecisionMatrix, StateVector, StateVectors from .base import Type -from .particle import Particle, MultiModelParticle, RaoBlackwellisedParticle from .numeric import Probability +from .particle import MultiModelParticle, Particle, RaoBlackwellisedParticle class State(Type): """State type. Most simple state type, which only has time and a state vector.""" + timestamp: datetime.datetime = Property( - default=None, doc="Timestamp of the state. Default None.") - state_vector: StateVector = Property(doc='State vector.') + default=None, doc="Timestamp of the state. Default None." + ) + state_vector: StateVector = Property(doc="State vector.") def __init__(self, state_vector, *args, **kwargs): # Don't cast away subtype of state_vector if not necessary - if state_vector is not None \ - and not isinstance(state_vector, (StateVector, StateVectors)): + if state_vector is not None and not isinstance( + state_vector, (StateVector, StateVectors) + ): state_vector = StateVector(state_vector) super().__init__(state_vector, *args, **kwargs) @@ -37,8 +40,12 @@ def ndim(self): return self.state_vector.shape[0] @staticmethod - def from_state(state: 'State', *args: Any, target_type: Optional[typing.Type] = None, - **kwargs: Any) -> 'State': + def from_state( + state: "State", + *args: Any, + target_type: Optional[typing.Type] = None, + **kwargs: Any, + ) -> "State": """Class utility function to create a new state (or compatible type) from an existing state. The type and properties of this new state are defined by `state` except for any explicitly overwritten via `args` and `kwargs`. @@ -69,12 +76,14 @@ def from_state(state: 'State', *args: Any, target_type: Optional[typing.Type] = target_type = type(state) args_property_names = { - name for n, name in enumerate(target_type.properties) if n < len(args)} + name for n, name in enumerate(target_type.properties) if n < len(args) + } new_kwargs = { name: getattr(state, name) for name in type(state).properties.keys() & target_type.properties.keys() - if name not in args_property_names and name not in kwargs} + if name not in args_property_names and name not in kwargs + } new_kwargs.update(kwargs) @@ -91,11 +100,15 @@ def __init_subclass__(cls, **kwargs): # subclasses return if len(bases) != 2: - raise TypeError('A CreatableFromState subclass must have exactly two superclasses') + raise TypeError( + "A CreatableFromState subclass must have exactly two superclasses" + ) base_class, state_type = cls.__bases__ if not issubclass(base_class, CreatableFromState): - raise TypeError('The first superclass of a CreatableFromState subclass must be a ' - 'CreatableFromState (or a subclass)') + raise TypeError( + "The first superclass of a CreatableFromState subclass must be a " + "CreatableFromState (or a subclass)" + ) if not issubclass(state_type, State): # Non-state subclasses do not need adding to the class mapping, as they should not # be created from States @@ -107,11 +120,8 @@ def __init_subclass__(cls, **kwargs): @classmethod def from_state( - cls, - state: State, - *args: Any, - target_type: Optional[type] = None, - **kwargs: Any) -> 'State': + cls, state: State, *args: Any, target_type: Optional[type] = None, **kwargs: Any + ) -> "State": """ Return new object instance of suitable type from an existing `state`. The type returned can be explicitly specified using the `target_type` argument, otherwise @@ -157,27 +167,36 @@ def from_state( if isinstance(state, StateMutableSequence): state = state.state try: - state_type = next(type_ for type_ in type(state).mro() - if type_ in CreatableFromState.class_mapping[cls]) + state_type = next( + type_ + for type_ in type(state).mro() + if type_ in CreatableFromState.class_mapping[cls] + ) except StopIteration: - raise TypeError(f'{cls.__name__} type not defined for {type(state).__name__}') + raise TypeError( + f"{cls.__name__} type not defined for {type(state).__name__}" + ) if target_type is None: target_type = CreatableFromState.class_mapping[cls][state_type] return target_type.from_state(state, *args, **kwargs, target_type=target_type) - - + + class PointMassState(State): - - state_vector: StateVectors = Property(doc='State vectors.') - weight: MutableSequence[Probability] = Property(default=None, doc='Masses of grid points') - grid_delta: np.ndarray = Property(default=None, doc='Grid step per dim') - grid_dim: np.ndarray = Property(default=None, doc='Grid coordinates per dimension before rotation and translation') - center: np.ndarray = Property(default=None, doc='Center of the grid') - eigVec: np.ndarray = Property(default=None, doc='Eigenvectors of the grid') - Npa: np.ndarray = Property(default=None, doc='Points per dim') - - + + state_vector: StateVectors = Property(doc="State vectors.") + weight: MutableSequence[Probability] = Property( + default=None, doc="Masses of grid points" + ) + grid_delta: np.ndarray = Property(default=None, doc="Grid step per dim") + grid_dim: np.ndarray = Property( + default=None, + doc="Grid coordinates per dimension before rotation and translation", + ) + center: np.ndarray = Property(default=None, doc="Center of the grid") + eigVec: np.ndarray = Property(default=None, doc="Eigenvectors of the grid") + Npa: np.ndarray = Property(default=None, doc="Points per dim") + def __len__(self): return self.state_vector.shape[1] @@ -186,23 +205,22 @@ def ndim(self): """The number of dimensions represented by the state.""" return self.state_vector.shape[0] - @clearable_cached_property('state_vector') + @clearable_cached_property("state_vector") def mean(self): """Sample mean for particles""" return np.hstack(self.state_vector @ self.weight * np.prod(self.grid_delta)) - - #@profile - def covar(self): - # Measurement update covariance + # @profile + def covar(self): + # Measurement update covariance chip_ = self.state_vector - self.mean[:, np.newaxis] - chip_w = chip_ * self.weight.reshape(1, -1, order='C') - measVar = (chip_w @ chip_.T) * np.prod(self.grid_delta) + chip_w = chip_ * self.weight.reshape(1, -1, order="C") + measVar = (chip_w @ chip_.T) * np.prod(self.grid_delta) return measVar + State.register(PointMassState) # noqa: E305 - - + class ASDState(Type): """ASD State type @@ -210,15 +228,15 @@ class ASDState(Type): For the use of Accumulated State Densities. """ - multi_state_vector: StateVector = Property( - doc="State vector of all timestamps") + multi_state_vector: StateVector = Property(doc="State vector of all timestamps") timestamps: Sequence[datetime.datetime] = Property( - doc="List of all timestamps which have a state in the ASDState") + doc="List of all timestamps which have a state in the ASDState" + ) max_nstep: int = Property( - doc="Decides when the state is pruned in a prediction step. If 0 then there is no pruning") + doc="Decides when the state is pruned in a prediction step. If 0 then there is no pruning" + ) - def __init__(self, multi_state_vector, timestamps, - max_nstep=0, *args, **kwargs): + def __init__(self, multi_state_vector, timestamps, max_nstep=0, *args, **kwargs): if multi_state_vector is not None and timestamps is not None: multi_state_vector = StateVector(multi_state_vector) if not isinstance(timestamps, Sequence): @@ -230,18 +248,18 @@ def __getitem__(self, item): if isinstance(item, Integral): ndim = self.ndim start = item * ndim - end = None if item == -1 else (item+1) * ndim + end = None if item == -1 else (item + 1) * ndim state_slice = slice(start, end) state_vector = StateVector(self.multi_state_vector[state_slice]) timestamp = self.timestamps[item] return State(state_vector=state_vector, timestamp=timestamp) else: - raise TypeError(f'{type(self).__name__!r} only subscriptable by int') + raise TypeError(f"{type(self).__name__!r} only subscriptable by int") @property def state_vector(self): """The State vector of the newest timestamp""" - return self.multi_state_vector[0:self.ndim] + return self.multi_state_vector[0: self.ndim] @property def timestamp(self): @@ -258,12 +276,12 @@ def nstep(self): """Number of timesteps which are in the ASDState""" return len(self.timestamps) - @clearable_cached_property('multi_state_vector', 'timestamps') + @clearable_cached_property("multi_state_vector", "timestamps") def state(self): """A :class:`~.State` object representing latest timestamp""" return self[0] - @clearable_cached_property('multi_state_vector', 'timestamps') + @clearable_cached_property("multi_state_vector", "timestamps") def states(self): return [self[i] for i in range(self.nstep)] @@ -299,7 +317,8 @@ class StateMutableSequence(Type, abc.MutableSequence): states: MutableSequence[State] = Property( default=None, - doc="The initial list of states. Default `None` which initialises with empty list.") + doc="The initial list of states. Default `None` which initialises with empty list.", + ) def __init__(self, states=None, *args, **kwargs): if states is None: @@ -320,8 +339,9 @@ def __delitem__(self, index): def __getitem__(self, index): if isinstance(index, slice) and ( - isinstance(index.start, datetime.datetime) - or isinstance(index.stop, datetime.datetime)): + isinstance(index.start, datetime.datetime) + or isinstance(index.stop, datetime.datetime) + ): items = [] for state in self.states: try: @@ -331,16 +351,17 @@ def __getitem__(self, index): continue except TypeError as exc: raise TypeError( - 'both indices must be `datetime.datetime` objects for' - 'time slice') from exc + "both indices must be `datetime.datetime` objects for" + "time slice" + ) from exc items.append(state) - return StateMutableSequence(items[::index.step]) + return StateMutableSequence(items[:: index.step]) elif isinstance(index, datetime.datetime): for state in reversed(self.states): if state.timestamp == index: return state else: - raise IndexError('timestamp not found in states') + raise IndexError("timestamp not found in states") elif isinstance(index, slice): return StateMutableSequence(self.states.__getitem__(index)) else: @@ -365,7 +386,7 @@ def __getattribute__(self, name): else: # For non _ attributes, try to get the attribute from self.state instead of self. try: - my_state = Type.__getattribute__(self, 'state') + my_state = Type.__getattribute__(self, "state") return getattr(my_state, name) except AttributeError: # If we get the error about 'State' not having the attribute, then we want to @@ -415,7 +436,8 @@ class GaussianState(State): This is a simple Gaussian state object, which, as the name suggests, is described by a Gaussian state distribution. """ - covar: CovarianceMatrix = Property(doc='Covariance matrix of state.') + + covar: CovarianceMatrix = Property(doc="Covariance matrix of state.") def __init__(self, state_vector, covar, *args, **kwargs): # Don't cast away subtype of covar if not necessary @@ -423,8 +445,7 @@ def __init__(self, state_vector, covar, *args, **kwargs): covar = CovarianceMatrix(covar) super().__init__(state_vector, covar, *args, **kwargs) if self.state_vector.shape[0] != self.covar.shape[0]: - raise ValueError( - "state vector and covariance should have same dimensions") + raise ValueError("state vector and covariance should have same dimensions") @property def mean(self): @@ -440,8 +461,10 @@ class SqrtGaussianState(State): taste. No checks are undertaken to ensure that a sensible square root form has been chosen. """ - sqrt_covar: CovarianceMatrix = Property(doc="A square root form of the Gaussian covariance " - "matrix.") + + sqrt_covar: CovarianceMatrix = Property( + doc="A square root form of the Gaussian covariance " "matrix." + ) def __init__(self, state_vector, sqrt_covar, *args, **kwargs): sqrt_covar = CovarianceMatrix(sqrt_covar) @@ -452,10 +475,12 @@ def mean(self): """The state mean, equivalent to state vector""" return self.state_vector - @clearable_cached_property('sqrt_covar') + @clearable_cached_property("sqrt_covar") def covar(self): """The full covariance matrix.""" return self.sqrt_covar @ self.sqrt_covar.T + + GaussianState.register(SqrtGaussianState) # noqa: E305 @@ -468,22 +493,21 @@ class InformationState(State): covariance, respectively, of a Gaussian state. """ - precision: PrecisionMatrix = Property(doc='precision matrix of state.') - @clearable_cached_property('state_vector', 'precision') + precision: PrecisionMatrix = Property(doc="precision matrix of state.") + + @clearable_cached_property("state_vector", "precision") def gaussian_state(self): """The Gaussian state.""" - return GaussianState(self.mean, - self.covar, - self.timestamp) + return GaussianState(self.mean, self.covar, self.timestamp) - @clearable_cached_property('precision') + @clearable_cached_property("precision") def covar(self): """Covariance matrix, inverse of :attr:`precision` matrix.""" return np.linalg.inv(self.precision) - @clearable_cached_property('state_vector', 'precision') + @clearable_cached_property("state_vector", "precision") def mean(self): """Equivalent Gaussian mean""" return self.covar @ self.state_vector @@ -515,7 +539,8 @@ def from_gaussian_state(cls, gaussian_state, *args, **kwargs): state_vector=state_vector, precision=precision, timestamp=timestamp, - *args, **kwargs + *args, + **kwargs, ) @@ -525,12 +550,14 @@ class ASDGaussianState(ASDState): This is a simple Accumulated State Density Gaussian state object, which as the name suggests is described by a Gaussian state distribution. """ + multi_covar: CovarianceMatrix = Property(doc="Covariance of all timesteps") correlation_matrices: MutableSequence[MutableMapping[str, np.ndarray]] = Property( default=None, doc="Sequence of Correlation Matrices, consisting of :math:`P_{l|l}`, :math:`P_{l|l+1}` " - "and :math:`F_{l+1|l}` built in the Kalman predictor and Kalman updater, aligned to " - ":attr:`timestamps`") + "and :math:`F_{l+1|l}` built in the Kalman predictor and Kalman updater, aligned to " + ":attr:`timestamps`", + ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -541,30 +568,32 @@ def __getitem__(self, item): if isinstance(item, Integral): ndim = self.ndim start = item * ndim - end = None if item == -1 else (item+1) * ndim + end = None if item == -1 else (item + 1) * ndim state_slice = slice(start, end) state_vector = StateVector(self.multi_state_vector[state_slice]) covar = CovarianceMatrix(self.multi_covar[state_slice, state_slice]) timestamp = self.timestamps[item] - return GaussianState(state_vector=state_vector, covar=covar, timestamp=timestamp) + return GaussianState( + state_vector=state_vector, covar=covar, timestamp=timestamp + ) else: - raise TypeError(f'{type(self).__name__!r} only subscriptable by int') + raise TypeError(f"{type(self).__name__!r} only subscriptable by int") @property def covar(self): - return self.multi_covar[:self.ndim, :self.ndim] + return self.multi_covar[: self.ndim, : self.ndim] @property def mean(self): """The state mean, equivalent to state vector""" return self.state_vector - @clearable_cached_property('multi_state_vector', 'multi_covar', 'timestamps') + @clearable_cached_property("multi_state_vector", "multi_covar", "timestamps") def state(self): """A :class:`~.GaussianState` object representing latest timestamp""" return super().state - @clearable_cached_property('multi_state_vector', 'multi_covar', 'timestamps') + @clearable_cached_property("multi_state_vector", "multi_covar", "timestamps") def states(self): return super().states @@ -575,14 +604,13 @@ class WeightedGaussianState(GaussianState): Gaussian State object with an associated weight. Used as components for a GaussianMixtureState. """ + weight: Probability = Property(default=0, doc="Weight of the Gaussian State.") - @clearable_cached_property('state_vector', 'covar') + @clearable_cached_property("state_vector", "covar") def gaussian_state(self): """The Gaussian state.""" - return GaussianState(self.state_vector, - self.covar, - timestamp=self.timestamp) + return GaussianState(self.state_vector, self.covar, timestamp=self.timestamp) @classmethod def from_gaussian_state(cls, gaussian_state, *args, copy=True, **kwargs): @@ -613,10 +641,7 @@ def from_gaussian_state(cls, gaussian_state, *args, copy=True, **kwargs): state_vector = state_vector.copy() covar = covar.copy() return cls( - state_vector=state_vector, - covar=covar, - timestamp=timestamp, - *args, **kwargs + state_vector=state_vector, covar=covar, timestamp=timestamp, *args, **kwargs ) @@ -626,10 +651,11 @@ class TaggedWeightedGaussianState(WeightedGaussianState): Gaussian State object with an associated weight and tag. Used as components for a GaussianMixtureState. """ + tag: str = Property(default=None, doc="Unique tag of the Gaussian State.") - BIRTH = 'birth' - '''Tag value used to signify birth component''' + BIRTH = "birth" + """Tag value used to signify birth component""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -643,6 +669,7 @@ class ASDWeightedGaussianState(ASDGaussianState): ASD Gaussian State object with an associated weight. Used as components for a GaussianMixtureState. """ + weight: Probability = Property(default=0, doc="Weight of the Gaussian State.") @@ -653,24 +680,34 @@ class ParticleState(State): distribution of particles """ - state_vector: StateVectors = Property(doc='State vectors.') - weight: MutableSequence[Probability] = Property(default=None, doc='Weights of particles') - log_weight: np.ndarray = Property(default=None, doc='Log weights of particles') - parent: 'ParticleState' = Property(default=None, doc='Parent particles') - particle_list: MutableSequence[Particle] = Property(default=None, - doc='List of Particle objects') - fixed_covar: CovarianceMatrix = Property(default=None, - doc='Fixed covariance value. Default `None`, where' - 'weighted sample covariance is then used.') + state_vector: StateVectors = Property(doc="State vectors.") + weight: MutableSequence[Probability] = Property( + default=None, doc="Weights of particles" + ) + log_weight: np.ndarray = Property(default=None, doc="Log weights of particles") + parent: "ParticleState" = Property(default=None, doc="Parent particles") + particle_list: MutableSequence[Particle] = Property( + default=None, doc="List of Particle objects" + ) + fixed_covar: CovarianceMatrix = Property( + default=None, + doc="Fixed covariance value. Default `None`, where" + "weighted sample covariance is then used.", + ) def __init__(self, *args, **kwargs): weight = next( - (val for name, val in zip(type(self).properties, args) if name == 'weight'), - kwargs.get('weight', None)) + (val for name, val in zip(type(self).properties, args) if name == "weight"), + kwargs.get("weight", None), + ) log_weight, idx = next( - ((val, idx) for idx, (name, val) in enumerate(zip(type(self).properties, args)) - if name == 'log_weight'), - (kwargs.get('log_weight', None), None)) + ( + (val, idx) + for idx, (name, val) in enumerate(zip(type(self).properties, args)) + if name == "log_weight" + ), + (kwargs.get("log_weight", None), None), + ) if weight is not None and log_weight is not None: raise ValueError("Cannot provide both weight and log weight") @@ -679,31 +716,40 @@ def __init__(self, *args, **kwargs): if idx is not None: args[idx] = log_weight else: - kwargs['log_weight'] = log_weight + kwargs["log_weight"] = log_weight super().__init__(*args, **kwargs) - if (self.particle_list is not None) and \ - (self.state_vector is not None or self.weight is not None): - raise ValueError("Use either a list of Particle objects or StateVectors and weights," - " but not both.") + if (self.particle_list is not None) and ( + self.state_vector is not None or self.weight is not None + ): + raise ValueError( + "Use either a list of Particle objects or StateVectors and weights," + " but not both." + ) if self.particle_list and isinstance(self.particle_list, list): - self.state_vector = \ - StateVectors([particle.state_vector for particle in self.particle_list]) - self.weight = \ - np.array([Probability(particle.weight) for particle in self.particle_list]) + self.state_vector = StateVectors( + [particle.state_vector for particle in self.particle_list] + ) + self.weight = np.array( + [Probability(particle.weight) for particle in self.particle_list] + ) parent_list = [particle.parent for particle in self.particle_list] if parent_list.count(None) == 0: self.parent = ParticleState(None, particle_list=parent_list) elif 0 < parent_list.count(None) < len(parent_list): - raise ValueError("Either all particles should have" - " parents or none of them should.") + raise ValueError( + "Either all particles should have" + " parents or none of them should." + ) if self.parent: self.parent.parent = None # Removed to avoid using significant memory - if self.state_vector is not None and not isinstance(self.state_vector, StateVectors): + if self.state_vector is not None and not isinstance( + self.state_vector, StateVectors + ): self.state_vector = StateVectors(self.state_vector) def __getitem__(self, item): @@ -718,40 +764,57 @@ def __getitem__(self, item): log_weight = None if isinstance(item, int): - result = Particle(state_vector=self.state_vector[:, item], - weight=self.weight[item] if self.weight is not None else None, - parent=parent) + result = Particle( + state_vector=self.state_vector[:, item], + weight=self.weight[item] if self.weight is not None else None, + parent=parent, + ) else: # Allow for Prediction/Update sub-types - result = type(self).from_state(self, - state_vector=self.state_vector[:, item], - log_weight=log_weight, - parent=parent) + result = type(self).from_state( + self, + state_vector=self.state_vector[:, item], + log_weight=log_weight, + parent=parent, + ) return result @classmethod - def from_state(cls, state: 'State', *args: Any, target_type: Optional[typing.Type] = None, - **kwargs: Any) -> 'State': + def from_state( + cls, + state: "State", + *args: Any, + target_type: Optional[typing.Type] = None, + **kwargs: Any, + ) -> "State": # Handle default presence of both particle_list and weight once class has been created by # ignoring particle_list and weight (setting to None) if not provided. particle_list, particle_list_idx = next( - ((val, idx) for idx, (name, val) in enumerate(zip(cls.properties, args)) - if name == 'particle_list'), - (kwargs.get('particle_list', None), None)) + ( + (val, idx) + for idx, (name, val) in enumerate(zip(cls.properties, args)) + if name == "particle_list" + ), + (kwargs.get("particle_list", None), None), + ) if particle_list_idx is None: - kwargs['particle_list'] = particle_list + kwargs["particle_list"] = particle_list weight, weight_idx = next( - ((val, idx) for idx, (name, val) in enumerate(zip(cls.properties, args)) - if name == 'weight'), - (kwargs.get('weight', None), None)) + ( + (val, idx) + for idx, (name, val) in enumerate(zip(cls.properties, args)) + if name == "weight" + ), + (kwargs.get("weight", None), None), + ) if weight_idx is None: - kwargs['weight'] = weight + kwargs["weight"] = weight return super().from_state(state, *args, target_type=target_type, **kwargs) - @clearable_cached_property('state_vector', 'log_weight') + @clearable_cached_property("state_vector", "log_weight") def particles(self): """Sequence of individual :class:`~.Particle` objects.""" if self.particle_list is not None: @@ -766,14 +829,14 @@ def ndim(self): """The number of dimensions represented by the state.""" return self.state_vector.shape[0] - @clearable_cached_property('state_vector', 'log_weight') + @clearable_cached_property("state_vector", "log_weight") def mean(self): """Sample mean for particles""" if len(self) == 1: # No need to calculate mean return self.state_vector return np.average(self.state_vector, axis=1, weights=np.exp(self.log_weight)) - @clearable_cached_property('state_vector', 'log_weight', 'fixed_covar') + @clearable_cached_property("state_vector", "log_weight", "fixed_covar") def covar(self): """Sample covariance matrix for particles""" if self.fixed_covar is not None: @@ -786,22 +849,23 @@ def weight(self, value): self.log_weight = None else: self.log_weight = np.log(np.asarray(value, dtype=np.float64)) - self.__dict__['weight'] = np.asanyarray(value) + self.__dict__["weight"] = np.asanyarray(value) @weight.getter def weight(self): try: - return self.__dict__['weight'] + return self.__dict__["weight"] except KeyError: log_weight = self.log_weight if log_weight is None: return None weight = Probability.from_log_ufunc(log_weight) - self.__dict__['weight'] = weight + self.__dict__["weight"] = weight return weight + State.register(ParticleState) # noqa: E305 -ParticleState.log_weight._clear_cached.add('weight') +ParticleState.log_weight._clear_cached.add("weight") class MultiModelParticleState(ParticleState): @@ -814,13 +878,15 @@ class MultiModelParticleState(ParticleState): dynamic_model: np.ndarray = Property( default=None, - doc="Array of indices that identify which model is associated with each particle.") + doc="Array of indices that identify which model is associated with each particle.", + ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.particle_list and isinstance(self.particle_list, list): - self.dynamic_model = \ - np.array([particle.dynamic_model for particle in self.particle_list]) + self.dynamic_model = np.array( + [particle.dynamic_model for particle in self.particle_list] + ) def __getitem__(self, item): if self.parent is not None: @@ -843,14 +909,17 @@ def __getitem__(self, item): state_vector=self.state_vector[:, item], weight=self.weight[item] if self.weight is not None else None, parent=parent, - dynamic_model=dynamic_model) + dynamic_model=dynamic_model, + ) else: # Allow for Prediction/Update sub-types - result = type(self).from_state(self, - state_vector=self.state_vector[:, item], - log_weight=log_weight, - parent=parent, - dynamic_model=dynamic_model) + result = type(self).from_state( + self, + state_vector=self.state_vector[:, item], + log_weight=log_weight, + parent=parent, + dynamic_model=dynamic_model, + ) return result @@ -859,14 +928,15 @@ class RaoBlackwellisedParticleState(ParticleState): model_probabilities: np.ndarray = Property( default=None, doc="2d NumPy array containing probability of particle belong to particular model. " - "Shape (n-models, m-particles)." + "Shape (n-models, m-particles).", ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.particle_list and isinstance(self.particle_list, list): - self.model_probabilities = \ - np.column_stack([particle.model_probabilities for particle in self.particle_list]) + self.model_probabilities = np.column_stack( + [particle.model_probabilities for particle in self.particle_list] + ) def __getitem__(self, item): if self.parent is not None: @@ -889,14 +959,17 @@ def __getitem__(self, item): state_vector=self.state_vector[:, item], weight=self.weight[item] if self.weight is not None else None, parent=parent, - model_probabilities=model_probabilities) + model_probabilities=model_probabilities, + ) else: # Allow for Prediction/Update sub-types - result = type(self).from_state(self, - state_vector=self.state_vector[:, item], - log_weight=log_weight, - parent=parent, - model_probabilities=model_probabilities) + result = type(self).from_state( + self, + state_vector=self.state_vector[:, item], + log_weight=log_weight, + parent=parent, + model_probabilities=model_probabilities, + ) return result @@ -916,8 +989,7 @@ class BernoulliParticleState(ParticleState): """ existence_probability: Probability = Property( - default=None, - doc="Target existence probability estimate" + default=None, doc="Target existence probability estimate" ) def __init__(self, *args, **kwargs): @@ -940,9 +1012,9 @@ def __getitem__(self, item): existence_probability = None if isinstance(item, int): - result = Particle(state_vector=self.state_vector[:, item], - weight=weight, - parent=parent) + result = Particle( + state_vector=self.state_vector[:, item], weight=weight, parent=parent + ) else: # Allow for Prediction/Update sub-types result = type(self).from_state( @@ -950,7 +1022,8 @@ def __getitem__(self, item): state_vector=self.state_vector[:, item], parent=parent, particle_list=None, - existence_probability=existence_probability) + existence_probability=existence_probability, + ) return result @@ -970,10 +1043,12 @@ class EnsembleState(State): """ - state_vector: StateVectors = Property(doc="An ensemble of state vectors which represent the " - "state") + state_vector: StateVectors = Property( + doc="An ensemble of state vectors which represent the " "state" + ) timestamp: datetime.datetime = Property( - default=None, doc="Timestamp of the state. Default None.") + default=None, doc="Timestamp of the state. Default None." + ) @classmethod def from_gaussian_state(cls, gaussian_state, num_vectors, **kwargs): @@ -996,9 +1071,11 @@ def from_gaussian_state(cls, gaussian_state, num_vectors, **kwargs): covar = gaussian_state.covar timestamp = gaussian_state.timestamp - return cls(state_vector=cls.generate_ensemble(mean, covar, num_vectors), - timestamp=timestamp, - **kwargs) + return cls( + state_vector=cls.generate_ensemble(mean, covar, num_vectors), + timestamp=timestamp, + **kwargs, + ) @staticmethod def generate_ensemble(mean, covar, num_vectors): @@ -1026,7 +1103,8 @@ def generate_ensemble(mean, covar, num_vectors): mean = StateVector(mean) ndim = mean.shape[0] vectors = np.atleast_2d( - multivariate_normal.rvs(np.zeros(ndim), covar, num_vectors)) + multivariate_normal.rvs(np.zeros(ndim), covar, num_vectors) + ) if ndim > 1: vectors = vectors.T @@ -1037,22 +1115,23 @@ def num_vectors(self): """Number of columns in state ensemble""" return np.shape(self.state_vector)[1] - @clearable_cached_property('state_vector') + @clearable_cached_property("state_vector") def mean(self): """The state mean, numerically equivalent to state vector""" return np.average(self.state_vector, axis=1) - @clearable_cached_property('state_vector') + @clearable_cached_property("state_vector") def covar(self): """Sample covariance matrix for ensemble""" return np.cov(self.state_vector) - @clearable_cached_property('state_vector') + @clearable_cached_property("state_vector") def sqrt_covar(self): """sqrt of sample covariance matrix for ensemble, useful for some EnKF algorithms""" - return ((self.state_vector-np.tile(self.mean, self.num_vectors)) - / np.sqrt(self.num_vectors - 1)) + return (self.state_vector - np.tile(self.mean, self.num_vectors)) / np.sqrt( + self.num_vectors - 1 + ) class CategoricalState(State): @@ -1063,13 +1142,16 @@ class CategoricalState(State): of discrete categories :math:`\Phi = \{\phi^m|m\in \mathbf{N}, m\le M\}` for some finite :math:`M`.""" - categories: Sequence[float] = Property(doc="Category names. Defaults to a list of integers.", - default=None) + categories: Sequence[float] = Property( + doc="Category names. Defaults to a list of integers.", default=None + ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.state_vector = self.state_vector / np.sum(self.state_vector) # normalise state vector + self.state_vector = self.state_vector / np.sum( + self.state_vector + ) # normalise state vector if self.categories is None: self.categories = list(map(str, range(self.ndim))) @@ -1080,9 +1162,11 @@ def __init__(self, *args, **kwargs): ) def __str__(self): - strings = [f"P({category}) = {p}" - for category, p in zip(self.categories, self.state_vector)] - string = ',\n'.join(strings) + strings = [ + f"P({category}) = {p}" + for category, p in zip(self.categories, self.state_vector) + ] + string = ",\n".join(strings) return string @property @@ -1100,11 +1184,13 @@ class CompositeState(Type): sub_states: Sequence[State] = Property( doc="Sequence of sub-states comprising the composite state. All sub-states must have " - "matching timestamp. Must not be empty.") + "matching timestamp. Must not be empty." + ) default_timestamp: datetime.datetime = Property( default=None, doc="Default timestamp if no sub-states exist to attain timestamp from. Defaults to " - "`None`, whereby sub-states will be required to have timestamps.") + "`None`, whereby sub-states will be required to have timestamps.", + ) def __init__(self, *args, **kwargs): @@ -1125,17 +1211,22 @@ def _check_timestamp(self): self._timestamp = None - sub_timestamps = {sub_state.timestamp - for sub_state in self.sub_states - if sub_state.timestamp} + sub_timestamps = { + sub_state.timestamp for sub_state in self.sub_states if sub_state.timestamp + } if len(sub_timestamps) > 1: raise ValueError("All sub-states must share the same timestamp if defined") - if (sub_timestamps and self.default_timestamp - and not sub_timestamps == {self.default_timestamp}): - raise ValueError("Sub-state timestamps and default timestamp must be the same if " - "defined") + if ( + sub_timestamps + and self.default_timestamp + and not sub_timestamps == {self.default_timestamp} + ): + raise ValueError( + "Sub-state timestamps and default timestamp must be the same if " + "defined" + ) if sub_timestamps: self.default_timestamp = sub_timestamps.pop() diff --git a/stonesoup/updater/pointMass.py b/stonesoup/updater/pointMass.py index c218ba392..6ef1cb371 100644 --- a/stonesoup/updater/pointMass.py +++ b/stonesoup/updater/pointMass.py @@ -1,26 +1,16 @@ -import copy from functools import lru_cache from typing import Callable -import warnings - import numpy as np -from scipy.linalg import inv -from scipy.special import logsumexp - -from .base import Updater -from .kalman import KalmanUpdater, ExtendedKalmanUpdater +from scipy.stats import multivariate_normal +from stonesoup.types.state import PointMassState from ..base import Property -from ..functions import cholesky_eps, sde_euler_maruyama_integration -from ..predictor.particle import MultiModelPredictor, RaoBlackwellisedMultiModelPredictor -from ..resampler import Resampler from ..regulariser import Regulariser +from ..resampler import Resampler from ..types.prediction import ( - Prediction, ParticleMeasurementPrediction, GaussianStatePrediction, MeasurementPrediction) -from ..types.update import ParticleStateUpdate, Update -from scipy.stats import multivariate_normal -from stonesoup.types.state import PointMassState - -import matplotlib.pyplot as plt + MeasurementPrediction, +) +from ..types.update import Update +from .base import Updater class PointMassUpdater(Updater): @@ -33,33 +23,35 @@ class PointMassUpdater(Updater): called every time, and it's up to the resampler to decide if resampling is required). """ - sFactor: float = Property( - default=3, - doc="How many sigma to cover by the grid") - resampler: Resampler = Property(default=None, doc='Resampler to prevent particle degeneracy') + + sFactor: float = Property(default=3, doc="How many sigma to cover by the grid") + resampler: Resampler = Property( + default=None, doc="Resampler to prevent particle degeneracy" + ) regulariser: Regulariser = Property( default=None, - doc='Regulariser to prevent particle impoverishment. The regulariser ' - 'is normally used after resampling. If a :class:`~.Resampler` is defined, ' - 'then regularisation will only take place if the particles have been ' - 'resampled. If the :class:`~.Resampler` is not defined but a ' - ':class:`~.Regulariser` is, then regularisation will be conducted under the ' - 'assumption that the user intends for this to occur.') + doc="Regulariser to prevent particle impoverishment. The regulariser " + "is normally used after resampling. If a :class:`~.Resampler` is defined, " + "then regularisation will only take place if the particles have been " + "resampled. If the :class:`~.Resampler` is not defined but a " + ":class:`~.Regulariser` is, then regularisation will be conducted under the " + "assumption that the user intends for this to occur.", + ) constraint_func: Callable = Property( default=None, doc="Callable, user defined function for applying " - "constraints to the states. This is done by setting the weights " - "of particles to 0 for particles that are not correctly constrained. " - "This function provides indices of the unconstrained particles and " - "should accept a :class:`~.ParticleState` object and return an array-like " - "object of logical indices. " + "constraints to the states. This is done by setting the weights " + "of particles to 0 for particles that are not correctly constrained. " + "This function provides indices of the unconstrained particles and " + "should accept a :class:`~.ParticleState` object and return an array-like " + "object of logical indices. ", ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - #@profile + # @profile def update(self, hypothesis, **kwargs): """Particle Filter update step @@ -73,43 +65,48 @@ def update(self, hypothesis, **kwargs): ------- : :class:`~.ParticleState` The state posterior - """ + """ predicted_state = Update.from_state( state=hypothesis.prediction, hypothesis=hypothesis, - timestamp=hypothesis.prediction.timestamp + timestamp=hypothesis.prediction.timestamp, ) if hypothesis.measurement.measurement_model is None: measurement_model = self.measurement_model else: measurement_model = hypothesis.measurement.measurement_model - R = measurement_model.covar() - - x = measurement_model.function(predicted_state) # SLOW LINE 92% guess it is just StoneSoup - pdf_value = multivariate_normal.pdf(x.T,np.ravel(hypothesis.measurement.state_vector), R) + + x = measurement_model.function( + predicted_state + ) + pdf_value = multivariate_normal.pdf( + x.T, np.ravel(hypothesis.measurement.state_vector), R + ) new_weight = np.ravel(hypothesis.prediction.weight) * np.ravel(pdf_value) - - new_weight = new_weight/(np.prod(hypothesis.prediction.grid_delta)*sum(new_weight)) - - predicted_state = PointMassState(state_vector=hypothesis.prediction.state_vector, - weight=new_weight, - grid_delta = hypothesis.prediction.grid_delta, - grid_dim = hypothesis.prediction.grid_dim, - center = hypothesis.prediction.center, - eigVec = hypothesis.prediction.eigVec, - Npa = hypothesis.prediction.Npa, - timestamp = hypothesis.prediction.timestamp) - # plt.plot(new_weight) + + new_weight = new_weight / ( + np.prod(hypothesis.prediction.grid_delta) * sum(new_weight) + ) + + predicted_state = PointMassState( + state_vector=hypothesis.prediction.state_vector, + weight=new_weight, + grid_delta=hypothesis.prediction.grid_delta, + grid_dim=hypothesis.prediction.grid_dim, + center=hypothesis.prediction.center, + eigVec=hypothesis.prediction.eigVec, + Npa=hypothesis.prediction.Npa, + timestamp=hypothesis.prediction.timestamp, + ) return predicted_state @lru_cache() - def predict_measurement(self, state_prediction, measurement_model=None, - **kwargs): + def predict_measurement(self, state_prediction, measurement_model=None, **kwargs): if measurement_model is None: measurement_model = self.measurement_model @@ -117,6 +114,7 @@ def predict_measurement(self, state_prediction, measurement_model=None, new_state_vector = measurement_model.function(state_prediction, **kwargs) return MeasurementPrediction.from_state( - state_prediction, state_vector=new_state_vector, timestamp=state_prediction.timestamp) - - + state_prediction, + state_vector=new_state_vector, + timestamp=state_prediction.timestamp, + ) From 13ac8e84fc9047b02feaf20b67cffdd0847f912d Mon Sep 17 00:00:00 2001 From: pesslovany <43780444+pesslovany@users.noreply.github.com> Date: Wed, 26 Jun 2024 14:09:27 +0200 Subject: [PATCH 03/16] Update nonlinear.py Deleted unwanted changes --- stonesoup/models/measurement/nonlinear.py | 85 ----------------------- 1 file changed, 85 deletions(-) diff --git a/stonesoup/models/measurement/nonlinear.py b/stonesoup/models/measurement/nonlinear.py index f129a297b..b72127ee2 100644 --- a/stonesoup/models/measurement/nonlinear.py +++ b/stonesoup/models/measurement/nonlinear.py @@ -20,91 +20,6 @@ from ...types.state import State - -class TerrainAidedNavigation(): - - def __init__(self,interpolator,noise_covar,mapping): - self.interpolator = interpolator - self.noise_covar = noise_covar - self.mapping = mapping - - @property - def ndim_meas(self) -> int: - """ndim_meas getter method - - Returns - ------- - :class:`int` - The number of measurement dimensions - """ - return 1 - - def function(self, state, noise=False, **kwargs) -> StateVector: - - out = self.interpolator(state.state_vector[self.mapping,:].T) - if isinstance(noise, bool) or noise is None: - if noise: - noise = np.random.normal([0], self.noise_covar, out.size) - out = out + noise - - - # Return the interpolated measurements with added noise - return out - - def covar(self, **kwargs) -> CovarianceMatrix: - """Returns the measurement model noise covariance matrix. - - Returns - ------- - :class:`~.CovarianceMatrix` of shape\ - (:py:attr:`~ndim_meas`, :py:attr:`~ndim_meas`) - The measurement noise covariance. - """ - - return self.noise_covar - - def logpdf(self, state1: State, state2: State, **kwargs) -> Union[float, np.ndarray]: - r"""Model log pdf/likelihood evaluation function - - Evaluates the pdf/likelihood of ``state1``, given the state - ``state2`` which is passed to :meth:`function()`. - - In mathematical terms, this can be written as: - - .. math:: - - p = p(y_t | x_t) = \mathcal{N}(y_t; x_t, Q) - - where :math:`y_t` = ``state_vector1``, :math:`x_t` = ``state_vector2`` - and :math:`Q` = :attr:`covar`. - - Parameters - ---------- - state1 : State - state2 : State - - Returns - ------- - : float or :class:`~.numpy.ndarray` - The log likelihood of ``state1``, given ``state2`` - """ - covar = self.covar(**kwargs) - - - # Calculate difference before to handle custom types (mean defaults to zero) - # This is required as log pdf coverts arrays to floats - likelihood = np.atleast_1d( - multivariate_normal.logpdf((state1.state_vector - self.function(state2, **kwargs)).T, - cov=covar)) - - if len(likelihood) == 1: - likelihood = likelihood[0] - - return likelihood - - - - class CombinedReversibleGaussianMeasurementModel(ReversibleModel, GaussianModel, MeasurementModel): r"""Combine multiple models into a single model by stacking them. From 1dc8e5abf5b9d1bfc3f7cc40f6316669687aabba Mon Sep 17 00:00:00 2001 From: pesslovany Date: Wed, 26 Jun 2024 14:13:50 +0200 Subject: [PATCH 04/16] flake8 changes --- stonesoup/functions/__init__.py | 8 ++++---- stonesoup/predictor/pointMass.py | 2 -- stonesoup/types/prediction.py | 6 ++++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/stonesoup/functions/__init__.py b/stonesoup/functions/__init__.py index 1060d4cd2..ef2319fa8 100644 --- a/stonesoup/functions/__init__.py +++ b/stonesoup/functions/__init__.py @@ -209,11 +209,11 @@ def gauss2sigma(state, alpha=1.0, beta=2.0, kappa=None): sigma_points = sigma_points.astype(float) # Can't use in place addition/subtraction as casting issues may arise when mixing float/int - sigma_points[:, 1 : (ndim_state + 1)] = sigma_points[ - :, 1 : (ndim_state + 1) + sigma_points[:, 1: (ndim_state + 1)] = sigma_points[ + :, 1: (ndim_state + 1) ] + sqrt_sigma * np.sqrt(c) - sigma_points[:, (ndim_state + 1) :] = sigma_points[ - :, (ndim_state + 1) : + sigma_points[:, (ndim_state + 1):] = sigma_points[ + :, (ndim_state + 1): ] - sqrt_sigma * np.sqrt(c) # Put these sigma points into s State object list diff --git a/stonesoup/predictor/pointMass.py b/stonesoup/predictor/pointMass.py index d0014371f..b4bdbf142 100644 --- a/stonesoup/predictor/pointMass.py +++ b/stonesoup/predictor/pointMass.py @@ -129,11 +129,9 @@ def predict(self, prior, timestamp=None, **kwargs): np.sum(predDensityProb) * np.prod(GridDelta) ) - xOld = F @ np.vstack(prior.mean) Ppold = F @ eigVect - return PointMassState( state_vector=StateVectors(np.squeeze(predGrid)), weight=abs(predDensityProb), diff --git a/stonesoup/types/prediction.py b/stonesoup/types/prediction.py index f2319ef98..1df6f0180 100644 --- a/stonesoup/types/prediction.py +++ b/stonesoup/types/prediction.py @@ -141,19 +141,21 @@ class ParticleStatePrediction(Prediction, ParticleState): This is a simple Particle state prediction object. """ + class PointMassStatePrediction(Prediction, PointMassState): """ParticleStatePrediction type This is a simple Particle state prediction object. """ + class ParticleMeasurementPrediction(MeasurementPrediction, ParticleState): """MeasurementStatePrediction type This is a simple Particle measurement prediction object. """ - - + + class PointMassMeasurementPrediction(MeasurementPrediction, PointMassState): """MeasurementStatePrediction type From 1442c8b85c60296332beb0c89d842c33563fa8a4 Mon Sep 17 00:00:00 2001 From: pesslovany <43780444+pesslovany@users.noreply.github.com> Date: Wed, 26 Jun 2024 14:16:08 +0200 Subject: [PATCH 05/16] flake8 --- stonesoup/models/measurement/nonlinear.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/stonesoup/models/measurement/nonlinear.py b/stonesoup/models/measurement/nonlinear.py index b72127ee2..d8048d846 100644 --- a/stonesoup/models/measurement/nonlinear.py +++ b/stonesoup/models/measurement/nonlinear.py @@ -17,7 +17,6 @@ from ...types.angle import Bearing, Elevation, Azimuth from ..base import LinearModel, GaussianModel, ReversibleModel from .base import MeasurementModel -from ...types.state import State class CombinedReversibleGaussianMeasurementModel(ReversibleModel, GaussianModel, MeasurementModel): @@ -280,8 +279,6 @@ def rvs(self, num_samples=1, **kwargs) -> Union[StateVector, StateVectors]: out = super().rvs(num_samples, **kwargs) out = np.array([[Elevation(0.)], [Bearing(0.)], [0.]]) + out return out - - class CartesianToBearingRange(NonLinearGaussianMeasurement, ReversibleModel): @@ -427,7 +424,6 @@ def rvs(self, num_samples=1, **kwargs) -> Union[StateVector, StateVectors]: return out - class CartesianToElevationBearing(NonLinearGaussianMeasurement): r"""This is a class implementation of a time-invariant measurement model, \ where measurements are assumed to be received in the form of bearing \ From 1baf1bf42c3c3c494cfb3ff8b5c3ace7f488d41c Mon Sep 17 00:00:00 2001 From: pesslovany Date: Wed, 26 Jun 2024 18:39:46 +0200 Subject: [PATCH 06/16] added test for functions and pointmass --- stonesoup/functions/tests/test_functions.py | 302 +- stonesoup/plotter.py | 3024 ------------------- stonesoup/types/tests/test_state.py | 461 ++- stonesoup/updater/tests/test_kalman.py | 5 + stonesoup/updater/tests/test_pointmass.py | 135 + 5 files changed, 650 insertions(+), 3277 deletions(-) delete mode 100644 stonesoup/plotter.py create mode 100644 stonesoup/updater/tests/test_pointmass.py diff --git a/stonesoup/functions/tests/test_functions.py b/stonesoup/functions/tests/test_functions.py index 057bb7bde..5ced14ae1 100644 --- a/stonesoup/functions/tests/test_functions.py +++ b/stonesoup/functions/tests/test_functions.py @@ -1,22 +1,60 @@ -import pytest import numpy as np +import pytest from numpy import deg2rad -from scipy.linalg import cholesky, LinAlgError +from numpy import linalg as LA from pytest import approx, raises +from scipy.linalg import LinAlgError, cholesky +from ...types.array import CovarianceMatrix, Matrix, StateVector, StateVectors +from ...types.state import GaussianState, State from .. import ( - cholesky_eps, jacobian, gm_reduce_single, mod_bearing, mod_elevation, gauss2sigma, - rotx, roty, rotz, cart2sphere, cart2angles, pol2cart, sphere2cart, dotproduct, gm_sample, - gauss2cubature, cubature2gauss, cubature_transform) -from ...types.array import StateVector, StateVectors, Matrix, CovarianceMatrix -from ...types.state import State, GaussianState + cart2angles, + cart2sphere, + cholesky_eps, + cubature2gauss, + cubature_transform, + dotproduct, + gauss2cubature, + gauss2sigma, + gm_reduce_single, + gm_sample, + gridCreation, + jacobian, + mod_bearing, + mod_elevation, + pol2cart, + rotx, + roty, + rotz, + sphere2cart, +) + + +def test_gridCreation(): + nx = 4 + meanX0 = np.array([36569, 50, 55581, 50]) # mean value + varX0 = np.diag([90, 5, 160, 5]) # variance + Npa = np.array([31, 31, 27, 27]) # must be ODD! + sFactor = 4 # scaling factor (number of sigmas covered by the grid) + + [predGrid, predGridDelta, gridDimOld, xOld, Ppold] = gridCreation( + np.vstack(meanX0), varX0, sFactor, nx, Npa + ) + + mean_diffs = np.array([np.mean(np.diff(sublist)) for sublist in gridDimOld]) + + eigVal, eigVect = LA.eig(varX0) + + assert np.allclose(meanX0, np.mean(predGrid, axis=1), 0, atol=1.0e-1) + assert np.all(meanX0 == xOld.ravel()) + assert np.all(np.argsort(predGridDelta) == np.argsort(np.diag(varX0))) + assert np.allclose(mean_diffs, predGridDelta, 0, atol=1e-10) + assert np.all(eigVect == Ppold) def test_cholesky_eps(): - matrix = np.array([[0.4, -0.2, 0.1], - [0.3, 0.1, -0.2], - [-0.3, 0.0, 0.4]]) - matrix = matrix@matrix.T + matrix = np.array([[0.4, -0.2, 0.1], [0.3, 0.1, -0.2], [-0.3, 0.0, 0.4]]) + matrix = matrix @ matrix.T cholesky_matrix = cholesky(matrix) @@ -26,51 +64,54 @@ def test_cholesky_eps(): def test_cholesky_eps_bad(): matrix = np.array( - [[ 0.05201447, 0.02882126, -0.00569971, -0.00733617], # noqa: E201 - [ 0.02882126, 0.01642966, -0.00862847, -0.00673035], # noqa: E201 - [-0.00569971, -0.00862847, 0.06570757, 0.03251551], - [-0.00733617, -0.00673035, 0.03251551, 0.01648615]]) + [ + [0.05201447, 0.02882126, -0.00569971, -0.00733617], # noqa: E201 + [0.02882126, 0.01642966, -0.00862847, -0.00673035], # noqa: E201 + [-0.00569971, -0.00862847, 0.06570757, 0.03251551], + [-0.00733617, -0.00673035, 0.03251551, 0.01648615], + ] + ) with raises(LinAlgError): cholesky(matrix) cholesky_eps(matrix) def test_jacobian(): - """ jacobian function test """ + """jacobian function test""" # State related variables state_mean = StateVector([[3.0], [1.0]]) def f(x): - return np.array([[1, 1], [0, 1]])@x.state_vector + return np.array([[1, 1], [0, 1]]) @ x.state_vector jac = jacobian(f, State(state_mean)) assert np.allclose(jac, np.array([[1, 1], [0, 1]])) def test_jacobian2(): - """ jacobian function test """ + """jacobian function test""" # Sample functions to compute Jacobian on def fun(x): - """ function for testing scalars i.e. scalar input, scalar output""" - return 2*x.state_vector**2 + """function for testing scalars i.e. scalar input, scalar output""" + return 2 * x.state_vector**2 def fun1d(ins): - """ test function with vector input, scalar output""" - out = 2*ins.state_vector[0, :]+3*ins.state_vector[1, :] + """test function with vector input, scalar output""" + out = 2 * ins.state_vector[0, :] + 3 * ins.state_vector[1, :] return np.atleast_2d(out) def fun2d(vec): - """ test function with 2d input and 2d output""" + """test function with 2d input and 2d output""" out = np.empty(vec.state_vector.shape) - out[0, :] = 2*vec.state_vector[0, :]**2 + 3*vec.state_vector[1, :]**2 - out[1, :] = 2*vec.state_vector[0, :]+3*vec.state_vector[1, :] + out[0, :] = 2 * vec.state_vector[0, :] ** 2 + 3 * vec.state_vector[1, :] ** 2 + out[1, :] = 2 * vec.state_vector[0, :] + 3 * vec.state_vector[1, :] return out x = 3 jac = jacobian(fun, State(StateVector([[x]]))) - assert np.allclose(jac, 4*x) + assert np.allclose(jac, 4 * x) x = StateVector([[1], [2]]) # Tolerance value to use to test if arrays are equal @@ -79,25 +120,24 @@ def fun2d(vec): jac = jacobian(fun1d, State(x)) T = np.array([2.0, 3.0]) - FOM = np.where(np.abs(jac-T) > tol) + FOM = np.where(np.abs(jac - T) > tol) # Check # of array elements bigger than tol assert len(FOM[0]) == 0 jac = jacobian(fun2d, State(x)) - T = np.array([[4.0*x[0], 6*x[1]], - [2, 3]]) + T = np.array([[4.0 * x[0], 6 * x[1]], [2, 3]]) FOM = np.where(np.abs(jac - T) > tol) # Check # of array elements bigger than tol assert len(FOM[0]) == 0 def test_jacobian_param(): - """ jacobian function test """ + """jacobian function test""" # Sample functions to compute Jacobian on def fun(x, value=0.0): - """ function for jabcobian parameter passing""" - return value*x.state_vector + """function for jabcobian parameter passing""" + return value * x.state_vector x = 4 value = 2.0 @@ -107,7 +147,7 @@ def fun(x, value=0.0): def test_jacobian_large_values(): # State related variables - state = State(StateVector([[1E10], [1.0]])) + state = State(StateVector([[1e10], [1.0]])) def f(x): return x.state_vector**2 @@ -118,31 +158,31 @@ def f(x): def test_gm_reduce_single(): - means = StateVectors([StateVector([1, 2]), StateVector([3, 4]), StateVector([5, 6])]) - covars = np.stack([[[1, 1], [1, 0.7]], - [[1.2, 1.4], [1.3, 2]], - [[2, 1.4], [1.2, 1.2]]], axis=2) + means = StateVectors( + [StateVector([1, 2]), StateVector([3, 4]), StateVector([5, 6])] + ) + covars = np.stack( + [[[1, 1], [1, 0.7]], [[1.2, 1.4], [1.3, 2]], [[2, 1.4], [1.2, 1.2]]], axis=2 + ) weights = np.array([1, 2, 5]) mean, covar = gm_reduce_single(means, covars, weights) assert np.allclose(mean, np.array([[4], [5]])) - assert np.allclose(covar, np.array([[3.675, 3.35], - [3.2, 3.3375]])) + assert np.allclose(covar, np.array([[3.675, 3.35], [3.2, 3.3375]])) # Test handling of means as array instead of StateVectors mean, covar = gm_reduce_single(means.view(np.ndarray), covars, weights) assert np.allclose(mean, np.array([[4], [5]])) - assert np.allclose(covar, np.array([[3.675, 3.35], - [3.2, 3.3375]])) + assert np.allclose(covar, np.array([[3.675, 3.35], [3.2, 3.3375]])) def test_bearing(): - bearing_in = [10., 170., 190., 260., 280., 350., 705] + bearing_in = [10.0, 170.0, 190.0, 260.0, 280.0, 350.0, 705] rad_in = deg2rad(bearing_in) - bearing_out = [10., 170., -170., -100., -80., -10., -15.] + bearing_out = [10.0, 170.0, -170.0, -100.0, -80.0, -10.0, -15.0] rad_out = deg2rad(bearing_out) for ind, val in enumerate(rad_in): @@ -150,23 +190,17 @@ def test_bearing(): def test_elevation(): - elev_in = [10., 80., 110., 170., 190., 260., 280] + elev_in = [10.0, 80.0, 110.0, 170.0, 190.0, 260.0, 280] rad_in = deg2rad(elev_in) - elev_out = [10., 80., 70., 10., -10., -80., -80.] + elev_out = [10.0, 80.0, 70.0, 10.0, -10.0, -80.0, -80.0] rad_out = deg2rad(elev_out) for ind, val in enumerate(rad_in): assert rad_out[ind] == approx(mod_elevation(val)) -@pytest.mark.parametrize( - "mean", - [ - 1, # int - 1.0 # float - ] -) +@pytest.mark.parametrize("mean", [1, 1.0]) # int # float def test_gauss2sigma(mean): covar = 2.0 state = GaussianState([[mean]], [[covar]]) @@ -174,15 +208,18 @@ def test_gauss2sigma(mean): sigma_points_states, mean_weights, covar_weights = gauss2sigma(state, kappa=0) for n, sigma_point_state in zip((0, 1, -1), sigma_points_states): - assert sigma_point_state.state_vector[0, 0] == approx(mean + n*covar**0.5) + assert sigma_point_state.state_vector[0, 0] == approx(mean + n * covar**0.5) def test_gauss2sigma_bad_covar(): covar = np.array( - [[ 0.05201447, 0.02882126, -0.00569971, -0.00733617], # noqa: E201 - [ 0.02882126, 0.01642966, -0.00862847, -0.00673035], # noqa: E201 - [-0.00569971, -0.00862847, 0.06570757, 0.03251551], - [-0.00733617, -0.00673035, 0.03251551, 0.01648615]]) + [ + [0.05201447, 0.02882126, -0.00569971, -0.00733617], # noqa: E201 + [0.02882126, 0.01642966, -0.00862847, -0.00673035], # noqa: E201 + [-0.00569971, -0.00862847, 0.06570757, 0.03251551], + [-0.00733617, -0.00673035, 0.03251551, 0.01648615], + ] + ) state = GaussianState([[0], [0], [0], [0]], covar) with pytest.warns(UserWarning, match="Matrix is not positive definite"): @@ -202,7 +239,7 @@ def test_gauss2sigma_bad_covar(): np.array([np.pi / 8]), np.array([-np.pi / 8]), ) - ] + ], ) def test_rotations(angle): @@ -210,28 +247,28 @@ def test_rotations(angle): zero = np.zeros_like(angle) one = np.ones_like(angle) - assert np.array_equal(rotx(angle), np.array([[one, zero, zero], - [zero, c, -s], - [zero, s, c]])) - assert np.array_equal(roty(angle), np.array([[c, zero, s], - [zero, one, zero], - [-s, zero, c]])) - assert np.array_equal(rotz(angle), np.array([[c, -s, zero], - [s, c, zero], - [zero, zero, one]])) + assert np.array_equal( + rotx(angle), np.array([[one, zero, zero], [zero, c, -s], [zero, s, c]]) + ) + assert np.array_equal( + roty(angle), np.array([[c, zero, s], [zero, one, zero], [-s, zero, c]]) + ) + assert np.array_equal( + rotz(angle), np.array([[c, -s, zero], [s, c, zero], [zero, zero, one]]) + ) @pytest.mark.parametrize( "x, y, z", [ # Cartesian values - (1., 0., 0.), - (0., 1., 0.), - (0., 0., 1.), - (1., 1., 0.), - (1., 0., 1.), - (0., 1., 1.), - (1., 1., 1.) - ] + (1.0, 0.0, 0.0), + (0.0, 1.0, 0.0), + (0.0, 0.0, 1.0), + (1.0, 1.0, 0.0), + (1.0, 0.0, 1.0), + (0.0, 1.0, 1.0), + (1.0, 1.0, 1.0), + ], ) def test_cart_sphere_inversions(x, y, z): @@ -258,17 +295,21 @@ def test_cart_sphere_inversions(x, y, z): (StateVector([-1, 0]), StateVector([1, -2, 3])), (Matrix([[1, 0], [0, 1]]), Matrix([[3, 1], [1, -3]])), (StateVectors([[1, 0], [0, 1]]), StateVectors([[3, 1], [1, -3]])), - (StateVectors([[1, 0], [0, 1]]), StateVector([3, 1])) - ] + (StateVectors([[1, 0], [0, 1]]), StateVector([3, 1])), + ], ) def test_dotproduct(state_vector1, state_vector2): # Test that they raise the right error if not 1d, i.e. vectors - if type(state_vector1) is not type(state_vector2): + if not isinstance(state_vector1, type(state_vector2)): with pytest.raises(ValueError): dotproduct(state_vector1, state_vector2) - elif type(state_vector1) is not StateVectors and type(state_vector2) is not StateVectors and \ - type(state_vector2) is not StateVector and type(state_vector1) is not StateVector: + elif ( + not isinstance(state_vector1, StateVectors) + and not isinstance(state_vector2, StateVectors) + and not isinstance(state_vector2, StateVector) + and not isinstance(state_vector1, StateVector) + ): with pytest.raises(ValueError): dotproduct(state_vector1, state_vector2) else: @@ -282,8 +323,10 @@ def test_dotproduct(state_vector1, state_vector2): for a_i, b_i in zip(state_vector1, state_vector2): out += a_i * b_i - assert np.allclose(dotproduct(state_vector1, state_vector2), - np.reshape(out, np.shape(dotproduct(state_vector1, state_vector2)))) + assert np.allclose( + dotproduct(state_vector1, state_vector2), + np.reshape(out, np.shape(dotproduct(state_vector1, state_vector2))), + ) @pytest.mark.parametrize( @@ -292,48 +335,66 @@ def test_dotproduct(state_vector1, state_vector2): ( [np.array([10, 10]), np.array([20, 20]), np.array([30, 30])], # means [np.eye(2), np.eye(2), np.eye(2)], # covars - np.array([1/3]*3), # weights - 20 # size - ), ( + np.array([1 / 3] * 3), # weights + 20, # size + ), + ( StateVectors(np.array([[20, 30, 40, 50], [20, 30, 40, 50]])), # means [np.eye(2), np.eye(2), np.eye(2), np.eye(2)], # covars - np.array([1/4]*4), # weights - 20 # size - ), ( + np.array([1 / 4] * 4), # weights + 20, # size + ), + ( [np.array([10, 10]), np.array([20, 20]), np.array([30, 30])], # means np.array([np.eye(2), np.eye(2), np.eye(2)]), # covars - np.array([1/3]*3), # weights - 20 # size - ), ( - [StateVector(np.array([10, 10])), StateVector(np.array([20, 20])), - StateVector(np.array([30, 30]))], # means + np.array([1 / 3] * 3), # weights + 20, # size + ), + ( + [ + StateVector(np.array([10, 10])), + StateVector(np.array([20, 20])), + StateVector(np.array([30, 30])), + ], # means [np.eye(2), np.eye(2), np.eye(2)], # covars - np.array([1/3]*3), # weights - 20 # size - ), ( + np.array([1 / 3] * 3), # weights + 20, # size + ), + ( StateVector(np.array([10, 10])), # means [np.eye(2)], # covars np.array([1]), # weights - 20 # size - ), ( + 20, # size + ), + ( np.array([10, 10]), # means [np.eye(2)], # covars np.array([1]), # weights - 20 # size - ), ( + 20, # size + ), + ( [np.array([10, 10]), np.array([20, 20]), np.array([30, 30])], # means [np.eye(2), np.eye(2), np.eye(2)], # covars None, # weights - 20 # size - ), ( + 20, # size + ), + ( StateVectors(np.array([[20, 30, 40, 50], [20, 30, 40, 50]])), # means [np.eye(2), np.eye(2), np.eye(2), np.eye(2)], # covars None, # weights - 20 # size - ) - ], ids=["mean_list", "mean_statevectors", "3d_covar_array", "mean_statevector_list", - "single_statevector_mean", "single_ndarray_mean", "no_weight_mean_list", - "no_weight_mean_statevectors"] + 20, # size + ), + ], + ids=[ + "mean_list", + "mean_statevectors", + "3d_covar_array", + "mean_statevector_list", + "single_statevector_mean", + "single_ndarray_mean", + "no_weight_mean_list", + "no_weight_mean_statevectors", + ], ) def test_gm_sample(means, covars, weights, size): samples = gm_sample(means, covars, size, weights=weights) @@ -352,11 +413,19 @@ def test_gm_sample(means, covars, weights, size): [ (StateVector([0]), CovarianceMatrix([[1]]), None), (StateVector([-7, 5]), CovarianceMatrix([[1.1, -0.04], [-0.04, 1.2]]), 2.0), - (StateVector([12, -4, 0, 5]), CovarianceMatrix([[0.7, 0.04, -0.02, 0], - [0.04, 1.1, 0.09, 0.06], - [-0.02, 0.09, 0.9, -0.01], - [0, 0.06, -0.01, 1.1]]), 0.7) - ] + ( + StateVector([12, -4, 0, 5]), + CovarianceMatrix( + [ + [0.7, 0.04, -0.02, 0], + [0.04, 1.1, 0.09, 0.06], + [-0.02, 0.09, 0.9, -0.01], + [0, 0.06, -0.01, 1.1], + ] + ), + 0.7, + ), + ], ) def test_cubature_transform(mean, covar, alp): @@ -369,12 +438,15 @@ def identity_function(inpu): if alp is None: cub_pts = gauss2cubature(instate) outsv, outcovar = cubature2gauss(cub_pts) - mean, covar, cross_covar, cubature_points = cubature_transform(instate, identity_function) + mean, covar, cross_covar, cubature_points = cubature_transform( + instate, identity_function + ) else: cub_pts = gauss2cubature(instate, alpha=alp) outsv, outcovar = cubature2gauss(cub_pts, alpha=alp) - mean, covar, cross_covar, cubature_points = cubature_transform(instate, identity_function, - alpha=alp) + mean, covar, cross_covar, cubature_points = cubature_transform( + instate, identity_function, alpha=alp + ) assert np.allclose(outsv, instate.state_vector) assert np.allclose(outcovar, instate.covar) diff --git a/stonesoup/plotter.py b/stonesoup/plotter.py deleted file mode 100644 index 528597cce..000000000 --- a/stonesoup/plotter.py +++ /dev/null @@ -1,3024 +0,0 @@ -import warnings -from abc import ABC, abstractmethod -from datetime import datetime, timedelta -from enum import IntEnum -from itertools import chain -from typing import Collection, Iterable, Union, List, Optional, Tuple, Dict - -import numpy as np -from matplotlib import animation as animation -from matplotlib import pyplot as plt -from matplotlib.legend_handler import HandlerPatch -from matplotlib.lines import Line2D -from matplotlib.patches import Ellipse -from mergedeep import merge -from scipy.integrate import quad -from scipy.optimize import brentq -from scipy.stats import kde -try: - from plotly import colors -except ImportError: - colors = None -try: - import plotly.graph_objects as go -except ImportError: - go = None - -from .base import Base, Property -from .models.base import LinearModel, Model -from .types import detection -from .types.array import StateVector -from .types.groundtruth import GroundTruthPath -from .types.metric import SingleTimeMetric -from .types.state import State, StateMutableSequence -from .types.update import Update - - -class Dimension(IntEnum): - """Dimension Enum class for specifying plotting parameters in the Plotter class. - Used to sanitize inputs for the dimension attribute of Plotter(). - - Attributes - ---------- - TWO: int - Specifies 2D plotting for Plotter object - THREE: int - Specifies 3D plotting for Plotter object - """ - ONE = 1 # 1D plotting mode (plot state over time in Plotterly) - TWO = 2 # 2D plotting mode (original plotter.py functionality) - THREE = 3 # 3D plotting mode - - -class _Plotter(ABC): - - @abstractmethod - def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwargs): - raise NotImplementedError - - @abstractmethod - def plot_measurements(self, measurements, mapping, measurement_model=None, - measurements_label="Measurements", **kwargs): - raise NotImplementedError - - @abstractmethod - def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_label="Tracks", - **kwargs): - raise NotImplementedError - - @abstractmethod - def plot_sensors(self, sensors, mapping, sensor_label="Sensors", **kwargs): - raise NotImplementedError - - def _conv_measurements(self, measurements, mapping, measurement_model=None, - convert_measurements=True) -> \ - Tuple[Dict[detection.Detection, StateVector], Dict[detection.Clutter, StateVector]]: - conv_detections = {} - conv_clutter = {} - for state in measurements: - meas_model = state.measurement_model # measurement_model from detections - if meas_model is None: - meas_model = measurement_model # measurement_model from input - - if not convert_measurements: - state_vec = state.state_vector[mapping, :] - elif isinstance(meas_model, LinearModel): - model_matrix = meas_model.matrix() - inv_model_matrix = np.linalg.pinv(model_matrix) - state_vec = (inv_model_matrix @ state.state_vector)[mapping, :] - elif isinstance(meas_model, Model): - try: - state_vec = meas_model.inverse_function(state)[mapping, :] - except (NotImplementedError, AttributeError): - warnings.warn('Nonlinear measurement model used with no inverse ' - 'function available') - continue - else: - warnings.warn('Measurement model type not specified for all detections') - continue - - if isinstance(state, detection.Clutter): - # Plot clutter - conv_clutter[state] = (*state_vec, ) - - elif isinstance(state, detection.Detection): - # Plot detections - conv_detections[state] = (*state_vec, ) - else: - warnings.warn(f'Unknown type {type(state)}') - continue - return conv_detections, conv_clutter - - -class Plotter(_Plotter): - """Plotting class for building graphs of Stone Soup simulations using matplotlib - - A plotting class which is used to simplify the process of plotting ground truths, - measurements, clutter and tracks. Tracks can be plotted with uncertainty ellipses or - particles if required. Legends are automatically generated with each plot. - Three dimensional plots can be created using the optional dimension parameter. - - Parameters - ---------- - dimension: enum \'Dimension\' - Optional parameter to specify 2D or 3D plotting. Default is 2D plotting. - plot_timeseries: bool - Specify whether data to be plotted is time series data. Default False - \\*\\*kwargs: dict - Additional arguments to be passed to plot function. For example, figsize (Default is - (10, 6)). - - Attributes - ---------- - fig: matplotlib.figure.Figure - Generated figure for graphs to be plotted on - ax: matplotlib.axes.Axes - Generated axes for graphs to be plotted on - legend_dict: dict - Dictionary of legend handles as :class:`matplotlib.legend_handler.HandlerBase` - and labels as str - """ - - def __init__(self, dimension=Dimension.TWO, **kwargs): - figure_kwargs = {"figsize": (10, 6)} - figure_kwargs.update(kwargs) - if isinstance(dimension, type(Dimension.TWO)): - self.dimension = dimension - elif isinstance(dimension, int): - self.dimension = Dimension(dimension) - else: - raise TypeError("%s is an unsupported type for \'dimension\'; " - "expected type %s" % (type(dimension), type(Dimension.TWO))) - # Generate plot axes - self.fig = plt.figure(**figure_kwargs) - if self.dimension is Dimension.TWO: # 2D axes - self.ax = self.fig.add_subplot(1, 1, 1) - self.ax.axis('equal') - else: # 3D axes - self.ax = self.fig.add_subplot(111, projection='3d') - self.ax.axis('auto') - self.ax.set_zlabel("$z$") - self.ax.set_xlabel("$x$") - self.ax.set_ylabel("$y$") - - # Create empty dictionary for legend handles and labels - dict used to - # prevent multiple entries with the same label from displaying on legend - # This is new compared to plotter.py - self.legend_dict = {} # create an empty dictionary to hold legend entries - - def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwargs): - """Plots ground truth(s) - - Plots each ground truth path passed in to :attr:`truths` and generates a legend - automatically. Ground truths are plotted as dashed lines with default colors. - - Users can change linestyle, color and marker using keyword arguments. Any changes - will apply to all ground truths. - - Parameters - ---------- - truths : Collection of :class:`~.GroundTruthPath` - Collection of ground truths which will be plotted. If not a collection and instead a - single :class:`~.GroundTruthPath` type, the argument is modified to be a set to allow - for iteration. - mapping: list - List of items specifying the mapping of the position components of the state space. - truths_label: str - Label for truth data. Default is "Ground Truth" - \\*\\*kwargs: dict - Additional arguments to be passed to plot function. Default is ``linestyle="--"``. - - Returns - ------- - : list of :class:`matplotlib.artist.Artist` - List of artists that have been added to the axis. - """ - truths_kwargs = dict(linestyle="--") - truths_kwargs.update(kwargs) - if not isinstance(truths, Collection) or isinstance(truths, StateMutableSequence): - truths = {truths} # Make a set of length 1 - - artists = [] - for truth in truths: - if self.dimension is Dimension.TWO: # plots the ground truths in xy - artists.extend( - self.ax.plot([state.state_vector[mapping[0]] for state in truth], - [state.state_vector[mapping[1]] for state in truth], - **truths_kwargs)) - elif self.dimension is Dimension.THREE: # plots the ground truths in xyz - artists.extend( - self.ax.plot3D([state.state_vector[mapping[0]] for state in truth], - [state.state_vector[mapping[1]] for state in truth], - [state.state_vector[mapping[2]] for state in truth], - **truths_kwargs)) - else: - raise NotImplementedError('Unsupported dimension type for truth plotting') - # Generate legend items - if "color" in kwargs: - colour = kwargs["color"] - else: - colour = "black" - truths_handle = Line2D([], [], linestyle=truths_kwargs['linestyle'], color=colour) - self.legend_dict[truths_label] = truths_handle - # Generate legend - artists.append(self.ax.legend(handles=self.legend_dict.values(), - labels=self.legend_dict.keys())) - return artists - - def plot_measurements(self, measurements, mapping, measurement_model=None, - measurements_label="Measurements", convert_measurements=True, **kwargs): - """Plots measurements - - Plots detections and clutter, generating a legend automatically. Detections are plotted as - blue circles by default unless the detection type is clutter. - If the detection type is :class:`~.Clutter` it is plotted as a yellow 'tri-up' marker. - - Users can change the color and marker of detections using keyword arguments but not for - clutter detections. - - Parameters - ---------- - measurements : Collection of :class:`~.Detection` - Detections which will be plotted. If measurements is a set of lists it is flattened. - mapping: list - List of items specifying the mapping of the position components of the state space. - measurement_model : :class:`~.Model`, optional - User-defined measurement model to be used in finding measurement state inverses if - they cannot be found from the measurements themselves. - measurements_label : str - Label for the measurements. Default is "Measurements". - convert_measurements : bool - Should the measurements be converted from measurement space to state space before - being plotted. Default is True - \\*\\*kwargs: dict - Additional arguments to be passed to plot function for detections. Defaults are - ``marker='o'`` and ``color='b'``. - - Returns - ------- - : list of :class:`matplotlib.artist.Artist` - List of artists that have been added to the axis. - """ - - measurement_kwargs = dict(marker='o', color='b') - measurement_kwargs.update(kwargs) - - if not isinstance(measurements, Collection): - measurements = {measurements} # Make a set of length 1 - - if any(isinstance(item, set) for item in measurements): - measurements_set = chain.from_iterable(measurements) # Flatten into one set - else: - measurements_set = measurements - - plot_detections, plot_clutter = self._conv_measurements(measurements_set, - mapping, - measurement_model, - convert_measurements) - - artists = [] - if plot_detections: - detection_array = np.array(list(plot_detections.values())) - # *detection_array.T unpacks detection_array by columns - # (same as passing in detection_array[:,0], detection_array[:,1], etc...) - artists.append(self.ax.scatter(*detection_array.T, **measurement_kwargs)) - measurements_handle = Line2D([], [], linestyle='', **measurement_kwargs) - - # Generate legend items for measurements - self.legend_dict[measurements_label] = measurements_handle - - if plot_clutter: - clutter_kwargs = kwargs.copy() - clutter_kwargs.update(dict(marker='2')) - clutter_array = np.array(list(plot_clutter.values())) - artists.append(self.ax.scatter(*clutter_array.T, **clutter_kwargs)) - clutter_handle = Line2D([], [], linestyle='', **clutter_kwargs) - clutter_label = "Clutter" - - # Generate legend items for clutter - self.legend_dict[clutter_label] = clutter_handle - - # Generate legend - artists.append(self.ax.legend(handles=self.legend_dict.values(), - labels=self.legend_dict.keys())) - return artists - - def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_label="Tracks", - err_freq=1, same_color=False, **kwargs): - """Plots track(s) - - Plots each track generated, generating a legend automatically. If ``uncertainty=True`` - and is being plotted in 2D, error ellipses are plotted. If being plotted in - 3D, uncertainty bars are plotted every :attr:`err_freq` measurement, default - plots uncertainty bars at every track step. Tracks are plotted as solid - lines with point markers and default colors. Uncertainty bars are plotted - with a default color which is the same for all tracks. - - Users can change linestyle, color and marker using keyword arguments. Uncertainty metrics - will also be plotted with the user defined colour and any changes will apply to all tracks. - - Parameters - ---------- - tracks : Collection of :class:`~.Track` - Collection of tracks which will be plotted. If not a collection, and instead a single - :class:`~.Track` type, the argument is modified to be a set to allow for iteration. - mapping: list - List of items specifying the mapping of the position - components of the state space. - uncertainty : bool - If True, function plots uncertainty ellipses or bars. - particle : bool - If True, function plots particles. - track_label: str - Label to apply to all tracks for legend. - err_freq: int - Frequency of error bar plotting on tracks. Default value is 1, meaning - error bars are plotted at every track step. - same_color: bool - Should all the tracks have the same color. Default False - \\*\\*kwargs: dict - Additional arguments to be passed to plot function. Defaults are ``linestyle="-"``, - ``marker='s'`` for :class:`~.Update` and ``marker='o'`` for other states. - - Returns - ------- - : list of :class:`matplotlib.artist.Artist` - List of artists that have been added to the axis. - """ - - tracks_kwargs = dict(linestyle='-', marker="s", color=None) - tracks_kwargs.update(kwargs) - if not isinstance(tracks, Collection) or isinstance(tracks, StateMutableSequence): - tracks = {tracks} # Make a set of length 1 - - # Plot tracks - artists = [] - track_colors = {} - for track in tracks: - # Get indexes for Update and non-Update states for styling markers - update_indexes = [] - not_update_indexes = [] - for n, state in enumerate(track): - if isinstance(state, Update): - update_indexes.append(n) - else: - not_update_indexes.append(n) - - data = np.concatenate( - [(getattr(state, 'mean', state.state_vector)[mapping, :]) - for state in track], - axis=1) - - line = self.ax.plot( - *data, - markevery=update_indexes, - **tracks_kwargs) - artists.extend(line) - if not_update_indexes: - artists.extend(self.ax.plot( - *data[:, not_update_indexes], - marker="o" if "marker" not in kwargs else kwargs['marker'], - linestyle='', - color=plt.getp(line[0], 'color'))) - track_colors[track] = plt.getp(line[0], 'color') - if same_color: - tracks_kwargs['color'] = plt.getp(line[0], 'color') - - if tracks: # If no tracks `line` won't be defined - # Assuming a single track or all plotted as the same colour then the following will - # work. Otherwise will just render the final track colour. - tracks_kwargs['color'] = plt.getp(line[0], 'color') - - # Generate legend items for track - track_handle = Line2D([], [], linestyle=tracks_kwargs['linestyle'], - marker=tracks_kwargs['marker'], color=tracks_kwargs['color']) - self.legend_dict[track_label] = track_handle - if uncertainty: - if self.dimension is Dimension.TWO: - # Plot uncertainty ellipses - for track in tracks: - HH = np.eye(track.ndim)[mapping, :] # Get position mapping matrix - check = err_freq - 1 # plot the first one - for state in track: - check += 1 - if check % err_freq: - continue - w, v = np.linalg.eig(HH @ state.covar @ HH.T) - if np.iscomplexobj(w) or np.iscomplexobj(v): - warnings.warn("Can not plot uncertainty for all states due to complex " - "eignevalues or eigenvectors", UserWarning) - continue - max_ind = np.argmax(w) - min_ind = np.argmin(w) - orient = np.arctan2(v[1, max_ind], v[0, max_ind]) - ellipse = Ellipse(xy=state.mean[mapping[:2], 0], - width=2 * np.sqrt(w[max_ind]), - height=2 * np.sqrt(w[min_ind]), - angle=np.rad2deg(orient), alpha=0.2, - color=track_colors[track]) - self.ax.add_artist(ellipse) - artists.append(ellipse) - - # Generate legend items for uncertainty ellipses - ellipse_handle = Ellipse((0.5, 0.5), 0.5, 0.5, alpha=0.2, - color=tracks_kwargs['color']) - ellipse_label = "Uncertainty" - self.legend_dict[ellipse_label] = ellipse_handle - # Generate legend - artists.append(self.ax.legend(handles=self.legend_dict.values(), - labels=self.legend_dict.keys(), - handler_map={Ellipse: _HandlerEllipse()})) - else: - # Plot 3D error bars on tracks - for track in tracks: - HH = np.eye(track.ndim)[mapping, :] # Get position mapping matrix - check = err_freq - for state in track: - if not check % err_freq: - w, v = np.linalg.eig(HH @ state.covar @ HH.T) - - xl = state.state_vector[mapping[0]] - yl = state.state_vector[mapping[1]] - zl = state.state_vector[mapping[2]] - - x_err = w[0] - y_err = w[1] - z_err = w[2] - - artists.extend( - self.ax.plot3D([xl+x_err, xl-x_err], [yl, yl], [zl, zl], - marker="_", color=tracks_kwargs['color'])) - artists.extend( - self.ax.plot3D([xl, xl], [yl+y_err, yl-y_err], [zl, zl], - marker="_", color=tracks_kwargs['color'])) - artists.extend( - self.ax.plot3D([xl, xl], [yl, yl], [zl+z_err, zl-z_err], - marker="_", color=tracks_kwargs['color'])) - check += 1 - - if particle: - if self.dimension is Dimension.TWO: - # Plot particles - for track in tracks: - for state in track: - data = state.state_vector[mapping[:2], :] - artists.extend(self.ax.plot(data[0], data[1], linestyle='', marker=".", - markersize=1, alpha=0.5)) - - # Generate legend items for particles - particle_handle = Line2D([], [], linestyle='', color="black", marker='.', - markersize=1) - particle_label = "Particles" - self.legend_dict[particle_label] = particle_handle - # Generate legend - artists.append(self.ax.legend(handles=self.legend_dict.values(), - labels=self.legend_dict.keys())) - else: - raise NotImplementedError("""Particle plotting is not currently supported for - 3D visualization""") - - else: - artists.append(self.ax.legend(handles=self.legend_dict.values(), - labels=self.legend_dict.keys())) - - return artists - - def plot_sensors(self, sensors, mapping=None, sensor_label="Sensors", **kwargs): - """Plots sensor(s) - - Plots sensors. Users can change the color and marker of sensors using keyword - arguments. Default is a black 'x' marker. - - Parameters - ---------- - sensors : Collection of :class:`~.Sensor` - Sensors to plot - mapping: list - List of items specifying the mapping of the position components of the - sensor's position. Default is either [0, 1] or [0, 1, 2] depending on `self.dimension` - sensor_label: str - Label to apply to all sensors for legend. - \\*\\*kwargs: dict - Additional arguments to be passed to plot function for sensors. Defaults are - ``marker='x'`` and ``color='black'``. - - Returns - ------- - : list of :class:`matplotlib.artist.Artist` - List of artists that have been added to the axis. - """ - - sensor_kwargs = dict(marker='x', color='black') - sensor_kwargs.update(kwargs) - - if not isinstance(sensors, Collection): - sensors = {sensors} # Make a set of length 1 - - if mapping is None: - mapping = list(range(self.dimension)) - - artists = [] - for sensor in sensors: - if self.dimension is Dimension.TWO: # plots the sensors in xy - artists.append(self.ax.scatter(sensor.position[mapping[0]], - sensor.position[mapping[1]], - **sensor_kwargs)) - elif self.dimension is Dimension.THREE: # plots the sensors in xyz - artists.extend(self.ax.plot3D(sensor.position[mapping[0]], - sensor.position[mapping[1]], - sensor.position[mapping[2]], - **sensor_kwargs)) - else: - raise NotImplementedError('Unsupported dimension type for sensor plotting') - self.legend_dict[sensor_label] = Line2D([], [], linestyle='', **sensor_kwargs) - artists.append(self.ax.legend(handles=self.legend_dict.values(), - labels=self.legend_dict.keys())) - return artists - - def set_equal_3daxis(self, axes=None): - """Plots minimum/maximum points with no linestyle to increase the plotting region to - simulate `.ax.axis('equal')` from matplotlib 2d plots which is not possible using 3d - projection. - - Parameters - ---------- - axes: list - List of dimension index specifying the equal axes, equal x and y = [0,1]. - Default is x,y [0,1]. - """ - if not axes: - axes = [0, 1] - if self.dimension is Dimension.THREE: - min_xyz = [0, 0, 0] - max_xyz = [0, 0, 0] - for n in range(3): - for line in self.ax.lines: - min_xyz[n] = np.min([min_xyz[n], *line.get_data_3d()[n]]) - max_xyz[n] = np.max([max_xyz[n], *line.get_data_3d()[n]]) - - extremes = np.max([x - y for x, y in zip(max_xyz, min_xyz)]) - equal_axes = [0, 0, 0] - for i in axes: - equal_axes[i] = 1 - lower = ([np.mean([x, y]) for x, y in zip(max_xyz, min_xyz)] - extremes/2) * equal_axes - upper = ([np.mean([x, y]) for x, y in zip(max_xyz, min_xyz)] + extremes/2) * equal_axes - ghosts = GroundTruthPath(states=[State(state_vector=lower), - State(state_vector=upper)]) - - self.ax.plot3D([state.state_vector[0] for state in ghosts], - [state.state_vector[1] for state in ghosts], - [state.state_vector[2] for state in ghosts], - linestyle="") - - def plot_density(self, state_sequences: Collection[StateMutableSequence], - index: Union[int, None] = -1, - mapping=(0, 2), n_bins=300, **kwargs): - """ - - Parameters - ---------- - state_sequences : a collection of :class:`~.StateMutableSequence` - Set of tracks which will be plotted. If not a set, and instead a single - :class:`~.Track` type, the argument is modified to be a set to allow for iteration. - index: int - Which index of the StateMutableSequences should be plotted. - Default value is '-1' which is the last state in the sequences. - index can be set to None if all indices of the sequence should be included in the plot - mapping: list - List of 2 items specifying the mapping of the x and y components of the state space. - n_bins : int - Size of the bins used to group the data - \\*\\*kwargs: dict - Additional arguments to be passed to pcolormesh function. - """ - if len(state_sequences) == 0: - raise ValueError("Skipping plotting density due to state_sequences being empty.") - if index is None: # Plot all states in the sequence - x = np.array([a_state.state_vector[mapping[0]] - for a_state_sequence in state_sequences - for a_state in a_state_sequence]) - y = np.array([a_state.state_vector[mapping[1]] - for a_state_sequence in state_sequences - for a_state in a_state_sequence]) - else: # Only plot one state out of the sequences - x = np.array([a_state_sequence.states[index].state_vector[mapping[0]] - for a_state_sequence in state_sequences]) - y = np.array([a_state_sequence.states[index].state_vector[mapping[1]] - for a_state_sequence in state_sequences]) - if np.allclose(x, y, atol=1e-10): - raise ValueError("Skipping plotting density due to x and y values are the same. " - "This leads to a singular matrix in the kde function.") - # Evaluate a gaussian kde on a regular grid of n_bins x n_bins over data extents - k = kde.gaussian_kde([x, y]) - xi, yi = np.mgrid[x.min():x.max():n_bins * 1j, y.min():y.max():n_bins * 1j] - zi = k(np.vstack([xi.flatten(), yi.flatten()])) - - # Make the plot - self.ax.pcolormesh(xi, yi, zi.reshape(xi.shape), shading='auto', **kwargs) - - # Ellipse legend patch (used in Tutorial 3) - @staticmethod - def ellipse_legend(ax, label_list, color_list, **kwargs): - """Adds an ellipse patch to the legend on the axes. One patch added for each item in - `label_list` with the corresponding color from `color_list`. - - Parameters - ---------- - ax : matplotlib.axes.Axes - Looks at the plot axes defined - label_list : list of str - Takes in list of strings intended to label ellipses in legend - color_list : list of str - Takes in list of colors corresponding to string/label - Must be the same length as label_list - \\*\\*kwargs: dict - Additional arguments to be passed to plot function. Default is ``alpha=0.2``. - """ - - ellipse_kwargs = dict(alpha=0.2) - ellipse_kwargs.update(kwargs) - - legend = ax.legend(handler_map={Ellipse: _HandlerEllipse()}) - handles, labels = ax.get_legend_handles_labels() - for color in color_list: - handle = Ellipse((0.5, 0.5), 0.5, 0.5, color=color, **ellipse_kwargs) - handles.append(handle) - for label in label_list: - labels.append(label) - legend._legend_box = None - legend._init_legend_box(handles, labels) - legend._set_loc(legend._loc) - legend.set_title(legend.get_title().get_text()) - - -class _HandlerEllipse(HandlerPatch): - def create_artists(self, legend, orig_handle, - xdescent, ydescent, width, height, fontsize, trans): - center = 0.5*width - 0.5*xdescent, 0.5*height - 0.5*ydescent - p = Ellipse(xy=center, width=width + xdescent, - height=height + ydescent) - self.update_prop(p, orig_handle, legend) - p.set_transform(trans) - return [p] - - -class MetricPlotter(ABC): - """Class for plotting Stone Soup metrics using matplotlib - - A plotting class which is used to simplify the process of plotting metrics. - Legends are automatically generated with each plot. - - """ - def __init__(self): - self.fig = None - self.axes = None - self.plottable_metrics = list() - - def plot_metrics(self, metrics, generator_names=None, metric_names=None, - combine_plots=True, **kwargs): - """Plots metrics - - Plots each plottable metric passed in to :attr:`metrics` across a series of subplots - and generates legend(s) automatically. Metrics are plotted as lines with default colors. - - Users can change linestyle, color and marker or other features using keyword arguments. - Any changes will apply to all metrics. - - Parameters - ---------- - metrics : dict of :class:`~.Metric` - Dictionary of generated metrics to be plotted. - generator_names: list of str - Generator(s) to extract specific metrics from :attr:`metrics` for plotting. - Default None to take all metrics. - metric_names: list of str - Specific metric(s) to extract from :class:`~.MetricGenerator` for plotting. - Default None to take all metrics in generators. - combine_plots: bool - Plot metrics of same type on the same subplot. Default True. - \\*\\*kwargs: dict - Additional arguments to be passed to plot function. Default is ``linestyle="-"``. - - Returns - ------- - : :class:`matplotlib.pyplot.figure` - Figure containing subplots displaying all plottable metrics. - """ - for metric_dict in metrics.values(): - for metric_name, metric in metric_dict.items(): - if isinstance(metric.value, List) \ - and all(isinstance(x, SingleTimeMetric) for x in metric.value): - self.plottable_metrics.append(metric_name) - - metrics_kwargs = dict(linestyle="-") - metrics_kwargs.update(kwargs) - - generator_names = list(metrics.keys()) if generator_names is None else generator_names - - # warning for user input metrics that will not be plotted - if metric_names is not None: - for metric_name in metric_names: - if metric_name not in self.plottable_metrics: - warnings.warn(f"{metric_name} " - f"is not a plottable metric and will not be plotted") - else: - metric_names = self.extract_metric_types(metrics) - - metrics_to_plot = self._extract_plottable_metrics(metrics, generator_names, metric_names) - - if combine_plots: - self.combine_plots(metrics_to_plot, metrics_kwargs) - else: - self.plot_separately(metrics_to_plot, metrics_kwargs) - - def _extract_plottable_metrics(self, metrics, generator_names, metric_names): - """ - Extract all plottable metrics from dict of generated metrics. - - Parameters - ---------- - metrics: dict of :class:`~.Metric` - Dictionary of generated metrics. - generator_names: list of str - Generator(s) to extract specific metrics from :attr:`metrics` for plotting. - metric_names: list of str - Specific metric(s) to extract from :class:`~.MetricGenerator` for plotting. - - Returns - ------- - : dict - Dict of all plottable metrics. - """ - metrics_dict = dict() - - for generator_name in generator_names: - for metric_name in metric_names: - if metric_name in metrics[generator_name].keys() and \ - metric_name in self.plottable_metrics: - if generator_name not in metrics_dict.keys(): - metrics_dict[generator_name] = \ - {metric_name: metrics[generator_name][metric_name]} - else: - metrics_dict[generator_name][metric_name] = \ - metrics[generator_name][metric_name] - - return metrics_dict - - def _count_subplots(self, metrics_to_plot, combine_plots): - """ - Calculate number of subplots needed to plot all metrics. - - Parameters - ---------- - metrics_to_plot: dict of :class:`~.Metric` - Dictionary of metrics to be plotted. - combine_plots: bool - Specifies whether same metric types should be plotted on same subplot. - - Returns - ------- - : int - Number of subplots to generate. - """ - if combine_plots: - metric_types = self.extract_metric_types(metrics_to_plot) - number_of_subplots = len(metric_types) - - else: - number_of_subplots = 0 - for generator in metrics_to_plot.keys(): - number_of_subplots += len(metrics_to_plot[generator]) - - return number_of_subplots - - @staticmethod - def extract_metric_types(metrics): - """ - Identify the different types of metric held in dict of metrics. - - Parameters - ---------- - metrics: dict of :class:`~.Metric` - Dictionary of metrics. - - Returns - ------- - : list - Sorted list of types of metric - """ - metric_types = set() - for generator in metrics.keys(): - for metric_key in metrics[generator].keys(): - metric_types.add(metric_key) - - metric_types = list(metric_types) - metric_types.sort() - - return metric_types - - def combine_plots(self, metrics_to_plot, metrics_kwargs): - """ - Generates one subplot for each different metric type and plots metrics of the same - type on same subplot. Metrics are plotted over time. - - Parameters - ---------- - metrics_to_plot: dict of :class:`~.Metric` - Dictionary of metrics to plot. - metrics_kwargs: dict - Keyword arguments to be passed to plot function. - - Returns - ------- - : :class:`matplotlib.pyplot.figure` - Figure containing subplots displaying metrics. - """ - # determine how many plots required - equal to number of metric types - number_of_subplots = self._count_subplots(metrics_to_plot, True) - - # initialise each subplot - self.fig, axes = plt.subplots(number_of_subplots, figsize=(10, 6*number_of_subplots)) - self.fig.subplots_adjust(hspace=0.3) - - # extract data for each subplot and plot it - metric_types = self.extract_metric_types(metrics_to_plot) - - self.axes = axes if isinstance(axes, Iterable) else [axes] - - # generate colour map for lines to be plotted - if 'color' not in metrics_kwargs.keys(): - colour_map = plt.cm.rainbow(np.linspace(0, 1, len(metrics_to_plot.keys()))) - else: - colour_map = metrics_kwargs['color'] - metrics_kwargs.pop('color') - - for metric_type, axis in zip(list(metric_types), self.axes): - artists = [] - legend_dict = {} - - colour_map_copy = iter(colour_map.copy()) - - for generator in metrics_to_plot.keys(): - for metric in metrics_to_plot[generator].keys(): - if metric == metric_type: - colour = next(colour_map_copy) - metric_values = metrics_to_plot[generator][metric].value - artists.extend(axis.plot([_.timestamp for _ in metric_values], - [_.value for _ in metric_values], - color=colour, - **metrics_kwargs)) - - metric_handle = Line2D([], [], linestyle=metrics_kwargs['linestyle'], - color=colour) - legend_dict[generator] = metric_handle - - # Generate legend - artists.append(axis.legend(handles=legend_dict.values(), - labels=legend_dict.keys())) - - y_label = metric_type.split(' at times')[0] - artists.extend(axis.set(title=metric_type.split(' at times')[0], - xlabel="Time", ylabel=y_label)) - - def plot_separately(self, metrics_to_plot, metrics_kwargs): - """ - Generates one subplot for each different individual metric and plots metric - values over time. - - Parameters - ---------- - metrics_to_plot: dict of :class:`~.Metric` - Dictionary of metrics to plot. - metrics_kwargs: dict - Keyword arguments to be passed to plot function. - - Returns - ------- - : :class:`matplotlib.pyplot.figure` - Figure containing subplots displaying metrics. - """ - metrics_kwargs['color'] = metrics_kwargs['color'] if \ - 'color' in metrics_kwargs.keys() else 'blue' - - # determine how many plots required - equal to number of metrics within the generators - number_of_subplots = self._count_subplots(metrics_to_plot, False) - - # initialise each plot - self.fig, axes = plt.subplots(number_of_subplots, figsize=(10, 6*number_of_subplots)) - self.fig.subplots_adjust(hspace=0.3) - - # extract data for each plot and plot it - all_metrics = {} - for generator in metrics_to_plot.keys(): - for metric in list(metrics_to_plot[generator].keys()): - all_metrics[f'{generator}: {metric}'] = metrics_to_plot[generator][metric] - - self.axes = axes if isinstance(axes, Iterable) else [axes] - - for metric, axis in zip(all_metrics.keys(), self.axes): - y_label = str(all_metrics[metric].title).split(' at times')[0] - axis.set(title=str(all_metrics[metric].title), xlabel='Time', ylabel=y_label) - metric_values = all_metrics[metric].value - axis.plot([_.timestamp for _ in metric_values], - [_.value for _ in metric_values], - **metrics_kwargs) - - # Generate legend - metric_handle = Line2D([], [], linestyle=metrics_kwargs['linestyle'], - color=metrics_kwargs['color']) - axis.legend(handles=[metric_handle], - labels=[metric.split(' at times')[0]]) - - def set_fig_title(self, title): - """ - Set title for the figure. - - Parameters - ---------- - title: str - Figure title text. - - Returns - ------- - Text instance of figure title. - """ - self.fig.suptitle(t=title) - - def set_ax_title(self, titles): - """ - Set axis titles for each axis in figure. - - Parameters - ---------- - titles: list of str - List of strings for title text for each axis. - - Returns - ------- - Text instance of axis titles. - """ - for axis, title in zip(self.axes, titles): - axis.set(title=title) - - -class Plotterly(_Plotter): - """Plotting class for building graphs of Stone Soup simulations using plotly - - A plotting class which is used to simplify the process of plotting ground truths, - measurements, clutter and tracks. Tracks can be plotted with uncertainty ellipses or - particles if required. Legends are automatically generated with each plot. - Three-dimensional plots can be created using the optional dimension parameter. - - Parameters - ---------- - dimension: enum \'Dimension\' - Optional parameter to specify 1D, 2D, or 3D plotting. - axis_labels: list - Optional parameter to specify the axis labels for non-xy dimensions. Default None, i.e., - "x" and "y". - \\*\\*kwargs: dict - Additional arguments to be passed to the Plotly.graph_objects Figure. - - Attributes - ---------- - fig: plotly.graph_objects.Figure - Generated figure to display graphs. - """ - def __init__(self, dimension=Dimension.TWO, axis_labels=None, **kwargs): - if dimension != Dimension.ONE: - if not axis_labels: - axis_labels = ["x", "y"] - else: - if axis_labels: - if len(axis_labels) == 1: - axis_labels = ["Time", axis_labels[0]] - else: - axis_labels = ["Time", "x"] - if go is None: - raise RuntimeError("Usage of Plotterly plotter requires installation of `plotly`") - - self.dimension = Dimension(dimension) # allows 1, 2, 3, - # Dimension(1), Dimension(2) or Dimension(3) - - from plotly import colors - layout_kwargs = dict( - xaxis_title=axis_labels[0], - yaxis_title=axis_labels[1], - colorway=colors.qualitative.Plotly, # Needed to match colours later. - ) - - if self.dimension == 3: - layout_kwargs.update(dict(scene_aspectmode='data')) # auto shapes fig to fit data well - - merge(layout_kwargs, kwargs) - - # Generate plot axes - self.fig = go.Figure(layout=layout_kwargs) - - @staticmethod - def _format_state_text(state): - text = [] - text.append(type(state).__name__) - text.append(getattr(state, 'mean', state.state_vector)) - text.append(state.timestamp) - text.extend([f"{key}: {value}" for key, value in getattr(state, 'metadata', {}).items()]) - - return "
".join((str(t) for t in text)) - - def _check_mapping(self, mapping): - if len(mapping) == 0: - raise ValueError("No indices provided in mapping.") - elif len(mapping) != self.dimension: - raise TypeError("Plotter dimension is not same as the mapping dimension.") - - def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwargs): - """Plots ground truth(s) - - Plots each ground truth path passed in to :attr:`truths` and generates a legend - automatically. Ground truths are plotted as dashed lines with default colors. - - Users can change line style, color and marker using keyword arguments. Any changes - will apply to all ground truths. - - Parameters - ---------- - truths : Collection of :class:`~.GroundTruthPath` - Collection of ground truths which will be plotted. If not a collection, - and instead a single :class:`~.GroundTruthPath` type, the argument is modified to be a - set to allow for iteration. - mapping: list - List of items specifying the mapping of the position components of the state space. - truths_label: str - Label for truth data. Default is "Ground Truth" - \\*\\*kwargs: dict - Additional arguments to be passed to scatter function. Default is - ``line=dict(dash="dash")``. - """ - if not isinstance(truths, Collection) or isinstance(truths, StateMutableSequence): - truths = {truths} - - self._check_mapping(mapping) # ensure mapping is compatible with plotter dimension - - truths_kwargs = dict( - mode="lines", line=dict(dash="dash"), legendgroup=truths_label, legendrank=100, - name=truths_label) - - if self.dimension == 3: # make ground truth line thicker so easier to see in 3d plot - truths_kwargs.update(dict(line=dict(width=8, dash="longdashdot"))) - - merge(truths_kwargs, kwargs) - add_legend = truths_kwargs['legendgroup'] not in {trace.legendgroup - for trace in self.fig.data} - - for truth in truths: - scatter_kwargs = truths_kwargs.copy() - if add_legend: - scatter_kwargs['showlegend'] = True - add_legend = False - else: - scatter_kwargs['showlegend'] = False - - if self.dimension == 1: - self.fig.add_scatter( - x=[state.timestamp for state in truth], - y=[state.state_vector[mapping[0]] for state in truth], - text=[self._format_state_text(state) for state in truth], - **scatter_kwargs) - - elif self.dimension == 2: - self.fig.add_scatter( - x=[state.state_vector[mapping[0]] for state in truth], - y=[state.state_vector[mapping[1]] for state in truth], - text=[self._format_state_text(state) for state in truth], - **scatter_kwargs) - - elif self.dimension == 3: - self.fig.add_scatter3d( - x=[state.state_vector[mapping[0]] for state in truth], - y=[state.state_vector[mapping[1]] for state in truth], - z=[state.state_vector[mapping[2]] for state in truth], - text=[self._format_state_text(state) for state in truth], - **scatter_kwargs) - - def plot_measurements(self, measurements, mapping, measurement_model=None, - measurements_label="Measurements", convert_measurements=True, **kwargs): - """Plots measurements - - Plots detections and clutter, generating a legend automatically. Detections are plotted as - blue circles by default unless the detection type is clutter. - If the detection type is :class:`~.Clutter` it is plotted as a yellow 'tri-up' marker. - - Users can change the color and marker of detections using keyword arguments but not for - clutter detections. - - Parameters - ---------- - measurements : Collection of :class:`~.Detection` - Detections which will be plotted. If measurements is a set of lists it is flattened. - mapping: list - List of items specifying the mapping of the position components of the state space. - measurement_model : :class:`~.Model`, optional - User-defined measurement model to be used in finding measurement state inverses if - they cannot be found from the measurements themselves. - measurements_label : str - Label for the measurements. Default is "Measurements". - convert_measurements: bool - Should the measurements be converted from measurement space to state space before - being plotted. Default is True - \\*\\*kwargs: dict - Additional arguments to be passed to scatter function for detections. Defaults are - ``marker=dict(color="#636EFA")``. - """ - - if not isinstance(measurements, Collection): - measurements = {measurements} - - if any(isinstance(item, set) for item in measurements): - measurements_set = chain.from_iterable(measurements) # Flatten into one set - else: - measurements_set = set(measurements) - - self._check_mapping(mapping) - - plot_detections, plot_clutter = self._conv_measurements(measurements_set, - mapping, - measurement_model, - convert_measurements) - - if plot_detections: - name = measurements_label + "
(Detections)" - measurement_kwargs = dict( - mode='markers', marker=dict(color='#636EFA'), - name=name, legendgroup=name, legendrank=200) - - if self.dimension == 3: # make markers smaller in 3d plot - measurement_kwargs.update(dict(marker=dict(size=4, color='#636EFA'))) - - merge(measurement_kwargs, kwargs) - if measurement_kwargs['legendgroup'] not in {trace.legendgroup - for trace in self.fig.data}: - measurement_kwargs['showlegend'] = True - else: - measurement_kwargs['showlegend'] = False - detection_array = np.asarray(list(plot_detections.values()), dtype=np.float64) - - if self.dimension == 1: - self.fig.add_scatter( - x=[state.timestamp for state in plot_detections.keys()], - y=detection_array[:, 0], - text=[self._format_state_text(state) for state in plot_detections.keys()], - **measurement_kwargs, - ) - elif self.dimension == 2: - self.fig.add_scatter( - x=detection_array[:, 0], - y=detection_array[:, 1], - text=[self._format_state_text(state) for state in plot_detections.keys()], - **measurement_kwargs, - ) - elif self.dimension == 3: - self.fig.add_scatter3d( - x=detection_array[:, 0], - y=detection_array[:, 1], - z=detection_array[:, 2], - text=[self._format_state_text(state) for state in plot_detections.keys()], - **measurement_kwargs, - ) - - if plot_clutter: - name = measurements_label + "
(Clutter)" - clutter_kwargs = dict( - mode='markers', marker=dict(symbol="star-triangle-up", color='#FECB52'), - name=name, legendgroup=name, legendrank=210) - - if self.dimension == 3: # update - star-triangle-up not in 3d plotly - measurement_kwargs.update(dict(marker=dict(size=4, symbol="diamond", - color='#FECB52'))) - - merge(clutter_kwargs, kwargs) - if clutter_kwargs['legendgroup'] not in {trace.legendgroup - for trace in self.fig.data}: - clutter_kwargs['showlegend'] = True - else: - clutter_kwargs['showlegend'] = False - clutter_array = np.asarray(list(plot_clutter.values()), dtype=np.float64) - - if self.dimension == 1: - self.fig.add_scatter( - x=[state.timestamp for state in plot_clutter.keys()], - y=clutter_array[:, 0], - text=[self._format_state_text(state) for state in plot_clutter.keys()], - **clutter_kwargs, - ) - elif self.dimension == 2: - self.fig.add_scatter( - x=clutter_array[:, 0], - y=clutter_array[:, 1], - text=[self._format_state_text(state) for state in plot_clutter.keys()], - **clutter_kwargs, - ) - elif self.dimension == 3: - self.fig.add_scatter3d( - x=clutter_array[:, 0], - y=clutter_array[:, 1], - z=clutter_array[:, 2], - text=[self._format_state_text(state) for state in plot_clutter.keys()], - **clutter_kwargs, - ) - - def get_next_color(self): - """ - Find the colour of the next plot. This approach to getting colour isn't ideal, but should - work in most cases... - Returns - ------- - dist : str - Hex string for a colour - """ - # Find how many sequences have been plotted so far. The current plot has already been added - # to fig.data, so -1 is needed - figure_index = len(self.fig.data) - 1 - - # Get the list of colours used for plotting - colorway = self.fig.layout.colorway - max_index = len(colorway) - - # Use the modulo operator to limit the colour index to limits of the colorway. - # If figure_index > max_index then colours will be reused - color_index = figure_index % max_index - return colorway[color_index] - - def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_label="Tracks", - ellipse_points=30, err_freq=1, same_color=False, **kwargs): - """Plots track(s) - - Plots each track generated, generating a legend automatically. If ``uncertainty=True`` - error ellipses are plotted. - Tracks are plotted as solid lines with point markers and default colors. - - Users can change line style, color and marker using keyword arguments. - - Parameters - ---------- - tracks : Collection of :class:`~.Track` - Collection of tracks which will be plotted. If not a collection, and instead a single - :class:`~.Track` type, the argument is modified to be a set to allow for iteration. - mapping: list - List of items specifying the mapping of the position - components of the state space. - uncertainty : bool - If True, function plots uncertainty ellipses. - particle : bool - If True, function plots particles. - track_label: str - Label to apply to all tracks for legend. - ellipse_points: int - Number of points for polygon approximating ellipse shape - err_freq: int - Frequency of error bar plotting on tracks. Default value is 1, meaning - error bars are plotted at every track step. - same_color: bool - Should all the tracks have the same colour. Default False - \\*\\*kwargs: dict - Additional arguments to be passed to scatter function. Defaults are - ``marker=dict(symbol='square')`` for :class:`~.Update` and - ``marker=dict(symbol='circle')`` for other states. - """ - if not isinstance(tracks, Collection) or isinstance(tracks, StateMutableSequence): - tracks = {tracks} # Make a set of length 1 - - self._check_mapping(mapping) # check size of mapping against dimension of plotter - - # Plot tracks - track_colors = {} - track_kwargs = dict(mode='markers+lines', legendgroup=track_label, legendrank=300) - - if self.dimension == 3: # change visuals to work well in 3d - track_kwargs.update(dict(line=dict(width=7)), marker=dict(size=4)) - merge(track_kwargs, kwargs) - add_legend = track_kwargs['legendgroup'] not in {trace.legendgroup - for trace in self.fig.data} - - if same_color: - color = track_kwargs.get('marker', {}).get('color') or \ - track_kwargs.get('line', {}).get('color') - - # Set the colour if it hasn't already been set - if color is None: - track_kwargs['marker'] = track_kwargs.get('marker', {}) - track_kwargs['marker']['color'] = self.get_next_color() - - for track in tracks: - scatter_kwargs = track_kwargs.copy() - scatter_kwargs['name'] = track.id - if add_legend: - scatter_kwargs['name'] = track_label - scatter_kwargs['showlegend'] = True - add_legend = False - else: - scatter_kwargs['showlegend'] = False - scatter_kwargs['marker'] = scatter_kwargs.get('marker', {}).copy() - if 'symbol' not in scatter_kwargs['marker']: - scatter_kwargs['marker']['symbol'] = [ - 'square' if isinstance(state, Update) else 'circle' for state in track] - - if len(self.fig.data) > 0: - track_colors[track] = (self.fig.data[-1].line.color - or self.fig.data[-1].marker.color - or self.get_next_color()) - else: - track_colors[track] = self.get_next_color() - - if self.dimension == 1: # plot 1D tracks - - if uncertainty or particle: - raise NotImplementedError - - self.fig.add_scatter( - x=[state.timestamp for state in track], - y=[float(getattr(state, 'mean', state.state_vector)[mapping[0]]) - for state in track], - text=[self._format_state_text(state) for state in track], - **scatter_kwargs) - - elif self.dimension == 2: # plot 2D tracks - - self.fig.add_scatter( - x=[float(getattr(state, 'mean', state.state_vector)[mapping[0]]) - for state in track], - y=[float(getattr(state, 'mean', state.state_vector)[mapping[1]]) - for state in track], - text=[self._format_state_text(state) for state in track], - **scatter_kwargs) - - elif self.dimension == 3: # plot 3D tracks - - if particle: - raise NotImplementedError - - # create empty error arrays - err_x = np.array([np.nan for _ in range(len(track))], dtype=float) - err_y = np.array([np.nan for _ in range(len(track))], dtype=float) - err_z = np.array([np.nan for _ in range(len(track))], dtype=float) - - if uncertainty: # find x,y,z error bars for relevant states - - for count, state in enumerate(track): - - if not count % err_freq: # ie count % err_freq = 0 - HH = np.eye(track.ndim)[mapping, :] # Get position mapping matrix - cov = HH @ state.covar @ HH.T - - err_x[count] = np.sqrt(cov[0, 0]) - err_y[count] = np.sqrt(cov[1, 1]) - err_z[count] = np.sqrt(cov[2, 2]) - - self.fig.add_scatter3d( - x=[float(getattr(state, 'mean', state.state_vector)[mapping[0]]) - for state in track], - error_x=dict(type='data', thickness=10, width=3, array=err_x), - - y=[float(getattr(state, 'mean', state.state_vector)[mapping[1]]) - for state in track], - error_y=dict(type='data', thickness=10, width=3, array=err_y), - - z=[float(getattr(state, 'mean', state.state_vector)[mapping[2]]) - for state in track], - error_z=dict(type='data', thickness=10, width=3, array=err_z), - # note that 3D error thickness seems to be broken in Plotly - - text=[self._format_state_text(state) for state in track], - **scatter_kwargs) - - track_colors[track] = (self.fig.data[-1].line.color - or self.fig.data[-1].marker.color - or self.get_next_color()) - - # earlier checking means this only applies to 2D. - if uncertainty and self.dimension == 2: - name = track_kwargs['legendgroup'] + "
(Ellipses)" - add_legend = name not in {trace.legendgroup for trace in self.fig.data} - for track in tracks: - ellipse_kwargs = dict( - mode='none', fill='toself', fillcolor=track_colors[track], - opacity=0.2, hoverinfo='skip', - legendgroup=name, name=name, - legendrank=track_kwargs['legendrank'] + 10) - for state in track: - points = self._generate_ellipse_points(state, mapping, ellipse_points) - if add_legend: - ellipse_kwargs['showlegend'] = True - add_legend = False - else: - ellipse_kwargs['showlegend'] = False - - self.fig.add_scatter(x=points[0, :], y=points[1, :], **ellipse_kwargs) - - if particle and self.dimension == 2: - name = track_kwargs['legendgroup'] + "
(Particles)" - add_legend = name not in {trace.legendgroup for trace in self.fig.data} - for track in tracks: - for state in track: - particle_kwargs = dict( - mode='markers', marker=dict(size=2), - opacity=0.4, hoverinfo='skip', - legendgroup=name, name=name, - legendrank=track_kwargs['legendrank'] + 20) - if add_legend: - particle_kwargs['showlegend'] = True - add_legend = False - else: - particle_kwargs['showlegend'] = False - data = state.state_vector[mapping[:2], :] - self.fig.add_scattergl(x=data[0], y=data[1], **particle_kwargs) - - @staticmethod - def _generate_ellipse_points(state, mapping, n_points=30): - """Generate error ellipse points for given state and mapping""" - HH = np.eye(state.ndim)[mapping, :] # Get position mapping matrix - w, v = np.linalg.eig(HH @ state.covar @ HH.T) - max_ind = np.argmax(w) - min_ind = np.argmin(w) - orient = np.arctan2(v[1, max_ind], v[0, max_ind]) - a = np.sqrt(w[max_ind]) - b = np.sqrt(w[min_ind]) - m = 1 - (b**2 / a**2) - - def func(x): - return np.sqrt(1 - (m**2 * np.sin(x)**2)) - - def func2(z): - return quad(func, 0, z)[0] - - c = 4 * a * func2(np.pi / 2) - - points = [] - for n in range(n_points): - def func3(x): - return n/n_points*c - a*func2(x) - - points.append((brentq(func3, 0, 2 * np.pi, xtol=1e-4))) - - c, s = np.cos(orient), np.sin(orient) - rotational_matrix = np.array(((c, -s), (s, c))) - points.append(points[0]) - points = np.array([[a * np.sin(i), b * np.cos(i)] for i in points]) - points = rotational_matrix @ points.T - return points + state.mean[mapping[:2], :] - - def plot_sensors(self, sensors, mapping=[0, 1], sensor_label="Sensors", **kwargs): - """Plots sensor(s) - - Plots sensors. Users can change the color and marker of sensors using keyword - arguments. Default is a black 'x' marker. - - Parameters - ---------- - sensors : Collection of :class:`~.Sensor` - Sensors to plot - mapping: list - List of items specifying the mapping of the position - components of the sensor's position. - sensor_label: str - Label to apply to all sensors for legend. - \\*\\*kwargs: dict - Additional arguments to be passed to scatter function for sensors. Defaults are - ``marker=dict(symbol='x', color='black')``. - """ - - if not isinstance(sensors, Collection): - sensors = {sensors} - - self._check_mapping(mapping) # ensure mapping is compatible with plotter dimension - - if self.dimension == 1 or self.dimension == 3: - raise NotImplementedError - - sensor_kwargs = dict(mode='markers', marker=dict(symbol='x', color='black'), - legendgroup=sensor_label, legendrank=50) - merge(sensor_kwargs, kwargs) - - sensor_kwargs['name'] = sensor_label - if sensor_kwargs['legendgroup'] not in {trace.legendgroup - for trace in self.fig.data}: - sensor_kwargs['showlegend'] = True - else: - sensor_kwargs['showlegend'] = True - - sensor_xy = np.array([sensor.position[mapping, 0] for sensor in sensors]) - self.fig.add_scatter(x=sensor_xy[:, 0], y=sensor_xy[:, 1], **sensor_kwargs) - - def hide_plot_traces(self, items_to_hide=None): - """Hide Plot Traces - - This function allows plotting items to be invisible as default. Users can toggle the plot - trace to visible. - - Parameters - ---------- - items_to_hide : Iterable[str] - The legend label (`legendgroups`) for the plot traces that should be invisible as - default. If left as ``None`` no traces will be shown. - """ - for fig_data in self.fig.data: - if items_to_hide is None or fig_data.legendgroup in items_to_hide: - fig_data.visible = "legendonly" - else: - fig_data.visible = None - - def show_plot_traces(self, items_to_show=None): - """Show Plot Traces - - This function allows specific plotting items to be shown as default. All labels not - mentioned in `items_to_show` will be invisible and can be manually toggled on. - - Parameters - ---------- - items_to_show : Iterable[str] - The legend label (`legendgroups`) for the plot traces that should be shown as - default. If left as ``None`` all traces will be shown. - """ - for fig_data in self.fig.data: - if items_to_show is None or fig_data.legendgroup in items_to_show: - fig_data.visible = None - else: - fig_data.visible = "legendonly" - - -class PolarPlotterly(_Plotter): - - def __init__(self, dimension=Dimension.TWO, **kwargs): - if go is None: - raise RuntimeError("Usage of Plotterly plotter requires installation of `plotly`") - if isinstance(dimension, type(Dimension.TWO)): - self.dimension = dimension - elif isinstance(dimension, int): - self.dimension = Dimension(dimension) - else: - raise TypeError("%s is an unsupported type for \'dimension\'; " - "expected type %s" % (type(dimension), type(Dimension.TWO))) - if self.dimension != dimension.TWO: - raise TypeError("Only 2D plotting currently supported") - - layout_kwargs = dict() - layout_kwargs.update(kwargs) - - # Generate plot axes - self.fig = go.Figure(layout=layout_kwargs) - - def plot_state_sequence(self, state_sequences, angle_mapping: int, range_mapping: int = None, - label="", **kwargs): - """Plots state sequence(s) - - Plots each state sequence passed in to :attr:`state_sequences` and generates a legend - automatically. - - Users can change line style, color and marker using keyword arguments. Any changes - will apply to all ground truths. - - Parameters - ---------- - state_sequences : Collection of :class:`~.StateMutableSequence` - Collection of state sequences which will be plotted. If not a collection, - and instead a single :class:`~.StateMutableSequence` type, the argument is modified - to be a set to allow for iteration. - angle_mapping: int - Specifying the mapping of the angular component of the state space to be plotted. - range_mapping: int - Specifying the mapping of the range component of the state space to be plotted. If - `None`, the angular component will be plotted against time. - label: str - Label for truth data. - \\*\\*kwargs: dict - Additional arguments to be passed to scatter function. Default is - ``mode=marker``. - The default unit for the angular component is radians. This can be changed to degrees - with the keyword argument ``thetaunit='degrees'``. - """ - - if not isinstance(state_sequences, Collection) \ - or isinstance(state_sequences, StateMutableSequence): - state_sequences = {state_sequences} - - plotting_kwargs = dict( - mode="markers", legendgroup=label, legendrank=200, - name=label, thetaunit="radians") - merge(plotting_kwargs, kwargs) - add_legend = plotting_kwargs['legendgroup'] not in {trace.legendgroup - for trace in self.fig.data} - - for state_sequence in state_sequences: - if range_mapping is None: - r = [state.timestamp for state in state_sequence] - else: - r = [float(state.state_vector[range_mapping]) for state in state_sequence] - bearings = [float(state.state_vector[angle_mapping]) for state in state_sequence] - - scatter_kwargs = plotting_kwargs.copy() - if add_legend: - scatter_kwargs['showlegend'] = True - add_legend = False - else: - scatter_kwargs['showlegend'] = False - - polar_plot = go.Scatterpolar( - r=r, - theta=bearings, **scatter_kwargs) - self.fig.add_trace(polar_plot) - - def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwargs): - """Plots ground truth(s) - - Plots each ground truth path passed in to :attr:`truths` and generates a legend - automatically. Ground truths are plotted as dashed lines with default colors. - - Users can change line style, color and marker using keyword arguments. Any changes - will apply to all ground truths. - - Parameters - ---------- - truths : Collection of :class:`~.GroundTruthPath` - Collection of ground truths which will be plotted. If not a collection, - and instead a single :class:`~.GroundTruthPath` type, the argument is modified to be a - set to allow for iteration. - mapping: list - List of items specifying the mapping of the position components of the state space. - truths_label: str - Label for truth data. Default is "Ground Truth". - \\*\\*kwargs: dict - Additional arguments to be passed to scatter function. Default is - ``line=dict(dash="dash")``. - """ - truths_kwargs = dict(mode="lines", line=dict(dash="dash"), legendrank=100) - merge(truths_kwargs, kwargs) - angle_mapping = mapping[0] - if len(mapping) > 1: - range_mapping = mapping[1] - else: - range_mapping = None - self.plot_state_sequence(state_sequences=truths, angle_mapping=angle_mapping, - range_mapping=range_mapping, label=truths_label, **truths_kwargs) - - def plot_measurements(self, measurements, mapping, measurement_model=None, - measurements_label="Measurements", convert_measurements=True, **kwargs): - """Plots measurements - - Plots detections and clutter, generating a legend automatically. Detections are plotted as - blue circles by default unless the detection type is clutter. - If the detection type is :class:`~.Clutter` it is plotted as a yellow 'tri-up' marker. - - Users can change the color and marker of detections using keyword arguments but not for - clutter detections. - - Parameters - ---------- - measurements : Collection of :class:`~.Detection` - Detections which will be plotted. If measurements is a set of lists it is flattened. - mapping: list - List of items specifying the mapping of the position components of the state space. - measurement_model : :class:`~.Model`, optional - User-defined measurement model to be used in finding measurement state inverses if - they cannot be found from the measurements themselves. - measurements_label : str - Label for the measurements. Default is "Measurements". - convert_measurements: bool - Should the measurements be converted before being plotted. Default is True. - \\*\\*kwargs: dict - Additional arguments to be passed to scatter function for detections. Defaults are - ``marker=dict(color="#636EFA")``. - """ - - if not isinstance(measurements, Collection): - measurements = {measurements} - - if any(isinstance(item, set) for item in measurements): - measurements_set = chain.from_iterable(measurements) # Flatten into one set - else: - measurements_set = set(measurements) - - plot_detections, plot_clutter = self._conv_measurements(measurements_set, - mapping, - measurement_model, - convert_measurements) - - angle_mapping = 0 - if len(mapping) > 1: - range_mapping = 1 - else: - range_mapping = None - - if plot_detections: - name = measurements_label + "
(Detections)" - measurement_kwargs = dict(mode='markers', marker=dict(color='#636EFA'), legendrank=200) - merge(measurement_kwargs, kwargs) - plotting_data = [State(state_vector=plotting_state_vector, - timestamp=det.timestamp) - for det, plotting_state_vector in plot_detections.items()] - - self.plot_state_sequence(state_sequences=[plotting_data], angle_mapping=angle_mapping, - range_mapping=range_mapping, label=name, - **measurement_kwargs) - - if plot_clutter: - name = measurements_label + "
(Clutter)" - clutter_kwargs = dict(mode='markers', legendrank=210, - marker=dict(symbol="star-triangle-up", color='#FECB52')) - merge(clutter_kwargs, kwargs) - plotting_data = [State(state_vector=plotting_state_vector, - timestamp=det.timestamp) - for det, plotting_state_vector in plot_clutter.items()] - - self.plot_state_sequence(state_sequences=[plotting_data], angle_mapping=angle_mapping, - range_mapping=range_mapping, label=name, - **clutter_kwargs) - - def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_label="Tracks", - **kwargs): - """Plots track(s) - - Plots each track generated, generating a legend automatically. If ``uncertainty=True`` - error ellipses are plotted. - Tracks are plotted as solid lines with point markers and default colors. - - Users can change line style, color and marker using keyword arguments. - - Parameters - ---------- - tracks : Collection of :class:`~.Track` - Collection of tracks which will be plotted. If not a collection, and instead a single - :class:`~.Track` type, the argument is modified to be a set to allow for iteration. - mapping: list - List of items specifying the mapping of the position - components of the state space. - uncertainty : bool - If True, function plots uncertainty ellipses. - particle : bool - If True, function plots particles. - track_label: str - Label to apply to all tracks for legend. - \\*\\*kwargs: dict - Additional arguments to be passed to scatter function. Defaults are - ``mode='markers+lines'``. - """ - if uncertainty or particle: - raise NotImplementedError - - track_kwargs = dict(mode='markers+lines', legendrank=300) - merge(track_kwargs, kwargs) - angle_mapping = mapping[0] - if len(mapping) > 1: - range_mapping = mapping[1] - else: - range_mapping = None - self.plot_state_sequence(state_sequences=tracks, angle_mapping=angle_mapping, - range_mapping=range_mapping, label=track_label, **track_kwargs) - - def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs): - raise NotImplementedError - - -class _AnimationPlotterDataClass(Base): - plotting_data = Property(Iterable[State]) - plotting_label: str = Property() - plotting_keyword_arguments: dict = Property() - - -class AnimationPlotter(_Plotter): - - def __init__(self, dimension=Dimension.TWO, x_label: str = "$x$", y_label: str = "$y$", - title: str = None, legend_kwargs: dict = {}, **kwargs): - - self.figure_kwargs = {"figsize": (10, 6)} - self.figure_kwargs.update(kwargs) - if dimension != Dimension.TWO: - raise NotImplementedError - - self.legend_kwargs = dict() - self.legend_kwargs.update(legend_kwargs) - - self.x_label: str = x_label - self.y_label: str = y_label - - if title: - title += "\n" - self.title: str = title - - self.plotting_data: List[_AnimationPlotterDataClass] = [] - - self.animation_output: animation.FuncAnimation = None - - def run(self, - times_to_plot: List[datetime] = None, - plot_item_expiry: Optional[timedelta] = None, - **kwargs): - """Run the animation - - Parameters - ---------- - times_to_plot : List of :class:`~.datetime` - List of datetime objects of when to refresh and draw the animation. Default `None`, - where unique timestamps of data will be used. - plot_item_expiry: :class:`~.timedelta`, Optional - Describes how long states will remain present in the figure. Default value of None - means data is shown indefinitely - \\*\\*kwargs: dict - Additional arguments to be passed to the animation.FuncAnimation function - """ - if times_to_plot is None: - times_to_plot = sorted({ - state.timestamp - for plotting_data in self.plotting_data - for state in plotting_data.plotting_data}) - - self.animation_output = self.run_animation( - times_to_plot=times_to_plot, - data=self.plotting_data, - plot_item_expiry=plot_item_expiry, - x_label=self.x_label, - y_label=self.y_label, - figure_kwargs=self.figure_kwargs, - legend_kwargs=self.legend_kwargs, - animation_input_kwargs=kwargs, - plot_title=self.title - ) - return self.animation_output - - def save(self, filename='example.mp4', **kwargs): - """Save the animation - - Parameters - ---------- - filename : str - filename of animation file - \\*\\*kwargs: dict - Additional arguments to be passed to the animation.save function - """ - if self.animation_output is None: - raise ValueError("Animation hasn't been run yet. Therefore there is no animation to " - "save") - - self.animation_output.save(filename, **kwargs) - - def plot_ground_truths(self, truths, mapping: List[int], truths_label: str = "Ground Truth", - **kwargs): - """Plots ground truth(s) - - Plots each ground truth path passed in to :attr:`truths` and generates a legend - automatically. Ground truths are plotted as dashed lines with default colors. - - Users can change linestyle, color and marker using keyword arguments. Any changes - will apply to all ground truths. - - Parameters - ---------- - truths : Collection of :class:`~.GroundTruthPath` - Collection of ground truths which will be plotted. If not a collection and instead a - single :class:`~.GroundTruthPath` type, the argument is modified to be a set to allow - for iteration. - mapping: list - List of items specifying the mapping of the position components of the state space. - truths_label: str - Label for truth data. Default is "Ground Truth" - \\*\\*kwargs: dict - Additional arguments to be passed to plot function. Default is ``linestyle="--"``. - """ - - truths_kwargs = dict(linestyle="--") - truths_kwargs.update(kwargs) - self.plot_state_mutable_sequence(truths, mapping, truths_label, **truths_kwargs) - - def plot_tracks(self, tracks, mapping: List[int], uncertainty=False, particle=False, - track_label="Tracks", **kwargs): - """Plots track(s) - - Plots each track generated, generating a legend automatically. Tracks are plotted as solid - lines with point markers and default colors. Users can change linestyle, color and marker - using keyword arguments. - - Parameters - ---------- - tracks : Collection of :class:`~.Track` - Collection of tracks which will be plotted. If not a collection, and instead a single - :class:`~.Track` type, the argument is modified to be a set to allow for iteration. - mapping: list - List of items specifying the mapping of the position - components of the state space. - uncertainty : bool - Currently not implemented. If True, an error is raised - particle : bool - Currently not implemented. If True, an error is raised - track_label: str - Label to apply to all tracks for legend. - \\*\\*kwargs: dict - Additional arguments to be passed to plot function. Defaults are ``linestyle="-"``, - ``marker='s'`` for :class:`~.Update` and ``marker='o'`` for other states. - """ - if uncertainty or particle: - raise NotImplementedError - - tracks_kwargs = dict(linestyle='-', marker="s", color=None) - tracks_kwargs.update(kwargs) - self.plot_state_mutable_sequence(tracks, mapping, track_label, **tracks_kwargs) - - def plot_state_mutable_sequence(self, state_mutable_sequences, mapping: List[int], label: str, - **plotting_kwargs): - """Plots State Mutable Sequence - - Parameters - ---------- - state_mutable_sequences : Collection of :class:`~.StateMutableSequence` - Collection of states to be plotted - mapping: list - List of items specifying the mapping of the position components of the state space. - label : str - User-defined measurement model to be used in finding measurement state inverses if - they cannot be found from the measurements themselves. - \\*\\*kwargs: dict - Additional arguments to be passed to plot function for states. - """ - - if not isinstance(state_mutable_sequences, Collection) or \ - isinstance(state_mutable_sequences, StateMutableSequence): - state_mutable_sequences = {state_mutable_sequences} # Make a set of length 1 - - for idx, state_mutable_sequence in enumerate(state_mutable_sequences): - if idx == 0: - this_plotting_label = label - else: - this_plotting_label = None - - self.plotting_data.append(_AnimationPlotterDataClass( - plotting_data=[State(state_vector=[state.state_vector[mapping[0]], - state.state_vector[mapping[1]]], - timestamp=state.timestamp) - for state in state_mutable_sequence], - plotting_label=this_plotting_label, - plotting_keyword_arguments=plotting_kwargs - )) - - def plot_measurements(self, measurements, mapping, measurement_model=None, - measurements_label="", convert_measurements=True, **kwargs): - """Plots measurements - - Plots detections and clutter, generating a legend automatically. Detections are plotted as - blue circles by default unless the detection type is clutter. - If the detection type is :class:`~.Clutter` it is plotted as a yellow 'tri-up' marker. - - Users can change the color and marker of detections using keyword arguments but not for - clutter detections. - - Parameters - ---------- - measurements : Collection of :class:`~.Detection` - Detections which will be plotted. If measurements is a set of lists it is flattened. - mapping: list - List of items specifying the mapping of the position components of the state space. - measurement_model : :class:`~.Model`, optional - User-defined measurement model to be used in finding measurement state inverses if - they cannot be found from the measurements themselves. - measurements_label: str - Label for measurements. Default will be "Detections" or "Clutter" - convert_measurements: bool - Should the measurements be converted from measurement space to state space before - being plotted. Default is True - \\*\\*kwargs: dict - Additional arguments to be passed to plot function for detections. Defaults are - ``marker='o'`` and ``color='b'``. - """ - - measurement_kwargs = dict(marker='o', color='b') - measurement_kwargs.update(kwargs) - - if not isinstance(measurements, Collection): - measurements = {measurements} # Make a set of length 1 - - if any(isinstance(item, set) for item in measurements): - measurements_set = chain.from_iterable(measurements) # Flatten into one set - else: - measurements_set = measurements - - plot_detections, plot_clutter = self._conv_measurements(measurements_set, - mapping, - measurement_model, - convert_measurements) - - if measurements_label != "": - measurements_label = measurements_label + " " - - if plot_detections: - detection_kwargs = dict(linestyle='', marker='o', color='b') - detection_kwargs.update(kwargs) - self.plotting_data.append(_AnimationPlotterDataClass( - plotting_data=[State(state_vector=plotting_state_vector, - timestamp=detection.timestamp) - for detection, plotting_state_vector in plot_detections.items()], - plotting_label=measurements_label + "Detections", - plotting_keyword_arguments=detection_kwargs - )) - - if plot_clutter: - clutter_kwargs = dict(linestyle='', marker='2', color='y') - clutter_kwargs.update(kwargs) - self.plotting_data.append(_AnimationPlotterDataClass( - plotting_data=[State(state_vector=plotting_state_vector, - timestamp=detection.timestamp) - for detection, plotting_state_vector in plot_clutter.items()], - plotting_label=measurements_label + "Clutter", - plotting_keyword_arguments=clutter_kwargs - )) - - def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs): - raise NotImplementedError - - @classmethod - def run_animation(cls, - times_to_plot: List[datetime], - data: Iterable[_AnimationPlotterDataClass], - plot_item_expiry: Optional[timedelta] = None, - axis_padding: float = 0.1, - figure_kwargs: dict = {}, - animation_input_kwargs: dict = {}, - legend_kwargs: dict = {}, - x_label: str = "$x$", - y_label: str = "$y$", - plot_title: str = None - ) -> animation.FuncAnimation: - """ - Parameters - ---------- - times_to_plot : Iterable[datetime] - All the times that the plotter should plot - data : Iterable[datetime] - All the data that should be plotted - plot_item_expiry: timedelta - How long a state should be displayed for. Default value of None - means data is shown indefinitely - axis_padding: float - How much extra space should be given around the edge of the plot - figure_kwargs: dict - Keyword arguments for the pyplot figure function. See matplotlib.pyplot.figure for more - details - animation_input_kwargs: dict - Keyword arguments for FuncAnimation class. See matplotlib.animation.FuncAnimation for - more details. Default values are: blit=False, repeat=False, interval=50 - legend_kwargs: dict - Keyword arguments for the pyplot legend function. See matplotlib.pyplot.legend for more - details - x_label: str - Label for the x axis - y_label: str - Label for the y axis - plot_title: str - Title for the plot - - Returns - ------- - : animation.FuncAnimation - Animation object - """ - - animation_kwargs = dict(blit=False, repeat=False, interval=50) # milliseconds - animation_kwargs.update(animation_input_kwargs) - - fig1 = plt.figure(**figure_kwargs) - - the_lines = [] - plotting_data = [] - legends_key = [] - - for a_plot_object in data: - if a_plot_object.plotting_data is not None: - the_data = np.array( - [a_state.state_vector for a_state in a_plot_object.plotting_data]) - if len(the_data) == 0: - continue - the_lines.append( - plt.plot([], # the_data[:1, 0], - [], # the_data[:1, 1], - **a_plot_object.plotting_keyword_arguments)[0]) - - legends_key.append(a_plot_object.plotting_label) - plotting_data.append(a_plot_object.plotting_data) - - if axis_padding: - [x_limits, y_limits] = [ - [min(state.state_vector[idx] for line in data for state in line.plotting_data), - max(state.state_vector[idx] for line in data for state in line.plotting_data)] - for idx in [0, 1]] - - for axis_limits in [x_limits, y_limits]: - limit_padding = axis_padding * (axis_limits[1] - axis_limits[0]) - # The casting to float to ensure the limits contain do not contain angle classes - axis_limits[0] = float(axis_limits[0] - limit_padding) - axis_limits[1] = float(axis_limits[1] + limit_padding) - - plt.xlim(x_limits) - plt.ylim(y_limits) - else: - plt.axis('equal') - - plt.xlabel(x_label) - plt.ylabel(y_label) - - lines_with_legend = [line for line, label in zip(the_lines, legends_key) - if label is not None] - plt.legend(lines_with_legend, [label for label in legends_key if label is not None], - **legend_kwargs) - - if plot_item_expiry is None: - min_plot_time = min(state.timestamp - for line in data - for state in line.plotting_data) - min_plot_times = [min_plot_time] * len(times_to_plot) - else: - min_plot_times = [time - plot_item_expiry for time in times_to_plot] - - line_ani = animation.FuncAnimation(fig1, cls.update_animation, - frames=len(times_to_plot), - fargs=(the_lines, plotting_data, min_plot_times, - times_to_plot, plot_title), - **animation_kwargs) - - plt.draw() - - return line_ani - - @staticmethod - def update_animation(index: int, lines: List[Line2D], data_list: List[List[State]], - start_times: List[datetime], end_times: List[datetime], title: str): - """ - Parameters - ---------- - index : int - Which index of the start_times and end_times should be used - lines : List[Line2D] - The data that will be plotted, to be plotted. - data_list : List[List[State]] - All the data that should be plotted - start_times : List[datetime] - lowest (earliest) time for an item to be plotted - end_times : List[datetime] - highest (latest) time for an item to be plotted - title: str - Title for the plot - - Returns - ------- - : List[Line2D] - The data that will be plotted - """ - - min_time = start_times[index] - max_time = end_times[index] - - if title is None: - title = "" - plt.title(title + str(max_time)) - for i, data_source in enumerate(data_list): - - if data_source is not None: - the_data = np.array([a_state.state_vector for a_state in data_source - if min_time <= a_state.timestamp <= max_time]) - if the_data.size > 0: - lines[i].set_data(the_data[:, 0], - the_data[:, 1]) - else: - lines[i].set_data([], - []) - return lines - - -class AnimatedPlotterly(_Plotter): - """ - Class for a 2D animated plotter that uses Plotly graph objects rather than matplotlib. - This gives the user the ability to see how tracking works through time, while being - able to interact with tracks, truths, etc, in the same way that is enabled by - Plotly static plots. - - Simplifies the process of plotting ground truths, measurements, clutter, and tracks. - Tracks can be plotted with uncertainty ellipses or particles if required. Legends - are automatically generated with each plot. - - Parameters - ---------- - timesteps: Collection - Collection of equally-spaced timesteps. Each animation frame is a timestep. - tail_length: float - Percentage of sim time for which previous values will still be displayed for. - Value can be between 0 and 1. Default is 0.3. - equal_size: bool - Makes x and y axes equal when figure is resized. Default is False. - sim_duration: int - Time taken to run animation (s). Default is 6 - \\*\\*kwargs - Additional arguments to be passed in the initialisation. - - Attributes - ---------- - - """ - - def __init__(self, timesteps, tail_length=0.3, equal_size=False, - sim_duration=6, **kwargs): - """ - Initialise the figure and checks that inputs are correctly formatted. - Creates an empty frame for each timestep, and configures - the buttons and slider. - - - """ - if go is None or colors is None: - raise RuntimeError("Usage of Plotterly plotter requires installation of `plotly`") - - self.equal_size = equal_size - - # checking that there are multiple timesteps - if len(timesteps) < 2: - raise ValueError("Must be at least 2 timesteps for animation.") - - # checking that timesteps are evenly spaced - time_spaces = np.unique(np.diff(timesteps)) - - # gives the unique values of time gaps between timesteps. If this contains more than - # one value, then timesteps are not all evenly spaced which is an issue. - if len(time_spaces) != 1: - warnings.warn("Timesteps are not equally spaced, so the passage of time is not linear") - self.timesteps = timesteps - - # checking input to tail_length - if tail_length > 1 or tail_length < 0: - raise ValueError("Tail length should be between 0 and 1") - self.tail_length = tail_length - - # checking sim_duration - if sim_duration <= 0: - raise ValueError("Simulation duration must be positive") - - # time window is calculated as sim_length * tail_length. This is - # the window of time for which past plots are still visible - self.time_window = (timesteps[-1] - timesteps[0]) * tail_length - - self.colorway = colors.qualitative.Plotly[1:] # plotting colours - - self.all_masks = dict() # dictionary to be filled up later - - self.plotting_function_called = False # keeps track if anything has been plotted or not - # so that only the first data plotted will override the default axis max and mins. - - self.fig = go.Figure() - - layout_kwargs = dict( - xaxis=dict(title=dict(text="x")), - yaxis=dict(title=dict(text="y")), - colorway=self.colorway, # Needed to match colours later. - height=550, - autosize=True - ) - # layout_kwargs.update(kwargs) - self.fig.update_layout(layout_kwargs) - - # initialise frames according to simulation timesteps - self.fig.frames = [dict( - name=str(time), - data=[], - traces=[] - ) for time in timesteps] - - self.fig.update_xaxes(range=[0, 10]) - self.fig.update_yaxes(range=[0, 10]) - - frame_duration = sim_duration * 1000 / len(self.fig.frames) - - # if the gap between timesteps is greater than a day, it isn't necessary - # to display hour and minute information, so remove this to give a cleaner display. - # a and b are used in the slider steps label later - if time_spaces[0] >= timedelta(days=1): - start_cut_off = None - end_cut_off = 10 - - # if the simulation is over a day long, display all information which - # looks clunky but is necessary - elif timesteps[-1] - timesteps[0] > timedelta(days=1): - start_cut_off = None - end_cut_off = None - - # otherwise, remove day information and just show - # hours, mins, etc. which is cleaner to look at - else: - start_cut_off = 11 - end_cut_off = None - - # create button and slider - updatemenus = [dict(type='buttons', - buttons=[{ - "args": [None, - {"frame": {"duration": frame_duration, "redraw": True}, - "fromcurrent": True, "transition": {"duration": 0}}], - "label": "Play", - "method": "animate" - }, { - "args": [[None], {"frame": {"duration": 0, "redraw": True}, - "mode": "immediate", - "transition": {"duration": 0}}], - "label": "Stop", - "method": "animate" - }], - direction='left', - pad=dict(r=10, t=75), - showactive=True, x=0.1, y=0, xanchor='right', yanchor='top') - ] - sliders = [{'yanchor': 'top', - 'xanchor': 'left', - 'currentvalue': {'font': {'size': 16}, 'prefix': 'Time: ', 'visible': True, - 'xanchor': 'right'}, - 'transition': {'duration': frame_duration, 'easing': 'linear'}, - 'pad': {'b': 10, 't': 50}, - 'len': 0.9, 'x': 0.1, 'y': 0, - 'steps': [{'args': [[frame.name], { - 'frame': {'duration': 1.0, 'easing': 'linear', 'redraw': True}, - 'transition': {'duration': 0, 'easing': 'linear'}}], - 'label': frame.name[start_cut_off: end_cut_off], - 'method': 'animate'} for frame in - self.fig.frames - ]}] - self.fig.update_layout(updatemenus=updatemenus, sliders=sliders) - self.fig.update_layout(kwargs) - - def show(self): - """ - Display the animation. - """ - return self.fig - - def _resize(self, data, type="track"): - """ - Reshape figure so that everything is in view. - - Parameters - ---------- - - data: - Collection of values that are being added to the figure. - Will be a list if coming from plot_ground_Truths or - plot_tracks, but will be a dictionary if coming from plot_measurements. - """ - - # fill in all data. If there is no data, fill all_x, all_y with current axis limits - if not data: - all_x = list(self.fig.layout.xaxis.range) - all_y = list(self.fig.layout.xaxis.range) - else: - all_x = list() - all_y = list() - - # fill in data - if type == "measurements": - - for key, item in data.items(): - all_x.extend(data[key]["x"]) - all_y.extend(data[key]["y"]) - - elif type in ("ground_truth", "tracks"): - - for n, _ in enumerate(data): - all_x.extend(data[n]["x"]) - all_y.extend(data[n]["y"]) - - elif type == "sensor": - sensor_xy = np.array([sensor.position[[0, 1], 0] for sensor in data]) - all_x.extend(sensor_xy[:, 0]) - all_y.extend(sensor_xy[:, 1]) - - elif type == "particle_or_uncertainty": - # data comes in format of list of dictionaries. Each dictionary contains 'x' and 'y', - # which are a list of lists. - for dictionary in data: - for x_values in dictionary["x"]: - all_x.extend([np.nanmax(x_values), np.nanmin(x_values)]) - for y_values in dictionary["y"]: - all_y.extend([np.nanmax(y_values), np.nanmin(y_values)]) - - xmax = max(all_x) - ymax = max(all_y) - xmin = min(all_x) - ymin = min(all_y) - - if self.equal_size: - xmax = ymax = max(xmax, ymax) - xmin = ymin = min(xmin, ymin) - - # if it's first time plotting data, want to ensure plotter is bound to that data - # and not the default values. Issues arise if the initial plotted data is much - # smaller than the default 0 to 10 values. - if not self.plotting_function_called: - - self.fig.update_xaxes(range=[xmin, xmax]) - self.fig.update_yaxes(range=[ymin, ymax]) - - # need to check if it's actually necessary to resize or not - if xmax >= self.fig.layout.xaxis.range[1] or xmin <= self.fig.layout.xaxis.range[0]: - - xmax = max(xmax, self.fig.layout.xaxis.range[1]) - xmin = min(xmin, self.fig.layout.xaxis.range[0]) - xrange = xmax - xmin - - # update figure while adding a small buffer to the mins and maxes - self.fig.update_xaxes(range=[xmin - xrange / 20, xmax + xrange / 20]) - - if ymax >= self.fig.layout.yaxis.range[1] or ymin <= self.fig.layout.yaxis.range[0]: - - ymax = max(ymax, self.fig.layout.yaxis.range[1]) - ymin = min(ymin, self.fig.layout.yaxis.range[0]) - yrange = ymax - ymin - - self.fig.update_yaxes(range=[ymin - yrange / 20, ymax + yrange / 20]) - - def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", - resize=True, **kwargs): - - """Plots ground truth(s) - - Plots each ground truth path passed in to :attr:`truths` and generates a legend - automatically. Ground truths are plotted as dashed lines with default colors. - - Users can change linestyle, color and marker using keyword arguments. Any changes - will apply to all ground truths. - - Parameters - ---------- - truths : Collection of :class:`~.GroundTruthPath` - Collection of ground truths which will be plotted. If not a collection and instead a - single :class:`~.GroundTruthPath` type, the argument is modified to be a set to allow - for iteration. - mapping: list - List of items specifying the mapping of the position components of the state space. - truths_label: str - Name of ground truths in legend/plot - resize: bool - if True, will resize figure to ensure that ground truths are in view - \\*\\*kwargs: dict - Additional arguments to be passed to plot function. Default is ``linestyle="--"``. - - """ - - if not isinstance(truths, Collection) or isinstance(truths, StateMutableSequence): - truths = {truths} # Make a set of length 1 - - data = [dict() for _ in truths] # put all data into one place for later plotting - for n, truth in enumerate(truths): - - # initialise arrays that go inside the dictionary - data[n].update(x=np.zeros(len(truth)), - y=np.zeros(len(truth)), - time=np.array([0 for _ in range(len(truth))], dtype=object), - time_str=np.array([0 for _ in range(len(truth))], dtype=object), - type=np.array([0 for _ in range(len(truth))], dtype=object)) - - for k, state in enumerate(truth): - # fill the arrays here - data[n]["x"][k] = state.state_vector[mapping[0]] - data[n]["y"][k] = state.state_vector[mapping[1]] - data[n]["time"][k] = state.timestamp - data[n]["time_str"][k] = str(state.timestamp) - data[n]["type"][k] = type(state).__name__ - - trace_base = len(self.fig.data) # number of traces currently in the animation - - # add a trace that keeps the legend up for the entire simulation (will remain - # even if no truths are present), then add a trace for each truth in the simulation. - # initialise keyword arguments, then add them to the traces - truth_kwargs = dict(x=[], y=[], mode="lines", hoverinfo='none', legendgroup=truths_label, - line=dict(dash="dash", color=self.colorway[0]), legendrank=100, - name=truths_label, showlegend=True) - merge(truth_kwargs, kwargs) - # legend dummy trace - self.fig.add_trace(go.Scatter(truth_kwargs)) - - # we don't want the legend for any of the actual traces - truth_kwargs.update({"showlegend": False}) - - for n, _ in enumerate(truths): - # change the colour of each truth and include n in its name - merge(truth_kwargs, dict(line=dict(color=self.colorway[n % len(self.colorway)]))) - merge(truth_kwargs, kwargs) - self.fig.add_trace(go.Scatter(truth_kwargs)) # add to traces - - for frame in self.fig.frames: - - # get current fig data and traces - data_ = list(frame.data) - traces_ = list(frame.traces) - - # convert string to datetime object - frame_time = datetime.fromisoformat(frame.name) - cutoff_time = (frame_time - self.time_window) - - # for the legend - data_.append(go.Scatter(x=[0, 0], y=[0, 0])) - traces_.append(trace_base) - - for n, truth in enumerate(truths): - # all truth points that come at or before the frame time - t_upper = [data[n]["time"] <= frame_time] - - # only select detections that come after the time cut-off - t_lower = [data[n]["time"] >= cutoff_time] - - # put together - mask = np.logical_and(t_upper, t_lower) - - # find x, y, time, and type - truth_x = data[n]["x"][tuple(mask)] - # add in np.inf to ensure traces are present for every timestep - truth_x = np.append(truth_x, [np.inf]) - truth_y = data[n]["y"][tuple(mask)] - truth_y = np.append(truth_y, [np.inf]) - times = data[n]["time_str"][tuple(mask)] - - data_.append(go.Scatter(x=truth_x, - y=truth_y, - meta=times, - hovertemplate='GroundTruthState' + - '
(%{x}, %{y})' + - '
Time: %{meta}')) - - traces_.append(trace_base + n + 1) # append data to correct trace - - frame.data = data_ - frame.traces = traces_ - - if resize: - self._resize(data, type="ground_truth") - - # we have called a plotting function so update flag (gets used in _resize) - self.plotting_function_called = True - - def plot_measurements(self, measurements, mapping, measurement_model=None, - resize=True, measurements_label="Measurements", - convert_measurements=True, **kwargs): - """Plots measurements - - Plots detections and clutter, generating a legend automatically. Detections are plotted as - blue circles by default unless the detection type is clutter. - If the detection type is :class:`~.Clutter` it is plotted as a yellow 'tri-up' marker. - - Users can change the color and marker of detections using keyword arguments but not for - clutter detections. - - Parameters - ---------- - measurements : Collection of :class:`~.Detection` - Detections which will be plotted. If measurements is a set of lists it is flattened. - mapping: list - List of items specifying the mapping of the position components of the state space. - measurement_model : :class:`~.Model`, optional - User-defined measurement model to be used in finding measurement state inverses if - they cannot be found from the measurements themselves. - resize: bool - If True, will resize figure to ensure measurements are in view - measurements_label : str - Label for the measurements. Default is "Measurements". - convert_measurements : bool - Should the measurements be converted from measurement space to state space before - being plotted. Default is True - \\*\\*kwargs: dict - Additional arguments to be passed to scatter function for detections. Defaults are - ``marker=dict(color="#636EFA")``. - """ - - if not isinstance(measurements, Collection): - measurements = {measurements} # Make a set of length 1 - - if any(isinstance(item, set) for item in measurements): - measurements_set = chain.from_iterable(measurements) # Flatten into one set - else: - measurements_set = measurements - plot_detections, plot_clutter = self._conv_measurements(measurements_set, - mapping, - measurement_model, - convert_measurements) - plot_combined = {'Detection': plot_detections, - 'Clutter': plot_clutter} # for later reference - - # this dictionary will store all the plotting data that we need - # from the detections and clutter into numpy arrays that we can easily - # access to plot - combined_data = dict() - - # only add clutter or detections to plot if necessary - if plot_detections: - combined_data.update(dict(Detection=dict())) - if plot_clutter: - combined_data.update(dict(Clutter=dict())) - - # initialise combined_data - for key in combined_data.keys(): - length = len(plot_combined[key]) - combined_data[key].update({ - "x": np.zeros(length), - "y": np.zeros(length), - "time": np.array([0 for _ in range(length)], dtype=object), - "time_str": np.array([0 for _ in range(length)], dtype=object), - "type": np.array([0 for _ in range(length)], dtype=object)}) - - # and now fill in the data - - for key in combined_data.keys(): - for n, det in enumerate(plot_combined[key]): - x, y = list(plot_combined[key].values())[n] - combined_data[key]["x"][n] = x - combined_data[key]["y"][n] = y - combined_data[key]["time"][n] = det.timestamp - combined_data[key]["time_str"][n] = str(det.timestamp) - combined_data[key]["type"][n] = type(det).__name__ - - # get number of traces currently in fig - trace_base = len(self.fig.data) - - # initialise detections - name = measurements_label + "
(Detections)" - measurement_kwargs = dict(x=[], y=[], mode='markers', - name=name, - legendgroup=name, - legendrank=200, showlegend=True, - marker=dict(color="#636EFA"), hoverinfo='none') - merge(measurement_kwargs, kwargs) - - self.fig.add_trace(go.Scatter(measurement_kwargs)) # trace for legend - - measurement_kwargs.update({"showlegend": False}) - self.fig.add_trace(go.Scatter(measurement_kwargs)) # trace for plotting - - # change necessary kwargs to initialise clutter trace - name = measurements_label + "
(Clutter)" - clutter_kwargs = dict(x=[], y=[], mode='markers', - name=name, - legendgroup=name, - legendrank=300, showlegend=True, - marker=dict(symbol="star-triangle-up", color='#FECB52'), - hoverinfo='none') - merge(clutter_kwargs, kwargs) - - self.fig.add_trace(go.Scatter(clutter_kwargs)) # trace for plotting clutter - - # add data to frames - for frame in self.fig.frames: - - data_ = list(frame.data) - traces_ = list(frame.traces) - - # add blank data to ensure detection legend stays in place - data_.append(go.Scatter(x=[-np.inf, np.inf], y=[-np.inf, np.inf])) - traces_.append(trace_base) # ensure data is added to correct trace - - frame_time = datetime.fromisoformat(frame.name) # convert string to datetime object - - # time at which dets will disappear from the fig - cutoff_time = (frame_time - self.time_window) - - for j, key in enumerate(combined_data.keys()): - # only select measurements that arrive by the time of the current frame - t_upper = [combined_data[key]["time"] <= frame_time] - - # only select detections that come after the time cut-off - t_lower = [combined_data[key]["time"] >= cutoff_time] - - # put them together to create the final mask - mask = np.logical_and(t_upper, t_lower) - - # find x and y points for true detections and clutter - det_x = combined_data[key]["x"][tuple(mask)] - det_x = np.append(det_x, [np.inf]) - det_y = combined_data[key]["y"][tuple(mask)] - det_y = np.append(det_y, [np.inf]) - det_times = combined_data[key]["time_str"][tuple(mask)] - - data_.append(go.Scatter(x=det_x, - y=det_y, - meta=det_times, - hovertemplate=f'{key}' + - '
(%{x}, %{y})' + - '
Time: %{meta}')) - traces_.append(trace_base + j + 1) - - frame.data = data_ # update the figure - frame.traces = traces_ - - if resize: - self._resize(combined_data, "measurements") - - # we have called a plotting function so update flag (gets used in resize) - self.plotting_function_called = True - - def plot_tracks(self, tracks, mapping, uncertainty=False, resize=True, - particle=False, plot_history=False, ellipse_points=30, - track_label="Tracks", **kwargs): - """ - Plots each track generated, generating a legend automatically. If 'uncertainty=True', - error ellipses are plotted. Tracks are plotted as solid lines with point markers - and default colours. - - Users can change linestyle, color, and marker using keyword arguments. Uncertainty metrics - will also be plotted with the user defined colour and any changes will apply to all tracks. - - Parameters - ---------- - tracks: Collection of :class '~Track' - Collection of tracks which will be plotted. If not a collection, and instead a single - :class:'~Track' type, the argument is modified to be a set to allow for iteration - - mapping: list - List of items specifying the mapping of the position - components of the state space - uncertainty: bool - If True, function plots uncertainty ellipses - resize: bool - If True, plotter will change bounds so that tracks are in view - particle: bool - If True, function plots particles - plot_history: bool - If true, plots all particles and uncertainty ellipses up to current time step - ellipse_points: int - Number of points for polygon approximating ellipse shape - track_label: str - Label to apply to all tracks for legend - \\*\\*kwargs: dict - Additional arguments to be passed to plot function. Defaults are ``linestyle="-"``, - ``marker='s'`` for :class:`~.Update` and ``marker='o'`` for other states. - - Returns - ------- - """ - - if not isinstance(tracks, Collection) or isinstance(tracks, StateMutableSequence): - tracks = {tracks} # Make a set of length 1 - - # So that we can plot tracks for both the current time and for some previous times, - # we put plotting data for each track into a dictionary so that it can be easily - # accessed later. - data = [dict() for _ in tracks] - - for n, track in enumerate(tracks): # sum up means - accounts for particle filter - - xydata = np.concatenate( - [(getattr(state, 'mean', state.state_vector)[mapping, :]) - for state in track], - axis=1) - - # initialise arrays that go inside the dictionary - data[n].update(x=xydata[0], - y=xydata[1], - time=np.array([0 for _ in range(len(track))], dtype=object), - time_str=np.array([0 for _ in range(len(track))], dtype=object), - type=np.array([0 for _ in range(len(track))], dtype=object)) - - for k, state in enumerate(track): - # fill the arrays here - data[n]["time"][k] = state.timestamp - data[n]["time_str"][k] = str(state.timestamp) - data[n]["type"][k] = type(state).__name__ - - trace_base = len(self.fig.data) # number of traces - - # add dummy trace for legend for track - - track_kwargs = dict(x=[], y=[], mode="markers+lines", line=dict(color=self.colorway[2]), - legendgroup=track_label, legendrank=400, name=track_label, - showlegend=True) - track_kwargs.update(kwargs) - self.fig.add_trace(go.Scatter(track_kwargs)) - - # and initialise traces for every track. Need to change a few kwargs: - track_kwargs.update({'showlegend': False}) - - for k, _ in enumerate(tracks): - # update track colours - track_kwargs.update({'line': dict(color=self.colorway[(k + 2) % len(self.colorway)])}) - track_kwargs.update(kwargs) - self.fig.add_trace(go.Scatter(track_kwargs)) - - for frame in self.fig.frames: - # get current fig data and traces - data_ = list(frame.data) - traces_ = list(frame.traces) - - # convert string to datetime object - frame_time = datetime.fromisoformat(frame.name) - - self.all_masks[frame_time] = dict() # save mask for later use - cutoff_time = (frame_time - self.time_window) - # add blank data to ensure legend stays in place - data_.append(go.Scatter(x=[-np.inf, np.inf], y=[-np.inf, np.inf])) - traces_.append(trace_base) # ensure data is added to correct trace - - for n, track in enumerate(tracks): - - # all track points that come at or before the frame time - t_upper = [data[n]["time"] <= frame_time] - # only select detections that come after the time cut-off - t_lower = [data[n]["time"] >= cutoff_time] - - # put together - mask = np.logical_and(t_upper, t_lower) - - # put into dictionary for later use - if plot_history: - self.all_masks[frame_time][n] = np.logical_and(t_upper, t_lower) - else: - self.all_masks[frame_time][n] = [data[n]["time"] == frame_time] - - # find x, y, time, and type - track_x = data[n]["x"][tuple(mask)] - # add np.inf to plot so that the traces are present for entire simulation - track_x = np.append(track_x, [np.inf]) - - # repeat for y - track_y = data[n]["y"][tuple(mask)] - track_y = np.append(track_y, [np.inf]) - track_type = data[n]["type"][tuple(mask)] - times = data[n]["time_str"][tuple(mask)] - - data_.append(go.Scatter(x=track_x, # plot track - y=track_y, - meta=track_type, - customdata=times, - hovertemplate='%{meta}' + - '
(%{x}, %{y})' + - '
Time: %{customdata}')) - - traces_.append(trace_base + n + 1) # add to correct trace - - frame.data = data_ - frame.traces = traces_ - - if resize: - self._resize(data, "tracks") - - if uncertainty: # plot ellipses - name = f'{track_label}
Uncertainty' - uncertainty_kwargs = dict(x=[], y=[], legendgroup=name, fill='toself', - fillcolor=self.colorway[2], - opacity=0.2, legendrank=500, name=name, - hoverinfo='skip', - mode='none', showlegend=True) - uncertainty_kwargs.update(kwargs) - - # dummy trace for legend for uncertainty - self.fig.add_trace(go.Scatter(uncertainty_kwargs)) - - # and an uncertainty ellipse trace for each track - uncertainty_kwargs.update({'showlegend': False}) - for k, _ in enumerate(tracks): - uncertainty_kwargs.update( - {'fillcolor': self.colorway[(k + 2) % len(self.colorway)]}) - uncertainty_kwargs.update(kwargs) - self.fig.add_trace(go.Scatter(uncertainty_kwargs)) - - # following function finds uncertainty data points and plots them - self._plot_particles_and_ellipses(tracks, mapping, resize, method="uncertainty") - - if particle: # plot particles - - # initialise traces. One for legend and one per track - name = f'{track_label}
Particles' - particle_kwargs = dict(mode='markers', marker=dict(size=2, color=self.colorway[2]), - opacity=0.4, - hoverinfo='skip', legendgroup=name, name=name, - legendrank=520, showlegend=True) - # apply any keyword arguments - particle_kwargs.update(kwargs) - self.fig.add_trace(go.Scatter(particle_kwargs)) # legend trace - - particle_kwargs.update({"showlegend": False}) - - for k, track in enumerate(tracks): # trace for each track - - particle_kwargs.update( - {'marker': dict(size=2, color=self.colorway[(k + 2) % len(self.colorway)])}) - particle_kwargs.update(kwargs) - self.fig.add_trace(go.Scatter(particle_kwargs)) - - self._plot_particles_and_ellipses(tracks, mapping, resize, method="particles") - - # we have called a plotting function so update flag - self.plotting_function_called = True - - def _plot_particles_and_ellipses(self, tracks, mapping, resize, method="uncertainty"): - - """ - The logic for plotting uncertainty ellipses and particles is nearly identical, - so it is put into one function. - - Parameters - ---------- - tracks: Collection of :class '~Track' - Collection of tracks which will be plotted. If not a collection, and instead a single - :class:'~Track' type, the argument is modified to be a set to allow for iteration - mapping: list - List of items specifying the mapping of the position components of the state space. - method: str - Can either be "uncertainty" or "particles". Depends on what the function is plotting. - """ - - data = [dict() for _ in tracks] - trace_base = len(self.fig.data) - for n, track in enumerate(tracks): - - # initialise arrays that store particle/ellipse for later plotting - data[n].update(x=np.array([0 for _ in range(len(track))], dtype=object), - y=np.array([0 for _ in range(len(track))], dtype=object)) - - for k, state in enumerate(track): - - # find data points - if method == "uncertainty": - - data_x, data_y = Plotterly._generate_ellipse_points(state, mapping) - data_x = list(data_x) - data_y = list(data_y) - data_x.append(np.nan) # necessary to draw multiple ellipses at once - data_y.append(np.nan) - data[n]["x"][k] = data_x - data[n]["y"][k] = data_y - - elif method == "particles": - - data_xy = state.state_vector[mapping[:2], :] - data[n]["x"][k] = data_xy[0] - data[n]["y"][k] = data_xy[1] - - else: - raise ValueError("Should be 'uncertainty' or 'particles'") - - for frame in self.fig.frames: - - frame_time = datetime.fromisoformat(frame.name) - - data_ = list(frame.data) # current data in frame - traces_ = list(frame.traces) # current traces in frame - - data_.append(go.Scatter(x=[-np.inf], y=[np.inf])) # add empty data for legend trace - traces_.append(trace_base - len(tracks) - 1) # ensure correct trace - - for n, track in enumerate(tracks): - # now plot the data - _x = list(chain(*data[n]["x"][tuple(self.all_masks[frame_time][n])])) - _y = list(chain(*data[n]["y"][tuple(self.all_masks[frame_time][n])])) - _x.append(np.inf) - _y.append(np.inf) - data_.append(go.Scatter(x=_x, y=_y)) - traces_.append(trace_base - len(tracks) + n) - - frame.data = data_ - frame.traces = traces_ - - if resize: - self._resize(data, type="particle_or_uncertainty") - - def plot_sensors(self, sensors, sensor_label="Sensors", resize=True, **kwargs): - """Plots sensor(s) - - Plots sensors. Users can change the color and marker of detections using keyword - arguments. Default is a black 'x' marker. Currently only works for stationary - sensors. - - Parameters - ---------- - sensors : Collection of :class:`~.Sensor` - Sensors to plot - sensor_label: str - Label to apply to all tracks for legend. - \\*\\*kwargs: dict - Additional arguments to be passed to scatter function for detections. Defaults are - ``marker=dict(symbol='x', color='black')``. - """ - if not isinstance(sensors, Collection): - sensors = {sensors} - - # don't run any of this if there is no data input - if sensors: - trace_base = len(self.fig.data) # number of traces currently in figure - sensor_kwargs = dict(mode='markers', marker=dict(symbol='x', color='black'), - legendgroup=sensor_label, legendrank=50, - name=sensor_label, showlegend=True) - merge(sensor_kwargs, kwargs) - - self.fig.add_trace(go.Scatter(sensor_kwargs)) # initialises trace - - # sensor position - sensor_xy = np.array([sensor.position[[0, 1], 0] for sensor in sensors]) - if resize: - self._resize(sensors, "sensor") - - for frame in self.fig.frames: # the plotting bit - traces_ = list(frame.traces) - data_ = list(frame.data) - - data_.append(go.Scatter(x=sensor_xy[:, 0], y=sensor_xy[:, 1])) - traces_.append(trace_base) - - frame.traces = traces_ - frame.data = data_ - - # we have called a plotting function so update flag (used in _resize) - self.plotting_function_called = True diff --git a/stonesoup/types/tests/test_state.py b/stonesoup/types/tests/test_state.py index 4b64cbee9..8284c781f 100644 --- a/stonesoup/types/tests/test_state.py +++ b/stonesoup/types/tests/test_state.py @@ -1,21 +1,39 @@ import copy import datetime +import time import numpy as np import pytest import scipy.linalg +from numpy.linalg import inv +from ...base import Property +from ...functions import gridCreation from ..angle import Bearing -from ..array import StateVector, StateVectors, CovarianceMatrix +from ..array import CovarianceMatrix, StateVector, StateVectors from ..groundtruth import GroundTruthState from ..numeric import Probability from ..particle import Particle -from ..state import CreatableFromState -from ..state import State, GaussianState, ParticleState, EnsembleState, \ - StateMutableSequence, WeightedGaussianState, SqrtGaussianState, CategoricalState, \ - CompositeState, InformationState, ASDState, ASDGaussianState, ASDWeightedGaussianState, \ - MultiModelParticleState, RaoBlackwellisedParticleState, BernoulliParticleState -from ...base import Property +from ..state import ( + ASDGaussianState, + ASDState, + ASDWeightedGaussianState, + BernoulliParticleState, + CategoricalState, + CompositeState, + CreatableFromState, + EnsembleState, + GaussianState, + InformationState, + MultiModelParticleState, + ParticleState, + PointMassState, + RaoBlackwellisedParticleState, + SqrtGaussianState, + State, + StateMutableSequence, + WeightedGaussianState, +) def test_state(): @@ -39,16 +57,23 @@ def test_state_invalid_vector(): def test_gaussianstate(): - """ GaussianState Type test """ + """GaussianState Type test""" with pytest.raises(TypeError): GaussianState() mean = StateVector([[-1.8513], [0.9994], [0], [0]]) * 1e4 - covar = CovarianceMatrix([[2.2128, 0, 0, 0], - [0.0002, 2.2130, 0, 0], - [0.3897, -0.00004, 0.0128, 0], - [0, 0.3897, 0.0013, 0.0135]]) * 1e3 + covar = ( + CovarianceMatrix( + [ + [2.2128, 0, 0, 0], + [0.0002, 2.2130, 0, 0], + [0.3897, -0.00004, 0.0128, 0], + [0, 0.3897, 0.0013, 0.0135], + ] + ) + * 1e3 + ) timestamp = datetime.datetime.now() # Test state initiation without timestamp @@ -67,16 +92,23 @@ def test_gaussianstate(): def test_informationstate(): - """ InformationState Type test """ + """InformationState Type test""" with pytest.raises(TypeError): InformationState() mean = StateVector([[-1.8513], [0.9994], [0], [0]]) * 1e4 - covar = CovarianceMatrix([[2.2128, 0, 0, 0], - [0.0002, 2.2130, 0, 0], - [0.3897, -0.00004, 0.0128, 0], - [0, 0.3897, 0.0013, 0.0135]]) * 1e3 + covar = ( + CovarianceMatrix( + [ + [2.2128, 0, 0, 0], + [0.0002, 2.2130, 0, 0], + [0.3897, -0.00004, 0.0128, 0], + [0, 0.3897, 0.0013, 0.0135], + ] + ) + * 1e3 + ) timestamp = datetime.datetime.now() information_matrix = np.linalg.inv(covar) @@ -90,9 +122,9 @@ def test_informationstate(): # Test state initiation with timestamp state = InformationState(information_state, information_matrix, timestamp) - assert (np.allclose(mean, state.mean)) - assert (np.allclose(covar, state.covar)) - assert (state.timestamp == timestamp) + assert np.allclose(mean, state.mean) + assert np.allclose(covar, state.covar) + assert state.timestamp == timestamp # Testing from_gaussian state method gs = GaussianState(mean, covar) @@ -117,10 +149,17 @@ def test_sqrtgaussianstate(): """Test the square root Gaussian Type""" mean = np.array([[-1.8513], [0.9994], [0], [0]]) * 1e4 - covar = np.array([[2.2128, 0.1, 0.03, 0.01], - [0.1, 2.2130, 0.03, 0.02], - [0.03, 0.03, 2.123, 0.01], - [0.01, 0.02, 0.01, 2.012]]) * 1e3 + covar = ( + np.array( + [ + [2.2128, 0.1, 0.03, 0.01], + [0.1, 2.2130, 0.03, 0.02], + [0.03, 0.03, 2.123, 0.01], + [0.01, 0.02, 0.01, 2.012], + ] + ) + * 1e3 + ) timestamp = datetime.datetime.now() # Test that a lower triangular matrix returned when 'full' covar is passed @@ -129,8 +168,12 @@ def test_sqrtgaussianstate(): assert np.array_equal(state.sqrt_covar, lower_covar) assert np.allclose(state.covar, covar, 0, atol=1e-10) assert np.allclose(state.sqrt_covar @ state.sqrt_covar.T, covar, 0, atol=1e-10) - assert np.allclose(state.sqrt_covar @ state.sqrt_covar.T, lower_covar @ lower_covar.T, 0, - atol=1e-10) + assert np.allclose( + state.sqrt_covar @ state.sqrt_covar.T, + lower_covar @ lower_covar.T, + 0, + atol=1e-10, + ) # Test that a general square root matrix is also a solution general_covar = scipy.linalg.sqrtm(covar) @@ -180,9 +223,13 @@ def test_particlestate(): # Create 10 1d particles: [[0,0,0,0,0,100,100,100,100,100]] # with equal weight num_particles = 10 - weight = Probability(1/num_particles) - particles = StateVectors(np.concatenate( - (np.tile([[0]], num_particles//2), np.tile([[100]], num_particles//2)), axis=1)) + weight = Probability(1 / num_particles) + particles = StateVectors( + np.concatenate( + (np.tile([[0]], num_particles // 2), np.tile([[100]], num_particles // 2)), + axis=1, + ) + ) weights = np.tile(weight, num_particles) # Test state without timestamp @@ -201,9 +248,15 @@ def test_particlestate(): # [[0,0,0,0,0,100,100,100,100,100], # [0,0,0,0,0,200,200,200,200,200]] # use same weights - particles = StateVectors(np.concatenate((np.tile([[0], [0]], num_particles//2), - np.tile([[100], [200]], num_particles//2)), - axis=1)) + particles = StateVectors( + np.concatenate( + ( + np.tile([[0], [0]], num_particles // 2), + np.tile([[100], [200]], num_particles // 2), + ), + axis=1, + ) + ) state = ParticleState(particles, weight=weights) assert isinstance(state, State) @@ -213,23 +266,31 @@ def test_particlestate(): assert state.ndim == 2 # Create ParticleState from state vectors, weights and particle list - state_vector_array = np.concatenate((np.tile([[0], [0]], num_particles // 2), - np.tile([[100], [200]], num_particles // 2)), - axis=1) - state_vector_gen = ([state_vector_array[0][particle], - state_vector_array[1][particle]] for particle in range(num_particles)) + state_vector_array = np.concatenate( + ( + np.tile([[0], [0]], num_particles // 2), + np.tile([[100], [200]], num_particles // 2), + ), + axis=1, + ) + state_vector_gen = ( + [state_vector_array[0][particle], state_vector_array[1][particle]] + for particle in range(num_particles) + ) weight = Probability(1 / num_particles) - particle_list = [Particle(state_vector, - weight=weight) for state_vector in state_vector_gen] + particle_list = [ + Particle(state_vector, weight=weight) for state_vector in state_vector_gen + ] with pytest.raises(ValueError): ParticleState(particles, particle_list=particle_list, weight=weight) parent_list = particle_list - particle_list2 = [Particle([0, 0], - weight=weight, - parent=parent) for parent in parent_list] - state = ParticleState(None, particle_list=particle_list2, - timestamp=timestamp, fixed_covar=[1, 1]) + particle_list2 = [ + Particle([0, 0], weight=weight, parent=parent) for parent in parent_list + ] + state = ParticleState( + None, particle_list=particle_list2, timestamp=timestamp, fixed_covar=[1, 1] + ) assert isinstance(state.parent, ParticleState) assert state.covar == [1, 1] @@ -243,8 +304,14 @@ def test_particlestate(): @pytest.mark.parametrize( - 'particle_class', [ParticleState, MultiModelParticleState, RaoBlackwellisedParticleState, - BernoulliParticleState]) + "particle_class", + [ + ParticleState, + MultiModelParticleState, + RaoBlackwellisedParticleState, + BernoulliParticleState, + ], +) def test_particle_get_item(particle_class): with pytest.raises(TypeError): particle_class() @@ -252,9 +319,13 @@ def test_particle_get_item(particle_class): # Create 10 1d particles: [[0,0,0,0,0,100,100,100,100,100]] # with equal weight num_particles = 10 - weight = Probability(1/num_particles) - particles = StateVectors(np.concatenate( - (np.tile([[0]], num_particles//2), np.tile([[100]], num_particles//2)), axis=1)) + weight = Probability(1 / num_particles) + particles = StateVectors( + np.concatenate( + (np.tile([[0]], num_particles // 2), np.tile([[100]], num_particles // 2)), + axis=1, + ) + ) weights = np.tile(weight, num_particles) timestamp = datetime.datetime.now() @@ -264,8 +335,8 @@ def test_particle_get_item(particle_class): assert np.allclose(state[0].state_vector, StateVector([[0]])) assert np.allclose(state[-1].state_vector, StateVector([[100]])) - assert pytest.approx(1/num_particles) == state[0].weight - assert pytest.approx(1/num_particles) == state[-1].weight + assert pytest.approx(1 / num_particles) == state[0].weight + assert pytest.approx(1 / num_particles) == state[-1].weight assert np.allclose(state[0].parent.state_vector, state[0].state_vector) assert np.allclose(state[-1].parent.state_vector, state[-1].state_vector) @@ -284,11 +355,19 @@ def test_particlestate_weighted(): # Half particles at high weight at 0 # Create 10 1d particles: [[0,0,0,0,0,100,100,100,100,100]] # with different weights this time. First half have 0.75 and the second half 0.25. - particles = StateVectors([np.concatenate( - (np.tile(0, num_particles // 2), np.tile(100, num_particles // 2)))]) + particles = StateVectors( + [ + np.concatenate( + (np.tile(0, num_particles // 2), np.tile(100, num_particles // 2)) + ) + ] + ) weights = np.concatenate( - (np.tile(Probability(0.75 / (num_particles / 2)), num_particles // 2), - np.tile(Probability(0.25 / (num_particles / 2)), num_particles // 2))) + ( + np.tile(Probability(0.75 / (num_particles / 2)), num_particles // 2), + np.tile(Probability(0.25 / (num_particles / 2)), num_particles // 2), + ) + ) # Check particles sum to 1 still assert pytest.approx(1) == sum(weight for weight in weights) @@ -310,24 +389,34 @@ def test_particlestate_angle(): # There's interplay between the Bearing and Probability types and resulting # rounding approximations can fail the test. particles = StateVectors( - np.concatenate((np.tile([[Bearing(np.pi + 0.1)], [-10.0]], num_particles//2), - np.tile([[Bearing(np.pi - 0.1)], [20.0]], num_particles//2)), axis=1)) + np.concatenate( + ( + np.tile([[Bearing(np.pi + 0.1)], [-10.0]], num_particles // 2), + np.tile([[Bearing(np.pi - 0.1)], [20.0]], num_particles // 2), + ), + axis=1, + ) + ) - weight = Probability(1/num_particles) + weight = Probability(1 / num_particles) weights = np.tile(weight, num_particles) # Test state without timestamp state = ParticleState(particles, weight=weights) - assert np.allclose(state.mean, StateVector([[np.pi], [5.]])) + assert np.allclose(state.mean, StateVector([[np.pi], [5.0]])) assert np.allclose(state.covar, CovarianceMatrix([[0.01, -1.5], [-1.5, 225]])) def test_particlestate_cache(): num_particles = 10 - weight = Probability(1/num_particles) - particles = StateVectors(np.concatenate( - (np.tile([[0]], num_particles//2), np.tile([[100]], num_particles//2)), axis=1)) + weight = Probability(1 / num_particles) + particles = StateVectors( + np.concatenate( + (np.tile([[0]], num_particles // 2), np.tile([[100]], num_particles // 2)), + axis=1, + ) + ) weights = np.tile(weight, num_particles) state = ParticleState(particles, weight=weights) @@ -349,12 +438,18 @@ def test_particlestate_cache(): @pytest.mark.parametrize( - 'particle_class', [ParticleState, MultiModelParticleState, RaoBlackwellisedParticleState, - BernoulliParticleState]) + "particle_class", + [ + ParticleState, + MultiModelParticleState, + RaoBlackwellisedParticleState, + BernoulliParticleState, + ], +) def test_particle_parent_parent(particle_class): - state1 = ParticleState([[1, 2, 3]], weight=np.full((3, ), 1/3)) - state2 = ParticleState([[2, 3, 1]], weight=np.full((3, ), 1/3), parent=state1) - state3 = ParticleState([[3, 1, 2]], weight=np.full((3, ), 1/3), parent=state2) + state1 = ParticleState([[1, 2, 3]], weight=np.full((3,), 1 / 3)) + state2 = ParticleState([[2, 3, 1]], weight=np.full((3,), 1 / 3), parent=state1) + state3 = ParticleState([[3, 1, 2]], weight=np.full((3,), 1 / 3), parent=state2) assert state2.parent is state1 assert state3.parent is state2 @@ -400,8 +495,9 @@ def test_ensemblestate(): # 1 Dimensional test_mean_1d = np.array([0]) test_covar_1d = np.array([1]) - ensemble1d = state.generate_ensemble(mean=test_mean_1d, - covar=test_covar_1d, num_vectors=5) + ensemble1d = state.generate_ensemble( + mean=test_mean_1d, covar=test_covar_1d, num_vectors=5 + ) assert np.shape(ensemble1d) == (1, 5) assert isinstance(ensemble1d, StateVectors) @@ -409,8 +505,9 @@ def test_ensemblestate(): # Lets pass in a state vector mean(as opposed to an array) while we're at it test_mean_2d = StateVector([1, 1]) test_covar_2d = CovarianceMatrix(np.eye(2)) - ensemble2d = state.generate_ensemble(mean=test_mean_2d, - covar=test_covar_2d, num_vectors=5) + ensemble2d = state.generate_ensemble( + mean=test_mean_2d, covar=test_covar_2d, num_vectors=5 + ) assert np.shape(ensemble2d) == (2, 5) assert isinstance(ensemble2d, StateVectors) @@ -439,15 +536,15 @@ def test_state_mutable_sequence_state(): timestamp = datetime.datetime(2018, 1, 1, 14) delta = datetime.timedelta(minutes=1) sequence = StateMutableSequence( - [State(state_vector, timestamp=timestamp+delta*n) - for n in range(10)]) + [State(state_vector, timestamp=timestamp + delta * n) for n in range(10)] + ) assert sequence.state is sequence.states[-1] assert np.array_equal(sequence.state_vector, state_vector) - assert sequence.timestamp == timestamp + delta*9 + assert sequence.timestamp == timestamp + delta * 9 del sequence[-1] - assert sequence.timestamp == timestamp + delta*8 + assert sequence.timestamp == timestamp + delta * 8 def test_state_mutable_sequence_slice(): @@ -455,8 +552,8 @@ def test_state_mutable_sequence_slice(): timestamp = datetime.datetime(2018, 1, 1, 14) delta = datetime.timedelta(minutes=1) sequence = StateMutableSequence( - [State(state_vector, timestamp=timestamp+delta*n) - for n in range(10)]) + [State(state_vector, timestamp=timestamp + delta * n) for n in range(10)] + ) assert isinstance(sequence[timestamp:], StateMutableSequence) assert isinstance(sequence[5:], StateMutableSequence) @@ -465,11 +562,11 @@ def test_state_mutable_sequence_slice(): assert len(sequence[timestamp:]) == 10 assert len(sequence[:timestamp]) == 0 - assert len(sequence[timestamp+delta*5:]) == 5 - assert len(sequence[:timestamp+delta*5]) == 5 - assert len(sequence[timestamp+delta*4:timestamp+delta*6]) == 2 - assert len(sequence[timestamp+delta*2:timestamp+delta*8:3]) == 2 - assert len(sequence[timestamp+delta*1:][:timestamp+delta*2]) == 1 + assert len(sequence[timestamp + delta * 5:]) == 5 + assert len(sequence[: timestamp + delta * 5]) == 5 + assert len(sequence[timestamp + delta * 4: timestamp + delta * 6]) == 2 + assert len(sequence[timestamp + delta * 2: timestamp + delta * 8: 3]) == 2 + assert len(sequence[timestamp + delta * 1:][: timestamp + delta * 2]) == 1 assert sequence[timestamp] == sequence.states[0] @@ -489,7 +586,7 @@ def test_state_mutable_sequence_slice(): sequence[timestamp:1] with pytest.raises(IndexError): - sequence[timestamp-delta] + sequence[timestamp - delta] def test_state_mutable_sequence_sequence_init(): @@ -498,8 +595,10 @@ def test_state_mutable_sequence_sequence_init(): timestamp = datetime.datetime(2018, 1, 1, 14) delta = datetime.timedelta(minutes=1) sequence = StateMutableSequence( - StateMutableSequence([State(state_vector, timestamp=timestamp + delta * n) - for n in range(10)])) + StateMutableSequence( + [State(state_vector, timestamp=timestamp + delta * n) for n in range(10)] + ) + ) assert not isinstance(sequence.states, list) @@ -529,10 +628,12 @@ def complicated_attribute(self): if self.test_property == 3: return self.test_property else: - raise AttributeError('Custom error message') + raise AttributeError("Custom error message") timestamp = datetime.datetime.now() - test_obj = TestSMS(states=State(state_vector=StateVector([1, 2, 3]), timestamp=timestamp)) + test_obj = TestSMS( + states=State(state_vector=StateVector([1, 2, 3]), timestamp=timestamp) + ) # First check no errors on assigned vars test_obj.test_method() @@ -546,11 +647,14 @@ def complicated_attribute(self): assert test_obj.timestamp == timestamp # Now check that the right error messages are raised on missing attributes - with pytest.raises(AttributeError, match="'TestSMS' object has no attribute 'missing_method'"): + with pytest.raises( + AttributeError, match="'TestSMS' object has no attribute 'missing_method'" + ): test_obj.missing_method() - with pytest.raises(AttributeError, match="'TestSMS' object has no attribute " - "'missing_variable'"): + with pytest.raises( + AttributeError, match="'TestSMS' object has no attribute " "'missing_variable'" + ): _ = test_obj.missing_variable # And check custom error messages are not swallowed @@ -569,8 +673,8 @@ def test_state_mutable_sequence_copy(): timestamp = datetime.datetime(2018, 1, 1, 14) delta = datetime.timedelta(minutes=1) sequence = StateMutableSequence( - [State(state_vector, timestamp=timestamp+delta*n) - for n in range(10)]) + [State(state_vector, timestamp=timestamp + delta * n) for n in range(10)] + ) sequence2 = copy.copy(sequence) @@ -588,7 +692,7 @@ def test_from_state(): states = [ State(**kwargs), GaussianState(**kwargs, covar=np.eye(4)), - GroundTruthState(**kwargs, metadata={"colour": "blue"}) + GroundTruthState(**kwargs, metadata={"colour": "blue"}), ] for use_sequence in (False, True): @@ -649,9 +753,13 @@ def test_from_state(): def test_creatable_from_state_error(): class SubclassCfs(CreatableFromState): pass - with pytest.raises(TypeError, - match='The first superclass of a CreatableFromState subclass must be a ' - 'CreatableFromState \\(or a subclass\\)'): + + with pytest.raises( + TypeError, + match="The first superclass of a CreatableFromState subclass must be a " + "CreatableFromState \\(or a subclass\\)", + ): + class SubSubclassCfs(State, SubclassCfs): pass @@ -660,8 +768,12 @@ class SubSubclassCfs(State, SubclassCfs): def test_creatable_from_state_multi_base_error(): class SubclassCfs(CreatableFromState): pass - with pytest.raises(TypeError, - match='A CreatableFromState subclass must have exactly two superclasses'): + + with pytest.raises( + TypeError, + match="A CreatableFromState subclass must have exactly two superclasses", + ): + class SubSubclassCfs(State, StateMutableSequence, SubclassCfs): pass @@ -669,9 +781,13 @@ class SubSubclassCfs(State, StateMutableSequence, SubclassCfs): def test_categorical_state(): # Test mismatched number of category names - with pytest.raises(ValueError, match="ndim of 3 does not match number of categories 4"): - CategoricalState(state_vector=StateVector([50, 60, 90]), - categories=['red', 'green', 'blue', 'yellow']) + with pytest.raises( + ValueError, match="ndim of 3 does not match number of categories 4" + ): + CategoricalState( + state_vector=StateVector([50, 60, 90]), + categories=["red", "green", "blue", "yellow"], + ) state = CategoricalState(state_vector=StateVector([50, 60, 90])) @@ -679,32 +795,41 @@ def test_categorical_state(): state.state_vector == [0.25, 0.3, 0.45] # Test default category names - assert state.categories == ['0', '1', '2'] + assert state.categories == ["0", "1", "2"] # Test string assert str(state) == "P(0) = 0.25,\nP(1) = 0.3,\nP(2) = 0.45" # Test category - assert state.category == '2' + assert state.category == "2" def test_composite_state_timestamp(): - with pytest.raises(ValueError, - match="All sub-states must share the same timestamp if defined"): + with pytest.raises( + ValueError, match="All sub-states must share the same timestamp if defined" + ): CompositeState([State([0], timestamp=1), State([0], timestamp=2)]) - with pytest.raises(ValueError, - match="Sub-state timestamps and default timestamp must be the same if " - "defined"): + with pytest.raises( + ValueError, + match="Sub-state timestamps and default timestamp must be the same if " + "defined", + ): CompositeState([State([0], timestamp=1)], default_timestamp=2) - with pytest.raises(ValueError, - match="Sub-state timestamps and default timestamp must be the same if " - "defined"): - CompositeState([State([0], timestamp=1), State([0], timestamp=1)], default_timestamp=2) + with pytest.raises( + ValueError, + match="Sub-state timestamps and default timestamp must be the same if " + "defined", + ): + CompositeState( + [State([0], timestamp=1), State([0], timestamp=1)], default_timestamp=2 + ) for i in range(1, 4): assert CompositeState(i * [State([0], timestamp=1)]).timestamp == 1 - assert CompositeState(i * [State([0], timestamp=1)], - default_timestamp=1).timestamp == 1 + assert ( + CompositeState(i * [State([0], timestamp=1)], default_timestamp=1).timestamp + == 1 + ) assert CompositeState(i * [State([0])]).timestamp is None @@ -720,8 +845,10 @@ def test_composite_state(): state = CompositeState(sub_states) # Test state vectors - for actual, expected in zip(state.state_vectors, - [StateVector([0, 1]), StateVector([2]), StateVector([3, 4])]): + for actual, expected in zip( + state.state_vectors, + [StateVector([0, 1]), StateVector([2]), StateVector([3, 4])], + ): assert (actual == expected).all() # Test state vector @@ -761,8 +888,7 @@ def test_asd_state(): timestamp1 = datetime.datetime.now() timestamp2 = datetime.datetime.now() state_vector = np.array([[0], [1], [2], [3]]) - state = ASDState(state_vector, - timestamps=[timestamp1, timestamp2], max_nstep=10) + state = ASDState(state_vector, timestamps=[timestamp1, timestamp2], max_nstep=10) assert state.timestamps == [timestamp1, timestamp2] assert np.array_equal(state.multi_state_vector, state_vector) assert np.array_equal(state.state_vector, state_vector[0:2]) @@ -784,16 +910,23 @@ def test_asd_state(): def test_asd_gaussian_state(): - """ GaussianState Type test """ + """GaussianState Type test""" with pytest.raises(TypeError): ASDGaussianState() mean = np.array([[-1.8513], [0.9994], [0], [0]]) * 1e4 - covar = np.array([[2.2128, 0, 0, 0], - [0.0002, 2.2130, 0, 0], - [0.3897, -0.00004, 0.0128, 0], - [0, 0.3897, 0.0013, 0.0135]]) * 1e3 + covar = ( + np.array( + [ + [2.2128, 0, 0, 0], + [0.0002, 2.2130, 0, 0], + [0.3897, -0.00004, 0.0128, 0], + [0, 0.3897, 0.0013, 0.0135], + ] + ) + * 1e3 + ) timestamp = datetime.datetime.now() # Test state initiation without timestamp @@ -811,24 +944,32 @@ def test_asd_gaussian_state(): timestamp1 = datetime.datetime.now() timestamp2 = datetime.datetime.now() state_vector = np.array([[0], [1], [2], [3], [4], [5], [6], [7]]) - covar = np.array([[2.2128, 0, 0, 0, 2.2128, 0, 0, 0], - [0.0002, 2.2130, 0, 0, 0.0002, 2.2130, 0, 0], - [0.3897, -0.00004, 0.0128, 0, 0.3897, -0.00004, - 0.0128, 0], - [0, 0.3897, 0.0013, 0.0135, 0, 0.3897, 0.0013, 0.0135], - [2.2128, 0, 0, 0, 2.2128, 0, 0, 0], - [0.0002, 2.2130, 0, 0, 0.0002, 2.2130, 0, 0], - [0.3897, -0.00004, 0.0128, 0, - 0.3897, -0.00004, 0.0128, 0], - [0, 0.3897, 0.0013, 0.0135, 0, 0.3897, 0.0013, 0.0135] - ]) * 1e3 - state = ASDGaussianState(state_vector, multi_covar=covar, - timestamps=[timestamp1, timestamp2], max_nstep=10) + covar = ( + np.array( + [ + [2.2128, 0, 0, 0, 2.2128, 0, 0, 0], + [0.0002, 2.2130, 0, 0, 0.0002, 2.2130, 0, 0], + [0.3897, -0.00004, 0.0128, 0, 0.3897, -0.00004, 0.0128, 0], + [0, 0.3897, 0.0013, 0.0135, 0, 0.3897, 0.0013, 0.0135], + [2.2128, 0, 0, 0, 2.2128, 0, 0, 0], + [0.0002, 2.2130, 0, 0, 0.0002, 2.2130, 0, 0], + [0.3897, -0.00004, 0.0128, 0, 0.3897, -0.00004, 0.0128, 0], + [0, 0.3897, 0.0013, 0.0135, 0, 0.3897, 0.0013, 0.0135], + ] + ) + * 1e3 + ) + state = ASDGaussianState( + state_vector, + multi_covar=covar, + timestamps=[timestamp1, timestamp2], + max_nstep=10, + ) assert state.timestamps == [timestamp1, timestamp2] assert state.timestamp == timestamp1 assert np.array_equal(state_vector[0:4], state.mean) assert np.array_equal(covar, state.multi_covar) - assert state.ndim == state_vector.shape[0]/2 + assert state.ndim == state_vector.shape[0] / 2 assert state.nstep == 2 assert state.max_nstep == 10 @@ -855,5 +996,49 @@ def test_asd_weighted_gaussian_state(): timestamp = datetime.datetime.now() a = ASDWeightedGaussianState( - mean, multi_covar=covar, weight=weight, timestamps=[timestamp]) + mean, multi_covar=covar, weight=weight, timestamps=[timestamp] + ) assert a.weight == weight + + +def test_pointmassstate(): + nx = 4 + meanX0 = np.array([36569, 50, 55581, 50]) # mean value + varX0 = np.diag([90, 5, 160, 5]) # variance + Npa = np.array( + [31, 31, 27, 27] + ) # 33 number of points per axis, for FFT must be ODD!!!! + N = np.prod(Npa) # number of points - total + sFactor = 4 # scaling factor (number of sigmas covered by the grid) + + [predGrid, predGridDelta, gridDimOld, xOld, Ppold] = gridCreation( + np.vstack(meanX0), varX0, sFactor, nx, Npa + ) + + meanX0 = np.vstack(meanX0) + pom = predGrid - np.matlib.repmat(meanX0, 1, N) + denominator = np.sqrt((2 * np.pi) ** nx) * np.linalg.det(varX0) + pompom = np.sum( + -0.5 * np.multiply(pom.T @ inv(varX0), pom.T), 1 + ) # elementwise multiplication + pomexp = np.exp(pompom) + predDensityProb = pomexp / denominator # Adding probabilities to points + predDensityProb = predDensityProb / (sum(predDensityProb) * np.prod(predGridDelta)) + + start_time = time.time() + + priorPMF = PointMassState( + state_vector=StateVectors(predGrid), + weight=predDensityProb, + grid_delta=predGridDelta, + grid_dim=gridDimOld, + center=xOld, + eigVec=Ppold, + Npa=Npa, + timestamp=start_time, + ) + + assert np.allclose(priorPMF.mean, meanX0.ravel(), 0, 1e-2) + assert np.allclose(priorPMF.covar(), varX0, 0, 1e-1) + assert priorPMF.ndim == nx + assert priorPMF.__len__() == N diff --git a/stonesoup/updater/tests/test_kalman.py b/stonesoup/updater/tests/test_kalman.py index 332ee7d87..84717595f 100644 --- a/stonesoup/updater/tests/test_kalman.py +++ b/stonesoup/updater/tests/test_kalman.py @@ -272,3 +272,8 @@ def test_schmidtkalman(): assert np.allclose(update.mean, sk_update.mean) assert np.allclose(update.covar, sk_update.covar) + + +if __name__ == "__main__": + import pytest + pytest.main(['-v', __file__]) diff --git a/stonesoup/updater/tests/test_pointmass.py b/stonesoup/updater/tests/test_pointmass.py new file mode 100644 index 000000000..c5009830b --- /dev/null +++ b/stonesoup/updater/tests/test_pointmass.py @@ -0,0 +1,135 @@ +"""Test for updater.kalman module""" + +import pytest +import numpy as np + +from stonesoup.models.measurement.linear import LinearGaussian +from stonesoup.types.detection import Detection +from stonesoup.types.hypothesis import SingleHypothesis +from stonesoup.types.prediction import ( + GaussianStatePrediction, GaussianMeasurementPrediction, PointMassStatePrediction, PointMassMeasurementPrediction) +from stonesoup.types.state import GaussianState +from stonesoup.updater.pointMass import PointMassUpdater + +from stonesoup.models.transition.linear import KnownTurnRate +from datetime import datetime +from datetime import timedelta +from functions import gridCreation +from stonesoup.types.array import StateVectors +import time + + +@pytest.fixture(params=[PointMassUpdater]) + + +def updater(request): + updater_class = request.param + measurement_model = LinearGaussian( + ndim_state=2, mapping=[0], noise_covar=np.array([[0.04]])) + return updater_class(measurement_model) + +def test_pointmass(updater): + + time_difference = timedelta(days=0, hours=0, minutes=0, seconds=1) + + + # Initial condition - Gaussian + nx = 2 + meanX0 = np.array([36569, 55581]) # mean value + varX0 = np.diag([90, 160]) # variance + Npa = np.array([31, 31]) # 33 number of points per axis, for FFT must be ODD!!!! + N = np.prod(Npa) # number of points - total + sFactor = 4 # scaling factor (number of sigmas covered by the grid) + + + [predGrid, predGridDelta, gridDimOld, xOld, Ppold] = gridCreation(np.vstack(meanX0),varX0,sFactor,nx,Npa) + + meanX0 = np.vstack(meanX0) + pom = predGrid-np.matlib.repmat(meanX0,1,N) + denominator = np.sqrt((2*np.pi)**nx)*np.linalg.det(varX0) + pompom = np.sum(-0.5*np.multiply(pom.T@np.inv(varX0),pom.T),1) #elementwise multiplication + pomexp = np.exp(pompom) + predDensityProb = pomexp/denominator # Adding probabilities to points + predDensityProb = predDensityProb/(sum(predDensityProb)*np.prod(predGridDelta)) + + + start_time = time.time() + + prediction = PointMassStatePrediction(state_vector=StateVectors(predGrid), + weight=predDensityProb, + grid_delta = predGridDelta, + grid_dim = gridDimOld, + center = xOld, + eigVec = Ppold, + Npa = Npa, + timestamp=start_time), + prediction = prediction[0] + measurement = Detection(np.array([[-6.23]])) + + measurement_model=updater.measurement_model + + # Calculate evaluation variables + eval_measurement_prediction = GaussianMeasurementPrediction( + measurement_model.matrix(time_difference) @ prediction.mean, + measurement_model.matrix(time_difference) @ prediction.covar + @ measurement_model.matrix(time_difference).T + + measurement_model.covar(time_difference), + cross_covar=prediction.covar @ measurement_model.matrix(time_difference).T) + kalman_gain = eval_measurement_prediction.cross_covar @ np.linalg.inv( + eval_measurement_prediction.covar) + eval_posterior = GaussianState( + prediction.mean + + kalman_gain @ (measurement.state_vector + - eval_measurement_prediction.mean), + prediction.covar + - kalman_gain@eval_measurement_prediction.covar @ kalman_gain.T) + + + # Get and assert measurement prediction + measurement_prediction = updater.predict_measurement(prediction) + assert np.allclose(measurement_prediction.mean, + eval_measurement_prediction.mean, + 0, atol=1.e-14) + assert np.allclose(measurement_prediction.covar, + eval_measurement_prediction.covar, + 0, atol=1.e-14) + assert np.allclose(measurement_prediction.cross_covar, + eval_measurement_prediction.cross_covar, + 0, atol=1.e-13) + + # Perform and assert state update (without measurement prediction) + posterior = updater.update(SingleHypothesis( + prediction=prediction, + measurement=measurement)) + assert np.allclose(posterior.mean, eval_posterior.mean, 0, atol=1.e-14) + assert np.allclose(posterior.covar, eval_posterior.covar, 0, atol=1.e-13) + assert np.array_equal(posterior.hypothesis.prediction, prediction) + assert np.allclose( + posterior.hypothesis.measurement_prediction.state_vector, + measurement_prediction.state_vector, 0, atol=1.e-14) + assert np.allclose(posterior.hypothesis.measurement_prediction.covar, + measurement_prediction.covar, 0, atol=1.e-14) + assert np.array_equal(posterior.hypothesis.measurement, measurement) + assert posterior.timestamp == prediction.timestamp + + # Perform and assert state update + posterior = updater.update(SingleHypothesis( + prediction=prediction, + measurement=measurement, + measurement_prediction=measurement_prediction)) + assert np.allclose(posterior.mean, eval_posterior.mean, 0, atol=1.e-14) + assert np.allclose(posterior.covar, eval_posterior.covar, 0, atol=1.e-13) + assert np.array_equal(posterior.hypothesis.prediction, prediction) + assert np.allclose( + posterior.hypothesis.measurement_prediction.state_vector, + measurement_prediction.state_vector, 0, atol=1.e-14) + assert np.allclose(posterior.hypothesis.measurement_prediction.covar, + measurement_prediction.covar, 0, atol=1.e-14) + assert np.array_equal(posterior.hypothesis.measurement, measurement) + assert posterior.timestamp == prediction.timestamp + + + +if __name__ == "__main__": + import pytest + pytest.main(['-v', __file__]) \ No newline at end of file From 714a78a350159c7c01069f0d50b8f0f41b556852 Mon Sep 17 00:00:00 2001 From: pesslovany Date: Wed, 26 Jun 2024 18:52:58 +0200 Subject: [PATCH 07/16] fixed deletion of plotter --- stonesoup/plotter.py | 3024 +++++++++++++++++++++ stonesoup/updater/tests/test_kalman.py | 4 +- stonesoup/updater/tests/test_pointmass.py | 135 - 3 files changed, 3026 insertions(+), 137 deletions(-) create mode 100644 stonesoup/plotter.py delete mode 100644 stonesoup/updater/tests/test_pointmass.py diff --git a/stonesoup/plotter.py b/stonesoup/plotter.py new file mode 100644 index 000000000..528597cce --- /dev/null +++ b/stonesoup/plotter.py @@ -0,0 +1,3024 @@ +import warnings +from abc import ABC, abstractmethod +from datetime import datetime, timedelta +from enum import IntEnum +from itertools import chain +from typing import Collection, Iterable, Union, List, Optional, Tuple, Dict + +import numpy as np +from matplotlib import animation as animation +from matplotlib import pyplot as plt +from matplotlib.legend_handler import HandlerPatch +from matplotlib.lines import Line2D +from matplotlib.patches import Ellipse +from mergedeep import merge +from scipy.integrate import quad +from scipy.optimize import brentq +from scipy.stats import kde +try: + from plotly import colors +except ImportError: + colors = None +try: + import plotly.graph_objects as go +except ImportError: + go = None + +from .base import Base, Property +from .models.base import LinearModel, Model +from .types import detection +from .types.array import StateVector +from .types.groundtruth import GroundTruthPath +from .types.metric import SingleTimeMetric +from .types.state import State, StateMutableSequence +from .types.update import Update + + +class Dimension(IntEnum): + """Dimension Enum class for specifying plotting parameters in the Plotter class. + Used to sanitize inputs for the dimension attribute of Plotter(). + + Attributes + ---------- + TWO: int + Specifies 2D plotting for Plotter object + THREE: int + Specifies 3D plotting for Plotter object + """ + ONE = 1 # 1D plotting mode (plot state over time in Plotterly) + TWO = 2 # 2D plotting mode (original plotter.py functionality) + THREE = 3 # 3D plotting mode + + +class _Plotter(ABC): + + @abstractmethod + def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwargs): + raise NotImplementedError + + @abstractmethod + def plot_measurements(self, measurements, mapping, measurement_model=None, + measurements_label="Measurements", **kwargs): + raise NotImplementedError + + @abstractmethod + def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_label="Tracks", + **kwargs): + raise NotImplementedError + + @abstractmethod + def plot_sensors(self, sensors, mapping, sensor_label="Sensors", **kwargs): + raise NotImplementedError + + def _conv_measurements(self, measurements, mapping, measurement_model=None, + convert_measurements=True) -> \ + Tuple[Dict[detection.Detection, StateVector], Dict[detection.Clutter, StateVector]]: + conv_detections = {} + conv_clutter = {} + for state in measurements: + meas_model = state.measurement_model # measurement_model from detections + if meas_model is None: + meas_model = measurement_model # measurement_model from input + + if not convert_measurements: + state_vec = state.state_vector[mapping, :] + elif isinstance(meas_model, LinearModel): + model_matrix = meas_model.matrix() + inv_model_matrix = np.linalg.pinv(model_matrix) + state_vec = (inv_model_matrix @ state.state_vector)[mapping, :] + elif isinstance(meas_model, Model): + try: + state_vec = meas_model.inverse_function(state)[mapping, :] + except (NotImplementedError, AttributeError): + warnings.warn('Nonlinear measurement model used with no inverse ' + 'function available') + continue + else: + warnings.warn('Measurement model type not specified for all detections') + continue + + if isinstance(state, detection.Clutter): + # Plot clutter + conv_clutter[state] = (*state_vec, ) + + elif isinstance(state, detection.Detection): + # Plot detections + conv_detections[state] = (*state_vec, ) + else: + warnings.warn(f'Unknown type {type(state)}') + continue + return conv_detections, conv_clutter + + +class Plotter(_Plotter): + """Plotting class for building graphs of Stone Soup simulations using matplotlib + + A plotting class which is used to simplify the process of plotting ground truths, + measurements, clutter and tracks. Tracks can be plotted with uncertainty ellipses or + particles if required. Legends are automatically generated with each plot. + Three dimensional plots can be created using the optional dimension parameter. + + Parameters + ---------- + dimension: enum \'Dimension\' + Optional parameter to specify 2D or 3D plotting. Default is 2D plotting. + plot_timeseries: bool + Specify whether data to be plotted is time series data. Default False + \\*\\*kwargs: dict + Additional arguments to be passed to plot function. For example, figsize (Default is + (10, 6)). + + Attributes + ---------- + fig: matplotlib.figure.Figure + Generated figure for graphs to be plotted on + ax: matplotlib.axes.Axes + Generated axes for graphs to be plotted on + legend_dict: dict + Dictionary of legend handles as :class:`matplotlib.legend_handler.HandlerBase` + and labels as str + """ + + def __init__(self, dimension=Dimension.TWO, **kwargs): + figure_kwargs = {"figsize": (10, 6)} + figure_kwargs.update(kwargs) + if isinstance(dimension, type(Dimension.TWO)): + self.dimension = dimension + elif isinstance(dimension, int): + self.dimension = Dimension(dimension) + else: + raise TypeError("%s is an unsupported type for \'dimension\'; " + "expected type %s" % (type(dimension), type(Dimension.TWO))) + # Generate plot axes + self.fig = plt.figure(**figure_kwargs) + if self.dimension is Dimension.TWO: # 2D axes + self.ax = self.fig.add_subplot(1, 1, 1) + self.ax.axis('equal') + else: # 3D axes + self.ax = self.fig.add_subplot(111, projection='3d') + self.ax.axis('auto') + self.ax.set_zlabel("$z$") + self.ax.set_xlabel("$x$") + self.ax.set_ylabel("$y$") + + # Create empty dictionary for legend handles and labels - dict used to + # prevent multiple entries with the same label from displaying on legend + # This is new compared to plotter.py + self.legend_dict = {} # create an empty dictionary to hold legend entries + + def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwargs): + """Plots ground truth(s) + + Plots each ground truth path passed in to :attr:`truths` and generates a legend + automatically. Ground truths are plotted as dashed lines with default colors. + + Users can change linestyle, color and marker using keyword arguments. Any changes + will apply to all ground truths. + + Parameters + ---------- + truths : Collection of :class:`~.GroundTruthPath` + Collection of ground truths which will be plotted. If not a collection and instead a + single :class:`~.GroundTruthPath` type, the argument is modified to be a set to allow + for iteration. + mapping: list + List of items specifying the mapping of the position components of the state space. + truths_label: str + Label for truth data. Default is "Ground Truth" + \\*\\*kwargs: dict + Additional arguments to be passed to plot function. Default is ``linestyle="--"``. + + Returns + ------- + : list of :class:`matplotlib.artist.Artist` + List of artists that have been added to the axis. + """ + truths_kwargs = dict(linestyle="--") + truths_kwargs.update(kwargs) + if not isinstance(truths, Collection) or isinstance(truths, StateMutableSequence): + truths = {truths} # Make a set of length 1 + + artists = [] + for truth in truths: + if self.dimension is Dimension.TWO: # plots the ground truths in xy + artists.extend( + self.ax.plot([state.state_vector[mapping[0]] for state in truth], + [state.state_vector[mapping[1]] for state in truth], + **truths_kwargs)) + elif self.dimension is Dimension.THREE: # plots the ground truths in xyz + artists.extend( + self.ax.plot3D([state.state_vector[mapping[0]] for state in truth], + [state.state_vector[mapping[1]] for state in truth], + [state.state_vector[mapping[2]] for state in truth], + **truths_kwargs)) + else: + raise NotImplementedError('Unsupported dimension type for truth plotting') + # Generate legend items + if "color" in kwargs: + colour = kwargs["color"] + else: + colour = "black" + truths_handle = Line2D([], [], linestyle=truths_kwargs['linestyle'], color=colour) + self.legend_dict[truths_label] = truths_handle + # Generate legend + artists.append(self.ax.legend(handles=self.legend_dict.values(), + labels=self.legend_dict.keys())) + return artists + + def plot_measurements(self, measurements, mapping, measurement_model=None, + measurements_label="Measurements", convert_measurements=True, **kwargs): + """Plots measurements + + Plots detections and clutter, generating a legend automatically. Detections are plotted as + blue circles by default unless the detection type is clutter. + If the detection type is :class:`~.Clutter` it is plotted as a yellow 'tri-up' marker. + + Users can change the color and marker of detections using keyword arguments but not for + clutter detections. + + Parameters + ---------- + measurements : Collection of :class:`~.Detection` + Detections which will be plotted. If measurements is a set of lists it is flattened. + mapping: list + List of items specifying the mapping of the position components of the state space. + measurement_model : :class:`~.Model`, optional + User-defined measurement model to be used in finding measurement state inverses if + they cannot be found from the measurements themselves. + measurements_label : str + Label for the measurements. Default is "Measurements". + convert_measurements : bool + Should the measurements be converted from measurement space to state space before + being plotted. Default is True + \\*\\*kwargs: dict + Additional arguments to be passed to plot function for detections. Defaults are + ``marker='o'`` and ``color='b'``. + + Returns + ------- + : list of :class:`matplotlib.artist.Artist` + List of artists that have been added to the axis. + """ + + measurement_kwargs = dict(marker='o', color='b') + measurement_kwargs.update(kwargs) + + if not isinstance(measurements, Collection): + measurements = {measurements} # Make a set of length 1 + + if any(isinstance(item, set) for item in measurements): + measurements_set = chain.from_iterable(measurements) # Flatten into one set + else: + measurements_set = measurements + + plot_detections, plot_clutter = self._conv_measurements(measurements_set, + mapping, + measurement_model, + convert_measurements) + + artists = [] + if plot_detections: + detection_array = np.array(list(plot_detections.values())) + # *detection_array.T unpacks detection_array by columns + # (same as passing in detection_array[:,0], detection_array[:,1], etc...) + artists.append(self.ax.scatter(*detection_array.T, **measurement_kwargs)) + measurements_handle = Line2D([], [], linestyle='', **measurement_kwargs) + + # Generate legend items for measurements + self.legend_dict[measurements_label] = measurements_handle + + if plot_clutter: + clutter_kwargs = kwargs.copy() + clutter_kwargs.update(dict(marker='2')) + clutter_array = np.array(list(plot_clutter.values())) + artists.append(self.ax.scatter(*clutter_array.T, **clutter_kwargs)) + clutter_handle = Line2D([], [], linestyle='', **clutter_kwargs) + clutter_label = "Clutter" + + # Generate legend items for clutter + self.legend_dict[clutter_label] = clutter_handle + + # Generate legend + artists.append(self.ax.legend(handles=self.legend_dict.values(), + labels=self.legend_dict.keys())) + return artists + + def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_label="Tracks", + err_freq=1, same_color=False, **kwargs): + """Plots track(s) + + Plots each track generated, generating a legend automatically. If ``uncertainty=True`` + and is being plotted in 2D, error ellipses are plotted. If being plotted in + 3D, uncertainty bars are plotted every :attr:`err_freq` measurement, default + plots uncertainty bars at every track step. Tracks are plotted as solid + lines with point markers and default colors. Uncertainty bars are plotted + with a default color which is the same for all tracks. + + Users can change linestyle, color and marker using keyword arguments. Uncertainty metrics + will also be plotted with the user defined colour and any changes will apply to all tracks. + + Parameters + ---------- + tracks : Collection of :class:`~.Track` + Collection of tracks which will be plotted. If not a collection, and instead a single + :class:`~.Track` type, the argument is modified to be a set to allow for iteration. + mapping: list + List of items specifying the mapping of the position + components of the state space. + uncertainty : bool + If True, function plots uncertainty ellipses or bars. + particle : bool + If True, function plots particles. + track_label: str + Label to apply to all tracks for legend. + err_freq: int + Frequency of error bar plotting on tracks. Default value is 1, meaning + error bars are plotted at every track step. + same_color: bool + Should all the tracks have the same color. Default False + \\*\\*kwargs: dict + Additional arguments to be passed to plot function. Defaults are ``linestyle="-"``, + ``marker='s'`` for :class:`~.Update` and ``marker='o'`` for other states. + + Returns + ------- + : list of :class:`matplotlib.artist.Artist` + List of artists that have been added to the axis. + """ + + tracks_kwargs = dict(linestyle='-', marker="s", color=None) + tracks_kwargs.update(kwargs) + if not isinstance(tracks, Collection) or isinstance(tracks, StateMutableSequence): + tracks = {tracks} # Make a set of length 1 + + # Plot tracks + artists = [] + track_colors = {} + for track in tracks: + # Get indexes for Update and non-Update states for styling markers + update_indexes = [] + not_update_indexes = [] + for n, state in enumerate(track): + if isinstance(state, Update): + update_indexes.append(n) + else: + not_update_indexes.append(n) + + data = np.concatenate( + [(getattr(state, 'mean', state.state_vector)[mapping, :]) + for state in track], + axis=1) + + line = self.ax.plot( + *data, + markevery=update_indexes, + **tracks_kwargs) + artists.extend(line) + if not_update_indexes: + artists.extend(self.ax.plot( + *data[:, not_update_indexes], + marker="o" if "marker" not in kwargs else kwargs['marker'], + linestyle='', + color=plt.getp(line[0], 'color'))) + track_colors[track] = plt.getp(line[0], 'color') + if same_color: + tracks_kwargs['color'] = plt.getp(line[0], 'color') + + if tracks: # If no tracks `line` won't be defined + # Assuming a single track or all plotted as the same colour then the following will + # work. Otherwise will just render the final track colour. + tracks_kwargs['color'] = plt.getp(line[0], 'color') + + # Generate legend items for track + track_handle = Line2D([], [], linestyle=tracks_kwargs['linestyle'], + marker=tracks_kwargs['marker'], color=tracks_kwargs['color']) + self.legend_dict[track_label] = track_handle + if uncertainty: + if self.dimension is Dimension.TWO: + # Plot uncertainty ellipses + for track in tracks: + HH = np.eye(track.ndim)[mapping, :] # Get position mapping matrix + check = err_freq - 1 # plot the first one + for state in track: + check += 1 + if check % err_freq: + continue + w, v = np.linalg.eig(HH @ state.covar @ HH.T) + if np.iscomplexobj(w) or np.iscomplexobj(v): + warnings.warn("Can not plot uncertainty for all states due to complex " + "eignevalues or eigenvectors", UserWarning) + continue + max_ind = np.argmax(w) + min_ind = np.argmin(w) + orient = np.arctan2(v[1, max_ind], v[0, max_ind]) + ellipse = Ellipse(xy=state.mean[mapping[:2], 0], + width=2 * np.sqrt(w[max_ind]), + height=2 * np.sqrt(w[min_ind]), + angle=np.rad2deg(orient), alpha=0.2, + color=track_colors[track]) + self.ax.add_artist(ellipse) + artists.append(ellipse) + + # Generate legend items for uncertainty ellipses + ellipse_handle = Ellipse((0.5, 0.5), 0.5, 0.5, alpha=0.2, + color=tracks_kwargs['color']) + ellipse_label = "Uncertainty" + self.legend_dict[ellipse_label] = ellipse_handle + # Generate legend + artists.append(self.ax.legend(handles=self.legend_dict.values(), + labels=self.legend_dict.keys(), + handler_map={Ellipse: _HandlerEllipse()})) + else: + # Plot 3D error bars on tracks + for track in tracks: + HH = np.eye(track.ndim)[mapping, :] # Get position mapping matrix + check = err_freq + for state in track: + if not check % err_freq: + w, v = np.linalg.eig(HH @ state.covar @ HH.T) + + xl = state.state_vector[mapping[0]] + yl = state.state_vector[mapping[1]] + zl = state.state_vector[mapping[2]] + + x_err = w[0] + y_err = w[1] + z_err = w[2] + + artists.extend( + self.ax.plot3D([xl+x_err, xl-x_err], [yl, yl], [zl, zl], + marker="_", color=tracks_kwargs['color'])) + artists.extend( + self.ax.plot3D([xl, xl], [yl+y_err, yl-y_err], [zl, zl], + marker="_", color=tracks_kwargs['color'])) + artists.extend( + self.ax.plot3D([xl, xl], [yl, yl], [zl+z_err, zl-z_err], + marker="_", color=tracks_kwargs['color'])) + check += 1 + + if particle: + if self.dimension is Dimension.TWO: + # Plot particles + for track in tracks: + for state in track: + data = state.state_vector[mapping[:2], :] + artists.extend(self.ax.plot(data[0], data[1], linestyle='', marker=".", + markersize=1, alpha=0.5)) + + # Generate legend items for particles + particle_handle = Line2D([], [], linestyle='', color="black", marker='.', + markersize=1) + particle_label = "Particles" + self.legend_dict[particle_label] = particle_handle + # Generate legend + artists.append(self.ax.legend(handles=self.legend_dict.values(), + labels=self.legend_dict.keys())) + else: + raise NotImplementedError("""Particle plotting is not currently supported for + 3D visualization""") + + else: + artists.append(self.ax.legend(handles=self.legend_dict.values(), + labels=self.legend_dict.keys())) + + return artists + + def plot_sensors(self, sensors, mapping=None, sensor_label="Sensors", **kwargs): + """Plots sensor(s) + + Plots sensors. Users can change the color and marker of sensors using keyword + arguments. Default is a black 'x' marker. + + Parameters + ---------- + sensors : Collection of :class:`~.Sensor` + Sensors to plot + mapping: list + List of items specifying the mapping of the position components of the + sensor's position. Default is either [0, 1] or [0, 1, 2] depending on `self.dimension` + sensor_label: str + Label to apply to all sensors for legend. + \\*\\*kwargs: dict + Additional arguments to be passed to plot function for sensors. Defaults are + ``marker='x'`` and ``color='black'``. + + Returns + ------- + : list of :class:`matplotlib.artist.Artist` + List of artists that have been added to the axis. + """ + + sensor_kwargs = dict(marker='x', color='black') + sensor_kwargs.update(kwargs) + + if not isinstance(sensors, Collection): + sensors = {sensors} # Make a set of length 1 + + if mapping is None: + mapping = list(range(self.dimension)) + + artists = [] + for sensor in sensors: + if self.dimension is Dimension.TWO: # plots the sensors in xy + artists.append(self.ax.scatter(sensor.position[mapping[0]], + sensor.position[mapping[1]], + **sensor_kwargs)) + elif self.dimension is Dimension.THREE: # plots the sensors in xyz + artists.extend(self.ax.plot3D(sensor.position[mapping[0]], + sensor.position[mapping[1]], + sensor.position[mapping[2]], + **sensor_kwargs)) + else: + raise NotImplementedError('Unsupported dimension type for sensor plotting') + self.legend_dict[sensor_label] = Line2D([], [], linestyle='', **sensor_kwargs) + artists.append(self.ax.legend(handles=self.legend_dict.values(), + labels=self.legend_dict.keys())) + return artists + + def set_equal_3daxis(self, axes=None): + """Plots minimum/maximum points with no linestyle to increase the plotting region to + simulate `.ax.axis('equal')` from matplotlib 2d plots which is not possible using 3d + projection. + + Parameters + ---------- + axes: list + List of dimension index specifying the equal axes, equal x and y = [0,1]. + Default is x,y [0,1]. + """ + if not axes: + axes = [0, 1] + if self.dimension is Dimension.THREE: + min_xyz = [0, 0, 0] + max_xyz = [0, 0, 0] + for n in range(3): + for line in self.ax.lines: + min_xyz[n] = np.min([min_xyz[n], *line.get_data_3d()[n]]) + max_xyz[n] = np.max([max_xyz[n], *line.get_data_3d()[n]]) + + extremes = np.max([x - y for x, y in zip(max_xyz, min_xyz)]) + equal_axes = [0, 0, 0] + for i in axes: + equal_axes[i] = 1 + lower = ([np.mean([x, y]) for x, y in zip(max_xyz, min_xyz)] - extremes/2) * equal_axes + upper = ([np.mean([x, y]) for x, y in zip(max_xyz, min_xyz)] + extremes/2) * equal_axes + ghosts = GroundTruthPath(states=[State(state_vector=lower), + State(state_vector=upper)]) + + self.ax.plot3D([state.state_vector[0] for state in ghosts], + [state.state_vector[1] for state in ghosts], + [state.state_vector[2] for state in ghosts], + linestyle="") + + def plot_density(self, state_sequences: Collection[StateMutableSequence], + index: Union[int, None] = -1, + mapping=(0, 2), n_bins=300, **kwargs): + """ + + Parameters + ---------- + state_sequences : a collection of :class:`~.StateMutableSequence` + Set of tracks which will be plotted. If not a set, and instead a single + :class:`~.Track` type, the argument is modified to be a set to allow for iteration. + index: int + Which index of the StateMutableSequences should be plotted. + Default value is '-1' which is the last state in the sequences. + index can be set to None if all indices of the sequence should be included in the plot + mapping: list + List of 2 items specifying the mapping of the x and y components of the state space. + n_bins : int + Size of the bins used to group the data + \\*\\*kwargs: dict + Additional arguments to be passed to pcolormesh function. + """ + if len(state_sequences) == 0: + raise ValueError("Skipping plotting density due to state_sequences being empty.") + if index is None: # Plot all states in the sequence + x = np.array([a_state.state_vector[mapping[0]] + for a_state_sequence in state_sequences + for a_state in a_state_sequence]) + y = np.array([a_state.state_vector[mapping[1]] + for a_state_sequence in state_sequences + for a_state in a_state_sequence]) + else: # Only plot one state out of the sequences + x = np.array([a_state_sequence.states[index].state_vector[mapping[0]] + for a_state_sequence in state_sequences]) + y = np.array([a_state_sequence.states[index].state_vector[mapping[1]] + for a_state_sequence in state_sequences]) + if np.allclose(x, y, atol=1e-10): + raise ValueError("Skipping plotting density due to x and y values are the same. " + "This leads to a singular matrix in the kde function.") + # Evaluate a gaussian kde on a regular grid of n_bins x n_bins over data extents + k = kde.gaussian_kde([x, y]) + xi, yi = np.mgrid[x.min():x.max():n_bins * 1j, y.min():y.max():n_bins * 1j] + zi = k(np.vstack([xi.flatten(), yi.flatten()])) + + # Make the plot + self.ax.pcolormesh(xi, yi, zi.reshape(xi.shape), shading='auto', **kwargs) + + # Ellipse legend patch (used in Tutorial 3) + @staticmethod + def ellipse_legend(ax, label_list, color_list, **kwargs): + """Adds an ellipse patch to the legend on the axes. One patch added for each item in + `label_list` with the corresponding color from `color_list`. + + Parameters + ---------- + ax : matplotlib.axes.Axes + Looks at the plot axes defined + label_list : list of str + Takes in list of strings intended to label ellipses in legend + color_list : list of str + Takes in list of colors corresponding to string/label + Must be the same length as label_list + \\*\\*kwargs: dict + Additional arguments to be passed to plot function. Default is ``alpha=0.2``. + """ + + ellipse_kwargs = dict(alpha=0.2) + ellipse_kwargs.update(kwargs) + + legend = ax.legend(handler_map={Ellipse: _HandlerEllipse()}) + handles, labels = ax.get_legend_handles_labels() + for color in color_list: + handle = Ellipse((0.5, 0.5), 0.5, 0.5, color=color, **ellipse_kwargs) + handles.append(handle) + for label in label_list: + labels.append(label) + legend._legend_box = None + legend._init_legend_box(handles, labels) + legend._set_loc(legend._loc) + legend.set_title(legend.get_title().get_text()) + + +class _HandlerEllipse(HandlerPatch): + def create_artists(self, legend, orig_handle, + xdescent, ydescent, width, height, fontsize, trans): + center = 0.5*width - 0.5*xdescent, 0.5*height - 0.5*ydescent + p = Ellipse(xy=center, width=width + xdescent, + height=height + ydescent) + self.update_prop(p, orig_handle, legend) + p.set_transform(trans) + return [p] + + +class MetricPlotter(ABC): + """Class for plotting Stone Soup metrics using matplotlib + + A plotting class which is used to simplify the process of plotting metrics. + Legends are automatically generated with each plot. + + """ + def __init__(self): + self.fig = None + self.axes = None + self.plottable_metrics = list() + + def plot_metrics(self, metrics, generator_names=None, metric_names=None, + combine_plots=True, **kwargs): + """Plots metrics + + Plots each plottable metric passed in to :attr:`metrics` across a series of subplots + and generates legend(s) automatically. Metrics are plotted as lines with default colors. + + Users can change linestyle, color and marker or other features using keyword arguments. + Any changes will apply to all metrics. + + Parameters + ---------- + metrics : dict of :class:`~.Metric` + Dictionary of generated metrics to be plotted. + generator_names: list of str + Generator(s) to extract specific metrics from :attr:`metrics` for plotting. + Default None to take all metrics. + metric_names: list of str + Specific metric(s) to extract from :class:`~.MetricGenerator` for plotting. + Default None to take all metrics in generators. + combine_plots: bool + Plot metrics of same type on the same subplot. Default True. + \\*\\*kwargs: dict + Additional arguments to be passed to plot function. Default is ``linestyle="-"``. + + Returns + ------- + : :class:`matplotlib.pyplot.figure` + Figure containing subplots displaying all plottable metrics. + """ + for metric_dict in metrics.values(): + for metric_name, metric in metric_dict.items(): + if isinstance(metric.value, List) \ + and all(isinstance(x, SingleTimeMetric) for x in metric.value): + self.plottable_metrics.append(metric_name) + + metrics_kwargs = dict(linestyle="-") + metrics_kwargs.update(kwargs) + + generator_names = list(metrics.keys()) if generator_names is None else generator_names + + # warning for user input metrics that will not be plotted + if metric_names is not None: + for metric_name in metric_names: + if metric_name not in self.plottable_metrics: + warnings.warn(f"{metric_name} " + f"is not a plottable metric and will not be plotted") + else: + metric_names = self.extract_metric_types(metrics) + + metrics_to_plot = self._extract_plottable_metrics(metrics, generator_names, metric_names) + + if combine_plots: + self.combine_plots(metrics_to_plot, metrics_kwargs) + else: + self.plot_separately(metrics_to_plot, metrics_kwargs) + + def _extract_plottable_metrics(self, metrics, generator_names, metric_names): + """ + Extract all plottable metrics from dict of generated metrics. + + Parameters + ---------- + metrics: dict of :class:`~.Metric` + Dictionary of generated metrics. + generator_names: list of str + Generator(s) to extract specific metrics from :attr:`metrics` for plotting. + metric_names: list of str + Specific metric(s) to extract from :class:`~.MetricGenerator` for plotting. + + Returns + ------- + : dict + Dict of all plottable metrics. + """ + metrics_dict = dict() + + for generator_name in generator_names: + for metric_name in metric_names: + if metric_name in metrics[generator_name].keys() and \ + metric_name in self.plottable_metrics: + if generator_name not in metrics_dict.keys(): + metrics_dict[generator_name] = \ + {metric_name: metrics[generator_name][metric_name]} + else: + metrics_dict[generator_name][metric_name] = \ + metrics[generator_name][metric_name] + + return metrics_dict + + def _count_subplots(self, metrics_to_plot, combine_plots): + """ + Calculate number of subplots needed to plot all metrics. + + Parameters + ---------- + metrics_to_plot: dict of :class:`~.Metric` + Dictionary of metrics to be plotted. + combine_plots: bool + Specifies whether same metric types should be plotted on same subplot. + + Returns + ------- + : int + Number of subplots to generate. + """ + if combine_plots: + metric_types = self.extract_metric_types(metrics_to_plot) + number_of_subplots = len(metric_types) + + else: + number_of_subplots = 0 + for generator in metrics_to_plot.keys(): + number_of_subplots += len(metrics_to_plot[generator]) + + return number_of_subplots + + @staticmethod + def extract_metric_types(metrics): + """ + Identify the different types of metric held in dict of metrics. + + Parameters + ---------- + metrics: dict of :class:`~.Metric` + Dictionary of metrics. + + Returns + ------- + : list + Sorted list of types of metric + """ + metric_types = set() + for generator in metrics.keys(): + for metric_key in metrics[generator].keys(): + metric_types.add(metric_key) + + metric_types = list(metric_types) + metric_types.sort() + + return metric_types + + def combine_plots(self, metrics_to_plot, metrics_kwargs): + """ + Generates one subplot for each different metric type and plots metrics of the same + type on same subplot. Metrics are plotted over time. + + Parameters + ---------- + metrics_to_plot: dict of :class:`~.Metric` + Dictionary of metrics to plot. + metrics_kwargs: dict + Keyword arguments to be passed to plot function. + + Returns + ------- + : :class:`matplotlib.pyplot.figure` + Figure containing subplots displaying metrics. + """ + # determine how many plots required - equal to number of metric types + number_of_subplots = self._count_subplots(metrics_to_plot, True) + + # initialise each subplot + self.fig, axes = plt.subplots(number_of_subplots, figsize=(10, 6*number_of_subplots)) + self.fig.subplots_adjust(hspace=0.3) + + # extract data for each subplot and plot it + metric_types = self.extract_metric_types(metrics_to_plot) + + self.axes = axes if isinstance(axes, Iterable) else [axes] + + # generate colour map for lines to be plotted + if 'color' not in metrics_kwargs.keys(): + colour_map = plt.cm.rainbow(np.linspace(0, 1, len(metrics_to_plot.keys()))) + else: + colour_map = metrics_kwargs['color'] + metrics_kwargs.pop('color') + + for metric_type, axis in zip(list(metric_types), self.axes): + artists = [] + legend_dict = {} + + colour_map_copy = iter(colour_map.copy()) + + for generator in metrics_to_plot.keys(): + for metric in metrics_to_plot[generator].keys(): + if metric == metric_type: + colour = next(colour_map_copy) + metric_values = metrics_to_plot[generator][metric].value + artists.extend(axis.plot([_.timestamp for _ in metric_values], + [_.value for _ in metric_values], + color=colour, + **metrics_kwargs)) + + metric_handle = Line2D([], [], linestyle=metrics_kwargs['linestyle'], + color=colour) + legend_dict[generator] = metric_handle + + # Generate legend + artists.append(axis.legend(handles=legend_dict.values(), + labels=legend_dict.keys())) + + y_label = metric_type.split(' at times')[0] + artists.extend(axis.set(title=metric_type.split(' at times')[0], + xlabel="Time", ylabel=y_label)) + + def plot_separately(self, metrics_to_plot, metrics_kwargs): + """ + Generates one subplot for each different individual metric and plots metric + values over time. + + Parameters + ---------- + metrics_to_plot: dict of :class:`~.Metric` + Dictionary of metrics to plot. + metrics_kwargs: dict + Keyword arguments to be passed to plot function. + + Returns + ------- + : :class:`matplotlib.pyplot.figure` + Figure containing subplots displaying metrics. + """ + metrics_kwargs['color'] = metrics_kwargs['color'] if \ + 'color' in metrics_kwargs.keys() else 'blue' + + # determine how many plots required - equal to number of metrics within the generators + number_of_subplots = self._count_subplots(metrics_to_plot, False) + + # initialise each plot + self.fig, axes = plt.subplots(number_of_subplots, figsize=(10, 6*number_of_subplots)) + self.fig.subplots_adjust(hspace=0.3) + + # extract data for each plot and plot it + all_metrics = {} + for generator in metrics_to_plot.keys(): + for metric in list(metrics_to_plot[generator].keys()): + all_metrics[f'{generator}: {metric}'] = metrics_to_plot[generator][metric] + + self.axes = axes if isinstance(axes, Iterable) else [axes] + + for metric, axis in zip(all_metrics.keys(), self.axes): + y_label = str(all_metrics[metric].title).split(' at times')[0] + axis.set(title=str(all_metrics[metric].title), xlabel='Time', ylabel=y_label) + metric_values = all_metrics[metric].value + axis.plot([_.timestamp for _ in metric_values], + [_.value for _ in metric_values], + **metrics_kwargs) + + # Generate legend + metric_handle = Line2D([], [], linestyle=metrics_kwargs['linestyle'], + color=metrics_kwargs['color']) + axis.legend(handles=[metric_handle], + labels=[metric.split(' at times')[0]]) + + def set_fig_title(self, title): + """ + Set title for the figure. + + Parameters + ---------- + title: str + Figure title text. + + Returns + ------- + Text instance of figure title. + """ + self.fig.suptitle(t=title) + + def set_ax_title(self, titles): + """ + Set axis titles for each axis in figure. + + Parameters + ---------- + titles: list of str + List of strings for title text for each axis. + + Returns + ------- + Text instance of axis titles. + """ + for axis, title in zip(self.axes, titles): + axis.set(title=title) + + +class Plotterly(_Plotter): + """Plotting class for building graphs of Stone Soup simulations using plotly + + A plotting class which is used to simplify the process of plotting ground truths, + measurements, clutter and tracks. Tracks can be plotted with uncertainty ellipses or + particles if required. Legends are automatically generated with each plot. + Three-dimensional plots can be created using the optional dimension parameter. + + Parameters + ---------- + dimension: enum \'Dimension\' + Optional parameter to specify 1D, 2D, or 3D plotting. + axis_labels: list + Optional parameter to specify the axis labels for non-xy dimensions. Default None, i.e., + "x" and "y". + \\*\\*kwargs: dict + Additional arguments to be passed to the Plotly.graph_objects Figure. + + Attributes + ---------- + fig: plotly.graph_objects.Figure + Generated figure to display graphs. + """ + def __init__(self, dimension=Dimension.TWO, axis_labels=None, **kwargs): + if dimension != Dimension.ONE: + if not axis_labels: + axis_labels = ["x", "y"] + else: + if axis_labels: + if len(axis_labels) == 1: + axis_labels = ["Time", axis_labels[0]] + else: + axis_labels = ["Time", "x"] + if go is None: + raise RuntimeError("Usage of Plotterly plotter requires installation of `plotly`") + + self.dimension = Dimension(dimension) # allows 1, 2, 3, + # Dimension(1), Dimension(2) or Dimension(3) + + from plotly import colors + layout_kwargs = dict( + xaxis_title=axis_labels[0], + yaxis_title=axis_labels[1], + colorway=colors.qualitative.Plotly, # Needed to match colours later. + ) + + if self.dimension == 3: + layout_kwargs.update(dict(scene_aspectmode='data')) # auto shapes fig to fit data well + + merge(layout_kwargs, kwargs) + + # Generate plot axes + self.fig = go.Figure(layout=layout_kwargs) + + @staticmethod + def _format_state_text(state): + text = [] + text.append(type(state).__name__) + text.append(getattr(state, 'mean', state.state_vector)) + text.append(state.timestamp) + text.extend([f"{key}: {value}" for key, value in getattr(state, 'metadata', {}).items()]) + + return "
".join((str(t) for t in text)) + + def _check_mapping(self, mapping): + if len(mapping) == 0: + raise ValueError("No indices provided in mapping.") + elif len(mapping) != self.dimension: + raise TypeError("Plotter dimension is not same as the mapping dimension.") + + def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwargs): + """Plots ground truth(s) + + Plots each ground truth path passed in to :attr:`truths` and generates a legend + automatically. Ground truths are plotted as dashed lines with default colors. + + Users can change line style, color and marker using keyword arguments. Any changes + will apply to all ground truths. + + Parameters + ---------- + truths : Collection of :class:`~.GroundTruthPath` + Collection of ground truths which will be plotted. If not a collection, + and instead a single :class:`~.GroundTruthPath` type, the argument is modified to be a + set to allow for iteration. + mapping: list + List of items specifying the mapping of the position components of the state space. + truths_label: str + Label for truth data. Default is "Ground Truth" + \\*\\*kwargs: dict + Additional arguments to be passed to scatter function. Default is + ``line=dict(dash="dash")``. + """ + if not isinstance(truths, Collection) or isinstance(truths, StateMutableSequence): + truths = {truths} + + self._check_mapping(mapping) # ensure mapping is compatible with plotter dimension + + truths_kwargs = dict( + mode="lines", line=dict(dash="dash"), legendgroup=truths_label, legendrank=100, + name=truths_label) + + if self.dimension == 3: # make ground truth line thicker so easier to see in 3d plot + truths_kwargs.update(dict(line=dict(width=8, dash="longdashdot"))) + + merge(truths_kwargs, kwargs) + add_legend = truths_kwargs['legendgroup'] not in {trace.legendgroup + for trace in self.fig.data} + + for truth in truths: + scatter_kwargs = truths_kwargs.copy() + if add_legend: + scatter_kwargs['showlegend'] = True + add_legend = False + else: + scatter_kwargs['showlegend'] = False + + if self.dimension == 1: + self.fig.add_scatter( + x=[state.timestamp for state in truth], + y=[state.state_vector[mapping[0]] for state in truth], + text=[self._format_state_text(state) for state in truth], + **scatter_kwargs) + + elif self.dimension == 2: + self.fig.add_scatter( + x=[state.state_vector[mapping[0]] for state in truth], + y=[state.state_vector[mapping[1]] for state in truth], + text=[self._format_state_text(state) for state in truth], + **scatter_kwargs) + + elif self.dimension == 3: + self.fig.add_scatter3d( + x=[state.state_vector[mapping[0]] for state in truth], + y=[state.state_vector[mapping[1]] for state in truth], + z=[state.state_vector[mapping[2]] for state in truth], + text=[self._format_state_text(state) for state in truth], + **scatter_kwargs) + + def plot_measurements(self, measurements, mapping, measurement_model=None, + measurements_label="Measurements", convert_measurements=True, **kwargs): + """Plots measurements + + Plots detections and clutter, generating a legend automatically. Detections are plotted as + blue circles by default unless the detection type is clutter. + If the detection type is :class:`~.Clutter` it is plotted as a yellow 'tri-up' marker. + + Users can change the color and marker of detections using keyword arguments but not for + clutter detections. + + Parameters + ---------- + measurements : Collection of :class:`~.Detection` + Detections which will be plotted. If measurements is a set of lists it is flattened. + mapping: list + List of items specifying the mapping of the position components of the state space. + measurement_model : :class:`~.Model`, optional + User-defined measurement model to be used in finding measurement state inverses if + they cannot be found from the measurements themselves. + measurements_label : str + Label for the measurements. Default is "Measurements". + convert_measurements: bool + Should the measurements be converted from measurement space to state space before + being plotted. Default is True + \\*\\*kwargs: dict + Additional arguments to be passed to scatter function for detections. Defaults are + ``marker=dict(color="#636EFA")``. + """ + + if not isinstance(measurements, Collection): + measurements = {measurements} + + if any(isinstance(item, set) for item in measurements): + measurements_set = chain.from_iterable(measurements) # Flatten into one set + else: + measurements_set = set(measurements) + + self._check_mapping(mapping) + + plot_detections, plot_clutter = self._conv_measurements(measurements_set, + mapping, + measurement_model, + convert_measurements) + + if plot_detections: + name = measurements_label + "
(Detections)" + measurement_kwargs = dict( + mode='markers', marker=dict(color='#636EFA'), + name=name, legendgroup=name, legendrank=200) + + if self.dimension == 3: # make markers smaller in 3d plot + measurement_kwargs.update(dict(marker=dict(size=4, color='#636EFA'))) + + merge(measurement_kwargs, kwargs) + if measurement_kwargs['legendgroup'] not in {trace.legendgroup + for trace in self.fig.data}: + measurement_kwargs['showlegend'] = True + else: + measurement_kwargs['showlegend'] = False + detection_array = np.asarray(list(plot_detections.values()), dtype=np.float64) + + if self.dimension == 1: + self.fig.add_scatter( + x=[state.timestamp for state in plot_detections.keys()], + y=detection_array[:, 0], + text=[self._format_state_text(state) for state in plot_detections.keys()], + **measurement_kwargs, + ) + elif self.dimension == 2: + self.fig.add_scatter( + x=detection_array[:, 0], + y=detection_array[:, 1], + text=[self._format_state_text(state) for state in plot_detections.keys()], + **measurement_kwargs, + ) + elif self.dimension == 3: + self.fig.add_scatter3d( + x=detection_array[:, 0], + y=detection_array[:, 1], + z=detection_array[:, 2], + text=[self._format_state_text(state) for state in plot_detections.keys()], + **measurement_kwargs, + ) + + if plot_clutter: + name = measurements_label + "
(Clutter)" + clutter_kwargs = dict( + mode='markers', marker=dict(symbol="star-triangle-up", color='#FECB52'), + name=name, legendgroup=name, legendrank=210) + + if self.dimension == 3: # update - star-triangle-up not in 3d plotly + measurement_kwargs.update(dict(marker=dict(size=4, symbol="diamond", + color='#FECB52'))) + + merge(clutter_kwargs, kwargs) + if clutter_kwargs['legendgroup'] not in {trace.legendgroup + for trace in self.fig.data}: + clutter_kwargs['showlegend'] = True + else: + clutter_kwargs['showlegend'] = False + clutter_array = np.asarray(list(plot_clutter.values()), dtype=np.float64) + + if self.dimension == 1: + self.fig.add_scatter( + x=[state.timestamp for state in plot_clutter.keys()], + y=clutter_array[:, 0], + text=[self._format_state_text(state) for state in plot_clutter.keys()], + **clutter_kwargs, + ) + elif self.dimension == 2: + self.fig.add_scatter( + x=clutter_array[:, 0], + y=clutter_array[:, 1], + text=[self._format_state_text(state) for state in plot_clutter.keys()], + **clutter_kwargs, + ) + elif self.dimension == 3: + self.fig.add_scatter3d( + x=clutter_array[:, 0], + y=clutter_array[:, 1], + z=clutter_array[:, 2], + text=[self._format_state_text(state) for state in plot_clutter.keys()], + **clutter_kwargs, + ) + + def get_next_color(self): + """ + Find the colour of the next plot. This approach to getting colour isn't ideal, but should + work in most cases... + Returns + ------- + dist : str + Hex string for a colour + """ + # Find how many sequences have been plotted so far. The current plot has already been added + # to fig.data, so -1 is needed + figure_index = len(self.fig.data) - 1 + + # Get the list of colours used for plotting + colorway = self.fig.layout.colorway + max_index = len(colorway) + + # Use the modulo operator to limit the colour index to limits of the colorway. + # If figure_index > max_index then colours will be reused + color_index = figure_index % max_index + return colorway[color_index] + + def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_label="Tracks", + ellipse_points=30, err_freq=1, same_color=False, **kwargs): + """Plots track(s) + + Plots each track generated, generating a legend automatically. If ``uncertainty=True`` + error ellipses are plotted. + Tracks are plotted as solid lines with point markers and default colors. + + Users can change line style, color and marker using keyword arguments. + + Parameters + ---------- + tracks : Collection of :class:`~.Track` + Collection of tracks which will be plotted. If not a collection, and instead a single + :class:`~.Track` type, the argument is modified to be a set to allow for iteration. + mapping: list + List of items specifying the mapping of the position + components of the state space. + uncertainty : bool + If True, function plots uncertainty ellipses. + particle : bool + If True, function plots particles. + track_label: str + Label to apply to all tracks for legend. + ellipse_points: int + Number of points for polygon approximating ellipse shape + err_freq: int + Frequency of error bar plotting on tracks. Default value is 1, meaning + error bars are plotted at every track step. + same_color: bool + Should all the tracks have the same colour. Default False + \\*\\*kwargs: dict + Additional arguments to be passed to scatter function. Defaults are + ``marker=dict(symbol='square')`` for :class:`~.Update` and + ``marker=dict(symbol='circle')`` for other states. + """ + if not isinstance(tracks, Collection) or isinstance(tracks, StateMutableSequence): + tracks = {tracks} # Make a set of length 1 + + self._check_mapping(mapping) # check size of mapping against dimension of plotter + + # Plot tracks + track_colors = {} + track_kwargs = dict(mode='markers+lines', legendgroup=track_label, legendrank=300) + + if self.dimension == 3: # change visuals to work well in 3d + track_kwargs.update(dict(line=dict(width=7)), marker=dict(size=4)) + merge(track_kwargs, kwargs) + add_legend = track_kwargs['legendgroup'] not in {trace.legendgroup + for trace in self.fig.data} + + if same_color: + color = track_kwargs.get('marker', {}).get('color') or \ + track_kwargs.get('line', {}).get('color') + + # Set the colour if it hasn't already been set + if color is None: + track_kwargs['marker'] = track_kwargs.get('marker', {}) + track_kwargs['marker']['color'] = self.get_next_color() + + for track in tracks: + scatter_kwargs = track_kwargs.copy() + scatter_kwargs['name'] = track.id + if add_legend: + scatter_kwargs['name'] = track_label + scatter_kwargs['showlegend'] = True + add_legend = False + else: + scatter_kwargs['showlegend'] = False + scatter_kwargs['marker'] = scatter_kwargs.get('marker', {}).copy() + if 'symbol' not in scatter_kwargs['marker']: + scatter_kwargs['marker']['symbol'] = [ + 'square' if isinstance(state, Update) else 'circle' for state in track] + + if len(self.fig.data) > 0: + track_colors[track] = (self.fig.data[-1].line.color + or self.fig.data[-1].marker.color + or self.get_next_color()) + else: + track_colors[track] = self.get_next_color() + + if self.dimension == 1: # plot 1D tracks + + if uncertainty or particle: + raise NotImplementedError + + self.fig.add_scatter( + x=[state.timestamp for state in track], + y=[float(getattr(state, 'mean', state.state_vector)[mapping[0]]) + for state in track], + text=[self._format_state_text(state) for state in track], + **scatter_kwargs) + + elif self.dimension == 2: # plot 2D tracks + + self.fig.add_scatter( + x=[float(getattr(state, 'mean', state.state_vector)[mapping[0]]) + for state in track], + y=[float(getattr(state, 'mean', state.state_vector)[mapping[1]]) + for state in track], + text=[self._format_state_text(state) for state in track], + **scatter_kwargs) + + elif self.dimension == 3: # plot 3D tracks + + if particle: + raise NotImplementedError + + # create empty error arrays + err_x = np.array([np.nan for _ in range(len(track))], dtype=float) + err_y = np.array([np.nan for _ in range(len(track))], dtype=float) + err_z = np.array([np.nan for _ in range(len(track))], dtype=float) + + if uncertainty: # find x,y,z error bars for relevant states + + for count, state in enumerate(track): + + if not count % err_freq: # ie count % err_freq = 0 + HH = np.eye(track.ndim)[mapping, :] # Get position mapping matrix + cov = HH @ state.covar @ HH.T + + err_x[count] = np.sqrt(cov[0, 0]) + err_y[count] = np.sqrt(cov[1, 1]) + err_z[count] = np.sqrt(cov[2, 2]) + + self.fig.add_scatter3d( + x=[float(getattr(state, 'mean', state.state_vector)[mapping[0]]) + for state in track], + error_x=dict(type='data', thickness=10, width=3, array=err_x), + + y=[float(getattr(state, 'mean', state.state_vector)[mapping[1]]) + for state in track], + error_y=dict(type='data', thickness=10, width=3, array=err_y), + + z=[float(getattr(state, 'mean', state.state_vector)[mapping[2]]) + for state in track], + error_z=dict(type='data', thickness=10, width=3, array=err_z), + # note that 3D error thickness seems to be broken in Plotly + + text=[self._format_state_text(state) for state in track], + **scatter_kwargs) + + track_colors[track] = (self.fig.data[-1].line.color + or self.fig.data[-1].marker.color + or self.get_next_color()) + + # earlier checking means this only applies to 2D. + if uncertainty and self.dimension == 2: + name = track_kwargs['legendgroup'] + "
(Ellipses)" + add_legend = name not in {trace.legendgroup for trace in self.fig.data} + for track in tracks: + ellipse_kwargs = dict( + mode='none', fill='toself', fillcolor=track_colors[track], + opacity=0.2, hoverinfo='skip', + legendgroup=name, name=name, + legendrank=track_kwargs['legendrank'] + 10) + for state in track: + points = self._generate_ellipse_points(state, mapping, ellipse_points) + if add_legend: + ellipse_kwargs['showlegend'] = True + add_legend = False + else: + ellipse_kwargs['showlegend'] = False + + self.fig.add_scatter(x=points[0, :], y=points[1, :], **ellipse_kwargs) + + if particle and self.dimension == 2: + name = track_kwargs['legendgroup'] + "
(Particles)" + add_legend = name not in {trace.legendgroup for trace in self.fig.data} + for track in tracks: + for state in track: + particle_kwargs = dict( + mode='markers', marker=dict(size=2), + opacity=0.4, hoverinfo='skip', + legendgroup=name, name=name, + legendrank=track_kwargs['legendrank'] + 20) + if add_legend: + particle_kwargs['showlegend'] = True + add_legend = False + else: + particle_kwargs['showlegend'] = False + data = state.state_vector[mapping[:2], :] + self.fig.add_scattergl(x=data[0], y=data[1], **particle_kwargs) + + @staticmethod + def _generate_ellipse_points(state, mapping, n_points=30): + """Generate error ellipse points for given state and mapping""" + HH = np.eye(state.ndim)[mapping, :] # Get position mapping matrix + w, v = np.linalg.eig(HH @ state.covar @ HH.T) + max_ind = np.argmax(w) + min_ind = np.argmin(w) + orient = np.arctan2(v[1, max_ind], v[0, max_ind]) + a = np.sqrt(w[max_ind]) + b = np.sqrt(w[min_ind]) + m = 1 - (b**2 / a**2) + + def func(x): + return np.sqrt(1 - (m**2 * np.sin(x)**2)) + + def func2(z): + return quad(func, 0, z)[0] + + c = 4 * a * func2(np.pi / 2) + + points = [] + for n in range(n_points): + def func3(x): + return n/n_points*c - a*func2(x) + + points.append((brentq(func3, 0, 2 * np.pi, xtol=1e-4))) + + c, s = np.cos(orient), np.sin(orient) + rotational_matrix = np.array(((c, -s), (s, c))) + points.append(points[0]) + points = np.array([[a * np.sin(i), b * np.cos(i)] for i in points]) + points = rotational_matrix @ points.T + return points + state.mean[mapping[:2], :] + + def plot_sensors(self, sensors, mapping=[0, 1], sensor_label="Sensors", **kwargs): + """Plots sensor(s) + + Plots sensors. Users can change the color and marker of sensors using keyword + arguments. Default is a black 'x' marker. + + Parameters + ---------- + sensors : Collection of :class:`~.Sensor` + Sensors to plot + mapping: list + List of items specifying the mapping of the position + components of the sensor's position. + sensor_label: str + Label to apply to all sensors for legend. + \\*\\*kwargs: dict + Additional arguments to be passed to scatter function for sensors. Defaults are + ``marker=dict(symbol='x', color='black')``. + """ + + if not isinstance(sensors, Collection): + sensors = {sensors} + + self._check_mapping(mapping) # ensure mapping is compatible with plotter dimension + + if self.dimension == 1 or self.dimension == 3: + raise NotImplementedError + + sensor_kwargs = dict(mode='markers', marker=dict(symbol='x', color='black'), + legendgroup=sensor_label, legendrank=50) + merge(sensor_kwargs, kwargs) + + sensor_kwargs['name'] = sensor_label + if sensor_kwargs['legendgroup'] not in {trace.legendgroup + for trace in self.fig.data}: + sensor_kwargs['showlegend'] = True + else: + sensor_kwargs['showlegend'] = True + + sensor_xy = np.array([sensor.position[mapping, 0] for sensor in sensors]) + self.fig.add_scatter(x=sensor_xy[:, 0], y=sensor_xy[:, 1], **sensor_kwargs) + + def hide_plot_traces(self, items_to_hide=None): + """Hide Plot Traces + + This function allows plotting items to be invisible as default. Users can toggle the plot + trace to visible. + + Parameters + ---------- + items_to_hide : Iterable[str] + The legend label (`legendgroups`) for the plot traces that should be invisible as + default. If left as ``None`` no traces will be shown. + """ + for fig_data in self.fig.data: + if items_to_hide is None or fig_data.legendgroup in items_to_hide: + fig_data.visible = "legendonly" + else: + fig_data.visible = None + + def show_plot_traces(self, items_to_show=None): + """Show Plot Traces + + This function allows specific plotting items to be shown as default. All labels not + mentioned in `items_to_show` will be invisible and can be manually toggled on. + + Parameters + ---------- + items_to_show : Iterable[str] + The legend label (`legendgroups`) for the plot traces that should be shown as + default. If left as ``None`` all traces will be shown. + """ + for fig_data in self.fig.data: + if items_to_show is None or fig_data.legendgroup in items_to_show: + fig_data.visible = None + else: + fig_data.visible = "legendonly" + + +class PolarPlotterly(_Plotter): + + def __init__(self, dimension=Dimension.TWO, **kwargs): + if go is None: + raise RuntimeError("Usage of Plotterly plotter requires installation of `plotly`") + if isinstance(dimension, type(Dimension.TWO)): + self.dimension = dimension + elif isinstance(dimension, int): + self.dimension = Dimension(dimension) + else: + raise TypeError("%s is an unsupported type for \'dimension\'; " + "expected type %s" % (type(dimension), type(Dimension.TWO))) + if self.dimension != dimension.TWO: + raise TypeError("Only 2D plotting currently supported") + + layout_kwargs = dict() + layout_kwargs.update(kwargs) + + # Generate plot axes + self.fig = go.Figure(layout=layout_kwargs) + + def plot_state_sequence(self, state_sequences, angle_mapping: int, range_mapping: int = None, + label="", **kwargs): + """Plots state sequence(s) + + Plots each state sequence passed in to :attr:`state_sequences` and generates a legend + automatically. + + Users can change line style, color and marker using keyword arguments. Any changes + will apply to all ground truths. + + Parameters + ---------- + state_sequences : Collection of :class:`~.StateMutableSequence` + Collection of state sequences which will be plotted. If not a collection, + and instead a single :class:`~.StateMutableSequence` type, the argument is modified + to be a set to allow for iteration. + angle_mapping: int + Specifying the mapping of the angular component of the state space to be plotted. + range_mapping: int + Specifying the mapping of the range component of the state space to be plotted. If + `None`, the angular component will be plotted against time. + label: str + Label for truth data. + \\*\\*kwargs: dict + Additional arguments to be passed to scatter function. Default is + ``mode=marker``. + The default unit for the angular component is radians. This can be changed to degrees + with the keyword argument ``thetaunit='degrees'``. + """ + + if not isinstance(state_sequences, Collection) \ + or isinstance(state_sequences, StateMutableSequence): + state_sequences = {state_sequences} + + plotting_kwargs = dict( + mode="markers", legendgroup=label, legendrank=200, + name=label, thetaunit="radians") + merge(plotting_kwargs, kwargs) + add_legend = plotting_kwargs['legendgroup'] not in {trace.legendgroup + for trace in self.fig.data} + + for state_sequence in state_sequences: + if range_mapping is None: + r = [state.timestamp for state in state_sequence] + else: + r = [float(state.state_vector[range_mapping]) for state in state_sequence] + bearings = [float(state.state_vector[angle_mapping]) for state in state_sequence] + + scatter_kwargs = plotting_kwargs.copy() + if add_legend: + scatter_kwargs['showlegend'] = True + add_legend = False + else: + scatter_kwargs['showlegend'] = False + + polar_plot = go.Scatterpolar( + r=r, + theta=bearings, **scatter_kwargs) + self.fig.add_trace(polar_plot) + + def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwargs): + """Plots ground truth(s) + + Plots each ground truth path passed in to :attr:`truths` and generates a legend + automatically. Ground truths are plotted as dashed lines with default colors. + + Users can change line style, color and marker using keyword arguments. Any changes + will apply to all ground truths. + + Parameters + ---------- + truths : Collection of :class:`~.GroundTruthPath` + Collection of ground truths which will be plotted. If not a collection, + and instead a single :class:`~.GroundTruthPath` type, the argument is modified to be a + set to allow for iteration. + mapping: list + List of items specifying the mapping of the position components of the state space. + truths_label: str + Label for truth data. Default is "Ground Truth". + \\*\\*kwargs: dict + Additional arguments to be passed to scatter function. Default is + ``line=dict(dash="dash")``. + """ + truths_kwargs = dict(mode="lines", line=dict(dash="dash"), legendrank=100) + merge(truths_kwargs, kwargs) + angle_mapping = mapping[0] + if len(mapping) > 1: + range_mapping = mapping[1] + else: + range_mapping = None + self.plot_state_sequence(state_sequences=truths, angle_mapping=angle_mapping, + range_mapping=range_mapping, label=truths_label, **truths_kwargs) + + def plot_measurements(self, measurements, mapping, measurement_model=None, + measurements_label="Measurements", convert_measurements=True, **kwargs): + """Plots measurements + + Plots detections and clutter, generating a legend automatically. Detections are plotted as + blue circles by default unless the detection type is clutter. + If the detection type is :class:`~.Clutter` it is plotted as a yellow 'tri-up' marker. + + Users can change the color and marker of detections using keyword arguments but not for + clutter detections. + + Parameters + ---------- + measurements : Collection of :class:`~.Detection` + Detections which will be plotted. If measurements is a set of lists it is flattened. + mapping: list + List of items specifying the mapping of the position components of the state space. + measurement_model : :class:`~.Model`, optional + User-defined measurement model to be used in finding measurement state inverses if + they cannot be found from the measurements themselves. + measurements_label : str + Label for the measurements. Default is "Measurements". + convert_measurements: bool + Should the measurements be converted before being plotted. Default is True. + \\*\\*kwargs: dict + Additional arguments to be passed to scatter function for detections. Defaults are + ``marker=dict(color="#636EFA")``. + """ + + if not isinstance(measurements, Collection): + measurements = {measurements} + + if any(isinstance(item, set) for item in measurements): + measurements_set = chain.from_iterable(measurements) # Flatten into one set + else: + measurements_set = set(measurements) + + plot_detections, plot_clutter = self._conv_measurements(measurements_set, + mapping, + measurement_model, + convert_measurements) + + angle_mapping = 0 + if len(mapping) > 1: + range_mapping = 1 + else: + range_mapping = None + + if plot_detections: + name = measurements_label + "
(Detections)" + measurement_kwargs = dict(mode='markers', marker=dict(color='#636EFA'), legendrank=200) + merge(measurement_kwargs, kwargs) + plotting_data = [State(state_vector=plotting_state_vector, + timestamp=det.timestamp) + for det, plotting_state_vector in plot_detections.items()] + + self.plot_state_sequence(state_sequences=[plotting_data], angle_mapping=angle_mapping, + range_mapping=range_mapping, label=name, + **measurement_kwargs) + + if plot_clutter: + name = measurements_label + "
(Clutter)" + clutter_kwargs = dict(mode='markers', legendrank=210, + marker=dict(symbol="star-triangle-up", color='#FECB52')) + merge(clutter_kwargs, kwargs) + plotting_data = [State(state_vector=plotting_state_vector, + timestamp=det.timestamp) + for det, plotting_state_vector in plot_clutter.items()] + + self.plot_state_sequence(state_sequences=[plotting_data], angle_mapping=angle_mapping, + range_mapping=range_mapping, label=name, + **clutter_kwargs) + + def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_label="Tracks", + **kwargs): + """Plots track(s) + + Plots each track generated, generating a legend automatically. If ``uncertainty=True`` + error ellipses are plotted. + Tracks are plotted as solid lines with point markers and default colors. + + Users can change line style, color and marker using keyword arguments. + + Parameters + ---------- + tracks : Collection of :class:`~.Track` + Collection of tracks which will be plotted. If not a collection, and instead a single + :class:`~.Track` type, the argument is modified to be a set to allow for iteration. + mapping: list + List of items specifying the mapping of the position + components of the state space. + uncertainty : bool + If True, function plots uncertainty ellipses. + particle : bool + If True, function plots particles. + track_label: str + Label to apply to all tracks for legend. + \\*\\*kwargs: dict + Additional arguments to be passed to scatter function. Defaults are + ``mode='markers+lines'``. + """ + if uncertainty or particle: + raise NotImplementedError + + track_kwargs = dict(mode='markers+lines', legendrank=300) + merge(track_kwargs, kwargs) + angle_mapping = mapping[0] + if len(mapping) > 1: + range_mapping = mapping[1] + else: + range_mapping = None + self.plot_state_sequence(state_sequences=tracks, angle_mapping=angle_mapping, + range_mapping=range_mapping, label=track_label, **track_kwargs) + + def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs): + raise NotImplementedError + + +class _AnimationPlotterDataClass(Base): + plotting_data = Property(Iterable[State]) + plotting_label: str = Property() + plotting_keyword_arguments: dict = Property() + + +class AnimationPlotter(_Plotter): + + def __init__(self, dimension=Dimension.TWO, x_label: str = "$x$", y_label: str = "$y$", + title: str = None, legend_kwargs: dict = {}, **kwargs): + + self.figure_kwargs = {"figsize": (10, 6)} + self.figure_kwargs.update(kwargs) + if dimension != Dimension.TWO: + raise NotImplementedError + + self.legend_kwargs = dict() + self.legend_kwargs.update(legend_kwargs) + + self.x_label: str = x_label + self.y_label: str = y_label + + if title: + title += "\n" + self.title: str = title + + self.plotting_data: List[_AnimationPlotterDataClass] = [] + + self.animation_output: animation.FuncAnimation = None + + def run(self, + times_to_plot: List[datetime] = None, + plot_item_expiry: Optional[timedelta] = None, + **kwargs): + """Run the animation + + Parameters + ---------- + times_to_plot : List of :class:`~.datetime` + List of datetime objects of when to refresh and draw the animation. Default `None`, + where unique timestamps of data will be used. + plot_item_expiry: :class:`~.timedelta`, Optional + Describes how long states will remain present in the figure. Default value of None + means data is shown indefinitely + \\*\\*kwargs: dict + Additional arguments to be passed to the animation.FuncAnimation function + """ + if times_to_plot is None: + times_to_plot = sorted({ + state.timestamp + for plotting_data in self.plotting_data + for state in plotting_data.plotting_data}) + + self.animation_output = self.run_animation( + times_to_plot=times_to_plot, + data=self.plotting_data, + plot_item_expiry=plot_item_expiry, + x_label=self.x_label, + y_label=self.y_label, + figure_kwargs=self.figure_kwargs, + legend_kwargs=self.legend_kwargs, + animation_input_kwargs=kwargs, + plot_title=self.title + ) + return self.animation_output + + def save(self, filename='example.mp4', **kwargs): + """Save the animation + + Parameters + ---------- + filename : str + filename of animation file + \\*\\*kwargs: dict + Additional arguments to be passed to the animation.save function + """ + if self.animation_output is None: + raise ValueError("Animation hasn't been run yet. Therefore there is no animation to " + "save") + + self.animation_output.save(filename, **kwargs) + + def plot_ground_truths(self, truths, mapping: List[int], truths_label: str = "Ground Truth", + **kwargs): + """Plots ground truth(s) + + Plots each ground truth path passed in to :attr:`truths` and generates a legend + automatically. Ground truths are plotted as dashed lines with default colors. + + Users can change linestyle, color and marker using keyword arguments. Any changes + will apply to all ground truths. + + Parameters + ---------- + truths : Collection of :class:`~.GroundTruthPath` + Collection of ground truths which will be plotted. If not a collection and instead a + single :class:`~.GroundTruthPath` type, the argument is modified to be a set to allow + for iteration. + mapping: list + List of items specifying the mapping of the position components of the state space. + truths_label: str + Label for truth data. Default is "Ground Truth" + \\*\\*kwargs: dict + Additional arguments to be passed to plot function. Default is ``linestyle="--"``. + """ + + truths_kwargs = dict(linestyle="--") + truths_kwargs.update(kwargs) + self.plot_state_mutable_sequence(truths, mapping, truths_label, **truths_kwargs) + + def plot_tracks(self, tracks, mapping: List[int], uncertainty=False, particle=False, + track_label="Tracks", **kwargs): + """Plots track(s) + + Plots each track generated, generating a legend automatically. Tracks are plotted as solid + lines with point markers and default colors. Users can change linestyle, color and marker + using keyword arguments. + + Parameters + ---------- + tracks : Collection of :class:`~.Track` + Collection of tracks which will be plotted. If not a collection, and instead a single + :class:`~.Track` type, the argument is modified to be a set to allow for iteration. + mapping: list + List of items specifying the mapping of the position + components of the state space. + uncertainty : bool + Currently not implemented. If True, an error is raised + particle : bool + Currently not implemented. If True, an error is raised + track_label: str + Label to apply to all tracks for legend. + \\*\\*kwargs: dict + Additional arguments to be passed to plot function. Defaults are ``linestyle="-"``, + ``marker='s'`` for :class:`~.Update` and ``marker='o'`` for other states. + """ + if uncertainty or particle: + raise NotImplementedError + + tracks_kwargs = dict(linestyle='-', marker="s", color=None) + tracks_kwargs.update(kwargs) + self.plot_state_mutable_sequence(tracks, mapping, track_label, **tracks_kwargs) + + def plot_state_mutable_sequence(self, state_mutable_sequences, mapping: List[int], label: str, + **plotting_kwargs): + """Plots State Mutable Sequence + + Parameters + ---------- + state_mutable_sequences : Collection of :class:`~.StateMutableSequence` + Collection of states to be plotted + mapping: list + List of items specifying the mapping of the position components of the state space. + label : str + User-defined measurement model to be used in finding measurement state inverses if + they cannot be found from the measurements themselves. + \\*\\*kwargs: dict + Additional arguments to be passed to plot function for states. + """ + + if not isinstance(state_mutable_sequences, Collection) or \ + isinstance(state_mutable_sequences, StateMutableSequence): + state_mutable_sequences = {state_mutable_sequences} # Make a set of length 1 + + for idx, state_mutable_sequence in enumerate(state_mutable_sequences): + if idx == 0: + this_plotting_label = label + else: + this_plotting_label = None + + self.plotting_data.append(_AnimationPlotterDataClass( + plotting_data=[State(state_vector=[state.state_vector[mapping[0]], + state.state_vector[mapping[1]]], + timestamp=state.timestamp) + for state in state_mutable_sequence], + plotting_label=this_plotting_label, + plotting_keyword_arguments=plotting_kwargs + )) + + def plot_measurements(self, measurements, mapping, measurement_model=None, + measurements_label="", convert_measurements=True, **kwargs): + """Plots measurements + + Plots detections and clutter, generating a legend automatically. Detections are plotted as + blue circles by default unless the detection type is clutter. + If the detection type is :class:`~.Clutter` it is plotted as a yellow 'tri-up' marker. + + Users can change the color and marker of detections using keyword arguments but not for + clutter detections. + + Parameters + ---------- + measurements : Collection of :class:`~.Detection` + Detections which will be plotted. If measurements is a set of lists it is flattened. + mapping: list + List of items specifying the mapping of the position components of the state space. + measurement_model : :class:`~.Model`, optional + User-defined measurement model to be used in finding measurement state inverses if + they cannot be found from the measurements themselves. + measurements_label: str + Label for measurements. Default will be "Detections" or "Clutter" + convert_measurements: bool + Should the measurements be converted from measurement space to state space before + being plotted. Default is True + \\*\\*kwargs: dict + Additional arguments to be passed to plot function for detections. Defaults are + ``marker='o'`` and ``color='b'``. + """ + + measurement_kwargs = dict(marker='o', color='b') + measurement_kwargs.update(kwargs) + + if not isinstance(measurements, Collection): + measurements = {measurements} # Make a set of length 1 + + if any(isinstance(item, set) for item in measurements): + measurements_set = chain.from_iterable(measurements) # Flatten into one set + else: + measurements_set = measurements + + plot_detections, plot_clutter = self._conv_measurements(measurements_set, + mapping, + measurement_model, + convert_measurements) + + if measurements_label != "": + measurements_label = measurements_label + " " + + if plot_detections: + detection_kwargs = dict(linestyle='', marker='o', color='b') + detection_kwargs.update(kwargs) + self.plotting_data.append(_AnimationPlotterDataClass( + plotting_data=[State(state_vector=plotting_state_vector, + timestamp=detection.timestamp) + for detection, plotting_state_vector in plot_detections.items()], + plotting_label=measurements_label + "Detections", + plotting_keyword_arguments=detection_kwargs + )) + + if plot_clutter: + clutter_kwargs = dict(linestyle='', marker='2', color='y') + clutter_kwargs.update(kwargs) + self.plotting_data.append(_AnimationPlotterDataClass( + plotting_data=[State(state_vector=plotting_state_vector, + timestamp=detection.timestamp) + for detection, plotting_state_vector in plot_clutter.items()], + plotting_label=measurements_label + "Clutter", + plotting_keyword_arguments=clutter_kwargs + )) + + def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs): + raise NotImplementedError + + @classmethod + def run_animation(cls, + times_to_plot: List[datetime], + data: Iterable[_AnimationPlotterDataClass], + plot_item_expiry: Optional[timedelta] = None, + axis_padding: float = 0.1, + figure_kwargs: dict = {}, + animation_input_kwargs: dict = {}, + legend_kwargs: dict = {}, + x_label: str = "$x$", + y_label: str = "$y$", + plot_title: str = None + ) -> animation.FuncAnimation: + """ + Parameters + ---------- + times_to_plot : Iterable[datetime] + All the times that the plotter should plot + data : Iterable[datetime] + All the data that should be plotted + plot_item_expiry: timedelta + How long a state should be displayed for. Default value of None + means data is shown indefinitely + axis_padding: float + How much extra space should be given around the edge of the plot + figure_kwargs: dict + Keyword arguments for the pyplot figure function. See matplotlib.pyplot.figure for more + details + animation_input_kwargs: dict + Keyword arguments for FuncAnimation class. See matplotlib.animation.FuncAnimation for + more details. Default values are: blit=False, repeat=False, interval=50 + legend_kwargs: dict + Keyword arguments for the pyplot legend function. See matplotlib.pyplot.legend for more + details + x_label: str + Label for the x axis + y_label: str + Label for the y axis + plot_title: str + Title for the plot + + Returns + ------- + : animation.FuncAnimation + Animation object + """ + + animation_kwargs = dict(blit=False, repeat=False, interval=50) # milliseconds + animation_kwargs.update(animation_input_kwargs) + + fig1 = plt.figure(**figure_kwargs) + + the_lines = [] + plotting_data = [] + legends_key = [] + + for a_plot_object in data: + if a_plot_object.plotting_data is not None: + the_data = np.array( + [a_state.state_vector for a_state in a_plot_object.plotting_data]) + if len(the_data) == 0: + continue + the_lines.append( + plt.plot([], # the_data[:1, 0], + [], # the_data[:1, 1], + **a_plot_object.plotting_keyword_arguments)[0]) + + legends_key.append(a_plot_object.plotting_label) + plotting_data.append(a_plot_object.plotting_data) + + if axis_padding: + [x_limits, y_limits] = [ + [min(state.state_vector[idx] for line in data for state in line.plotting_data), + max(state.state_vector[idx] for line in data for state in line.plotting_data)] + for idx in [0, 1]] + + for axis_limits in [x_limits, y_limits]: + limit_padding = axis_padding * (axis_limits[1] - axis_limits[0]) + # The casting to float to ensure the limits contain do not contain angle classes + axis_limits[0] = float(axis_limits[0] - limit_padding) + axis_limits[1] = float(axis_limits[1] + limit_padding) + + plt.xlim(x_limits) + plt.ylim(y_limits) + else: + plt.axis('equal') + + plt.xlabel(x_label) + plt.ylabel(y_label) + + lines_with_legend = [line for line, label in zip(the_lines, legends_key) + if label is not None] + plt.legend(lines_with_legend, [label for label in legends_key if label is not None], + **legend_kwargs) + + if plot_item_expiry is None: + min_plot_time = min(state.timestamp + for line in data + for state in line.plotting_data) + min_plot_times = [min_plot_time] * len(times_to_plot) + else: + min_plot_times = [time - plot_item_expiry for time in times_to_plot] + + line_ani = animation.FuncAnimation(fig1, cls.update_animation, + frames=len(times_to_plot), + fargs=(the_lines, plotting_data, min_plot_times, + times_to_plot, plot_title), + **animation_kwargs) + + plt.draw() + + return line_ani + + @staticmethod + def update_animation(index: int, lines: List[Line2D], data_list: List[List[State]], + start_times: List[datetime], end_times: List[datetime], title: str): + """ + Parameters + ---------- + index : int + Which index of the start_times and end_times should be used + lines : List[Line2D] + The data that will be plotted, to be plotted. + data_list : List[List[State]] + All the data that should be plotted + start_times : List[datetime] + lowest (earliest) time for an item to be plotted + end_times : List[datetime] + highest (latest) time for an item to be plotted + title: str + Title for the plot + + Returns + ------- + : List[Line2D] + The data that will be plotted + """ + + min_time = start_times[index] + max_time = end_times[index] + + if title is None: + title = "" + plt.title(title + str(max_time)) + for i, data_source in enumerate(data_list): + + if data_source is not None: + the_data = np.array([a_state.state_vector for a_state in data_source + if min_time <= a_state.timestamp <= max_time]) + if the_data.size > 0: + lines[i].set_data(the_data[:, 0], + the_data[:, 1]) + else: + lines[i].set_data([], + []) + return lines + + +class AnimatedPlotterly(_Plotter): + """ + Class for a 2D animated plotter that uses Plotly graph objects rather than matplotlib. + This gives the user the ability to see how tracking works through time, while being + able to interact with tracks, truths, etc, in the same way that is enabled by + Plotly static plots. + + Simplifies the process of plotting ground truths, measurements, clutter, and tracks. + Tracks can be plotted with uncertainty ellipses or particles if required. Legends + are automatically generated with each plot. + + Parameters + ---------- + timesteps: Collection + Collection of equally-spaced timesteps. Each animation frame is a timestep. + tail_length: float + Percentage of sim time for which previous values will still be displayed for. + Value can be between 0 and 1. Default is 0.3. + equal_size: bool + Makes x and y axes equal when figure is resized. Default is False. + sim_duration: int + Time taken to run animation (s). Default is 6 + \\*\\*kwargs + Additional arguments to be passed in the initialisation. + + Attributes + ---------- + + """ + + def __init__(self, timesteps, tail_length=0.3, equal_size=False, + sim_duration=6, **kwargs): + """ + Initialise the figure and checks that inputs are correctly formatted. + Creates an empty frame for each timestep, and configures + the buttons and slider. + + + """ + if go is None or colors is None: + raise RuntimeError("Usage of Plotterly plotter requires installation of `plotly`") + + self.equal_size = equal_size + + # checking that there are multiple timesteps + if len(timesteps) < 2: + raise ValueError("Must be at least 2 timesteps for animation.") + + # checking that timesteps are evenly spaced + time_spaces = np.unique(np.diff(timesteps)) + + # gives the unique values of time gaps between timesteps. If this contains more than + # one value, then timesteps are not all evenly spaced which is an issue. + if len(time_spaces) != 1: + warnings.warn("Timesteps are not equally spaced, so the passage of time is not linear") + self.timesteps = timesteps + + # checking input to tail_length + if tail_length > 1 or tail_length < 0: + raise ValueError("Tail length should be between 0 and 1") + self.tail_length = tail_length + + # checking sim_duration + if sim_duration <= 0: + raise ValueError("Simulation duration must be positive") + + # time window is calculated as sim_length * tail_length. This is + # the window of time for which past plots are still visible + self.time_window = (timesteps[-1] - timesteps[0]) * tail_length + + self.colorway = colors.qualitative.Plotly[1:] # plotting colours + + self.all_masks = dict() # dictionary to be filled up later + + self.plotting_function_called = False # keeps track if anything has been plotted or not + # so that only the first data plotted will override the default axis max and mins. + + self.fig = go.Figure() + + layout_kwargs = dict( + xaxis=dict(title=dict(text="x")), + yaxis=dict(title=dict(text="y")), + colorway=self.colorway, # Needed to match colours later. + height=550, + autosize=True + ) + # layout_kwargs.update(kwargs) + self.fig.update_layout(layout_kwargs) + + # initialise frames according to simulation timesteps + self.fig.frames = [dict( + name=str(time), + data=[], + traces=[] + ) for time in timesteps] + + self.fig.update_xaxes(range=[0, 10]) + self.fig.update_yaxes(range=[0, 10]) + + frame_duration = sim_duration * 1000 / len(self.fig.frames) + + # if the gap between timesteps is greater than a day, it isn't necessary + # to display hour and minute information, so remove this to give a cleaner display. + # a and b are used in the slider steps label later + if time_spaces[0] >= timedelta(days=1): + start_cut_off = None + end_cut_off = 10 + + # if the simulation is over a day long, display all information which + # looks clunky but is necessary + elif timesteps[-1] - timesteps[0] > timedelta(days=1): + start_cut_off = None + end_cut_off = None + + # otherwise, remove day information and just show + # hours, mins, etc. which is cleaner to look at + else: + start_cut_off = 11 + end_cut_off = None + + # create button and slider + updatemenus = [dict(type='buttons', + buttons=[{ + "args": [None, + {"frame": {"duration": frame_duration, "redraw": True}, + "fromcurrent": True, "transition": {"duration": 0}}], + "label": "Play", + "method": "animate" + }, { + "args": [[None], {"frame": {"duration": 0, "redraw": True}, + "mode": "immediate", + "transition": {"duration": 0}}], + "label": "Stop", + "method": "animate" + }], + direction='left', + pad=dict(r=10, t=75), + showactive=True, x=0.1, y=0, xanchor='right', yanchor='top') + ] + sliders = [{'yanchor': 'top', + 'xanchor': 'left', + 'currentvalue': {'font': {'size': 16}, 'prefix': 'Time: ', 'visible': True, + 'xanchor': 'right'}, + 'transition': {'duration': frame_duration, 'easing': 'linear'}, + 'pad': {'b': 10, 't': 50}, + 'len': 0.9, 'x': 0.1, 'y': 0, + 'steps': [{'args': [[frame.name], { + 'frame': {'duration': 1.0, 'easing': 'linear', 'redraw': True}, + 'transition': {'duration': 0, 'easing': 'linear'}}], + 'label': frame.name[start_cut_off: end_cut_off], + 'method': 'animate'} for frame in + self.fig.frames + ]}] + self.fig.update_layout(updatemenus=updatemenus, sliders=sliders) + self.fig.update_layout(kwargs) + + def show(self): + """ + Display the animation. + """ + return self.fig + + def _resize(self, data, type="track"): + """ + Reshape figure so that everything is in view. + + Parameters + ---------- + + data: + Collection of values that are being added to the figure. + Will be a list if coming from plot_ground_Truths or + plot_tracks, but will be a dictionary if coming from plot_measurements. + """ + + # fill in all data. If there is no data, fill all_x, all_y with current axis limits + if not data: + all_x = list(self.fig.layout.xaxis.range) + all_y = list(self.fig.layout.xaxis.range) + else: + all_x = list() + all_y = list() + + # fill in data + if type == "measurements": + + for key, item in data.items(): + all_x.extend(data[key]["x"]) + all_y.extend(data[key]["y"]) + + elif type in ("ground_truth", "tracks"): + + for n, _ in enumerate(data): + all_x.extend(data[n]["x"]) + all_y.extend(data[n]["y"]) + + elif type == "sensor": + sensor_xy = np.array([sensor.position[[0, 1], 0] for sensor in data]) + all_x.extend(sensor_xy[:, 0]) + all_y.extend(sensor_xy[:, 1]) + + elif type == "particle_or_uncertainty": + # data comes in format of list of dictionaries. Each dictionary contains 'x' and 'y', + # which are a list of lists. + for dictionary in data: + for x_values in dictionary["x"]: + all_x.extend([np.nanmax(x_values), np.nanmin(x_values)]) + for y_values in dictionary["y"]: + all_y.extend([np.nanmax(y_values), np.nanmin(y_values)]) + + xmax = max(all_x) + ymax = max(all_y) + xmin = min(all_x) + ymin = min(all_y) + + if self.equal_size: + xmax = ymax = max(xmax, ymax) + xmin = ymin = min(xmin, ymin) + + # if it's first time plotting data, want to ensure plotter is bound to that data + # and not the default values. Issues arise if the initial plotted data is much + # smaller than the default 0 to 10 values. + if not self.plotting_function_called: + + self.fig.update_xaxes(range=[xmin, xmax]) + self.fig.update_yaxes(range=[ymin, ymax]) + + # need to check if it's actually necessary to resize or not + if xmax >= self.fig.layout.xaxis.range[1] or xmin <= self.fig.layout.xaxis.range[0]: + + xmax = max(xmax, self.fig.layout.xaxis.range[1]) + xmin = min(xmin, self.fig.layout.xaxis.range[0]) + xrange = xmax - xmin + + # update figure while adding a small buffer to the mins and maxes + self.fig.update_xaxes(range=[xmin - xrange / 20, xmax + xrange / 20]) + + if ymax >= self.fig.layout.yaxis.range[1] or ymin <= self.fig.layout.yaxis.range[0]: + + ymax = max(ymax, self.fig.layout.yaxis.range[1]) + ymin = min(ymin, self.fig.layout.yaxis.range[0]) + yrange = ymax - ymin + + self.fig.update_yaxes(range=[ymin - yrange / 20, ymax + yrange / 20]) + + def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", + resize=True, **kwargs): + + """Plots ground truth(s) + + Plots each ground truth path passed in to :attr:`truths` and generates a legend + automatically. Ground truths are plotted as dashed lines with default colors. + + Users can change linestyle, color and marker using keyword arguments. Any changes + will apply to all ground truths. + + Parameters + ---------- + truths : Collection of :class:`~.GroundTruthPath` + Collection of ground truths which will be plotted. If not a collection and instead a + single :class:`~.GroundTruthPath` type, the argument is modified to be a set to allow + for iteration. + mapping: list + List of items specifying the mapping of the position components of the state space. + truths_label: str + Name of ground truths in legend/plot + resize: bool + if True, will resize figure to ensure that ground truths are in view + \\*\\*kwargs: dict + Additional arguments to be passed to plot function. Default is ``linestyle="--"``. + + """ + + if not isinstance(truths, Collection) or isinstance(truths, StateMutableSequence): + truths = {truths} # Make a set of length 1 + + data = [dict() for _ in truths] # put all data into one place for later plotting + for n, truth in enumerate(truths): + + # initialise arrays that go inside the dictionary + data[n].update(x=np.zeros(len(truth)), + y=np.zeros(len(truth)), + time=np.array([0 for _ in range(len(truth))], dtype=object), + time_str=np.array([0 for _ in range(len(truth))], dtype=object), + type=np.array([0 for _ in range(len(truth))], dtype=object)) + + for k, state in enumerate(truth): + # fill the arrays here + data[n]["x"][k] = state.state_vector[mapping[0]] + data[n]["y"][k] = state.state_vector[mapping[1]] + data[n]["time"][k] = state.timestamp + data[n]["time_str"][k] = str(state.timestamp) + data[n]["type"][k] = type(state).__name__ + + trace_base = len(self.fig.data) # number of traces currently in the animation + + # add a trace that keeps the legend up for the entire simulation (will remain + # even if no truths are present), then add a trace for each truth in the simulation. + # initialise keyword arguments, then add them to the traces + truth_kwargs = dict(x=[], y=[], mode="lines", hoverinfo='none', legendgroup=truths_label, + line=dict(dash="dash", color=self.colorway[0]), legendrank=100, + name=truths_label, showlegend=True) + merge(truth_kwargs, kwargs) + # legend dummy trace + self.fig.add_trace(go.Scatter(truth_kwargs)) + + # we don't want the legend for any of the actual traces + truth_kwargs.update({"showlegend": False}) + + for n, _ in enumerate(truths): + # change the colour of each truth and include n in its name + merge(truth_kwargs, dict(line=dict(color=self.colorway[n % len(self.colorway)]))) + merge(truth_kwargs, kwargs) + self.fig.add_trace(go.Scatter(truth_kwargs)) # add to traces + + for frame in self.fig.frames: + + # get current fig data and traces + data_ = list(frame.data) + traces_ = list(frame.traces) + + # convert string to datetime object + frame_time = datetime.fromisoformat(frame.name) + cutoff_time = (frame_time - self.time_window) + + # for the legend + data_.append(go.Scatter(x=[0, 0], y=[0, 0])) + traces_.append(trace_base) + + for n, truth in enumerate(truths): + # all truth points that come at or before the frame time + t_upper = [data[n]["time"] <= frame_time] + + # only select detections that come after the time cut-off + t_lower = [data[n]["time"] >= cutoff_time] + + # put together + mask = np.logical_and(t_upper, t_lower) + + # find x, y, time, and type + truth_x = data[n]["x"][tuple(mask)] + # add in np.inf to ensure traces are present for every timestep + truth_x = np.append(truth_x, [np.inf]) + truth_y = data[n]["y"][tuple(mask)] + truth_y = np.append(truth_y, [np.inf]) + times = data[n]["time_str"][tuple(mask)] + + data_.append(go.Scatter(x=truth_x, + y=truth_y, + meta=times, + hovertemplate='GroundTruthState' + + '
(%{x}, %{y})' + + '
Time: %{meta}')) + + traces_.append(trace_base + n + 1) # append data to correct trace + + frame.data = data_ + frame.traces = traces_ + + if resize: + self._resize(data, type="ground_truth") + + # we have called a plotting function so update flag (gets used in _resize) + self.plotting_function_called = True + + def plot_measurements(self, measurements, mapping, measurement_model=None, + resize=True, measurements_label="Measurements", + convert_measurements=True, **kwargs): + """Plots measurements + + Plots detections and clutter, generating a legend automatically. Detections are plotted as + blue circles by default unless the detection type is clutter. + If the detection type is :class:`~.Clutter` it is plotted as a yellow 'tri-up' marker. + + Users can change the color and marker of detections using keyword arguments but not for + clutter detections. + + Parameters + ---------- + measurements : Collection of :class:`~.Detection` + Detections which will be plotted. If measurements is a set of lists it is flattened. + mapping: list + List of items specifying the mapping of the position components of the state space. + measurement_model : :class:`~.Model`, optional + User-defined measurement model to be used in finding measurement state inverses if + they cannot be found from the measurements themselves. + resize: bool + If True, will resize figure to ensure measurements are in view + measurements_label : str + Label for the measurements. Default is "Measurements". + convert_measurements : bool + Should the measurements be converted from measurement space to state space before + being plotted. Default is True + \\*\\*kwargs: dict + Additional arguments to be passed to scatter function for detections. Defaults are + ``marker=dict(color="#636EFA")``. + """ + + if not isinstance(measurements, Collection): + measurements = {measurements} # Make a set of length 1 + + if any(isinstance(item, set) for item in measurements): + measurements_set = chain.from_iterable(measurements) # Flatten into one set + else: + measurements_set = measurements + plot_detections, plot_clutter = self._conv_measurements(measurements_set, + mapping, + measurement_model, + convert_measurements) + plot_combined = {'Detection': plot_detections, + 'Clutter': plot_clutter} # for later reference + + # this dictionary will store all the plotting data that we need + # from the detections and clutter into numpy arrays that we can easily + # access to plot + combined_data = dict() + + # only add clutter or detections to plot if necessary + if plot_detections: + combined_data.update(dict(Detection=dict())) + if plot_clutter: + combined_data.update(dict(Clutter=dict())) + + # initialise combined_data + for key in combined_data.keys(): + length = len(plot_combined[key]) + combined_data[key].update({ + "x": np.zeros(length), + "y": np.zeros(length), + "time": np.array([0 for _ in range(length)], dtype=object), + "time_str": np.array([0 for _ in range(length)], dtype=object), + "type": np.array([0 for _ in range(length)], dtype=object)}) + + # and now fill in the data + + for key in combined_data.keys(): + for n, det in enumerate(plot_combined[key]): + x, y = list(plot_combined[key].values())[n] + combined_data[key]["x"][n] = x + combined_data[key]["y"][n] = y + combined_data[key]["time"][n] = det.timestamp + combined_data[key]["time_str"][n] = str(det.timestamp) + combined_data[key]["type"][n] = type(det).__name__ + + # get number of traces currently in fig + trace_base = len(self.fig.data) + + # initialise detections + name = measurements_label + "
(Detections)" + measurement_kwargs = dict(x=[], y=[], mode='markers', + name=name, + legendgroup=name, + legendrank=200, showlegend=True, + marker=dict(color="#636EFA"), hoverinfo='none') + merge(measurement_kwargs, kwargs) + + self.fig.add_trace(go.Scatter(measurement_kwargs)) # trace for legend + + measurement_kwargs.update({"showlegend": False}) + self.fig.add_trace(go.Scatter(measurement_kwargs)) # trace for plotting + + # change necessary kwargs to initialise clutter trace + name = measurements_label + "
(Clutter)" + clutter_kwargs = dict(x=[], y=[], mode='markers', + name=name, + legendgroup=name, + legendrank=300, showlegend=True, + marker=dict(symbol="star-triangle-up", color='#FECB52'), + hoverinfo='none') + merge(clutter_kwargs, kwargs) + + self.fig.add_trace(go.Scatter(clutter_kwargs)) # trace for plotting clutter + + # add data to frames + for frame in self.fig.frames: + + data_ = list(frame.data) + traces_ = list(frame.traces) + + # add blank data to ensure detection legend stays in place + data_.append(go.Scatter(x=[-np.inf, np.inf], y=[-np.inf, np.inf])) + traces_.append(trace_base) # ensure data is added to correct trace + + frame_time = datetime.fromisoformat(frame.name) # convert string to datetime object + + # time at which dets will disappear from the fig + cutoff_time = (frame_time - self.time_window) + + for j, key in enumerate(combined_data.keys()): + # only select measurements that arrive by the time of the current frame + t_upper = [combined_data[key]["time"] <= frame_time] + + # only select detections that come after the time cut-off + t_lower = [combined_data[key]["time"] >= cutoff_time] + + # put them together to create the final mask + mask = np.logical_and(t_upper, t_lower) + + # find x and y points for true detections and clutter + det_x = combined_data[key]["x"][tuple(mask)] + det_x = np.append(det_x, [np.inf]) + det_y = combined_data[key]["y"][tuple(mask)] + det_y = np.append(det_y, [np.inf]) + det_times = combined_data[key]["time_str"][tuple(mask)] + + data_.append(go.Scatter(x=det_x, + y=det_y, + meta=det_times, + hovertemplate=f'{key}' + + '
(%{x}, %{y})' + + '
Time: %{meta}')) + traces_.append(trace_base + j + 1) + + frame.data = data_ # update the figure + frame.traces = traces_ + + if resize: + self._resize(combined_data, "measurements") + + # we have called a plotting function so update flag (gets used in resize) + self.plotting_function_called = True + + def plot_tracks(self, tracks, mapping, uncertainty=False, resize=True, + particle=False, plot_history=False, ellipse_points=30, + track_label="Tracks", **kwargs): + """ + Plots each track generated, generating a legend automatically. If 'uncertainty=True', + error ellipses are plotted. Tracks are plotted as solid lines with point markers + and default colours. + + Users can change linestyle, color, and marker using keyword arguments. Uncertainty metrics + will also be plotted with the user defined colour and any changes will apply to all tracks. + + Parameters + ---------- + tracks: Collection of :class '~Track' + Collection of tracks which will be plotted. If not a collection, and instead a single + :class:'~Track' type, the argument is modified to be a set to allow for iteration + + mapping: list + List of items specifying the mapping of the position + components of the state space + uncertainty: bool + If True, function plots uncertainty ellipses + resize: bool + If True, plotter will change bounds so that tracks are in view + particle: bool + If True, function plots particles + plot_history: bool + If true, plots all particles and uncertainty ellipses up to current time step + ellipse_points: int + Number of points for polygon approximating ellipse shape + track_label: str + Label to apply to all tracks for legend + \\*\\*kwargs: dict + Additional arguments to be passed to plot function. Defaults are ``linestyle="-"``, + ``marker='s'`` for :class:`~.Update` and ``marker='o'`` for other states. + + Returns + ------- + """ + + if not isinstance(tracks, Collection) or isinstance(tracks, StateMutableSequence): + tracks = {tracks} # Make a set of length 1 + + # So that we can plot tracks for both the current time and for some previous times, + # we put plotting data for each track into a dictionary so that it can be easily + # accessed later. + data = [dict() for _ in tracks] + + for n, track in enumerate(tracks): # sum up means - accounts for particle filter + + xydata = np.concatenate( + [(getattr(state, 'mean', state.state_vector)[mapping, :]) + for state in track], + axis=1) + + # initialise arrays that go inside the dictionary + data[n].update(x=xydata[0], + y=xydata[1], + time=np.array([0 for _ in range(len(track))], dtype=object), + time_str=np.array([0 for _ in range(len(track))], dtype=object), + type=np.array([0 for _ in range(len(track))], dtype=object)) + + for k, state in enumerate(track): + # fill the arrays here + data[n]["time"][k] = state.timestamp + data[n]["time_str"][k] = str(state.timestamp) + data[n]["type"][k] = type(state).__name__ + + trace_base = len(self.fig.data) # number of traces + + # add dummy trace for legend for track + + track_kwargs = dict(x=[], y=[], mode="markers+lines", line=dict(color=self.colorway[2]), + legendgroup=track_label, legendrank=400, name=track_label, + showlegend=True) + track_kwargs.update(kwargs) + self.fig.add_trace(go.Scatter(track_kwargs)) + + # and initialise traces for every track. Need to change a few kwargs: + track_kwargs.update({'showlegend': False}) + + for k, _ in enumerate(tracks): + # update track colours + track_kwargs.update({'line': dict(color=self.colorway[(k + 2) % len(self.colorway)])}) + track_kwargs.update(kwargs) + self.fig.add_trace(go.Scatter(track_kwargs)) + + for frame in self.fig.frames: + # get current fig data and traces + data_ = list(frame.data) + traces_ = list(frame.traces) + + # convert string to datetime object + frame_time = datetime.fromisoformat(frame.name) + + self.all_masks[frame_time] = dict() # save mask for later use + cutoff_time = (frame_time - self.time_window) + # add blank data to ensure legend stays in place + data_.append(go.Scatter(x=[-np.inf, np.inf], y=[-np.inf, np.inf])) + traces_.append(trace_base) # ensure data is added to correct trace + + for n, track in enumerate(tracks): + + # all track points that come at or before the frame time + t_upper = [data[n]["time"] <= frame_time] + # only select detections that come after the time cut-off + t_lower = [data[n]["time"] >= cutoff_time] + + # put together + mask = np.logical_and(t_upper, t_lower) + + # put into dictionary for later use + if plot_history: + self.all_masks[frame_time][n] = np.logical_and(t_upper, t_lower) + else: + self.all_masks[frame_time][n] = [data[n]["time"] == frame_time] + + # find x, y, time, and type + track_x = data[n]["x"][tuple(mask)] + # add np.inf to plot so that the traces are present for entire simulation + track_x = np.append(track_x, [np.inf]) + + # repeat for y + track_y = data[n]["y"][tuple(mask)] + track_y = np.append(track_y, [np.inf]) + track_type = data[n]["type"][tuple(mask)] + times = data[n]["time_str"][tuple(mask)] + + data_.append(go.Scatter(x=track_x, # plot track + y=track_y, + meta=track_type, + customdata=times, + hovertemplate='%{meta}' + + '
(%{x}, %{y})' + + '
Time: %{customdata}')) + + traces_.append(trace_base + n + 1) # add to correct trace + + frame.data = data_ + frame.traces = traces_ + + if resize: + self._resize(data, "tracks") + + if uncertainty: # plot ellipses + name = f'{track_label}
Uncertainty' + uncertainty_kwargs = dict(x=[], y=[], legendgroup=name, fill='toself', + fillcolor=self.colorway[2], + opacity=0.2, legendrank=500, name=name, + hoverinfo='skip', + mode='none', showlegend=True) + uncertainty_kwargs.update(kwargs) + + # dummy trace for legend for uncertainty + self.fig.add_trace(go.Scatter(uncertainty_kwargs)) + + # and an uncertainty ellipse trace for each track + uncertainty_kwargs.update({'showlegend': False}) + for k, _ in enumerate(tracks): + uncertainty_kwargs.update( + {'fillcolor': self.colorway[(k + 2) % len(self.colorway)]}) + uncertainty_kwargs.update(kwargs) + self.fig.add_trace(go.Scatter(uncertainty_kwargs)) + + # following function finds uncertainty data points and plots them + self._plot_particles_and_ellipses(tracks, mapping, resize, method="uncertainty") + + if particle: # plot particles + + # initialise traces. One for legend and one per track + name = f'{track_label}
Particles' + particle_kwargs = dict(mode='markers', marker=dict(size=2, color=self.colorway[2]), + opacity=0.4, + hoverinfo='skip', legendgroup=name, name=name, + legendrank=520, showlegend=True) + # apply any keyword arguments + particle_kwargs.update(kwargs) + self.fig.add_trace(go.Scatter(particle_kwargs)) # legend trace + + particle_kwargs.update({"showlegend": False}) + + for k, track in enumerate(tracks): # trace for each track + + particle_kwargs.update( + {'marker': dict(size=2, color=self.colorway[(k + 2) % len(self.colorway)])}) + particle_kwargs.update(kwargs) + self.fig.add_trace(go.Scatter(particle_kwargs)) + + self._plot_particles_and_ellipses(tracks, mapping, resize, method="particles") + + # we have called a plotting function so update flag + self.plotting_function_called = True + + def _plot_particles_and_ellipses(self, tracks, mapping, resize, method="uncertainty"): + + """ + The logic for plotting uncertainty ellipses and particles is nearly identical, + so it is put into one function. + + Parameters + ---------- + tracks: Collection of :class '~Track' + Collection of tracks which will be plotted. If not a collection, and instead a single + :class:'~Track' type, the argument is modified to be a set to allow for iteration + mapping: list + List of items specifying the mapping of the position components of the state space. + method: str + Can either be "uncertainty" or "particles". Depends on what the function is plotting. + """ + + data = [dict() for _ in tracks] + trace_base = len(self.fig.data) + for n, track in enumerate(tracks): + + # initialise arrays that store particle/ellipse for later plotting + data[n].update(x=np.array([0 for _ in range(len(track))], dtype=object), + y=np.array([0 for _ in range(len(track))], dtype=object)) + + for k, state in enumerate(track): + + # find data points + if method == "uncertainty": + + data_x, data_y = Plotterly._generate_ellipse_points(state, mapping) + data_x = list(data_x) + data_y = list(data_y) + data_x.append(np.nan) # necessary to draw multiple ellipses at once + data_y.append(np.nan) + data[n]["x"][k] = data_x + data[n]["y"][k] = data_y + + elif method == "particles": + + data_xy = state.state_vector[mapping[:2], :] + data[n]["x"][k] = data_xy[0] + data[n]["y"][k] = data_xy[1] + + else: + raise ValueError("Should be 'uncertainty' or 'particles'") + + for frame in self.fig.frames: + + frame_time = datetime.fromisoformat(frame.name) + + data_ = list(frame.data) # current data in frame + traces_ = list(frame.traces) # current traces in frame + + data_.append(go.Scatter(x=[-np.inf], y=[np.inf])) # add empty data for legend trace + traces_.append(trace_base - len(tracks) - 1) # ensure correct trace + + for n, track in enumerate(tracks): + # now plot the data + _x = list(chain(*data[n]["x"][tuple(self.all_masks[frame_time][n])])) + _y = list(chain(*data[n]["y"][tuple(self.all_masks[frame_time][n])])) + _x.append(np.inf) + _y.append(np.inf) + data_.append(go.Scatter(x=_x, y=_y)) + traces_.append(trace_base - len(tracks) + n) + + frame.data = data_ + frame.traces = traces_ + + if resize: + self._resize(data, type="particle_or_uncertainty") + + def plot_sensors(self, sensors, sensor_label="Sensors", resize=True, **kwargs): + """Plots sensor(s) + + Plots sensors. Users can change the color and marker of detections using keyword + arguments. Default is a black 'x' marker. Currently only works for stationary + sensors. + + Parameters + ---------- + sensors : Collection of :class:`~.Sensor` + Sensors to plot + sensor_label: str + Label to apply to all tracks for legend. + \\*\\*kwargs: dict + Additional arguments to be passed to scatter function for detections. Defaults are + ``marker=dict(symbol='x', color='black')``. + """ + if not isinstance(sensors, Collection): + sensors = {sensors} + + # don't run any of this if there is no data input + if sensors: + trace_base = len(self.fig.data) # number of traces currently in figure + sensor_kwargs = dict(mode='markers', marker=dict(symbol='x', color='black'), + legendgroup=sensor_label, legendrank=50, + name=sensor_label, showlegend=True) + merge(sensor_kwargs, kwargs) + + self.fig.add_trace(go.Scatter(sensor_kwargs)) # initialises trace + + # sensor position + sensor_xy = np.array([sensor.position[[0, 1], 0] for sensor in sensors]) + if resize: + self._resize(sensors, "sensor") + + for frame in self.fig.frames: # the plotting bit + traces_ = list(frame.traces) + data_ = list(frame.data) + + data_.append(go.Scatter(x=sensor_xy[:, 0], y=sensor_xy[:, 1])) + traces_.append(trace_base) + + frame.traces = traces_ + frame.data = data_ + + # we have called a plotting function so update flag (used in _resize) + self.plotting_function_called = True diff --git a/stonesoup/updater/tests/test_kalman.py b/stonesoup/updater/tests/test_kalman.py index 84717595f..b39301c54 100644 --- a/stonesoup/updater/tests/test_kalman.py +++ b/stonesoup/updater/tests/test_kalman.py @@ -272,8 +272,8 @@ def test_schmidtkalman(): assert np.allclose(update.mean, sk_update.mean) assert np.allclose(update.covar, sk_update.covar) - - + + if __name__ == "__main__": import pytest pytest.main(['-v', __file__]) diff --git a/stonesoup/updater/tests/test_pointmass.py b/stonesoup/updater/tests/test_pointmass.py deleted file mode 100644 index c5009830b..000000000 --- a/stonesoup/updater/tests/test_pointmass.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Test for updater.kalman module""" - -import pytest -import numpy as np - -from stonesoup.models.measurement.linear import LinearGaussian -from stonesoup.types.detection import Detection -from stonesoup.types.hypothesis import SingleHypothesis -from stonesoup.types.prediction import ( - GaussianStatePrediction, GaussianMeasurementPrediction, PointMassStatePrediction, PointMassMeasurementPrediction) -from stonesoup.types.state import GaussianState -from stonesoup.updater.pointMass import PointMassUpdater - -from stonesoup.models.transition.linear import KnownTurnRate -from datetime import datetime -from datetime import timedelta -from functions import gridCreation -from stonesoup.types.array import StateVectors -import time - - -@pytest.fixture(params=[PointMassUpdater]) - - -def updater(request): - updater_class = request.param - measurement_model = LinearGaussian( - ndim_state=2, mapping=[0], noise_covar=np.array([[0.04]])) - return updater_class(measurement_model) - -def test_pointmass(updater): - - time_difference = timedelta(days=0, hours=0, minutes=0, seconds=1) - - - # Initial condition - Gaussian - nx = 2 - meanX0 = np.array([36569, 55581]) # mean value - varX0 = np.diag([90, 160]) # variance - Npa = np.array([31, 31]) # 33 number of points per axis, for FFT must be ODD!!!! - N = np.prod(Npa) # number of points - total - sFactor = 4 # scaling factor (number of sigmas covered by the grid) - - - [predGrid, predGridDelta, gridDimOld, xOld, Ppold] = gridCreation(np.vstack(meanX0),varX0,sFactor,nx,Npa) - - meanX0 = np.vstack(meanX0) - pom = predGrid-np.matlib.repmat(meanX0,1,N) - denominator = np.sqrt((2*np.pi)**nx)*np.linalg.det(varX0) - pompom = np.sum(-0.5*np.multiply(pom.T@np.inv(varX0),pom.T),1) #elementwise multiplication - pomexp = np.exp(pompom) - predDensityProb = pomexp/denominator # Adding probabilities to points - predDensityProb = predDensityProb/(sum(predDensityProb)*np.prod(predGridDelta)) - - - start_time = time.time() - - prediction = PointMassStatePrediction(state_vector=StateVectors(predGrid), - weight=predDensityProb, - grid_delta = predGridDelta, - grid_dim = gridDimOld, - center = xOld, - eigVec = Ppold, - Npa = Npa, - timestamp=start_time), - prediction = prediction[0] - measurement = Detection(np.array([[-6.23]])) - - measurement_model=updater.measurement_model - - # Calculate evaluation variables - eval_measurement_prediction = GaussianMeasurementPrediction( - measurement_model.matrix(time_difference) @ prediction.mean, - measurement_model.matrix(time_difference) @ prediction.covar - @ measurement_model.matrix(time_difference).T - + measurement_model.covar(time_difference), - cross_covar=prediction.covar @ measurement_model.matrix(time_difference).T) - kalman_gain = eval_measurement_prediction.cross_covar @ np.linalg.inv( - eval_measurement_prediction.covar) - eval_posterior = GaussianState( - prediction.mean - + kalman_gain @ (measurement.state_vector - - eval_measurement_prediction.mean), - prediction.covar - - kalman_gain@eval_measurement_prediction.covar @ kalman_gain.T) - - - # Get and assert measurement prediction - measurement_prediction = updater.predict_measurement(prediction) - assert np.allclose(measurement_prediction.mean, - eval_measurement_prediction.mean, - 0, atol=1.e-14) - assert np.allclose(measurement_prediction.covar, - eval_measurement_prediction.covar, - 0, atol=1.e-14) - assert np.allclose(measurement_prediction.cross_covar, - eval_measurement_prediction.cross_covar, - 0, atol=1.e-13) - - # Perform and assert state update (without measurement prediction) - posterior = updater.update(SingleHypothesis( - prediction=prediction, - measurement=measurement)) - assert np.allclose(posterior.mean, eval_posterior.mean, 0, atol=1.e-14) - assert np.allclose(posterior.covar, eval_posterior.covar, 0, atol=1.e-13) - assert np.array_equal(posterior.hypothesis.prediction, prediction) - assert np.allclose( - posterior.hypothesis.measurement_prediction.state_vector, - measurement_prediction.state_vector, 0, atol=1.e-14) - assert np.allclose(posterior.hypothesis.measurement_prediction.covar, - measurement_prediction.covar, 0, atol=1.e-14) - assert np.array_equal(posterior.hypothesis.measurement, measurement) - assert posterior.timestamp == prediction.timestamp - - # Perform and assert state update - posterior = updater.update(SingleHypothesis( - prediction=prediction, - measurement=measurement, - measurement_prediction=measurement_prediction)) - assert np.allclose(posterior.mean, eval_posterior.mean, 0, atol=1.e-14) - assert np.allclose(posterior.covar, eval_posterior.covar, 0, atol=1.e-13) - assert np.array_equal(posterior.hypothesis.prediction, prediction) - assert np.allclose( - posterior.hypothesis.measurement_prediction.state_vector, - measurement_prediction.state_vector, 0, atol=1.e-14) - assert np.allclose(posterior.hypothesis.measurement_prediction.covar, - measurement_prediction.covar, 0, atol=1.e-14) - assert np.array_equal(posterior.hypothesis.measurement, measurement) - assert posterior.timestamp == prediction.timestamp - - - -if __name__ == "__main__": - import pytest - pytest.main(['-v', __file__]) \ No newline at end of file From 870ccb695f0069c97966769679580fd1e0ac125c Mon Sep 17 00:00:00 2001 From: pesslovany Date: Sun, 30 Jun 2024 11:53:48 +0200 Subject: [PATCH 08/16] added test for updater points mass --- stonesoup/updater/tests/test_pointmass.py | 87 +++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 stonesoup/updater/tests/test_pointmass.py diff --git a/stonesoup/updater/tests/test_pointmass.py b/stonesoup/updater/tests/test_pointmass.py new file mode 100644 index 000000000..142862b43 --- /dev/null +++ b/stonesoup/updater/tests/test_pointmass.py @@ -0,0 +1,87 @@ +"""Test for updater.particle module""" + +import datetime + +import numpy as np +import pytest +from numpy.linalg import inv +from stonesoup.functions import gridCreation +from stonesoup.models.measurement.linear import LinearGaussian +from stonesoup.types.array import StateVectors +from stonesoup.types.detection import Detection +from stonesoup.types.groundtruth import GroundTruthPath, GroundTruthState +from stonesoup.types.hypothesis import SingleHypothesis +from stonesoup.types.state import PointMassState +from stonesoup.updater.pointMass import PointMassUpdater + + +def test_pointmass(): + start_time = datetime.datetime.now().replace(microsecond=0) + truth = GroundTruthPath( + [GroundTruthState([36569, 50, 55581, 50], timestamp=start_time)] + ) + matrix = np.array( + [ + [1, 0], + [0, 1], + ] + ) + measurement_model = LinearGaussian(ndim_state=4, mapping=(0, 2), noise_covar=matrix) + measurements = [] + measurement = measurement_model.function(truth, noise=True) + measurements.append( + Detection( + measurement, timestamp=truth.timestamp, measurement_model=measurement_model + ) + ) + + # Initial condition - Gaussian + nx = 4 + meanX0 = np.array([36569, 50, 55581, 50]) # mean value + varX0 = np.diag([90, 5, 160, 5]) # variance + Npa = np.array( + [31, 31, 27, 27] + ) # 33 number of points per axis, for FFT must be ODD!!!! + N = np.prod(Npa) # number of points - total + sFactor = 4 # scaling factor (number of sigmas covered by the grid) + + [predGrid, predGridDelta, gridDimOld, xOld, Ppold] = gridCreation( + np.vstack(meanX0), varX0, sFactor, nx, Npa + ) + + meanX0 = np.vstack(meanX0) + pom = predGrid - np.matlib.repmat(meanX0, 1, N) + denominator = np.sqrt((2 * np.pi) ** nx) * np.linalg.det(varX0) + pompom = np.sum( + -0.5 * np.multiply(pom.T @ inv(varX0), pom.T), 1 + ) # elementwise multiplication + pomexp = np.exp(pompom) + predDensityProb = pomexp / denominator # Adding probabilities to points + predDensityProb = predDensityProb / (sum(predDensityProb) * np.prod(predGridDelta)) + + priorPMF = PointMassState( + state_vector=StateVectors(predGrid), + weight=predDensityProb, + grid_delta=predGridDelta, + grid_dim=gridDimOld, + center=xOld, + eigVec=Ppold, + Npa=Npa, + timestamp=start_time, + ) + pmfUpdater = PointMassUpdater(measurement_model) + for measurement in measurements: + hypothesis = SingleHypothesis(priorPMF, measurement) + post = pmfUpdater.update(hypothesis) + assert np.all(post.state_vector == StateVectors(predGrid)) + assert np.all(post.grid_delta == predGridDelta) + assert np.all(post.grid_dim == gridDimOld) + assert np.all(post.center == xOld) + assert np.all(post.eigVec == Ppold) + assert np.all(post.Npa == Npa) + assert post.timestamp == start_time + assert np.isclose(np.sum(post.weight * np.prod(post.grid_delta)), 1, 1e-2) + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) From 8eb51da5b51ae594bfb45fda206bf282c682b48b Mon Sep 17 00:00:00 2001 From: pesslovany <43780444+pesslovany@users.noreply.github.com> Date: Mon, 1 Jul 2024 14:19:37 +0200 Subject: [PATCH 09/16] Added test for predictor --- stonesoup/.DS_Store | Bin 12292 -> 12292 bytes stonesoup/functions/__init__.py | 2 +- stonesoup/predictor/tests/test_pointmass.py | 79 ++++++++++++++++++++ stonesoup/types/state.py | 1 + stonesoup/updater/.DS_Store | Bin 6148 -> 6148 bytes stonesoup/updater/tests/.DS_Store | Bin 0 -> 6148 bytes 6 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 stonesoup/predictor/tests/test_pointmass.py create mode 100644 stonesoup/updater/tests/.DS_Store diff --git a/stonesoup/.DS_Store b/stonesoup/.DS_Store index 90155ce47b69117b5fe1dc7eb911a0d80fdaa9fc..2e5c42ea18768b5dd64976ed56cffe4b8ab6344a 100644 GIT binary patch delta 363 zcmZokXi1ph&se@OU^hRb>t-H-e8ze%h608hhJ1z+AWUT_V$cKfD+@}Jd`gQ;iWnFe zVij0fPXR*;&?a!;lm!>% i<>cq315IY!tRT6Qbu+iZCwBG{h76F$H?wH`gmVCnIZ$Q* delta 71 zcmZokXi1ph&sefCU^hRb(`Fule8$cFLMzzW*%-jmg#?9Oc UpV-+`8Hx%r41+hbX#9k808KC!?*IS* diff --git a/stonesoup/functions/__init__.py b/stonesoup/functions/__init__.py index ef2319fa8..6ef9941ed 100644 --- a/stonesoup/functions/__init__.py +++ b/stonesoup/functions/__init__.py @@ -26,9 +26,9 @@ def gridCreation(xp_aux, Pp_aux, sFactor, nx, Npa): sortInd = np.argsort(sortInd) pom = np.sort(gridBound) + Ipom = np.argsort(gridBound) gridBound = pom[sortInd] - Ipom = np.argsort(gridBound) pom2 = eigVect[:, Ipom] eigVect = pom2[:, sortInd] gridDim = [] # Reset gridDim for each cycle diff --git a/stonesoup/predictor/tests/test_pointmass.py b/stonesoup/predictor/tests/test_pointmass.py new file mode 100644 index 000000000..aacc3cb0d --- /dev/null +++ b/stonesoup/predictor/tests/test_pointmass.py @@ -0,0 +1,79 @@ +"""Test for updater.particle module""" + +import datetime +from datetime import timedelta + +import numpy as np +import pytest +from numpy.linalg import inv +from stonesoup.functions import gridCreation +from stonesoup.models.transition.linear import KnownTurnRate +from stonesoup.predictor.kalman import KalmanPredictor +from stonesoup.predictor.pointMass import PointMassPredictor +from stonesoup.types.array import StateVectors +from stonesoup.types.state import GaussianState, PointMassState + + +def test_pointmass(): + start_time = datetime.datetime.now().replace(microsecond=0) + transition_model = KnownTurnRate( + turn_noise_diff_coeffs=[2, 2], turn_rate=np.deg2rad(30) + ) + time_difference = timedelta(days=0, hours=0, minutes=0, seconds=1) + + # Initial condition - Gaussian + nx = 4 + meanX0 = np.array([36569, 50, 55581, 50]) # mean value + varX0 = np.diag([90, 5, 160, 5.1]) # variance + Npa = np.array( + [33, 33, 33, 33] + ) # 33 number of points per axis, for FFT must be ODD!!!! + N = np.prod(Npa) # number of points - total + sFactor = 4 # scaling factor (number of sigmas covered by the grid) + + [predGrid, predGridDelta, gridDimOld, xOld, Ppold] = gridCreation( + np.vstack(meanX0), varX0, sFactor, nx, Npa + ) + + predictorKF = KalmanPredictor(transition_model) + priorKF = GaussianState(meanX0, varX0, timestamp=start_time) + prediction = predictorKF.predict(priorKF, timestamp=start_time + time_difference) + + meanX0 = np.vstack(meanX0) + pom = predGrid - np.matlib.repmat(meanX0, 1, N) + denominator = np.sqrt((2 * np.pi) ** nx) * np.linalg.det(varX0) + pompom = np.sum( + -0.5 * np.multiply(pom.T @ inv(varX0), pom.T), 1 + ) # elementwise multiplication + pomexp = np.exp(pompom) + predDensityProb = pomexp / denominator # Adding probabilities to points + predDensityProb = predDensityProb / (sum(predDensityProb) * np.prod(predGridDelta)) + + priorPMF = PointMassState( + state_vector=StateVectors(predGrid), + weight=predDensityProb, + grid_delta=predGridDelta, + grid_dim=gridDimOld, + center=xOld, + eigVec=Ppold, + Npa=Npa, + timestamp=start_time, + ) + pmfPredictor = PointMassPredictor(transition_model) + predictionPMF = pmfPredictor.predict( + priorPMF, timestamp=start_time + time_difference + ) + assert np.allclose(predictionPMF.mean, np.ravel(prediction.mean), atol=1) + assert np.allclose(predictionPMF.covar(), prediction.covar, atol=2) + assert np.all(predictionPMF.Npa == Npa) + assert np.all(np.argsort(predictionPMF.grid_delta) == np.argsort(np.diag(varX0))) + assert np.allclose( + predictionPMF.center, transition_model.matrix(time_difference) @ xOld, atol=1e-1 + ) + assert np.isclose( + np.sum(predictionPMF.weight * np.prod(predictionPMF.grid_delta)), 1, atol=1e-1 + ) + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/stonesoup/types/state.py b/stonesoup/types/state.py index fc47269a1..00f0df339 100644 --- a/stonesoup/types/state.py +++ b/stonesoup/types/state.py @@ -217,6 +217,7 @@ def covar(self): chip_ = self.state_vector - self.mean[:, np.newaxis] chip_w = chip_ * self.weight.reshape(1, -1, order="C") measVar = (chip_w @ chip_.T) * np.prod(self.grid_delta) + measVar = CovarianceMatrix(measVar) return measVar diff --git a/stonesoup/updater/.DS_Store b/stonesoup/updater/.DS_Store index d64b7b319ab02a5252d73c776ebfa1d1b5e21e5c..faf04181201f6d905be0ec048f05d10644cf4797 100644 GIT binary patch delta 53 zcmZoMXfc@J&&atkU^gQp=Vl%zV@7dSh7yKUhGHODTozoEmy@5D&cMLHxY?0;Bgc+3tl delta 32 ocmZoMXfc@J&&aVcU^gQp$7UWTW5&&?%!^niHmGf8=lIJH0HBKr(f|Me diff --git a/stonesoup/updater/tests/.DS_Store b/stonesoup/updater/tests/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 GIT binary patch literal 6148 zcmeH~Jr2S!425mzP>H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T0 Date: Mon, 1 Jul 2024 15:56:30 +0200 Subject: [PATCH 10/16] Added lines to tests. Reverted plotter. --- stonesoup/plotter.py | 6 +++--- stonesoup/predictor/pointMass.py | 7 ++----- stonesoup/predictor/tests/test_pointmass.py | 11 +++++++++++ stonesoup/updater/pointMass.py | 15 ++++++--------- 4 files changed, 22 insertions(+), 17 deletions(-) diff --git a/stonesoup/plotter.py b/stonesoup/plotter.py index 528597cce..7428411ec 100644 --- a/stonesoup/plotter.py +++ b/stonesoup/plotter.py @@ -1192,8 +1192,8 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, name=name, legendgroup=name, legendrank=210) if self.dimension == 3: # update - star-triangle-up not in 3d plotly - measurement_kwargs.update(dict(marker=dict(size=4, symbol="diamond", - color='#FECB52'))) + clutter_kwargs.update(dict(marker=dict(size=4, symbol="diamond", + color='#FECB52'))) merge(clutter_kwargs, kwargs) if clutter_kwargs['legendgroup'] not in {trace.legendgroup @@ -3021,4 +3021,4 @@ def plot_sensors(self, sensors, sensor_label="Sensors", resize=True, **kwargs): frame.data = data_ # we have called a plotting function so update flag (used in _resize) - self.plotting_function_called = True + self.plotting_function_called = True \ No newline at end of file diff --git a/stonesoup/predictor/pointMass.py b/stonesoup/predictor/pointMass.py index b4bdbf142..2f5d83980 100644 --- a/stonesoup/predictor/pointMass.py +++ b/stonesoup/predictor/pointMass.py @@ -40,15 +40,12 @@ def predict(self, prior, timestamp=None, **kwargs): The predicted state """ # Compute time_interval - try: - time_interval = timestamp - prior.timestamp - except TypeError: - time_interval = None + time_interval = timestamp - prior.timestamp time_difference = timedelta(days=0, hours=0, minutes=0, seconds=0) if time_interval == time_difference: predGrid = (prior.state_vector,) - predDensityProb = prior.weight # SLOW LINE + predDensityProb = prior.weight GridDelta = prior.grid_delta gridDimOld = prior.grid_dim xOld = prior.center diff --git a/stonesoup/predictor/tests/test_pointmass.py b/stonesoup/predictor/tests/test_pointmass.py index aacc3cb0d..d33d796bb 100644 --- a/stonesoup/predictor/tests/test_pointmass.py +++ b/stonesoup/predictor/tests/test_pointmass.py @@ -63,6 +63,9 @@ def test_pointmass(): predictionPMF = pmfPredictor.predict( priorPMF, timestamp=start_time + time_difference ) + predictionPMFnoTime = pmfPredictor.predict( + priorPMF, timestamp=start_time + ) assert np.allclose(predictionPMF.mean, np.ravel(prediction.mean), atol=1) assert np.allclose(predictionPMF.covar(), prediction.covar, atol=2) assert np.all(predictionPMF.Npa == Npa) @@ -73,6 +76,14 @@ def test_pointmass(): assert np.isclose( np.sum(predictionPMF.weight * np.prod(predictionPMF.grid_delta)), 1, atol=1e-1 ) + assert np.all(priorPMF.state_vector == predictionPMFnoTime.state_vector) + assert np.all(priorPMF.weight == predictionPMFnoTime.weight) + assert np.all(priorPMF.grid_delta == predictionPMFnoTime.grid_delta) + assert np.all(priorPMF.grid_dim == predictionPMFnoTime.grid_dim) + assert np.all(priorPMF.center == predictionPMFnoTime.center) + assert np.all(priorPMF.eigVec == predictionPMFnoTime.eigVec) + assert np.all(priorPMF.Npa == predictionPMFnoTime.Npa) + assert np.all(priorPMF.timestamp == predictionPMFnoTime.timestamp) if __name__ == "__main__": diff --git a/stonesoup/updater/pointMass.py b/stonesoup/updater/pointMass.py index 6ef1cb371..6c8bc12a8 100644 --- a/stonesoup/updater/pointMass.py +++ b/stonesoup/updater/pointMass.py @@ -53,7 +53,7 @@ def __init__(self, *args, **kwargs): # @profile def update(self, hypothesis, **kwargs): - """Particle Filter update step + """Point mass update step Parameters ---------- @@ -73,24 +73,21 @@ def update(self, hypothesis, **kwargs): timestamp=hypothesis.prediction.timestamp, ) - if hypothesis.measurement.measurement_model is None: - measurement_model = self.measurement_model - else: - measurement_model = hypothesis.measurement.measurement_model + measurement_model = hypothesis.measurement.measurement_model - R = measurement_model.covar() + R = measurement_model.covar() # Noise x = measurement_model.function( predicted_state - ) + ) # State to measurement space pdf_value = multivariate_normal.pdf( x.T, np.ravel(hypothesis.measurement.state_vector), R - ) + ) # likelihood new_weight = np.ravel(hypothesis.prediction.weight) * np.ravel(pdf_value) new_weight = new_weight / ( np.prod(hypothesis.prediction.grid_delta) * sum(new_weight) - ) + ) # Normalization predicted_state = PointMassState( state_vector=hypothesis.prediction.state_vector, From 7af6e21c50e6c2e82e1ba1c1a400599bc02da39e Mon Sep 17 00:00:00 2001 From: pesslovany <43780444+pesslovany@users.noreply.github.com> Date: Mon, 1 Jul 2024 15:59:48 +0200 Subject: [PATCH 11/16] Fixed flake 8 --- stonesoup/plotter.py | 2 +- stonesoup/updater/pointMass.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/stonesoup/plotter.py b/stonesoup/plotter.py index 7428411ec..26ae3b07e 100644 --- a/stonesoup/plotter.py +++ b/stonesoup/plotter.py @@ -3021,4 +3021,4 @@ def plot_sensors(self, sensors, sensor_label="Sensors", resize=True, **kwargs): frame.data = data_ # we have called a plotting function so update flag (used in _resize) - self.plotting_function_called = True \ No newline at end of file + self.plotting_function_called = True diff --git a/stonesoup/updater/pointMass.py b/stonesoup/updater/pointMass.py index 6c8bc12a8..34d4df38a 100644 --- a/stonesoup/updater/pointMass.py +++ b/stonesoup/updater/pointMass.py @@ -75,19 +75,19 @@ def update(self, hypothesis, **kwargs): measurement_model = hypothesis.measurement.measurement_model - R = measurement_model.covar() # Noise + R = measurement_model.covar() # Noise x = measurement_model.function( predicted_state - ) # State to measurement space + ) # State to measurement space pdf_value = multivariate_normal.pdf( x.T, np.ravel(hypothesis.measurement.state_vector), R - ) # likelihood + ) # likelihood new_weight = np.ravel(hypothesis.prediction.weight) * np.ravel(pdf_value) new_weight = new_weight / ( np.prod(hypothesis.prediction.grid_delta) * sum(new_weight) - ) # Normalization + ) # Normalization predicted_state = PointMassState( state_vector=hypothesis.prediction.state_vector, From 99001283fd7988605ba65a6411a608c75158cf5e Mon Sep 17 00:00:00 2001 From: pesslovany <43780444+pesslovany@users.noreply.github.com> Date: Tue, 2 Jul 2024 09:43:08 +0200 Subject: [PATCH 12/16] Added documentation --- stonesoup/functions/__init__.py | 31 +++++++++++++++++ stonesoup/predictor/pointMass.py | 8 ++--- stonesoup/types/state.py | 4 +++ stonesoup/updater/pointMass.py | 59 ++++---------------------------- 4 files changed, 45 insertions(+), 57 deletions(-) diff --git a/stonesoup/functions/__init__.py b/stonesoup/functions/__init__.py index 6ef9941ed..29fba8bd6 100644 --- a/stonesoup/functions/__init__.py +++ b/stonesoup/functions/__init__.py @@ -14,6 +14,37 @@ def gridCreation(xp_aux, Pp_aux, sFactor, nx, Npa): + """Grid for point mass filter + + Create a PMF grid based on center, covariance matrix, and sigma probability + + Parameters + ========== + xp_aux : numpy.ndarray + `nx` by `1` center of the grid + Pp_aux : numpy.ndarray + 'nx' by 'nx' covariance matrix + sFactor : int + Parameter for the size of the grid + nx : int + Dimension of the grid + Npa : numpy.ndarray + 'nx' by '' number of points per axis of the grid + + Returns + ======= + predGrid : numpy.ndarray + 'nx' by prod(Npa) predictive grid + predGridDelta : list + grid step per dimension + gridDim : list of numpy.ndarrays + grid coordinates per dimension before rotation and translation + xp_aux : numpy.ndarray + grid center + eigVect : numpy.ndarray + eigenvectors describing the rotation of the grid + + """ gridDim = np.zeros((nx, Npa[0])) gridStep = np.zeros(nx) eigVal, eigVect = LA.eig( diff --git a/stonesoup/predictor/pointMass.py b/stonesoup/predictor/pointMass.py index 2f5d83980..dcb7ba59b 100644 --- a/stonesoup/predictor/pointMass.py +++ b/stonesoup/predictor/pointMass.py @@ -23,16 +23,16 @@ class PointMassPredictor(Predictor): sFactor: float = Property(default=4, doc="How many sigma to cover by the grid") # @profile - def predict(self, prior, timestamp=None, **kwargs): - """Particle Filter prediction step + def predict(self, prior, timestamp=1, **kwargs): + """Point Mass Filter prediction step Parameters ---------- - prior : :class:`~.ParticleState` + prior : :class:`~.Point mass state` A prior state object timestamp: :class:`datetime.datetime`, optional A timestamp signifying when the prediction is performed - (the default is `None`) + (the default is `1`) Returns ------- diff --git a/stonesoup/types/state.py b/stonesoup/types/state.py index 00f0df339..7cc2e5777 100644 --- a/stonesoup/types/state.py +++ b/stonesoup/types/state.py @@ -184,6 +184,10 @@ def from_state( class PointMassState(State): + """PointMassState State type + + For the Lagrangina Point Mass filter. + """ state_vector: StateVectors = Property(doc="State vectors.") weight: MutableSequence[Probability] = Property( diff --git a/stonesoup/updater/pointMass.py b/stonesoup/updater/pointMass.py index 34d4df38a..c49b0094d 100644 --- a/stonesoup/updater/pointMass.py +++ b/stonesoup/updater/pointMass.py @@ -1,52 +1,19 @@ -from functools import lru_cache -from typing import Callable + import numpy as np from scipy.stats import multivariate_normal from stonesoup.types.state import PointMassState from ..base import Property -from ..regulariser import Regulariser -from ..resampler import Resampler -from ..types.prediction import ( - MeasurementPrediction, -) from ..types.update import Update from .base import Updater class PointMassUpdater(Updater): - """Particle Updater + """Point mass Updater - Perform an update by multiplying particle weights by PDF of measurement - model (either :attr:`~.Detection.measurement_model` or - :attr:`measurement_model`), and normalising the weights. If provided, a - :attr:`resampler` will be used to take a new sample of particles (this is - called every time, and it's up to the resampler to decide if resampling is - required). + Perform an update by multiplying grid points weights by PDF of measurement + model """ - - sFactor: float = Property(default=3, doc="How many sigma to cover by the grid") - resampler: Resampler = Property( - default=None, doc="Resampler to prevent particle degeneracy" - ) - regulariser: Regulariser = Property( - default=None, - doc="Regulariser to prevent particle impoverishment. The regulariser " - "is normally used after resampling. If a :class:`~.Resampler` is defined, " - "then regularisation will only take place if the particles have been " - "resampled. If the :class:`~.Resampler` is not defined but a " - ":class:`~.Regulariser` is, then regularisation will be conducted under the " - "assumption that the user intends for this to occur.", - ) - - constraint_func: Callable = Property( - default=None, - doc="Callable, user defined function for applying " - "constraints to the states. This is done by setting the weights " - "of particles to 0 for particles that are not correctly constrained. " - "This function provides indices of the unconstrained particles and " - "should accept a :class:`~.ParticleState` object and return an array-like " - "object of logical indices. ", - ) + sFactor: float = Property(default=4, doc="How many sigma to cover by the grid") def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -63,7 +30,7 @@ def update(self, hypothesis, **kwargs): Returns ------- - : :class:`~.ParticleState` + : :class:`~.PointMassState` The state posterior """ @@ -101,17 +68,3 @@ def update(self, hypothesis, **kwargs): ) return predicted_state - - @lru_cache() - def predict_measurement(self, state_prediction, measurement_model=None, **kwargs): - - if measurement_model is None: - measurement_model = self.measurement_model - - new_state_vector = measurement_model.function(state_prediction, **kwargs) - - return MeasurementPrediction.from_state( - state_prediction, - state_vector=new_state_vector, - timestamp=state_prediction.timestamp, - ) From e3933a79eec1da0c5fad3e761cc2d1842d48f370 Mon Sep 17 00:00:00 2001 From: pesslovany <43780444+pesslovany@users.noreply.github.com> Date: Tue, 2 Jul 2024 09:55:23 +0200 Subject: [PATCH 13/16] Removing inwanted files --- .DS_Store | Bin 6148 -> 0 bytes .gitignore | 3 + 01_terrain_aided_navigation.py | 309 ------------------------- 02_ParticleFilter.py | 284 ----------------------- docs/.DS_Store | Bin 6148 -> 0 bytes docs/examples/.DS_Store | Bin 8196 -> 0 bytes docs/source/.DS_Store | Bin 6148 -> 0 bytes docs/tutorials/.DS_Store | Bin 6148 -> 0 bytes pokus1.py | 45 ---- stonesoup/.DS_Store | Bin 12292 -> 0 bytes stonesoup/dataassociator/.DS_Store | Bin 6148 -> 0 bytes stonesoup/functions/.DS_Store | Bin 6148 -> 0 bytes stonesoup/gater/.DS_Store | Bin 6148 -> 0 bytes stonesoup/hypothesiser/.DS_Store | Bin 6148 -> 0 bytes stonesoup/initiator/.DS_Store | Bin 6148 -> 0 bytes stonesoup/measures/.DS_Store | Bin 6148 -> 0 bytes stonesoup/models/.DS_Store | Bin 6148 -> 0 bytes stonesoup/models/measurement/.DS_Store | Bin 6148 -> 0 bytes stonesoup/movable/.DS_Store | Bin 6148 -> 0 bytes stonesoup/platform/.DS_Store | Bin 6148 -> 0 bytes stonesoup/predictor/.DS_Store | Bin 6148 -> 0 bytes stonesoup/regulariser/.DS_Store | Bin 6148 -> 0 bytes stonesoup/resampler/.DS_Store | Bin 6148 -> 0 bytes stonesoup/sampler/.DS_Store | Bin 6148 -> 0 bytes stonesoup/sensor/.DS_Store | Bin 6148 -> 0 bytes stonesoup/tracker/.DS_Store | Bin 6148 -> 0 bytes stonesoup/types/.DS_Store | Bin 6148 -> 0 bytes stonesoup/updater/.DS_Store | Bin 6148 -> 0 bytes stonesoup/updater/tests/.DS_Store | Bin 6148 -> 0 bytes terrain_aided_navigation.py | 309 ------------------------- terrain_aided_navigation.py.lprof | Bin 979 -> 0 bytes 31 files changed, 3 insertions(+), 947 deletions(-) delete mode 100644 .DS_Store delete mode 100644 01_terrain_aided_navigation.py delete mode 100644 02_ParticleFilter.py delete mode 100644 docs/.DS_Store delete mode 100644 docs/examples/.DS_Store delete mode 100644 docs/source/.DS_Store delete mode 100644 docs/tutorials/.DS_Store delete mode 100644 pokus1.py delete mode 100644 stonesoup/.DS_Store delete mode 100644 stonesoup/dataassociator/.DS_Store delete mode 100644 stonesoup/functions/.DS_Store delete mode 100644 stonesoup/gater/.DS_Store delete mode 100644 stonesoup/hypothesiser/.DS_Store delete mode 100644 stonesoup/initiator/.DS_Store delete mode 100644 stonesoup/measures/.DS_Store delete mode 100644 stonesoup/models/.DS_Store delete mode 100644 stonesoup/models/measurement/.DS_Store delete mode 100644 stonesoup/movable/.DS_Store delete mode 100644 stonesoup/platform/.DS_Store delete mode 100644 stonesoup/predictor/.DS_Store delete mode 100644 stonesoup/regulariser/.DS_Store delete mode 100644 stonesoup/resampler/.DS_Store delete mode 100644 stonesoup/sampler/.DS_Store delete mode 100644 stonesoup/sensor/.DS_Store delete mode 100644 stonesoup/tracker/.DS_Store delete mode 100644 stonesoup/types/.DS_Store delete mode 100644 stonesoup/updater/.DS_Store delete mode 100644 stonesoup/updater/tests/.DS_Store delete mode 100644 terrain_aided_navigation.py delete mode 100644 terrain_aided_navigation.py.lprof diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 08a30397bf998b9bd0d3fa6910467d2626341266..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKy-or_5T4ZwL9xJ&M2m}smP%`4vf%jyJOBz9FoC-Zh(FyufbXI4J+!veSXxgO{fX?!L1JSrW+vJFj@g-=?>^XFmWV`s+^G_kiKvLC&|5~AVQO5pDO<3fRiMyw zwCRZ2)T3i+PDQhW6<`JOrhrAcAJ26o-$(zB}gDzk9B{z8CJ5Hs85!?W}RPt{{n2TA~r1qH>7$9|6Y{ zrx@(q>}E%~Txx3@7n>`ZOU8UItND>GKa%>#k}=;|V;J_MunS%|DcEwMfc++lrChLK z6086#kS7Jy{@|l2lok_%dh0-^uK+*~Zf$thXCu&84xqG{7=#BVOe)Z%3VX#6CLMm? z#z~8bL6c5GSH?WHvamN4p{v91>u?g%Aor{QE09)T(R?hc^Z#J>{Xd<=pI8A_;9n^q zyr2=(@k;h=oq9QS*7|6NXv$PvVo;}`v&XT1P)BhEO&i91^#GI>6NB);_>X{=fjd^< HPZjtAU;3M8 diff --git a/.gitignore b/.gitignore index bcad6d2a9..7c5a556e1 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,6 @@ htmlcov/ #Animation plotting example docs/examples/plotting/example_animation.gif + +#Mac files +.DS_Store diff --git a/01_terrain_aided_navigation.py b/01_terrain_aided_navigation.py deleted file mode 100644 index d4c06d405..000000000 --- a/01_terrain_aided_navigation.py +++ /dev/null @@ -1,309 +0,0 @@ -#!/usr/bin/env python - -# ===================================== -# 4 - Sampling methods: particle filter -# ===================================== -# """ - - - -# %% -# -# Nearly-constant velocity example -# -------------------------------- -# We continue in the same vein as the previous tutorials. -# -# Ground truth -# ^^^^^^^^^^^^ -# Import the necessary libraries - -import numpy as np -import matplotlib.pyplot as plt -import time - -from datetime import datetime -from datetime import timedelta - - -# Initialise Stone Soup ground-truth and transition models. -from stonesoup.models.transition.linear import CombinedLinearGaussianTransitionModel, \ - ConstantVelocity -from stonesoup.models.transition.linear import KnownTurnRate -from stonesoup.types.groundtruth import GroundTruthPath, GroundTruthState -from stonesoup.types.detection import Detection -from stonesoup.models.measurement.nonlinear import TerrainAidedNavigation -from stonesoup.models.measurement.linear import LinearGaussian -from scipy.interpolate import RegularGridInterpolator -from stonesoup.predictor.particle import ParticlePredictor -from stonesoup.resampler.particle import ESSResampler -from stonesoup.resampler.particle import MultinomialResampler -from stonesoup.updater.particle import ParticleUpdater -from stonesoup.functions import gridCreation -from numpy.linalg import inv -from stonesoup.types.state import PointMassState -from stonesoup.types.hypothesis import SingleHypothesis -from stonesoup.types.track import Track -from stonesoup.types.state import GaussianState - -from stonesoup.predictor.pointMass import PointMassPredictor -from stonesoup.updater.pointMass import PointMassUpdater -from scipy.stats import multivariate_normal - -from stonesoup.predictor.kalman import KalmanPredictor -from stonesoup.updater.kalman import KalmanUpdater - -from stonesoup.types.numeric import Probability # Similar to a float type -from stonesoup.types.state import ParticleState -from stonesoup.types.array import StateVectors -import json - - - -# Initialize arrays to store RMSE values -matrixTruePMF = [] -matrixTruePF = [] -matrixTrueKF = [] -MC = 10 -for mc in range(0,MC): - print(mc) - start_time = datetime.now().replace(microsecond=0) - - # %% - - #np.random.seed(1991) - - # %% - - - transition_model = KnownTurnRate(turn_noise_diff_coeffs = [2, 2], turn_rate = np.deg2rad(30)) - - # This needs to be done in other way - time_difference = timedelta(days=0, hours=0, minutes=0, seconds=1) - - - timesteps = [start_time] - truth = GroundTruthPath([GroundTruthState([36569, 50, 55581, 50], timestamp=start_time)]) - - # %% - # Create the truth path - for k in range(1, 20): - timesteps.append(start_time+timedelta(seconds=k)) - truth.append(GroundTruthState( - transition_model.function(truth[k-1], noise=True, time_interval=timedelta(seconds=1)), - timestamp=timesteps[k])) - - - # %% - # Initialise the bearing, range sensor using the appropriate measurement model. - - # Open the JSON file - with open('/Users/matoujak/Desktop/file.json', 'r') as file: - # Load JSON data - data = json.load(file) - - map_x = data['x'] - map_y = data['y'] - map_z = data['z'] - - map_x = np.array(map_x) - map_y = np.array(map_y) - map_z = np.matrix(map_z) - - - interpolator = RegularGridInterpolator((map_x[0,:],map_y[:,0]), map_z) - - - - measurement_model = TerrainAidedNavigation(interpolator,noise_covar = 1, mapping=(0, 2)) - # matrix = np.array([ - # [1, 0], - # [0, 1], - # ]) - # measurement_model = LinearGaussian(ndim_state = 4, mapping = (0, 2), noise_covar = matrix) - - # %% - # Populate the measurement array - measurements = [] - for state in truth: - measurement = measurement_model.function(state, noise=True) - measurements.append(Detection(measurement, timestamp=state.timestamp, - measurement_model=measurement_model)) - - - - predictor = ParticlePredictor(transition_model) - resampler = MultinomialResampler() - updater = ParticleUpdater(measurement_model, resampler) - - - predictorKF = KalmanPredictor(transition_model) - updaterKF = KalmanUpdater(measurement_model) - - - - # %% - # Initialise a prior - # ^^^^^^^^^^^^^^^^^^ - # To start we create a prior estimate. This is a :class:`~.ParticleState` which describes - # the state as a distribution of particles using :class:`~.StateVectors` and weights. - # This is sampled from the Gaussian distribution (using the same parameters we - # had in the previous examples). - - number_particles = 10000 - - # Sample from the prior Gaussian distribution - samples = multivariate_normal.rvs(np.array([36569, 50, 55581, 50]), - np.diag([90, 5, 160, 5]), - size=number_particles) - - # Create prior particle state. - prior = ParticleState(state_vector=StateVectors(samples.T), - weight=np.array([Probability(1/number_particles)]*number_particles), - timestamp=start_time) - - priorKF = GaussianState([[36569], [50], [55581], [50]], np.diag([90, 5, 160, 5]), timestamp=start_time) - - # %% PMF prior - - pmfPredictor = PointMassPredictor(transition_model) - pmfUpdater = PointMassUpdater(measurement_model) - # Initial condition - Gaussian - nx = 4 - meanX0 = np.array([36569, 50, 55581, 50]) # mean value - varX0 = np.diag([90, 5, 160, 5]) # variance - Npa = np.array([31, 31, 27, 27]) # 33 number of points per axis, for FFT must be ODD!!!! - N = np.prod(Npa) # number of points - total - sFactor = 4 # scaling factor (number of sigmas covered by the grid) - - - [predGrid, predGridDelta, gridDimOld, xOld, Ppold] = gridCreation(np.vstack(meanX0),varX0,sFactor,nx,Npa) - - meanX0 = np.vstack(meanX0) - pom = predGrid-np.matlib.repmat(meanX0,1,N) - denominator = np.sqrt((2*np.pi)**nx)*np.linalg.det(varX0) - pompom = np.sum(-0.5*np.multiply(pom.T@inv(varX0),pom.T),1) #elementwise multiplication - pomexp = np.exp(pompom) - predDensityProb = pomexp/denominator # Adding probabilities to points - predDensityProb = predDensityProb/(sum(predDensityProb)*np.prod(predGridDelta)) - - priorPMF = PointMassState(state_vector=StateVectors(predGrid), - weight=predDensityProb, - grid_delta = predGridDelta, - grid_dim = gridDimOld, - center = xOld, - eigVec = Ppold, - Npa = Npa, - timestamp=start_time) - - F = transition_model.matrix(prior=prior, time_interval=time_difference) - Q = transition_model.covar(time_interval=time_difference) - - - - priorPMF = PointMassState(state_vector=StateVectors(predGrid), - weight=predDensityProb, - grid_delta = predGridDelta, - grid_dim = gridDimOld, - center = xOld, - eigVec = Ppold, - Npa = Npa, - timestamp=start_time) - - - matrixPMF = [] - - start_time = time.time() - track = Track() - for measurement in measurements: - prediction = pmfPredictor.predict(priorPMF, timestamp=measurement.timestamp) - hypothesis = SingleHypothesis(prediction, measurement) - post = pmfUpdater.update(hypothesis) - priorPMF = post - matrixPMF.append(post.mean) - # print(post.mean) - - # Record the end time - end_time = time.time() - - # Calculate the elapsed time - # print(end_time - start_time) - - - # matrixKF = [] - - # start_time = time.time() - # track = Track() - # for measurement in measurements: - # prediction = predictorKF.predict(priorKF, timestamp=measurement.timestamp) - # hypothesis = SingleHypothesis(prediction, measurement) - # post = updaterKF.update(hypothesis) - # priorKF = post - # matrixKF.append(post.mean) - # # print(post.mean) - - # # Record the end time - # end_time = time.time() - - # %% - # Run the tracker - # ^^^^^^^^^^^^^^^ - # We now run the predict and update steps, propagating the collection of particles and resampling - # when told to (at every step). - - matrixPF = [] - start_time = time.time() - track = Track() - for measurement in measurements: - prediction = predictor.predict(prior, timestamp=measurement.timestamp) - hypothesis = SingleHypothesis(prediction, measurement) - post = updater.update(hypothesis) - # print(post.mean) - track.append(post) - matrixPF.append(post.mean) - prior = track[-1] - - # Record the end time - end_time = time.time() - - # Calculate the elapsed time - # print(end_time - start_time) - - - for ind in range(0,20): - matrixTruePMF.append(np.ravel(np.vstack(matrixPMF[ind])-truth.states[ind].state_vector)) - matrixTruePF.append(np.ravel(matrixPF[ind]-truth.states[ind].state_vector)) - # matrixTrueKF.append(np.ravel(matrixKF[ind]-truth.states[ind].state_vector)) - - -def rmse(errors): - """ - Calculate the Root Mean Square Error (RMSE) from a list of errors. - - Args: - errors (list): List of errors. - - Returns: - float: RMSE value. - """ - # Convert the list of errors into a numpy array for easier computation - errors_array = np.array(errors) - - # Square the errors - squared_errors = np.square(errors_array) - - # Calculate the mean squared error - mean_squared_error = np.mean(squared_errors,0) - - # Calculate the root mean squared error - rmse_value = np.sqrt(mean_squared_error) - - return rmse_value - - -print(rmse(matrixTruePF)) -print(rmse(matrixTruePMF)) -# print(rmse(matrixTrueKF)) - - - - diff --git a/02_ParticleFilter.py b/02_ParticleFilter.py deleted file mode 100644 index 020aa578c..000000000 --- a/02_ParticleFilter.py +++ /dev/null @@ -1,284 +0,0 @@ -#!/usr/bin/env python - -""" -===================================== -4 - Sampling methods: particle filter -===================================== -""" - - - -# %% -# -# Nearly-constant velocity example -# -------------------------------- -# We continue in the same vein as the previous tutorials. -# -# Ground truth -# ^^^^^^^^^^^^ -# Import the necessary libraries - -import numpy as np -import matplotlib.pyplot as plt -import time - -from datetime import datetime -from datetime import timedelta - - -from stonesoup.functions import gridCreation -from numpy.linalg import inv -from stonesoup.types.state import PointMassState -from stonesoup.types.hypothesis import SingleHypothesis -from stonesoup.types.track import Track - -from stonesoup.predictor.pointMass import PointMassPredictor -from stonesoup.updater.pointMass import PointMassUpdater - -from stonesoup.models.transition.linear import CombinedLinearGaussianTransitionModel, \ - ConstantVelocity -from stonesoup.types.groundtruth import GroundTruthPath, GroundTruthState -from stonesoup.models.measurement.nonlinear import CartesianToBearingRange -from stonesoup.types.detection import Detection - - -from stonesoup.predictor.particle import ParticlePredictor - -from stonesoup.resampler.particle import ESSResampler - -from stonesoup.updater.particle import ParticleUpdater - - -from scipy.stats import multivariate_normal - -from stonesoup.types.numeric import Probability # Similar to a float type -from stonesoup.types.state import ParticleState -from stonesoup.types.array import StateVectors - - - -timePMF = [] -timePF = [] - -# %% - -#np.random.seed(1991) - -# %% -# Initialise Stone Soup ground-truth and transition models. - -kf = 10 -matrixTruePMF = [] -matrixTruePF = [] -MC = 10 -for mc in range(0,MC): - print(mc) - start_time = datetime.now().replace(microsecond=0) - transition_model = CombinedLinearGaussianTransitionModel([ConstantVelocity(1), - ConstantVelocity(1)]) - - # This needs to be done in other way - time_difference = timedelta(days=0, hours=0, minutes=0, seconds=1) - - - timesteps = [start_time] - truth = GroundTruthPath([GroundTruthState([48, 1, -5, 1], timestamp=start_time)]) - - # %% - # Create the truth path - for k in range(1, kf): - timesteps.append(start_time+timedelta(seconds=k)) - truth.append(GroundTruthState( - transition_model.function(truth[k-1], noise=True, time_interval=timedelta(seconds=1)), - timestamp=timesteps[k])) - - - # %% - # Initialise the bearing, range sensor using the appropriate measurement model. - - - sensor_x = 50 - sensor_y = 0 - - measurement_model = CartesianToBearingRange( - ndim_state=4, - mapping=(0, 2), - noise_covar=np.diag([np.radians(0.1), 0.1]), - translation_offset=np.array([[sensor_x], [sensor_y]]) - ) - - # %% - # Populate the measurement array - measurements = [] - for state in truth: - measurement = measurement_model.function(state, noise=True) - measurements.append(Detection(measurement, timestamp=state.timestamp, - measurement_model=measurement_model)) - - - - resampler = ESSResampler() - updater = ParticleUpdater(measurement_model, resampler) - predictor = ParticlePredictor(transition_model) - updater = ParticleUpdater(measurement_model, resampler) - # %% - # Initialise a prior - # ^^^^^^^^^^^^^^^^^^ - # To start we create a prior estimate. This is a :class:`~.ParticleState` which describes - # the state as a distribution of particles using :class:`~.StateVectors` and weights. - # This is sampled from the Gaussian distribution (using the same parameters we - # had in the previous examples). - - - number_particles = 130000 - - # Sample from the prior Gaussian distribution - samples = multivariate_normal.rvs(np.array([48, 1, -5, 1]), - np.diag([1, 0.5, 1, 0.5]), - size=number_particles) - - # Create prior particle state. - prior = ParticleState(state_vector=StateVectors(samples.T), - weight=np.array([Probability(1/number_particles)]*number_particles), - timestamp=start_time) - - # %% PMF prior - - pmfPredictor = PointMassPredictor(transition_model) - pmfUpdater = PointMassUpdater(measurement_model) - # Initial condition - Gaussian - nx = 4 - meanX0 = np.array([48, 1, -5, 1]) # mean value - varX0 = np.diag([1, 0.5, 1, 0.5]) # variance - Npa = np.array([19, 19, 19, 19]) # number of points per axis, for FFT must be ODD!!!! - N = np.prod(Npa) # number of points - total - sFactor = 4 # scaling factor (number of sigmas covered by the grid) - - - [predGrid, predGridDelta, gridDimOld, xOld, Ppold] = gridCreation(np.vstack(meanX0),varX0,sFactor,nx,Npa) - - meanX0 = np.vstack(meanX0) - pom = predGrid-np.matlib.repmat(meanX0,1,N) - denominator = np.sqrt((2*np.pi)**nx)*np.linalg.det(varX0) - pompom = np.sum(-0.5*np.multiply(pom.T@inv(varX0),pom.T),1) #elementwise multiplication - pomexp = np.exp(pompom) - predDensityProb = pomexp/denominator # Adding probabilities to points - predDensityProb = predDensityProb/(sum(predDensityProb)*np.prod(predGridDelta)) - - priorPMF = PointMassState(state_vector=StateVectors(predGrid), - weight=predDensityProb, - grid_delta = predGridDelta, - grid_dim = gridDimOld, - center = xOld, - eigVec = Ppold, - Npa = Npa, - timestamp=start_time) - - F = transition_model.matrix(prior=prior, time_interval=time_difference) - Q = transition_model.covar(time_interval=time_difference) - - FqF = np.linalg.inv(F)@Q@np.linalg.inv(F.T) - - - priorPMF = PointMassState(state_vector=StateVectors(predGrid), - weight=predDensityProb, - grid_delta = predGridDelta, - grid_dim = gridDimOld, - center = xOld, - eigVec = Ppold, - Npa = Npa, - timestamp=start_time) - - - matrixPMF = [] - - start_time = time.time() - track = Track() - for measurement in measurements: - prediction = pmfPredictor.predict(priorPMF, timestamp=measurement.timestamp) - hypothesis = SingleHypothesis(prediction, measurement) - post = pmfUpdater.update(hypothesis) - priorPMF = post - - matrixPMF.append(post.mean) - #print(post.mean) - - # Record the end time - end_time = time.time() - - # Calculate the elapsed time - timePMF.append(end_time - start_time) - - - # %% - # Run the tracker - # ^^^^^^^^^^^^^^^ - # We now run the predict and update steps, propagating the collection of particles and resampling - # when told to (at every step). - - matrixPF = [] - start_time = time.time() - track = Track() - for measurement in measurements: - prediction = predictor.predict(prior, timestamp=measurement.timestamp) - hypothesis = SingleHypothesis(prediction, measurement) - post = updater.update(hypothesis) - #print(post.mean) - track.append(post) - matrixPF.append(post.mean) - prior = track[-1] - - # Record the end time - end_time = time.time() - - # Calculate the elapsed time - timePF.append(end_time - start_time) - - - for ind in range(0,kf): - matrixTruePF.append(np.ravel(matrixPF[ind]-truth.states[ind].state_vector)) - # print(np.vstack(matrixPF[ind])) - - for ind in range(0,kf): - matrixTruePMF.append(np.ravel(np.vstack(matrixPMF[ind])-truth.states[ind].state_vector)) - # print(np.vstack(matrixPMF[ind])) - - - # for ind in range(0,kf): - # print(truth.states[ind].state_vector) - - - -def rmse(errors): - """ - Calculate the Root Mean Square Error (RMSE) from a list of errors. - - Args: - errors (list): List of errors. - - Returns: - float: RMSE value. - """ - # Convert the list of errors into a numpy array for easier computation - errors_array = np.array(errors) - - # Square the errors - squared_errors = np.square(errors_array) - - # Calculate the mean squared error - mean_squared_error = np.mean(squared_errors) - - # Calculate the root mean squared error - rmse_value = np.sqrt(mean_squared_error) - - return rmse_value - - -print('PF rmse', rmse(matrixTruePF)) -print('PMF rmse', rmse(matrixTruePMF)) -print('PF s', np.mean(timePF)) -print('PMF s', np.mean(timePMF)) - - - - diff --git a/docs/.DS_Store b/docs/.DS_Store deleted file mode 100644 index 28ed39947384e21cde2abf87863c6a86232d68a9..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKy-EW?5T4abjL{-kMX)#?n24+Lzk^%X3|Lacm*H2Mxc zf)6015&UL%HMhA->_lY-cE7#ZnZ55m?sjg8NYzH|Dp84u3TTX85nYDyJhzmsSPv&F zy~dz>UTw6Zu$PJzhbSNl{5}Qv+s#p%Iuuew?)O_++i_>Zs9tY&qb9tCtLyFkhnJ_r z`H1#65xtAV4pBKqIyRRMjp+;wdep_(Ws#qrWW^-8OmF=0-sUrNNv#z|Ge*ne9A1S9 zmZCfjpvj2()S`~FvrTS!?{nspe2RCwpM2(GW?UAu|IFetq!APTq&teVRb*Z-r%=YKazu0#P*;9n`Ayn3Tv!z0<+ntC{{wLaP@ q8XL#O26YNLTaNXBOK}-Z8|Hi-00WD$L3m*DBcNrFP89f21wH^#uaT7i diff --git a/docs/examples/.DS_Store b/docs/examples/.DS_Store deleted file mode 100644 index 20f8592cc74c91040103ae54139faa141aefc235..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8196 zcmeHMziSjh6n=9a++jith>*_3!YT+U?ALIIrG|8ZkV`I}!Q0*Qet>oh)`DFs!Bzrd zYh`C`qhKM3jZHxCFYtRahchtqj+yW6eD9mtnY;}ViS1c`i)e$0+Nf;T zSI|@x9_Lyr4c)T|E8tHwp(A>LdO*DrI&=gDf&xK-pg>R{DDW>RfIFK@+jH-$(a;73 zf&%}g0(?G1sBB|xb7kexfkvhPuuXK!g5Sskq)()6tZlBWbj34O_8_#Y&?$yc?ii04 z4jXHmE34c|D0dQiW}!0_A!mnQB;h1tWkVYj2nv)H;M{!;`T7tU`^3rLr^w?oN{553 z-GeL{Dc7fY9X`oAo&9094_y1x^Q%v;uYR~y@6$iG&*-7e>jkEuOVmdcQ&bg8hPaMo zPemlJFV}n6v7Ej#vDD0yWJNTKd%5v^0Tgng-9nzR4?n<28DI;vYpBvF1?dP6j z^Yf!zKHn6d1Wcyj!e4mi_}sfK`8WuhPr7FDnL`zN`B5$(xgDF>pBeUS;wtm?2bh>K zo5TA}!@(LaX0IA|=2f>#n?`tFGAP0CNJ;~D{>Rrae+gsrc-J*NDsTm#vv3YCdaKrF zjP>c^EvVQpH(tv1aS%4AXFG;d3oC3r$KEe~l*=b(KC4)pL+CSsR%7htnLJ-gXG-_x zSy-E6>%QUWeYVYZ5xmQLk9n^oH)g)eS}rK!8=U`l{ontWHN3z(C=e7_KmpP0>~{DP)46rYrGcIZ zbq|#b<8o!C3mS%Szf8wr%|8tBJJMsZwz;wxLHX++0_Oe}an5rUVeOq9?tgQqhWlUc I;^k|815Yj>SpWb4 diff --git a/docs/source/.DS_Store b/docs/source/.DS_Store deleted file mode 100644 index f7bfcf8ee86195c1907d789a9377b6e378ddb3d0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHK%Sr<=6uq&9ma5RD3x&`{UH1bDV;L82y6H*~>=P9_Orb4wIXeYEL~!L#DEbNF z#y{{EJU5B$OlR79a@AG5xVE$ATwCUxLAX%3 ziJQ=A)=TK@&=g8T=Ag@LV7f1sf)rBc&WZQ1h&`~?Fm3$QW3KQQnQ zd;;%mE7+u^g#jUSSDl}8xwD_26uTxO)wt}|i9#X@;Ec5`G;@sKvsY}*Mz~n%F;bf1 zms5uZm23%!0;0fQQ-HtS9qLj_Gr0Txa%6ySc{J?TTb(=`8|uxC;gqptwK8p7Jp54i)3EXEcKgLnf|rW9yOmAzspQ;v9K^J0sIK~ql3UOtqaS=k$k z(z9dyNT-vE4N5BthyuO>>*g`!^Z#V={ogN=Cs9BY_*V+3TGC1ySdu+kD~sc^HiEa| pY@AmZ)Fo)_ajZLh6z{^d;a$WXU~I84hzLyn2xu9k5e0r#fp}AC diff --git a/pokus1.py b/pokus1.py deleted file mode 100644 index 1f2bca21c..000000000 --- a/pokus1.py +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -Created on Tue Apr 2 08:53:07 2024 - -@author: matoujak -""" - -from stonesoup.types.state import PointMassState -from stonesoup.functions import gridCreationFFT -import numpy as np -from numpy.linalg import inv -from stonesoup.types.array import StateVectors -from datetime import datetime -import numpy.matlib -start_time = datetime.now().replace(microsecond=0) - - -# Initial condition - Gaussian -nx = 2 -meanX0 = np.array([20, 5]) # mean value -varX0 = np.array([[0.1, 0], [0, 0.1]]) # variance -Npa = np.array([21, 21]) # number of points per axis, for FFT must be ODD!!!! -N = np.prod(Npa) # number of points - total -sFactor = 4 # scaling factor (number of sigmas covered by the grid) - - -[predGrid, gridDimOld, predGridDelta] = gridCreationFFT(np.vstack(meanX0),varX0,sFactor,nx,Npa) - -meanX0 = np.vstack(meanX0) -pom = predGrid-np.matlib.repmat(meanX0,1,N) -denominator = np.sqrt((2*np.pi)**nx)*np.linalg.det(varX0) -pompom = np.sum(-0.5*np.multiply(pom.T@inv(varX0),pom.T),1) #elementwise multiplication -pomexp = np.exp(pompom) -predDensityProb = pomexp/denominator # Adding probabilities to points -predDensityProb = predDensityProb/(sum(predDensityProb)*np.prod(predGridDelta)) - -np.hstack(predGrid @ predDensityProb * np.prod(predGridDelta)) - -prior = PointMassState(state_vector=StateVectors(predGrid), - weight=predDensityProb, - grid_delta = predGridDelta, - timestamp=start_time) - -a = prior.covar \ No newline at end of file diff --git a/stonesoup/.DS_Store b/stonesoup/.DS_Store deleted file mode 100644 index 2e5c42ea18768b5dd64976ed56cffe4b8ab6344a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12292 zcmeI2PiP!f9LImN>1H=3#(2{95H~@r9@4ZZf(LOEVi2h{A&r6x&E`)xX?J(xZc@{e z@!%g2BsbB61doECQZEKEPz)6c_Mit3f{=@UN-3fT!Gho4o7tJ)Z|2Qxym+uP%goz- z@B7~Syx+Wee}1z>AEl9O%R44Wd#nLUC|9pRuD0F%HgGaN_#_(IPET zjaqegVj`xFN8}L2`48djraHMY|9s7zM|^SV*pXlU?A{wYvy5N8bOxVQx=0K7l^vM{ zG7DrD$SjaqAhW=~w*b$*IEs<^-iytynFTTn+}0N0@u7s1!?GW4dZ8IQaN#99r}Glt z9QclXfQc*lVc8Eiz0hPu9+MTJm(}sN7><`6@gpL{wz&y*Lrwdq*{^%O_?Q-1@31uJ98F1+#y@Koy#Y2VkjI z6$<(so`$G`eFW`#PWKIOJkYl{W>bYU$tuO;4G=|m32LBOqqFcUG^wtLoO^Deu9!wL z-M@49RiDYC3#rm{nyZW|=?=r?f@=-Fk|rHvZ^r7{{K@s!S`+Cn(G0vBEt6WOQF;#a zHCl&tmk>Kj6Q)fSe4E8@E`ej~Cf0@S9KM}5I@X^@IWB4#r73!bCgD5bl3i*b=M2Ji zoTK>GjvMh~S80S!qNFOyuRvX;6%%@;jEYiPrWCFrxoxNXwprVbd_5p+f@l%7EEtxi zx?8!=p0%sw{TetL^dU=mbvL|Y=}L9ay0%a;o)4z081IxOb1w5Vq|znRYlnx9e)PlF z+^-{=kS0l^IC7sbN3Y?i;;wPeZo*oo2Pj+7^OMO+$@cg=_Zq)inmVNHtX8~aczX(N^^>z5Csc8f}Ko_f*q|0O9y#&o6Pb4n-c|?$DNYHj4ch^!zzz9ii-YX zpKOX+GQRp=@GPX}2=C{?awvZ1IP>>_e+As$bHZa=g!{&zT_t%7Uj@9D@}y%YnM7f>_LHw&1zVtwsWEv}FpSEd4vNI8-vS*3WqH3YZkB1{9DD0>uc^4K6! zCCAeX0f$nA{o1Zj_4oMpSV!%Pu!vbJD2Duxv&J4uAvtE>4>$^V5=1%D9?2@jQz)Y{ z+#glWg_&lf<~FKGVS9DsJrK`HaaMZhdNf^<{(N;foFzpWk~2x8D7{{csYx)j(!rAl1_b zwP(#pInut#D#b%v%u_TxJb&AE;q=HCHjQK&9toIYWl5?uo#ry@(r);0mhjDU<4)w?u96Xn5PuFp7Pe0nTLK~p>?Ckj@*U`4vxbxUso%FIT<3~#Eu{yO5 z*^-0iQSG2PZin*cs}U;y#|LPg@i{=3XEFZ!$M65~9w6!ee@2vDG7DrD_J5S<}GT8Snl6>>+N8!TZtK`Mk4K?sB*tBn-seI7av!6CQ+H77v4@eGMA zR+I?AjATE1zPnFadxwa4@~{{ajfkki5M){Oh^Tu_d*&=gmK@{hb=R(YQv6<(Jo|uJ zx}~etGJpQrY-ee^rkSt0IW*P9>(TM!%kxQYU3_C*-{n~oIs0@^H?+aLDQiBRB5P*N z>z==F>lPln+Wj77{WlbFaR!_LXTTZw=M3P^7O4&teRc+%0cT*#fSeBjL$EfCis|UU zlv)7b5b7jYOD`cY!LT-rim*UfLxCE~)?%=RLm$kqHjIiIPHfEwTjkH@g;RCp56PXl zR`l5!a0W^S2D%-|{eOl}rnkv2Lwx59I0Ju-0UkEfW`dWpyYLRr~==d;}|h!SmT_ z6QxCBL4a(@_bxvB&UshGH4*Xfel{TL6Hx>e?CfFKAhIsnl8#wqk;5JtT~kJLx}kB) z+YP^w0lB*#Zfi^>J>u@x_cvbL4n|p#mzwsQ;1&E$Ej};Sxoq}mo5=ySb?yH0yUJa z#b6DGJvP6@uu{};VrxFwPG;*+IGm3CYHyx!zypq+{*2_t)jiD2$i1<~Cs}M{=DTc3<;sa<9*khRh6T?ao7Kr~52sHTM I4E!hqpK8%hQ2+n{ diff --git a/stonesoup/gater/.DS_Store b/stonesoup/gater/.DS_Store deleted file mode 100644 index b0033fdaa6638caa70cb43a004108816e15632c1..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKOHKnZ47H(1BVf@b%dlbJ6A-H51ie62ieS-dDizc&b2T=ci}zxKm@*Q(3L#sH zpW{zv=1G&pBBImRsv)h*y{&fl*0H<#~shZ)OXGB(fO8JeR!Bp3(=f`MQl7&s~eII}6u zg<o4WxH{o3xrM>Pml{pjOZd^0#R_;cgY1jg)0x`tFBU~b(qI^>83SVO!O%%0kAPa&*sUKXi@R>K9a!dXy6oN0 zHF2o^~KxlFk|&g#`f_$L$4^OkO4A42FL&zI4T1; zvrU?FM`|)a2FSn{1GYaT48t5)9P8ErRa*d{U8AeOwqC+&QeX}&j)*|iOMzah(PF5V z!=EBA2Np*!7memaFX4;~kb!^3 zfR5TlJI6=G+4|%2?5s@~w-_dtSImIG?p*>fv3=yS20fn4MqUmqj<$;VYdMe?fl3G! I88`w1AB{*VasU7T diff --git a/stonesoup/initiator/.DS_Store b/stonesoup/initiator/.DS_Store deleted file mode 100644 index ae5c9b07d20ce2e2d58f1ceb4febc88360ad15ac..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKOG*Pl5UtV(17?w>%XZ^FH)z9nf?PmiL>CDYh>FW5x9~QuJceiSRaYra9B>yz zUP0HZu20X*gX!)P5uH6Qry>&(X)p}ci~%wCVCX24M?kG>Y*%;F`AxUp3@r0EUH0xL z@*uaeh2DGrm#2Po-L>sv)h*y{cJJrU7vs0f{fy-=8Jqj>485Y9LI%hH86X2>;GhiP z%rwjVFILFIu0D^;1n3&5qP$fDH5*IQHqp&i^a?GJ{Qizl1X~KnDI9 z13GHw?F=6kXX}s8v$Hl~Tw|D6UNHj#yLSn|#P*T18uWNF8+kdfINB=aujN2q1S%m^ JWZ(b{d;sp@D=Pp1 diff --git a/stonesoup/measures/.DS_Store b/stonesoup/measures/.DS_Store deleted file mode 100644 index 7c425774bb63cfd36857d176fc1bd12a2b5aa329..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKu};H44E2>F6?N&z#%+kPe-Nth1ziwJpeag+Mu{j4Z1*ut{0Be4#sULB!t>c` z6Q!-hf&kf)?_GTMo%61WYa()!hxwRjL_`B9WA6yV7U6Z)j`Y;Ri!Sbw(+%acpj*m1 z-e&lX4Dj6zaa$SH^oYCL++VgV#?!niYe~Df-0fLbX*ye0Gk6SMuio#9&zJLk-t%vG z*NbKr0o&3c-QyRw!nkVZm_4;un|YmmoSZ(t#k?BvWwps=yGQ$cC4;jtF$RnQW8m*H zfSS!R7z)~G3>X8(z?uR6K6of&q8J3@rvpQ50f2p&qhQXx1jl&AL@@|r1>z(WD4|YU z3@72Rd-Y2cgP??y)8@nJ$xb^I7f;9jKDv`j1Z^}1jDa=-2l6=L`hT&z|8ED`lQCcn z{3`}rKb@u%ypq<|&dYJF_0Vf53;PAZbqFT06vLNG@i8%# Ifgfey3;b16tN;K2 diff --git a/stonesoup/models/.DS_Store b/stonesoup/models/.DS_Store deleted file mode 100644 index f5a947452f8f91f00a2df45ffd298f9322bfc5b8..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKKTiTN9Q_qYB;decqT{*9OycNpl5;oV10eqbBnL+hNp!oa|L88g4;PIBBN_|fI^PZ zq6W37M;)ps*tWttU>*2x4)EG7Vzt(3M3-2*i}hRgyOmli7>tzU!?$$iAn?4VA2h+p z9}REA?(6;WPwe(D*rQ%t#VL!~f=-_%bdFz#LQ|(ROETj6{IXNNno503;|zj@u5^~E z$DQj5Xb3eS1>`U3dDnVkjJTe1_hs%E>oHghJ?Sh{kDRO`4d6?I#_%XaiLao}(3SPn ze6D21_3d6y@1ACQK2z8Wo$J|#-s~g$SCGy%pm9unl>6+0p;ml6r}5%itDpN92ip() z?yVd%KEvy{hR)0l>)f=!p3RXjDzv$Ez&c=4)3RGQ%K4J*f9rcdpxf&ydsyhjN_z?QeLZ48CygSBsB%Oq-(B{?w z>wxJ%R$Z2O|367T|C^ocm36>6uu%?(oLBRzcqF~ICLWIWS_kP2iGz8OLfHkG-i~E~ ex8gk{8EA930CY7*3Q>cwe*~lrHnR@=sRLj2G01!X diff --git a/stonesoup/models/measurement/.DS_Store b/stonesoup/models/measurement/.DS_Store deleted file mode 100644 index 5cf8e9813069d86bfb75238fec42626d2b6d6ff8..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKu}T9$5S@t;B56`uZXsCb56*CcWgu9eCXuAbae-)Ici$uU1NQ!c;1}2nHmQ8G zGbDFhBG`z^49vd0*_qwBx8Wv3L_D~g_KA8#)PV}Njxg*HSr=_c%Pg|UagPaI(}d=9 zLs`Sy4Zo2Axw|fIE2Aadg>3C6$;n8^~F0o&3c-QpLv!nmsC$R`)|)oNZ(m*Mo`HRe@`&x>s?yFKdX59NFpXTTY7 z2L3t&sM#W&uA-05fHU9>Y#5O5Lx2hom|tR8Drz{fH6LszvvnvOPsjctx)UdgJ~{)=K%IdDee6m7Ki}W~*Mt1Y z8E^*v6$9K(N9h2sWVN;Na#Cvp^a?5>ex>3j1d~{b;VY&17#aokAQNC>SSrE-@jn8Q L1|OV(A7$VZ8k|*_ diff --git a/stonesoup/movable/.DS_Store b/stonesoup/movable/.DS_Store deleted file mode 100644 index 70344b36f1ef73e1e28bd626cdb1c6336f8bf9fc..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKK~BR!475u@L0o!4>TwT9ocn{N3NPpfpe+KVhOJb%=gtGT@c=Hof=3|n7{+TW zB&3KNLSRewY`pgFW)j6Q5%KhHIUyPmQ2`Yk?89sknHTNJz&tX^d5uIjtINr3k*aks z+8w`<0eN_DTD!O$BoPjO_J9;^i`hWU$|KCmWD`&tN_)`q+JS-pxs^ s4bU^Fh{V;3O$r9N6eCtj@fex~{vZQjG^`cjf%uO=rooLf@S_a80}X~s)c^nh diff --git a/stonesoup/platform/.DS_Store b/stonesoup/platform/.DS_Store deleted file mode 100644 index 9a53a118d848d81354a3dbc853f5701d6c18c5d5..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKJ5Iwu5S@V_S%M;>qFhlxof|BXIYBOff)Nyn94R1C?qR6l9B>DE=s5y!K1A5K zj!-ls?Y!CftnDXxcZZ00`p`^>#za(M2(lOhBJRP^o+pohtZS?nx0Bg*yIc({^FLkk z?gw;FH?)TCz5nM&H@a-=dcJ7q@K!I!r;lexZx{WHW|Pr8rSs?a*Ivfz?l(hUC=Uq+ zf`MQl7zhTo$^g!6k?PPebTAMM1OpoeWPeB)g4wYc>ed0JEdWri(JHW|mynv|m>r8D zA`rGxpp~+>7;NS6C(p}{#n8%$z4>5&@=x=^`gE+HG@Upbh7JaTfi464HXY0Pe~Dja zu*vV1#4H#H26l`A9@VpYijVTM_0Q+YS(`AfFhnG;hyj7!xddP$`^bp~EuO?iUUn>o UvWob%92ge?B_vcZumuLb0RvYo82|tP diff --git a/stonesoup/predictor/.DS_Store b/stonesoup/predictor/.DS_Store deleted file mode 100644 index bf511d63fcf003677aa416fe974c96266d529912..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKF;2r!4D~NX3MxZK#xDl;+#pop0H_B@X^PsRQ6d$jvSs8F%p8Fez&*G@XC|J% ztu|3wBo+k7mhAWAZzs!UVIgiJASyWQ?%g5o!qD<52vYa9!dcU4L%{Q-?Uq#Qp z5nU~sRcyGJW4gmHYz4b&*Z90{?>6gtjL+8>r*B+WBfltjwd{}B^L(zs#27FJjDf$* z0D3k{G!Qgv3>X8(z?K32K6of&qF4)tPX|mG0f2qDgJ8|G1m}3gM6njc4#Y_+P*R
Yn21FD?8KV;nJA~I+I})jd7hTvRr(4QtL3fmO zyv^_%8Q{C?*1yno)$*Ke1Hyl3C= zt`^NKJhr7{dcZGig>lu+@%*~I+RSS`JR81#guEK@MY+vow@3T@3qIe(7%&EmfxpfG zYBo#U7qrnBFb0f)4Fmjr@KDA?u@(%U4h*3M0QO)Gf;smR9OD%e#aa+65GSEP33b|H zI0=W{t6!p63raXSZ9bfy?6gC1;dJcpgFCrI&_-jx7-%zaB##5G|5y9_|8|f)83V?^ zzhc1k(n&hTD`{=*yd2kB550l1uwN~>3Be?kV)$|?K7|H>-SY&PDAt0oK;M1& diff --git a/stonesoup/resampler/.DS_Store b/stonesoup/resampler/.DS_Store deleted file mode 100644 index a059a4de56d06a9fa7a14b0a2294b323f3cf4b9f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKu};H44E2>F1$F7jcudS4kdRnH6~3S!kkS;jL!(3l2DW?;AHWyDMaA~(LDkBEju6hRq#M;NvUud{ZfrxspxVULWiDWfIb zP}=b}!*67O?{0wGN~xp=+}-B>($!)#$%?#`w9l*Eo_UcZvsE#JNA&UZcDs0gx%|$1 zzK3_YY-TaVwsc5$_=Qy%t9FhT&+XM_UY}kOb~vqS?y8;t>Dz!+FFz~2WCWsDVT!SLz85Ly6WALbyKb1%U$UNKgz1+fBg5(<=1 zr!9t)aM->2#fr6{gp~Qv zA9$CGdKMnr(gEG$7q-HSEUCYS9V&GS$Bd=q297%&F@ zI|HcMEYU#FMq|JjFb38P@b|$(8DqstFnl^NgcbnUgE9clZCzAbT|ns_+1)V;5-4pVTx}TO`Je3=BL655OD1JMcDq zK3n-AMQjy9w&eRPpMAc3iQ-&{$n_q^U7`*V1yIJuHii|#^Q<*#sHY8devOo-liO~8 zlqR!Uv^oAF1AOl`>54|Qqlm7kLc zOHhmE*vS|=uRKL=YgC;r?_n^__fX1D(qadzBP#Ws)?d;!Ex^R0QlD;mzJ5MiuK#>` z_;}R*U=CWy)ckV&y{FO+^N31)dO^C0F<=ZB18ZXdJ)0%i5wy`5Fb0f)oB@76giyvv zF%$H!4h;SX0BpkT1vZ~0I44w$6f;50K%ArkCDrK>!$~^)p~gjunV_VT^OdoVR(5(q zalShIp${kdvJ=}zW55_F7}yZA!~OsK=ktG&V3sjp4E#F=Tr=*+{EI_-YwZHL*M`tb pC=17Bf_VxI--;2-t#}{m1^$o?V5FD{!UM4n0Z)Ss#=u$`_yX+iWuX87 diff --git a/stonesoup/tracker/.DS_Store b/stonesoup/tracker/.DS_Store deleted file mode 100644 index 7647a2d276261c7e59843107c654da772a39e269..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKOG*Pl5UtV(1G>o4W#hvA+@KRia32p4Vni1S6NrM#zL3}OT)yfm#fby%g2*fA zde!ymnRzhXT_U3Mmvt*L6_EzRP|X+-bB~5jB6$MTy2gI<*e>q-?QUe5zv;4fKa*#9 zkUjLl``^9|quajgR-1kWZ`0mgyKiz~Wf94yf7!0PPxG1-A7PR+9pAU~xnQqFxI0QjHcvy&V1& zc{#8+dbwyc9~w{oY+kfJ9qXr>E}9*w$p9G`GH~kCsh$5f_+>_${9y@aWPl9(GX`|h zExI{AD$dp)pJ!)n!dPOMSY9y$0()=?z{K{E%Nq1}G8=h0usGT(=C9>IUIZ#3RAk@? G4155JwjcM(UwQk4q2s+&_pVenC%OOl?~1p&@}P=q-5k68R`4P!GcfZuvopJswBzk*>!MQ2^DmWvs<5EvN#`%(mX3%*}?tI_Q))cIBOW0+-#*==1wp!QSY47TA|B36Wr9I~#vPF>ZJGtqUg;393~Flz}<} zEqUy6|3BV-|F0+MlQN(T{3`}bJ08VDypr#&jhEwIYoTK(3&)j$s}u}eD@H80;sdA? Y_$?2Bk;hUH9*F)Ca5Sh<27Z)*Pr|!XcK`qY diff --git a/stonesoup/updater/.DS_Store b/stonesoup/updater/.DS_Store deleted file mode 100644 index faf04181201f6d905be0ec048f05d10644cf4797..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKF;2rk5Zp6C7=_$dGK1?;pQuMfK9tehMzcV8oV6vMSUAqud!%$mDUIoz zl9sm`{vrdsc3rGiLM7c{?dI#3OtOBQ7I|sYHd`mlJ@X=phLd6l5C7@y;vxHZIr_H{_VleE9X^<&x?63i#;0idyn7dzlkrW zlVX~|j4wb%8Bhk4fi+=3b}YZUCi!Zb%78MkItF-u2%wCiVCB$m9cXL`0PMhQ1%0k1 zI3^Gb1uKUbffyGGbfHE{4CBIK55zANtQ@*11 diff --git a/stonesoup/updater/tests/.DS_Store b/stonesoup/updater/tests/.DS_Store deleted file mode 100644 index 5008ddfcf53c02e82d7eee2e57c38e5672ef89f6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeH~Jr2S!425mzP>H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T0TJ&24Q^rBEah{qoNKJS`?-t48He4qDu-@I?$?Lcea*Ta9# z>9VRR1;f-=)Uskm3x7SjPTvg5k}cbo`^vpgbGO`x6SCEcs;yf5AsNeBifP3vvaMIw zv~m Date: Wed, 30 Oct 2024 16:06:15 +0100 Subject: [PATCH 14/16] Changes based on review. Making the code work again and fixing pep8 issues. --- functions.py | 199 ---------------------------- stonesoup/functions/__init__.py | 9 +- stonesoup/predictor/pointMass.py | 11 +- stonesoup/types/prediction.py | 6 +- stonesoup/types/state.py | 1 - stonesoup/types/tests/test_state.py | 3 +- stonesoup/updater/pointMass.py | 29 +++- 7 files changed, 35 insertions(+), 223 deletions(-) delete mode 100644 functions.py diff --git a/functions.py b/functions.py deleted file mode 100644 index d7c2b9030..000000000 --- a/functions.py +++ /dev/null @@ -1,199 +0,0 @@ -import numpy as np -from scipy.stats import mvn -import itertools -from numpy import linalg as LA -import scipy.special as sciSpec -from scipy.interpolate import RegularGridInterpolator -from scipy.signal import fftconvolve - - -def boxvertex(n, bound): - bound = np.flipud(bound) - vertices = np.zeros((2**n, n)) - for k in range(2**n): - for d in range(n): - if k & (1 << d): - vertices[k, d] = bound[d] - else: - vertices[k, d] = -bound[d] - return vertices - - - -def measPdfPrepFFT(measPdf, gridDimOld, predMeanEst, predVarEst, F, sFactor, nx, Npa, k): - # Setup the measurement grid - eigVal, eigVect = np.linalg.eig(predVarEst) # eigenvalue and eigenvectors, for setting up the grid - gridBoundWant = np.sqrt(eigVal) * sFactor # Wanted boundaries of pred grid - gridBoundWantCorners = np.dot(boxvertex(nx, gridBoundWant), eigVect.T).T + predMeanEst # Wanted corner of predictive grid - gridBoundWantCorners = np.dot(np.linalg.inv(F), gridBoundWantCorners) # Back to filtering space - maxF = np.max(gridBoundWantCorners, axis=1) # Min/Max meas corners - minF = np.min(gridBoundWantCorners, axis=1) - gridDim = [] - gridStep = np.zeros((nx, 1)) - for ind3 in range(nx): # Creation of filtering grid so that it creates wanted predictive grid - gridDim.append(np.linspace(minF[ind3], maxF[ind3], Npa)) - gridStep[ind3] = abs(gridDim[ind3][0] - gridDim[ind3][1]) - measGridNew = np.array(np.meshgrid(*gridDim)).reshape(nx, -1, order='C') - - - GridDelta = gridStep # Grid step size - GridDelta = np.squeeze(GridDelta) - - # Interpolation - Fint = RegularGridInterpolator(gridDimOld, measPdf.reshape(Npa, Npa, order='F'), method="linear", bounds_error=False, fill_value=0) - if k == 0: - filtGridInterpInvTrsf = measGridNew.T - else: - filtGridInterpInvTrsf = np.dot(np.linalg.inv(F), measGridNew).T - measPdf = Fint(filtGridInterpInvTrsf) - - gridDimOld = gridDim - - # # Unpack x, y coordinates from measGrid - # x_coords, y_coords = measGridNew - - # # Plot the data as a scatter plot - # plt.figure() - # plt.scatter(x_coords, y_coords, c=measPdf, cmap='viridis') - - - return measPdf, gridDimOld, GridDelta, measGridNew - - - -def pmfUpdateFFT(F, measPdf, measGridNew, GridDelta, k, Npa, invQ, predDenDenomW, nx): - # Predictive grid - predGrid = np.dot(F, measGridNew) - - # Grid step size - GridDelta[:, k+1] = np.dot(F, GridDelta[:, k]) - - # ULTRA FAST PMF - filtDenDOTprodDeltas = np.dot(measPdf, np.prod(GridDelta[:, k])) # measurement PDF * measurement PDF step size - filtDenDOTprodDeltasCub = np.reshape(filtDenDOTprodDeltas, (Npa, Npa), order='C') # Into physical space - - halfGrid = (np.ceil(predGrid.shape[1] / 2)-1).astype(int) - - pom = np.transpose(predGrid[:, halfGrid][:, np.newaxis] - predGrid) # Middle row of the TPM matrix - TPMrow = (np.exp(np.sum(-0.5 * pom @ invQ * pom, axis=1)) / predDenDenomW).reshape(1, -1, order='C') # Middle row of the TPM matrix - TPMrowCubPom = np.reshape(TPMrow, (Npa, Npa), order='F') # Into physical space - - # Compute the convolution using scipy.signal.fftconvolve - convolution_result_complex = fftconvolve(filtDenDOTprodDeltasCub, TPMrowCubPom, mode='same') - - # Take the real part of the convolution result to get a real-valued result - convolution_result_real = np.real(convolution_result_complex).T - - - predDensityProb = np.reshape(convolution_result_real, (-1,1), order='F') - predDensityProb = predDensityProb / (np.sum(predDensityProb) * np.prod(GridDelta[:, k+1])) # Normalization (theoretically not needed) - - - return predDensityProb, predGrid, GridDelta - - -def ukfUpdate(measVar,nx,kappa,measMean,ffunct,k,Q): - # UKF prediction for grid placement - S = np.linalg.cholesky(measVar) #lower choleski - decomp = np.sqrt(nx+kappa)*S - rep = np.matlib.repmat(measMean.T,2*nx,1).T + np.c_[decomp,-decomp] #concatenate - chi = np.c_[measMean, rep] - wUKF = np.array(np.c_[kappa,np.matlib.repmat(0.5,1,2*nx)])/(nx+kappa) #weights - - Y = ffunct(chi, np.zeros((nx,1)),k) - xp_aux = Y @ wUKF.T - Ydiff = Y - xp_aux - Pp_aux = np.multiply(Ydiff,np.matlib.repmat(wUKF,nx,1))@Ydiff.T+Q.T # UKF prediction var - return xp_aux,Pp_aux - - -def gridCreationFFT(xp_aux, Pp_aux, sFactor, nx, Npa): - # Boundaries of grid - gridBound = np.sqrt(np.diag(Pp_aux)) * sFactor - - # Creation of propagated grid - gridDim = [] - gridStep = np.zeros((nx, 1)) - for ind3 in range(nx): - gridDim.append(np.linspace(-gridBound[ind3], gridBound[ind3], Npa) + xp_aux[ind3]) - gridStep[ind3] = abs(gridDim[ind3][0] - gridDim[ind3][1]) - - # Grid rotation by eigenvectors and translation to the counted unscented mean - predGrid = np.array(np.meshgrid(*gridDim)).reshape(nx, -1, order='C') - - # Grid step size - predGridDelta = np.squeeze(gridStep) - - return predGrid, gridDim, predGridDelta - - -def gridCreation(xp_aux,Pp_aux,sFactor,nx,Npa): - gridDim = np.zeros((nx,Npa)) - gridStep = np.zeros(nx) - eigVal,eigVect = LA.eig(Pp_aux) # eigenvalue and eigenvectors for setting up the grid - gridBound = np.sqrt(eigVal)*sFactor #Boundaries of grid - - for ind3 in range(0,nx): #Creation of propagated grid - gridDim[ind3] = np.linspace(-gridBound[ind3], gridBound[ind3], Npa) #New grid with middle in 0 - gridStep[ind3] = np.absolute(gridDim[ind3][0] - gridDim[ind3][1]) #Grid step - - combvec_predGrid = np.array(list(itertools.product(*gridDim))) - predGrid_pom = np.dot(combvec_predGrid,eigVect).T - size_pom = np.size(predGrid_pom,1) - predGrid = predGrid_pom + np.matlib.repmat(xp_aux,1,size_pom) #Grid rotation by eigenvectors and traslation to the counted unscented mean - predGridDelta = gridStep # Grid step size - return predGrid,predGridDelta - - -def pmfMeas(predGrid,nz,k,z,invR,predDenDenomV,predDensityProb,predGridDelta,hfunct): - predThrMeasEq = hfunct(predGrid,np.zeros((nz,1)),k+1) #Prediction density grid through measurement EQ - pom = np.matlib.repmat(z,np.size(predThrMeasEq.T,0),1)-predThrMeasEq.T #Measurement - measurementEQ(Grid) - citatel = np.exp(np.sum(-0.5*np.multiply(pom @ invR,pom),1)) - filterDensityNoNorm = np.multiply(citatel / predDenDenomV ,predDensityProb.T) - filterDensityNoNorm = filterDensityNoNorm.T - measPdf = (filterDensityNoNorm / np.sum(np.prod(predGridDelta)*filterDensityNoNorm,0)) - return measPdf - -def pmfUpdateSTD(measGrid,measPdf,predGridDelta,ffunct,predGrid,nx,k,invQ,predDenDenomW,N): - fitDenDOTprodDeltas = measPdf*np.prod(predGridDelta[:,k]) # measurement PDF * measurement PDF step size - gridNext = ffunct(measGrid,np.zeros((nx,1)),k+1) # Old grid through dynamics - - predDensityProb = np.zeros((N,1)) - for ind2 in range(0,N): #Over number of state of prediction grid - pom = (predGrid[:,ind2].T-(gridNext).T) - suma = np.sum(-0.5*np.multiply(pom@invQ,pom),1) - predDensityProb[ind2,0] = (np.exp(suma)/predDenDenomW).T@fitDenDOTprodDeltas - predDensityProb = predDensityProb/(np.sum(predDensityProb)*np.prod(predGridDelta[:,k+1])) # Normalizaton (theoretically not needed) - return predDensityProb - - -def pmfUpdateDWC(invF,predGrid,measGrid,predGridDelta,Qa,cnormHere,measPdf,N,k): - predDensityProb = np.zeros((N,N)) - for i in range(0,N): # Unecessary for cycle, for clearer understanding - ma = invF @ predGrid[:,i] - for n in range(0,N): - lowerBound = np.array([measGrid[:,n]-predGridDelta[:,k]/2]).T # boundary of rectangular region M - upperBound = np.array([measGrid[:,n]+predGridDelta[:,k]/2]).T - cdfAct = mvn.mvnun(lowerBound,upperBound,ma,Qa)[0] #Integral calculation - predDensityProb[i,n] = cnormHere*cdfAct*measPdf[n] # Predictive density - predDensityProb = np.sum(predDensityProb,1) - predDensityProb = predDensityProb/(np.sum(predDensityProb)*np.prod(predGridDelta[:,k+1])) # Normalizaton (theoretically not needed) - return predDensityProb - - -def pmfUpdateDiagDWC(measGrid,N,predGridDelta,F,predGrid,Q,s2,measPdf,normDiagDWC,k): - predDensityProb = np.zeros((N,N)) - bound = np.zeros((np.size(measGrid,0),np.size(measGrid,0))) # boundary of rectangular region M - - for i in range(0,N): - for n in range(0,N): - bound = np.array([measGrid[:,n]-predGridDelta[:,k]/2,measGrid[:,n]+ predGridDelta[:,k]/2]).T - pom = np.array([ np.divide(-F@bound[:,0] + predGrid[:,i],np.sqrt(np.diag(Q))), np.divide(-F@bound[:,1] + predGrid[:,i],np.sqrt(np.diag(Q))) ]).T # NESEDI DIMENZE!!!!!!!!!! - erfAct = np.prod((0.5 - 0.5*sciSpec.erf(pom[:,1]/s2)) - (0.5 - 0.5*sciSpec.erf(pom[:,0]/s2))) - predDensityProb[i,n] = measPdf[n] * normDiagDWC * erfAct # Predictive density - predDensityProb = np.sum(predDensityProb,1) - predDensityProb = predDensityProb/(np.sum(predDensityProb)*np.prod(predGridDelta[:,k+1])) # Normalizaton (theoretically not needed) - return predDensityProb - - - diff --git a/stonesoup/functions/__init__.py b/stonesoup/functions/__init__.py index 29fba8bd6..33145d73b 100644 --- a/stonesoup/functions/__init__.py +++ b/stonesoup/functions/__init__.py @@ -38,15 +38,14 @@ def gridCreation(xp_aux, Pp_aux, sFactor, nx, Npa): predGridDelta : list grid step per dimension gridDim : list of numpy.ndarrays - grid coordinates per dimension before rotation and translation + grid coordinates per dimension before rotation and translation xp_aux : numpy.ndarray grid center eigVect : numpy.ndarray eigenvectors describing the rotation of the grid - + """ - gridDim = np.zeros((nx, Npa[0])) - gridStep = np.zeros(nx) + eigVal, eigVect = LA.eig( Pp_aux ) # eigenvalue and eigenvectors for setting up the grid @@ -72,7 +71,7 @@ def gridCreation(xp_aux, Pp_aux, sFactor, nx, Npa): combvec_predGrid = np.array(list(itertools.product(*gridDim))) predGrid_pom = np.dot(eigVect, combvec_predGrid.T) size_pom = np.size(predGrid_pom, 1) - # Grid rotation by eigenvectors and traslation to the counted unscented mean + # Grid rotation by eigenvectors and translation to the counted unscented mean predGrid = predGrid_pom + matlib.repmat(xp_aux, 1, size_pom) predGridDelta = gridStep # Grid step size return predGrid, predGridDelta, gridDim, xp_aux, eigVect diff --git a/stonesoup/predictor/pointMass.py b/stonesoup/predictor/pointMass.py index dcb7ba59b..47090d20f 100644 --- a/stonesoup/predictor/pointMass.py +++ b/stonesoup/predictor/pointMass.py @@ -15,15 +15,14 @@ class PointMassPredictor(Predictor): - """ParticlePredictor class + """PointMassPredictor class - An implementation of a Particle Filter predictor. + An implementation of a Point Mass Filter predictor. """ - sFactor: float = Property(default=4, doc="How many sigma to cover by the grid") + sFactor: float = Property(default=4., doc="How many sigma to cover by the grid") - # @profile - def predict(self, prior, timestamp=1, **kwargs): + def predict(self, prior, timestamp=None, **kwargs): """Point Mass Filter prediction step Parameters @@ -36,7 +35,7 @@ def predict(self, prior, timestamp=1, **kwargs): Returns ------- - : :class:`~.ParticleStatePrediction` + : :class:`~.PointMassStatePrediction` The predicted state """ # Compute time_interval diff --git a/stonesoup/types/prediction.py b/stonesoup/types/prediction.py index 7ca7907b3..a971598f8 100644 --- a/stonesoup/types/prediction.py +++ b/stonesoup/types/prediction.py @@ -143,9 +143,9 @@ class ParticleStatePrediction(Prediction, ParticleState): class PointMassStatePrediction(Prediction, PointMassState): - """ParticleStatePrediction type + """PointMassStatePrediction type - This is a simple Particle state prediction object. + This is a simple Point mass state prediction object. """ @@ -159,7 +159,7 @@ class ParticleMeasurementPrediction(MeasurementPrediction, ParticleState): class PointMassMeasurementPrediction(MeasurementPrediction, PointMassState): """MeasurementStatePrediction type - This is a simple Particle measurement prediction object. + This is a simple Point mass measurement prediction object. """ diff --git a/stonesoup/types/state.py b/stonesoup/types/state.py index fd00229b1..c44cb90e8 100644 --- a/stonesoup/types/state.py +++ b/stonesoup/types/state.py @@ -215,7 +215,6 @@ def mean(self): """Sample mean for particles""" return np.hstack(self.state_vector @ self.weight * np.prod(self.grid_delta)) - # @profile def covar(self): # Measurement update covariance chip_ = self.state_vector - self.mean[:, np.newaxis] diff --git a/stonesoup/types/tests/test_state.py b/stonesoup/types/tests/test_state.py index 5fde575d5..8cd4f0854 100644 --- a/stonesoup/types/tests/test_state.py +++ b/stonesoup/types/tests/test_state.py @@ -1046,7 +1046,7 @@ def test_pointmassstate(): assert priorPMF.ndim == nx assert priorPMF.__len__() == N - + def test_kernel_particle_state(): number_particles = 5 weights = np.array([1 / number_particles] * number_particles) @@ -1068,4 +1068,3 @@ def test_kernel_particle_state(): assert number_particles == len(prior) assert 4 == prior.ndim assert np.array_equal(state_vector @ weights[:, np.newaxis], prior.mean) - diff --git a/stonesoup/updater/pointMass.py b/stonesoup/updater/pointMass.py index c49b0094d..78cf13e02 100644 --- a/stonesoup/updater/pointMass.py +++ b/stonesoup/updater/pointMass.py @@ -1,8 +1,11 @@ - +from functools import lru_cache import numpy as np from scipy.stats import multivariate_normal from stonesoup.types.state import PointMassState from ..base import Property +from ..types.prediction import ( + MeasurementPrediction, +) from ..types.update import Update from .base import Updater @@ -13,12 +16,12 @@ class PointMassUpdater(Updater): Perform an update by multiplying grid points weights by PDF of measurement model """ - sFactor: float = Property(default=4, doc="How many sigma to cover by the grid") + + sFactor: float = Property(default=4.0, doc="How many sigma to cover by the grid") def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # @profile def update(self, hypothesis, **kwargs): """Point mass update step @@ -30,7 +33,7 @@ def update(self, hypothesis, **kwargs): Returns ------- - : :class:`~.PointMassState` + : :class:`~.ParticleMeasurementPrediction` The state posterior """ @@ -44,9 +47,7 @@ def update(self, hypothesis, **kwargs): R = measurement_model.covar() # Noise - x = measurement_model.function( - predicted_state - ) # State to measurement space + x = measurement_model.function(predicted_state) # State to measurement space pdf_value = multivariate_normal.pdf( x.T, np.ravel(hypothesis.measurement.state_vector), R ) # likelihood @@ -68,3 +69,17 @@ def update(self, hypothesis, **kwargs): ) return predicted_state + + @lru_cache() + def predict_measurement(self, state_prediction, measurement_model=None, **kwargs): + + if measurement_model is None: + measurement_model = self.measurement_model + + new_state_vector = measurement_model.function(state_prediction, **kwargs) + + return MeasurementPrediction.from_state( + state_prediction, + state_vector=new_state_vector, + timestamp=state_prediction.timestamp, + ) From 0645cc3b1b5c142abfd8007ae797e17f1a15b6c2 Mon Sep 17 00:00:00 2001 From: pesslovany <43780444+pesslovany@users.noreply.github.com> Date: Wed, 30 Oct 2024 16:17:16 +0100 Subject: [PATCH 15/16] Fixing flake8 issues after merge --- stonesoup/platform/base.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/stonesoup/platform/base.py b/stonesoup/platform/base.py index f9ff74758..f8b5dcb53 100644 --- a/stonesoup/platform/base.py +++ b/stonesoup/platform/base.py @@ -1,10 +1,8 @@ -import uuid from collections.abc import MutableSequence from ..base import Property, Base from ..movable import Movable, FixedMovable, MovingMovable, MultiTransitionMovable from ..sensor.sensor import Sensor -from ..types.groundtruth import GroundTruthPath class Platform(Base): From a7f3bc6ae4acf72a1161e3e1a513ce00f3c29e1d Mon Sep 17 00:00:00 2001 From: pesslovany <43780444+pesslovany@users.noreply.github.com> Date: Thu, 31 Oct 2024 12:02:28 +0100 Subject: [PATCH 16/16] Changed names of the functions to maintain python naming convetions --- stonesoup/predictor/{pointMass.py => pointmass.py} | 1 - stonesoup/predictor/tests/test_pointmass.py | 2 +- stonesoup/updater/{pointMass.py => pointmass.py} | 0 stonesoup/updater/tests/test_pointmass.py | 2 +- 4 files changed, 2 insertions(+), 3 deletions(-) rename stonesoup/predictor/{pointMass.py => pointmass.py} (99%) rename stonesoup/updater/{pointMass.py => pointmass.py} (100%) diff --git a/stonesoup/predictor/pointMass.py b/stonesoup/predictor/pointmass.py similarity index 99% rename from stonesoup/predictor/pointMass.py rename to stonesoup/predictor/pointmass.py index 47090d20f..d217a308d 100644 --- a/stonesoup/predictor/pointMass.py +++ b/stonesoup/predictor/pointmass.py @@ -31,7 +31,6 @@ def predict(self, prior, timestamp=None, **kwargs): A prior state object timestamp: :class:`datetime.datetime`, optional A timestamp signifying when the prediction is performed - (the default is `1`) Returns ------- diff --git a/stonesoup/predictor/tests/test_pointmass.py b/stonesoup/predictor/tests/test_pointmass.py index d33d796bb..54d40d3d4 100644 --- a/stonesoup/predictor/tests/test_pointmass.py +++ b/stonesoup/predictor/tests/test_pointmass.py @@ -9,7 +9,7 @@ from stonesoup.functions import gridCreation from stonesoup.models.transition.linear import KnownTurnRate from stonesoup.predictor.kalman import KalmanPredictor -from stonesoup.predictor.pointMass import PointMassPredictor +from stonesoup.predictor.pointmass import PointMassPredictor from stonesoup.types.array import StateVectors from stonesoup.types.state import GaussianState, PointMassState diff --git a/stonesoup/updater/pointMass.py b/stonesoup/updater/pointmass.py similarity index 100% rename from stonesoup/updater/pointMass.py rename to stonesoup/updater/pointmass.py diff --git a/stonesoup/updater/tests/test_pointmass.py b/stonesoup/updater/tests/test_pointmass.py index 142862b43..d097a50da 100644 --- a/stonesoup/updater/tests/test_pointmass.py +++ b/stonesoup/updater/tests/test_pointmass.py @@ -12,7 +12,7 @@ from stonesoup.types.groundtruth import GroundTruthPath, GroundTruthState from stonesoup.types.hypothesis import SingleHypothesis from stonesoup.types.state import PointMassState -from stonesoup.updater.pointMass import PointMassUpdater +from stonesoup.updater.pointmass import PointMassUpdater def test_pointmass():