Skip to content

Commit

Permalink
Merge pull request #198 from Simple-Robotics/topic/composite-cost-cas…
Browse files Browse the repository at this point in the history
…t-residual

[modelling/costs] Add getResidual() templated getter for composite cost functions
  • Loading branch information
ManifoldFR authored Sep 17, 2024
2 parents 923e7c0 + b20c824 commit 0e62452
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 25 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

- Getter `getResidual<Derived>()` for composite cost functions ([#198](https://github.com/Simple-Robotics/aligator/pull/198))

### Changed

- Change storage for `ConstraintStack` to using two `std::vector<polymorphic<>>` the struct `StageConstraintTpl` is now merely a convenient API shortcut for the end-user.
Expand Down
11 changes: 11 additions & 0 deletions include/aligator/modelling/costs/log-residual-cost.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ template <typename Scalar> struct LogResidualCostTpl : CostAbstractTpl<Scalar> {
return std::make_shared<Data>(this->ndx(), this->nu,
residual_->createData());
}

/// @brief Get a pointer to the underlying type of the residual, by attempting
/// to cast.
template <typename Derived> Derived *getResidual() {
return dynamic_cast<Derived *>(&*residual_);
}

/// @copybrief getResidual().
template <typename Derived> const Derived *getResidual() const {
return dynamic_cast<const Derived *>(&*residual_);
}
};

extern template struct LogResidualCostTpl<context::Scalar>;
Expand Down
11 changes: 11 additions & 0 deletions include/aligator/modelling/costs/quad-residual-cost.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ struct QuadraticResidualCostTpl : CostAbstractTpl<_Scalar> {
return std::make_shared<Data>(this->ndx(), this->nu,
residual_->createData());
}

/// @brief Get a pointer to the underlying type of the residual, by attempting
/// to cast.
template <typename Derived> Derived *getResidual() {
return dynamic_cast<Derived *>(&*residual_);
}

/// @copybrief getResidual().
template <typename Derived> const Derived *getResidual() const {
return dynamic_cast<const Derived *>(&*residual_);
}
};

extern template struct QuadraticResidualCostTpl<context::Scalar>;
Expand Down
4 changes: 2 additions & 2 deletions tests/continuous.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ BOOST_AUTO_TEST_CASE(create_data) {
pinocchio::buildModels::humanoidRandom(model);

using StateMultibody = proxsuite::nlp::MultibodyPhaseSpace<double>;
auto spaceptr = std::make_shared<StateMultibody>(model);
const StateMultibody state{model};
Eigen::MatrixXd B(model.nv, model.nv);
B.setIdentity();
dynamics::MultibodyFreeFwdDynamicsTpl<double> contdyn(*spaceptr, B);
dynamics::MultibodyFreeFwdDynamicsTpl<double> contdyn(state, B);

using ContDataAbstract = dynamics::ContinuousDynamicsDataTpl<double>;
using Data = dynamics::MultibodyFreeFwdDataTpl<double>;
Expand Down
54 changes: 31 additions & 23 deletions tests/costs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ using T = double;
using context::MatrixXs;
using context::VectorXs;
using QuadraticResidualCost = QuadraticResidualCostTpl<T>;
using StateError = StateErrorResidualTpl<T>;

void fd_test(VectorXs x0, VectorXs u0, MatrixXs weights,
QuadraticResidualCost qres, shared_ptr<context::CostData> data) {
Expand Down Expand Up @@ -42,63 +43,70 @@ void fd_test(VectorXs x0, VectorXs u0, MatrixXs weights,

BOOST_AUTO_TEST_CASE(quad_state_se2) {
using SE2 = proxsuite::nlp::SETpl<2, T>;
auto space = std::make_shared<SE2>();
SE2 space;

const Eigen::Index ndx = space->ndx();
const Eigen::Index ndx = space.ndx();
const Eigen::Index nu = 1UL;
Eigen::VectorXd u0(nu);
u0.setZero();

const auto target = space->rand();
const auto target = space.rand();

const auto fun =
std::make_shared<StateErrorResidualTpl<T>>(*space, nu, target);
const StateError fun(space, nu, target);

BOOST_CHECK_EQUAL(fun->nr, ndx);
BOOST_CHECK_EQUAL(fun.nr, ndx);
Eigen::MatrixXd weights(ndx, ndx);
weights.setIdentity();
const auto qres =
std::make_shared<QuadraticStateCostTpl<T>>(*space, nu, target, weights);
const QuadraticStateCostTpl<T> qres(space, nu, target, weights);

shared_ptr<context::CostData> data = qres->createData();
auto fd = fun->createData();
shared_ptr<context::CostData> data = qres.createData();
auto fd = fun.createData();

const int nrepeats = 10;

for (int k = 0; k < nrepeats; k++) {
Eigen::VectorXd x0 = space->rand();
fd_test(x0, u0, weights, *qres, data);
Eigen::VectorXd x0 = space.rand();
fd_test(x0, u0, weights, qres, data);
}

const StateError *fun_cast = qres.getResidual<StateError>();
BOOST_CHECK(fun_cast != nullptr);
}

BOOST_AUTO_TEST_CASE(quad_state_highdim) {
using VectorSpace = proxsuite::nlp::VectorSpaceTpl<T>;
const Eigen::Index ndx = 56;
const auto space = std::make_shared<VectorSpace>(ndx);
const VectorSpace space(ndx);
const Eigen::Index nu = 1UL;

Eigen::VectorXd u0(nu);
u0.setZero();

const auto target = space->rand();
const auto target = space.rand();

const auto fun =
std::make_shared<StateErrorResidualTpl<T>>(*space, nu, target);
const StateErrorResidualTpl<T> fun(space, nu, target);

BOOST_CHECK_EQUAL(fun->nr, ndx);
BOOST_CHECK_EQUAL(fun.nr, ndx);
Eigen::MatrixXd weights(ndx, ndx);
weights.setIdentity();
const auto qres =
std::make_shared<QuadraticStateCostTpl<T>>(*space, nu, target, weights);
const QuadraticStateCostTpl<T> qres(space, nu, target, weights);

shared_ptr<context::CostData> data = qres->createData();
auto fd = fun->createData();
shared_ptr<context::CostData> data = qres.createData();
auto fd = fun.createData();

const int nrepeats = 10;

for (int k = 0; k < nrepeats; k++) {
Eigen::VectorXd x0 = space->rand();
fd_test(x0, u0, weights, *qres, data);
Eigen::VectorXd x0 = space.rand();
fd_test(x0, u0, weights, qres, data);
}

const StateError *fun_cast = qres.getResidual<StateError>();
BOOST_CHECK(fun_cast != nullptr);

{
const auto *try_cast = qres.getResidual<ControlErrorResidualTpl<T>>();
BOOST_CHECK(try_cast == nullptr);
}
}

Expand Down

0 comments on commit 0e62452

Please sign in to comment.