Skip to content

Commit

Permalink
Merge pull request #139 from gchenfc/feature/linear-fg-marginals
Browse files Browse the repository at this point in the history
Gaussian Factor Graph Marginals
  • Loading branch information
dellaert authored Oct 10, 2019
2 parents 2f6edee + 5900eaa commit 38cf6bd
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 82 deletions.
91 changes: 91 additions & 0 deletions cython/gtsam/tests/test_GaussianFactorGraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
Atlanta, Georgia 30332-0415
All Rights Reserved
See LICENSE for the license information
Unit tests for Linear Factor Graphs.
Author: Frank Dellaert & Gerry Chen
"""
# pylint: disable=invalid-name, no-name-in-module, no-member

from __future__ import print_function

import unittest

import gtsam
from gtsam.utils.test_case import GtsamTestCase

import numpy as np

def create_graph():
"""Create a basic linear factor graph for testing"""
graph = gtsam.GaussianFactorGraph()

x0 = gtsam.symbol(ord('x'), 0)
x1 = gtsam.symbol(ord('x'), 1)
x2 = gtsam.symbol(ord('x'), 2)

BETWEEN_NOISE = gtsam.noiseModel_Diagonal.Sigmas(np.ones(1))
PRIOR_NOISE = gtsam.noiseModel_Diagonal.Sigmas(np.ones(1))

graph.add(x1, np.eye(1), x0, -np.eye(1), np.ones(1), BETWEEN_NOISE)
graph.add(x2, np.eye(1), x1, -np.eye(1), 2*np.ones(1), BETWEEN_NOISE)
graph.add(x0, np.eye(1), np.zeros(1), PRIOR_NOISE)

return graph, (x0, x1, x2)

class TestGaussianFactorGraph(GtsamTestCase):
"""Tests for Gaussian Factor Graphs."""

def test_fg(self):
"""Test solving a linear factor graph"""
graph, X = create_graph()
result = graph.optimize()

EXPECTEDX = [0, 1, 3]

# check solutions
self.assertAlmostEqual(EXPECTEDX[0], result.at(X[0]), delta=1e-8)
self.assertAlmostEqual(EXPECTEDX[1], result.at(X[1]), delta=1e-8)
self.assertAlmostEqual(EXPECTEDX[2], result.at(X[2]), delta=1e-8)

def test_convertNonlinear(self):
"""Test converting a linear factor graph to a nonlinear one"""
graph, X = create_graph()

EXPECTEDM = [1, 2, 3]

# create nonlinear factor graph for marginalization
nfg = gtsam.LinearContainerFactor.ConvertLinearGraph(graph)
optimizer = gtsam.LevenbergMarquardtOptimizer(nfg, gtsam.Values())
nlresult = optimizer.optimizeSafely()

# marginalize
marginals = gtsam.Marginals(nfg, nlresult)
m = [marginals.marginalCovariance(x) for x in X]

# check linear marginalizations
self.assertAlmostEqual(EXPECTEDM[0], m[0], delta=1e-8)
self.assertAlmostEqual(EXPECTEDM[1], m[1], delta=1e-8)
self.assertAlmostEqual(EXPECTEDM[2], m[2], delta=1e-8)

def test_linearMarginalization(self):
"""Marginalize a linear factor graph"""
graph, X = create_graph()
result = graph.optimize()

EXPECTEDM = [1, 2, 3]

# linear factor graph marginalize
marginals = gtsam.Marginals(graph, result)
m = [marginals.marginalCovariance(x) for x in X]

# check linear marginalizations
self.assertAlmostEqual(EXPECTEDM[0], m[0], delta=1e-8)
self.assertAlmostEqual(EXPECTEDM[1], m[1], delta=1e-8)
self.assertAlmostEqual(EXPECTEDM[2], m[2], delta=1e-8)

if __name__ == '__main__':
unittest.main()
4 changes: 4 additions & 0 deletions gtsam.h
Original file line number Diff line number Diff line change
Expand Up @@ -2011,6 +2011,10 @@ class Values {
class Marginals {
Marginals(const gtsam::NonlinearFactorGraph& graph,
const gtsam::Values& solution);
Marginals(const gtsam::GaussianFactorGraph& gfgraph,
const gtsam::Values& solution);
Marginals(const gtsam::GaussianFactorGraph& gfgraph,
const gtsam::VectorValues& solutionvec);

void print(string s) const;
Matrix marginalCovariance(size_t variable) const;
Expand Down
35 changes: 27 additions & 8 deletions gtsam/nonlinear/Marginals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,36 @@ namespace gtsam {
/* ************************************************************************* */
Marginals::Marginals(const NonlinearFactorGraph& graph, const Values& solution, Factorization factorization,
EliminateableFactorGraph<GaussianFactorGraph>::OptionalOrdering ordering)
{
: values_(solution), factorization_(factorization) {
gttic(MarginalsConstructor);

// Linearize graph
graph_ = *graph.linearize(solution);
computeBayesTree(ordering);
}

/* ************************************************************************* */
Marginals::Marginals(const GaussianFactorGraph& graph, const VectorValues& solution, Factorization factorization,
EliminateableFactorGraph<GaussianFactorGraph>::OptionalOrdering ordering)
: graph_(graph), factorization_(factorization) {
gttic(MarginalsConstructor);
Values vals;
for (const auto& keyValue: solution) {
vals.insert(keyValue.first, keyValue.second);
}
values_ = vals;
computeBayesTree(ordering);
}

// Store values
values_ = solution;
/* ************************************************************************* */
Marginals::Marginals(const GaussianFactorGraph& graph, const Values& solution, Factorization factorization,
EliminateableFactorGraph<GaussianFactorGraph>::OptionalOrdering ordering)
: graph_(graph), values_(solution), factorization_(factorization) {
gttic(MarginalsConstructor);
computeBayesTree(ordering);
}

/* ************************************************************************* */
void Marginals::computeBayesTree(EliminateableFactorGraph<GaussianFactorGraph>::OptionalOrdering ordering) {
// Compute BayesTree
factorization_ = factorization;
if(factorization_ == CHOLESKY)
bayesTree_ = *graph_.eliminateMultifrontal(ordering, EliminatePreferCholesky);
else if(factorization_ == QR)
Expand Down Expand Up @@ -125,7 +144,7 @@ JointMarginal Marginals::jointMarginalInformation(const KeyVector& variables) co
// Get dimensions from factor graph
std::vector<size_t> dims;
dims.reserve(variablesSorted.size());
for(Key key: variablesSorted) {
for(const auto& key: variablesSorted) {
dims.push_back(values_.at(key).dim());
}

Expand All @@ -142,7 +161,7 @@ VectorValues Marginals::optimize() const {
void JointMarginal::print(const std::string& s, const KeyFormatter& formatter) const {
cout << s << "Joint marginal on keys ";
bool first = true;
for(Key key: keys_) {
for(const auto& key: keys_) {
if(!first)
cout << ", ";
else
Expand Down
26 changes: 25 additions & 1 deletion gtsam/nonlinear/Marginals.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class GTSAM_EXPORT Marginals {
/// Default constructor only for Cython wrapper
Marginals(){}

/** Construct a marginals class.
/** Construct a marginals class from a nonlinear factor graph.
* @param graph The factor graph defining the full joint density on all variables.
* @param solution The linearization point about which to compute Gaussian marginals (usually the MLE as obtained from a NonlinearOptimizer).
* @param factorization The linear decomposition mode - either Marginals::CHOLESKY (faster and suitable for most problems) or Marginals::QR (slower but more numerically stable for poorly-conditioned problems).
Expand All @@ -60,6 +60,24 @@ class GTSAM_EXPORT Marginals {
Marginals(const NonlinearFactorGraph& graph, const Values& solution, Factorization factorization = CHOLESKY,
EliminateableFactorGraph<GaussianFactorGraph>::OptionalOrdering ordering = boost::none);

/** Construct a marginals class from a linear factor graph.
* @param graph The factor graph defining the full joint density on all variables.
* @param solution The solution point to compute Gaussian marginals.
* @param factorization The linear decomposition mode - either Marginals::CHOLESKY (faster and suitable for most problems) or Marginals::QR (slower but more numerically stable for poorly-conditioned problems).
* @param ordering An optional variable ordering for elimination.
*/
Marginals(const GaussianFactorGraph& graph, const Values& solution, Factorization factorization = CHOLESKY,
EliminateableFactorGraph<GaussianFactorGraph>::OptionalOrdering ordering = boost::none);

/** Construct a marginals class from a linear factor graph.
* @param graph The factor graph defining the full joint density on all variables.
* @param solution The solution point to compute Gaussian marginals.
* @param factorization The linear decomposition mode - either Marginals::CHOLESKY (faster and suitable for most problems) or Marginals::QR (slower but more numerically stable for poorly-conditioned problems).
* @param ordering An optional variable ordering for elimination.
*/
Marginals(const GaussianFactorGraph& graph, const VectorValues& solution, Factorization factorization = CHOLESKY,
EliminateableFactorGraph<GaussianFactorGraph>::OptionalOrdering ordering = boost::none);

/** print */
void print(const std::string& str = "Marginals: ", const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;

Expand All @@ -81,6 +99,12 @@ class GTSAM_EXPORT Marginals {

/** Optimize the bayes tree */
VectorValues optimize() const;

protected:

/** Compute the Bayes Tree as a helper function to the constructor */
void computeBayesTree(EliminateableFactorGraph<GaussianFactorGraph>::OptionalOrdering ordering);

};

/**
Expand Down
Loading

0 comments on commit 38cf6bd

Please sign in to comment.