diff --git a/stonesoup/measures.py b/stonesoup/measures.py index 4ee890705..f67ab7ef1 100644 --- a/stonesoup/measures.py +++ b/stonesoup/measures.py @@ -253,9 +253,19 @@ def __call__(self, state1, state2): """ + if hasattr(state1, 'mean'): + state_vector1 = state1.mean + else: + state_vector1 = state1.state_vector + + if hasattr(state2, 'mean'): + state_vector2 = state2.mean + else: + state_vector2 = state2.state_vector + if self.mapping is not None: - mu1 = state1.state_vector[self.mapping, :] - mu2 = state2.state_vector[self.mapping2, :] + mu1 = state_vector1[self.mapping, :] + mu2 = state_vector2[self.mapping2, :] # extract the mapped covariance data rows = np.array(self.mapping, dtype=np.intp) @@ -263,8 +273,8 @@ def __call__(self, state1, state2): sigma1 = state1.covar[rows[:, np.newaxis], columns] sigma2 = state2.covar[rows[:, np.newaxis], columns] else: - mu1 = state1.state_vector - mu2 = state2.state_vector + mu1 = state_vector1 + mu2 = state_vector2 sigma1 = state1.covar sigma2 = state2.covar @@ -276,6 +286,12 @@ def __call__(self, state1, state2): denominator = np.linalg.det(sigma1_plus_sigma2/2) squared_hellinger = 1 - np.sqrt(numerator/denominator)*np.exp(epsilon) squared_hellinger = squared_hellinger.item() + + if -1e-10 < squared_hellinger < 0.0: + squared_hellinger = 0.0 + elif squared_hellinger < 0.0: # pragma: no cover + raise ValueError("Measure shouldn't be less than 0") # this should be impossible + return squared_hellinger