Skip to content

Commit

Permalink
Simply sum-product test in Python.
Browse files Browse the repository at this point in the history
  • Loading branch information
keevindoherty committed Feb 12, 2023
1 parent 9fa2d30 commit 92443f5
Showing 1 changed file with 3 additions and 11 deletions.
14 changes: 3 additions & 11 deletions python/gtsam/tests/test_DiscreteFactorGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import unittest

import numpy as np
from gtsam import DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol
from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol
from gtsam.utils.test_case import GtsamTestCase

OrderingType = Ordering.OrderingType
Expand Down Expand Up @@ -250,7 +250,7 @@ def make_key(character, index, cardinality):

# Ensure that the stationary distribution is positive and normalized
stationary_dist /= np.sum(stationary_dist)
expected = stationary_dist.flatten()
expected = DecisionTreeFactor(X[chain_length-1], stationary_dist.flatten())

# The transition matrix parsed by DiscreteConditional is a row-wise CPT
transitions = transitions.T
Expand All @@ -271,16 +271,8 @@ def make_key(character, index, cardinality):
# Get the DiscreteConditional representing the marginal on the last factor
last_marginal = sum_product.at(chain_length - 1)

# Extract the actual marginal probabilities
assignment = DiscreteValues()
marginal_probs = []
for i in range(num_states):
assignment[X[chain_length - 1][0]] = i
marginal_probs.append(last_marginal(assignment))
marginal_probs = np.array(marginal_probs)

# Ensure marginal probabilities are close to the stationary distribution
self.gtsamAssertEquals(expected, marginal_probs)
self.gtsamAssertEquals(expected, last_marginal)

if __name__ == "__main__":
unittest.main()

0 comments on commit 92443f5

Please sign in to comment.