Skip to content

Commit

Permalink
Merge pull request #205 from Simple-Robotics/topic/more-casting-short…
Browse files Browse the repository at this point in the history
…cuts

[core | modelling] Add templated getters for cost and dynamics
  • Loading branch information
ManifoldFR authored Sep 19, 2024
2 parents b1ec455 + ec926aa commit 92f171d
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 1 deletion.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

- Templated getters `getCost<U>()` and `getDynamics<U>()` in the StageModel class, and another `getDynamics<U>()` for integrator classes, to get the concrete types ([##205](https://github.com/Simple-Robotics/aligator/pull/205))

### Changed

- All map types are now `boost::unordered_map` ([#203](https://github.com/Simple-Robotics/aligator/pull/203))
Expand Down
28 changes: 28 additions & 0 deletions include/aligator/core/stage-model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

namespace aligator {

#define ALIGATOR_CHECK_DERIVED_CLASS(Base, Derived) \
static_assert((std::is_base_of_v<Base, Derived>), \
"Failed check for derived class.")

/** @brief A stage in the control problem.
*
* @details Each stage containts cost functions, dynamical
Expand Down Expand Up @@ -46,6 +50,30 @@ template <typename _Scalar> struct StageModelTpl {
/// Dynamics model
PolyDynamics dynamics_;

/// @brief Get a pointer to an expected concrete type for the cost function.
template <typename U> U *getCost() {
ALIGATOR_CHECK_DERIVED_CLASS(Cost, U);
return dynamic_cast<U *>(&*cost_);
}

/// @copybrief castCost()
template <typename U> const U *getCost() const {
ALIGATOR_CHECK_DERIVED_CLASS(Cost, U);
return dynamic_cast<const U *>(&*cost_);
}

/// @brief Get a pointer to an expected concrete type for the dynamics class.
template <typename U> U *getDynamics() {
ALIGATOR_CHECK_DERIVED_CLASS(Dynamics, U);
return dynamic_cast<U *>(&*dynamics_);
}

/// @copybrief castDynamics()
template <typename U> const U *getDynamics() const {
ALIGATOR_CHECK_DERIVED_CLASS(Dynamics, U);
return dynamic_cast<const U *>(&*dynamics_);
}

/// Constructor assumes the control space is a Euclidean space of
/// dimension @p nu.
StageModelTpl(const PolyCost &cost, const PolyDynamics &dynamics);
Expand Down
7 changes: 7 additions & 0 deletions include/aligator/modelling/dynamics/integrator-abstract.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ struct IntegratorAbstractTpl : DynamicsModelTpl<_Scalar> {
/// The underlying continuous dynamics.
xyz::polymorphic<ContinuousDynamics> continuous_dynamics_;

template <typename U> U *getDynamics() {
return dynamic_cast<U *>(&*continuous_dynamics_);
}
template <typename U> const U *getDynamics() const {
return dynamic_cast<const U *>(&*continuous_dynamics_);
}

/// Constructor from instances of DynamicsType.
explicit IntegratorAbstractTpl(
const xyz::polymorphic<ContinuousDynamics> &cont_dynamics);
Expand Down
5 changes: 5 additions & 0 deletions include/aligator/modelling/dynamics/integrator-explicit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ struct ExplicitIntegratorAbstractTpl : ExplicitDynamicsModelTpl<_Scalar> {

xyz::polymorphic<ODEType> ode_;

template <typename U> U *getDynamics() { return dynamic_cast<U *>(&*ode_); }
template <typename U> const U *getDynamics() const {
return dynamic_cast<const U *>(&*ode_);
}

explicit ExplicitIntegratorAbstractTpl(
const xyz::polymorphic<ODEType> &cont_dynamics);
virtual ~ExplicitIntegratorAbstractTpl() = default;
Expand Down
7 changes: 6 additions & 1 deletion tests/problem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,15 @@ BOOST_AUTO_TEST_CASE(test_problem) {

auto nu = f.nu;
auto &space = f.space;
auto &stage = *f.problem.stages_[0];
const auto &stage = *f.problem.stages_[0];
BOOST_CHECK_EQUAL(stage.numPrimal(), space.ndx() + nu);
BOOST_CHECK_EQUAL(stage.numDual(), space.ndx());

auto *p_dyn = stage.getDynamics<MyModel>();
BOOST_CHECK(p_dyn);
auto *p_cost = stage.getCost<MyCost>();
BOOST_CHECK(p_cost);

Eigen::VectorXd u0(nu);
u0.setZero();
auto x0 = stage.xspace_->rand();
Expand Down

0 comments on commit 92f171d

Please sign in to comment.