Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix quality issues #33

Merged
merged 6 commits into from
Oct 6, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 35 additions & 10 deletions cryocare/internals/CryoCAREDataModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,19 @@
class CryoCARE_Dataset(tf.keras.utils.Sequence):
def __init__(self, tomo_paths_odd=None, tomo_paths_even=None, n_samples_per_tomo=None,
extraction_shapes=None, mean=None, std=None,
sample_shape=(64, 64, 64), shuffle=True, n_normalization_samples=500):
sample_shape=(64, 64, 64), shuffle=True, n_normalization_samples=500, tilt_axis=None):
self.tomo_paths_odd = tomo_paths_odd
self.tomo_paths_even = tomo_paths_even
self.n_samples_per_tomo = n_samples_per_tomo
self.tilt_axis = tilt_axis

if self.tilt_axis is not None:
tilt_axis_index = ["Z", "Y", "X"].index(self.tilt_axis)
rot_axes = [0, 1, 2]
rot_axes.remove(tilt_axis_index)
self.rot_axes = tuple(rot_axes)
else:
self.rot_axes = None

self.extraction_shapes = extraction_shapes
self.mean = mean
Expand Down Expand Up @@ -49,11 +58,12 @@ def save(self, path):
extraction_shapes=self.extraction_shapes,
sample_shape=self.sample_shape,
shuffle=self.shuffle,
coords=self.coords)
coords=self.coords,
tilt_axis=self.tilt_axis)

@classmethod
def load(cls, path):
tmp = np.load(path)
tmp = np.load(path, allow_pickle=True)
tomo_paths_odd = [str(p) for p in tmp['tomo_paths_odd']]
tomo_paths_even = [str(p) for p in tmp['tomo_paths_even']]
mean = tmp['mean']
Expand All @@ -63,6 +73,10 @@ def load(cls, path):
sample_shape = tmp['sample_shape']
shuffle = tmp['shuffle']
coords = tmp['coords']
if isinstance(tmp['tilt_axis'], np.ndarray):
tilt_axis = None
else:
tilt_axis = tmp['tilt_axis']

ds = cls(tomo_paths_odd=tomo_paths_odd,
tomo_paths_even=tomo_paths_even,
Expand All @@ -71,7 +85,8 @@ def load(cls, path):
n_samples_per_tomo=n_samples_per_tomo,
extraction_shapes=extraction_shapes,
sample_shape=sample_shape,
shuffle=shuffle)
shuffle=shuffle,
tilt_axis=tilt_axis)
ds.coords = coords
return ds

Expand Down Expand Up @@ -120,7 +135,16 @@ def create_random_coords(self, z, y, x, n_samples):

return np.stack([z_coords, y_coords, x_coords], -1)

def random_swapper(self, x, y):
def augment(self, x, y):
if self.tilt_axis is not None:
if self.sample_shape[0] == self.sample_shape[1] and \
self.sample_shape[0] == self.sample_shape[2]:
rot_k = np.random.randint(0, 4, 1)

x[...,0] = np.rot90(x[...,0], k=rot_k, axes=self.rot_axes)
y[...,0] = np.rot90(y[...,0], k=rot_k, axes=self.rot_axes)


if np.random.rand() > 0.5:
return y, x
else:
Expand All @@ -140,8 +164,7 @@ def __getitem__(self, idx):
odd_subvolume = self.tomos_odd[tomo_index].data[z:z + self.sample_shape[0],
y:y + self.sample_shape[1],
x:x + self.sample_shape[2]]

return self.random_swapper(np.array(even_subvolume)[..., np.newaxis], np.array(odd_subvolume)[..., np.newaxis])
return self.augment(np.array(even_subvolume)[..., np.newaxis], np.array(odd_subvolume)[..., np.newaxis])

def __iter__(self):
for idx in self.indices:
Expand Down Expand Up @@ -182,7 +205,8 @@ def setup(self, tomo_paths_odd, tomo_paths_even, n_samples_per_tomo, validation_
n_samples_per_tomo * (1 - validation_fraction)),
extraction_shapes=train_extraction_shapes,
sample_shape=sample_shape,
shuffle=True, n_normalization_samples=n_normalization_samples)
shuffle=True, n_normalization_samples=n_normalization_samples,
tilt_axis=tilt_axis)

self.val_dataset = CryoCARE_Dataset(tomo_paths_odd=tomo_paths_odd,
tomo_paths_even=tomo_paths_even,
Expand All @@ -191,7 +215,8 @@ def setup(self, tomo_paths_odd, tomo_paths_even, n_samples_per_tomo, validation_
n_samples_per_tomo=int(n_samples_per_tomo * validation_fraction),
extraction_shapes=val_extraction_shapes,
sample_shape=sample_shape,
shuffle=False)
shuffle=False,
tilt_axis=None)

def save(self, path):
self.train_dataset.save(join(path, 'train_data.npz'))
Expand All @@ -211,7 +236,7 @@ def __compute_extraction_shapes__(self, even_path, odd_path, tilt_axis_index, sa
assert even.data.shape[1] > 2 * sample_shape[1]
assert even.data.shape[2] > 2 * sample_shape[2]

val_cut_off = int(even.data.shape[tilt_axis_index] * validation_fraction)
val_cut_off = int(even.data.shape[tilt_axis_index] * (1 - validation_fraction))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that this is correct.

In cryoCARE_extract_train the validation fraction is already defined as:

validation_fraction=(1.0 - config['split'])

See: https://github.com/juglab/cryoCARE_pip/blob/master/cryocare/scripts/cryoCARE_extract_train_data.py#L28

Given a split of 0.9: With your change 90% of the data would be validation data?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here

extraction_shape_train[tilt_axis_index] = [0, val_cut_off]
it takes then train_data to go from 0 up to the val_cut_off. Maybe I am confusing it :/

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are totally right :-)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! Please merge it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't manage to run it yet. Secretly hoping that @rdrighetto finds the time 😇

Copy link
Contributor

@rdrighetto rdrighetto Oct 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I tested it and now I get this error:

2022-10-05 15:28:10.372969: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
  0%|          | 0/500 [00:04<?, ?it/s]
Traceback (most recent call last):
  File "/scicore/home/engel0006/GROUP/pool-engel/soft/cryo-care/cryoCARE_pip/cryocare_11/bin/cryoCARE_extract_train_data.py", line 45, in <module>
    main()
  File "/scicore/home/engel0006/GROUP/pool-engel/soft/cryo-care/cryoCARE_pip/cryocare_11/bin/cryoCARE_extract_train_data.py", line 27, in main
    dm.setup(config['odd'], config['even'], n_samples_per_tomo=config['num_slices'],
  File "/scicore/home/engel0006/GROUP/pool-engel/soft/cryo-care/cryoCARE_pip/cryocare_11/lib/python3.8/site-packages/cryocare/internals/CryoCAREDataModule.py", line 194, in setup
    self.train_dataset = CryoCARE_Dataset(tomo_paths_odd=tomo_paths_odd,
  File "/scicore/home/engel0006/GROUP/pool-engel/soft/cryo-care/cryoCARE_pip/cryocare_11/lib/python3.8/site-packages/cryocare/internals/CryoCAREDataModule.py", line 46, in __init__
    self.compute_mean_std(n_samples=n_normalization_samples)
  File "/scicore/home/engel0006/GROUP/pool-engel/soft/cryo-care/cryoCARE_pip/cryocare_11/lib/python3.8/site-packages/cryocare/internals/CryoCAREDataModule.py", line 91, in compute_mean_std
    x, _ = self.__getitem__(i)
  File "/scicore/home/engel0006/GROUP/pool-engel/soft/cryo-care/cryoCARE_pip/cryocare_11/lib/python3.8/site-packages/cryocare/internals/CryoCAREDataModule.py", line 161, in __getitem__
    return self.augment(np.array(even_subvolume)[..., np.newaxis], np.array(odd_subvolume)[..., np.newaxis])
  File "/scicore/home/engel0006/GROUP/pool-engel/soft/cryo-care/cryoCARE_pip/cryocare_11/lib/python3.8/site-packages/cryocare/internals/CryoCAREDataModule.py", line 137, in augment
    x[i] = np.rot90(x[i], k=rot_k[i], axes=self.rot_axes)
ValueError: could not broadcast input array from shape (1,72,72) into shape (72,72,1)

I tried to fix it by myself without success ☹️. What I did notice is the following:

  1. The problem occurs because x and y have shape (72, 72, 72, 1) when augment() is called
  2. Therefore, when k=rot_k[i] is 1 or 3 in np.rot90 (i.e. a rotation of 90 or 270 degrees) the resulting array will be of shape (1,72,72) which it will try to put in an array whose original shape is (72,72,1) (i.e. x[i])

As I said, I tried to fix this by getting rid of this 4th dimension within augment() and placing it back when returning from that function, but then the code would break somewhere else, so I decided it's better to stop and call the experts 😅

Thanks again!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I pushed a hot-fix.

I promise, if it does not work this time I will setup everything on my end and stop coding in the github IDE!

if ((even.data.shape[tilt_axis_index] - val_cut_off) < sample_shape[tilt_axis_index]) or val_cut_off < sample_shape[tilt_axis_index]:
val_cut_off = even.data.shape[tilt_axis_index] - sample_shape[tilt_axis_index] - 1

Expand Down