From e42b9edd351b3799d1b23c871969f4823753d561 Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Wed, 19 Jun 2024 11:32:33 +0200 Subject: [PATCH 1/3] fix #2300; dataloader did not stack segmentations properly if deep supervision was disabled --- nnunetv2/training/dataloading/data_loader_2d.py | 5 ++++- nnunetv2/training/dataloading/data_loader_3d.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/nnunetv2/training/dataloading/data_loader_2d.py b/nnunetv2/training/dataloading/data_loader_2d.py index 655a7aae5..08bfad87a 100644 --- a/nnunetv2/training/dataloading/data_loader_2d.py +++ b/nnunetv2/training/dataloading/data_loader_2d.py @@ -101,7 +101,10 @@ def generate_train_batch(self): images.append(tmp['image']) segs.append(tmp['segmentation']) data_all = torch.stack(images) - seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))] + if isinstance(segs[0], list): + seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))] + else: + seg_all = torch.stack(segs) del segs, images if torch is not None: torch.set_num_threads(torch_nthreads) diff --git a/nnunetv2/training/dataloading/data_loader_3d.py b/nnunetv2/training/dataloading/data_loader_3d.py index 3131e1f09..d17928475 100644 --- a/nnunetv2/training/dataloading/data_loader_3d.py +++ b/nnunetv2/training/dataloading/data_loader_3d.py @@ -64,7 +64,10 @@ def generate_train_batch(self): images.append(tmp['image']) segs.append(tmp['segmentation']) data_all = torch.stack(images) - seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))] + if isinstance(segs[0], list): + seg_all = [torch.stack([s[i] for s in segs]) for i in range(len(segs[0]))] + else: + seg_all = torch.stack(segs) del segs, images if torch is not None: torch.set_num_threads(torch_nthreads) From ec229b53c206ecf79c786e4b838e3c511462c76e Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Wed, 19 Jun 2024 11:46:21 +0200 Subject: [PATCH 2/3] add trainers with 500 and 750 epochs --- .../training_length/nnUNetTrainer_Xepochs.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py b/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py index e3a71a000..9d4867003 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py +++ b/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py @@ -55,6 +55,20 @@ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dic self.num_epochs = 250 +class nnUNetTrainer_500epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 500 + + +class nnUNetTrainer_750epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 750 + + class nnUNetTrainer_2000epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, device: torch.device = torch.device('cuda')): From ed88855f5fa03142d989e90c5ae1588b89858d00 Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Wed, 19 Jun 2024 14:00:30 +0200 Subject: [PATCH 3/3] use np.allclose as suggested in #2282 --- nnunetv2/imageio/base_reader_writer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nnunetv2/imageio/base_reader_writer.py b/nnunetv2/imageio/base_reader_writer.py index 2847478ae..4ca536e5d 100644 --- a/nnunetv2/imageio/base_reader_writer.py +++ b/nnunetv2/imageio/base_reader_writer.py @@ -21,11 +21,11 @@ class BaseReaderWriter(ABC): @staticmethod def _check_all_same(input_list): - # compare all entries to the first - for i in input_list[1:]: - if i != input_list[0]: - return False - return True + if len(input_list) == 1: + return True + else: + # compare all entries to the first + return np.allclose(input_list[0], input_list[1:]) @staticmethod def _check_all_same_array(input_list):