diff --git a/gtsam/base/tests/testMatrix.cpp b/gtsam/base/tests/testMatrix.cpp index a7c2187059..7802f27e1d 100644 --- a/gtsam/base/tests/testMatrix.cpp +++ b/gtsam/base/tests/testMatrix.cpp @@ -173,7 +173,7 @@ TEST(Matrix, stack ) { Matrix A = (Matrix(2, 2) << -5.0, 3.0, 00.0, -5.0).finished(); Matrix B = (Matrix(3, 2) << -0.5, 2.1, 1.1, 3.4, 2.6, 7.1).finished(); - Matrix AB = stack(2, &A, &B); + Matrix AB = gtsam::stack(2, &A, &B); Matrix C(5, 2); for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) @@ -187,7 +187,7 @@ TEST(Matrix, stack ) std::vector matrices; matrices.push_back(A); matrices.push_back(B); - Matrix AB2 = stack(matrices); + Matrix AB2 = gtsam::stack(matrices); EQUALITY(C,AB2); } diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 439889ebfc..f6a64f11fc 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -248,8 +248,9 @@ namespace gtsam { void dot(std::ostream& os, bool showZero) const override { os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_ << "\"]\n"; - for (size_t i = 0; i < branches_.size(); i++) { - NodePtr branch = branches_[i]; + size_t B = branches_.size(); + for (size_t i = 0; i < B; i++) { + const NodePtr& branch = branches_[i]; // Check if zero if (!showZero) { @@ -258,8 +259,10 @@ namespace gtsam { } os << "\"" << this->id() << "\" -> \"" << branch->id() << "\""; - if (i == 0) os << " [style=dashed]"; - if (i > 1) os << " [style=bold]"; + if (B == 2) { + if (i == 0) os << " [style=dashed]"; + if (i > 1) os << " [style=bold]"; + } os << std::endl; branch->dot(os, showZero); } @@ -671,7 +674,14 @@ namespace gtsam { int result = system( ("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str()); if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed"); -} + } + + template + std::string DecisionTree::dot(bool showZero) const { + std::stringstream ss; + dot(ss, showZero); + return ss.str(); + } /*********************************************************************************/ diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 0ee0b8be0c..1f76f4ca32 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -198,6 +198,9 @@ namespace gtsam { /** output to graphviz format, open a file */ void dot(const std::string& name, bool showZero = true) const; + /** output to graphviz format string */ + std::string dot(bool showZero = true) const; + /// @name Advanced Interface /// @{ diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index daea84e70d..87fdca72c9 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -38,7 +38,7 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; - double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat??? + string dot(bool showZero = false) const; }; #include @@ -52,8 +52,6 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { DiscreteConditional(const gtsam::DecisionTreeFactor& joint, const gtsam::DecisionTreeFactor& marginal, const gtsam::Ordering& orderedKeys); - size_t size() const; // TODO(dellaert): why do I have to repeat??? - double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat??? void print(string s = "Discrete Conditional\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; @@ -82,6 +80,8 @@ class DiscreteBayesNet { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const; + string dot(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; @@ -98,9 +98,19 @@ class DiscreteBayesTree { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const; + string dot(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + void saveGraph(string s, + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; double operator()(const gtsam::DiscreteValues& values) const; }; +#include +class DotWriter { + DotWriter(); +}; + #include class DiscreteFactorGraph { DiscreteFactorGraph(); @@ -117,7 +127,15 @@ class DiscreteFactorGraph { void print(string s = "") const; bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const; - + + string dot(const gtsam::DotWriter& dotWriter = gtsam::DotWriter(), + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + void saveGraph(string s, + const gtsam::DotWriter& dotWriter = gtsam::DotWriter(), + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + gtsam::DecisionTreeFactor product() const; double operator()(const gtsam::DiscreteValues& values) const; gtsam::DiscreteValues optimize() const; diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index cb50dd05f2..f0c7d37281 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -135,7 +135,7 @@ TEST(DiscreteBayesNet, Asia) { } /* ************************************************************************* */ -TEST_UNSAFE(DiscreteBayesNet, Sugar) { +TEST(DiscreteBayesNet, Sugar) { DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2); DiscreteBayesNet bn; @@ -149,6 +149,29 @@ TEST_UNSAFE(DiscreteBayesNet, Sugar) { bn.add(C | S = "1/1/2 5/2/3"); } +/* ************************************************************************* */ +TEST(DiscreteBayesNet, Dot) { + DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2), + Either(5, 2); + + DiscreteBayesNet fragment; + fragment.add(Asia % "99/1"); + fragment.add(Smoking % "50/50"); + + fragment.add(Tuberculosis | Asia = "99/1 95/5"); + fragment.add(LungCancer | Smoking = "99/1 90/10"); + fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); + + string actual = fragment.dot(); + EXPECT(actual == + "digraph G{\n" + "0->3\n" + "4->6\n" + "3->5\n" + "6->5\n" + "}"); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index 73f345151e..d9f5f5df79 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -26,76 +26,89 @@ using namespace boost::assign; #include +#include #include using namespace std; using namespace gtsam; - -static bool debug = false; +static constexpr bool debug = false; /* ************************************************************************* */ +struct TestFixture { + vector keys; + DiscreteBayesNet bayesNet; + boost::shared_ptr bayesTree; + + /** + * Create a thin-tree Bayesnet, a la Jean-Guillaume Durand (former student), + * and then create the Bayes tree from it. + */ + TestFixture() { + // Define variables. + for (int i = 0; i < 15; i++) { + DiscreteKey key_i(i, 2); + keys.push_back(key_i); + } -TEST_UNSAFE(DiscreteBayesTree, ThinTree) { - const int nrNodes = 15; - const size_t nrStates = 2; + // Create thin-tree Bayesnet. + bayesNet.add(keys[14] % "1/3"); - // define variables - vector key; - for (int i = 0; i < nrNodes; i++) { - DiscreteKey key_i(i, nrStates); - key.push_back(key_i); - } + bayesNet.add(keys[13] | keys[14] = "1/3 3/1"); + bayesNet.add(keys[12] | keys[14] = "3/1 3/1"); - // create a thin-tree Bayesnet, a la Jean-Guillaume - DiscreteBayesNet bayesNet; - bayesNet.add(key[14] % "1/3"); + bayesNet.add((keys[11] | keys[13], keys[14]) = "1/4 2/3 3/2 4/1"); + bayesNet.add((keys[10] | keys[13], keys[14]) = "1/4 3/2 2/3 4/1"); + bayesNet.add((keys[9] | keys[12], keys[14]) = "4/1 2/3 F 1/4"); + bayesNet.add((keys[8] | keys[12], keys[14]) = "T 1/4 3/2 4/1"); - bayesNet.add(key[13] | key[14] = "1/3 3/1"); - bayesNet.add(key[12] | key[14] = "3/1 3/1"); + bayesNet.add((keys[7] | keys[11], keys[13]) = "1/4 2/3 3/2 4/1"); + bayesNet.add((keys[6] | keys[11], keys[13]) = "1/4 3/2 2/3 4/1"); + bayesNet.add((keys[5] | keys[10], keys[13]) = "4/1 2/3 3/2 1/4"); + bayesNet.add((keys[4] | keys[10], keys[13]) = "2/3 1/4 3/2 4/1"); - bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1"); - bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1"); - bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4"); - bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1"); + bayesNet.add((keys[3] | keys[9], keys[12]) = "1/4 2/3 3/2 4/1"); + bayesNet.add((keys[2] | keys[9], keys[12]) = "1/4 8/2 2/3 4/1"); + bayesNet.add((keys[1] | keys[8], keys[12]) = "4/1 2/3 3/2 1/4"); + bayesNet.add((keys[0] | keys[8], keys[12]) = "2/3 1/4 3/2 4/1"); - bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1"); - bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1"); - bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4"); - bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1"); + // Create a BayesTree out of the Bayes net. + bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal(); + } +}; - bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1"); - bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1"); - bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4"); - bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1"); +/* ************************************************************************* */ +TEST(DiscreteBayesTree, ThinTree) { + const TestFixture self; + const auto& keys = self.keys; if (debug) { - GTSAM_PRINT(bayesNet); - bayesNet.saveGraph("/tmp/discreteBayesNet.dot"); + GTSAM_PRINT(self.bayesNet); + self.bayesNet.saveGraph("/tmp/discreteBayesNet.dot"); } // create a BayesTree out of a Bayes net - auto bayesTree = DiscreteFactorGraph(bayesNet).eliminateMultifrontal(); if (debug) { - GTSAM_PRINT(*bayesTree); - bayesTree->saveGraph("/tmp/discreteBayesTree.dot"); + GTSAM_PRINT(*self.bayesTree); + self.bayesTree->saveGraph("/tmp/discreteBayesTree.dot"); } // Check frontals and parents for (size_t i : {13, 14, 9, 3, 2, 8, 1, 0, 10, 5, 4}) { - auto clique_i = (*bayesTree)[i]; + auto clique_i = (*self.bayesTree)[i]; EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals())); } - auto R = bayesTree->roots().front(); + auto R = self.bayesTree->roots().front(); // Check whether BN and BT give the same answer on all configurations - vector allPosbValues = cartesianProduct( - key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] & - key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]); + vector allPosbValues = + cartesianProduct(keys[0] & keys[1] & keys[2] & keys[3] & keys[4] & + keys[5] & keys[6] & keys[7] & keys[8] & keys[9] & + keys[10] & keys[11] & keys[12] & keys[13] & keys[14]); for (size_t i = 0; i < allPosbValues.size(); ++i) { DiscreteValues x = allPosbValues[i]; - double expected = bayesNet.evaluate(x); - double actual = bayesTree->evaluate(x); + double expected = self.bayesNet.evaluate(x); + double actual = self.bayesTree->evaluate(x); DOUBLES_EQUAL(expected, actual, 1e-9); } @@ -107,7 +120,7 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) { joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0; for (size_t i = 0; i < allPosbValues.size(); ++i) { DiscreteValues x = allPosbValues[i]; - double px = bayesTree->evaluate(x); + double px = self.bayesTree->evaluate(x); for (size_t i = 0; i < 15; i++) if (x[i]) marginals[i] += px; if (x[12] && x[14]) { @@ -141,46 +154,46 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) { DiscreteValues all1 = allPosbValues.back(); // check separator marginal P(S0) - auto clique = (*bayesTree)[0]; + auto clique = (*self.bayesTree)[0]; DiscreteFactorGraph separatorMarginal0 = clique->separatorMarginal(EliminateDiscrete); DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9); // check separator marginal P(S9), should be P(14) - clique = (*bayesTree)[9]; + clique = (*self.bayesTree)[9]; DiscreteFactorGraph separatorMarginal9 = clique->separatorMarginal(EliminateDiscrete); DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9); // check separator marginal of root, should be empty - clique = (*bayesTree)[11]; + clique = (*self.bayesTree)[11]; DiscreteFactorGraph separatorMarginal11 = clique->separatorMarginal(EliminateDiscrete); LONGS_EQUAL(0, separatorMarginal11.size()); // check shortcut P(S9||R) to root - clique = (*bayesTree)[9]; + clique = (*self.bayesTree)[9]; DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete); LONGS_EQUAL(1, shortcut.size()); DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9); // check shortcut P(S8||R) to root - clique = (*bayesTree)[8]; + clique = (*self.bayesTree)[8]; shortcut = clique->shortcut(R, EliminateDiscrete); DOUBLES_EQUAL(joint_11_12_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9); // check shortcut P(S2||R) to root - clique = (*bayesTree)[2]; + clique = (*self.bayesTree)[2]; shortcut = clique->shortcut(R, EliminateDiscrete); DOUBLES_EQUAL(joint_9_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9); // check shortcut P(S0||R) to root - clique = (*bayesTree)[0]; + clique = (*self.bayesTree)[0]; shortcut = clique->shortcut(R, EliminateDiscrete); DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9); // calculate all shortcuts to root - DiscreteBayesTree::Nodes cliques = bayesTree->nodes(); + DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes(); for (auto clique : cliques) { DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete); if (debug) { @@ -192,7 +205,7 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) { // Check all marginals DiscreteFactor::shared_ptr marginalFactor; for (size_t i = 0; i < 15; i++) { - marginalFactor = bayesTree->marginalFactor(i, EliminateDiscrete); + marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete); double actual = (*marginalFactor)(all1); DOUBLES_EQUAL(marginals[i], actual, 1e-9); } @@ -200,30 +213,60 @@ TEST_UNSAFE(DiscreteBayesTree, ThinTree) { DiscreteBayesNet::shared_ptr actualJoint; // Check joint P(8, 2) - actualJoint = bayesTree->jointBayesNet(8, 2, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(8, 2, EliminateDiscrete); DOUBLES_EQUAL(joint82, actualJoint->evaluate(all1), 1e-9); // Check joint P(1, 2) - actualJoint = bayesTree->jointBayesNet(1, 2, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(1, 2, EliminateDiscrete); DOUBLES_EQUAL(joint12, actualJoint->evaluate(all1), 1e-9); // Check joint P(2, 4) - actualJoint = bayesTree->jointBayesNet(2, 4, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(2, 4, EliminateDiscrete); DOUBLES_EQUAL(joint24, actualJoint->evaluate(all1), 1e-9); // Check joint P(4, 5) - actualJoint = bayesTree->jointBayesNet(4, 5, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(4, 5, EliminateDiscrete); DOUBLES_EQUAL(joint45, actualJoint->evaluate(all1), 1e-9); // Check joint P(4, 6) - actualJoint = bayesTree->jointBayesNet(4, 6, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(4, 6, EliminateDiscrete); DOUBLES_EQUAL(joint46, actualJoint->evaluate(all1), 1e-9); // Check joint P(4, 11) - actualJoint = bayesTree->jointBayesNet(4, 11, EliminateDiscrete); + actualJoint = self.bayesTree->jointBayesNet(4, 11, EliminateDiscrete); DOUBLES_EQUAL(joint_4_11, actualJoint->evaluate(all1), 1e-9); } +/* ************************************************************************* */ +TEST(DiscreteBayesTree, Dot) { + const TestFixture self; + string actual = self.bayesTree->dot(); + EXPECT(actual == + "digraph G{\n" + "0[label=\"13,11,6,7\"];\n" + "0->1\n" + "1[label=\"14 : 11,13\"];\n" + "1->2\n" + "2[label=\"9,12 : 14\"];\n" + "2->3\n" + "3[label=\"3 : 9,12\"];\n" + "2->4\n" + "4[label=\"2 : 9,12\"];\n" + "2->5\n" + "5[label=\"8 : 12,14\"];\n" + "5->6\n" + "6[label=\"1 : 8,12\"];\n" + "5->7\n" + "7[label=\"0 : 8,12\"];\n" + "1->8\n" + "8[label=\"10 : 13,14\"];\n" + "8->9\n" + "9[label=\"5 : 10,13\"];\n" + "8->10\n" + "10[label=\"4 : 10,13\"];\n" + "}"); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index cca9ac69ec..32117bd25c 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -359,6 +359,31 @@ cout << unicorns; } #endif +/* ************************************************************************* */ +TEST(DiscreteFactorGraph, Dot) { + // Declare a bunch of keys + DiscreteKey C(0, 2), A(1, 2), B(2, 2); + + // Create Factor graph + DiscreteFactorGraph graph; + graph.add(C & A, "0.2 0.8 0.3 0.7"); + graph.add(C & B, "0.1 0.9 0.4 0.6"); + + string actual = graph.dot(); + string expected = + "graph {\n" + " size=\"5,5\";\n" + "\n" + " var0[label=\"0\"];\n" + " var1[label=\"1\"];\n" + " var2[label=\"2\"];\n" + "\n" + " var0--var1;\n" + " var0--var2;\n" + "}\n"; + EXPECT(actual == expected); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/inference/BayesNet-inst.h b/gtsam/inference/BayesNet-inst.h index a737622585..be34b2928f 100644 --- a/gtsam/inference/BayesNet-inst.h +++ b/gtsam/inference/BayesNet-inst.h @@ -35,21 +35,39 @@ void BayesNet::print( /* ************************************************************************* */ template -void BayesNet::saveGraph(const std::string& s, - const KeyFormatter& keyFormatter) const { - std::ofstream of(s.c_str()); - of << "digraph G{\n"; - - for (auto conditional : boost::adaptors::reverse(*this)) { - typename CONDITIONAL::Frontals frontals = conditional->frontals(); - Key me = frontals.front(); - typename CONDITIONAL::Parents parents = conditional->parents(); - for (Key p : parents) - of << keyFormatter(p) << "->" << keyFormatter(me) << std::endl; +void BayesNet::dot(std::ostream& os, + const KeyFormatter& keyFormatter) const { + os << "digraph G{\n"; + + for (auto conditional : *this) { + auto frontals = conditional->frontals(); + const Key me = frontals.front(); + auto parents = conditional->parents(); + for (const Key& p : parents) + os << keyFormatter(p) << "->" << keyFormatter(me) << "\n"; } - of << "}"; + os << "}"; + std::flush(os); +} + +/* ************************************************************************* */ +template +std::string BayesNet::dot(const KeyFormatter& keyFormatter) const { + std::stringstream ss; + dot(ss, keyFormatter); + return ss.str(); +} + +/* ************************************************************************* */ +template +void BayesNet::saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter) const { + std::ofstream of(filename.c_str()); + dot(of, keyFormatter); of.close(); } +/* ************************************************************************* */ + } // namespace gtsam diff --git a/gtsam/inference/BayesNet.h b/gtsam/inference/BayesNet.h index 938278d5ad..f987ad51be 100644 --- a/gtsam/inference/BayesNet.h +++ b/gtsam/inference/BayesNet.h @@ -64,11 +64,21 @@ namespace gtsam { /// @} - /// @name Standard Interface + /// @name Graph Display /// @{ - void saveGraph(const std::string& s, + /// Output to graphviz format, stream version. + void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// Output to graphviz format string. + std::string dot( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// output to file with graphviz format. + void saveGraph(const std::string& filename, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// @} }; } diff --git a/gtsam/inference/BayesTree-inst.h b/gtsam/inference/BayesTree-inst.h index 5b53a57193..9b937fefb5 100644 --- a/gtsam/inference/BayesTree-inst.h +++ b/gtsam/inference/BayesTree-inst.h @@ -63,20 +63,40 @@ namespace gtsam { } /* ************************************************************************* */ - template - void BayesTree::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const { - if (roots_.empty()) throw std::invalid_argument("the root of Bayes tree has not been initialized!"); - std::ofstream of(s.c_str()); - of<< "digraph G{\n"; - for(const sharedClique& root: roots_) - saveGraph(of, root, keyFormatter); - of<<"}"; + template + void BayesTree::dot(std::ostream& os, + const KeyFormatter& keyFormatter) const { + if (roots_.empty()) + throw std::invalid_argument( + "the root of Bayes tree has not been initialized!"); + os << "digraph G{\n"; + for (const sharedClique& root : roots_) dot(os, root, keyFormatter); + os << "}"; + std::flush(os); + } + + /* ************************************************************************* */ + template + std::string BayesTree::dot(const KeyFormatter& keyFormatter) const { + std::stringstream ss; + dot(ss, keyFormatter); + return ss.str(); + } + + /* ************************************************************************* */ + template + void BayesTree::saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter) const { + std::ofstream of(filename.c_str()); + dot(of, keyFormatter); of.close(); } /* ************************************************************************* */ - template - void BayesTree::saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& indexFormatter, int parentnum) const { + template + void BayesTree::dot(std::ostream& s, sharedClique clique, + const KeyFormatter& indexFormatter, + int parentnum) const { static int num = 0; bool first = true; std::stringstream out; @@ -107,7 +127,7 @@ namespace gtsam { for (sharedClique c : clique->children) { num++; - saveGraph(s, c, indexFormatter, parentnum); + dot(s, c, indexFormatter, parentnum); } } diff --git a/gtsam/inference/BayesTree.h b/gtsam/inference/BayesTree.h index cc003d8dcb..68a45a014a 100644 --- a/gtsam/inference/BayesTree.h +++ b/gtsam/inference/BayesTree.h @@ -182,13 +182,20 @@ namespace gtsam { */ sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate& function = EliminationTraitsType::DefaultEliminate) const; - /** - * Read only with side effects - */ + /// @name Graph Display + /// @{ + + /// Output to graphviz format, stream version. + void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; - /** saves the Tree to a text file in GraphViz format */ - void saveGraph(const std::string& s, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// Output to graphviz format string. + std::string dot( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// output to file with graphviz format. + void saveGraph(const std::string& filename, + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// @} /// @name Advanced Interface /// @{ @@ -236,8 +243,8 @@ namespace gtsam { protected: /** private helper method for saving the Tree to a text file in GraphViz format */ - void saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& keyFormatter, - int parentnum = 0) const; + void dot(std::ostream &s, sharedClique clique, const KeyFormatter& keyFormatter, + int parentnum = 0) const; /** Gather data on a single clique */ void getCliqueData(sharedClique clique, BayesTreeCliqueData* stats) const; @@ -249,7 +256,7 @@ namespace gtsam { void fillNodesIndex(const sharedClique& subtree); // Friend JunctionTree because it directly fills roots and nodes index. - template friend class EliminatableClusterTree; + template friend class EliminatableClusterTree; private: /** Serialization function */ diff --git a/gtsam/inference/DotWriter.cpp b/gtsam/inference/DotWriter.cpp new file mode 100644 index 0000000000..fb3ea05054 --- /dev/null +++ b/gtsam/inference/DotWriter.cpp @@ -0,0 +1,93 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DotWriter.cpp + * @brief Graphviz formatting for factor graphs. + * @author Frank Dellaert + * @date December, 2021 + */ + +#include +#include + +#include + +using namespace std; + +namespace gtsam { + +void DotWriter::writePreamble(ostream* os) const { + *os << "graph {\n"; + *os << " size=\"" << figureWidthInches << "," << figureHeightInches + << "\";\n\n"; +} + +void DotWriter::DrawVariable(Key key, const KeyFormatter& keyFormatter, + const boost::optional& position, + ostream* os) { + // Label the node with the label from the KeyFormatter + *os << " var" << key << "[label=\"" << keyFormatter(key) << "\""; + if (position) { + *os << ", pos=\"" << position->x() << "," << position->y() << "!\""; + } + *os << "];\n"; +} + +void DotWriter::DrawFactor(size_t i, const boost::optional& position, + ostream* os) { + *os << " factor" << i << "[label=\"\", shape=point"; + if (position) { + *os << ", pos=\"" << position->x() << "," << position->y() << "!\""; + } + *os << "];\n"; +} + +void DotWriter::ConnectVariables(Key key1, Key key2, ostream* os) { + *os << " var" << key1 << "--" + << "var" << key2 << ";\n"; +} + +void DotWriter::ConnectVariableFactor(Key key, size_t i, ostream* os) { + *os << " var" << key << "--" + << "factor" << i << ";\n"; +} + +void DotWriter::processFactor(size_t i, const KeyVector& keys, + const boost::optional& position, + ostream* os) const { + if (plotFactorPoints) { + if (binaryEdges && keys.size() == 2) { + ConnectVariables(keys[0], keys[1], os); + } else { + // Create dot for the factor. + DrawFactor(i, position, os); + + // Make factor-variable connections + if (connectKeysToFactor) { + for (Key key : keys) { + ConnectVariableFactor(key, i, os); + } + } + } + } else { + // just connect variables in a clique + for (Key key1 : keys) { + for (Key key2 : keys) { + if (key2 > key1) { + ConnectVariables(key1, key2, os); + } + } + } + } +} + +} // namespace gtsam diff --git a/gtsam/inference/DotWriter.h b/gtsam/inference/DotWriter.h new file mode 100644 index 0000000000..8f36f3ce64 --- /dev/null +++ b/gtsam/inference/DotWriter.h @@ -0,0 +1,69 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DotWriter.h + * @brief Graphviz formatter + * @author Frank Dellaert + * @date December, 2021 + */ + +#pragma once + +#include +#include +#include + +#include + +namespace gtsam { + +/// Graphviz formatter. +struct GTSAM_EXPORT DotWriter { + double figureWidthInches; ///< The figure width on paper in inches + double figureHeightInches; ///< The figure height on paper in inches + bool plotFactorPoints; ///< Plots each factor as a dot between the variables + bool connectKeysToFactor; ///< Draw a line from each key within a factor to + ///< the dot of the factor + bool binaryEdges; ///< just use non-dotted edges for binary factors + + DotWriter() + : figureWidthInches(5), + figureHeightInches(5), + plotFactorPoints(true), + connectKeysToFactor(true), + binaryEdges(true) {} + + /// Write out preamble, including size. + void writePreamble(std::ostream* os) const; + + /// Create a variable dot fragment. + static void DrawVariable(Key key, const KeyFormatter& keyFormatter, + const boost::optional& position, + std::ostream* os); + + /// Create factor dot. + static void DrawFactor(size_t i, const boost::optional& position, + std::ostream* os); + + /// Connect two variables. + static void ConnectVariables(Key key1, Key key2, std::ostream* os); + + /// Connect variable and factor. + static void ConnectVariableFactor(Key key, size_t i, std::ostream* os); + + /// Draw a single factor, specified by its index i and its variable keys. + void processFactor(size_t i, const KeyVector& keys, + const boost::optional& position, + std::ostream* os) const; +}; + +} // namespace gtsam diff --git a/gtsam/inference/FactorGraph-inst.h b/gtsam/inference/FactorGraph-inst.h index 166ae41f41..06b71036c7 100644 --- a/gtsam/inference/FactorGraph-inst.h +++ b/gtsam/inference/FactorGraph-inst.h @@ -26,6 +26,7 @@ #include #include #include // for cout :-( +#include #include #include @@ -125,4 +126,48 @@ FactorIndices FactorGraph::add_factors(const CONTAINER& factors, return newFactorIndices; } +/* ************************************************************************* */ +template +void FactorGraph::dot(std::ostream& os, const DotWriter& writer, + const KeyFormatter& keyFormatter) const { + writer.writePreamble(&os); + + // Create nodes for each variable in the graph + for (Key key : keys()) { + writer.DrawVariable(key, keyFormatter, boost::none, &os); + } + os << "\n"; + + // Create factors and variable connections + for (size_t i = 0; i < size(); ++i) { + const auto& factor = at(i); + if (factor) { + const KeyVector& factorKeys = factor->keys(); + writer.processFactor(i, factorKeys, boost::none, &os); + } + } + + os << "}\n"; + std::flush(os); +} + +/* ************************************************************************* */ +template +std::string FactorGraph::dot(const DotWriter& writer, + const KeyFormatter& keyFormatter) const { + std::stringstream ss; + dot(ss, writer, keyFormatter); + return ss.str(); +} + +/* ************************************************************************* */ +template +void FactorGraph::saveGraph(const std::string& filename, + const DotWriter& writer, + const KeyFormatter& keyFormatter) const { + std::ofstream of(filename.c_str()); + dot(of, writer, keyFormatter); + of.close(); +} + } // namespace gtsam diff --git a/gtsam/inference/FactorGraph.h b/gtsam/inference/FactorGraph.h index e337e3249f..a4c293b648 100644 --- a/gtsam/inference/FactorGraph.h +++ b/gtsam/inference/FactorGraph.h @@ -22,9 +22,10 @@ #pragma once +#include +#include #include #include -#include #include // for Eigen::aligned_allocator @@ -36,6 +37,7 @@ #include #include #include +#include namespace gtsam { /// Define collection type: @@ -371,6 +373,23 @@ class FactorGraph { return factors_.erase(first, last); } + /// @} + /// @name Graph Display + /// @{ + + /// Output to graphviz format, stream version. + void dot(std::ostream& os, const DotWriter& writer = DotWriter(), + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// Output to graphviz format string. + std::string dot(const DotWriter& writer = DotWriter(), + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// output to file with graphviz format. + void saveGraph(const std::string& filename, + const DotWriter& writer = DotWriter(), + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// @} /// @name Advanced Interface /// @{ diff --git a/gtsam/linear/GaussianFactorGraph.cpp b/gtsam/linear/GaussianFactorGraph.cpp index 664aeff6d0..72eb107d09 100644 --- a/gtsam/linear/GaussianFactorGraph.cpp +++ b/gtsam/linear/GaussianFactorGraph.cpp @@ -379,7 +379,7 @@ namespace gtsam { gttic(Compute_minimizing_step_size); // Compute minimizing step size - double step = -gradientSqNorm / dot(Rg, Rg); + double step = -gradientSqNorm / gtsam::dot(Rg, Rg); gttoc(Compute_minimizing_step_size); gttic(Compute_point); diff --git a/gtsam/nonlinear/GraphvizFormatting.cpp b/gtsam/nonlinear/GraphvizFormatting.cpp new file mode 100644 index 0000000000..c37f07c8a8 --- /dev/null +++ b/gtsam/nonlinear/GraphvizFormatting.cpp @@ -0,0 +1,136 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file GraphvizFormatting.cpp + * @brief Graphviz formatter for NonlinearFactorGraph + * @author Frank Dellaert + * @date December, 2021 + */ + +#include +#include + +// TODO(frank): nonlinear should not depend on geometry: +#include +#include + +#include + +namespace gtsam { + +Vector2 GraphvizFormatting::findBounds(const Values& values, + const KeySet& keys) const { + Vector2 min; + min.x() = std::numeric_limits::infinity(); + min.y() = std::numeric_limits::infinity(); + for (const Key& key : keys) { + if (values.exists(key)) { + boost::optional xy = operator()(values.at(key)); + if (xy) { + if (xy->x() < min.x()) min.x() = xy->x(); + if (xy->y() < min.y()) min.y() = xy->y(); + } + } + } + return min; +} + +boost::optional GraphvizFormatting::operator()( + const Value& value) const { + Vector3 t; + if (const GenericValue* p = + dynamic_cast*>(&value)) { + t << p->value().x(), p->value().y(), 0; + } else if (const GenericValue* p = + dynamic_cast*>(&value)) { + t << p->value().x(), p->value().y(), 0; + } else if (const GenericValue* p = + dynamic_cast*>(&value)) { + t = p->value().translation(); + } else if (const GenericValue* p = + dynamic_cast*>(&value)) { + t = p->value(); + } else { + return boost::none; + } + double x, y; + switch (paperHorizontalAxis) { + case X: + x = t.x(); + break; + case Y: + x = t.y(); + break; + case Z: + x = t.z(); + break; + case NEGX: + x = -t.x(); + break; + case NEGY: + x = -t.y(); + break; + case NEGZ: + x = -t.z(); + break; + default: + throw std::runtime_error("Invalid enum value"); + } + switch (paperVerticalAxis) { + case X: + y = t.x(); + break; + case Y: + y = t.y(); + break; + case Z: + y = t.z(); + break; + case NEGX: + y = -t.x(); + break; + case NEGY: + y = -t.y(); + break; + case NEGZ: + y = -t.z(); + break; + default: + throw std::runtime_error("Invalid enum value"); + } + return Vector2(x, y); +} + +// Return affinely transformed variable position if it exists. +boost::optional GraphvizFormatting::variablePos(const Values& values, + const Vector2& min, + Key key) const { + if (!values.exists(key)) return boost::none; + boost::optional xy = operator()(values.at(key)); + if (xy) { + xy->x() = scale * (xy->x() - min.x()); + xy->y() = scale * (xy->y() - min.y()); + } + return xy; +} + +// Return affinely transformed factor position if it exists. +boost::optional GraphvizFormatting::factorPos(const Vector2& min, + size_t i) const { + if (factorPositions.size() == 0) return boost::none; + auto it = factorPositions.find(i); + if (it == factorPositions.end()) return boost::none; + auto pos = it->second; + return Vector2(scale * (pos.x() - min.x()), scale * (pos.y() - min.y())); +} + +} // namespace gtsam diff --git a/gtsam/nonlinear/GraphvizFormatting.h b/gtsam/nonlinear/GraphvizFormatting.h new file mode 100644 index 0000000000..c36b09a8fc --- /dev/null +++ b/gtsam/nonlinear/GraphvizFormatting.h @@ -0,0 +1,69 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file GraphvizFormatting.h + * @brief Graphviz formatter for NonlinearFactorGraph + * @author Frank Dellaert + * @date December, 2021 + */ + +#pragma once + +#include + +namespace gtsam { + +class Values; +class Value; + +/** + * Formatting options and functions for saving a NonlinearFactorGraph instance + * in GraphViz format. + */ +struct GTSAM_EXPORT GraphvizFormatting : public DotWriter { + /// World axes to be assigned to paper axes + enum Axis { X, Y, Z, NEGX, NEGY, NEGZ }; + + Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal + ///< paper axis + Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper + ///< axis + double scale; ///< Scale all positions to reduce / increase density + bool mergeSimilarFactors; ///< Merge multiple factors that have the same + ///< connectivity + + /// (optional for each factor) Manually specify factor "dot" positions: + std::map factorPositions; + + /// Default constructor sets up robot coordinates. Paper horizontal is robot + /// Y, paper vertical is robot X. Default figure size of 5x5 in. + GraphvizFormatting() + : paperHorizontalAxis(Y), + paperVerticalAxis(X), + scale(1), + mergeSimilarFactors(false) {} + + // Find bounds + Vector2 findBounds(const Values& values, const KeySet& keys) const; + + /// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3 + boost::optional operator()(const Value& value) const; + + /// Return affinely transformed variable position if it exists. + boost::optional variablePos(const Values& values, const Vector2& min, + Key key) const; + + /// Return affinely transformed factor position if it exists. + boost::optional factorPos(const Vector2& min, size_t i) const; +}; + +} // namespace gtsam diff --git a/gtsam/nonlinear/NonlinearFactorGraph.cpp b/gtsam/nonlinear/NonlinearFactorGraph.cpp index 8e4cf277c2..7c6d2e6cf0 100644 --- a/gtsam/nonlinear/NonlinearFactorGraph.cpp +++ b/gtsam/nonlinear/NonlinearFactorGraph.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include // for GTSAM_USE_TBB @@ -35,7 +36,6 @@ #include #include -#include using namespace std; @@ -91,89 +91,25 @@ bool NonlinearFactorGraph::equals(const NonlinearFactorGraph& other, double tol) } /* ************************************************************************* */ -void NonlinearFactorGraph::saveGraph(std::ostream &stm, const Values& values, - const GraphvizFormatting& formatting, - const KeyFormatter& keyFormatter) const -{ - stm << "graph {\n"; - stm << " size=\"" << formatting.figureWidthInches << "," << - formatting.figureHeightInches << "\";\n\n"; +void NonlinearFactorGraph::dot(std::ostream& os, const Values& values, + const GraphvizFormatting& writer, + const KeyFormatter& keyFormatter) const { + writer.writePreamble(&os); + // Find bounds (imperative) KeySet keys = this->keys(); - - // Local utility function to extract x and y coordinates - struct { boost::optional operator()( - const Value& value, const GraphvizFormatting& graphvizFormatting) - { - Vector3 t; - if (const GenericValue* p = dynamic_cast*>(&value)) { - t << p->value().x(), p->value().y(), 0; - } else if (const GenericValue* p = dynamic_cast*>(&value)) { - t << p->value().x(), p->value().y(), 0; - } else if (const GenericValue* p = dynamic_cast*>(&value)) { - t = p->value().translation(); - } else if (const GenericValue* p = dynamic_cast*>(&value)) { - t = p->value(); - } else { - return boost::none; - } - double x, y; - switch (graphvizFormatting.paperHorizontalAxis) { - case GraphvizFormatting::X: x = t.x(); break; - case GraphvizFormatting::Y: x = t.y(); break; - case GraphvizFormatting::Z: x = t.z(); break; - case GraphvizFormatting::NEGX: x = -t.x(); break; - case GraphvizFormatting::NEGY: x = -t.y(); break; - case GraphvizFormatting::NEGZ: x = -t.z(); break; - default: throw std::runtime_error("Invalid enum value"); - } - switch (graphvizFormatting.paperVerticalAxis) { - case GraphvizFormatting::X: y = t.x(); break; - case GraphvizFormatting::Y: y = t.y(); break; - case GraphvizFormatting::Z: y = t.z(); break; - case GraphvizFormatting::NEGX: y = -t.x(); break; - case GraphvizFormatting::NEGY: y = -t.y(); break; - case GraphvizFormatting::NEGZ: y = -t.z(); break; - default: throw std::runtime_error("Invalid enum value"); - } - return Point2(x,y); - }} getXY; - - // Find bounds - double minX = numeric_limits::infinity(), maxX = -numeric_limits::infinity(); - double minY = numeric_limits::infinity(), maxY = -numeric_limits::infinity(); - for (const Key& key : keys) { - if (values.exists(key)) { - boost::optional xy = getXY(values.at(key), formatting); - if(xy) { - if(xy->x() < minX) - minX = xy->x(); - if(xy->x() > maxX) - maxX = xy->x(); - if(xy->y() < minY) - minY = xy->y(); - if(xy->y() > maxY) - maxY = xy->y(); - } - } - } + Vector2 min = writer.findBounds(values, keys); // Create nodes for each variable in the graph - for(Key key: keys){ - // Label the node with the label from the KeyFormatter - stm << " var" << key << "[label=\"" << keyFormatter(key) << "\""; - if(values.exists(key)) { - boost::optional xy = getXY(values.at(key), formatting); - if(xy) - stm << ", pos=\"" << formatting.scale*(xy->x() - minX) << "," << formatting.scale*(xy->y() - minY) << "!\""; - } - stm << "];\n"; + for (Key key : keys) { + auto position = writer.variablePos(values, min, key); + writer.DrawVariable(key, keyFormatter, position, &os); } - stm << "\n"; + os << "\n"; - if (formatting.mergeSimilarFactors) { + if (writer.mergeSimilarFactors) { // Remove duplicate factors - std::set structure; + std::set structure; for (const sharedFactor& factor : factors_) { if (factor) { KeyVector factorKeys = factor->keys(); @@ -184,86 +120,40 @@ void NonlinearFactorGraph::saveGraph(std::ostream &stm, const Values& values, // Create factors and variable connections size_t i = 0; - for(const KeyVector& factorKeys: structure){ - // Make each factor a dot - stm << " factor" << i << "[label=\"\", shape=point"; - { - map::const_iterator pos = formatting.factorPositions.find(i); - if(pos != formatting.factorPositions.end()) - stm << ", pos=\"" << formatting.scale*(pos->second.x() - minX) << "," - << formatting.scale*(pos->second.y() - minY) << "!\""; - } - stm << "];\n"; - - // Make factor-variable connections - for(Key key: factorKeys) { - stm << " var" << key << "--" << "factor" << i << ";\n"; - } - - ++ i; + for (const KeyVector& factorKeys : structure) { + writer.processFactor(i++, factorKeys, boost::none, &os); } } else { // Create factors and variable connections - for(size_t i = 0; i < size(); ++i) { + for (size_t i = 0; i < size(); ++i) { const NonlinearFactor::shared_ptr& factor = at(i); - // If null pointer, move on to the next - if (!factor) { - continue; - } - - if (formatting.plotFactorPoints) { - const KeyVector& keys = factor->keys(); - if (formatting.binaryEdges && keys.size() == 2) { - stm << " var" << keys[0] << "--" - << "var" << keys[1] << ";\n"; - } else { - // Make each factor a dot - stm << " factor" << i << "[label=\"\", shape=point"; - { - map::const_iterator pos = - formatting.factorPositions.find(i); - if (pos != formatting.factorPositions.end()) - stm << ", pos=\"" << formatting.scale * (pos->second.x() - minX) - << "," << formatting.scale * (pos->second.y() - minY) - << "!\""; - } - stm << "];\n"; - - // Make factor-variable connections - if (formatting.connectKeysToFactor && factor) { - for (Key key : *factor) { - stm << " var" << key << "--" - << "factor" << i << ";\n"; - } - } - } - } else { - Key k; - bool firstTime = true; - for (Key key : *this->at(i)) { - if (firstTime) { - k = key; - firstTime = false; - continue; - } - stm << " var" << key << "--" - << "var" << k << ";\n"; - k = key; - } + if (factor) { + const KeyVector& factorKeys = factor->keys(); + writer.processFactor(i, factorKeys, writer.factorPos(min, i), &os); } } } - stm << "}\n"; + os << "}\n"; + std::flush(os); +} + +/* ************************************************************************* */ +std::string NonlinearFactorGraph::dot( + const Values& values, const GraphvizFormatting& writer, + const KeyFormatter& keyFormatter) const { + std::stringstream ss; + dot(ss, values, writer, keyFormatter); + return ss.str(); } /* ************************************************************************* */ void NonlinearFactorGraph::saveGraph( - const std::string& file, const Values& values, - const GraphvizFormatting& graphvizFormatting, + const std::string& filename, const Values& values, + const GraphvizFormatting& writer, const KeyFormatter& keyFormatter) const { - std::ofstream of(file); - saveGraph(of, values, graphvizFormatting, keyFormatter); + std::ofstream of(filename); + dot(of, values, writer, keyFormatter); of.close(); } diff --git a/gtsam/nonlinear/NonlinearFactorGraph.h b/gtsam/nonlinear/NonlinearFactorGraph.h index 61cbbafb98..c958d63023 100644 --- a/gtsam/nonlinear/NonlinearFactorGraph.h +++ b/gtsam/nonlinear/NonlinearFactorGraph.h @@ -23,6 +23,7 @@ #include #include +#include #include #include @@ -41,32 +42,6 @@ namespace gtsam { template class ExpressionFactor; - /** - * Formatting options when saving in GraphViz format using - * NonlinearFactorGraph::saveGraph. - */ - struct GTSAM_EXPORT GraphvizFormatting { - enum Axis { X, Y, Z, NEGX, NEGY, NEGZ }; ///< World axes to be assigned to paper axes - Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal paper axis - Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper axis - double figureWidthInches; ///< The figure width on paper in inches - double figureHeightInches; ///< The figure height on paper in inches - double scale; ///< Scale all positions to reduce / increase density - bool mergeSimilarFactors; ///< Merge multiple factors that have the same connectivity - bool plotFactorPoints; ///< Plots each factor as a dot between the variables - bool connectKeysToFactor; ///< Draw a line from each key within a factor to the dot of the factor - bool binaryEdges; ///< just use non-dotted edges for binary factors - std::map factorPositions; ///< (optional for each factor) Manually specify factor "dot" positions. - /// Default constructor sets up robot coordinates. Paper horizontal is robot Y, - /// paper vertical is robot X. Default figure size of 5x5 in. - GraphvizFormatting() : - paperHorizontalAxis(Y), paperVerticalAxis(X), - figureWidthInches(5), figureHeightInches(5), scale(1), - mergeSimilarFactors(false), plotFactorPoints(true), - connectKeysToFactor(true), binaryEdges(true) {} - }; - - /** * A non-linear factor graph is a graph of non-Gaussian, i.e. non-linear factors, * which derive from NonlinearFactor. The values structures are typically (in SAM) more general @@ -115,21 +90,6 @@ namespace gtsam { /** Test equality */ bool equals(const NonlinearFactorGraph& other, double tol = 1e-9) const; - /// Write the graph in GraphViz format for visualization - void saveGraph(std::ostream& stm, const Values& values = Values(), - const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(), - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; - - /** - * Write the graph in GraphViz format to file for visualization. - * - * This is a wrapper friendly version since wrapped languages don't have - * access to C++ streams. - */ - void saveGraph(const std::string& file, const Values& values = Values(), - const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(), - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; - /** unnormalized error, \f$ 0.5 \sum_i (h_i(X_i)-z)^2/\sigma^2 \f$ in the most common case */ double error(const Values& values) const; @@ -246,7 +206,33 @@ namespace gtsam { emplace_shared>(key, prior, covariance); } - private: + /// @name Graph Display + /// @{ + + using FactorGraph::dot; + using FactorGraph::saveGraph; + + /// Output to graphviz format, stream version, with Values/extra options. + void dot( + std::ostream& os, const Values& values, + const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(), + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// Output to graphviz format string, with Values/extra options. + std::string dot( + const Values& values, + const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(), + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// output to file with graphviz format, with Values/extra options. + void saveGraph( + const std::string& filename, const Values& values, + const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(), + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// @} + + private: /** * Linearize from Scatter rather than from Ordering. Made private because @@ -275,6 +261,14 @@ namespace gtsam { Values GTSAM_DEPRECATED updateCholesky(const Values& values, boost::none_t, const Dampen& dampen = nullptr) const {return updateCholesky(values, dampen);} + + /** \deprecated */ + void GTSAM_DEPRECATED saveGraph( + std::ostream& os, const Values& values = Values(), + const GraphvizFormatting& graphvizFormatting = GraphvizFormatting(), + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { + dot(os, values, graphvizFormatting, keyFormatter); + } #endif }; diff --git a/python/gtsam/notebooks/DiscreteBayesTree.ipynb b/python/gtsam/notebooks/DiscreteBayesTree.ipynb new file mode 100644 index 0000000000..066c31d6a8 --- /dev/null +++ b/python/gtsam/notebooks/DiscreteBayesTree.ipynb @@ -0,0 +1,200 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# The Discrete Bayes Tree\n", + "\n", + "An example of building a Bayes net, then eliminating it into a Bayes tree. Mirrors the code in `testDiscreteBayesTree.cpp` .\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from gtsam import DiscreteBayesTree, DiscreteBayesNet, DiscreteKeys, DiscreteFactorGraph, Ordering\n", + "from gtsam.symbol_shorthand import S\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def P(*args):\n", + " \"\"\" Create a DiscreteKeys instances from a variable number of DiscreteKey pairs.\"\"\"\n", + " #TODO: We can make life easier by providing variable argument functions in C++ itself.\n", + " dks = DiscreteKeys()\n", + " for key in args:\n", + " dks.push_back(key)\n", + " return dks" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import graphviz\n", + "class show(graphviz.Source):\n", + " \"\"\" Display an object with a dot method as a graph.\"\"\"\n", + "\n", + " def __init__(self, obj):\n", + " \"\"\"Construct from object with 'dot' method.\"\"\"\n", + " # This small class takes an object, calls its dot function, and uses the\n", + " # resulting string to initialize a graphviz.Source instance. This in turn\n", + " # has a _repr_mimebundle_ method, which then renders it in the notebook.\n", + " super().__init__(obj.dot())" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\nG\n\n\n\n8\n\n8\n\n\n\n0\n\n0\n\n\n\n8->0\n\n\n\n\n\n1\n\n1\n\n\n\n8->1\n\n\n\n\n\n12\n\n12\n\n\n\n12->8\n\n\n\n\n\n12->0\n\n\n\n\n\n12->1\n\n\n\n\n\n9\n\n9\n\n\n\n12->9\n\n\n\n\n\n2\n\n2\n\n\n\n12->2\n\n\n\n\n\n3\n\n3\n\n\n\n12->3\n\n\n\n\n\n9->2\n\n\n\n\n\n9->3\n\n\n\n\n\n10\n\n10\n\n\n\n4\n\n4\n\n\n\n10->4\n\n\n\n\n\n5\n\n5\n\n\n\n10->5\n\n\n\n\n\n13\n\n13\n\n\n\n13->10\n\n\n\n\n\n13->4\n\n\n\n\n\n13->5\n\n\n\n\n\n11\n\n11\n\n\n\n13->11\n\n\n\n\n\n6\n\n6\n\n\n\n13->6\n\n\n\n\n\n7\n\n7\n\n\n\n13->7\n\n\n\n\n\n11->6\n\n\n\n\n\n11->7\n\n\n\n\n\n14\n\n14\n\n\n\n14->8\n\n\n\n\n\n14->12\n\n\n\n\n\n14->9\n\n\n\n\n\n14->10\n\n\n\n\n\n14->13\n\n\n\n\n\n14->11\n\n\n\n\n\n", + "text/plain": [ + "<__main__.show at 0x109c615b0>" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Define DiscreteKey pairs.\n", + "keys = [(j, 2) for j in range(15)]\n", + "\n", + "# Create thin-tree Bayesnet.\n", + "bayesNet = DiscreteBayesNet()\n", + "\n", + "\n", + "bayesNet.add(keys[0], P(keys[8], keys[12]), \"2/3 1/4 3/2 4/1\")\n", + "bayesNet.add(keys[1], P(keys[8], keys[12]), \"4/1 2/3 3/2 1/4\")\n", + "bayesNet.add(keys[2], P(keys[9], keys[12]), \"1/4 8/2 2/3 4/1\")\n", + "bayesNet.add(keys[3], P(keys[9], keys[12]), \"1/4 2/3 3/2 4/1\")\n", + "\n", + "bayesNet.add(keys[4], P(keys[10], keys[13]), \"2/3 1/4 3/2 4/1\")\n", + "bayesNet.add(keys[5], P(keys[10], keys[13]), \"4/1 2/3 3/2 1/4\")\n", + "bayesNet.add(keys[6], P(keys[11], keys[13]), \"1/4 3/2 2/3 4/1\")\n", + "bayesNet.add(keys[7], P(keys[11], keys[13]), \"1/4 2/3 3/2 4/1\")\n", + "\n", + "bayesNet.add(keys[8], P(keys[12], keys[14]), \"T 1/4 3/2 4/1\")\n", + "bayesNet.add(keys[9], P(keys[12], keys[14]), \"4/1 2/3 F 1/4\")\n", + "bayesNet.add(keys[10], P(keys[13], keys[14]), \"1/4 3/2 2/3 4/1\")\n", + "bayesNet.add(keys[11], P(keys[13], keys[14]), \"1/4 2/3 3/2 4/1\")\n", + "\n", + "bayesNet.add(keys[12], P(keys[14]), \"3/1 3/1\")\n", + "bayesNet.add(keys[13], P(keys[14]), \"1/3 3/1\")\n", + "\n", + "bayesNet.add(keys[14], P(), \"1/3\")\n", + "\n", + "show(bayesNet)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DiscreteValues{0: 1, 1: 1, 2: 0, 3: 1, 4: 1, 5: 1, 6: 0, 7: 1, 8: 0, 9: 0, 10: 0, 11: 0, 12: 1, 13: 1, 14: 0}\n", + "DiscreteValues{0: 0, 1: 1, 2: 0, 3: 0, 4: 1, 5: 0, 6: 0, 7: 0, 8: 1, 9: 1, 10: 0, 11: 1, 12: 0, 13: 0, 14: 1}\n", + "DiscreteValues{0: 1, 1: 0, 2: 1, 3: 1, 4: 0, 5: 0, 6: 1, 7: 0, 8: 1, 9: 0, 10: 1, 11: 1, 12: 0, 13: 1, 14: 0}\n", + "DiscreteValues{0: 1, 1: 1, 2: 0, 3: 0, 4: 1, 5: 1, 6: 1, 7: 1, 8: 0, 9: 1, 10: 0, 11: 0, 12: 1, 13: 0, 14: 1}\n", + "DiscreteValues{0: 0, 1: 0, 2: 1, 3: 0, 4: 1, 5: 1, 6: 1, 7: 0, 8: 1, 9: 1, 10: 0, 11: 1, 12: 0, 13: 0, 14: 1}\n" + ] + } + ], + "source": [ + "# Sample Bayes net (needs conditionals added in elimination order!)\n", + "for i in range(5):\n", + " print(bayesNet.sample())" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\n\nvar0\n\n0\n\n\n\nfactor0\n\n\n\n\nvar0--factor0\n\n\n\n\nvar1\n\n1\n\n\n\nfactor1\n\n\n\n\nvar1--factor1\n\n\n\n\nvar2\n\n2\n\n\n\nfactor2\n\n\n\n\nvar2--factor2\n\n\n\n\nvar3\n\n3\n\n\n\nfactor3\n\n\n\n\nvar3--factor3\n\n\n\n\nvar4\n\n4\n\n\n\nfactor4\n\n\n\n\nvar4--factor4\n\n\n\n\nvar5\n\n5\n\n\n\nfactor5\n\n\n\n\nvar5--factor5\n\n\n\n\nvar6\n\n6\n\n\n\nfactor6\n\n\n\n\nvar6--factor6\n\n\n\n\nvar7\n\n7\n\n\n\nfactor7\n\n\n\n\nvar7--factor7\n\n\n\n\nvar8\n\n8\n\n\n\nvar8--factor0\n\n\n\n\nvar8--factor1\n\n\n\n\nfactor8\n\n\n\n\nvar8--factor8\n\n\n\n\nvar9\n\n9\n\n\n\nvar9--factor2\n\n\n\n\nvar9--factor3\n\n\n\n\nfactor9\n\n\n\n\nvar9--factor9\n\n\n\n\nvar10\n\n10\n\n\n\nvar10--factor4\n\n\n\n\nvar10--factor5\n\n\n\n\nfactor10\n\n\n\n\nvar10--factor10\n\n\n\n\nvar11\n\n11\n\n\n\nvar11--factor6\n\n\n\n\nvar11--factor7\n\n\n\n\nfactor11\n\n\n\n\nvar11--factor11\n\n\n\n\nvar12\n\n12\n\n\n\nvar14\n\n14\n\n\n\nvar12--var14\n\n\n\n\nvar12--factor0\n\n\n\n\nvar12--factor1\n\n\n\n\nvar12--factor2\n\n\n\n\nvar12--factor3\n\n\n\n\nvar12--factor8\n\n\n\n\nvar12--factor9\n\n\n\n\nvar13\n\n13\n\n\n\nvar13--var14\n\n\n\n\nvar13--factor4\n\n\n\n\nvar13--factor5\n\n\n\n\nvar13--factor6\n\n\n\n\nvar13--factor7\n\n\n\n\nvar13--factor10\n\n\n\n\nvar13--factor11\n\n\n\n\nvar14--factor8\n\n\n\n\nvar14--factor9\n\n\n\n\nvar14--factor10\n\n\n\n\nvar14--factor11\n\n\n\n\nfactor14\n\n\n\n\nvar14--factor14\n\n\n\n\n", + "text/plain": [ + "<__main__.show at 0x109c61f10>" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create a factor graph out of the Bayes net.\n", + "factorGraph = DiscreteFactorGraph(bayesNet)\n", + "show(factorGraph)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\nG\n\n\n\n0\n\n8,12,14\n\n\n\n1\n\n0 : 8,12\n\n\n\n0->1\n\n\n\n\n\n2\n\n1 : 8,12\n\n\n\n0->2\n\n\n\n\n\n3\n\n9 : 12,14\n\n\n\n0->3\n\n\n\n\n\n6\n\n10,13 : 14\n\n\n\n0->6\n\n\n\n\n\n4\n\n2 : 9,12\n\n\n\n3->4\n\n\n\n\n\n5\n\n3 : 9,12\n\n\n\n3->5\n\n\n\n\n\n7\n\n4 : 10,13\n\n\n\n6->7\n\n\n\n\n\n8\n\n5 : 10,13\n\n\n\n6->8\n\n\n\n\n\n9\n\n11 : 13,14\n\n\n\n6->9\n\n\n\n\n\n10\n\n6 : 11,13\n\n\n\n9->10\n\n\n\n\n\n11\n\n7 : 11,13\n\n\n\n9->11\n\n\n\n\n\n", + "text/plain": [ + "<__main__.show at 0x109c61b50>" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create a BayesTree out of the factor graph.\n", + "ordering = Ordering()\n", + "for j in range(15): ordering.push_back(j)\n", + "bayesTree = factorGraph.eliminateMultifrontal(ordering)\n", + "show(bayesTree)" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + }, + "kernelspec": { + "display_name": "Python 3.8.9 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/python/gtsam/notebooks/DiscreteSwitching.ipynb b/python/gtsam/notebooks/DiscreteSwitching.ipynb new file mode 100644 index 0000000000..0707cbd3bc --- /dev/null +++ b/python/gtsam/notebooks/DiscreteSwitching.ipynb @@ -0,0 +1,176 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# A Discrete Switching System\n", + "\n", + "A la MHS, but all discrete.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from gtsam import DiscreteBayesNet, DiscreteKeys, DiscreteFactorGraph, Ordering\n", + "from gtsam.symbol_shorthand import S\n", + "from gtsam.symbol_shorthand import M\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def P(*args):\n", + " \"\"\" Create a DiscreteKeys instances from a variable number of DiscreteKey pairs.\"\"\"\n", + " # TODO: We can make life easier by providing variable argument functions in C++ itself.\n", + " dks = DiscreteKeys()\n", + " for key in args:\n", + " dks.push_back(key)\n", + " return dks\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import graphviz\n", + "\n", + "\n", + "class show(graphviz.Source):\n", + " \"\"\" Display an object with a dot method as a graph.\"\"\"\n", + "\n", + " def __init__(self, obj):\n", + " \"\"\"Construct from object with 'dot' method.\"\"\"\n", + " # This small class takes an object, calls its dot function, and uses the\n", + " # resulting string to initialize a graphviz.Source instance. This in turn\n", + " # has a _repr_mimebundle_ method, which then renders it in the notebook.\n", + " super().__init__(obj.dot())\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\nG\n\n\n\ns1\n\ns1\n\n\n\ns2\n\ns2\n\n\n\ns1->s2\n\n\n\n\n\ns3\n\ns3\n\n\n\ns2->s3\n\n\n\n\n\nm1\n\nm1\n\n\n\nm1->s2\n\n\n\n\n\ns4\n\ns4\n\n\n\ns3->s4\n\n\n\n\n\nm2\n\nm2\n\n\n\nm2->s3\n\n\n\n\n\ns5\n\ns5\n\n\n\ns4->s5\n\n\n\n\n\nm3\n\nm3\n\n\n\nm3->s4\n\n\n\n\n\nm4\n\nm4\n\n\n\nm4->s5\n\n\n\n\n\n", + "text/plain": [ + "<__main__.show at 0x119a80d90>" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nrStates = 3\n", + "K = 5\n", + "\n", + "bayesNet = DiscreteBayesNet()\n", + "for k in range(1, K):\n", + " key = S(k), nrStates\n", + " key_plus = S(k+1), nrStates\n", + " mode = M(k), 2\n", + " bayesNet.add(key_plus, P(key, mode), \"1/1/1 1/2/1 3/2/3 1/1/1 1/2/1 3/2/3\")\n", + "\n", + "show(bayesNet)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\n\nvar7854277750134145025\n\nm1\n\n\n\nfactor0\n\n\n\n\nvar7854277750134145025--factor0\n\n\n\n\nvar7854277750134145026\n\nm2\n\n\n\nfactor1\n\n\n\n\nvar7854277750134145026--factor1\n\n\n\n\nvar7854277750134145027\n\nm3\n\n\n\nfactor2\n\n\n\n\nvar7854277750134145027--factor2\n\n\n\n\nvar7854277750134145028\n\nm4\n\n\n\nfactor3\n\n\n\n\nvar7854277750134145028--factor3\n\n\n\n\nvar8286623314361712641\n\ns1\n\n\n\nvar8286623314361712641--factor0\n\n\n\n\nvar8286623314361712642\n\ns2\n\n\n\nvar8286623314361712642--factor0\n\n\n\n\nvar8286623314361712642--factor1\n\n\n\n\nvar8286623314361712643\n\ns3\n\n\n\nvar8286623314361712643--factor1\n\n\n\n\nvar8286623314361712643--factor2\n\n\n\n\nvar8286623314361712644\n\ns4\n\n\n\nvar8286623314361712644--factor2\n\n\n\n\nvar8286623314361712644--factor3\n\n\n\n\nvar8286623314361712645\n\ns5\n\n\n\nvar8286623314361712645--factor3\n\n\n\n\n", + "text/plain": [ + "<__main__.show at 0x119a80820>" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create a factor graph out of the Bayes net.\n", + "factorGraph = DiscreteFactorGraph(bayesNet)\n", + "show(factorGraph)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Position 0: s1, s2, s3, s4, s5, m1, m2, m3, m4\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": "\n\n\n\n\n\nG\n\n\n\n0\n\ns4,s5,m1,m2,m3,m4\n\n\n\n1\n\ns3 : m1,m2,m3,s4\n\n\n\n0->1\n\n\n\n\n\n2\n\ns2 : m1,m2,s3\n\n\n\n1->2\n\n\n\n\n\n3\n\ns1 : m1,s2\n\n\n\n2->3\n\n\n\n\n\n", + "text/plain": [ + "<__main__.show at 0x119a76b80>" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create a BayesTree out of the factor graph.\n", + "ordering = Ordering()\n", + "# First eliminate \"continuous\" states in time order\n", + "for k in range(1, K+1):\n", + " ordering.push_back(S(k))\n", + "for k in range(1, K):\n", + " ordering.push_back(M(k))\n", + "print(ordering)\n", + "bayesTree = factorGraph.eliminateMultifrontal(ordering)\n", + "show(bayesTree)\n" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + }, + "kernelspec": { + "display_name": "Python 3.8.9 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/testNonlinearFactorGraph.cpp b/tests/testNonlinearFactorGraph.cpp index fdb080a63b..4dec08f45c 100644 --- a/tests/testNonlinearFactorGraph.cpp +++ b/tests/testNonlinearFactorGraph.cpp @@ -15,6 +15,7 @@ * @brief testNonlinearFactorGraph * @author Carlos Nieto * @author Christian Potthast + * @author Frank Dellaert */ #include @@ -285,6 +286,7 @@ TEST(testNonlinearFactorGraph, addPrior) { EXPECT(0 != graph.error(values)); } +/* ************************************************************************* */ TEST(NonlinearFactorGraph, printErrors) { const NonlinearFactorGraph fg = createNonlinearFactorGraph(); @@ -309,6 +311,53 @@ TEST(NonlinearFactorGraph, printErrors) for (bool visit : visited) EXPECT(visit==true); } +/* ************************************************************************* */ +TEST(NonlinearFactorGraph, dot) { + string expected = + "graph {\n" + " size=\"5,5\";\n" + "\n" + " var7782220156096217089[label=\"l1\"];\n" + " var8646911284551352321[label=\"x1\"];\n" + " var8646911284551352322[label=\"x2\"];\n" + "\n" + " factor0[label=\"\", shape=point];\n" + " var8646911284551352321--factor0;\n" + " var8646911284551352321--var8646911284551352322;\n" + " var8646911284551352321--var7782220156096217089;\n" + " var8646911284551352322--var7782220156096217089;\n" + "}\n"; + + const NonlinearFactorGraph fg = createNonlinearFactorGraph(); + string actual = fg.dot(); + EXPECT(actual == expected); +} + +/* ************************************************************************* */ +TEST(NonlinearFactorGraph, dot_extra) { + string expected = + "graph {\n" + " size=\"5,5\";\n" + "\n" + " var7782220156096217089[label=\"l1\", pos=\"0,0!\"];\n" + " var8646911284551352321[label=\"x1\", pos=\"1,0!\"];\n" + " var8646911284551352322[label=\"x2\", pos=\"1,1.5!\"];\n" + "\n" + " factor0[label=\"\", shape=point];\n" + " var8646911284551352321--factor0;\n" + " var8646911284551352321--var8646911284551352322;\n" + " var8646911284551352321--var7782220156096217089;\n" + " var8646911284551352322--var7782220156096217089;\n" + "}\n"; + + const NonlinearFactorGraph fg = createNonlinearFactorGraph(); + const Values c = createValues(); + + stringstream ss; + fg.dot(ss, c); + EXPECT(ss.str() == expected); +} + /* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr); } /* ************************************************************************* */