Skip to content

Commit

Permalink
Merge pull request #681 from ekhunter123/add_chernoff_updater
Browse files Browse the repository at this point in the history
Add ChernoffUpdater class and track feeder
  • Loading branch information
sdhiscocks authored Jul 28, 2022
2 parents c098a63 + ddeee1b commit 0e8a394
Show file tree
Hide file tree
Showing 6 changed files with 320 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/source/stonesoup.feeder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,8 @@ Image
.. automodule:: stonesoup.feeder.image
:show-inheritance:

Track
-----

.. automodule:: stonesoup.feeder.track
:show-inheritance:
6 changes: 6 additions & 0 deletions docs/source/stonesoup.updater.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,9 @@ Composite

.. automodule:: stonesoup.updater.composite
:show-inheritance:

Chernoff
--------

.. automodule:: stonesoup.updater.chernoff
:show-inheritance:
50 changes: 50 additions & 0 deletions stonesoup/feeder/tests/test_track.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
import pytest
import numpy as np

from ...types.state import GaussianState
from ...types.track import Track
from ..track import Tracks2GaussianDetectionFeeder

t1 = Track(GaussianState([1, 1, 1, 1], np.diag([2, 2, 2, 2]), timestamp=2))
t2 = Track([GaussianState([1, 1, 1, 1], np.diag([2, 2, 2, 2]), timestamp=1),
GaussianState([2, 1, 2, 1], np.diag([2, 2, 2, 2]), timestamp=2)])
t3 = Track([GaussianState([1, 1], np.diag([2, 2]), timestamp=0),
GaussianState([2, 1], np.diag([2, 2]), timestamp=1),
GaussianState([3, 1], np.diag([2, 2]), timestamp=2)])
t4 = Track(GaussianState([1, 0, 1, 0, 1, 0], np.diag([2, 2, 2, 2, 2, 2]), timestamp=2))


@pytest.mark.parametrize(
"tracks",
[
([t1]),
([t1, t2]),
([t2, t3]),
([t1, t2, t3, t4])
]
)
def test_Track2GaussianDetectionFeeder(tracks):

# Make feeder and get detections
reader = [(tracks[0][-1].timestamp, tracks)]
feeder = Tracks2GaussianDetectionFeeder(reader=reader)
time, detections = next(feeder.data_gen())

# Check that there are the right number of detections
assert len(detections) == len(tracks)

# Check that the correct state was taken from each track
assert np.all([detections[i].timestamp == tracks[i][-1].timestamp for i in range(len(tracks))])
assert np.all([detections[i].timestamp == time for i in range(len(tracks))])

# Check that the dimension of each detection is correct
assert np.all([len(detections[i].state_vector) == len(tracks[i][-1].state_vector)
for i in range(len(tracks))])
assert np.all([len(detection.state_vector) == detection.measurement_model.ndim
for detection in detections])

# Check that the detection has the correct mean and covariance
for i in range(len(tracks)):
assert np.all(detections[i].state_vector == tracks[i][-1].state_vector)
assert np.all(detections[i].covar == tracks[i][-1].covar)
34 changes: 34 additions & 0 deletions stonesoup/feeder/track.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
import numpy as np

from stonesoup.types.detection import GaussianDetection
from stonesoup.feeder.base import DetectionFeeder
from stonesoup.models.measurement.linear import LinearGaussian
from ..buffered_generator import BufferedGenerator


class Tracks2GaussianDetectionFeeder(DetectionFeeder):
'''
Feeder consumes Track objects and outputs GaussianDetection objects.
At each time step, the :attr:`Reader` feeds in a set of live tracks. The feeder takes the most
recent state from each of those tracks, and turn them into a set of
:class:`~.GaussianDetection` objects. Each detection is given a :class:`~.LinearGaussian`
measurement model whose covariance is equal to the state covariance. The feeder assumes that
the tracks are all live, that is each track has a state at the most recent time step.
'''
@BufferedGenerator.generator_method
def data_gen(self):
for time, tracks in self.reader:
detections = []
for track in tracks:
dim = len(track.state.state_vector)
detections.append(
GaussianDetection.from_state(
track.state,
measurement_model=LinearGaussian(
dim, list(range(dim)), np.asarray(track.covar)),
target_type=GaussianDetection)
)

yield time, detections
131 changes: 131 additions & 0 deletions stonesoup/updater/chernoff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# -*- coding: utf-8 -*-
import numpy as np

from ..base import Property
from .base import Updater
from ..types.prediction import MeasurementPrediction
from ..types.update import Update


class ChernoffUpdater(Updater):
r"""A class which performs state updates using the Chernoff fusion rule. In this context,
measurements come in the form of states with a mean and covariance instead of just the
traditional mean. The measurements are expected to come as :class:`~.GaussianDetection`
objects.
The Chernoff fusion rule is written as
.. math::
p_{\omega}(x_{k}) = \frac{p_{1}(x_{k})^{\omega}p_{2}(x_{k})^{1-\omega}}
{\int p_{1}(x)^{\omega}p_{2}(x)^{1-\omega} \mathrm{d} x}
where :math:`\omega` is a weighting parameter in the range :math:`(0,1]`, which can be found
using an optimization algorithm.
In situations where :math:`p_1(x)` and :math:`p_2(x)` are multivariate Gaussian distributions,
the above formula is equal to the Covariance Intersection Algorithm from Julier and Uhlmann.
Let :math:`(a,A)` and :math:`(b,B)` be the means and covariances of the measurement and
prediction respectively. The Covariance Intersection Algorithm was reformulated for use in
Bayesian state estimation by Clark et al, yielding the update formulas for the covariance,
mean, and innovation:
.. math::
D &= \left ( \omega A^{-1} + (1-\omega)B^{-1} \right )\\
d &= D \left ( \omega A^{-1}a + (1-\omega)B^{-1}b \right )\\
V &= \frac{A}{1-\omega} + \frac{B}{\omega}
In filters where gating is required, the gating region can be written using the innovation
covariance matrix as:
.. math::
\mathcal{V}(\gamma) = \left\{ (a,A) : (a-b)^T V^{-1} (a-b) \leq \gamma \right\}
Note: If you have tracks that you would like to use as measurements for this updater, the
:class:`~.Tracks2GaussianDetectionFeeder` class can be used to convert the tracks to the
appropriate format.
References
----------
[1] Hurley, M. B., “An information theoretic justification for covariance intersection and its
generalization,” in [Proceedings of the Fifth International Conference on Information Fusion.
FUSION 2002.(IEEE Cat. No. 02EX5997) ], 1, 505–511, IEEE (2002).
https://ieeexplore.ieee.org/document/1021196.
[2] Julier, S., Uhlmann, J., and Durrant-Whyte, H. F., “A new method for the nonlinear
transformation of means and covariances in filters and estimators,” IEEE Transactions on
automatic control 45(3), 477–482 (2000).
https://ieeexplore.ieee.org/abstract/document/847726/similar#similar.
[3] Clark, D. E. and Campbell, M. A., “Integrating covariance intersection into Bayesian
multi-target tracking filters,” preprint on TechRxiv. submitted to IEEE Transactions on
Aerospace and Electronic Systems .
"""

omega: float = Property(
default=0.5,
doc="A weighting parameter in the range :math:`(0,1]`")

def predict_measurement(self, predicted_state, measurement_model=None, **kwargs):
'''
This function predicts the measurement in situations where the predicted state consists
of a covariance and state vector.
'''

measurement_model = self._check_measurement_model(measurement_model)

# The innovation covariance uses the noise covariance from the measurement model
state_covar_m = measurement_model.noise_covar
innov_covar = 1/(1-self.omega)*state_covar_m + 1/self.omega*predicted_state.covar

# The predicted measurement and measurement cross covariance can be taken from
# the predicted state
predicted_meas = predicted_state.state_vector
meas_cross_cov = predicted_state.covar

# Combine everything into a GaussianMeasurementPrediction object
return MeasurementPrediction.from_state(predicted_state, predicted_meas, innov_covar,
predicted_state.timestamp,
cross_covar=meas_cross_cov)

def update(self, hypothesis, force_symmetric_covariance=False, **kwargs):
'''
Given a hypothesis, calculate the posterior mean and covariance
'''

# Get the predicted state out of the hypothesis. These are 'B' and 'b', the
# covariance and mean of the predicted Gaussian
predicted_covar = hypothesis.prediction.covar
predicted_mean = hypothesis.prediction.state_vector

# Extract the vector and covariance from the measurement. These are 'A' and 'a', the
# covariance and mean of the Gaussian measurement.
measurement_covar = hypothesis.measurement.covar
measurement_mean = hypothesis.measurement.state_vector

# Predict the measurement if it is not already done
if hypothesis.measurement_prediction is None:
hypothesis.measurement_prediction = self.predict_measurement(
hypothesis.prediction,
measurement_model=hypothesis.measurement.measurement_model,
**kwargs
)

# Calculate the updated mean and covariance from covariance intersection
posterior_covariance = np.linalg.inv(self.omega*np.linalg.inv(measurement_covar) +
(1-self.omega)*np.linalg.inv(predicted_covar))
posterior_mean = posterior_covariance @ (self.omega*np.linalg.inv(measurement_covar)
@ measurement_mean +
(1-self.omega)*np.linalg.inv(predicted_covar)
@ predicted_mean)

# Optionally force the posterior covariance to be a symmetric matrix
if force_symmetric_covariance:
posterior_covariance = \
(posterior_covariance + posterior_covariance.T)/2

# Return the updated state
return Update.from_state(hypothesis.prediction, posterior_mean, posterior_covariance,
hypothesis, hypothesis.measurement.timestamp)
94 changes: 94 additions & 0 deletions stonesoup/updater/tests/test_chernoff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# -*- coding: utf-8 -*-
"""Test for updater.chernoff module"""
import pytest
import numpy as np

from stonesoup.models.measurement.linear import LinearGaussian
from stonesoup.types.detection import GaussianDetection
from stonesoup.types.hypothesis import SingleHypothesis
from stonesoup.types.prediction import (
GaussianStatePrediction, GaussianMeasurementPrediction)
from stonesoup.types.state import GaussianState
from stonesoup.updater.chernoff import ChernoffUpdater


@pytest.mark.parametrize(
"UpdaterClass, measurement_model, prediction, measurement, omega",
[
( # Chernoff Updater
ChernoffUpdater,
LinearGaussian(ndim_state=2, mapping=[0, 1],
noise_covar=np.array([[0.04, 0.04]])),
GaussianStatePrediction(np.array([[-6.45], [0.7]]),
np.array([[4.1123, 0.0013],
[0.0013, 0.0365]])),
GaussianDetection(state_vector=np.array([[-6.23, 0.83]]),
covar=np.diag([0.75, 1.2])),
0.5
)
],
ids=["standard"]
)
def test_chernoff(UpdaterClass, measurement_model, prediction, measurement, omega):

# Calculate evaluation variables
innov_cov = 1/(1-omega)*measurement_model.noise_covar + 1/omega*prediction.covar
eval_measurement_prediction = GaussianMeasurementPrediction(
measurement_model.matrix() @ prediction.mean,
innov_cov,
cross_covar=prediction.covar @ measurement_model.matrix().T)

posterior_cov = np.linalg.inv(omega*np.linalg.inv(measurement.covar) +
(1-omega)*np.linalg.inv(prediction.covar))
posterior_mean = posterior_cov@(omega*np.linalg.inv(measurement.covar) @
measurement.state_vector + (1-omega) *
np.linalg.inv(prediction.covar)@prediction.state_vector)
eval_posterior = GaussianState(
posterior_mean,
posterior_cov)

# Initialise a Chernoff updater
updater = UpdaterClass(measurement_model=measurement_model, omega=omega)

# Get and assert measurement prediction
measurement_prediction = updater.predict_measurement(prediction)
assert(np.allclose(measurement_prediction.mean,
eval_measurement_prediction.mean,
0, atol=1.e-14))
assert(np.allclose(measurement_prediction.covar,
eval_measurement_prediction.covar,
0, atol=1.e-14))
assert(np.allclose(measurement_prediction.cross_covar,
eval_measurement_prediction.cross_covar,
0, atol=1.e-14))

# Perform and assert state update (without measurement prediction)
posterior = updater.update(SingleHypothesis(
prediction=prediction,
measurement=measurement))
assert(np.allclose(posterior.mean, eval_posterior.mean, 0, atol=1.e-14))
assert(np.allclose(posterior.covar, eval_posterior.covar, 0, atol=1.e-14))
assert(np.array_equal(posterior.hypothesis.prediction, prediction))
assert (np.allclose(
posterior.hypothesis.measurement_prediction.state_vector,
measurement_prediction.state_vector, 0, atol=1.e-14))
assert (np.allclose(posterior.hypothesis.measurement_prediction.covar,
measurement_prediction.covar, 0, atol=1.e-14))
assert(np.array_equal(posterior.hypothesis.measurement, measurement))
assert(posterior.timestamp == prediction.timestamp)

# Perform and assert state update
posterior = updater.update(SingleHypothesis(
prediction=prediction,
measurement=measurement,
measurement_prediction=measurement_prediction))
assert(np.allclose(posterior.mean, eval_posterior.mean, 0, atol=1.e-14))
assert(np.allclose(posterior.covar, eval_posterior.covar, 0, atol=1.e-14))
assert(np.array_equal(posterior.hypothesis.prediction, prediction))
assert (np.allclose(
posterior.hypothesis.measurement_prediction.state_vector,
measurement_prediction.state_vector, 0, atol=1.e-14))
assert (np.allclose(posterior.hypothesis.measurement_prediction.covar,
measurement_prediction.covar, 0, atol=1.e-14))
assert(np.array_equal(posterior.hypothesis.measurement, measurement))
assert(posterior.timestamp == prediction.timestamp)

0 comments on commit 0e8a394

Please sign in to comment.