Skip to content

Commit

Permalink
Merge pull request #117 from borglab/feature/wrap-mestimator-weight
Browse files Browse the repository at this point in the history
Wrap mEstimators
  • Loading branch information
varunagrawal authored Oct 10, 2019
2 parents 38cf6bd + 4c9f9ec commit a4ac57c
Show file tree
Hide file tree
Showing 6 changed files with 383 additions and 132 deletions.
58 changes: 58 additions & 0 deletions gtsam.h
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,9 @@ virtual class Null: gtsam::noiseModel::mEstimator::Base {

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
double residual(double error) const;
};

virtual class Fair: gtsam::noiseModel::mEstimator::Base {
Expand All @@ -1370,6 +1373,9 @@ virtual class Fair: gtsam::noiseModel::mEstimator::Base {

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
double residual(double error) const;
};

virtual class Huber: gtsam::noiseModel::mEstimator::Base {
Expand All @@ -1378,6 +1384,20 @@ virtual class Huber: gtsam::noiseModel::mEstimator::Base {

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
double residual(double error) const;
};

virtual class Cauchy: gtsam::noiseModel::mEstimator::Base {
Cauchy(double k);
static gtsam::noiseModel::mEstimator::Cauchy* Create(double k);

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
double residual(double error) const;
};

virtual class Tukey: gtsam::noiseModel::mEstimator::Base {
Expand All @@ -1386,6 +1406,9 @@ virtual class Tukey: gtsam::noiseModel::mEstimator::Base {

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
double residual(double error) const;
};

virtual class Welsch: gtsam::noiseModel::mEstimator::Base {
Expand All @@ -1394,8 +1417,43 @@ virtual class Welsch: gtsam::noiseModel::mEstimator::Base {

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
double residual(double error) const;
};

virtual class GemanMcClure: gtsam::noiseModel::mEstimator::Base {
GemanMcClure(double c);
static gtsam::noiseModel::mEstimator::GemanMcClure* Create(double c);

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
double residual(double error) const;
};

virtual class DCS: gtsam::noiseModel::mEstimator::Base {
DCS(double c);
static gtsam::noiseModel::mEstimator::DCS* Create(double c);

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
double residual(double error) const;
};

virtual class L2WithDeadZone: gtsam::noiseModel::mEstimator::Base {
L2WithDeadZone(double k);
static gtsam::noiseModel::mEstimator::L2WithDeadZone* Create(double k);

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
double residual(double error) const;
};

}///\namespace mEstimator

Expand Down
129 changes: 106 additions & 23 deletions gtsam/linear/NoiseModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -718,15 +718,26 @@ void Null::print(const std::string &s="") const
Null::shared_ptr Null::Create()
{ return shared_ptr(new Null()); }

/* ************************************************************************* */
// Fair
/* ************************************************************************* */

Fair::Fair(double c, const ReweightScheme reweight) : Base(reweight), c_(c) {
if (c_ <= 0) {
throw runtime_error("mEstimator Fair takes only positive double in constructor.");
}
}

/* ************************************************************************* */
// Fair
/* ************************************************************************* */
double Fair::weight(double error) const {
return 1.0 / (1.0 + std::abs(error) / c_);
}

double Fair::residual(double error) const {
const double absError = std::abs(error);
const double normalizedError = absError / c_;
const double c_2 = c_ * c_;
return c_2 * (normalizedError - std::log(1 + normalizedError));
}

void Fair::print(const std::string &s="") const
{ cout << s << "fair (" << c_ << ")" << endl; }
Expand All @@ -750,6 +761,20 @@ Huber::Huber(double k, const ReweightScheme reweight) : Base(reweight), k_(k) {
}
}

double Huber::weight(double error) const {
const double absError = std::abs(error);
return (absError <= k_) ? (1.0) : (k_ / absError);
}

double Huber::residual(double error) const {
const double absError = std::abs(error);
if (absError <= k_) { // |x| <= k
return error*error / 2;
} else { // |x| > k
return k_ * (absError - (k_/2));
}
}

void Huber::print(const std::string &s="") const {
cout << s << "huber (" << k_ << ")" << endl;
}
Expand All @@ -774,6 +799,16 @@ Cauchy::Cauchy(double k, const ReweightScheme reweight) : Base(reweight), k_(k),
}
}

double Cauchy::weight(double error) const {
return ksquared_ / (ksquared_ + error*error);
}

double Cauchy::residual(double error) const {
const double xc2 = error / k_;
const double val = std::log(1 + (xc2*xc2));
return ksquared_ * val * 0.5;
}

void Cauchy::print(const std::string &s="") const {
cout << s << "cauchy (" << k_ << ")" << endl;
}
Expand All @@ -791,7 +826,31 @@ Cauchy::shared_ptr Cauchy::Create(double c, const ReweightScheme reweight) {
/* ************************************************************************* */
// Tukey
/* ************************************************************************* */
Tukey::Tukey(double c, const ReweightScheme reweight) : Base(reweight), c_(c), csquared_(c * c) {}

Tukey::Tukey(double c, const ReweightScheme reweight) : Base(reweight), c_(c), csquared_(c * c) {
if (c <= 0) {
throw runtime_error("mEstimator Tukey takes only positive double in constructor.");
}
}

double Tukey::weight(double error) const {
if (std::abs(error) <= c_) {
const double xc2 = error*error/csquared_;
return (1.0-xc2)*(1.0-xc2);
}
return 0.0;
}

double Tukey::residual(double error) const {
double absError = std::abs(error);
if (absError <= c_) {
const double xc2 = error*error/csquared_;
const double t = (1 - xc2)*(1 - xc2)*(1 - xc2);
return csquared_ * (1 - t) / 6.0;
} else {
return csquared_ / 6.0;
}
}

void Tukey::print(const std::string &s="") const {
std::cout << s << ": Tukey (" << c_ << ")" << std::endl;
Expand All @@ -810,8 +869,19 @@ Tukey::shared_ptr Tukey::Create(double c, const ReweightScheme reweight) {
/* ************************************************************************* */
// Welsch
/* ************************************************************************* */

Welsch::Welsch(double c, const ReweightScheme reweight) : Base(reweight), c_(c), csquared_(c * c) {}

double Welsch::weight(double error) const {
const double xc2 = (error*error)/csquared_;
return std::exp(-xc2);
}

double Welsch::residual(double error) const {
const double xc2 = (error*error)/csquared_;
return csquared_ * 0.5 * (1 - std::exp(-xc2) );
}

void Welsch::print(const std::string &s="") const {
std::cout << s << ": Welsch (" << c_ << ")" << std::endl;
}
Expand All @@ -826,24 +896,6 @@ Welsch::shared_ptr Welsch::Create(double c, const ReweightScheme reweight) {
return shared_ptr(new Welsch(c, reweight));
}

#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V4
Welsh::Welsh(double c, const ReweightScheme reweight) : Base(reweight), c_(c), csquared_(c * c) {}

void Welsh::print(const std::string &s="") const {
std::cout << s << ": Welsh (" << c_ << ")" << std::endl;
}

bool Welsh::equals(const Base &expected, double tol) const {
const Welsh* p = dynamic_cast<const Welsh*>(&expected);
if (p == NULL) return false;
return std::abs(c_ - p->c_) < tol;
}

Welsh::shared_ptr Welsh::Create(double c, const ReweightScheme reweight) {
return shared_ptr(new Welsh(c, reweight));
}
#endif

/* ************************************************************************* */
// GemanMcClure
/* ************************************************************************* */
Expand All @@ -858,6 +910,12 @@ double GemanMcClure::weight(double error) const {
return c4/(c2error*c2error);
}

double GemanMcClure::residual(double error) const {
const double c2 = c_*c_;
const double error2 = error*error;
return 0.5 * (c2 * error2) / (c2 + error2);
}

void GemanMcClure::print(const std::string &s="") const {
std::cout << s << ": Geman-McClure (" << c_ << ")" << std::endl;
}
Expand Down Expand Up @@ -890,6 +948,16 @@ double DCS::weight(double error) const {
return 1.0;
}

double DCS::residual(double error) const {
// This is the simplified version of Eq 9 from (Agarwal13icra)
// after you simplify and cancel terms.
const double e2 = error*error;
const double e4 = e2*e2;
const double c2 = c_*c_;

return (c2*e2 + c_*e4) / ((e2 + c_)*(e2 + c_));
}

void DCS::print(const std::string &s="") const {
std::cout << s << ": DCS (" << c_ << ")" << std::endl;
}
Expand All @@ -908,12 +976,27 @@ DCS::shared_ptr DCS::Create(double c, const ReweightScheme reweight) {
// L2WithDeadZone
/* ************************************************************************* */

L2WithDeadZone::L2WithDeadZone(double k, const ReweightScheme reweight) : Base(reweight), k_(k) {
L2WithDeadZone::L2WithDeadZone(double k, const ReweightScheme reweight)
: Base(reweight), k_(k) {
if (k_ <= 0) {
throw runtime_error("mEstimator L2WithDeadZone takes only positive double in constructor.");
}
}

double L2WithDeadZone::weight(double error) const {
// note that this code is slightly uglier than residual, because there are three distinct
// cases to handle (left of deadzone, deadzone, right of deadzone) instead of the two
// cases (deadzone, non-deadzone) in residual.
if (std::abs(error) <= k_) return 0.0;
else if (error > k_) return (-k_+error)/error;
else return (k_+error)/error;
}

double L2WithDeadZone::residual(double error) const {
const double abs_error = std::abs(error);
return (abs_error < k_) ? 0.0 : 0.5*(k_-abs_error)*(k_-abs_error);
}

void L2WithDeadZone::print(const std::string &s="") const {
std::cout << s << ": L2WithDeadZone (" << k_ << ")" << std::endl;
}
Expand Down
Loading

0 comments on commit a4ac57c

Please sign in to comment.