Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SMC-PHD components #798

Merged
merged 11 commits into from
Feb 28, 2024
6 changes: 6 additions & 0 deletions docs/source/stonesoup.hypothesiser.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ Hypothesiser
.. automodule:: stonesoup.hypothesiser.base
:show-inheritance:

Simple
------

.. automodule:: stonesoup.hypothesiser.simple
:show-inheritance:

Distance
--------

Expand Down
98 changes: 98 additions & 0 deletions stonesoup/hypothesiser/simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import datetime

from typing import Set

from stonesoup.base import Property
from stonesoup.hypothesiser import Hypothesiser
from stonesoup.predictor import Predictor
from stonesoup.types.detection import MissedDetection, Detection
from stonesoup.types.hypothesis import SingleHypothesis
from stonesoup.types.multihypothesis import MultipleHypothesis
from stonesoup.types.track import Track
from stonesoup.updater import Updater


class SimpleHypothesiser(Hypothesiser):
"""Simple Hypothesiser class

Generate track predictions at detection times and create hypotheses for
each detection, as well as a missed detection hypothesis.
"""
predictor: Predictor = Property(doc="Predict tracks to detection times")
updater: Updater = Property(
default=None,
doc="Updater used to get measurement prediction. Only required if "
"`predict_measurement` is `True`. Default is `None`")
check_timestamp: bool = Property(
default=True,
doc="Check that all detections have the same timestamp. Default is `True`")
predict_measurement: bool = Property(
default=False,
doc="Predict measurement for each detection. Default is `True`")

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.predict_measurement and self.updater is None:
raise ValueError("Updater must be provided if `predict_measurement` is `True`")

def hypothesise(self, track: Track, detections: Set[Detection], timestamp: datetime.datetime,
**kwargs) -> MultipleHypothesis:
""" Evaluate and return all track association hypotheses.

For a given track and a set of N available detections, return a
MultipleHypothesis object with N+1 detections (first detection is
a 'MissedDetection').

Parameters
----------
track : Track
The track object to hypothesise on
detections : set of :class:`~.Detection`
The available detections
timestamp : datetime.datetime
A timestamp used when evaluating the state and measurement
predictions. Note that if a given detection has a non-empty
timestamp, then prediction will be performed according to
the timestamp of the detection.

Returns
-------
: :class:`~.MultipleHypothesis`
A container of :class:`~SingleHypothesis` objects

"""

if self.check_timestamp:
# Check to make sure all detections are obtained from the same time
timestamps = {detection.timestamp for detection in detections}
if len(timestamps) > 1:
raise ValueError("All detections must have the same timestamp")

hypotheses = []

# Common state prediction
prediction = self.predictor.predict(track, timestamp=timestamp, **kwargs)

# Missed detection hypothesis
hypotheses.append(
SingleHypothesis(prediction, MissedDetection(timestamp=timestamp))
)

# True detection hypotheses
for detection in detections:

# Re-evaluate prediction
prediction = self.predictor.predict(track, timestamp=detection.timestamp, **kwargs)

# Compute measurement prediction
if self.predict_measurement:
measurement_prediction = self.updater.predict_measurement(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Codecov reported as not covered by test, which seems odd when test below should be setting predict_measurement to True and False.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice spot. I have added an extra test case to make sure we reach that part of the code.

prediction, timestamp=detection.timestamp, **kwargs)
else:
measurement_prediction = None

hypotheses.append(
SingleHypothesis(prediction, detection, measurement_prediction)
)

return MultipleHypothesis(hypotheses)
55 changes: 55 additions & 0 deletions stonesoup/hypothesiser/tests/test_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import datetime
import numpy as np
import pytest

from stonesoup.hypothesiser.simple import SimpleHypothesiser
from stonesoup.types.detection import Detection
from stonesoup.types.state import GaussianState
from stonesoup.types.track import Track


@pytest.mark.parametrize(
"check_timestamp, predict_measurement",
[
(True, False),
(False, True),
(False, False)
]
)
def test_simple(predictor, updater, check_timestamp, predict_measurement):
timestamp = datetime.datetime.now()
track = Track([GaussianState(np.array([[0]]), np.array([[1]]), timestamp)])

# Create 3 detections, 2 of which are at the same time
detection1 = Detection(np.array([[2]]), timestamp)
detection2 = Detection(np.array([[3]]), timestamp + datetime.timedelta(seconds=1))
detection3 = Detection(np.array([[10]]), timestamp + datetime.timedelta(seconds=1))
detections = {detection1, detection2, detection3}

hypothesiser = SimpleHypothesiser(predictor, updater, check_timestamp, predict_measurement)

if check_timestamp:
# Detection 1 has different timestamp to Detections 2 and 3, so this should raise an error
with pytest.raises(ValueError):
hypothesiser.hypothesise(track, detections, timestamp)
return

hypotheses = hypothesiser.hypothesise(track, detections, timestamp)

# There are 3 hypotheses - Detection 1, Detection 2, Detection 3, Missed Detection
assert len(hypotheses) == 4

# There is a missed detection hypothesis
assert any(not hypothesis.measurement for hypothesis in hypotheses)

if predict_measurement:
# Each true hypothesis has a measurement prediction
true_hypotheses = [hypothesis for hypothesis in hypotheses if hypothesis]
assert all(hypothesis.measurement_prediction is not None for hypothesis in true_hypotheses)
else:
assert all(hypothesis.measurement_prediction is None for hypothesis in hypotheses)


def test_invalid_simple_arguments(predictor):
with pytest.raises(ValueError):
SimpleHypothesiser(predictor, updater=None, predict_measurement=True)
142 changes: 142 additions & 0 deletions stonesoup/predictor/particle.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from enum import Enum
from typing import Sequence

import numpy as np
Expand All @@ -10,6 +11,8 @@
from .kalman import KalmanPredictor, ExtendedKalmanPredictor
from ..base import Property
from ..models.transition import TransitionModel
from ..sampler.particle import ParticleSampler
from ..types.numeric import Probability
from ..types.prediction import Prediction
from ..types.state import GaussianState
from ..sampler import Sampler
Expand Down Expand Up @@ -383,3 +386,142 @@ def get_detections(prior):
detections |= {hypothesis.measurement}

return detections


class SMCPHDBirthSchemeEnum(Enum):
"""SMC-PHD Birth scheme enumeration"""
EXPANSION = 'expansion' #: Expansion birth scheme
MIXTURE = 'mixture' #: Mixture birth scheme


class SMCPHDPredictor(Predictor):
"""Sequential Monte Carlo Probability Hypothesis Density (SMC-PHD) Predictor class

An implementation of a particle predictor that propagates only the first-order moment (i.e. the
Probability Hypothesis Density) of the multi-target state density based on [#phd]_.

.. note::

- It is assumed that the proposal distribution is the same as the dynamics
- Target "spawning" is not implemented

Parameters
----------

References
----------
.. [#phd] Ba-Ngu Vo, S. Singh and A. Doucet, "Sequential monte carlo implementation of the phd
filter for multi-target tracking," Sixth International Conference of Information
Fusion, 2003. Proceedings of the, Cairns, QLD, Australia, 2003, pp. 792-799,
doi: 10.1109/ICIF.2003.177320

"""
death_probability: Probability = Property(
doc="The probability of death per unit time. This is used to calculate the probability "
r"of survival as :math:`1 - \exp(-\lambda \Delta t)` where :math:`\lambda` is the "
r"probability of death and :math:`\Delta t` is the time interval")
birth_probability: Probability = Property(
doc="Probability of target birth. In the current implementation, this is used to calculate"
"the number of birth particles, as per the explanation under :py:attr:`~birth_scheme`")
birth_rate: float = Property(
doc="The expected number of new/born targets at each iteration. This is used to calculate"
"the weight of the birth particles")
birth_sampler: ParticleSampler = Property(
doc="Sampler object used for sampling birth particles. The weight of the sampled birth "
"particles is ignored and calculated internally based on the :py:attr:`~birth_rate` "
"and number of particles")
birth_func_num_samples_field: str = Property(
default='num_samples',
doc="The field name of the number of samples parameter for the birth sampler. This is "
"required since the number of samples required for the birth sampler may be vary"
"between iterations. Default is ``'num_samples'``")
birth_scheme: SMCPHDBirthSchemeEnum = Property(
default=SMCPHDBirthSchemeEnum.EXPANSION,
doc="The scheme for birth particles. Options are ``'expansion'`` | ``'mixture'``. Default "
"is ``'expansion'``.\n\n"
" - The ``'expansion'`` scheme follows the implementation of [#phd]_, meaning that "
"birth particles are appended to the list of surviving particles, where the number of "
"birth particles is computed as :math:`P_b N` where :math:`P_b` is the birth "
"probability and :math:`N` is the number of particles.\n"
" - The ``'mixture'`` scheme draws from a binomial distribution, with probability "
":math:`P_b`, for each particle to decide if it gets replaced by a birth particle. "
"The weights of the particles are then updated as a mixture of the survival and birth "
"probabilities."
)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Ensure birth scheme is a valid BirthSchemeEnum
self.birth_scheme = SMCPHDBirthSchemeEnum(self.birth_scheme)

@predict_lru_cache()
def predict(self, prior, timestamp=None, **kwargs):
""" SMC-PHD prediction step

Parameters
----------
prior: :class:`~.ParticleState`
The prior state
timestamp: :class:`datetime.datetime`
The time at which to predict the next state

Returns
-------
: :class:`~.ParticleStatePrediction`
The predicted state
"""
num_samples = len(prior)
log_prior_weights = prior.log_weight
time_interval = timestamp - prior.timestamp

# Predict surviving particles forward
pred_particles_sv = self.transition_model.function(prior,
time_interval=time_interval,
noise=True)

# Calculate probability of survival
log_prob_survive = -float(self.death_probability) * time_interval.total_seconds()

# Perform birth and update weights
if self.birth_scheme == SMCPHDBirthSchemeEnum.EXPANSION:
# Expansion birth scheme, as described in [1]
# Compute number of birth particles (J_k) as a fraction of the number of particles
num_birth = round(float(self.birth_probability) * num_samples)

# Sample birth particles
birth_particles = self.birth_sampler.sample(
params={self.birth_func_num_samples_field: num_birth}, timestamp=timestamp)
# Ensure birth weights are uniform and scaled by birth rate
birth_particles.log_weight = np.full((num_birth,), np.log(self.birth_rate / num_birth))

# Surviving particle weights
log_pred_weights = log_prob_survive + log_prior_weights

# Append birth particles to predicted ones
pred_particles_sv = StateVectors(
np.concatenate((pred_particles_sv, birth_particles.state_vector), axis=1))
log_pred_weights = np.concatenate((log_pred_weights, birth_particles.log_weight))
else:
# Flip a coin for each particle to decide if it gets replaced by a birth particle
birth_inds = np.flatnonzero(
np.random.binomial(1, float(self.birth_probability), num_samples)
)

# Sample birth particles and replace in original state vector matrix
num_birth = len(birth_inds)
birth_particles = self.birth_sampler.sample(
params={self.birth_func_num_samples_field: num_birth}, timestamp=timestamp)
# Replace particles in the state vector matrix
pred_particles_sv[:, birth_inds] = birth_particles.state_vector

# Process weights
prob_survive = np.exp(log_prob_survive)
birth_weight = self.birth_rate / num_samples
log_pred_weights = np.log(prob_survive + birth_weight) + log_prior_weights

prediction = Prediction.from_state(prior, state_vector=pred_particles_sv,
log_weight=log_pred_weights,
timestamp=timestamp, particle_list=None,
transition_model=self.transition_model)

return prediction
Loading