Skip to content

Commit

Permalink
Merge pull request #985 from borglab/featue/wrap_discrete_BT
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert authored Dec 24, 2021
2 parents 2a2c4ef + b29b0ea commit fb3f00d
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 0 deletions.
16 changes: 16 additions & 0 deletions gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,28 @@ class DiscreteBayesNet {
};

#include <gtsam/discrete/DiscreteBayesTree.h>
class DiscreteBayesTreeClique {
DiscreteBayesTreeClique();
DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional);
const gtsam::DiscreteConditional* conditional() const;
bool isRoot() const;
void printSignature(
const string& s = "Clique: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
double evaluate(const gtsam::DiscreteValues& values) const;
};

class DiscreteBayesTree {
DiscreteBayesTree();
void print(string s = "DiscreteBayesTree\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const;

size_t size() const;
bool empty() const;
const DiscreteBayesTreeClique* operator[](size_t j) const;

string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void saveGraph(string s,
Expand Down
25 changes: 25 additions & 0 deletions python/gtsam/tests/test_DiscreteBayesTree.dot
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
digraph G{
0[label="8,12,14"];
0->1
1[label="0 : 8,12"];
0->2
2[label="1 : 8,12"];
0->3
3[label="9 : 12,14"];
3->4
4[label="2 : 9,12"];
3->5
5[label="3 : 9,12"];
0->6
6[label="10,13 : 14"];
6->7
7[label="4 : 10,13"];
6->8
8[label="5 : 10,13"];
6->9
9[label="11 : 13,14"];
9->10
10[label="6 : 11,13"];
9->11
11[label="7 : 11,13"];
}
89 changes: 89 additions & 0 deletions python/gtsam/tests/test_DiscreteBayesTree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
Atlanta, Georgia 30332-0415
All Rights Reserved
See LICENSE for the license information
Unit tests for Discrete Bayes trees.
Author: Frank Dellaert
"""

# pylint: disable=no-name-in-module, invalid-name

import unittest

from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
DiscreteConditional, DiscreteFactorGraph, DiscreteKeys,
Ordering)
from gtsam.utils.test_case import GtsamTestCase


def P(*args):
""" Create a DiscreteKeys instances from a variable number of DiscreteKey pairs."""
# TODO: We can make life easier by providing variable argument functions in C++ itself.
dks = DiscreteKeys()
for key in args:
dks.push_back(key)
return dks


class TestDiscreteBayesNet(GtsamTestCase):
"""Tests for Discrete Bayes Nets."""

def test_elimination(self):
"""Test Multifrontal elimination."""

# Define DiscreteKey pairs.
keys = [(j, 2) for j in range(15)]

# Create thin-tree Bayesnet.
bayesNet = DiscreteBayesNet()

bayesNet.add(keys[0], P(keys[8], keys[12]), "2/3 1/4 3/2 4/1")
bayesNet.add(keys[1], P(keys[8], keys[12]), "4/1 2/3 3/2 1/4")
bayesNet.add(keys[2], P(keys[9], keys[12]), "1/4 8/2 2/3 4/1")
bayesNet.add(keys[3], P(keys[9], keys[12]), "1/4 2/3 3/2 4/1")

bayesNet.add(keys[4], P(keys[10], keys[13]), "2/3 1/4 3/2 4/1")
bayesNet.add(keys[5], P(keys[10], keys[13]), "4/1 2/3 3/2 1/4")
bayesNet.add(keys[6], P(keys[11], keys[13]), "1/4 3/2 2/3 4/1")
bayesNet.add(keys[7], P(keys[11], keys[13]), "1/4 2/3 3/2 4/1")

bayesNet.add(keys[8], P(keys[12], keys[14]), "T 1/4 3/2 4/1")
bayesNet.add(keys[9], P(keys[12], keys[14]), "4/1 2/3 F 1/4")
bayesNet.add(keys[10], P(keys[13], keys[14]), "1/4 3/2 2/3 4/1")
bayesNet.add(keys[11], P(keys[13], keys[14]), "1/4 2/3 3/2 4/1")

bayesNet.add(keys[12], P(keys[14]), "3/1 3/1")
bayesNet.add(keys[13], P(keys[14]), "1/3 3/1")

bayesNet.add(keys[14], P(), "1/3")

# Create a factor graph out of the Bayes net.
factorGraph = DiscreteFactorGraph(bayesNet)

# Create a BayesTree out of the factor graph.
ordering = Ordering()
for j in range(15):
ordering.push_back(j)
bayesTree = factorGraph.eliminateMultifrontal(ordering)

# Uncomment these for visualization:
# print(bayesTree)
# for key in range(15):
# bayesTree[key].printSignature()
# bayesTree.saveGraph("test_DiscreteBayesTree.dot")

self.assertFalse(bayesTree.empty())
self.assertEqual(12, bayesTree.size())

# The root is P( 8 12 14), we can retrieve it by key:
root = bayesTree[8]
self.assertIsInstance(root, DiscreteBayesTreeClique)
self.assertTrue(root.isRoot())
self.assertIsInstance(root.conditional(), DiscreteConditional)


if __name__ == "__main__":
unittest.main()

0 comments on commit fb3f00d

Please sign in to comment.