-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'topic/more-finite-difference-helpers' into devel
- Loading branch information
Showing
7 changed files
with
124 additions
and
79 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
58 changes: 58 additions & 0 deletions
58
include/aligator/modelling/autodiff/cost-finite-difference.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
#pragma once | ||
#include "aligator/core/cost-abstract.hpp" | ||
#include <proxsuite-nlp/manifold-base.hpp> | ||
|
||
namespace aligator { | ||
namespace autodiff { | ||
|
||
template <typename Scalar> | ||
struct CostFiniteDifferenceHelper : CostAbstractTpl<Scalar> { | ||
using Manifold = ManifoldAbstractTpl<Scalar>; | ||
using CostBase = CostAbstractTpl<Scalar>; | ||
using CostData = CostDataAbstractTpl<Scalar>; | ||
|
||
using CostBase::space; | ||
|
||
ALIGATOR_DYNAMIC_TYPEDEFS(Scalar); | ||
|
||
struct Data; | ||
|
||
CostFiniteDifferenceHelper(xyz::polymorphic<CostBase> cost, | ||
const Scalar fd_eps); | ||
|
||
void evaluate(const ConstVectorRef &x, const ConstVectorRef &u, | ||
CostData &data_) const override; | ||
|
||
void computeGradients(const ConstVectorRef &x, const ConstVectorRef &u, | ||
CostData &data_) const override; | ||
|
||
/// @brief Compute the cost Hessians \f$(\ell_{ij})_{i,j \in \{x,u\}}\f$ | ||
void computeHessians(const ConstVectorRef &, const ConstVectorRef &, | ||
CostData &) const override {} | ||
|
||
auto createData() const -> shared_ptr<CostData> override; | ||
|
||
xyz::polymorphic<CostBase> cost_; | ||
Scalar fd_eps; | ||
}; | ||
|
||
template <typename Scalar> | ||
struct CostFiniteDifferenceHelper<Scalar>::Data : CostData { | ||
|
||
shared_ptr<CostData> c1, c2; | ||
VectorXs dx, du; | ||
VectorXs xp, up; | ||
|
||
Data(CostFiniteDifferenceHelper const &obj) | ||
: CostData(obj), dx(obj.ndx()), du(obj.nu), xp(obj.nx()), up(obj.nu) { | ||
c1 = obj.cost_->createData(); | ||
c2 = obj.cost_->createData(); | ||
} | ||
}; | ||
|
||
#ifdef ALIGATOR_ENABLE_TEMPLATE_INSTANTIATION | ||
extern template struct CostFiniteDifferenceHelper<context::Scalar>; | ||
#endif | ||
|
||
} // namespace autodiff | ||
} // namespace aligator |
57 changes: 57 additions & 0 deletions
57
include/aligator/modelling/autodiff/cost-finite-difference.hxx
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
#pragma once | ||
|
||
#include "aligator/modelling/autodiff/cost-finite-difference.hpp" | ||
|
||
namespace aligator::autodiff { | ||
|
||
template <typename Scalar> | ||
CostFiniteDifferenceHelper<Scalar>::CostFiniteDifferenceHelper( | ||
xyz::polymorphic<CostBase> cost, const Scalar fd_eps) | ||
: CostBase(cost->space, cost->nu), cost_(cost), fd_eps(fd_eps) {} | ||
|
||
template <typename Scalar> | ||
void CostFiniteDifferenceHelper<Scalar>::evaluate(const ConstVectorRef &x, | ||
const ConstVectorRef &u, | ||
CostData &data_) const { | ||
Data &d = static_cast<Data &>(data_); | ||
cost_->evaluate(x, u, *d.c1); | ||
|
||
d.value_ = d.c1->value_; | ||
} | ||
|
||
template <typename Scalar> | ||
void CostFiniteDifferenceHelper<Scalar>::computeGradients( | ||
const ConstVectorRef &x, const ConstVectorRef &u, CostData &data_) const { | ||
Data &d = static_cast<Data &>(data_); | ||
Manifold const &space = *this->space; | ||
|
||
cost_->evaluate(x, u, *d.c1); | ||
|
||
d.dx.setZero(); | ||
for (int i = 0; i < this->ndx(); i++) { | ||
d.dx[i] = fd_eps; | ||
space.integrate(x, d.dx, d.xp); | ||
cost_->evaluate(d.xp, u, *d.c2); | ||
|
||
d.Lx_[i] = (d.c2->value_ - d.c1->value_) / fd_eps; | ||
d.dx[i] = 0.; | ||
} | ||
|
||
d.du.setZero(); | ||
for (int i = 0; i < this->nu; i++) { | ||
d.du[i] = fd_eps; | ||
d.up = u + d.du; | ||
cost_->evaluate(x, d.up, *d.c2); | ||
|
||
d.Lu_[i] = (d.c2->value_ - d.c1->value_) / fd_eps; | ||
d.du[i] = 0.; | ||
} | ||
} | ||
|
||
template <typename Scalar> | ||
auto CostFiniteDifferenceHelper<Scalar>::createData() const | ||
-> shared_ptr<CostData> { | ||
return std::make_shared<Data>(*this); | ||
} | ||
|
||
} // namespace aligator::autodiff |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#include "aligator/modelling/autodiff/cost-finite-difference.hxx" | ||
|
||
namespace aligator::autodiff { | ||
|
||
template struct CostFiniteDifferenceHelper<context::Scalar>; | ||
|
||
} // namespace aligator::autodiff |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters