Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes in gtsam to support hybrid branches #1055

Merged
merged 16 commits into from
Jan 23, 2022
13 changes: 13 additions & 0 deletions gtsam/base/utilities.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include <gtsam/base/utilities.h>

namespace gtsam {

std::string RedirectCout::str() const {
return ssBuffer_.str();
}

RedirectCout::~RedirectCout() {
std::cout.rdbuf(coutBuffer_);
}

}
12 changes: 6 additions & 6 deletions gtsam/base/utilities.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#pragma once

#include <string>
#include <iostream>
#include <sstream>

namespace gtsam {
/**
* For Python __str__().
Expand All @@ -12,14 +16,10 @@ struct RedirectCout {
RedirectCout() : ssBuffer_(), coutBuffer_(std::cout.rdbuf(ssBuffer_.rdbuf())) {}

/// return the string
std::string str() const {
return ssBuffer_.str();
}
std::string str() const;

/// destructor -- redirect stdout buffer to its original buffer
~RedirectCout() {
std::cout.rdbuf(coutBuffer_);
}
~RedirectCout();

private:
std::stringstream ssBuffer_;
Expand Down
2 changes: 1 addition & 1 deletion gtsam/discrete/AlgebraicDecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ namespace gtsam {
const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const {
auto valueFormatter = [](const double& v) {
return (boost::format("%4.2g") % v).str();
return (boost::format("%4.4g") % v).str();
};
Base::print(s, labelFormatter, valueFormatter);
}
Expand Down
6 changes: 5 additions & 1 deletion gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ namespace gtsam {
using MXChoice = typename DecisionTree<M, X>::Choice;
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
if (!choice) throw std::invalid_argument(
"DecisionTree::Convert: Invalid NodePtr");
"DecisionTree::convertFrom: Invalid NodePtr");

// get new label
const M oldLabel = choice->label();
Expand Down Expand Up @@ -634,6 +634,8 @@ namespace gtsam {

using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
if (!choice)
throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr");
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
}
};
Expand Down Expand Up @@ -663,6 +665,8 @@ namespace gtsam {

using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
if (!choice)
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
for (size_t i = 0; i < choice->nrChoices(); i++) {
choices[choice->label()] = i; // Set assignment for label to i
(*this)(choice->branches()[i]); // recurse!
Expand Down
9 changes: 8 additions & 1 deletion gtsam/discrete/DecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace gtsam {
* Y = function range (any algebra), e.g., bool, int, double
*/
template<typename L, typename Y>
class GTSAM_EXPORT DecisionTree {
class DecisionTree {

protected:
/// Default method for comparison of two objects of type Y.
Expand Down Expand Up @@ -340,4 +340,11 @@ namespace gtsam {
return f.apply(g, op);
}

/// unzip a DecisionTree if its leaves are `std::pair`
template<typename L, typename T1, typename T2>
std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(const DecisionTree<L, std::pair<T1, T2> > &input) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, @ProfFan , this should probably have been a static method and be called Unzip. Or a method, and then it would not take an argument, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, just looking for consensus, I think I would change it to static.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I see now it is neither: it is a free function :-/ Which is correctly named, so never mind! Will re-format is all.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to reopen CI?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll just do the required one and merge

return std::make_pair(DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
DecisionTree<L, T2>(input, [](std::pair<T1, T2> i) { return i.second; }));
}

} // namespace gtsam
14 changes: 13 additions & 1 deletion gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ namespace gtsam {
for (auto&& key : keys())
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
cout << " ]" << endl;
ADT::print("Potentials:", formatter);
ADT::print("", formatter);
}

/* ************************************************************************* */
Expand Down Expand Up @@ -168,6 +168,18 @@ namespace gtsam {
return result;
}

/* ************************************************************************* */
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}

/* ************************************************************************* */
static std::string valueFormatter(const double& v) {
return (boost::format("%4.2g") % v).str();
Expand Down
3 changes: 3 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ namespace gtsam {
/// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const;

/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;

/// @}
/// @name Wrapper support
/// @{
Expand Down
47 changes: 47 additions & 0 deletions gtsam/discrete/DiscreteFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,59 @@
* @author Frank Dellaert
*/

#include <gtsam/base/Vector.h>
#include <gtsam/discrete/DiscreteFactor.h>

#include <cmath>
#include <sstream>

using namespace std;

namespace gtsam {

/* ************************************************************************* */
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
double maxLogProb = -std::numeric_limits<double>::infinity();
for (size_t i = 0; i < logProbs.size(); i++) {
double logProb = logProbs[i];
if ((logProb != std::numeric_limits<double>::infinity()) &&
logProb > maxLogProb) {
maxLogProb = logProb;
}
}

// After computing the max = "Z" of the log probabilities L_i, we compute
// the log of the normalizing constant, log S, where S = sum_j exp(L_j - Z).
double total = 0.0;
for (size_t i = 0; i < logProbs.size(); i++) {
double probPrime = exp(logProbs[i] - maxLogProb);
total += probPrime;
}
double logTotal = log(total);

// Now we compute the (normalized) probability (for each i):
// p_i = exp(L_i - Z - log S)
double checkNormalization = 0.0;
std::vector<double> probs;
for (size_t i = 0; i < logProbs.size(); i++) {
double prob = exp(logProbs[i] - maxLogProb - logTotal);
probs.push_back(prob);
checkNormalization += prob;
}

// Numerical tolerance for floating point comparisons
double tol = 1e-9;

if (!gtsam::fpEqual(checkNormalization, 1.0, tol)) {
std::string errMsg =
std::string("expNormalize failed to normalize probabilities. ") +
std::string("Expected normalization constant = 1.0. Got value: ") +
std::to_string(checkNormalization) +
std::string(
"\n This could have resulted from numerical overflow/underflow.");
throw std::logic_error(errMsg);
}
return probs;
}

} // namespace gtsam
20 changes: 20 additions & 0 deletions gtsam/discrete/DiscreteFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,24 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
// traits
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};


/**
* @brief Normalize a set of log probabilities.
*
* Normalizing a set of log probabilities in a numerically stable way is
* tricky. To avoid overflow/underflow issues, we compute the largest
* (finite) log probability and subtract it from each log probability before
* normalizing. This comes from the observation that if:
* p_i = exp(L_i) / ( sum_j exp(L_j) ),
* Then,
* p_i = exp(Z) exp(L_i - Z) / (exp(Z) sum_j exp(L_j - Z)),
* = exp(L_i - Z) / ( sum_j exp(L_j - Z) )
*
* Setting Z = max_j L_j, we can avoid numerical issues that arise when all
* of the (unnormalized) log probabilities are either very large or very
* small.
*/
std::vector<double> expNormalize(const std::vector<double> &logProbs);


}// namespace gtsam
18 changes: 16 additions & 2 deletions gtsam/discrete/DiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,25 @@ namespace gtsam {
/* ************************************************************************* */
KeySet DiscreteFactorGraph::keys() const {
KeySet keys;
for(const sharedFactor& factor: *this)
if (factor) keys.insert(factor->begin(), factor->end());
for (const sharedFactor& factor : *this) {
if (factor) keys.insert(factor->begin(), factor->end());
}
return keys;
}

/* ************************************************************************* */
DiscreteKeys DiscreteFactorGraph::discreteKeys() const {
DiscreteKeys result;
for (auto&& factor : *this) {
if (auto p = boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
DiscreteKeys factor_keys = p->discreteKeys();
result.insert(result.end(), factor_keys.begin(), factor_keys.end());
}
}

return result;
}

/* ************************************************************************* */
DecisionTreeFactor DiscreteFactorGraph::product() const {
DecisionTreeFactor result;
Expand Down
3 changes: 3 additions & 0 deletions gtsam/discrete/DiscreteFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ class GTSAM_EXPORT DiscreteFactorGraph
/** Return the set of variables involved in the factors (set union) */
KeySet keys() const;

/// Return the DiscreteKeys in this factor graph.
DiscreteKeys discreteKeys() const;

/** return product of all factors as a single factor */
DecisionTreeFactor product() const;

Expand Down
9 changes: 7 additions & 2 deletions gtsam/discrete/DiscreteKey.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
namespace gtsam {

/**
* Key type for discrete conditionals
* Includes name and cardinality
* Key type for discrete variables.
* Includes Key and cardinality.
*/
using DiscreteKey = std::pair<Key,size_t>;

Expand All @@ -45,6 +45,11 @@ namespace gtsam {
/// Construct from a key
explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); }

/// Construct from cardinalities.
explicit DiscreteKeys(std::map<Key, size_t> cardinalities) {
for (auto&& kv : cardinalities) emplace_back(kv);
}

/// Construct from a vector of keys
DiscreteKeys(const std::vector<DiscreteKey>& keys) :
std::vector<DiscreteKey>(keys) {
Expand Down
2 changes: 2 additions & 0 deletions gtsam/discrete/DiscreteMarginals.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class GTSAM_EXPORT DiscreteMarginals {

public:

DiscreteMarginals() {}

/** Construct a marginals class.
* @param graph The factor graph defining the full joint density on all variables.
*/
Expand Down
25 changes: 25 additions & 0 deletions gtsam/discrete/tests/testDecisionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,31 @@ TEST(DecisionTree, labels) {
EXPECT_LONGS_EQUAL(2, labels.size());
}

/* ******************************************************************************** */
// Test retrieving all labels.
TEST(DecisionTree, unzip) {
using DTP = DecisionTree<string, std::pair<int, string>>;
using DT1 = DecisionTree<string, int>;
using DT2 = DecisionTree<string, string>;

// Create small two-level tree
string A("A"), B("B"), C("C");
DTP tree(B,
DTP(A, {0, "zero"}, {1, "one"}),
DTP(A, {2, "two"}, {1337, "l33t"})
);

DT1 dt1;
DT2 dt2;
std::tie(dt1, dt2) = unzip(tree);

DT1 tree1(B, DT1(A, 0, 1), DT1(A, 2, 1337));
DT2 tree2(B, DT2(A, "zero", "one"), DT2(A, "two", "l33t"));

EXPECT(tree1.equals(dt1));
EXPECT(tree2.equals(dt2));
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down
1 change: 0 additions & 1 deletion gtsam/inference/Factor.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ typedef FastSet<FactorIndex> FactorIndexSet;

/// @}

public:
/// @name Advanced Interface
/// @{

Expand Down
9 changes: 7 additions & 2 deletions gtsam/inference/FactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ class FactorGraph {
/** Collection of factors */
FastVector<sharedFactor> factors_;

/// Check exact equality of the factor pointers. Useful for derived ==.
bool isEqual(const FactorGraph& other) const {
return factors_ == other.factors_;
}

/// @name Standard Constructors
/// @{

Expand Down Expand Up @@ -290,11 +295,11 @@ class FactorGraph {
/// @name Testable
/// @{

/// print out graph
/// Print out graph to std::cout, with optional key formatter.
virtual void print(const std::string& s = "FactorGraph",
const KeyFormatter& formatter = DefaultKeyFormatter) const;

/** Check equality */
/// Check equality up to tolerance.
bool equals(const This& fg, double tol = 1e-9) const;
/// @}

Expand Down
4 changes: 2 additions & 2 deletions gtsam/inference/MetisIndex-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
namespace gtsam {

/* ************************************************************************* */
template<class FACTOR>
void MetisIndex::augment(const FactorGraph<FACTOR>& factors) {
template<class FACTORGRAPH>
void MetisIndex::augment(const FACTORGRAPH& factors) {
std::map<int32_t, std::set<int32_t> > iAdjMap; // Stores a set of keys that are adjacent to key x, with adjMap.first
std::map<int32_t, std::set<int32_t> >::iterator iAdjMapIt;
std::set<Key> keySet;
Expand Down
8 changes: 4 additions & 4 deletions gtsam/inference/MetisIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ class GTSAM_EXPORT MetisIndex {
nKeys_(0) {
}

template<class FG>
MetisIndex(const FG& factorGraph) :
template<class FACTORGRAPH>
MetisIndex(const FACTORGRAPH& factorGraph) :
nKeys_(0) {
augment(factorGraph);
}
Expand All @@ -78,8 +78,8 @@ class GTSAM_EXPORT MetisIndex {
* Augment the variable index with new factors. This can be used when
* solving problems incrementally.
*/
template<class FACTOR>
void augment(const FactorGraph<FACTOR>& factors);
template<class FACTORGRAPH>
void augment(const FACTORGRAPH& factors);

const std::vector<int32_t>& xadj() const {
return xadj_;
Expand Down
Loading