Skip to content

Commit

Permalink
Add symmetric scheme support for reg_test_conjugateGradient #92
Browse files Browse the repository at this point in the history
  • Loading branch information
onurulgen committed Jul 24, 2023
1 parent 7204698 commit 6b33dce
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 47 deletions.
2 changes: 1 addition & 1 deletion niftyreg_build_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
291
292
169 changes: 123 additions & 46 deletions reg-test/reg_test_conjugateGradient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

class ConjugateGradientTest: public InterfaceOptimiser {
protected:
using TestData = std::tuple<std::string, NiftiImage, NiftiImage, NiftiImage, NiftiImage>;
using TestCase = std::tuple<shared_ptr<Platform>, unique_ptr<F3dContent>, TestData, bool, bool, bool, float>;
using TestData = std::tuple<std::string, NiftiImage, NiftiImage, NiftiImage, NiftiImage, NiftiImage, NiftiImage>;
using TestCase = std::tuple<shared_ptr<Platform>, unique_ptr<F3dContent>, unique_ptr<F3dContent>, TestData, bool, bool, bool, float>;

inline static vector<TestCase> testCases;

Expand Down Expand Up @@ -54,13 +54,17 @@ class ConjugateGradientTest: public InterfaceOptimiser {
// Generate the different test cases
// Test 2D
NiftiImage controlPointGrid2d = CreateControlPointGrid(reference2d);
NiftiImage controlPointGridBw2d(controlPointGrid2d);
NiftiImage bestControlPointGrid2d(controlPointGrid2d, NiftiImage::Copy::ImageInfoAndAllocData);
NiftiImage transformationGradient2d(controlPointGrid2d, NiftiImage::Copy::ImageInfoAndAllocData);
NiftiImage transformationGradientBw2d(controlPointGrid2d, NiftiImage::Copy::ImageInfoAndAllocData);
auto bestCpp2dPtr = bestControlPointGrid2d.data();
auto transGrad2dPtr = transformationGradient2d.data();
auto transGradBw2dPtr = transformationGradientBw2d.data();
for (size_t i = 0; i < transformationGradient2d.nVoxels(); ++i) {
bestCpp2dPtr[i] = distr(gen);
transGrad2dPtr[i] = distr(gen);
transGradBw2dPtr[i] = distr(gen);
}

// Add the test data
Expand All @@ -69,28 +73,36 @@ class ConjugateGradientTest: public InterfaceOptimiser {
"2D",
std::move(reference2d),
std::move(controlPointGrid2d),
std::move(controlPointGridBw2d),
std::move(bestControlPointGrid2d),
std::move(transformationGradient2d)
std::move(transformationGradient2d),
std::move(transformationGradientBw2d)
));

// Test 3D
NiftiImage controlPointGrid3d = CreateControlPointGrid(reference3d);
NiftiImage controlPointGridBw3d(controlPointGrid3d);
NiftiImage bestControlPointGrid3d(controlPointGrid3d, NiftiImage::Copy::ImageInfoAndAllocData);
NiftiImage transformationGradient3d(controlPointGrid3d, NiftiImage::Copy::ImageInfoAndAllocData);
NiftiImage transformationGradientBw3d(controlPointGrid3d, NiftiImage::Copy::ImageInfoAndAllocData);
auto bestCpp3dPtr = bestControlPointGrid3d.data();
auto transGrad3dPtr = transformationGradient3d.data();
auto transGradBw3dPtr = transformationGradientBw3d.data();
for (size_t i = 0; i < transformationGradient3d.nVoxels(); ++i) {
bestCpp3dPtr[i] = distr(gen);
transGrad3dPtr[i] = distr(gen);
transGradBw3dPtr[i] = distr(gen);
}

// Add the test data
testData.emplace_back(TestData(
"3D",
std::move(reference3d),
std::move(controlPointGrid3d),
std::move(controlPointGridBw3d),
std::move(bestControlPointGrid3d),
std::move(transformationGradient3d)
std::move(transformationGradient3d),
std::move(transformationGradientBw3d)
));

// Add platforms, optimise*, and scale to the test data
Expand All @@ -104,10 +116,11 @@ class ConjugateGradientTest: public InterfaceOptimiser {
for (int optimiseZ = 0; optimiseZ < 2; optimiseZ++) {
// Make a copy of the test data
auto td = testData;
auto&& [testName, reference, controlPointGrid, bestControlPointGrid, transGrad] = td;
auto&& [testName, reference, controlPointGrid, controlPointGridBw, bestControlPointGrid, transGrad, transGradBw] = td;
// Add content
unique_ptr<F3dContent> content{ contentCreator->Create(reference, reference, controlPointGrid) };
testCases.push_back({ platform, std::move(content), std::move(td), optimiseX, optimiseY, optimiseZ, distr(gen) });
unique_ptr<F3dContent> contentBw{ contentCreator->Create(reference, reference, controlPointGridBw) };
testCases.push_back({ platform, std::move(content), std::move(contentBw), std::move(td), optimiseX, optimiseY, optimiseZ, distr(gen) });
}
}
}
Expand Down Expand Up @@ -148,38 +161,64 @@ class ConjugateGradientTest: public InterfaceOptimiser {
}
}

void UpdateGradientValues(NiftiImage& gradient, const bool& firstCall) {
void UpdateGradientValues(NiftiImage& gradient, const bool& firstCall, const bool& isSymmetric, NiftiImage *gradientBw) {
// Create array1 and array2
static NiftiImage array1;
static NiftiImage array2;
static NiftiImage array1, array1Bw;
static NiftiImage array2, array2Bw;
if (firstCall) {
array1 = NiftiImage(gradient, NiftiImage::Copy::ImageInfoAndAllocData);
array2 = NiftiImage(gradient, NiftiImage::Copy::ImageInfoAndAllocData);
array1 = array2 = NiftiImage(gradient, NiftiImage::Copy::ImageInfoAndAllocData);
if (isSymmetric)
array1Bw = array2Bw = NiftiImage(*gradientBw, NiftiImage::Copy::ImageInfoAndAllocData);
}

auto gradientPtr = gradient.data();
auto array1Ptr = array1.data();
auto array2Ptr = array2.data();
NiftiImageData gradientBwPtr, array1BwPtr, array2BwPtr;
if (isSymmetric) {
gradientBwPtr = gradientBw->data();
array1BwPtr = array1Bw.data();
array2BwPtr = array2Bw.data();
}

if (firstCall) {
// Initialise array1 and array2
for (size_t i = 0; i < gradient.nVoxels(); i++)
array2Ptr[i] = array1Ptr[i] = -static_cast<float>(gradientPtr[i]);
if (isSymmetric) {
for (size_t i = 0; i < gradientBw->nVoxels(); i++)
array2BwPtr[i] = array1BwPtr[i] = -static_cast<float>(gradientBwPtr[i]);
}
} else {
// Calculate gam
double dgg = 0, gg = 0;
for (size_t i = 0; i < gradient.nVoxels(); i++) {
gg += static_cast<float>(array2Ptr[i]) * static_cast<float>(array1Ptr[i]);
dgg += (static_cast<float>(gradientPtr[i]) + static_cast<float>(array1Ptr[i])) * static_cast<float>(gradientPtr[i]);
}
const double gam = dgg / gg;
double gam = dgg / gg;
if (isSymmetric) {
double dggBw = 0, ggBw = 0;
for (size_t i = 0; i < gradientBw->nVoxels(); i++) {
ggBw += static_cast<float>(array2BwPtr[i]) * static_cast<float>(array1BwPtr[i]);
dggBw += (static_cast<float>(gradientBwPtr[i]) + static_cast<float>(array1BwPtr[i])) * static_cast<float>(gradientBwPtr[i]);
}
gam = (dgg + dggBw) / (gg + ggBw);
}

// Update gradient values
for (size_t i = 0; i < gradient.nVoxels(); i++) {
array1Ptr[i] = -static_cast<float>(gradientPtr[i]);
array2Ptr[i] = static_cast<float>(array1Ptr[i]) + gam * static_cast<float>(array2Ptr[i]);
gradientPtr[i] = -static_cast<float>(array2Ptr[i]);
}
if (isSymmetric) {
for (size_t i = 0; i < gradientBw->nVoxels(); i++) {
array1BwPtr[i] = -static_cast<float>(gradientBwPtr[i]);
array2BwPtr[i] = static_cast<float>(array1BwPtr[i]) + gam * static_cast<float>(array2BwPtr[i]);
gradientBwPtr[i] = -static_cast<float>(array2BwPtr[i]);
}
}
}
}

Expand All @@ -193,8 +232,8 @@ TEST_CASE_METHOD(ConjugateGradientTest, "Conjugate gradient", "[ConjugateGradien
// Loop over all generated test cases
for (auto&& testCase : testCases) {
// Retrieve test information
auto&& [platform, content, testData, optimiseX, optimiseY, optimiseZ, scale] = testCase;
auto&& [testName, reference, controlPointGrid, bestControlPointGrid, transGrad] = testData;
auto&& [platform, content, contentBw, testData, optimiseX, optimiseY, optimiseZ, scale] = testCase;
auto&& [testName, reference, controlPointGrid, controlPointGridBw, bestControlPointGrid, transGrad, transGradBw] = testData;
const std::string sectionName = testName + " " + platform->GetName() + " " + (optimiseX ? "X" : "noX") + " " + (optimiseY ? "Y" : "noY") + " " + (optimiseZ ? "Z" : "noZ") + " scale = " + std::to_string(scale);

SECTION(sectionName) {
Expand All @@ -207,11 +246,15 @@ TEST_CASE_METHOD(ConjugateGradientTest, "Conjugate gradient", "[ConjugateGradien
img.disown();
content->UpdateControlPointGrid();

// Set the transformation gradient
// Set the transformation gradients
img = content->GetTransformationGradient();
img.copyData(transGrad);
img.disown();
content->UpdateTransformationGradient();
img = contentBw->GetTransformationGradient();
img.copyData(transGradBw);
img.disown();
contentBw->UpdateTransformationGradient();

// Create a copy of the control point grid for expected results
NiftiImage controlPointGridExpected = bestControlPointGrid;
Expand All @@ -237,41 +280,75 @@ TEST_CASE_METHOD(ConjugateGradientTest, "Conjugate gradient", "[ConjugateGradien
// Update the gradient values
// Only run once by discarding other optimiseX, optimiseY, optimiseZ combinations
if (!optimiseX && !optimiseY && !optimiseZ) {
std::cout << "\n**************** UpdateGradientValues " << sectionName << " ****************" << std::endl;

// Initialise the conjugate gradient
optimiser->UpdateGradientValues();
UpdateGradientValues(transGrad, true);
// Fill the gradient with random values
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> distr(0, 1);
auto gradientPtr = transGrad.data();
for (size_t i = 0; i < transGrad.nVoxels(); i++)
gradientPtr[i] = distr(gen);
// Update the transformation gradient
img = content->GetTransformationGradient();
img.copyData(transGrad);
img.disown();
content->UpdateTransformationGradient();
// Get the gradient values
optimiser->UpdateGradientValues();
UpdateGradientValues(transGrad, false);

// Check the results
img = content->GetTransformationGradient();
const auto gradPtr = img.data();
const auto gradExpPtr = transGrad.data();
img.disown();
for (size_t i = 0; i < transGrad.nVoxels(); ++i) {
const float gradVal = gradPtr[i];
const float gradExpVal = gradExpPtr[i];
std::cout << i << " " << gradVal << " " << gradExpVal << std::endl;
REQUIRE(fabs(gradVal - gradExpVal) < EPS);
for (int isSymmetric = 0; isSymmetric < 2; isSymmetric++) {
std::cout << "\n**************** UpdateGradientValues " << sectionName + (isSymmetric ? " Symmetric" : "") << " ****************" << std::endl;

// Create a random number generator
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> distr(0, 1);

// Create a symmetric optimiser if required
if (isSymmetric)
optimiser.reset(platform->template CreateOptimiser<float>(*content, *this, 0, true, optimiseX, optimiseY, optimiseZ, contentBw.get()));

// Initialise the conjugate gradients
optimiser->UpdateGradientValues();
UpdateGradientValues(transGrad, true, isSymmetric, &transGradBw);

// Fill the gradients with random values
auto gradientPtr = transGrad.data();
auto gradientBwPtr = transGradBw.data();
for (size_t i = 0; i < transGrad.nVoxels(); i++) {
gradientPtr[i] = distr(gen);
if (isSymmetric)
gradientBwPtr[i] = distr(gen);
}
// Update the transformation gradients
img = content->GetTransformationGradient();
img.copyData(transGrad);
img.disown();
content->UpdateTransformationGradient();
if (isSymmetric) {
img = contentBw->GetTransformationGradient();
img.copyData(transGradBw);
img.disown();
contentBw->UpdateTransformationGradient();
}

// Get the gradient values
optimiser->UpdateGradientValues();
UpdateGradientValues(transGrad, false, isSymmetric, &transGradBw);

// Check the results
img = content->GetTransformationGradient();
const auto gradPtr = img.data();
const auto gradExpPtr = transGrad.data();
img.disown();
NiftiImageData gradBwPtr, gradExpBwPtr;
if (isSymmetric) {
img = contentBw->GetTransformationGradient();
gradBwPtr = img.data();
gradExpBwPtr = transGradBw.data();
img.disown();
}
for (size_t i = 0; i < transGrad.nVoxels(); ++i) {
const float gradVal = gradPtr[i];
const float gradExpVal = gradExpPtr[i];
std::cout << i << " " << gradVal << " " << gradExpVal << std::endl;
REQUIRE(fabs(gradVal - gradExpVal) < EPS);
if (isSymmetric) {
const float gradBwVal = gradBwPtr[i];
const float gradExpBwVal = gradExpBwPtr[i];
std::cout << i << " " << gradBwVal << " " << gradExpBwVal << " backwards" << std::endl;
REQUIRE(fabs(gradBwVal - gradExpBwVal) < EPS);
}
}
}
}
// Ensure the termination of content before CudaContext
content.reset();
contentBw.reset();
}
}
}

0 comments on commit 6b33dce

Please sign in to comment.