Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streamlined dot methods #971

Merged
merged 11 commits into from
Dec 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions gtsam/base/tests/testMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ TEST(Matrix, stack )
{
Matrix A = (Matrix(2, 2) << -5.0, 3.0, 00.0, -5.0).finished();
Matrix B = (Matrix(3, 2) << -0.5, 2.1, 1.1, 3.4, 2.6, 7.1).finished();
Matrix AB = stack(2, &A, &B);
Matrix AB = gtsam::stack(2, &A, &B);
Matrix C(5, 2);
for (int i = 0; i < 2; i++)
for (int j = 0; j < 2; j++)
Expand All @@ -187,7 +187,7 @@ TEST(Matrix, stack )
std::vector<gtsam::Matrix> matrices;
matrices.push_back(A);
matrices.push_back(B);
Matrix AB2 = stack(matrices);
Matrix AB2 = gtsam::stack(matrices);
EQUALITY(C,AB2);
}

Expand Down
20 changes: 15 additions & 5 deletions gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,9 @@ namespace gtsam {
void dot(std::ostream& os, bool showZero) const override {
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_
<< "\"]\n";
for (size_t i = 0; i < branches_.size(); i++) {
NodePtr branch = branches_[i];
size_t B = branches_.size();
for (size_t i = 0; i < B; i++) {
const NodePtr& branch = branches_[i];

// Check if zero
if (!showZero) {
Expand All @@ -258,8 +259,10 @@ namespace gtsam {
}

os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
if (i == 0) os << " [style=dashed]";
if (i > 1) os << " [style=bold]";
if (B == 2) {
if (i == 0) os << " [style=dashed]";
if (i > 1) os << " [style=bold]";
}
os << std::endl;
branch->dot(os, showZero);
}
Expand Down Expand Up @@ -671,7 +674,14 @@ namespace gtsam {
int result = system(
("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str());
if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed");
}
}

template<typename L, typename Y>
std::string DecisionTree<L, Y>::dot(bool showZero) const {
std::stringstream ss;
dot(ss, showZero);
return ss.str();
}

/*********************************************************************************/

Expand Down
3 changes: 3 additions & 0 deletions gtsam/discrete/DecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ namespace gtsam {
/** output to graphviz format, open a file */
void dot(const std::string& name, bool showZero = true) const;

/** output to graphviz format string */
std::string dot(bool showZero = true) const;

/// @name Advanced Interface
/// @{

Expand Down
26 changes: 22 additions & 4 deletions gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor {
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
string dot(bool showZero = false) const;
};

#include <gtsam/discrete/DiscreteConditional.h>
Expand All @@ -52,8 +52,6 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
const gtsam::DecisionTreeFactor& marginal,
const gtsam::Ordering& orderedKeys);
size_t size() const; // TODO(dellaert): why do I have to repeat???
double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
void print(string s = "Discrete Conditional\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
Expand Down Expand Up @@ -82,6 +80,8 @@ class DiscreteBayesNet {
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void saveGraph(string s,
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
Expand All @@ -98,9 +98,19 @@ class DiscreteBayesTree {
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const;
string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void saveGraph(string s,
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
double operator()(const gtsam::DiscreteValues& values) const;
};

#include <gtsam/inference/DotWriter.h>
class DotWriter {
DotWriter();
};

#include <gtsam/discrete/DiscreteFactorGraph.h>
class DiscreteFactorGraph {
DiscreteFactorGraph();
Expand All @@ -117,7 +127,15 @@ class DiscreteFactorGraph {

void print(string s = "") const;
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;


string dot(const gtsam::DotWriter& dotWriter = gtsam::DotWriter(),
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void saveGraph(string s,
const gtsam::DotWriter& dotWriter = gtsam::DotWriter(),
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;

gtsam::DecisionTreeFactor product() const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const;
Expand Down
25 changes: 24 additions & 1 deletion gtsam/discrete/tests/testDiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ TEST(DiscreteBayesNet, Asia) {
}

/* ************************************************************************* */
TEST_UNSAFE(DiscreteBayesNet, Sugar) {
TEST(DiscreteBayesNet, Sugar) {
DiscreteKey T(0, 2), L(1, 2), E(2, 2), C(8, 3), S(7, 2);

DiscreteBayesNet bn;
Expand All @@ -149,6 +149,29 @@ TEST_UNSAFE(DiscreteBayesNet, Sugar) {
bn.add(C | S = "1/1/2 5/2/3");
}

/* ************************************************************************* */
TEST(DiscreteBayesNet, Dot) {
DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2),
Either(5, 2);

DiscreteBayesNet fragment;
fragment.add(Asia % "99/1");
fragment.add(Smoking % "50/50");

fragment.add(Tuberculosis | Asia = "99/1 95/5");
fragment.add(LungCancer | Smoking = "99/1 90/10");
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");

string actual = fragment.dot();
EXPECT(actual ==
"digraph G{\n"
"0->3\n"
"4->6\n"
"3->5\n"
"6->5\n"
"}");
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down
Loading