Skip to content

Commit

Permalink
Merge pull request #1080 from A-acuto/new_pf_proposal
Browse files Browse the repository at this point in the history
Particle filter proposal implementation
  • Loading branch information
sdhiscocks authored Oct 1, 2024
2 parents aa609b0 + daae749 commit a004fd9
Show file tree
Hide file tree
Showing 8 changed files with 330 additions and 20 deletions.
14 changes: 14 additions & 0 deletions docs/source/stonesoup.proposal.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
Proposal
===========

.. automodule:: stonesoup.proposal
:no-members:

.. automodule:: stonesoup.proposal.base
:show-inheritance:

Simple
------

.. automodule:: stonesoup.proposal.simple
:show-inheritance:
1 change: 1 addition & 0 deletions docs/source/stonesoup.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Algorithm Components
stonesoup.mixturereducer
stonesoup.models
stonesoup.predictor
stonesoup.proposal
stonesoup.regulariser
stonesoup.resampler
stonesoup.sampler
Expand Down
2 changes: 1 addition & 1 deletion stonesoup/hypothesiser/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def hypothesise(self, track, detections, timestamp, **kwargs):

# Re-evaluate prediction
prediction = self.predictor.predict(
track, timestamp=detection.timestamp, **kwargs)
track, timestamp=detection.timestamp, measurement=detection, **kwargs)

# Compute measurement prediction and distance measure
measurement_prediction = self.updater.predict_measurement(
Expand Down
35 changes: 22 additions & 13 deletions stonesoup/predictor/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from .kalman import KalmanPredictor, ExtendedKalmanPredictor
from ..base import Property
from ..models.transition import TransitionModel
from ..proposal.base import Proposal
from ..proposal.simple import PriorAsProposal
from ..sampler.particle import ParticleSampler
from ..types.numeric import Probability
from ..types.prediction import Prediction
Expand All @@ -25,9 +27,18 @@ class ParticlePredictor(Predictor):
An implementation of a Particle Filter predictor.
"""
proposal: Proposal = Property(
default=None,
doc="A proposal object that generates samples from the proposal distribution. If `None`,"
"the transition model is used to generate samples.")

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.proposal is None:
self.proposal = PriorAsProposal(self.transition_model)

@predict_lru_cache()
def predict(self, prior, timestamp=None, **kwargs):
def predict(self, prior, timestamp=None, measurement=None, **kwargs):
"""Particle Filter prediction step
Parameters
Expand All @@ -37,31 +48,29 @@ def predict(self, prior, timestamp=None, **kwargs):
timestamp: :class:`datetime.datetime`, optional
A timestamp signifying when the prediction is performed
(the default is `None`)
measurement: :class:`~.Detection`, optional
measurement used in the Kalman Filter proposal to update
the prediction
(the default is `None`)
Returns
-------
: :class:`~.ParticleStatePrediction`
The predicted state
"""

# Compute time_interval
try:
time_interval = timestamp - prior.timestamp
except TypeError:
# TypeError: (timestamp or prior.timestamp) is None
time_interval = None

new_state_vector = self.transition_model.function(
prior,
noise=True,
time_interval=time_interval,
**kwargs)

return Prediction.from_state(prior,
parent=prior,
state_vector=new_state_vector,
timestamp=timestamp,
transition_model=self.transition_model,
prior=prior)
return self.proposal.rvs(prior,
noise=True,
time_interval=time_interval,
measurement=measurement,
**kwargs)


class ParticleFlowKalmanPredictor(ParticlePredictor):
Expand Down
16 changes: 16 additions & 0 deletions stonesoup/proposal/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from abc import abstractmethod
from stonesoup.base import Base


class Proposal(Base):

@abstractmethod
def rvs(self, *args, **kwargs):
r"""Proposal noise/sample generation function
Generates samples from the proposal.
Parameters
----------
state: :class:`~.State`
The state to generate samples from.
"""
raise NotImplementedError
140 changes: 140 additions & 0 deletions stonesoup/proposal/simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from typing import Union

import numpy as np
from scipy.stats import multivariate_normal as mvn

from stonesoup.base import Property
from stonesoup.models.transition import TransitionModel
from stonesoup.proposal.base import Proposal
from stonesoup.types.array import StateVector, StateVectors
from stonesoup.types.detection import Detection
from stonesoup.types.state import State, GaussianState, SqrtGaussianState
from stonesoup.types.prediction import Prediction
from stonesoup.updater.base import Updater
from stonesoup.predictor.base import Predictor
from stonesoup.predictor.kalman import SqrtKalmanPredictor
from stonesoup.types.hypothesis import SingleHypothesis


class PriorAsProposal(Proposal):
"""Proposal that uses the dynamics model as the importance density.
This proposal uses the dynamics model to predict the next state, and then
uses the predicted state as the prior for the measurement model.
"""
transition_model: TransitionModel = Property(
doc="The transition model used to make the prediction")

def rvs(self, prior: State, measurement=None, time_interval=None,
**kwargs) -> Union[StateVector, StateVectors]:
"""Generate samples from the proposal.
Parameters
----------
state: :class:`~.State`
The state to generate samples from.
Returns
-------
: :class:`~.ParticlePrediction` with samples drawn from the updated proposal
"""

if measurement is not None:
timestamp = measurement.timestamp
time_interval = measurement.timestamp - prior.timestamp
else:
timestamp = prior.timestamp + time_interval

new_state_vector = self.transition_model.function(prior,
time_interval=time_interval,
**kwargs)
return Prediction.from_state(prior,
parent=prior,
state_vector=new_state_vector,
timestamp=timestamp,
transition_model=self.transition_model,
prior=prior)


class KFasProposal(Proposal):
"""This proposal uses the kalman filter prediction and update steps to
generate new set of particles and weights
"""
predictor: Predictor = Property(
doc="predictor to use the various values")
updater: Updater = Property(
doc="Updater used for update the values")

def rvs(self, prior: State, measurement: Detection = None, time_interval=None,
**kwargs):
"""Generate samples from the proposal.
Use the kalman filter predictor and updater to create a new distribution
Parameters
----------
state: :class:`~.State`
The state to generate samples from.
measurement: :class:`~.Detection`
the measurement that is used in the update step of the kalman prediction,
(the default is `None`)
time_interval: :class:`datetime.time_delta`
time interval of the prediction is needed to propagate the states
Returns
-------
: :class:`~.ParticlePrediction`
"""

if measurement is not None:
timestamp = measurement.timestamp
time_interval = measurement.timestamp - prior.timestamp
else:
timestamp = prior.timestamp + time_interval

if time_interval.total_seconds() == 0:
return Prediction.from_state(prior,
parent=prior,
state_vector=prior.state_vector,
timestamp=prior.timestamp,
transition_model=self.predictor.transition_model,
prior=prior)

prior_cls = GaussianState # Default
if isinstance(self.predictor, SqrtKalmanPredictor):
prior_cls = SqrtGaussianState

# Null covariance for the particles
null_covar = np.zeros_like(prior.covar)

predictions = [
self.predictor.predict(
prior_cls(particle_sv, null_covar, prior.timestamp),
timestamp=timestamp)
for particle_sv in prior.state_vector]

if measurement is not None:
updates = [self.updater.update(SingleHypothesis(prediction, measurement))
for prediction in predictions]
else:
updates = predictions # keep the prediction

# Draw the samples
samples = np.array([state.state_vector.reshape(-1) +
mvn.rvs(cov=state.covar).T
for state in updates])

# Compute the log of q(x_k|x_{k-1}, y_k)
post_log_weights = np.array([mvn.logpdf(sample - update.state_vector.reshape(-1),
cov=update.covar)
for sample, update in zip(samples, updates)])

pred_state = Prediction.from_state(prior,
parent=prior,
state_vector=StateVectors(samples.T),
timestamp=timestamp,
transition_model=self.predictor.transition_model,
prior=prior)

prior_log_weights = self.predictor.transition_model.logpdf(pred_state, prior,
time_interval=time_interval)

pred_state.log_weight = (pred_state.log_weight + prior_log_weights - post_log_weights)

return pred_state
129 changes: 129 additions & 0 deletions stonesoup/proposal/tests/test_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import itertools

import datetime
import numpy as np

# Import the proposals
from stonesoup.proposal.simple import PriorAsProposal, KFasProposal
from stonesoup.models.transition.linear import ConstantVelocity
from stonesoup.types.particle import Particle
from stonesoup.types.prediction import ParticleStatePrediction
from stonesoup.predictor.kalman import KalmanPredictor
from stonesoup.updater.kalman import KalmanUpdater
from stonesoup.types.state import ParticleState, GaussianState
from stonesoup.predictor.particle import ParticlePredictor
from stonesoup.types.detection import Detection
from stonesoup.models.measurement.linear import LinearGaussian
from stonesoup.types.hypothesis import SingleHypothesis


def test_prior_proposal():
# test that the proposal as prior and basic PF implementation
# yield same results, since they are driven by the transition model

# Initialise a transition model
cv = ConstantVelocity(noise_diff_coeff=0)

# Define time related variables
timestamp = datetime.datetime.now()
timediff = 2 # 2 sec
new_timestamp = timestamp + datetime.timedelta(seconds=timediff)
time_interval = new_timestamp - timestamp

num_particles = 9 # Number of particles

# Define prior state
prior_particles = [Particle(np.array([[i], [j]]), 1/num_particles)
for i, j in itertools.product([10, 20, 30], [10, 20, 30])]
prior = ParticleState(None, particle_list=prior_particles, timestamp=timestamp)

# predictors prior and standard stone soup
predictor_prior = ParticlePredictor(cv,
proposal=PriorAsProposal(cv))

# Check that the predictor without prior specified works with the prior as
# proposal
predictor_base = ParticlePredictor(cv)

# basic transition model evaluations
eval_particles = [Particle(cv.matrix(timestamp=new_timestamp,
time_interval=time_interval)
@ particle.state_vector,
1 / 9)
for particle in prior_particles]
eval_mean = np.mean(np.hstack([i.state_vector for i in eval_particles]),
axis=1).reshape(2, 1)

# construct the evaluation prediction
eval_prediction = ParticleStatePrediction(None, new_timestamp, particle_list=eval_particles)

prediction_base = predictor_base.predict(prior, timestamp=new_timestamp)
prediction_prior = predictor_prior.predict(prior, timestamp=new_timestamp)

assert np.all([eval_prediction.state_vector[:, i] ==
prediction_base.state_vector[:, i] for i in range(9)])
assert np.all([prediction_base.weight[i] == 1 / 9 for i in range(9)])

assert np.allclose(prediction_prior.mean, eval_mean)
assert prediction_prior.timestamp == new_timestamp
assert np.all([eval_prediction.state_vector[:, i] ==
prediction_prior.state_vector[:, i] for i in range(9)])
assert np.all([prediction_prior.weight[i] == 1 / 9 for i in range(9)])


def test_kf_proposal():

# Initialise a transition model
cv = ConstantVelocity(noise_diff_coeff=1)

# initialise the measurement model
lg = LinearGaussian(ndim_state=2,
mapping=[0],
noise_covar=np.diag([1]))

# Define time related variables
timestamp = datetime.datetime.now()
timediff = 2 # 2 sec
new_timestamp = timestamp + datetime.timedelta(seconds=timediff)
time_interval = new_timestamp - timestamp

num_particles = 9 # Number of particles

# Define prior state
prior_particles = [Particle(np.array([[i], [j]]), 1/num_particles)
for i, j in itertools.product([1, 2, 3], [1, 2, 3])]

prior = ParticleState(None, particle_list=prior_particles, timestamp=timestamp)

# null covariance for the predictions
null_covar = np.zeros_like(prior.covar)
prior_kf = GaussianState(prior.mean, null_covar, prior.timestamp)

# Kalman filter components
kf_predictor = KalmanPredictor(cv)
kf_updater = KalmanUpdater(lg)

# perform the kalman filter update
prediction = kf_predictor.predict(prior_kf, timestamp=new_timestamp)

# state prediction
new_state = GaussianState(state_vector=cv.function(prior_kf, noise=True,
time_interval=time_interval),
covar=np.diag([1, 1]),
timestamp=new_timestamp)

detection = Detection(lg.function(new_state,
noise=True),
timestamp=new_timestamp,
measurement_model=lg)

eval_state = kf_updater.update(SingleHypothesis(prediction, detection))

proposal = KFasProposal(KalmanPredictor(cv),
KalmanUpdater(lg))
# particle proposal
particle_proposal = proposal.rvs(prior, measurement=detection, time_interval=time_interval)

assert particle_proposal.state_vector.shape == prior.state_vector.shape
assert np.allclose(particle_proposal.mean, eval_state.state_vector, rtol=1)
assert particle_proposal.timestamp == new_timestamp
Loading

0 comments on commit a004fd9

Please sign in to comment.