From 9eea6cf21af426598b41a72bed8bc8022bf9149c Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 25 Jan 2022 17:15:52 -0500 Subject: [PATCH 1/3] Added sumProduct as a convenient alias --- gtsam/discrete/DiscreteFactorGraph.cpp | 17 +++++++++++++++ gtsam/discrete/DiscreteFactorGraph.h | 21 +++++++++++++++++-- .../tests/testDiscreteFactorGraph.cpp | 10 +++++++++ 3 files changed, 46 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index f8e1b4bb89..b4b65f885d 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -144,6 +144,23 @@ namespace gtsam { boost::dynamic_pointer_cast(lookup), max); } + /* ************************************************************************ */ + // sumProduct is just an alias for regular eliminateSequential. + DiscreteBayesNet DiscreteFactorGraph::sumProduct( + OptionalOrderingType orderingType) const { + gttic(DiscreteFactorGraph_sumProduct); + auto bayesNet = BaseEliminateable::eliminateSequential(orderingType); + return *bayesNet; + } + + DiscreteLookupDAG DiscreteFactorGraph::sumProduct( + const Ordering& ordering) const { + gttic(DiscreteFactorGraph_sumProduct); + auto bayesNet = + BaseEliminateable::eliminateSequential(ordering, EliminateForMPE); + return DiscreteLookupDAG::FromBayesNet(*bayesNet); + } + /* ************************************************************************ */ // The max-product solution below is a bit clunky: the elimination machinery // does not allow for differently *typed* versions of elimination, so we diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 1ba39ff9d0..2e9b40823f 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -132,11 +132,28 @@ class GTSAM_EXPORT DiscreteFactorGraph const std::string& s = "DiscreteFactorGraph", const KeyFormatter& formatter = DefaultKeyFormatter) const override; + /** + * @brief Implement the sum-product algorithm + * + * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM + * @return DiscreteBayesNet encoding posterior P(X|Z) + */ + DiscreteBayesNet sumProduct( + OptionalOrderingType orderingType = boost::none) const; + + /** + * @brief Implement the sum-product algorithm + * + * @param ordering + * @return DiscreteBayesNet encoding posterior P(X|Z) + */ + DiscreteLookupDAG sumProduct(const Ordering& ordering) const; + /** * @brief Implement the max-product algorithm * * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM - * @return DiscreteLookupDAG::shared_ptr DAG with lookup tables + * @return DiscreteLookupDAG DAG with lookup tables */ DiscreteLookupDAG maxProduct( OptionalOrderingType orderingType = boost::none) const; @@ -145,7 +162,7 @@ class GTSAM_EXPORT DiscreteFactorGraph * @brief Implement the max-product algorithm * * @param ordering - * @return DiscreteLookupDAG::shared_ptr `DAG with lookup tables + * @return DiscreteLookupDAG `DAG with lookup tables */ DiscreteLookupDAG maxProduct(const Ordering& ordering) const; diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index f4819dab54..63f5b73194 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -154,6 +154,16 @@ TEST(DiscreteFactorGraph, test) { auto actualMPE = graph.optimize(); EXPECT(assert_equal(mpe, actualMPE)); EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression + + // Test sumProduct alias with all orderings: + auto mpeProbability = expectedBayesNet(mpe); + EXPECT_DOUBLES_EQUAL(0.28125, mpeProbability, 1e-5); // regression + for (Ordering::OrderingType orderingType : + {Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL, + Ordering::CUSTOM}) { + auto bayesNet = graph.sumProduct(orderingType); + EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5); + } } /* ************************************************************************* */ From 09fa002bd76ab98abfc32b0b1579e6642710124e Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 25 Jan 2022 17:31:49 -0500 Subject: [PATCH 2/3] Python --- gtsam/discrete/discrete.i | 5 ++ gtsam/nonlinear/nonlinear.i | 5 ++ .../gtsam/tests/test_DiscreteFactorGraph.py | 52 ++++++++++++++++--- 3 files changed, 55 insertions(+), 7 deletions(-) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 3f2c3e0602..0dcbcc1cfc 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -277,7 +277,12 @@ class DiscreteFactorGraph { double operator()(const gtsam::DiscreteValues& values) const; gtsam::DiscreteValues optimize() const; + gtsam::DiscreteLookupDAG sumProduct(); + gtsam::DiscreteLookupDAG sumProduct(gtsam::Ordering::OrderingType type); + gtsam::DiscreteLookupDAG sumProduct(const gtsam::Ordering& ordering); + gtsam::DiscreteLookupDAG maxProduct(); + gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type); gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering); gtsam::DiscreteBayesNet eliminateSequential(); diff --git a/gtsam/nonlinear/nonlinear.i b/gtsam/nonlinear/nonlinear.i index b6ab086c45..a6883d38b8 100644 --- a/gtsam/nonlinear/nonlinear.i +++ b/gtsam/nonlinear/nonlinear.i @@ -111,6 +111,11 @@ size_t mrsymbolIndex(size_t key); #include class Ordering { + /// Type of ordering to use + enum OrderingType { + COLAMD, METIS, NATURAL, CUSTOM + }; + // Standard Constructors and Named Constructors Ordering(); Ordering(const gtsam::Ordering& other); diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index 1ba145e096..ef85fc7534 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -13,9 +13,11 @@ import unittest -from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues +from gtsam import DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering from gtsam.utils.test_case import GtsamTestCase +OrderingType = Ordering.OrderingType + class TestDiscreteFactorGraph(GtsamTestCase): """Tests for Discrete Factor Graphs.""" @@ -108,14 +110,50 @@ def test_MPE(self): graph.add([C, A], "0.2 0.8 0.3 0.7") graph.add([C, B], "0.1 0.9 0.4 0.6") - actualMPE = graph.optimize() + # We know MPE + mpe = DiscreteValues() + mpe[0] = 0 + mpe[1] = 1 + mpe[2] = 1 - expectedMPE = DiscreteValues() - expectedMPE[0] = 0 - expectedMPE[1] = 1 - expectedMPE[2] = 1 + # Use maxProduct + dag = graph.maxProduct(OrderingType.COLAMD) + actualMPE = dag.argmax() self.assertEqual(list(actualMPE.items()), - list(expectedMPE.items())) + list(mpe.items())) + + # All in one + actualMPE2 = graph.optimize() + self.assertEqual(list(actualMPE2.items()), + list(mpe.items())) + + def test_sumProduct(self): + """Test sumProduct.""" + + # Declare a bunch of keys + C, A, B = (0, 2), (1, 2), (2, 2) + + # Create Factor graph + graph = DiscreteFactorGraph() + graph.add([C, A], "0.2 0.8 0.3 0.7") + graph.add([C, B], "0.1 0.9 0.4 0.6") + + # We know MPE + mpe = DiscreteValues() + mpe[0] = 0 + mpe[1] = 1 + mpe[2] = 1 + + # Use default sumProduct + bayesNet = graph.sumProduct() + mpeProbability = bayesNet(mpe) + self.assertAlmostEqual(mpeProbability, 0.36) # regression + + # Use sumProduct + for ordering_type in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL, + OrderingType.CUSTOM]: + bayesNet = graph.sumProduct(ordering_type) + self.assertEqual(bayesNet(mpe), mpeProbability) if __name__ == "__main__": From d6b977927e204c2817d2f3e23295a05e435f94da Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 25 Jan 2022 23:47:53 -0500 Subject: [PATCH 3/3] Fix return type --- gtsam/discrete/DiscreteFactorGraph.cpp | 15 ++++++--------- gtsam/discrete/DiscreteFactorGraph.h | 2 +- gtsam/discrete/discrete.i | 6 +++--- gtsam/discrete/tests/testDiscreteFactorGraph.cpp | 5 +++++ 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index b4b65f885d..ebcac445c5 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -149,16 +149,15 @@ namespace gtsam { DiscreteBayesNet DiscreteFactorGraph::sumProduct( OptionalOrderingType orderingType) const { gttic(DiscreteFactorGraph_sumProduct); - auto bayesNet = BaseEliminateable::eliminateSequential(orderingType); + auto bayesNet = eliminateSequential(orderingType); return *bayesNet; } - DiscreteLookupDAG DiscreteFactorGraph::sumProduct( + DiscreteBayesNet DiscreteFactorGraph::sumProduct( const Ordering& ordering) const { gttic(DiscreteFactorGraph_sumProduct); - auto bayesNet = - BaseEliminateable::eliminateSequential(ordering, EliminateForMPE); - return DiscreteLookupDAG::FromBayesNet(*bayesNet); + auto bayesNet = eliminateSequential(ordering); + return *bayesNet; } /* ************************************************************************ */ @@ -170,16 +169,14 @@ namespace gtsam { DiscreteLookupDAG DiscreteFactorGraph::maxProduct( OptionalOrderingType orderingType) const { gttic(DiscreteFactorGraph_maxProduct); - auto bayesNet = - BaseEliminateable::eliminateSequential(orderingType, EliminateForMPE); + auto bayesNet = eliminateSequential(orderingType, EliminateForMPE); return DiscreteLookupDAG::FromBayesNet(*bayesNet); } DiscreteLookupDAG DiscreteFactorGraph::maxProduct( const Ordering& ordering) const { gttic(DiscreteFactorGraph_maxProduct); - auto bayesNet = - BaseEliminateable::eliminateSequential(ordering, EliminateForMPE); + auto bayesNet = eliminateSequential(ordering, EliminateForMPE); return DiscreteLookupDAG::FromBayesNet(*bayesNet); } diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 2e9b40823f..f962b1802d 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -147,7 +147,7 @@ class GTSAM_EXPORT DiscreteFactorGraph * @param ordering * @return DiscreteBayesNet encoding posterior P(X|Z) */ - DiscreteLookupDAG sumProduct(const Ordering& ordering) const; + DiscreteBayesNet sumProduct(const Ordering& ordering) const; /** * @brief Implement the max-product algorithm diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 0dcbcc1cfc..2582869019 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -277,9 +277,9 @@ class DiscreteFactorGraph { double operator()(const gtsam::DiscreteValues& values) const; gtsam::DiscreteValues optimize() const; - gtsam::DiscreteLookupDAG sumProduct(); - gtsam::DiscreteLookupDAG sumProduct(gtsam::Ordering::OrderingType type); - gtsam::DiscreteLookupDAG sumProduct(const gtsam::Ordering& ordering); + gtsam::DiscreteBayesNet sumProduct(); + gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type); + gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering); gtsam::DiscreteLookupDAG maxProduct(); gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type); diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index 63f5b73194..0a7d869ec5 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -158,6 +158,11 @@ TEST(DiscreteFactorGraph, test) { // Test sumProduct alias with all orderings: auto mpeProbability = expectedBayesNet(mpe); EXPECT_DOUBLES_EQUAL(0.28125, mpeProbability, 1e-5); // regression + + // Using custom ordering + DiscreteBayesNet bayesNet = graph.sumProduct(ordering); + EXPECT_DOUBLES_EQUAL(mpeProbability, bayesNet(mpe), 1e-5); + for (Ordering::OrderingType orderingType : {Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL, Ordering::CUSTOM}) {