Skip to content

Commit

Permalink
Merge pull request #930 from borglab/fix/formatting/discrete
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert authored Nov 18, 2021
2 parents 13b0136 + aebcf07 commit 770fda9
Show file tree
Hide file tree
Showing 16 changed files with 1,274 additions and 1,348 deletions.
170 changes: 84 additions & 86 deletions gtsam_unstable/discrete/AllDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,107 +5,105 @@
* @author Frank Dellaert
*/

#include <gtsam_unstable/discrete/Domain.h>
#include <gtsam_unstable/discrete/AllDiff.h>
#include <gtsam/base/Testable.h>
#include <gtsam_unstable/discrete/AllDiff.h>
#include <gtsam_unstable/discrete/Domain.h>

#include <boost/make_shared.hpp>

namespace gtsam {

/* ************************************************************************* */
AllDiff::AllDiff(const DiscreteKeys& dkeys) :
Constraint(dkeys.indices()) {
for(const DiscreteKey& dkey: dkeys)
cardinalities_.insert(dkey);
}
/* ************************************************************************* */
AllDiff::AllDiff(const DiscreteKeys& dkeys) : Constraint(dkeys.indices()) {
for (const DiscreteKey& dkey : dkeys) cardinalities_.insert(dkey);
}

/* ************************************************************************* */
void AllDiff::print(const std::string& s,
const KeyFormatter& formatter) const {
std::cout << s << "AllDiff on ";
for (Key dkey: keys_)
std::cout << formatter(dkey) << " ";
std::cout << std::endl;
}
/* ************************************************************************* */
void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const {
std::cout << s << "AllDiff on ";
for (Key dkey : keys_) std::cout << formatter(dkey) << " ";
std::cout << std::endl;
}

/* ************************************************************************* */
double AllDiff::operator()(const Values& values) const {
std::set < size_t > taken; // record values taken by keys
for(Key dkey: keys_) {
size_t value = values.at(dkey); // get the value for that key
if (taken.count(value)) return 0.0;// check if value alreday taken
taken.insert(value);// if not, record it as taken and keep checking
}
return 1.0;
/* ************************************************************************* */
double AllDiff::operator()(const Values& values) const {
std::set<size_t> taken; // record values taken by keys
for (Key dkey : keys_) {
size_t value = values.at(dkey); // get the value for that key
if (taken.count(value)) return 0.0; // check if value alreday taken
taken.insert(value); // if not, record it as taken and keep checking
}
return 1.0;
}

/* ************************************************************************* */
DecisionTreeFactor AllDiff::toDecisionTreeFactor() const {
// We will do this by converting the allDif into many BinaryAllDiff constraints
DecisionTreeFactor converted;
size_t nrKeys = keys_.size();
for (size_t i1 = 0; i1 < nrKeys; i1++)
for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) {
BinaryAllDiff binary12(discreteKey(i1),discreteKey(i2));
converted = converted * binary12.toDecisionTreeFactor();
}
return converted;
}
/* ************************************************************************* */
DecisionTreeFactor AllDiff::toDecisionTreeFactor() const {
// We will do this by converting the allDif into many BinaryAllDiff
// constraints
DecisionTreeFactor converted;
size_t nrKeys = keys_.size();
for (size_t i1 = 0; i1 < nrKeys; i1++)
for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) {
BinaryAllDiff binary12(discreteKey(i1), discreteKey(i2));
converted = converted * binary12.toDecisionTreeFactor();
}
return converted;
}

/* ************************************************************************* */
DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const {
// TODO: can we do this more efficiently?
return toDecisionTreeFactor() * f;
}
/* ************************************************************************* */
DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const {
// TODO: can we do this more efficiently?
return toDecisionTreeFactor() * f;
}

/* ************************************************************************* */
bool AllDiff::ensureArcConsistency(size_t j, std::vector<Domain>& domains) const {
// Though strictly not part of allDiff, we check for
// a value in domains[j] that does not occur in any other connected domain.
// If found, we make this a singleton...
// TODO: make a new constraint where this really is true
Domain& Dj = domains[j];
if (Dj.checkAllDiff(keys_, domains)) return true;
/* ************************************************************************* */
bool AllDiff::ensureArcConsistency(size_t j,
std::vector<Domain>& domains) const {
// Though strictly not part of allDiff, we check for
// a value in domains[j] that does not occur in any other connected domain.
// If found, we make this a singleton...
// TODO: make a new constraint where this really is true
Domain& Dj = domains[j];
if (Dj.checkAllDiff(keys_, domains)) return true;

// Check all other domains for singletons and erase corresponding values
// This is the same as arc-consistency on the equivalent binary constraints
bool changed = false;
for(Key k: keys_)
if (k != j) {
const Domain& Dk = domains[k];
if (Dk.isSingleton()) { // check if singleton
size_t value = Dk.firstValue();
if (Dj.contains(value)) {
Dj.erase(value); // erase value if true
changed = true;
}
// Check all other domains for singletons and erase corresponding values
// This is the same as arc-consistency on the equivalent binary constraints
bool changed = false;
for (Key k : keys_)
if (k != j) {
const Domain& Dk = domains[k];
if (Dk.isSingleton()) { // check if singleton
size_t value = Dk.firstValue();
if (Dj.contains(value)) {
Dj.erase(value); // erase value if true
changed = true;
}
}
return changed;
}
}
return changed;
}

/* ************************************************************************* */
Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const {
DiscreteKeys newKeys;
// loop over keys and add them only if they do not appear in values
for(Key k: keys_)
if (values.find(k) == values.end()) {
newKeys.push_back(DiscreteKey(k,cardinalities_.at(k)));
}
return boost::make_shared<AllDiff>(newKeys);
}
/* ************************************************************************* */
Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const {
DiscreteKeys newKeys;
// loop over keys and add them only if they do not appear in values
for (Key k : keys_)
if (values.find(k) == values.end()) {
newKeys.push_back(DiscreteKey(k, cardinalities_.at(k)));
}
return boost::make_shared<AllDiff>(newKeys);
}

/* ************************************************************************* */
Constraint::shared_ptr AllDiff::partiallyApply(
const std::vector<Domain>& domains) const {
DiscreteFactor::Values known;
for(Key k: keys_) {
const Domain& Dk = domains[k];
if (Dk.isSingleton())
known[k] = Dk.firstValue();
}
return partiallyApply(known);
/* ************************************************************************* */
Constraint::shared_ptr AllDiff::partiallyApply(
const std::vector<Domain>& domains) const {
DiscreteFactor::Values known;
for (Key k : keys_) {
const Domain& Dk = domains[k];
if (Dk.isSingleton()) known[k] = Dk.firstValue();
}
return partiallyApply(known);
}

/* ************************************************************************* */
} // namespace gtsam
/* ************************************************************************* */
} // namespace gtsam
112 changes: 56 additions & 56 deletions gtsam_unstable/discrete/AllDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,71 +7,71 @@

#pragma once

#include <gtsam_unstable/discrete/BinaryAllDiff.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam_unstable/discrete/BinaryAllDiff.h>

namespace gtsam {

/**
* General AllDiff constraint
* Returns 1 if values for all keys are different, 0 otherwise
* DiscreteFactors are all awkward in that they have to store two types of keys:
* for each variable we have a Key and an Key. In this factor, we
* keep the Indices locally, and the Indices are stored in IndexFactor.
*/
class GTSAM_UNSTABLE_EXPORT AllDiff: public Constraint {

std::map<Key,size_t> cardinalities_;

DiscreteKey discreteKey(size_t i) const {
Key j = keys_[i];
return DiscreteKey(j,cardinalities_.at(j));
}

public:

/// Constructor
AllDiff(const DiscreteKeys& dkeys);

// print
void print(const std::string& s = "",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;

/// equals
bool equals(const DiscreteFactor& other, double tol) const override {
if(!dynamic_cast<const AllDiff*>(&other))
return false;
else {
const AllDiff& f(static_cast<const AllDiff&>(other));
return cardinalities_.size() == f.cardinalities_.size()
&& std::equal(cardinalities_.begin(), cardinalities_.end(),
f.cardinalities_.begin());
}
/**
* General AllDiff constraint
* Returns 1 if values for all keys are different, 0 otherwise
* DiscreteFactors are all awkward in that they have to store two types of keys:
* for each variable we have a Key and an Key. In this factor, we
* keep the Indices locally, and the Indices are stored in IndexFactor.
*/
class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
std::map<Key, size_t> cardinalities_;

DiscreteKey discreteKey(size_t i) const {
Key j = keys_[i];
return DiscreteKey(j, cardinalities_.at(j));
}

public:
/// Constructor
AllDiff(const DiscreteKeys& dkeys);

// print
void print(const std::string& s = "", const KeyFormatter& formatter =
DefaultKeyFormatter) const override;

/// equals
bool equals(const DiscreteFactor& other, double tol) const override {
if (!dynamic_cast<const AllDiff*>(&other))
return false;
else {
const AllDiff& f(static_cast<const AllDiff&>(other));
return cardinalities_.size() == f.cardinalities_.size() &&
std::equal(cardinalities_.begin(), cardinalities_.end(),
f.cardinalities_.begin());
}
}

/// Calculate value = expensive !
double operator()(const Values& values) const override;
/// Calculate value = expensive !
double operator()(const Values& values) const override;

/// Convert into a decisiontree, can be *very* expensive !
DecisionTreeFactor toDecisionTreeFactor() const override;
/// Convert into a decisiontree, can be *very* expensive !
DecisionTreeFactor toDecisionTreeFactor() const override;

/// Multiply into a decisiontree
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
/// Multiply into a decisiontree
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;

/*
* Ensure Arc-consistency
* Arc-consistency involves creating binaryAllDiff constraints
* In which case the combinatorial hyper-arc explosion disappears.
* @param j domain to be checked
* @param domains all other domains
*/
bool ensureArcConsistency(size_t j, std::vector<Domain>& domains) const override;
/*
* Ensure Arc-consistency
* Arc-consistency involves creating binaryAllDiff constraints
* In which case the combinatorial hyper-arc explosion disappears.
* @param j domain to be checked
* @param domains all other domains
*/
bool ensureArcConsistency(size_t j,
std::vector<Domain>& domains) const override;

/// Partially apply known values
Constraint::shared_ptr partiallyApply(const Values&) const override;
/// Partially apply known values
Constraint::shared_ptr partiallyApply(const Values&) const override;

/// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply(const std::vector<Domain>&) const override;
};
/// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply(
const std::vector<Domain>&) const override;
};

} // namespace gtsam
} // namespace gtsam
Loading

0 comments on commit 770fda9

Please sign in to comment.