Skip to content

Commit

Permalink
Updating SquaredGaussianHellinger for catch when negative, and checki…
Browse files Browse the repository at this point in the history
…ng attribute to use state.mean
  • Loading branch information
jwragg-dstl committed Jul 1, 2022
1 parent a4c0cc3 commit 261a1cf
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions stonesoup/measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,18 +253,28 @@ def __call__(self, state1, state2):
"""
if hasattr(state1, 'mean'):
state_vector1 = state1.mean
else:
state_vector1 = state1.state_vector

if hasattr(state2, 'mean'):
state_vector2 = state2.mean
else:
state_vector2 = state2.state_vector

if self.mapping is not None:
mu1 = state1.state_vector[self.mapping, :]
mu2 = state2.state_vector[self.mapping2, :]
mu1 = state_vector1[self.mapping, :]
mu2 = state_vector2[self.mapping2, :]

# extract the mapped covariance data
rows = np.array(self.mapping, dtype=np.intp)
columns = np.array(self.mapping, dtype=np.intp)
sigma1 = state1.covar[rows[:, np.newaxis], columns]
sigma2 = state2.covar[rows[:, np.newaxis], columns]
else:
mu1 = state1.state_vector
mu2 = state2.state_vector
mu1 = state_vector1
mu2 = state_vector2
sigma1 = state1.covar
sigma2 = state2.covar

Expand All @@ -276,6 +286,12 @@ def __call__(self, state1, state2):
denominator = np.linalg.det(sigma1_plus_sigma2/2)
squared_hellinger = 1 - np.sqrt(numerator/denominator)*np.exp(epsilon)
squared_hellinger = squared_hellinger.item()

if -1e-10 < squared_hellinger < 0.0:
squared_hellinger = 0.0
elif squared_hellinger < 0.0: # pragma: no cover
raise ValueError("Measure shouldn't be less than 0") # this should be impossible

return squared_hellinger


Expand Down

0 comments on commit 261a1cf

Please sign in to comment.