diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index e5bc43a762..068a2031c6 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -220,7 +220,7 @@ TEST(HybridBayesNet, Optimize) { /* ****************************************************************************/ // Test Bayes net error -TEST(HybridBayesNet, logProbability) { +TEST(HybridBayesNet, Pruning) { Switching s(3); HybridBayesNet::shared_ptr posterior = @@ -228,25 +228,22 @@ TEST(HybridBayesNet, logProbability) { EXPECT_LONGS_EQUAL(5, posterior->size()); HybridValues delta = posterior->optimize(); - auto actualTree = posterior->logProbability(delta.continuous()); + auto actualTree = posterior->evaluate(delta.continuous()); + // Regression test on density tree. std::vector discrete_keys = {{M(0), 2}, {M(1), 2}}; - std::vector leaves = {1.8101301, 3.0128899, 2.8784032, 2.9825507}; + std::vector leaves = {6.1112424, 20.346113, 17.785849, 19.738098}; AlgebraicDecisionTree expected(discrete_keys, leaves); - - // regression EXPECT(assert_equal(expected, actualTree, 1e-6)); - // logProbability on pruned Bayes net + // Prune and get probabilities auto prunedBayesNet = posterior->prune(2); - auto prunedTree = prunedBayesNet.logProbability(delta.continuous()); + auto prunedTree = prunedBayesNet.evaluate(delta.continuous()); - std::vector pruned_leaves = {2e50, 3.0128899, 2e50, 2.9825507}; + // Regression test on pruned logProbability tree + std::vector pruned_leaves = {0.0, 20.346113, 0.0, 19.738098}; AlgebraicDecisionTree expected_pruned(discrete_keys, pruned_leaves); - - // regression - // TODO(dellaert): fix pruning, I have no insight in this code. - // EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6)); + EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6)); // Verify logProbability computation and check specific logProbability value const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}}; @@ -261,8 +258,9 @@ TEST(HybridBayesNet, logProbability) { logProbability += posterior->at(4)->asDiscrete()->logProbability(hybridValues); - EXPECT_DOUBLES_EQUAL(logProbability, actualTree(discrete_values), 1e-9); - EXPECT_DOUBLES_EQUAL(logProbability, prunedTree(discrete_values), 1e-9); + double density = exp(logProbability); + EXPECT_DOUBLES_EQUAL(density, actualTree(discrete_values), 1e-9); + EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9); EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues), 1e-9); }