Skip to content

Commit

Permalink
Refactor Optimiser #92
Browse files Browse the repository at this point in the history
  • Loading branch information
onurulgen committed Nov 24, 2023
1 parent b2a32ff commit 8182839
Show file tree
Hide file tree
Showing 15 changed files with 269 additions and 250 deletions.
2 changes: 1 addition & 1 deletion niftyreg_build_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
368
369
1 change: 1 addition & 0 deletions reg-lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ add_library(_reg_compute ${NIFTYREG_LIBRARY_TYPE}
Content.cpp
DefContent.cpp
F3dContent.cpp
Optimiser.cpp
Platform.cpp
Measure.cpp
)
Expand Down
2 changes: 1 addition & 1 deletion reg-lib/Compute.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include "Content.h"
#include "_reg_optimiser.h"
#include "Optimiser.hpp"

class Compute {
public:
Expand Down
140 changes: 72 additions & 68 deletions reg-lib/cpu/_reg_optimiser.cpp → reg-lib/Optimiser.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
/** @file _reg_optimiser.cpp
/** @file Optimiser.cpp
* @author Marc Modat
* @date 20/07/2012
*/

#include "_reg_optimiser.h"
#include "Optimiser.hpp"

/* *************************************************************** */
namespace NiftyReg {
/* *************************************************************** */
template <class T>
reg_optimiser<T>::reg_optimiser() {
Optimiser<T>::Optimiser() {
this->dofNumber = 0;
this->dofNumberBw = 0;
this->ndim = 3;
Expand All @@ -30,7 +32,7 @@ reg_optimiser<T>::reg_optimiser() {
}
/* *************************************************************** */
template <class T>
reg_optimiser<T>::~reg_optimiser() {
Optimiser<T>::~Optimiser() {
if (this->bestDof) {
free(this->bestDof);
this->bestDof = nullptr;
Expand All @@ -43,19 +45,19 @@ reg_optimiser<T>::~reg_optimiser() {
}
/* *************************************************************** */
template <class T>
void reg_optimiser<T>::Initialise(size_t nvox,
int ndim,
bool optX,
bool optY,
bool optZ,
size_t maxIt,
size_t startIt,
InterfaceOptimiser *intOpt,
T *cppData,
T *gradData,
size_t nvoxBw,
T *cppDataBw,
T *gradDataBw) {
void Optimiser<T>::Initialise(size_t nvox,
int ndim,
bool optX,
bool optY,
bool optZ,
size_t maxIt,
size_t startIt,
InterfaceOptimiser *intOpt,
T *cppData,
T *gradData,
size_t nvoxBw,
T *cppDataBw,
T *gradDataBw) {
this->dofNumber = nvox;
this->ndim = ndim;
this->optimiseX = optX;
Expand Down Expand Up @@ -87,7 +89,7 @@ void reg_optimiser<T>::Initialise(size_t nvox,
}
/* *************************************************************** */
template <class T>
void reg_optimiser<T>::RestoreBestDof() {
void Optimiser<T>::RestoreBestDof() {
// Restore forward transformation
memcpy(this->currentDof, this->bestDof, this->dofNumber * sizeof(T));
// Restore backward transformation if required
Expand All @@ -96,7 +98,7 @@ void reg_optimiser<T>::RestoreBestDof() {
}
/* *************************************************************** */
template <class T>
void reg_optimiser<T>::StoreCurrentDof() {
void Optimiser<T>::StoreCurrentDof() {
// Save forward transformation
memcpy(this->bestDof, this->currentDof, this->dofNumber * sizeof(T));
// Save backward transformation if required
Expand All @@ -105,7 +107,7 @@ void reg_optimiser<T>::StoreCurrentDof() {
}
/* *************************************************************** */
template <class T>
void reg_optimiser<T>::Perturbation(float length) {
void Optimiser<T>::Perturbation(float length) {
// Initialise the randomiser
srand((unsigned)time(nullptr));
// Reset the number of iteration
Expand All @@ -124,7 +126,7 @@ void reg_optimiser<T>::Perturbation(float length) {
}
/* *************************************************************** */
template <class T>
void reg_optimiser<T>::Optimise(T maxLength, T smallLength, T& startLength) {
void Optimiser<T>::Optimise(T maxLength, T smallLength, T& startLength) {
size_t lineIteration = 0;
float addedLength = 0;
float currentLength = static_cast<float>(startLength);
Expand Down Expand Up @@ -170,8 +172,11 @@ void reg_optimiser<T>::Optimise(T maxLength, T smallLength, T& startLength) {
this->RestoreBestDof();
}
/* *************************************************************** */
template class Optimiser<float>;
template class Optimiser<double>;
/* *************************************************************** */
template <class T>
reg_conjugateGradient<T>::reg_conjugateGradient(): reg_optimiser<T>::reg_optimiser() {
ConjugateGradient<T>::ConjugateGradient(): Optimiser<T>::Optimiser() {
this->array1 = nullptr;
this->array1Bw = nullptr;
this->array2 = nullptr;
Expand All @@ -180,7 +185,7 @@ reg_conjugateGradient<T>::reg_conjugateGradient(): reg_optimiser<T>::reg_optimis
}
/* *************************************************************** */
template <class T>
reg_conjugateGradient<T>::~reg_conjugateGradient() {
ConjugateGradient<T>::~ConjugateGradient() {
if (this->array1) {
free(this->array1);
this->array1 = nullptr;
Expand All @@ -201,20 +206,20 @@ reg_conjugateGradient<T>::~reg_conjugateGradient() {
}
/* *************************************************************** */
template <class T>
void reg_conjugateGradient<T>::Initialise(size_t nvox,
int ndim,
bool optX,
bool optY,
bool optZ,
size_t maxIt,
size_t startIt,
InterfaceOptimiser *intOpt,
T *cppData,
T *gradData,
size_t nvoxBw,
T *cppDataBw,
T *gradDataBw) {
reg_optimiser<T>::Initialise(nvox, ndim, optX, optY, optZ, maxIt, startIt, intOpt, cppData, gradData, nvoxBw, cppDataBw, gradDataBw);
void ConjugateGradient<T>::Initialise(size_t nvox,
int ndim,
bool optX,
bool optY,
bool optZ,
size_t maxIt,
size_t startIt,
InterfaceOptimiser *intOpt,
T *cppData,
T *gradData,
size_t nvoxBw,
T *cppDataBw,
T *gradDataBw) {
Optimiser<T>::Initialise(nvox, ndim, optX, optY, optZ, maxIt, startIt, intOpt, cppData, gradData, nvoxBw, cppDataBw, gradDataBw);
this->firstCall = true;
if (this->array1) free(this->array1);
if (this->array2) free(this->array2);
Expand All @@ -232,7 +237,7 @@ void reg_conjugateGradient<T>::Initialise(size_t nvox,
}
/* *************************************************************** */
template <class T>
void reg_conjugateGradient<T>::UpdateGradientValues() {
void ConjugateGradient<T>::UpdateGradientValues() {
#ifdef WIN32
long i;
long num = (long)this->dofNumber;
Expand Down Expand Up @@ -321,21 +326,22 @@ void reg_conjugateGradient<T>::UpdateGradientValues() {
}
/* *************************************************************** */
template <class T>
void reg_conjugateGradient<T>::Optimise(T maxLength,
T smallLength,
T &startLength) {
void ConjugateGradient<T>::Optimise(T maxLength, T smallLength, T& startLength) {
this->UpdateGradientValues();
reg_optimiser<T>::Optimise(maxLength, smallLength, startLength);
Optimiser<T>::Optimise(maxLength, smallLength, startLength);
}
/* *************************************************************** */
template <class T>
void reg_conjugateGradient<T>::Perturbation(float length) {
reg_optimiser<T>::Perturbation(length);
void ConjugateGradient<T>::Perturbation(float length) {
Optimiser<T>::Perturbation(length);
this->firstCall = true;
}
/* *************************************************************** */
template class ConjugateGradient<float>;
template class ConjugateGradient<double>;
/* *************************************************************** */
template <class T>
reg_lbfgs<T>::reg_lbfgs(): reg_optimiser<T>::reg_optimiser() {
Lbfgs<T>::Lbfgs(): Optimiser<T>::Optimiser() {
this->stepToKeep = 5;
this->oldDof = nullptr;
this->oldGrad = nullptr;
Expand All @@ -344,7 +350,7 @@ reg_lbfgs<T>::reg_lbfgs(): reg_optimiser<T>::reg_optimiser() {
}
/* *************************************************************** */
template <class T>
reg_lbfgs<T>::~reg_lbfgs() {
Lbfgs<T>::~Lbfgs() {
if (this->oldDof) {
free(this->oldDof);
this->oldDof = nullptr;
Expand Down Expand Up @@ -374,20 +380,20 @@ reg_lbfgs<T>::~reg_lbfgs() {
}
/* *************************************************************** */
template <class T>
void reg_lbfgs<T>::Initialise(size_t nvox,
int ndim,
bool optX,
bool optY,
bool optZ,
size_t maxIt,
size_t startIt,
InterfaceOptimiser *intOpt,
T *cppData,
T *gradData,
size_t nvoxBw,
T *cppDataBw,
T *gradDataBw) {
reg_optimiser<T>::Initialise(nvox, ndim, optX, optY, optZ, maxIt, startIt, intOpt, cppData, gradData, nvoxBw, cppDataBw, gradDataBw);
void Lbfgs<T>::Initialise(size_t nvox,
int ndim,
bool optX,
bool optY,
bool optZ,
size_t maxIt,
size_t startIt,
InterfaceOptimiser *intOpt,
T *cppData,
T *gradData,
size_t nvoxBw,
T *cppDataBw,
T *gradDataBw) {
Optimiser<T>::Initialise(nvox, ndim, optX, optY, optZ, maxIt, startIt, intOpt, cppData, gradData, nvoxBw, cppDataBw, gradDataBw);
this->stepToKeep = 5;
this->diffDof = (T**)malloc(this->stepToKeep * sizeof(T*));
this->diffGrad = (T**)malloc(this->stepToKeep * sizeof(T*));
Expand All @@ -404,17 +410,15 @@ void reg_lbfgs<T>::Initialise(size_t nvox,
}
/* *************************************************************** */
template <class T>
void reg_lbfgs<T>::UpdateGradientValues() {

void Lbfgs<T>::UpdateGradientValues() {
NR_FATAL_ERROR("Not implemented");
}
/* *************************************************************** */
template <class T>
void reg_lbfgs<T>::Optimise(T maxLength,
T smallLength,
T &startLength) {
void Lbfgs<T>::Optimise(T maxLength, T smallLength, T& startLength) {
this->UpdateGradientValues();
reg_optimiser<T>::Optimise(maxLength,
smallLength,
startLength);
Optimiser<T>::Optimise(maxLength, smallLength, startLength);
}
/* *************************************************************** */
} // namespace NiftyReg
/* *************************************************************** */
34 changes: 17 additions & 17 deletions reg-lib/cpu/_reg_optimiser.h → reg-lib/Optimiser.hpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
/** @file _reg_optimiser.h
/** @file Optimiser.hpp
* @author Marc Modat
* @date 20/07/2012
*/

#pragma once

#include "_reg_maths.h"
#include <stdlib.h>
#include <stdio.h>
#include <time.h>
#include "_reg_tools.h"

/* *************************************************************** */
namespace NiftyReg {
/* *************************************************************** */
/** @brief Interface between the registration class and the optimiser
*/
Expand All @@ -23,11 +22,11 @@ class InterfaceOptimiser {
virtual void UpdateBestObjFunctionValue() = 0;
};
/* *************************************************************** */
/** @class reg_optimiser
/** @class Optimiser
* @brief Standard gradient ascent optimisation
*/
template <class T>
class reg_optimiser {
class Optimiser {
protected:
bool isSymmetric;
size_t dofNumber;
Expand Down Expand Up @@ -55,8 +54,8 @@ class reg_optimiser {
virtual void UpdateGradientValues() {}

public:
reg_optimiser();
virtual ~reg_optimiser();
Optimiser();
virtual ~Optimiser();
virtual void StoreCurrentDof();
virtual void RestoreBestDof();
virtual size_t GetDofNumber() {
Expand Down Expand Up @@ -141,11 +140,11 @@ class reg_optimiser {
virtual void Perturbation(float length);
};
/* *************************************************************** */
/** @class reg_conjugateGradient
/** @class ConjugateGradient
* @brief Conjugate gradient ascent optimisation
*/
template <class T>
class reg_conjugateGradient: public reg_optimiser<T> {
class ConjugateGradient: public Optimiser<T> {
protected:
T *array1;
T *array1Bw;
Expand All @@ -159,8 +158,8 @@ class reg_conjugateGradient: public reg_optimiser<T> {
virtual void UpdateGradientValues() override;

public:
reg_conjugateGradient();
virtual ~reg_conjugateGradient();
ConjugateGradient();
virtual ~ConjugateGradient();
virtual void Initialise(size_t nvox,
int ndim,
bool optX,
Expand All @@ -184,7 +183,7 @@ class reg_conjugateGradient: public reg_optimiser<T> {
* @brief
*/
template <class T>
class reg_lbfgs: public reg_optimiser<T> {
class Lbfgs: public Optimiser<T> {
protected:
size_t stepToKeep;
T *oldDof;
Expand All @@ -198,8 +197,8 @@ class reg_lbfgs: public reg_optimiser<T> {
virtual void UpdateGradientValues() override;

public:
reg_lbfgs();
virtual ~reg_lbfgs();
Lbfgs();
virtual ~Lbfgs();
virtual void Initialise(size_t nvox,
int ndim,
bool optX,
Expand All @@ -218,4 +217,5 @@ class reg_lbfgs: public reg_optimiser<T> {
T& startLength) override;
};
/* *************************************************************** */
#include "_reg_optimiser.cpp"
} // namespace NiftyReg
/* *************************************************************** */
Loading

0 comments on commit 8182839

Please sign in to comment.