Skip to content

Commit

Permalink
Merge pull request #19 from uncbiag/feat-multiGradICON
Browse files Browse the repository at this point in the history
Add multiGradICON
  • Loading branch information
HastingsGreer authored Aug 1, 2024
2 parents e569afd + 1b6b106 commit c2c3f1e
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 5 deletions.
31 changes: 31 additions & 0 deletions .github/workflows/gpu-test-action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: gpu-tests

on:
pull_request:
push:
branches: [dev, main]

jobs:
test-linux:
runs-on: [self-hosted, linux]
strategy:
max-parallel: 5

steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
uses: actions/setup-python@v2
with:
python-version: 3.7
- name: Install dependencies
run: |
pip install -r requirements.txt
pip install -e .
- name: fast test with unittest
run: |
python -m unittest -k CPU
- name: GPU test with unittest
run: |
python -m unittest discover
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
icon_registration>=1.1.5
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ packages = find:
python_requires = >=3.7

install_requires =
icon_registration>=1.1.4
icon_registration>=1.1.5

[options.packages.find]
where = src
Expand Down
47 changes: 43 additions & 4 deletions src/unigradicon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,40 @@ def make_network(input_shape, include_last_step=False, lmbda=1.5, loss_fn=icon.L
net.assign_identity_map(input_shape)
return net

def make_sim(similarity):
if similarity == "lncc":
return icon.LNCC(sigma=5)
elif similarity == "lncc2":
return icon. SquaredLNCC(sigma=5)
elif similarity == "mind":
return icon.MINDSSC(radius=2, dilation=2)
else:
raise ValueError(f"Similarity measure {similarity} not recognized. Choose from [lncc, lncc2, mind].")

def get_multigradicon(loss_fn=icon.LNCC(sigma=5)):
net = make_network(input_shape, include_last_step=True, loss_fn=loss_fn)
from os.path import exists
weights_location = "network_weights/multigradicon1.0/Step_2_final.trch"
if not exists(weights_location):
print("Downloading pretrained multigradicon model")
import urllib.request
import os
download_path = "https://github.com/uncbiag/uniGradICON/releases/download/multigradicon_weights/Step_2_final.trch"
os.makedirs("network_weights/multigradicon1.0/", exist_ok=True)
urllib.request.urlretrieve(download_path, weights_location)
print(f"Loading weights from {weights_location}")
trained_weights = torch.load(weights_location, map_location=torch.device("cpu"))
net.regis_net.load_state_dict(trained_weights)
net.to(config.device)
net.eval()
return net

def get_unigradicon():
net = make_network(input_shape, include_last_step=True)
def get_unigradicon(loss_fn=icon.LNCC(sigma=5)):
net = make_network(input_shape, include_last_step=True, loss_fn=loss_fn)
from os.path import exists
weights_location = "network_weights/unigradicon1.0/Step_2_final.trch"
if not exists(weights_location):
print("Downloading pretrained model")
print("Downloading pretrained unigradicon model")
import urllib.request
import os
download_path = "https://github.com/uncbiag/uniGradICON/releases/download/unigradicon_weights/Step_2_final.trch"
Expand All @@ -177,6 +204,14 @@ def get_unigradicon():
net.eval()
return net

def get_model_from_model_zoo(model_name="unigradicon", loss_fn=icon.LNCC(sigma=5)):
if model_name == "unigradicon":
return get_unigradicon(loss_fn)
elif model_name == "multigradicon":
return get_multigradicon(loss_fn)
else:
raise ValueError(f"Model {model_name} not recognized. Choose from [unigradicon, multigradicon].")

def quantile(arr: torch.Tensor, q):
arr = arr.flatten()
l = len(arr)
Expand Down Expand Up @@ -241,10 +276,14 @@ def main():
default=None, type=str, help="The path to save the warped image.")
parser.add_argument("--io_iterations", required=False,
default="50", help="The number of IO iterations. Default is 50. Set to 'None' to disable IO.")
parser.add_argument("--io_sim", required=False,
default="lncc", help="The similarity measure used in IO. Default is LNCC. Choose from [lncc, lncc2, mind].")
parser.add_argument("--model", required=False,
default="unigradicon", help="The model to load. Default is unigradicon. Choose from [unigradicon, multigradicon].")

args = parser.parse_args()

net = get_unigradicon()
net = get_model_from_model_zoo(args.model, make_sim(args.io_sim))

fixed = itk.imread(args.fixed)
moving = itk.imread(args.moving)
Expand Down
110 changes: 110 additions & 0 deletions tests/test_command_arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import itk
import numpy as np
import unittest
import icon_registration.test_utils

import subprocess
import os
import torch


class TestCommandInterface(unittest.TestCase):
def __init__(self, methodName: str = "runTest") -> None:
super().__init__(methodName)
icon_registration.test_utils.download_test_data()
self.test_data_dir = icon_registration.test_utils.TEST_DATA_DIR
self.test_temp_dir = f"{self.test_data_dir}/temp"
os.makedirs(self.test_temp_dir, exist_ok=True)
self.device = torch.cuda.current_device()

def test_register_unigradicon_inference(self):
subprocess.run([
"unigradicon-register",
"--fixed", f"{self.test_data_dir}/lung_test_data/copd1_highres_EXP_STD_COPD_img.nii.gz",
"--fixed_modality", "ct",
"--fixed_segmentation", f"{self.test_data_dir}/lung_test_data/copd1_highres_EXP_STD_COPD_label.nii.gz",
"--moving", f"{self.test_data_dir}/lung_test_data/copd1_highres_INSP_STD_COPD_img.nii.gz",
"--moving_modality", "ct",
"--moving_segmentation", f"{self.test_data_dir}/lung_test_data/copd1_highres_INSP_STD_COPD_label.nii.gz",
"--transform_out", f"{self.test_temp_dir}/transform.hdf5",
"--io_iterations", "None"
])

# load transform
phi_AB = itk.transformread(f"{self.test_temp_dir}/transform.hdf5")[0]

assert isinstance(phi_AB, itk.CompositeTransform)

insp_points = icon_registration.test_utils.read_copd_pointset(
str(
icon_registration.test_utils.TEST_DATA_DIR
/ "lung_test_data/copd1_300_iBH_xyz_r1.txt"
)
)
exp_points = icon_registration.test_utils.read_copd_pointset(
str(
icon_registration.test_utils.TEST_DATA_DIR
/ "lung_test_data/copd1_300_eBH_xyz_r1.txt"
)
)

dists = []
for i in range(len(insp_points)):
px, py = (
insp_points[i],
np.array(phi_AB.TransformPoint(tuple(exp_points[i]))),
)
dists.append(np.sqrt(np.sum((px - py) ** 2)))
print(np.mean(dists))
self.assertLess(np.mean(dists), 2.1)

# remove temp file
os.remove(f"{self.test_temp_dir}/transform.hdf5")

def test_register_multigradicon_inference(self):
subprocess.run([
"unigradicon-register",
"--fixed", f"{self.test_data_dir}/lung_test_data/copd1_highres_EXP_STD_COPD_img.nii.gz",
"--fixed_modality", "ct",
"--fixed_segmentation", f"{self.test_data_dir}/lung_test_data/copd1_highres_EXP_STD_COPD_label.nii.gz",
"--moving", f"{self.test_data_dir}/lung_test_data/copd1_highres_INSP_STD_COPD_img.nii.gz",
"--moving_modality", "ct",
"--moving_segmentation", f"{self.test_data_dir}/lung_test_data/copd1_highres_INSP_STD_COPD_label.nii.gz",
"--transform_out", f"{self.test_temp_dir}/transform.hdf5",
"--io_iterations", "None",
"--model", "multigradicon"
])

# load transform
phi_AB = itk.transformread(f"{self.test_temp_dir}/transform.hdf5")[0]

assert isinstance(phi_AB, itk.CompositeTransform)

insp_points = icon_registration.test_utils.read_copd_pointset(
str(
icon_registration.test_utils.TEST_DATA_DIR
/ "lung_test_data/copd1_300_iBH_xyz_r1.txt"
)
)
exp_points = icon_registration.test_utils.read_copd_pointset(
str(
icon_registration.test_utils.TEST_DATA_DIR
/ "lung_test_data/copd1_300_eBH_xyz_r1.txt"
)
)

dists = []
for i in range(len(insp_points)):
px, py = (
insp_points[i],
np.array(phi_AB.TransformPoint(tuple(exp_points[i]))),
)
dists.append(np.sqrt(np.sum((px - py) ** 2)))
print(np.mean(dists))
self.assertLess(np.mean(dists), 3.8)

# remove temp file
os.remove(f"{self.test_temp_dir}/transform.hdf5")



19 changes: 19 additions & 0 deletions tests/test_requirements_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import unittest


class TestImports(unittest.TestCase):

def test_requirements_match_cfg(self):
from inspect import getsourcefile
import os.path as path, sys
import configparser

current_dir = path.dirname(path.abspath(getsourcefile(lambda: 0)))
parent_dir = current_dir[: current_dir.rfind(path.sep)]

with open(parent_dir + "/requirements.txt") as f:
requirements_txt = "\n" + f.read()
requirements_cfg = configparser.ConfigParser()
requirements_cfg.read(parent_dir + "/setup.cfg")
requirements_cfg = requirements_cfg["options"]["install_requires"]
self.assertEqual(requirements_txt, requirements_cfg)

0 comments on commit c2c3f1e

Please sign in to comment.