diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index 13e84c5b87..d725ceac87 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -14,7 +14,7 @@ import unittest import numpy as np -from gtsam import DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol +from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol from gtsam.utils.test_case import GtsamTestCase OrderingType = Ordering.OrderingType @@ -250,7 +250,7 @@ def make_key(character, index, cardinality): # Ensure that the stationary distribution is positive and normalized stationary_dist /= np.sum(stationary_dist) - expected = stationary_dist.flatten() + expected = DecisionTreeFactor(X[chain_length-1], stationary_dist.flatten()) # The transition matrix parsed by DiscreteConditional is a row-wise CPT transitions = transitions.T @@ -271,16 +271,8 @@ def make_key(character, index, cardinality): # Get the DiscreteConditional representing the marginal on the last factor last_marginal = sum_product.at(chain_length - 1) - # Extract the actual marginal probabilities - assignment = DiscreteValues() - marginal_probs = [] - for i in range(num_states): - assignment[X[chain_length - 1][0]] = i - marginal_probs.append(last_marginal(assignment)) - marginal_probs = np.array(marginal_probs) - # Ensure marginal probabilities are close to the stationary distribution - self.gtsamAssertEquals(expected, marginal_probs) + self.gtsamAssertEquals(expected, last_marginal) if __name__ == "__main__": unittest.main()