diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 308b2f9cac..841f90fe26 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -168,7 +168,7 @@ namespace gtsam { /// Render as markdown table. std::string _repr_markdown_( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; /// @} diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index 219f2d93e6..e50f4586f5 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -38,7 +38,7 @@ namespace gtsam { double DiscreteBayesNet::evaluate(const DiscreteValues & values) const { // evaluate all conditionals and multiply double result = 1.0; - for(DiscreteConditional::shared_ptr conditional: *this) + for(const DiscreteConditional::shared_ptr& conditional: *this) result *= (*conditional)(values); return result; } @@ -61,5 +61,15 @@ namespace gtsam { return result; } + /* ************************************************************************* */ + std::string DiscreteBayesNet::_repr_markdown_( + const KeyFormatter& keyFormatter) const { + using std::endl; + std::stringstream ss; + ss << "`DiscreteBayesNet` of size " << size() << endl << endl; + for(const DiscreteConditional::shared_ptr& conditional: *this) + ss << conditional->_repr_markdown_(keyFormatter) << endl; + return ss.str(); + } /* ************************************************************************* */ } // namespace diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index 2d92b72e8f..5eb656b3b5 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -13,6 +13,7 @@ * @file DiscreteBayesNet.h * @date Feb 15, 2011 * @author Duy-Nguyen Ta + * @author Frank dellaert */ #pragma once @@ -97,6 +98,14 @@ namespace gtsam { DiscreteValues sample() const; ///@} + /// @name Wrapper support + /// @{ + + /// Render as markdown table. + std::string _repr_markdown_( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// @} private: /** Serialization function */ diff --git a/gtsam/discrete/DiscreteBayesTree.cpp b/gtsam/discrete/DiscreteBayesTree.cpp index 48413405ac..09a1f47aac 100644 --- a/gtsam/discrete/DiscreteBayesTree.cpp +++ b/gtsam/discrete/DiscreteBayesTree.cpp @@ -55,8 +55,21 @@ namespace gtsam { return result; } -} // \namespace gtsam - - - + /* **************************************************************************/ + std::string DiscreteBayesTree::_repr_markdown_( + const KeyFormatter& keyFormatter) const { + using std::endl; + std::stringstream ss; + ss << "`DiscreteBayesTree` of size " << nodes_.size() << endl << endl; + auto visitor = [&](const DiscreteBayesTreeClique::shared_ptr& clique, + size_t& indent) { + ss << "\n" << clique->conditional()->_repr_markdown_(keyFormatter); + return indent + 1; + }; + size_t indent; + treeTraversal::DepthFirstForest(*this, indent, visitor); + return ss.str(); + } + /* **************************************************************************/ + } // namespace gtsam diff --git a/gtsam/discrete/DiscreteBayesTree.h b/gtsam/discrete/DiscreteBayesTree.h index 42ec7d4173..675b951edf 100644 --- a/gtsam/discrete/DiscreteBayesTree.h +++ b/gtsam/discrete/DiscreteBayesTree.h @@ -72,6 +72,8 @@ class GTSAM_EXPORT DiscreteBayesTree typedef DiscreteBayesTree This; typedef boost::shared_ptr shared_ptr; + /// @name Standard interface + /// @{ /** Default constructor, creates an empty Bayes tree */ DiscreteBayesTree() {} @@ -82,10 +84,19 @@ class GTSAM_EXPORT DiscreteBayesTree double evaluate(const DiscreteValues& values) const; //** (Preferred) sugar for the above for given DiscreteValues */ - double operator()(const DiscreteValues & values) const { + double operator()(const DiscreteValues& values) const { return evaluate(values); } + /// @} + /// @name Wrapper support + /// @{ + + /// Render as markdown table. + std::string _repr_markdown_( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// @} }; } // namespace gtsam diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 7ed962188e..b02d095758 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -229,11 +229,22 @@ std::string DiscreteConditional::_repr_markdown_( // Print out signature. ss << " $P("; - for(Key key: frontals()) - ss << keyFormatter(key); - if (nrParents() > 0) - ss << "|"; bool first = true; + for (Key key : frontals()) { + if (!first) ss << ","; + ss << keyFormatter(key); + first = false; + } + if (nrParents() == 0) { + // We have no parents, call factor method. + ss << ")$:" << std::endl; + ss << DecisionTreeFactor::_repr_markdown_(); + return ss.str(); + } + + // We have parents, continue signature and do custom print. + ss << "|"; + first = true; for (Key parent : parents()) { if (!first) ss << ","; ss << keyFormatter(parent); @@ -256,9 +267,8 @@ std::string DiscreteConditional::_repr_markdown_( pairs.emplace_back(key, k); n *= k; } - size_t nrParents = size() - nrFrontals_; std::vector> slatnorf(pairs.rbegin(), - pairs.rend() - nrParents); + pairs.rend() - nrParents()); const auto frontal_assignments = cartesianProduct(slatnorf); for (const auto& a : frontal_assignments) { for (it = beginFrontals(); it != endFrontals(); ++it) ss << a.at(*it); @@ -268,7 +278,7 @@ std::string DiscreteConditional::_repr_markdown_( // Print out separator with alignment hints. ss << "|"; - for (size_t j = 0; j < nrParents + n; j++) ss << ":-:|"; + for (size_t j = 0; j < nrParents() + n; j++) ss << ":-:|"; ss << "\n"; // Print out all rows. diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index ad21151a89..b76e4f65fb 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -172,7 +172,7 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor, /// Render as markdown table. std::string _repr_markdown_( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; /// @} }; diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index e2be94b94a..f046e5e44a 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -88,6 +88,14 @@ class GTSAM_EXPORT DiscreteFactor: public Factor { virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; + /// @} + /// @name Wrapper support + /// @{ + + /// Render as markdown table. + virtual std::string _repr_markdown_( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const = 0; + /// @} }; // DiscreteFactor diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 77127ac304..129ab4dae8 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -129,6 +129,18 @@ namespace gtsam { return std::make_pair(cond, sum); } -/* ************************************************************************* */ -} // namespace + /* ************************************************************************* */ + std::string DiscreteFactorGraph::_repr_markdown_( + const KeyFormatter& keyFormatter) const { + using std::endl; + std::stringstream ss; + ss << "`DiscreteFactorGraph` of size " << size() << endl << endl; + for (size_t i = 0; i < factors_.size(); i++) { + ss << "factor " << i << ":\n"; + ss << factors_[i]->_repr_markdown_(keyFormatter) << endl; + } + return ss.str(); + } + /* ************************************************************************* */ + } // namespace gtsam diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index ff0aaef19e..616d7c7d2a 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -154,6 +154,14 @@ public EliminateableFactorGraph { // /** Apply a reduction, which is a remapping of variable indices. */ // GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction); + /// @name Wrapper support + /// @{ + + /// Render as markdown table. + std::string _repr_markdown_( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// @} }; // \ DiscreteFactorGraph /// traits diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 4618073fa3..a2377dc59d 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -93,6 +93,8 @@ class DiscreteBayesNet { double operator()(const gtsam::DiscreteValues& values) const; gtsam::DiscreteValues optimize() const; gtsam::DiscreteValues sample() const; + string _repr_markdown_(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; }; #include @@ -124,6 +126,9 @@ class DiscreteBayesTree { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; double operator()(const gtsam::DiscreteValues& values) const; + + string _repr_markdown_(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; }; #include @@ -164,6 +169,9 @@ class DiscreteFactorGraph { gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering); gtsam::DiscreteBayesTree eliminateMultifrontal(); gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering); + + string _repr_markdown_(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; }; } // namespace gtsam diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index f0c7d37281..827b7d2485 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -38,6 +38,9 @@ using namespace boost::assign; using namespace std; using namespace gtsam; +static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), + LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2); + /* ************************************************************************* */ TEST(DiscreteBayesNet, bayesNet) { DiscreteBayesNet bayesNet; @@ -71,8 +74,6 @@ TEST(DiscreteBayesNet, bayesNet) { /* ************************************************************************* */ TEST(DiscreteBayesNet, Asia) { DiscreteBayesNet asia; - DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2), - Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2); asia.add(Asia % "99/1"); asia.add(Smoking % "50/50"); @@ -151,9 +152,6 @@ TEST(DiscreteBayesNet, Sugar) { /* ************************************************************************* */ TEST(DiscreteBayesNet, Dot) { - DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2), - Either(5, 2); - DiscreteBayesNet fragment; fragment.add(Asia % "99/1"); fragment.add(Smoking % "50/50"); @@ -172,6 +170,32 @@ TEST(DiscreteBayesNet, Dot) { "}"); } +/* ************************************************************************* */ +// Check markdown representation looks as expected. +TEST(DiscreteBayesNet, markdown) { + DiscreteBayesNet fragment; + fragment.add(Asia % "99/1"); + fragment.add(Smoking | Asia = "8/2 7/3"); + + string expected = + "`DiscreteBayesNet` of size 2\n" + "\n" + " $P(Asia)$:\n" + "|0|value|\n" + "|:-:|:-:|\n" + "|0|0.99|\n" + "|1|0.01|\n" + "\n" + " $P(Smoking|Asia)$:\n" + "|Asia|0|1|\n" + "|:-:|:-:|:-:|\n" + "|0|0.8|0.2|\n" + "|1|0.7|0.3|\n\n"; + auto formatter = [](Key key) { return key == 0 ? "Asia" : "Smoking"; }; + string actual = fragment._repr_markdown_(formatter); + EXPECT(actual == expected); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 698268e843..964a33926e 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -110,13 +110,15 @@ TEST(DiscreteConditional, Combine) { /* ************************************************************************* */ // Check markdown representation looks as expected, no parents. TEST(DiscreteConditional, markdown_prior) { - DiscreteKey A(Symbol('x', 1), 2); - DiscreteConditional conditional(A % "1/3"); + DiscreteKey A(Symbol('x', 1), 3); + DiscreteConditional conditional(A % "1/2/2"); string expected = " $P(x1)$:\n" - "|0|1|\n" + "|x1|value|\n" "|:-:|:-:|\n" - "|0.25|0.75|\n"; + "|0|0.2|\n" + "|1|0.4|\n" + "|2|0.4|\n"; string actual = conditional._repr_markdown_(); EXPECT(actual == expected); } diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index 32117bd25c..f1fd26af4f 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -361,11 +361,9 @@ cout << unicorns; /* ************************************************************************* */ TEST(DiscreteFactorGraph, Dot) { - // Declare a bunch of keys - DiscreteKey C(0, 2), A(1, 2), B(2, 2); - // Create Factor graph DiscreteFactorGraph graph; + DiscreteKey C(0, 2), A(1, 2), B(2, 2); graph.add(C & A, "0.2 0.8 0.3 0.7"); graph.add(C & B, "0.1 0.9 0.4 0.6"); @@ -384,6 +382,44 @@ TEST(DiscreteFactorGraph, Dot) { EXPECT(actual == expected); } +/* ************************************************************************* */ +// Check markdown representation looks as expected. +TEST(DiscreteFactorGraph, markdown) { + // Create Factor graph + DiscreteFactorGraph graph; + DiscreteKey C(0, 2), A(1, 2), B(2, 2); + graph.add(C & A, "0.2 0.8 0.3 0.7"); + graph.add(C & B, "0.1 0.9 0.4 0.6"); + + string expected = + "`DiscreteFactorGraph` of size 2\n" + "\n" + "factor 0:\n" + "|C|A|value|\n" + "|:-:|:-:|:-:|\n" + "|0|0|0.2|\n" + "|0|1|0.8|\n" + "|1|0|0.3|\n" + "|1|1|0.7|\n" + "\n" + "factor 1:\n" + "|C|B|value|\n" + "|:-:|:-:|:-:|\n" + "|0|0|0.1|\n" + "|0|1|0.9|\n" + "|1|0|0.4|\n" + "|1|1|0.6|\n\n"; + vector names{"C", "A", "B"}; + auto formatter = [names](Key key) { return names[key]; }; + string actual = graph._repr_markdown_(formatter); + EXPECT(actual == expected); + + // Make sure values are correctly displayed. + DiscreteValues values; + values[0] = 1; + values[1] = 0; + EXPECT_DOUBLES_EQUAL(0.3, graph[0]->operator()(values), 1e-9); +} /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h index 0a05bbfd2e..e772c54df5 100644 --- a/gtsam_unstable/discrete/Constraint.h +++ b/gtsam_unstable/discrete/Constraint.h @@ -22,6 +22,7 @@ #include #include +#include #include namespace gtsam { @@ -81,6 +82,16 @@ class GTSAM_EXPORT Constraint : public DiscreteFactor { /// Partially apply known values, domain version virtual shared_ptr partiallyApply(const Domains&) const = 0; /// @} + /// @name Wrapper support + /// @{ + + /// Render as markdown table. + std::string _repr_markdown_( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override { + return (boost::format("`Constraint` on %1% variables\n") % (size())).str(); + } + + /// @} }; // DiscreteFactor diff --git a/python/gtsam/notebooks/DiscreteSwitching.ipynb b/python/gtsam/notebooks/DiscreteSwitching.ipynb index 0707cbd3bc..6872e78c80 100644 --- a/python/gtsam/notebooks/DiscreteSwitching.ipynb +++ b/python/gtsam/notebooks/DiscreteSwitching.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -22,7 +22,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -37,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -57,21 +57,9 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "image/svg+xml": "\n\n\n\n\n\nG\n\n\n\ns1\n\ns1\n\n\n\ns2\n\ns2\n\n\n\ns1->s2\n\n\n\n\n\ns3\n\ns3\n\n\n\ns2->s3\n\n\n\n\n\nm1\n\nm1\n\n\n\nm1->s2\n\n\n\n\n\ns4\n\ns4\n\n\n\ns3->s4\n\n\n\n\n\nm2\n\nm2\n\n\n\nm2->s3\n\n\n\n\n\ns5\n\ns5\n\n\n\ns4->s5\n\n\n\n\n\nm3\n\nm3\n\n\n\nm3->s4\n\n\n\n\n\nm4\n\nm4\n\n\n\nm4->s5\n\n\n\n\n\n", - "text/plain": [ - "<__main__.show at 0x119a80d90>" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "nrStates = 3\n", "K = 5\n", @@ -81,59 +69,36 @@ " key = S(k), nrStates\n", " key_plus = S(k+1), nrStates\n", " mode = M(k), 2\n", - " bayesNet.add(key_plus, P(key, mode), \"1/1/1 1/2/1 3/2/3 1/1/1 1/2/1 3/2/3\")\n", + " bayesNet.add(key_plus, P(mode, key), \"9/1/0 1/8/1 0/1/9 1/9/0 0/1/9 9/0/1\")\n", "\n", - "show(bayesNet)\n" + "bayesNet" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "image/svg+xml": "\n\n\n\n\n\n\n\n\nvar7854277750134145025\n\nm1\n\n\n\nfactor0\n\n\n\n\nvar7854277750134145025--factor0\n\n\n\n\nvar7854277750134145026\n\nm2\n\n\n\nfactor1\n\n\n\n\nvar7854277750134145026--factor1\n\n\n\n\nvar7854277750134145027\n\nm3\n\n\n\nfactor2\n\n\n\n\nvar7854277750134145027--factor2\n\n\n\n\nvar7854277750134145028\n\nm4\n\n\n\nfactor3\n\n\n\n\nvar7854277750134145028--factor3\n\n\n\n\nvar8286623314361712641\n\ns1\n\n\n\nvar8286623314361712641--factor0\n\n\n\n\nvar8286623314361712642\n\ns2\n\n\n\nvar8286623314361712642--factor0\n\n\n\n\nvar8286623314361712642--factor1\n\n\n\n\nvar8286623314361712643\n\ns3\n\n\n\nvar8286623314361712643--factor1\n\n\n\n\nvar8286623314361712643--factor2\n\n\n\n\nvar8286623314361712644\n\ns4\n\n\n\nvar8286623314361712644--factor2\n\n\n\n\nvar8286623314361712644--factor3\n\n\n\n\nvar8286623314361712645\n\ns5\n\n\n\nvar8286623314361712645--factor3\n\n\n\n\n", - "text/plain": [ - "<__main__.show at 0x119a80820>" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], + "source": [ + "show(bayesNet)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# Create a factor graph out of the Bayes net.\n", "factorGraph = DiscreteFactorGraph(bayesNet)\n", - "show(factorGraph)\n" + "show(factorGraph)" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Position 0: s1, s2, s3, s4, s5, m1, m2, m3, m4\n", - "\n" - ] - }, - { - "data": { - "image/svg+xml": "\n\n\n\n\n\nG\n\n\n\n0\n\ns4,s5,m1,m2,m3,m4\n\n\n\n1\n\ns3 : m1,m2,m3,s4\n\n\n\n0->1\n\n\n\n\n\n2\n\ns2 : m1,m2,s3\n\n\n\n1->2\n\n\n\n\n\n3\n\ns1 : m1,s2\n\n\n\n2->3\n\n\n\n\n\n", - "text/plain": [ - "<__main__.show at 0x119a76b80>" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# Create a BayesTree out of the factor graph.\n", "ordering = Ordering()\n", @@ -144,8 +109,22 @@ " ordering.push_back(M(k))\n", "print(ordering)\n", "bayesTree = factorGraph.eliminateMultifrontal(ordering)\n", - "show(bayesTree)\n" + "bayesTree" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "show(bayesTree)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] } ], "metadata": {