Skip to content

Commit

Permalink
Merge pull request #855 from timothy-glover/create_constrained_update…
Browse files Browse the repository at this point in the history
…r_and_regulariser

Introduce ability to constrain particle states in ParticleUpdater and MCMCRegulariser
  • Loading branch information
sdhiscocks authored Sep 29, 2023
2 parents f24090c + 5273f07 commit 900538e
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 9 deletions.
17 changes: 16 additions & 1 deletion stonesoup/regulariser/particle.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import numpy as np
from scipy.stats import multivariate_normal, uniform
from typing import Sequence
from typing import Sequence, Callable

from .base import Regulariser
from ..functions import cholesky_eps
Expand Down Expand Up @@ -32,6 +32,16 @@ class MCMCRegulariser(Regulariser):

transition_model: TransitionModel = Property(doc="Transition model used for prediction",
default=None)
constraint_func: Callable = Property(
default=None,
doc="Callable, user defined function for applying "
"constraints to particle states. This is done by reverting particles "
"that are moved to a state outside of the defined constraints "
"back to the state prior to the move step. Particle states that are "
"input are assumed to be constrained. This function provides indices "
"of the unconstrained particles and should accept a :class:`~.ParticleState` "
"object and return an array-like object of logical indices. "
)

def regularise(self, prior, posterior):
"""Regularise the particles
Expand Down Expand Up @@ -84,6 +94,11 @@ def regularise(self, prior, posterior):
moved_particles.state_vector = moved_particles.state_vector + \
hopt * cholesky_eps(covar_est) @ np.random.randn(ndim, nparticles)

# Apply constraints if defined
if self.constraint_func is not None:
part_indx = self.constraint_func(moved_particles)
moved_particles.state_vector[:, part_indx] = posterior.state_vector[:, part_indx]

# Evaluate likelihoods
part_diff = moved_particles.state_vector - transitioned_prior.state_vector
move_likelihood = multivariate_normal.logpdf(part_diff.T,
Expand Down
29 changes: 24 additions & 5 deletions stonesoup/regulariser/tests/test_particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,39 @@
from ..particle import MCMCRegulariser


def dummy_constraint_function(particles):
part_indx = particles.state_vector[1, :] > 20
return part_indx


@pytest.mark.parametrize(
"transition_model, model_flag",
"transition_model, model_flag, constraint_func",
[
(
CombinedLinearGaussianTransitionModel([ConstantVelocity([0.05])]), # transition_model
False, # model_flag
None # constraint_function
),
(
CombinedLinearGaussianTransitionModel([ConstantVelocity([0.05])]), # transition_model
True, # model_flag
None # constraint_function
),
(
None, # transition_model
False, # model_flag
None # constraint_function
),
(
CombinedLinearGaussianTransitionModel([ConstantVelocity([0.05])]), # transition_model
False, # model_flag
dummy_constraint_function # constraint_function
)
],
ids=["with_transition_model_init", "without_transition_model_init", "no_transition_model"]
ids=["with_transition_model_init", "without_transition_model_init", "no_transition_model",
"with_constraint_function"]
)
def test_regulariser(transition_model, model_flag):
def test_regulariser(transition_model, model_flag, constraint_func):
particles = ParticleState(state_vector=None, particle_list=[Particle(np.array([[10], [10]]),
1 / 9),
Particle(np.array([[10], [20]]),
Expand Down Expand Up @@ -78,9 +92,10 @@ def test_regulariser(transition_model, model_flag):
state_update.weight = np.array([1/6, 5/48, 5/48, 5/48, 5/48, 5/48, 5/48, 5/48, 5/48])

if model_flag:
regulariser = MCMCRegulariser()
regulariser = MCMCRegulariser(constraint_func=constraint_func)
else:
regulariser = MCMCRegulariser(transition_model=transition_model)
regulariser = MCMCRegulariser(transition_model=transition_model,
constraint_func=constraint_func)

# state check
new_particles = regulariser.regularise(prediction, state_update)
Expand All @@ -90,6 +105,10 @@ def test_regulariser(transition_model, model_flag):
assert any(new_particles.weight == state_update.weight)
# Check that the timestamp is the same
assert new_particles.timestamp == state_update.timestamp
# Check that moved particles have been reverted back to original states if constrained
if constraint_func is not None:
indx = constraint_func(prediction) # likely unconstrained particles
assert np.all(new_particles.state_vector[:, indx] == prediction.state_vector[:, indx])

# list check3
with pytest.raises(TypeError) as e:
Expand Down
16 changes: 16 additions & 0 deletions stonesoup/updater/particle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
from functools import lru_cache
from typing import Callable

import numpy as np
from scipy.linalg import inv
Expand Down Expand Up @@ -32,6 +33,16 @@ class ParticleUpdater(Updater):
regulariser: Regulariser = Property(default=None, doc="Regulariser to prevent particle "
"impoverishment")

constraint_func: Callable = Property(
default=None,
doc="Callable, user defined function for applying "
"constraints to the states. This is done by setting the weights "
"of particles to 0 for particles that are not correctly constrained. "
"This function provides indices of the unconstrained particles and "
"should accept a :class:`~.ParticleState` object and return an array-like "
"object of logical indices. "
)

def update(self, hypothesis, **kwargs):
"""Particle Filter update step
Expand Down Expand Up @@ -61,6 +72,11 @@ def update(self, hypothesis, **kwargs):
new_weight = predicted_state.log_weight + measurement_model.logpdf(
hypothesis.measurement, predicted_state, **kwargs)

# Apply constraints if defined
if self.constraint_func is not None:
part_indx = self.constraint_func(predicted_state)
new_weight[part_indx] = -1*np.inf

# Normalise the weights
new_weight -= logsumexp(new_weight)

Expand Down
19 changes: 16 additions & 3 deletions stonesoup/updater/tests/test_particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,17 @@
from ...sampler.detection import SwitchingDetectionSampler, GaussianDetectionParticleSampler


def dummy_constraint_function(particles):
part_indx = particles.state_vector[1, :] > 20
return part_indx


@pytest.fixture(params=(
ParticleUpdater,
partial(ParticleUpdater, resampler=SystematicResampler()),
GromovFlowParticleUpdater,
GromovFlowKalmanParticleUpdater))
GromovFlowKalmanParticleUpdater,
partial(ParticleUpdater, constraint_func=dummy_constraint_function)))
def updater(request):
updater_class = request.param
measurement_model = LinearGaussian(
Expand Down Expand Up @@ -71,12 +77,19 @@ def test_particle(updater):
# Don't know what the particles will exactly be due to randomness so check
# some obvious properties

assert np.all(weight == 1 / 9 for weight in updated_state.weight)
if hasattr(updater, 'constraint_func') and updater.constraint_func is not None:
indx = dummy_constraint_function(prediction)
assert np.all(updated_state.weight[indx] == 0)

assert np.isclose(np.sum(updated_state.weight.astype(np.float_)), 1.0, rtol=1e-5)
assert updated_state.timestamp == timestamp
assert updated_state.hypothesis.measurement_prediction == measurement_prediction
assert updated_state.hypothesis.prediction == prediction
assert updated_state.hypothesis.measurement == measurement
assert np.allclose(updated_state.mean, StateVectors([[20.0], [20.0]]), rtol=2e-2)
if hasattr(updater, 'constraint_func') and updater.constraint_func is not None:
assert np.allclose(updated_state.mean, StateVectors([[20.0], [15.0]]), rtol=2e-2)
else:
assert np.allclose(updated_state.mean, StateVectors([[20.0], [20.0]]), rtol=2e-2)


def test_bernoulli_particle():
Expand Down

0 comments on commit 900538e

Please sign in to comment.