Skip to content

Commit

Permalink
Merge pull request #971 from borglab/feature/notebook_dot
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert authored Dec 21, 2021
2 parents 5e3db76 + e6ca595 commit 168a67d
Show file tree
Hide file tree
Showing 23 changed files with 1,206 additions and 289 deletions.
4 changes: 2 additions & 2 deletions gtsam/base/tests/testMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ TEST(Matrix, stack )
{
Matrix A = (Matrix(2, 2) << -5.0, 3.0, 00.0, -5.0).finished();
Matrix B = (Matrix(3, 2) << -0.5, 2.1, 1.1, 3.4, 2.6, 7.1).finished();
Matrix AB = stack(2, &A, &B);
Matrix AB = gtsam::stack(2, &A, &B);
Matrix C(5, 2);
for (int i = 0; i < 2; i++)
for (int j = 0; j < 2; j++)
Expand All @@ -187,7 +187,7 @@ TEST(Matrix, stack )
std::vector<gtsam::Matrix> matrices;
matrices.push_back(A);
matrices.push_back(B);
Matrix AB2 = stack(matrices);
Matrix AB2 = gtsam::stack(matrices);
EQUALITY(C,AB2);
}

Expand Down
20 changes: 15 additions & 5 deletions gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,9 @@ namespace gtsam {
void dot(std::ostream& os, bool showZero) const override {
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_
<< "\"]\n";
for (size_t i = 0; i < branches_.size(); i++) {
NodePtr branch = branches_[i];
size_t B = branches_.size();
for (size_t i = 0; i < B; i++) {
const NodePtr& branch = branches_[i];

// Check if zero
if (!showZero) {
Expand All @@ -258,8 +259,10 @@ namespace gtsam {
}

os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
if (i == 0) os << " [style=dashed]";
if (i > 1) os << " [style=bold]";
if (B == 2) {
if (i == 0) os << " [style=dashed]";
if (i > 1) os << " [style=bold]";
}
os << std::endl;
branch->dot(os, showZero);
}
Expand Down Expand Up @@ -671,7 +674,14 @@ namespace gtsam {
int result = system(
("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str());
if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed");
}
}

template<typename L, typename Y>
std::string DecisionTree<L, Y>::dot(bool showZero) const {
std::stringstream ss;
dot(ss, showZero);
return ss.str();
}

/*********************************************************************************/

Expand Down
3 changes: 3 additions & 0 deletions gtsam/discrete/DecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ namespace gtsam {
/** output to graphviz format, open a file */
void dot(const std::string& name, bool showZero = true) const;

/** output to graphviz format string */
std::string dot(bool showZero = true) const;

/// @name Advanced Interface
/// @{

Expand Down
26 changes: 22 additions & 4 deletions gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor {
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
string dot(bool showZero = false) const;
};

#include <gtsam/discrete/DiscreteConditional.h>
Expand All @@ -52,8 +52,6 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
const gtsam::DecisionTreeFactor& marginal,
const gtsam::Ordering& orderedKeys);
size_t size() const; // TODO(dellaert): why do I have to repeat???
double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
void print(string s = "Discrete Conditional\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
Expand Down Expand Up @@ -82,6 +80,8 @@ class DiscreteBayesNet {
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void saveGraph(string s,
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
Expand All @@ -98,9 +98,19 @@ class DiscreteBayesTree {
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const;
string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void saveGraph(string s,
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
double operator()(const gtsam::DiscreteValues& values) const;
};

#include <gtsam/inference/DotWriter.h>
class DotWriter {
DotWriter();
};

#include <gtsam/discrete/DiscreteFactorGraph.h>
class DiscreteFactorGraph {
DiscreteFactorGraph();
Expand All @@ -117,7 +127,15 @@ class DiscreteFactorGraph {

void print(string s = "") const;
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;


string dot(const gtsam::DotWriter& dotWriter = gtsam::DotWriter(),
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void saveGraph(string s,
const gtsam::DotWriter& dotWriter = gtsam::DotWriter(),
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;

gtsam::DecisionTreeFactor product() const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const;
Expand Down
25 changes: 24 additions & 1 deletion gtsam/discrete/tests/testDiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ TEST(DiscreteBayesNet, Asia) {
}

/* ************************************************************************* */
TEST_UNSAFE(DiscreteBayesNet, Sugar) {
TEST(DiscreteBayesNet, Sugar) {
DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2);

DiscreteBayesNet bn;
Expand All @@ -149,6 +149,29 @@ TEST_UNSAFE(DiscreteBayesNet, Sugar) {
bn.add(C | S = "1/1/2 5/2/3");
}

/* ************************************************************************* */
TEST(DiscreteBayesNet, Dot) {
DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2),
Either(5, 2);

DiscreteBayesNet fragment;
fragment.add(Asia % "99/1");
fragment.add(Smoking % "50/50");

fragment.add(Tuberculosis | Asia = "99/1 95/5");
fragment.add(LungCancer | Smoking = "99/1 90/10");
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");

string actual = fragment.dot();
EXPECT(actual ==
"digraph G{\n"
"0->3\n"
"4->6\n"
"3->5\n"
"6->5\n"
"}");
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down
Loading

0 comments on commit 168a67d

Please sign in to comment.