diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 87fdca72c9..9c53b3b70f 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -92,12 +92,28 @@ class DiscreteBayesNet { }; #include +class DiscreteBayesTreeClique { + DiscreteBayesTreeClique(); + DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional); + const gtsam::DiscreteConditional* conditional() const; + bool isRoot() const; + void printSignature( + const string& s = "Clique: ", + const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; + double evaluate(const gtsam::DiscreteValues& values) const; +}; + class DiscreteBayesTree { DiscreteBayesTree(); void print(string s = "DiscreteBayesTree\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const; + + size_t size() const; + bool empty() const; + const DiscreteBayesTreeClique* operator[](size_t j) const; + string dot(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; void saveGraph(string s, diff --git a/python/gtsam/tests/test_DiscreteBayesTree.dot b/python/gtsam/tests/test_DiscreteBayesTree.dot new file mode 100644 index 0000000000..d7cf7d9bc0 --- /dev/null +++ b/python/gtsam/tests/test_DiscreteBayesTree.dot @@ -0,0 +1,25 @@ +digraph G{ +0[label="8,12,14"]; +0->1 +1[label="0 : 8,12"]; +0->2 +2[label="1 : 8,12"]; +0->3 +3[label="9 : 12,14"]; +3->4 +4[label="2 : 9,12"]; +3->5 +5[label="3 : 9,12"]; +0->6 +6[label="10,13 : 14"]; +6->7 +7[label="4 : 10,13"]; +6->8 +8[label="5 : 10,13"]; +6->9 +9[label="11 : 13,14"]; +9->10 +10[label="6 : 11,13"]; +9->11 +11[label="7 : 11,13"]; +} \ No newline at end of file diff --git a/python/gtsam/tests/test_DiscreteBayesTree.py b/python/gtsam/tests/test_DiscreteBayesTree.py new file mode 100644 index 0000000000..d87734de99 --- /dev/null +++ b/python/gtsam/tests/test_DiscreteBayesTree.py @@ -0,0 +1,89 @@ +""" +GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Bayes trees. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique, + DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, + Ordering) +from gtsam.utils.test_case import GtsamTestCase + + +def P(*args): + """ Create a DiscreteKeys instances from a variable number of DiscreteKey pairs.""" + # TODO: We can make life easier by providing variable argument functions in C++ itself. + dks = DiscreteKeys() + for key in args: + dks.push_back(key) + return dks + + +class TestDiscreteBayesNet(GtsamTestCase): + """Tests for Discrete Bayes Nets.""" + + def test_elimination(self): + """Test Multifrontal elimination.""" + + # Define DiscreteKey pairs. + keys = [(j, 2) for j in range(15)] + + # Create thin-tree Bayesnet. + bayesNet = DiscreteBayesNet() + + bayesNet.add(keys[0], P(keys[8], keys[12]), "2/3 1/4 3/2 4/1") + bayesNet.add(keys[1], P(keys[8], keys[12]), "4/1 2/3 3/2 1/4") + bayesNet.add(keys[2], P(keys[9], keys[12]), "1/4 8/2 2/3 4/1") + bayesNet.add(keys[3], P(keys[9], keys[12]), "1/4 2/3 3/2 4/1") + + bayesNet.add(keys[4], P(keys[10], keys[13]), "2/3 1/4 3/2 4/1") + bayesNet.add(keys[5], P(keys[10], keys[13]), "4/1 2/3 3/2 1/4") + bayesNet.add(keys[6], P(keys[11], keys[13]), "1/4 3/2 2/3 4/1") + bayesNet.add(keys[7], P(keys[11], keys[13]), "1/4 2/3 3/2 4/1") + + bayesNet.add(keys[8], P(keys[12], keys[14]), "T 1/4 3/2 4/1") + bayesNet.add(keys[9], P(keys[12], keys[14]), "4/1 2/3 F 1/4") + bayesNet.add(keys[10], P(keys[13], keys[14]), "1/4 3/2 2/3 4/1") + bayesNet.add(keys[11], P(keys[13], keys[14]), "1/4 2/3 3/2 4/1") + + bayesNet.add(keys[12], P(keys[14]), "3/1 3/1") + bayesNet.add(keys[13], P(keys[14]), "1/3 3/1") + + bayesNet.add(keys[14], P(), "1/3") + + # Create a factor graph out of the Bayes net. + factorGraph = DiscreteFactorGraph(bayesNet) + + # Create a BayesTree out of the factor graph. + ordering = Ordering() + for j in range(15): + ordering.push_back(j) + bayesTree = factorGraph.eliminateMultifrontal(ordering) + + # Uncomment these for visualization: + # print(bayesTree) + # for key in range(15): + # bayesTree[key].printSignature() + # bayesTree.saveGraph("test_DiscreteBayesTree.dot") + + self.assertFalse(bayesTree.empty()) + self.assertEqual(12, bayesTree.size()) + + # The root is P( 8 12 14), we can retrieve it by key: + root = bayesTree[8] + self.assertIsInstance(root, DiscreteBayesTreeClique) + self.assertTrue(root.isRoot()) + self.assertIsInstance(root.conditional(), DiscreteConditional) + + +if __name__ == "__main__": + unittest.main()