Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 858 ignore multiple labels segmentation metrics support #1177

Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
type checking for ignore index expanded to iterable
shaydeci committed Jun 19, 2023
commit 59620c8d909660766ccf813bc930e3d4b5dd9a44
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import typing

import numpy as np
import torch
import torchmetrics
@@ -230,7 +232,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
self.total_label += pixel_labeled

def _handle_multiple_ignored_inds(self, target):
if isinstance(self.ignore_label, list):
if isinstance(self.ignore_label, typing.Iterable):
evaluated_classes_mask = torch.ones_like(target)
for ignored_label in self.ignore_label:
evaluated_classes_mask = evaluated_classes_mask.masked_fill(target.eq(ignored_label), 0)
@@ -264,7 +266,7 @@ def _handle_multiple_ignored_inds(ignore_index: Union[int, List[int]], num_class
:param num_classes: int, num_classes (original, before mapping) being passed to segmentation metric classesץ
:return:ignore_index, ignore_index_list, num_classes, unfiltered_num_classesignore_index, ignore_index_list, num_classes, unfiltered_num_classes
"""
if isinstance(ignore_index, list):
if isinstance(ignore_index, typing.Iterable):
ignore_index_list = ignore_index
unfiltered_num_classes = num_classes
num_classes = num_classes - len(ignore_index_list) + 1
@@ -310,7 +312,7 @@ def __init__(

if num_classes <= 1:
raise ValueError(f"IoU class only for multi-class usage! For binary usage, please call {BinaryIOU.__name__}")
if isinstance(ignore_index, list) and reduction == "none":
if isinstance(ignore_index, typing.Iterable) and reduction == "none":
raise ValueError("passing multiple ignore indices ")
ignore_index, ignore_index_list, num_classes, unfiltered_num_classes = _handle_multiple_ignored_inds(ignore_index, num_classes)