Skip to content

Commit

Permalink
fixing for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
kinderst committed Jun 1, 2024
1 parent a863e7f commit 9a954f1
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/segmentation/generic/utils/utils_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_dataloader_from_csv(model_arch, csv_path, csv_img_path_col, csv_label_pa
data_df = pd.read_csv(csv_path)

image_paths = data_df[csv_img_path_col].tolist()
mask_paths = data_df[csv_label_path_col].tolist()
mask_paths = data_df[csv_label_path_col].tolist() if csv_label_path_col is not None else None

data_set: Dataset[Tuple[torch.Tensor, torch.Tensor, np.ndarray[Any, Any], np.ndarray[Any, Any]]]
if model_arch == 'mask2former':
Expand Down
26 changes: 17 additions & 9 deletions src/segmentation/mask2former/data_modules/datasets_mask2former.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,24 @@ def __len__(self):
return len(self.images)

def __getitem__(self, idx):
# Get paths
# Get Image
path_to_img = os.path.join(self.image_root_dir, self.images[idx])
path_to_mask = os.path.join(self.mask_root_dir, self.masks[idx])

# Open image
original_image = np.array(Image.open(path_to_img))
original_segmentation_map = np.array(Image.open(path_to_mask).convert('L'))
# TODO: Fix hardcode
original_segmentation_map[original_segmentation_map == 0] = 1
original_segmentation_map[original_segmentation_map == 255] = 2

original_image_width = original_image.shape[0]
original_image_height = original_image.shape[1]

# Get mask if available. Otherwise, mask can be just all 0's
if self.masks is None or self.label_root_dir is None:
# Create a new array with the same shape, filled with ones (background)
original_segmentation_map = np.ones((original_image_width, original_image_height), dtype=np.uint8)
original_segmentation_map[0, 0] = 2
else:
path_to_mask = os.path.join(self.mask_root_dir, self.masks[idx])
original_segmentation_map = np.array(Image.open(path_to_mask).convert('L'))
# TODO: Fix hardcode
original_segmentation_map[original_segmentation_map == 0] = 1
original_segmentation_map[original_segmentation_map == 255] = 2

# Get transformed image
transformed = self.transform(image=original_image, mask=original_segmentation_map)
Expand All @@ -46,4 +54,4 @@ def __getitem__(self, idx):
mask_labels = batched['mask_labels'][0] # (2, 512, 512)
class_labels = batched['class_labels'][0] # (2,)

return pixel_values, mask_labels, class_labels
return pixel_values, mask_labels, class_labels, (original_image_width, original_image_height), self.images[idx]
4 changes: 2 additions & 2 deletions src/segmentation/mask2former/models/models_mask2former.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, configs, num_classes, preprocessor):

# will be used during inference
def forward(self, x):
pixel_values, mask_labels, class_labels = x
pixel_values, mask_labels, class_labels, original_sizes, filenames = x
output = self.model(
pixel_values=pixel_values,
mask_labels=mask_labels,
Expand All @@ -45,7 +45,7 @@ def forward(self, x):
return output

def common_step(self, batch, batch_idx):
pixel_values, mask_labels, class_labels = batch
pixel_values, mask_labels, class_labels, original_sizes, filenames = batch

# Forward pass
outputs = self(batch)
Expand Down

0 comments on commit 9a954f1

Please sign in to comment.