Skip to content

Commit

Permalink
Adds data augmentation as part of kfold process
Browse files Browse the repository at this point in the history
  • Loading branch information
jac241 committed Jul 19, 2019
1 parent cb3ae97 commit 74e9a26
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 19 deletions.
49 changes: 49 additions & 0 deletions augment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from argparse import ArgumentParser
from collections import namedtuple
from pathlib import Path

import tables
from Augmentor import DataPipeline

Samples = namedtuple('Samples', 'imdata truth')


def parse_args():
parser = ArgumentParser()

parser.add_argument('input_file_path', help='.h5 file path to augment, output will be same place with _aug added')

return parser.parse_args()


def main():
args = parse_args()
data = tables.open_file(args.input_file_path, mode='r')
samples = augment_images(xs=data.root.imdata, ys=data.root.truth)

print(save_samples(samples))


def augment_images(xs, ys):
pipeline = DataPipeline(xs, ys)

pipeline.random_distortion(probability=0.75, grid_height=16, grid_width=16, magnitude=4)
pipeline.rotate(probability=0.9, max_left_rotation=15, max_right_rotation=15)
pipeline.flip_top_bottom(probability=0.05)
pipeline.flip_left_right(probability=0.5)
pipeline.zoom(0.75, min_factor=0, max_factor=0.1)
pipeline.random_brightness(probability=0.9, min_factor=0.9, max_factor=1.1)

return pipeline.sample(len(xs) * 10)


def save_samples(samples, input_file_path):
p = Path(input_file_path)
output_file_path = Path(p.parent, f'{p.stem}_aug.h5')
# tables.open_file(str(output_file_path), 'w')
return output_file_path



if __name__ == '__main__':
main()
8 changes: 4 additions & 4 deletions config/run/combined_remote_dice.args
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
--training_model_name=combined_new_thresholding_fold_4.h5
--data_file=./datasets/combined_new_thresholding.h5
--training_model_name=combined_aug_fold_0.h5
--data_file=./datasets/combined_new_thresholding_aug.h5
--data_split=0.9
--training_split=datasets/combined_new_thresholding_kfold/fold_4_train.pkl
--validation_split=datasets/combined_new_thresholding_kfold/fold_4_val.pkl
--training_split=datasets/combined_new_thresholding_kfold_aug/fold_0_train.pkl
--validation_split=datasets/combined_new_thresholding_kfold/fold_0_val.pkl
--n_epochs=100
--image_masks=Muscle
--problem_type=Segmentation
Expand Down
10 changes: 10 additions & 0 deletions config/run/combined_remote_dice_aug.args
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
--training_model_name=combined_aug_fold_0.h5
--data_file=./datasets/combined_new_thresholding_aug.h5
--data_split=0.9
--training_split=datasets/combined_new_thresholding_kfold_aug/fold_0_train.pkl
--validation_split=datasets/combined_new_thresholding_kfold/fold_0_val.pkl
--n_epochs=100
--image_masks=Muscle
--problem_type=Segmentation
--GPU=1
--batch_size=16
159 changes: 148 additions & 11 deletions kfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,53 @@

import numpy as np
import tables
from Augmentor import DataPipeline
from keras_preprocessing.image import ImageDataGenerator
from sklearn.model_selection import KFold, train_test_split

from unet3d.utils import pickle_dump

ThreewayKFold = namedtuple('ThreewayKFold', 'train_indices val_indices test_indices')
ThreewayKFold = namedtuple('ThreewayKFold', 'train_indices val_indices test_indices xs ys subject_ids')


def main():
args = parse_args()

with tables.open_file(args.input_file) as data_file:
kfold_indices = get_kfold_indices(
xs=np.array(data_file.root.imdata),
ys=np.array(data_file.root.truth)
)
xs = np.array(data_file.root.imdata)
ys = np.array(data_file.root.truth)
subject_ids = np.array(data_file.root.subject_ids)

kfold = augment_if_desired(
kfold=get_kfold(xs, ys, subject_ids),
should_augment=args.augment,
samples_per_image=args.samples_per_image
)

if args.augment:
save_augmented_dataset(input_file_path=args.input_file, kfold=kfold)

kfold_directory = create_output_directory(
input_file_path=Path(args.input_file),
containing_dir_path=Path(args.output_dir)
containing_dir_path=Path(args.output_dir),
did_augment=args.augment
)

save_kfold_indices(kfold_directory, kfold_indices)
save_kfold_indices(kfold_directory, kfold)


def parse_args():
parser = ArgumentParser()

parser.add_argument('input_file', help='Input .h5 file')
parser.add_argument('output_dir', help='Directory where dir of kfold files will be created')
parser.add_argument('--augment', action='store_true')
parser.add_argument('--samples_per_image', default=10, type=int, help='Number of samples of each image to take with augmentation')

return parser.parse_args()


def get_kfold_indices(xs, ys):
def get_kfold(xs, ys, subject_ids):
sets_of_training_train_indices = []
sets_of_training_val_indices = []
sets_of_test_indices = []
Expand All @@ -51,11 +64,135 @@ def get_kfold_indices(xs, ys):
sets_of_training_val_indices.append(train_val_indices)
sets_of_test_indices.append(test_indices)

return ThreewayKFold(sets_of_training_train_indices, sets_of_training_val_indices, sets_of_test_indices)
return ThreewayKFold(
sets_of_training_train_indices,
sets_of_training_val_indices,
sets_of_test_indices,
xs,
ys,
subject_ids
)


def augment_if_desired(kfold, should_augment=False, **augmentation_options):
if should_augment:
return augment_data(kfold, **augmentation_options)
else:
return kfold


def augment_data(kfold_indices, samples_per_image):
new_xs = []
new_ys = []
new_subject_ids=[]

index = len(kfold_indices.xs) # appending new images to end of current list
new_train_indices = []
for fold_number in range(len(kfold_indices.train_indices)):
print('augmenting fold: ', fold_number)
subset_of_xs = kfold_indices.xs[kfold_indices.train_indices[fold_number]]
subset_of_ys = kfold_indices.ys[kfold_indices.train_indices[fold_number]]

x_samples, y_samples = get_augmented_samples(subset_of_xs, subset_of_ys, samples_per_image)

current_fold_train_indices = []

new_xs.extend(x_samples)
new_ys.extend(y_samples)

for _ in range(len(x_samples)):
current_fold_train_indices.append(index)
new_subject_ids.append(f'a{index}')
index = index + 1

new_train_indices.append(current_fold_train_indices)

return ThreewayKFold(
train_indices=new_train_indices,
val_indices=kfold_indices.val_indices,
test_indices=kfold_indices.test_indices,
xs=np.concatenate([kfold_indices.xs, new_xs]),
ys=np.concatenate([kfold_indices.ys, new_ys]),
subject_ids=np.concatenate([kfold_indices.subject_ids, new_subject_ids])
)


def get_augmented_samples(subset_of_xs, subset_of_ys, samples_per_image):
# pipeline = get_augmentation_pipeline(subset_of_xs, subset_of_ys)
# # samples = pipeline.sample(len(subset_of_xs) * samples_per_image)
# samples = pipeline.sample(10)
# return samples
datagen_args = dict(
horizontal_flip=True,
vertical_flip=True,
zoom_range=0.1,
rotation_range=20,
shear_range=5,
)

seed = 1
image_datagen = ImageDataGenerator(**datagen_args)
mask_datagen = ImageDataGenerator(**datagen_args)

x_samples = []
y_samples = []

image_datagen.fit(subset_of_xs, augment=True, seed=seed)
mask_datagen.fit(subset_of_ys, augment=True, seed=seed)

image_generator = image_datagen.flow(subset_of_xs, seed=seed)
mask_generator = image_datagen.flow(subset_of_ys, seed=seed)

for batch_number, (x_batch, y_batch) in enumerate(zip(image_generator, mask_generator)):
print('Batch: ', batch_number)

x_samples.extend(x_batch)
y_samples.extend(y_batch)

if len(x_samples) > len(subset_of_xs) * 10:
break

return np.array(x_samples), make_mask_boolean(np.array(y_samples))


def make_mask_boolean(mask):
mask[mask > 0] = 1
return mask


def get_augmentation_pipeline(xs, ys):
im_shape = (256, 256)

pipeline = DataPipeline([
[x, y]
for x, y
in zip(xs.reshape(len(xs), *im_shape), ys.reshape(len(ys), *im_shape))
])

# pipeline.random_distortion(probability=0.75, grid_height=4, grid_width=4, magnitude=4)
pipeline.rotate(probability=0.9, max_left_rotation=15, max_right_rotation=15)
pipeline.flip_top_bottom(probability=0.05)
pipeline.flip_left_right(probability=0.5)
pipeline.zoom(0.75, min_factor=1.0, max_factor=1.1)
# pipeline.random_brightness(probability=0.9, min_factor=0.9, max_factor=1.1)

return pipeline


def save_augmented_dataset(input_file_path, kfold):
with tables.open_file(_get_augmented_file_path(input_file_path), 'w') as hd5:
hd5.create_array(hd5.root, 'imdata', kfold.xs)
hd5.create_array(hd5.root, 'truth', kfold.ys)
hd5.create_array(hd5.root, 'subject_ids', kfold.subject_ids)


def _get_augmented_file_path(input_file_path):
old_path = Path(input_file_path)
return str(Path(old_path.parent, f'{old_path.stem}_aug.h5'))


def create_output_directory(input_file_path, containing_dir_path):
new_dir_name = f'{input_file_path.stem}_kfold'
def create_output_directory(input_file_path, containing_dir_path, did_augment):
new_dir_name = f'{input_file_path.stem}_kfold_aug' if did_augment else f'{input_file_path.stem}_kfold'
new_output_directory = Path(containing_dir_path, new_dir_name)
new_output_directory.mkdir(exist_ok=True)

Expand Down
8 changes: 4 additions & 4 deletions tests/kfold_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from unittest import TestCase

from kfold import get_kfold_indices
from kfold import get_kfold


class TestKfold(TestCase):
Expand All @@ -11,14 +11,14 @@ def setUp(self) -> None:
self.y = np.arange(200, 210)

def test_get_kfold_indices_should_split_data_into_3_sets_of_5_folds_by_default(self):
xs, vals, ys = get_kfold_indices(self.x, self.y)
xs, vals, ys = get_kfold(self.x, self.y)

self.assertEqual(len(xs), 5)
self.assertEqual(len(vals), 5)
self.assertEqual(len(ys), 5)

def test_get_kfold_indices_should_not_overlap(self):
kfold_indices = get_kfold_indices(self.x, self.y)
kfold_indices = get_kfold(self.x, self.y)

for train, val, test in zip(*kfold_indices):
self.assertFalse(any(i in val for i in train))
Expand All @@ -31,4 +31,4 @@ def test_get_kfold_indices_should_not_overlap(self):
self.assertFalse(any(i in val for i in test))

def test_get_kfold_indices_should_error_if_input_arrays_lengths_not_equal(self):
self.assertRaises(ValueError, get_kfold_indices, xs=[1], ys=[2, 3])
self.assertRaises(ValueError, get_kfold, xs=[1], ys=[2, 3])

0 comments on commit 74e9a26

Please sign in to comment.