Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added use of C2FViT to run Affine and Rigid alignment #88

Merged
merged 41 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
bf97489
Update README.md
Pei-mao May 6, 2024
6c2cd40
Update README.md
Pei-mao May 11, 2024
a273f10
Merge branch 'htylab:main' into main
Pei-mao May 11, 2024
1656f5b
Add affine and registration
Pei-mao May 11, 2024
b2ce682
Update setup.py
Pei-mao May 11, 2024
7d44ac7
Merge branch 'main' into main
Pei-mao May 11, 2024
72d32d7
Update setup.py
Pei-mao May 11, 2024
997d7ec
Merge branch 'htylab:main' into main
Pei-mao May 12, 2024
9444cee
Merge branch 'htylab:main' into main
Pei-mao May 20, 2024
46f14eb
Modify MNI152
Pei-mao May 20, 2024
099c825
Merge branch 'main' of https://github.com/Pei-mao/tigerbx
Pei-mao May 25, 2024
0b3f3b2
add Evaluate_registration
Pei-mao May 27, 2024
1d48174
Update validation.md
Pei-mao May 27, 2024
794feb9
Merge branch 'main' of https://github.com/Pei-mao/tigerbx
Pei-mao May 28, 2024
1d751ff
Modify the validate call
Pei-mao May 28, 2024
94615a5
Merge branch 'htylab:main' into main
Pei-mao May 29, 2024
0cd5896
Merge branch 'htylab:main' into main
Pei-mao May 30, 2024
3961eb5
Change MNI152 position, updata reg_v002 evaluation
Pei-mao May 30, 2024
4bf4746
Merge branch 'htylab:main' into main
Pei-mao May 31, 2024
3bcec88
use tempfile method
Pei-mao May 31, 2024
b686c7a
Merge branch 'htylab:main' into main
Pei-mao May 31, 2024
27c01eb
Add template parameter
Pei-mao May 31, 2024
00c0f54
Revise template call
Pei-mao Jun 1, 2024
1769cbc
Merge branch 'htylab:main' into main
Pei-mao Jun 2, 2024
20e8787
Revise user template
Pei-mao Jun 2, 2024
71ff3de
Merge branch 'htylab:main' into main
Pei-mao Jun 2, 2024
94720b5
Change validate from continuous to nearest
Pei-mao Jun 2, 2024
c15acf2
Merge branch 'htylab:main' into main
Pei-mao Jun 6, 2024
076bb5b
Merge branch 'htylab:main' into main
Pei-mao Jun 7, 2024
ee99001
Add rigid parameter
Pei-mao Jun 7, 2024
2b9319e
encoder and decoder
Pei-mao Jun 7, 2024
3c48b2c
revise
Pei-mao Jun 7, 2024
ec70487
Merge branch 'htylab:main' into main
Pei-mao Jun 8, 2024
f646822
revise validate.py
Pei-mao Jun 9, 2024
aa9b8ea
Merge branch 'htylab:main' into main
Pei-mao Jun 10, 2024
333af81
Correct the data of validate.py
Pei-mao Jun 10, 2024
aeda361
Correct the data of validate.py
Pei-mao Jun 10, 2024
13678e9
test
Pei-mao Jul 11, 2024
b9e6149
test
Pei-mao Jul 11, 2024
e206545
Merge branch 'htylab:main' into main
Pei-mao Jul 22, 2024
90147a6
Added use of C2FViT to run Affine alignment
Pei-mao Aug 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions doc/validation.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
df, metric3 = tigerbx.val('aseg_123', 'aseg', 'temp', GPU=True)
df, metric4 = tigerbx.val('dgm_123', 'aseg', 'temp', GPU=True)
df, metric5 = tigerbx.val('syn_123', 'aseg', 'temp', GPU=True)
df, metric6 = tigerbx.val('reg_60', 'aseg, 'temp', GPU=True, template='Template_T1_tbet.nii.gz')
df, metric6 = tigerbx.val('reg_60', 'aseg', 'temp', GPU=True, template='Template_T1_tbet.nii.gz')


print('bet_NFBS', metric1)
Expand Down Expand Up @@ -46,7 +46,7 @@
| Amygdala | L | 0.737 | 0.764| 0.716| R | 0.727 | 0.750| 0.711|
| Mean | L | 0.833 | 0.846| 0.820| R | 0.829 | 0.841| 0.807|
#### Registration:
mean dice: 0.808
mean dice: 0.797

#### Skull Stripping
bet_NFBS: 0.973
Expand Down
2 changes: 1 addition & 1 deletion doc/validation_archive.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
| Amygdala | L | 0.737 | 0.764| 0.716| R | 0.727 | 0.750| 0.711|
| Mean | L | 0.833 | 0.846| 0.820| R | 0.829 | 0.841| 0.807|
#### Registration:
mean dice: 0.808
mean dice: 0.797

#### Skull Stripping
bet_NFBS: 0.973
Expand Down
110 changes: 69 additions & 41 deletions tigerbx/bx.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tigerbx import lib_tool
from tigerbx import lib_bx
import copy
from nilearn.image import resample_to_img, reorder_img
from nilearn.image import resample_to_img, reorder_img, resample_img

# determine if application is a script file or frozen exe
if getattr(sys, 'frozen', False):
Expand Down Expand Up @@ -225,6 +225,8 @@ def run_args(args):
omodel['cgw'] = 'mprage_cgw_v001_r111.onnx'
omodel['syn'] = 'mprage_synthseg_v003_r111.onnx'
omodel['reg'] = 'mprage_reg_v002_train.onnx'
omodel['affine'] = 'mprage_affine_v001_train.onnx'
omodel['rigid'] = 'mprage_rigid_v001_train.onnx'
omodel['encode'] = 'mprage_encode_v1.onnx'
omodel['decode'] = 'mprage_decode_v1.onnx'

Expand Down Expand Up @@ -422,59 +424,85 @@ def run_args(args):
result_filedict['ct'] = fn

if run_d['affine'] or run_d['rigid'] or run_d['registration']:
import SimpleITK as sitk
bet = lib_bx.read_nib(input_nib) * lib_bx.read_nib(tbetmask_nib)
bet = bet.astype(input_nib.dataobj.dtype)
bet_nib = nib.Nifti1Image(bet, input_nib.affine, input_nib.header)
bet_nib = reorder_img(bet_nib, resample='continuous')
#bet = bet_nib.get_fdata()
bet_sitk = lib_bx.from_nib_get_sitk(bet_nib)

bet_nib = reorder_img(bet_nib, resample='continuous')
ori_affine = bet_nib.affine
bet_data = bet_nib.get_fdata()
bet_data, _ = lib_bx.pad_to_shape(bet_data, (256, 256, 256))
bet_data, _ = lib_bx.crop_image(bet_data, target_shape=(256, 256, 256))

template_nib = lib_tool.get_template(run_d['template'])
template_sitk = lib_bx.from_nib_get_sitk(template_nib)

#template_nib = reorder_img(template_nib, resample='continuous')
template_nib = reorder_img(template_nib, resample='continuous')
fixed_affine = template_nib.affine
template_data = template_nib.get_fdata()
template_data, pad_width = lib_bx.pad_to_shape(template_data, (256, 256, 256))

moving = bet_data.astype(np.float32)[None, ...][None, ...]
moving = lib_bx.min_max_norm(moving)
if run_d['template'] == None:
template_data = np.clip(template_data, a_min=2500, a_max=np.max(template_data))
fixed = template_data.astype(np.float32)[None, ...][None, ...]
fixed = lib_bx.min_max_norm(fixed)

if run_d['rigid']:
rigid_sitk, final_transform = lib_bx.affine_reg(template_sitk, bet_sitk, mode='rigid')
rigid_nib = lib_bx.from_sitk_get_nib(rigid_sitk)
model_ff = lib_tool.get_model(omodel['rigid'])
output = lib_tool.predict(model_ff, [moving, fixed], GPU=args.gpu, mode='reg')
rigided, regid_matrix = np.squeeze(output[0]), np.squeeze(output[1])
rigided = lib_bx.remove_padding(rigided, pad_width)

rigid_nib = nib.Nifti1Image(rigided, fixed_affine)
fn = save_nib(rigid_nib, ftemplate, 'rigid')
result_dict['rigid'] = rigid_nib
result_filedict['rigid'] = fn

Af_sitk, final_transform = lib_bx.affine_reg(template_sitk, bet_sitk)
Af_nib = lib_bx.from_sitk_get_nib(Af_sitk)

result_dict['Affine_matrix'] = final_transform
if run_d['affine']:
fn = save_nib(Af_nib, ftemplate, 'Af')
result_dict['Af'] = Af_nib
result_filedict['Af'] = fn

if run_d['registration']:
Af_data = Af_nib.get_fdata()
moving_image = Af_data.astype(np.float32)[None, ...][None, ...]
moving_image = moving_image/np.max(moving_image)
fixed_image = template_data.astype(np.float32)[None, ...][None, ...]
fixed_image = fixed_image/np.max(fixed_image)
model_ff = lib_tool.get_model(omodel['reg'])

output = lib_tool.predict(model_ff, [moving_image, fixed_image], GPU=args.gpu, mode='reg')
moved = np.squeeze(output[0])
warp = np.squeeze(output[1])
moved_nib = nib.Nifti1Image(moved,
template_nib.affine, template_nib.header)
warp_nib = nib.Nifti1Image(warp,
template_nib.affine, template_nib.header)
if run_d['affine'] or run_d['registration']:

fn = save_nib(moved_nib, ftemplate, 'reg')
result_dict['reg'] = moved_nib
result_filedict['reg'] = fn

#fn = save_nib(warp_nib, ftemplate, 'dense_warp')
result_dict['dense_warp'] = warp_nib
#result_filedict['dense_warp'] = fn
model_ff = lib_tool.get_model(omodel['affine'])
output = lib_tool.predict(model_ff, [moving, fixed], GPU=args.gpu, mode='reg')
affined, affine_matrix, init_flow = np.squeeze(output[0]), np.squeeze(output[1]), output[2]
initflow_nib = nib.Nifti1Image(init_flow, ori_affine)
result_dict['init_flow'] = initflow_nib
affined = lib_bx.remove_padding(affined, pad_width)
affine_nib = nib.Nifti1Image(affined, fixed_affine)

result_dict['Affine_matrix'] = affine_matrix
if run_d['affine']:
fn = save_nib(affine_nib, ftemplate, 'Af')
result_dict['Af'] = affine_nib
result_filedict['Af'] = fn

if run_d['registration']:
template_data = template_nib.get_fdata()

fixed_image = template_data.astype(np.float32)[None, ...][None, ...]
fixed_image = lib_bx.min_max_norm(fixed_image)
#fixed_image = fixed_image/np.max(fixed_image)

Af_data = affine_nib.get_fdata()
moving_image = Af_data.astype(np.float32)[None, ...][None, ...]
#moving_image = moving_image/np.max(moving_image)

model_ff = lib_tool.get_model(omodel['reg'])

output = lib_tool.predict(model_ff, [moving_image, fixed_image], GPU=args.gpu, mode='reg')
moved = np.squeeze(output[0])
warp = np.squeeze(output[1])
moved_nib = nib.Nifti1Image(moved,
fixed_affine, template_nib.header)
warp_nib = nib.Nifti1Image(warp,
fixed_affine, template_nib.header)

fn = save_nib(moved_nib, ftemplate, 'reg')
result_dict['reg'] = moved_nib
result_filedict['reg'] = fn

#fn = save_nib(warp_nib, ftemplate, 'dense_warp')
result_dict['dense_warp'] = warp_nib
#result_filedict['dense_warp'] = fn

print('Processing time: %d seconds' % (time.time() - t))
if len(input_file_list) == 1:
Expand Down
39 changes: 39 additions & 0 deletions tigerbx/lib_bx.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,4 +329,43 @@ def from_sitk_get_nib(sitk_image):
return nii_image_copy


def pad_to_shape(img, target_shape):
"""
Pads the input image with zeros to match the target shape.
"""
padding = [(max(0, t - s)) for s, t in zip(img.shape, target_shape)]
pad_width = [(p // 2, p - (p // 2)) for p in padding]
padded_img = np.pad(img, pad_width, mode='constant', constant_values=0)
return padded_img, pad_width


def min_max_norm(img):
max = np.max(img)
min = np.min(img)

norm_img = (img - min) / (max - min)

return norm_img


def crop_image(image, target_shape):
"""Crops the image to the target shape."""
current_shape = image.shape
crop_slices = []

for i in range(len(target_shape)):
start = (current_shape[i] - target_shape[i]) // 2
end = start + target_shape[i]
crop_slices.append(slice(start, end))

cropped_image = image[tuple(crop_slices)]
return cropped_image, crop_slices


def remove_padding(padded_img, pad_width):
"""
Removes the padding from the input image based on the pad_width.
"""
slices = [slice(p[0], -p[1] if p[1] != 0 else None) for p in pad_width]
cropped_img = padded_img[tuple(slices)]
return cropped_img
7 changes: 6 additions & 1 deletion tigerbx/lib_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,15 @@ def download(url, file_name):
def get_template(template_ff):
mni_template = nib.load(join(application_path, 'template', 'MNI152_T1_1mm_brain.nii.gz'))
mni_affine = mni_template.affine

if template_ff:
full_path = join(application_path, 'template', template_ff)
if isfile(template_ff):
full_path = template_ff

if isfile(full_path):
user_template_nib = nib.load(full_path)
#resampled_template = lib_bx.resample_voxel(user_template_nib, (1, 1, 1), (256, 256, 256))
resampled_template = resample_img(user_template_nib, target_affine=mni_affine, target_shape=[160, 224, 192])
return resampled_template
else:
Expand Down Expand Up @@ -260,7 +262,10 @@ def predict(model, data, GPU, mode=None):
input_names = [input.name for input in session.get_inputs()]
inputs = {input_names[0]: data[0], input_names[1]: data[1]}
return session.run(None, inputs)

if mode == 'affine_transform':
input_names = [input.name for input in session.get_inputs()]
inputs = {input_names[0]: data[0], input_names[1]: data[1], input_names[2]: data[2]}
return session.run(None, inputs)
if mode == 'encode':
mu, sigma = session.run(None, {session.get_inputs()[0].name: data.astype(data_type)}, )
return mu, sigma
Expand Down
Binary file modified tigerbx/template/MNI152_T1_1mm_brain.nii.gz
Binary file not shown.
Binary file modified tigerbx/template/MNI152_T1_1mm_brain_aseg.nii.gz
Binary file not shown.
39 changes: 25 additions & 14 deletions tigerbx/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,30 +208,41 @@ def val(argstring, input_dir, output_dir=None, model=None, GPU=False, debug=Fals
result = tigerbx.run(gpu_str + 'r', f, output_dir, model=model, template=template)

model_transform = lib_tool.get_model('mprage_transform.onnx')
import SimpleITK as sitk
model_affine_transform = lib_tool.get_model('mprage_affine_transform_v001_train.onnx')

template_nib = lib_tool.get_template(template)
template_sitk = lib_bx.from_nib_get_sitk(template_nib)
template_nib = reorder_img(template_nib, resample='continuous')
template_data = template_nib.get_fdata()
template_data, pad_width = lib_bx.pad_to_shape(template_data, (256, 256, 256))

moving_seg_sitk = sitk.ReadImage(f.replace('raw60', 'label60'), sitk.sitkFloat32)
Af_seg_sitk = lib_bx.affine_transform(template_sitk, moving_seg_sitk, result['Affine_matrix'])
Af_seg_nib = lib_bx.from_sitk_get_nib(Af_seg_sitk)

Af_seg_data = Af_seg_nib.get_fdata().astype(np.float32)
Af_seg_data = np.expand_dims(Af_seg_data, axis=0)
Af_seg_data = np.expand_dims(Af_seg_data, axis=0)
moving_seg_nib = nib.load(f.replace('raw60', 'label60'))
moving_seg_nib = reorder_img(moving_seg_nib, resample='nearest')
moving_seg_data = moving_seg_nib.get_fdata().astype(np.float32)
moving_seg_data, _ = lib_bx.pad_to_shape(moving_seg_data, (256, 256, 256))
moving_seg_data, _ = lib_bx.crop_image(moving_seg_data, target_shape=(256, 256, 256))
moving_seg = np.expand_dims(np.expand_dims(moving_seg_data, axis=0), axis=1)

init_flow = result['init_flow'].get_fdata().astype(np.float32)
Affine_matrix = result['Affine_matrix'].astype(np.float32)
Affine_matrix= np.expand_dims(Affine_matrix, axis=0)

output = lib_tool.predict(model_affine_transform, [moving_seg, init_flow, Affine_matrix], GPU=None, mode='affine_transform')
moved_seg = np.squeeze(output[0])

moved_seg = lib_bx.remove_padding(moved_seg, pad_width)

moved_seg = np.expand_dims(np.expand_dims(moved_seg, axis=0), axis=1)
warp = result['dense_warp'].get_fdata().astype(np.float32)
warp = np.expand_dims(warp, axis=0)
output = lib_tool.predict(model_transform, [Af_seg_data, warp], GPU=None, mode='reg')
output = lib_tool.predict(model_transform, [moved_seg, warp], GPU=None, mode='reg')
moved_seg = np.squeeze(output[0])
moved_seg_nib = nib.Nifti1Image(moved_seg,
template_nib.affine, template_nib.header)



mask_pred = reorder_img(moved_seg_nib, resample='nearest').get_fdata().astype(int)
template_seg = lib_tool.get_template_seg(template)
mask_gt = reorder_img(template_seg, resample='nearest').get_fdata().astype(int)



dice26 = get_dice26(mask_gt, mask_pred)
dsc_list.append(dice26)

Expand Down
Loading