Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
grt123 processing adaptation
Browse files Browse the repository at this point in the history
  • Loading branch information
vessemer committed Oct 18, 2017
1 parent dca2e70 commit 389824f
Show file tree
Hide file tree
Showing 12 changed files with 271 additions and 219 deletions.
75 changes: 11 additions & 64 deletions prediction/src/algorithms/classify/src/gtr123_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import torch
import numpy as np
import SimpleITK as sitk

import torch
from src.preprocess import load_ct, preprocess_ct, crop_patches
from torch import nn
from torch.autograd import Variable

from ....preprocess.gtr123_preprocess import lum_trans, resample

""""
Classification model from team gtr123
Code adapted from https://github.com/lfz/DSB2017
Expand Down Expand Up @@ -224,55 +221,11 @@ def forward(self, xlist, coordlist):
return nodulePred, casePred, out


class SimpleCrop(object):
""" """

def __init__(self):
self.crop_size = config['crop_size']
self.scaleLim = config['scaleLim']
self.radiusLim = config['radiusLim']
self.stride = config['stride']
self.filling_value = config['filling_value']

def __call__(self, imgs, target):
crop_size = np.array(self.crop_size).astype('int')

start = (target[:3] - crop_size / 2).astype('int')
pad = [[0, 0]]

for i in range(3):
if start[i] < 0:
leftpad = -start[i]
start[i] = 0
else:
leftpad = 0
if start[i] + crop_size[i] > imgs.shape[i + 1]:
rightpad = start[i] + crop_size[i] - imgs.shape[i + 1]
else:
rightpad = 0

pad.append([leftpad, rightpad])

imgs = np.pad(imgs, pad, 'constant', constant_values=self.filling_value)
crop = imgs[:, start[0]:start[0] + crop_size[0], start[1]:start[1] + crop_size[1],
start[2]:start[2] + crop_size[2]]

normstart = np.array(start).astype('float32') / np.array(imgs.shape[1:]) - 0.5
normsize = np.array(crop_size).astype('float32') / np.array(imgs.shape[1:])
xx, yy, zz = np.meshgrid(np.linspace(normstart[0], normstart[0] + normsize[0], self.crop_size[0] / self.stride),
np.linspace(normstart[1], normstart[1] + normsize[1], self.crop_size[1] / self.stride),
np.linspace(normstart[2], normstart[2] + normsize[2], self.crop_size[2] / self.stride),
indexing='ij')
coord = np.concatenate([xx[np.newaxis, ...], yy[np.newaxis, ...], zz[np.newaxis, :]], 0).astype('float32')

return crop, coord


def predict(image_itk, nodule_list, model_path="src/algorithms/classify/assets/gtr123_model.ckpt"):
def predict(ct_path, nodule_list, model_path="src/algorithms/classify/assets/gtr123_model.ckpt"):
"""
Args:
image_itk: ITK dicom image
ct_path (str): path to a MetaImage or DICOM data.
nodule_list: List of nodules
model_path: Path to the torch model (Default value = "src/algorithms/classify/assets/gtr123_model.ckpt")
Expand All @@ -292,20 +245,14 @@ def predict(image_itk, nodule_list, model_path="src/algorithms/classify/assets/g
# else:
# casenet = torch.nn.parallel.DistributedDataParallel(casenet)

image = sitk.GetArrayFromImage(image_itk)
spacing = np.array(image_itk.GetSpacing())[::-1]
image = lum_trans(image)
image = resample(image, spacing, np.array([1, 1, 1]), order=1)[0]

crop = SimpleCrop()

preprocess = preprocess_ct.PreprocessCT(clip_lower=-1200., clip_upper=600., spacing=1., order=1,
min_max_normalize=True, scale=255, dtype='uint8')
ct_array, meta = preprocess(*load_ct.load_ct(ct_path))
patches = crop_patches.patches_from_ct(ct_array, meta, config['crop_size'], nodule_list,
stride=config['stride'], pad_value=config['filling_value'])
results = []
for nodule in nodule_list:
print(nodule)
nod_location = np.array([np.float32(nodule[s]) for s in ["z", "y", "x"]])
nod_location *= spacing
cropped_image, coords = crop(image[np.newaxis], nod_location)
cropped_image = Variable(torch.from_numpy(cropped_image[np.newaxis]).float())
for nodule, (cropped_image, coords) in zip(nodule_list, patches):
cropped_image = Variable(torch.from_numpy(cropped_image[np.newaxis, np.newaxis]).float())
cropped_image.volatile = True
coords = Variable(torch.from_numpy(coords[np.newaxis]).float())
coords.volatile = True
Expand Down
32 changes: 2 additions & 30 deletions prediction/src/algorithms/classify/trained_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,10 @@
for if nodules are concerning or not.
"""


import SimpleITK as sitk
from src.algorithms.classify.src import gtr123_model
from src.preprocess.load_ct import load_ct


def predict(dicom_path, centroids, model_path=None,
preprocess_ct=None, preprocess_model_input=None):
def predict(dicom_path, centroids, model_path=None):
""" Predicts if centroids are concerning or not.
Given path to a DICOM image and an iterator of centroids:
Expand All @@ -31,10 +27,6 @@ def predict(dicom_path, centroids, model_path=None,
'y': int,
'z': int}
model_path (str): A path to the serialized model
preprocess_ct (preprocess.preprocess_dicom.PreprocessDicom): A preprocess
method which aimed at brining the input data to the desired view.
preprocess_model_input (callable[ndarray, list[dict]]): preprocess for a model
input.
Returns:
list[dict]: a list of centroids with the probability they are
Expand All @@ -44,25 +36,5 @@ def predict(dicom_path, centroids, model_path=None,
'z': int,
'p_concerning': float}
"""
reader = sitk.ImageSeriesReader()
filenames = reader.GetGDCMSeriesFileNames(dicom_path)

if not filenames:
raise ValueError("The path doesn't contain neither .mhd nor .dcm files")

reader.SetFileNames(filenames)
image = reader.Execute()

if preprocess_ct:
voxel_data = preprocess_ct(load_ct(dicom_path))
else:
voxel_data = image

if preprocess_model_input:
preprocessed = preprocess_model_input(voxel_data, centroids)
else:
preprocessed = voxel_data

model_path = model_path or "src/algorithms/classify/assets/gtr123_model.ckpt"

return gtr123_model.predict(preprocessed, centroids, model_path)
return gtr123_model.predict(dicom_path, centroids, model_path)
35 changes: 18 additions & 17 deletions prediction/src/algorithms/identify/src/gtr123_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import numpy as np
import torch
from scipy.special import expit
from src.preprocess import preprocess_ct, load_ct
from src.preprocess.extract_lungs import extract_lungs
from torch import nn
from torch.autograd import Variable
from scipy.special import expit

import SimpleITK as sitk
import numpy as np

from ....preprocess.extract_lungs import extract_lungs
from ....preprocess.gtr123_preprocess import lum_trans, resample

""""
Detector model from team gtr123
Expand All @@ -34,11 +31,12 @@
config['r_rand_crop'] = 0.3
config['pad_value'] = 170

__all__ = ["Net", "lum_trans", "resample", "GetPBB", "SplitComb"]
__all__ = ["Net", "GetPBB", "SplitComb"]


class PostRes(nn.Module):
""" """

def __init__(self, n_in, n_out, stride=1):
super(PostRes, self).__init__()
self.conv1 = nn.Conv3d(n_in, n_out, kernel_size=3, stride=stride, padding=1)
Expand Down Expand Up @@ -181,6 +179,7 @@ def forward(self, x, coord):

class GetPBB(object):
""" """

def __init__(self, stride=4, anchors=(10.0, 30.0, 60.)):
self.stride = stride
self.anchors = np.asarray(anchors)
Expand Down Expand Up @@ -211,6 +210,7 @@ def __call__(self, output, thresh=-3, ismask=False):

class SplitComb(object):
""" """

def __init__(self, side_len, max_stride, stride, margin, pad_value):
self.side_len = side_len
self.max_stride = max_stride
Expand Down Expand Up @@ -446,7 +446,7 @@ def filter_lungs(image, spacing=(1, 1, 1), fill_value=170):
return extracted, mask


def predict(image_itk, model_path="src/algorithms/identify/assets/dsb2017_detector.ckpt"):
def predict(ct_path, model_path="src/algorithms/identify/assets/dsb2017_detector.ckpt"):
"""
Args:
Expand All @@ -458,10 +458,10 @@ def predict(image_itk, model_path="src/algorithms/identify/assets/dsb2017_detect
List of Nodule locations and probabilities
"""

spacing = np.array(image_itk.GetSpacing())[::-1]
image = sitk.GetArrayFromImage(image_itk)
masked_image, mask = filter_lungs(image)
ct_array, meta = load_ct.load_ct(ct_path)
meta = load_ct.MetaImage(meta)
spacing = np.array(meta.spacing)
masked_image, mask = filter_lungs(ct_array)
# masked_image = image
net = Net()
net.load_state_dict(torch.load(model_path)["state_dict"])
Expand All @@ -473,11 +473,12 @@ def predict(image_itk, model_path="src/algorithms/identify/assets/dsb2017_detect
# We have to use small batches until the next release of PyTorch, as bigger ones will segfault for CPU
# split_comber = SplitComb(side_len=int(32), margin=16, max_stride=16, stride=4, pad_value=170)
# Transform image to the 0-255 range and resample to 1x1x1mm
imgs = lum_trans(masked_image)
imgs = resample(imgs, spacing, np.array([1, 1, 1]), order=1)[0]
imgs = imgs[np.newaxis, ...]
preprocess = preprocess_ct.PreprocessCT(clip_lower=-1200., clip_upper=600., spacing=1., order=1,
min_max_normalize=True, scale=255, dtype='uint8')
ct_array, meta = preprocess(ct_array, meta)
ct_array = ct_array[np.newaxis, ...]

imgT, coords, nzhw = split_data(imgs, split_comber=split_comber)
imgT, coords, nzhw = split_data(ct_array, split_comber=split_comber)
results = []
# Loop over the image chunks
for img, coord in zip(imgT, coords):
Expand Down
54 changes: 40 additions & 14 deletions prediction/src/preprocess/crop_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import scipy.ndimage
from src.preprocess import load_ct


def mm2voxel(coord, origin=0., spacing=1.):
Expand All @@ -20,13 +21,11 @@ def mm2voxel(coord, origin=0., spacing=1.):
coord = np.array(coord)
origin = scipy.ndimage._ni_support._normalize_sequence(origin, len(coord))
spacing = scipy.ndimage._ni_support._normalize_sequence(spacing, len(coord))
coord = np.ceil(coord - np.array(origin)) / np.array(spacing)
print(coord)
print(origin)
return coord.astype(np.int).tolist()
coord = np.ceil((coord - np.array(origin)) / np.array(spacing))
return coord.astype(np.int)


def crop_patch(ct_array, meta, patch_shape=None, centroids=None):
def crop_patch(ct_array, meta, patch_shape=None, centroids=None, stride=None, pad_value=0):
""" Generator yield a patch of a desired shape for each centroid
from a given a CT scan.
Expand All @@ -39,27 +38,51 @@ def crop_patch(ct_array, meta, patch_shape=None, centroids=None):
'y': int,
'z': int}
meta (src.preprocess.load_ct.MetaData): meta information of the CT scan.
stride (int): stride for patch coordinates meshgrid.
If None is set (default), then no meshgrid will be returned.
pad_value (int): value with which an array padding will be performed.
Yields:
np.ndarray: a cropped patch from the CT scan.
np.ndarray: cropped patch from a CT scan.
np.ndarray | None: meshgrid of a patch.
"""
if centroids is None:
centroids = []
if patch_shape is None:
patch_shape = []

if not isinstance(meta, load_ct.MetaData):
meta = load_ct.MetaData(meta)

patch_shape = scipy.ndimage._ni_support._normalize_sequence(patch_shape, len(ct_array.shape))
padding = np.ceil(np.array(patch_shape) / 2.).astype(np.int)
patch_shape = np.array(patch_shape)
init_shape = np.array(ct_array.shape)
padding = np.ceil(patch_shape / 2.).astype(np.int)
padding = np.stack([padding, padding], axis=1)
ct_array = np.pad(ct_array, padding, mode='edge')
ct_array = np.pad(ct_array, padding, mode='constant', constant_values=pad_value)
for centroid in centroids:
centroid = mm2voxel([centroid[axis] for axis in 'zyx'], meta.origin, meta.spacing)
yield ct_array[centroid[0]: centroid[0] + patch_shape[0],
centroid[1]: centroid[1] + patch_shape[1],
centroid[2]: centroid[2] + patch_shape[2]]

patch = ct_array[centroid[0]: centroid[0] + patch_shape[0],
centroid[1]: centroid[1] + patch_shape[1],
centroid[2]: centroid[2] + patch_shape[2]]

if stride:
init_shape += np.clip(patch_shape // 2 - centroid, 0, np.inf).astype(np.int64)
init_shape += np.clip(centroid + patch_shape // 2 - init_shape, 0, np.inf).astype(np.int64)

normstart = (np.array(centroid) - patch_shape / 2) / init_shape - 0.5
normsize = patch_shape / init_shape
xx, yy, zz = np.meshgrid(np.linspace(normstart[0], normstart[0] + normsize[0], patch_shape[0] // stride),
np.linspace(normstart[1], normstart[1] + normsize[1], patch_shape[1] // stride),
np.linspace(normstart[2], normstart[2] + normsize[2], patch_shape[2] // stride),
indexing='ij')
coord = np.concatenate([xx[np.newaxis, ...], yy[np.newaxis, ...], zz[np.newaxis, :]], 0).astype('float32')
yield patch, coord

yield patch


def patches_from_ct(ct_array, meta, patch_shape=None, centroids=None):
def patches_from_ct(ct_array, meta, patch_shape=None, centroids=None, stride=None, pad_value=0):
""" Given a CT scan, and a list of centroids return the list of patches
of the desired patch shape.
Expand All @@ -74,6 +97,9 @@ def patches_from_ct(ct_array, meta, patch_shape=None, centroids=None):
'y': int,
'z': int}
meta (src.preprocess.load_ct.MetaData): meta information of the CT scan.
stride (int): stride for patches' coordinates meshgrids.
If None is set (default), then no meshgrid will be returned.
pad_value (int): value with which an array padding will be performed.
Yields:
np.ndarray: a cropped patch from the CT scan.
Expand All @@ -82,6 +108,6 @@ def patches_from_ct(ct_array, meta, patch_shape=None, centroids=None):
patch_shape = []
if centroids is None:
centroids = []
patch_generator = crop_patch(ct_array, meta, patch_shape, centroids)
patch_generator = crop_patch(ct_array, meta, patch_shape, centroids, stride, pad_value)
patches = itertools.islice(patch_generator, len(centroids))
return list(patches)
Loading

0 comments on commit 389824f

Please sign in to comment.