From 9a954f10a75a748627454f54c71c641568ce5b6f Mon Sep 17 00:00:00 2001 From: kinderst Date: Sat, 1 Jun 2024 15:24:29 -0600 Subject: [PATCH] fixing for inference --- src/segmentation/generic/utils/utils_train.py | 2 +- .../data_modules/datasets_mask2former.py | 26 ++++++++++++------- .../mask2former/models/models_mask2former.py | 4 +-- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/segmentation/generic/utils/utils_train.py b/src/segmentation/generic/utils/utils_train.py index b48b26e..136df4e 100644 --- a/src/segmentation/generic/utils/utils_train.py +++ b/src/segmentation/generic/utils/utils_train.py @@ -44,7 +44,7 @@ def get_dataloader_from_csv(model_arch, csv_path, csv_img_path_col, csv_label_pa data_df = pd.read_csv(csv_path) image_paths = data_df[csv_img_path_col].tolist() - mask_paths = data_df[csv_label_path_col].tolist() + mask_paths = data_df[csv_label_path_col].tolist() if csv_label_path_col is not None else None data_set: Dataset[Tuple[torch.Tensor, torch.Tensor, np.ndarray[Any, Any], np.ndarray[Any, Any]]] if model_arch == 'mask2former': diff --git a/src/segmentation/mask2former/data_modules/datasets_mask2former.py b/src/segmentation/mask2former/data_modules/datasets_mask2former.py index 9a897cd..f9a5fcb 100644 --- a/src/segmentation/mask2former/data_modules/datasets_mask2former.py +++ b/src/segmentation/mask2former/data_modules/datasets_mask2former.py @@ -17,16 +17,24 @@ def __len__(self): return len(self.images) def __getitem__(self, idx): - # Get paths + # Get Image path_to_img = os.path.join(self.image_root_dir, self.images[idx]) - path_to_mask = os.path.join(self.mask_root_dir, self.masks[idx]) - - # Open image original_image = np.array(Image.open(path_to_img)) - original_segmentation_map = np.array(Image.open(path_to_mask).convert('L')) - # TODO: Fix hardcode - original_segmentation_map[original_segmentation_map == 0] = 1 - original_segmentation_map[original_segmentation_map == 255] = 2 + + original_image_width = original_image.shape[0] + original_image_height = original_image.shape[1] + + # Get mask if available. Otherwise, mask can be just all 0's + if self.masks is None or self.label_root_dir is None: + # Create a new array with the same shape, filled with ones (background) + original_segmentation_map = np.ones((original_image_width, original_image_height), dtype=np.uint8) + original_segmentation_map[0, 0] = 2 + else: + path_to_mask = os.path.join(self.mask_root_dir, self.masks[idx]) + original_segmentation_map = np.array(Image.open(path_to_mask).convert('L')) + # TODO: Fix hardcode + original_segmentation_map[original_segmentation_map == 0] = 1 + original_segmentation_map[original_segmentation_map == 255] = 2 # Get transformed image transformed = self.transform(image=original_image, mask=original_segmentation_map) @@ -46,4 +54,4 @@ def __getitem__(self, idx): mask_labels = batched['mask_labels'][0] # (2, 512, 512) class_labels = batched['class_labels'][0] # (2,) - return pixel_values, mask_labels, class_labels + return pixel_values, mask_labels, class_labels, (original_image_width, original_image_height), self.images[idx] diff --git a/src/segmentation/mask2former/models/models_mask2former.py b/src/segmentation/mask2former/models/models_mask2former.py index 91a26ea..5a19565 100644 --- a/src/segmentation/mask2former/models/models_mask2former.py +++ b/src/segmentation/mask2former/models/models_mask2former.py @@ -36,7 +36,7 @@ def __init__(self, configs, num_classes, preprocessor): # will be used during inference def forward(self, x): - pixel_values, mask_labels, class_labels = x + pixel_values, mask_labels, class_labels, original_sizes, filenames = x output = self.model( pixel_values=pixel_values, mask_labels=mask_labels, @@ -45,7 +45,7 @@ def forward(self, x): return output def common_step(self, batch, batch_idx): - pixel_values, mask_labels, class_labels = batch + pixel_values, mask_labels, class_labels, original_sizes, filenames = batch # Forward pass outputs = self(batch)