forked from dstl/Stone-Soup
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_particle.py
161 lines (146 loc) · 8.84 KB
/
test_particle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import numpy as np
import datetime
import pytest
from ...types.state import ParticleState
from ...types.particle import Particle
from ...types.hypothesis import SingleHypothesis
from ...types.prediction import ParticleStatePrediction, ParticleMeasurementPrediction
from ...models.measurement.linear import LinearGaussian
from ...models.transition.linear import CombinedLinearGaussianTransitionModel, ConstantVelocity
from ...types.detection import Detection
from ...types.update import Update, ParticleStateUpdate
from ..particle import MCMCRegulariser
def dummy_constraint_function(particles):
part_indx = particles.state_vector[1, :] > 20
return part_indx
@pytest.mark.parametrize(
"transition_model, model_flag, constraint_func",
[
(
CombinedLinearGaussianTransitionModel([ConstantVelocity([0.05])]), # transition_model
False, # model_flag
None # constraint_function
),
(
CombinedLinearGaussianTransitionModel([ConstantVelocity([0.05])]), # transition_model
True, # model_flag
None # constraint_function
),
(
None, # transition_model
False, # model_flag
None # constraint_function
),
(
CombinedLinearGaussianTransitionModel([ConstantVelocity([0.05])]), # transition_model
False, # model_flag
dummy_constraint_function # constraint_function
)
],
ids=["with_transition_model_init", "without_transition_model_init", "no_transition_model",
"with_constraint_function"]
)
def test_regulariser(transition_model, model_flag, constraint_func):
particles = ParticleState(state_vector=None, particle_list=[Particle(np.array([[10], [10]]),
1 / 9),
Particle(np.array([[10], [20]]),
1 / 9),
Particle(np.array([[10], [30]]),
1 / 9),
Particle(np.array([[20], [10]]),
1 / 9),
Particle(np.array([[20], [20]]),
1 / 9),
Particle(np.array([[20], [30]]),
1 / 9),
Particle(np.array([[30], [10]]),
1 / 9),
Particle(np.array([[30], [20]]),
1 / 9),
Particle(np.array([[30], [30]]),
1 / 9),
])
timestamp = datetime.datetime.now()
if transition_model is not None:
new_state_vector = transition_model.function(particles,
noise=True,
time_interval=datetime.timedelta(seconds=1))
else:
new_state_vector = particles.state_vector
prediction = ParticleStatePrediction(new_state_vector,
timestamp=timestamp,
transition_model=transition_model)
measurement_model = LinearGaussian(ndim_state=2, mapping=(0, 1), noise_covar=np.eye(2))
measurement = Detection(state_vector=np.array([[5], [7]]),
timestamp=timestamp, measurement_model=measurement_model)
hypothesis = SingleHypothesis(prediction=prediction,
measurement=measurement,
measurement_prediction=None)
state_update = Update.from_state(state=prediction,
hypothesis=hypothesis,
timestamp=timestamp+datetime.timedelta(seconds=1))
# A PredictedParticleState is used here as the point at which the regulariser is implemented
# in the updater is before the updated state has taken the updated state type.
state_update.weight = np.array([1/6, 5/48, 5/48, 5/48, 5/48, 5/48, 5/48, 5/48, 5/48])
if model_flag:
regulariser = MCMCRegulariser(constraint_func=constraint_func)
else:
regulariser = MCMCRegulariser(transition_model=transition_model,
constraint_func=constraint_func)
# state check
new_particles = regulariser.regularise(prediction, state_update)
# Check the shape of the new state vector
assert new_particles.state_vector.shape == state_update.state_vector.shape
# Check weights are unchanged
assert any(new_particles.weight == state_update.weight)
# Check that the timestamp is the same
assert new_particles.timestamp == state_update.timestamp
# Check that moved particles have been reverted back to original states if constrained
if constraint_func is not None:
indx = constraint_func(prediction) # likely unconstrained particles
assert np.all(new_particles.state_vector[:, indx] == prediction.state_vector[:, indx])
# list check3
with pytest.raises(TypeError) as e:
new_particles = regulariser.regularise(particles.particle_list,
state_update)
assert "Only ParticleState type is supported!" in str(e.value)
with pytest.raises(Exception) as e:
new_particles = regulariser.regularise(particles,
state_update.particle_list)
assert "Only ParticleState type is supported!" in str(e.value)
def test_no_measurement():
particles = ParticleState(state_vector=None, particle_list=[Particle(np.array([[10], [10]]),
1 / 9),
Particle(np.array([[10], [20]]),
1 / 9),
Particle(np.array([[10], [30]]),
1 / 9),
Particle(np.array([[20], [10]]),
1 / 9),
Particle(np.array([[20], [20]]),
1 / 9),
Particle(np.array([[20], [30]]),
1 / 9),
Particle(np.array([[30], [10]]),
1 / 9),
Particle(np.array([[30], [20]]),
1 / 9),
Particle(np.array([[30], [30]]),
1 / 9),
])
timestamp = datetime.datetime.now()
prediction = ParticleStatePrediction(None, particle_list=particles.particle_list,
timestamp=timestamp)
meas_pred = ParticleMeasurementPrediction(None, particle_list=particles, timestamp=timestamp)
state_update = ParticleStateUpdate(None, SingleHypothesis(prediction=prediction,
measurement=None,
measurement_prediction=meas_pred),
particle_list=particles.particle_list, timestamp=timestamp)
regulariser = MCMCRegulariser()
new_particles = regulariser.regularise(particles, state_update)
# Check the shape of the new state vector
assert new_particles.state_vector.shape == state_update.state_vector.shape
# Check weights are unchanged
assert any(new_particles.weight == state_update.weight)
# Check that the timestamp is the same
assert new_particles.timestamp == state_update.timestamp