diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 60f3017f4f..d2e05927a3 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -179,5 +179,6 @@ namespace gtsam { }; // AlgebraicDecisionTree +template struct traits> : public Testable> {}; } // namespace gtsam diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 4f3e3f7f14..2607a80ef5 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -34,12 +34,13 @@ namespace gtsam { /* ******************************************************************************** */ DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials) : - DiscreteFactor(keys.indices()), Potentials(keys, potentials) { + DiscreteFactor(keys.indices()), ADT(potentials), + cardinalities_(keys.cardinalities()) { } /* *************************************************************************/ DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) : - DiscreteFactor(c.keys()), Potentials(c) { + DiscreteFactor(c.keys()), AlgebraicDecisionTree(c), cardinalities_(c.cardinalities_) { } /* ************************************************************************* */ @@ -48,16 +49,24 @@ namespace gtsam { return false; } else { - const DecisionTreeFactor& f(static_cast(other)); - return Potentials::equals(f, tol); + const auto& f(static_cast(other)); + return ADT::equals(f, tol); } } + /* ************************************************************************* */ + double DecisionTreeFactor::safe_div(const double &a, const double &b) { + // The use for safe_div is when we divide the product factor by the sum + // factor. If the product or sum is zero, we accord zero probability to the + // event. + return (a == 0 || b == 0) ? 0 : (a / b); + } + /* ************************************************************************* */ void DecisionTreeFactor::print(const string& s, const KeyFormatter& formatter) const { cout << s; - Potentials::print("Potentials:",formatter); + ADT::print("Potentials:",formatter); } /* ************************************************************************* */ @@ -162,20 +171,20 @@ namespace gtsam { void DecisionTreeFactor::dot(std::ostream& os, const KeyFormatter& keyFormatter, bool showZero) const { - Potentials::dot(os, keyFormatter, valueFormatter, showZero); + ADT::dot(os, keyFormatter, valueFormatter, showZero); } /** output to graphviz format, open a file */ void DecisionTreeFactor::dot(const std::string& name, const KeyFormatter& keyFormatter, bool showZero) const { - Potentials::dot(name, keyFormatter, valueFormatter, showZero); + ADT::dot(name, keyFormatter, valueFormatter, showZero); } /** output to graphviz format string */ std::string DecisionTreeFactor::dot(const KeyFormatter& keyFormatter, bool showZero) const { - return Potentials::dot(keyFormatter, valueFormatter, showZero); + return ADT::dot(keyFormatter, valueFormatter, showZero); } /* ************************************************************************* */ @@ -209,5 +218,15 @@ namespace gtsam { return ss.str(); } + DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const vector &table) : + DiscreteFactor(keys.indices()), AlgebraicDecisionTree(keys, table), + cardinalities_(keys.cardinalities()) { + } + + DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const string &table) : + DiscreteFactor(keys.indices()), AlgebraicDecisionTree(keys, table), + cardinalities_(keys.cardinalities()) { + } + /* ************************************************************************* */ } // namespace gtsam diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index f8832c2237..f7c50d5b5f 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -19,7 +19,8 @@ #pragma once #include -#include +#include +#include #include #include @@ -35,7 +36,7 @@ namespace gtsam { /** * A discrete probabilistic factor */ - class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public Potentials { + class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public AlgebraicDecisionTree { public: @@ -43,6 +44,10 @@ namespace gtsam { typedef DecisionTreeFactor This; typedef DiscreteFactor Base; ///< Typedef to base class typedef boost::shared_ptr shared_ptr; + typedef AlgebraicDecisionTree ADT; + + protected: + std::map cardinalities_; public: @@ -55,11 +60,11 @@ namespace gtsam { /** Constructor from Indices, Ordering, and AlgebraicDecisionDiagram */ DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials); - /** Constructor from Indices and (string or doubles) */ - template - DecisionTreeFactor(const DiscreteKeys& keys, SOURCE table) : - DiscreteFactor(keys.indices()), Potentials(keys, table) { - } + /** Constructor from doubles */ + DecisionTreeFactor(const DiscreteKeys& keys, const std::vector& table); + + /** Constructor from string */ + DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table); /// Single-key specialization template @@ -71,7 +76,7 @@ namespace gtsam { : DecisionTreeFactor(DiscreteKeys{key}, row) {} /** Construct from a DiscreteConditional type */ - DecisionTreeFactor(const DiscreteConditional& c); + explicit DecisionTreeFactor(const DiscreteConditional& c); /// @} /// @name Testable @@ -90,7 +95,7 @@ namespace gtsam { /// Value is just look up in AlgebraicDecisonTree double operator()(const DiscreteValues& values) const override { - return Potentials::operator()(values); + return ADT::operator()(values); } /// multiply two factors @@ -98,6 +103,10 @@ namespace gtsam { return apply(f, ADT::Ring::mul); } + static double safe_div(const double& a, const double& b); + + size_t cardinality(Key j) const { return cardinalities_.at(j);} + /// divide by factor f (safely) DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { return apply(f, safe_div); diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 8ee93eb774..951c0b6cab 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -80,7 +80,7 @@ void DiscreteConditional::print(const string& s, } } cout << ")"; - Potentials::print(""); + ADT::print(""); cout << endl; } diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 7ce3dc9308..4c2e964fda 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -128,7 +128,7 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor, /// Evaluate, just look up in AlgebraicDecisonTree double operator()(const DiscreteValues& values) const override { - return Potentials::operator()(values); + return ADT::operator()(values); } /** Convert to a factor */ diff --git a/gtsam/discrete/Potentials.cpp b/gtsam/discrete/Potentials.cpp deleted file mode 100644 index 057b6a2655..0000000000 --- a/gtsam/discrete/Potentials.cpp +++ /dev/null @@ -1,96 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file Potentials.cpp - * @date March 24, 2011 - * @author Frank Dellaert - */ - -#include -#include - -#include - -#include - -using namespace std; - -namespace gtsam { - -/* ************************************************************************* */ -double Potentials::safe_div(const double& a, const double& b) { - // cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b)); - // The use for safe_div is when we divide the product factor by the sum - // factor. If the product or sum is zero, we accord zero probability to the - // event. - return (a == 0 || b == 0) ? 0 : (a / b); -} - -/* ******************************************************************************** - */ -Potentials::Potentials() : ADT(1.0) {} - -/* ******************************************************************************** - */ -Potentials::Potentials(const DiscreteKeys& keys, const ADT& decisionTree) - : ADT(decisionTree), cardinalities_(keys.cardinalities()) {} - -/* ************************************************************************* */ -bool Potentials::equals(const Potentials& other, double tol) const { - return ADT::equals(other, tol); -} - -/* ************************************************************************* */ -void Potentials::print(const string& s, const KeyFormatter& formatter) const { - cout << s << "\n Cardinalities: { "; - for (const std::pair& key : cardinalities_) - cout << formatter(key.first) << ":" << key.second << ", "; - cout << "}" << endl; - ADT::print(" ", formatter); -} -// -// /* ************************************************************************* */ -// template -// void Potentials::remapIndices(const P& remapping) { -// // Permute the _cardinalities (TODO: Inefficient Consider Improving) -// DiscreteKeys keys; -// map ordering; -// -// // Get the original keys from cardinalities_ -// for(const DiscreteKey& key: cardinalities_) -// keys & key; -// -// // Perform Permutation -// for(DiscreteKey& key: keys) { -// ordering[key.first] = remapping[key.first]; -// key.first = ordering[key.first]; -// } -// -// // Change *this -// AlgebraicDecisionTree permuted((*this), ordering); -// *this = permuted; -// cardinalities_ = keys.cardinalities(); -// } -// -// /* ************************************************************************* */ -// void Potentials::permuteWithInverse(const Permutation& inversePermutation) { -// remapIndices(inversePermutation); -// } -// -// /* ************************************************************************* */ -// void Potentials::reduceWithInverse(const internal::Reduction& inverseReduction) { -// remapIndices(inverseReduction); -// } - - /* ************************************************************************* */ - -} // namespace gtsam diff --git a/gtsam/discrete/Potentials.h b/gtsam/discrete/Potentials.h deleted file mode 100644 index 856b928168..0000000000 --- a/gtsam/discrete/Potentials.h +++ /dev/null @@ -1,97 +0,0 @@ -/* ---------------------------------------------------------------------------- - - * GTSAM Copyright 2010, Georgia Tech Research Corporation, - * Atlanta, Georgia 30332-0415 - * All Rights Reserved - * Authors: Frank Dellaert, et al. (see THANKS for the full author list) - - * See LICENSE for the license information - - * -------------------------------------------------------------------------- */ - -/** - * @file Potentials.h - * @date March 24, 2011 - * @author Frank Dellaert - */ - -#pragma once - -#include -#include -#include - -#include -#include - -namespace gtsam { - - /** - * A base class for both DiscreteFactor and DiscreteConditional - */ - class GTSAM_EXPORT Potentials: public AlgebraicDecisionTree { - - public: - - typedef AlgebraicDecisionTree ADT; - - protected: - - /// Cardinality for each key, used in combine - std::map cardinalities_; - - /** Constructor from ColumnIndex, and ADT */ - Potentials(const ADT& potentials) : - ADT(potentials) { - } - - // Safe division for probabilities - static double safe_div(const double& a, const double& b); - -// // Apply either a permutation or a reduction -// template -// void remapIndices(const P& remapping); - - public: - - /** Default constructor for I/O */ - Potentials(); - - /** Constructor from Indices and ADT */ - Potentials(const DiscreteKeys& keys, const ADT& decisionTree); - - /** Constructor from Indices and (string or doubles) */ - template - Potentials(const DiscreteKeys& keys, SOURCE table) : - ADT(keys, table), cardinalities_(keys.cardinalities()) { - } - - // Testable - bool equals(const Potentials& other, double tol = 1e-9) const; - void print(const std::string& s = "Potentials: ", - const KeyFormatter& formatter = DefaultKeyFormatter) const; - - size_t cardinality(Key j) const { return cardinalities_.at(j);} - -// /** -// * @brief Permutes the keys in Potentials -// * -// * This permutes the Indices and performs necessary re-ordering of ADD. -// * This is virtual so that derived types e.g. DecisionTreeFactor can -// * re-implement it. -// */ -// GTSAM_EXPORT virtual void permuteWithInverse(const Permutation& inversePermutation); -// -// /** -// * Apply a reduction, which is a remapping of variable indices. -// */ -// GTSAM_EXPORT virtual void reduceWithInverse(const internal::Reduction& inverseReduction); - - }; // Potentials - -// traits -template<> struct traits : public Testable {}; -template<> struct traits : public Testable {}; - - -} // namespace gtsam diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 251978c99b..0686b3920c 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -41,21 +41,23 @@ using namespace gtsam; static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2); +using ADT = AlgebraicDecisionTree; + /* ************************************************************************* */ TEST(DiscreteBayesNet, bayesNet) { DiscreteBayesNet bayesNet; DiscreteKey Parent(0, 2), Child(1, 2); auto prior = boost::make_shared(Parent % "6/4"); - CHECK(assert_equal(Potentials::ADT({Parent}, "0.6 0.4"), - (Potentials::ADT)*prior)); + CHECK(assert_equal(ADT({Parent}, "0.6 0.4"), + (ADT)*prior)); bayesNet.push_back(prior); auto conditional = boost::make_shared(Child | Parent = "7/3 8/2"); EXPECT_LONGS_EQUAL(1, *(conditional->beginFrontals())); - Potentials::ADT expected(Child & Parent, "0.7 0.8 0.3 0.2"); - CHECK(assert_equal(expected, (Potentials::ADT)*conditional)); + ADT expected(Child & Parent, "0.7 0.8 0.3 0.2"); + CHECK(assert_equal(expected, (ADT)*conditional)); bayesNet.push_back(conditional); DiscreteFactorGraph fg(bayesNet); diff --git a/gtsam/inference/BayesTree.h b/gtsam/inference/BayesTree.h index 68a45a014a..a32b3ce22b 100644 --- a/gtsam/inference/BayesTree.h +++ b/gtsam/inference/BayesTree.h @@ -143,7 +143,7 @@ namespace gtsam { const Nodes& nodes() const { return nodes_; } /** Access node by variable */ - const sharedNode operator[](Key j) const { return nodes_.at(j); } + sharedClique operator[](Key j) const { return nodes_.at(j); } /** return root cliques */ const Roots& roots() const { return roots_; } diff --git a/gtsam_unstable/discrete/Scheduler.cpp b/gtsam_unstable/discrete/Scheduler.cpp index e34613c3b3..f166405932 100644 --- a/gtsam_unstable/discrete/Scheduler.cpp +++ b/gtsam_unstable/discrete/Scheduler.cpp @@ -130,9 +130,9 @@ void Scheduler::addStudentSpecificConstraints(size_t i, // get all constraints then specialize to slot size_t dummyIndex = maxNrStudents_ * 3 + maxNrStudents_; DiscreteKey dummy(dummyIndex, nrTimeSlots()); - Potentials::ADT p(dummy & areaKey, + AlgebraicDecisionTree p(dummy & areaKey, available_); // available_ is Doodle string - Potentials::ADT q = p.choose(dummyIndex, *slot); + auto q = p.choose(dummyIndex, *slot); CSP::add(areaKey, q); } else { DiscreteKeys keys {s.key_, areaKey};