Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap mEstimators #117

Merged
merged 17 commits into from
Oct 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions gtsam.h
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,9 @@ virtual class Null: gtsam::noiseModel::mEstimator::Base {

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
double residual(double error) const;
};

virtual class Fair: gtsam::noiseModel::mEstimator::Base {
Expand All @@ -1370,6 +1373,9 @@ virtual class Fair: gtsam::noiseModel::mEstimator::Base {

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
double residual(double error) const;
};

virtual class Huber: gtsam::noiseModel::mEstimator::Base {
Expand All @@ -1378,6 +1384,20 @@ virtual class Huber: gtsam::noiseModel::mEstimator::Base {

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
double residual(double error) const;
};

virtual class Cauchy: gtsam::noiseModel::mEstimator::Base {
dellaert marked this conversation as resolved.
Show resolved Hide resolved
Cauchy(double k);
static gtsam::noiseModel::mEstimator::Cauchy* Create(double k);

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
double residual(double error) const;
};

virtual class Tukey: gtsam::noiseModel::mEstimator::Base {
Expand All @@ -1386,6 +1406,9 @@ virtual class Tukey: gtsam::noiseModel::mEstimator::Base {

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
double residual(double error) const;
};

virtual class Welsch: gtsam::noiseModel::mEstimator::Base {
Expand All @@ -1394,8 +1417,43 @@ virtual class Welsch: gtsam::noiseModel::mEstimator::Base {

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
double residual(double error) const;
};

virtual class GemanMcClure: gtsam::noiseModel::mEstimator::Base {
GemanMcClure(double c);
static gtsam::noiseModel::mEstimator::GemanMcClure* Create(double c);

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
double residual(double error) const;
};

virtual class DCS: gtsam::noiseModel::mEstimator::Base {
DCS(double c);
static gtsam::noiseModel::mEstimator::DCS* Create(double c);

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
double residual(double error) const;
};

virtual class L2WithDeadZone: gtsam::noiseModel::mEstimator::Base {
L2WithDeadZone(double k);
static gtsam::noiseModel::mEstimator::L2WithDeadZone* Create(double k);

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
double residual(double error) const;
};

}///\namespace mEstimator

Expand Down
129 changes: 106 additions & 23 deletions gtsam/linear/NoiseModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -718,15 +718,26 @@ void Null::print(const std::string &s="") const
Null::shared_ptr Null::Create()
{ return shared_ptr(new Null()); }

/* ************************************************************************* */
// Fair
/* ************************************************************************* */

Fair::Fair(double c, const ReweightScheme reweight) : Base(reweight), c_(c) {
if (c_ <= 0) {
throw runtime_error("mEstimator Fair takes only positive double in constructor.");
}
}

/* ************************************************************************* */
// Fair
/* ************************************************************************* */
double Fair::weight(double error) const {
return 1.0 / (1.0 + std::abs(error) / c_);
}
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved

double Fair::residual(double error) const {
const double absError = std::abs(error);
const double normalizedError = absError / c_;
const double c_2 = c_ * c_;
return c_2 * (normalizedError - std::log(1 + normalizedError));
}

void Fair::print(const std::string &s="") const
{ cout << s << "fair (" << c_ << ")" << endl; }
Expand All @@ -750,6 +761,20 @@ Huber::Huber(double k, const ReweightScheme reweight) : Base(reweight), k_(k) {
}
}

double Huber::weight(double error) const {
const double absError = std::abs(error);
return (absError <= k_) ? (1.0) : (k_ / absError);
}

double Huber::residual(double error) const {
const double absError = std::abs(error);
if (absError <= k_) { // |x| <= k
return error*error / 2;
} else { // |x| > k
return k_ * (absError - (k_/2));
}
}

void Huber::print(const std::string &s="") const {
cout << s << "huber (" << k_ << ")" << endl;
}
Expand All @@ -774,6 +799,16 @@ Cauchy::Cauchy(double k, const ReweightScheme reweight) : Base(reweight), k_(k),
}
}

double Cauchy::weight(double error) const {
return ksquared_ / (ksquared_ + error*error);
}

double Cauchy::residual(double error) const {
const double xc2 = error / k_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor nit: xc2 is usually error*error/ksquared_, and then can do std::log(1 + xc2) below.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@varunagrawal I know you can't finish this yet because of other priorities, but could you incorporate this change into the new PR?

const double val = std::log(1 + (xc2*xc2));
return ksquared_ * val * 0.5;
}

void Cauchy::print(const std::string &s="") const {
cout << s << "cauchy (" << k_ << ")" << endl;
}
Expand All @@ -791,7 +826,31 @@ Cauchy::shared_ptr Cauchy::Create(double c, const ReweightScheme reweight) {
/* ************************************************************************* */
// Tukey
/* ************************************************************************* */
Tukey::Tukey(double c, const ReweightScheme reweight) : Base(reweight), c_(c), csquared_(c * c) {}

Tukey::Tukey(double c, const ReweightScheme reweight) : Base(reweight), c_(c), csquared_(c * c) {
if (c <= 0) {
throw runtime_error("mEstimator Tukey takes only positive double in constructor.");
}
}

double Tukey::weight(double error) const {
if (std::abs(error) <= c_) {
const double xc2 = error*error/csquared_;
return (1.0-xc2)*(1.0-xc2);
}
return 0.0;
}
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved

double Tukey::residual(double error) const {
double absError = std::abs(error);
if (absError <= c_) {
const double xc2 = error*error/csquared_;
const double t = (1 - xc2)*(1 - xc2)*(1 - xc2);
return csquared_ * (1 - t) / 6.0;
} else {
return csquared_ / 6.0;
}
}

void Tukey::print(const std::string &s="") const {
std::cout << s << ": Tukey (" << c_ << ")" << std::endl;
Expand All @@ -810,8 +869,19 @@ Tukey::shared_ptr Tukey::Create(double c, const ReweightScheme reweight) {
/* ************************************************************************* */
// Welsch
/* ************************************************************************* */

Welsch::Welsch(double c, const ReweightScheme reweight) : Base(reweight), c_(c), csquared_(c * c) {}

double Welsch::weight(double error) const {
const double xc2 = (error*error)/csquared_;
return std::exp(-xc2);
}

double Welsch::residual(double error) const {
const double xc2 = (error*error)/csquared_;
return csquared_ * 0.5 * (1 - std::exp(-xc2) );
}

void Welsch::print(const std::string &s="") const {
std::cout << s << ": Welsch (" << c_ << ")" << std::endl;
}
Expand All @@ -826,24 +896,6 @@ Welsch::shared_ptr Welsch::Create(double c, const ReweightScheme reweight) {
return shared_ptr(new Welsch(c, reweight));
}

#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V4
Welsh::Welsh(double c, const ReweightScheme reweight) : Base(reweight), c_(c), csquared_(c * c) {}

void Welsh::print(const std::string &s="") const {
std::cout << s << ": Welsh (" << c_ << ")" << std::endl;
}

bool Welsh::equals(const Base &expected, double tol) const {
const Welsh* p = dynamic_cast<const Welsh*>(&expected);
if (p == NULL) return false;
return std::abs(c_ - p->c_) < tol;
}

Welsh::shared_ptr Welsh::Create(double c, const ReweightScheme reweight) {
return shared_ptr(new Welsh(c, reweight));
}
#endif

/* ************************************************************************* */
// GemanMcClure
/* ************************************************************************* */
Expand All @@ -858,6 +910,12 @@ double GemanMcClure::weight(double error) const {
return c4/(c2error*c2error);
}

double GemanMcClure::residual(double error) const {
const double c2 = c_*c_;
const double error2 = error*error;
return 0.5 * (c2 * error2) / (c2 + error2);
}

void GemanMcClure::print(const std::string &s="") const {
std::cout << s << ": Geman-McClure (" << c_ << ")" << std::endl;
}
Expand Down Expand Up @@ -890,6 +948,16 @@ double DCS::weight(double error) const {
return 1.0;
}

double DCS::residual(double error) const {
// This is the simplified version of Eq 9 from (Agarwal13icra)
// after you simplify and cancel terms.
const double e2 = error*error;
const double e4 = e2*e2;
const double c2 = c_*c_;

return (c2*e2 + c_*e4) / ((e2 + c_)*(e2 + c_));
}

void DCS::print(const std::string &s="") const {
std::cout << s << ": DCS (" << c_ << ")" << std::endl;
}
Expand All @@ -908,12 +976,27 @@ DCS::shared_ptr DCS::Create(double c, const ReweightScheme reweight) {
// L2WithDeadZone
/* ************************************************************************* */

L2WithDeadZone::L2WithDeadZone(double k, const ReweightScheme reweight) : Base(reweight), k_(k) {
L2WithDeadZone::L2WithDeadZone(double k, const ReweightScheme reweight)
: Base(reweight), k_(k) {
if (k_ <= 0) {
throw runtime_error("mEstimator L2WithDeadZone takes only positive double in constructor.");
}
}

double L2WithDeadZone::weight(double error) const {
// note that this code is slightly uglier than residual, because there are three distinct
// cases to handle (left of deadzone, deadzone, right of deadzone) instead of the two
// cases (deadzone, non-deadzone) in residual.
if (std::abs(error) <= k_) return 0.0;
else if (error > k_) return (-k_+error)/error;
else return (k_+error)/error;
}

double L2WithDeadZone::residual(double error) const {
const double abs_error = std::abs(error);
return (abs_error < k_) ? 0.0 : 0.5*(k_-abs_error)*(k_-abs_error);
}

void L2WithDeadZone::print(const std::string &s="") const {
std::cout << s << ": L2WithDeadZone (" << k_ << ")" << std::endl;
}
Expand Down
Loading