Skip to content

Commit

Permalink
Merge pull request #37 from varunagrawal/debug-pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Mar 22, 2022
2 parents b1499a4 + 092536e commit 242e84c
Show file tree
Hide file tree
Showing 6 changed files with 442 additions and 332 deletions.
6 changes: 4 additions & 2 deletions gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -633,10 +633,12 @@ namespace gtsam {
using LY = DecisionTree<L, Y>;

// ugliness below because apparently we can't have templated virtual
// functions If leaf, apply unary conversion "op" and create a unique leaf
// functions
// If leaf, apply unary conversion "op" and create a unique leaf
using MXLeaf = typename DecisionTree<M, X>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f))
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f)) {
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
}

// Check if Choice
using MXChoice = typename DecisionTree<M, X>::Choice;
Expand Down
6 changes: 6 additions & 0 deletions gtsam/discrete/DecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,12 @@ namespace gtsam {
template <typename Func>
void visitWith(Func f) const;

size_t nrLeaves() {
size_t total = 0;
visit([&total](const Y& node) { total += 1; });
return total;
}

/**
* @brief Fold a binary function over the tree, returning accumulator.
*
Expand Down
27 changes: 27 additions & 0 deletions gtsam/hybrid/GaussianHybridFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
* @date January 2022
*/

#include <gtsam/base/utilities.h>

#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/hybrid/DCGaussianMixtureFactor.h>
Expand Down Expand Up @@ -66,6 +68,7 @@ static Sum& operator+=(Sum& sum, const GaussianFactor::shared_ptr& factor) {
}

Sum GaussianHybridFactorGraph::sum() const {
gttic_(Sum);
// "sum" all factors, gathering into GaussianFactorGraph
DCGaussianMixtureFactor::Sum sum;
for (auto&& dcFactor : dcGraph()) {
Expand Down Expand Up @@ -112,6 +115,8 @@ ostream& operator<<(ostream& os,
pair<AbstractConditional::shared_ptr, boost::shared_ptr<Factor>>
EliminateHybrid(const GaussianHybridFactorGraph& factors,
const Ordering& ordering) {

ordering.print("\nEliminating:");
// STEP 1: SUM
// Create a new decision tree with all factors gathered at leaves.
Sum sum = factors.sum();
Expand All @@ -128,6 +133,7 @@ EliminateHybrid(const GaussianHybridFactorGraph& factors,
// TODO(fan): Now let's assume that all continuous will be eliminated first!
// Here sum is null if remaining are all discrete
if (sum.empty()) {
gttic_(DFG);
// Not sure if this is the correct thing, but anyway!
DiscreteFactorGraph dfg;
dfg.push_back(factors.discreteGraph());
Expand All @@ -152,6 +158,7 @@ EliminateHybrid(const GaussianHybridFactorGraph& factors,
KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)?
auto eliminate = [&](const GaussianFactorGraph& graph)
-> GaussianFactorGraph::EliminationResult {
gttic_(Eliminate);
if (graph.empty()) return {nullptr, nullptr};
auto result = EliminatePreferCholesky(graph, ordering);
if (keysOfEliminated.empty())
Expand All @@ -161,10 +168,30 @@ EliminateHybrid(const GaussianHybridFactorGraph& factors,
if (keysOfSeparator.empty()) keysOfSeparator = result.second->keys();
return result;
};

auto valueFormatter = [&](const GaussianFactorGraph &v) {
auto printCapture = [&](const GaussianFactorGraph &p) {
RedirectCout rd;
p.print("", DefaultKeyFormatter);
std::string s = rd.str();
return s;
};

std::string format_template = "Gaussian factor graph with %d factors:\n%s\n";
return (boost::format(format_template) % v.size() % printCapture(v)).str();
};
sum.print(">>>>>>>>>>>>>\n", DefaultKeyFormatter, valueFormatter);

gttic_(EliminationResult);
std::cout << ">>>>>>> nrLeaves in `sum`: " << sum.nrLeaves() << std::endl;
DecisionTree<Key, EliminationPair> eliminationResults(sum, eliminate);
// std::cout << "Elimination done!!!!!!!\n\n" << std::endl;
gttoc_(EliminationResult);

gttic_(Leftover);
// STEP 3: Create result
auto pair = unzip(eliminationResults);

const GaussianMixture::Conditionals& conditionals = pair.first;
const DCGaussianMixtureFactor::Factors& separatorFactors = pair.second;

Expand Down
44 changes: 22 additions & 22 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,28 +74,28 @@ void GaussianMixture::print(const std::string &s,
std::cout << "]";
std::cout << "{\n";

auto valueFormatter = [&](const GaussianFactor::shared_ptr &v) {
auto printCapture = [&](const GaussianFactor::shared_ptr &p) {
RedirectCout rd;
p->print("", keyFormatter);
std::string s = rd.str();
return s;
};

std::string format_template = "Gaussian factor on %d keys: \n%s\n";

if (auto hessianFactor = boost::dynamic_pointer_cast<HessianFactor>(v)) {
format_template = "Hessian factor on %d keys: \n%s\n";
}

if (auto jacobianFactor = boost::dynamic_pointer_cast<JacobianFactor>(v)) {
format_template = "Jacobian factor on %d keys: \n%s\n";
}
if (!v) return std::string {"nullptr\n"};
return (boost::format(format_template) % v->size() % printCapture(v)).str();
};

factors_.print("", keyFormatter, valueFormatter);
// auto valueFormatter = [&](const GaussianFactor::shared_ptr &v) {
// auto printCapture = [&](const GaussianFactor::shared_ptr &p) {
// RedirectCout rd;
// p->print("", keyFormatter);
// std::string s = rd.str();
// return s;
// };

// std::string format_template = "Gaussian factor on %d keys: \n%s\n";

// if (auto hessianFactor = boost::dynamic_pointer_cast<HessianFactor>(v)) {
// format_template = "Hessian factor on %d keys: \n%s\n";
// }

// if (auto jacobianFactor = boost::dynamic_pointer_cast<JacobianFactor>(v)) {
// format_template = "Jacobian factor on %d keys: \n%s\n";
// }
// if (!v) return std::string {"nullptr\n"};
// return (boost::format(format_template) % v->size() % printCapture(v)).str();
// };

// factors_.print("", keyFormatter, valueFormatter);
std::cout << "}\n";
}

Expand Down
44 changes: 32 additions & 12 deletions gtsam/hybrid/IncrementalHybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,48 +72,67 @@ void IncrementalHybrid::update(GaussianHybridFactorGraph graph,
}
}

gttic_(Elimination);
// Eliminate partially.
HybridBayesNet::shared_ptr bayesNetFragment;
graph.print("\n>>>>");
std::cout << "\n\n" << std::endl;
auto result = graph.eliminatePartialSequential(ordering);
bayesNetFragment = result.first;
remainingFactorGraph_ = *result.second;

gttoc_(Elimination);

// Add the partial bayes net to the posterior bayes net.
hybridBayesNet_.push_back<HybridBayesNet>(*bayesNetFragment);

// Prune
if (maxNrLeaves) {
const auto N = *maxNrLeaves;

// Check if discreteGraph is empty. Possible if no discrete variables.
if (remainingFactorGraph_.discreteGraph().empty()) return;

const auto lastDensity =
boost::dynamic_pointer_cast<GaussianMixture>(hybridBayesNet_.back());

// Check if discreteGraph exists. Possible that `update` had no DCFactors or Discrete Factors.
if (remainingFactorGraph_.discreteGraph().size() == 0) return;

auto discreteFactor = boost::dynamic_pointer_cast<DecisionTreeFactor>(
remainingFactorGraph_.discreteGraph().at(0));

std::cout << "Initial number of leaves: " << discreteFactor->nrLeaves()
<< std::endl;

// Let's assume that the structure of the last discrete density will be the
// same as the last continuous
std::vector<double> probabilities;
// TODO(fan): The number of probabilities can be lower than the actual
// number of choices
discreteFactor->visit(
[&](const double &prob) { probabilities.emplace_back(prob); });

if (probabilities.size() < N) return;
// The number of probabilities can be lower than max_leaves
if (probabilities.size() <= N) return;

std::nth_element(probabilities.begin(), probabilities.begin() + N,
probabilities.end(), std::greater<double>{});
std::sort(probabilities.begin(), probabilities.end(),
std::greater<double>{});

auto thresholdValue = probabilities[N - 1];
double threshold = probabilities[N - 1];

// Now threshold
auto threshold = [thresholdValue](const double &value) {
return value < thresholdValue ? 0.0 : value;
// Now threshold the decision tree
size_t total = 0;
auto thresholdFunc = [threshold, &total, N](const double &value) {
if (value < threshold || total >= N) {
return 0.0;
} else {
total += 1;
return value;
}
};
DecisionTree<Key, double> thresholded(*discreteFactor, threshold);
DecisionTree<Key, double> thresholded(*discreteFactor, thresholdFunc);
size_t nrPrunedLeaves = 0;
thresholded.visit([&nrPrunedLeaves](const double &d) {
if (d > 0) nrPrunedLeaves += 1;
});
std::cout << "Leaves after pruning: " << nrPrunedLeaves << std::endl;

// Create a new factor with pruned tree
// DecisionTreeFactor newFactor(discreteFactor->discreteKeys(),
Expand Down Expand Up @@ -142,6 +161,7 @@ void IncrementalHybrid::update(GaussianHybridFactorGraph graph,
hybridBayesNet_.atGaussian(hybridBayesNet_.size() - 1)->factors_ =
prunedConditionalsTree;
}
tictoc_print_();
}

/* ************************************************************************* */
Expand Down
Loading

0 comments on commit 242e84c

Please sign in to comment.