Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix generalized dice computation #7970

Merged
merged 18 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 77 additions & 48 deletions monai/metrics/generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,47 @@
import torch

from monai.metrics.utils import do_metric_reduction, ignore_background
from monai.utils import MetricReduction, Weight, look_up_option
from monai.utils import MetricReduction, Weight, deprecated_arg, deprecated_arg_default, look_up_option

from .metric import CumulativeIterationMetric


class GeneralizedDiceScore(CumulativeIterationMetric):
"""Compute the Generalized Dice Score metric between tensors, as the complement of the Generalized Dice Loss defined in:
"""
Compute the Generalized Dice Score metric between tensors.

This metric is the complement of the Generalized Dice Loss defined in:
Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning
loss function for highly unbalanced segmentations. DLMIA 2017.
loss function for highly unbalanced segmentations. DLMIA 2017.

The inputs `y_pred` and `y` are expected to be one-hot, binarized channel-first
or batch-first tensors, i.e., CHW[D] or BCHW[D].
The inputs `y_pred` and `y` are expected to be one-hot, binarized batch-first tensors, i.e., NCHW[D].

Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

Args:
include_background (bool, optional): whether to include the background class (assumed to be in channel 0), in the
include_background: Whether to include the background class (assumed to be in channel 0) in the
score computation. Defaults to True.
reduction (str, optional): define mode of reduction to the metrics. Available reduction modes:
{``"none"``, ``"mean_batch"``, ``"sum_batch"``}. Default to ``"mean_batch"``. If "none", will not do reduction.
weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
reduction: Define mode of reduction to the metrics. Available reduction modes:
{``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
ground truth volume into a weight factor. Defaults to ``"square"``.

Raises:
ValueError: when the `weight_type` is not one of {``"none"``, ``"mean"``, ``"sum"``}.
ValueError: When the `reduction` is not one of MetricReduction enum.
"""

@deprecated_arg_default(
"reduction",
old_default=MetricReduction.MEAN_BATCH,
new_default=MetricReduction.MEAN,
since="1.4.0",
replaced="1.5.0",
msg_suffix=(
"Old versions computed `mean` when `mean_batch` was provided due to bug in reduction, "
"If you want to retain the old behavior (calculating the mean), please explicitly set the parameter to 'mean'."
),
)
def __init__(
self,
include_background: bool = True,
Expand All @@ -50,79 +63,90 @@ def __init__(
) -> None:
super().__init__()
self.include_background = include_background
reduction_options = [
"none",
"mean_batch",
"sum_batch",
MetricReduction.NONE,
MetricReduction.MEAN_BATCH,
MetricReduction.SUM_BATCH,
]
self.reduction = reduction
if self.reduction not in reduction_options:
raise ValueError(f"reduction must be one of {reduction_options}")
self.reduction = look_up_option(reduction, MetricReduction)
self.weight_type = look_up_option(weight_type, Weight)
self.sum_over_classes = self.reduction in {
MetricReduction.SUM,
MetricReduction.MEAN,
MetricReduction.MEAN_CHANNEL,
MetricReduction.SUM_CHANNEL,
}

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
"""Computes the Generalized Dice Score and returns a tensor with its per image values.
"""
Computes the Generalized Dice Score and returns a tensor with its per image values.

Args:
y_pred (torch.Tensor): binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,
y_pred (torch.Tensor): Binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,
where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions.
y (torch.Tensor): binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.
y (torch.Tensor): Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.

Returns:
torch.Tensor: Generalized Dice Score averaged across batch and class

Raises:
ValueError: if `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
ValueError: If `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
"""
return compute_generalized_dice(
y_pred=y_pred, y=y, include_background=self.include_background, weight_type=self.weight_type
y_pred=y_pred,
y=y,
include_background=self.include_background,
weight_type=self.weight_type,
sum_over_classes=self.sum_over_classes,
)

@deprecated_arg(
"reduction",
since="1.3.3",
removed="1.7.0",
msg_suffix="Reduction will be ignored. Set reduction during init. as gen.dice needs it during compute",
)
def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Tensor:
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
"""
Execute reduction logic for the output of `compute_generalized_dice`.

Args:
reduction (Union[MetricReduction, str, None], optional): define mode of reduction to the metrics.
Available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``}.
Defaults to ``"mean"``. If "none", will not do reduction.
Returns:
torch.Tensor: Aggregated metric value.

Raises:
ValueError: If the data to aggregate is not a PyTorch Tensor.
"""
data = self.get_buffer()
if not isinstance(data, torch.Tensor):
raise ValueError("The data to aggregate must be a PyTorch Tensor.")

# Validate reduction argument if specified
if reduction is not None:
reduction_options = ["none", "mean", "sum", "mean_batch", "sum_batch"]
if reduction not in reduction_options:
raise ValueError(f"reduction must be one of {reduction_options}")

# Do metric reduction and return
f, _ = do_metric_reduction(data, reduction or self.reduction)
f, _ = do_metric_reduction(data, self.reduction)

return f


def compute_generalized_dice(
y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, weight_type: Weight | str = Weight.SQUARE
y_pred: torch.Tensor,
y: torch.Tensor,
include_background: bool = True,
weight_type: Weight | str = Weight.SQUARE,
sum_over_classes: bool = False,
) -> torch.Tensor:
"""Computes the Generalized Dice Score and returns a tensor with its per image values.
"""
Computes the Generalized Dice Score and returns a tensor with its per image values.

Args:
y_pred (torch.Tensor): binarized segmentation model output. It should be binarized, in one-hot format
y_pred (torch.Tensor): Binarized segmentation model output. It should be binarized, in one-hot format
and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the
remaining are the spatial dimensions.
y (torch.Tensor): binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.
include_background (bool, optional): whether to include score computation on the first channel of the
y (torch.Tensor): Binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.
include_background: Whether to include score computation on the first channel of the
predicted output. Defaults to True.
weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to
transform ground truth volume into a weight factor. Defaults to ``"square"``.
sum_over_labels (bool): Whether to sum the numerator and denominator across all labels before the final computation.

Returns:
torch.Tensor: per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].
torch.Tensor: Per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].

Raises:
ValueError: if `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
ValueError: If `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
or `y_pred` and `y` don't have the same shape.
"""
# Ensure tensors have at least 3 dimensions and have the same shape
Expand Down Expand Up @@ -158,16 +182,21 @@ def compute_generalized_dice(
b[infs] = 0
b[infs] = torch.max(b)

# Compute the weighted numerator and denominator, summing along the class axis
numer = 2.0 * (intersection * w).sum(dim=1)
denom = (denominator * w).sum(dim=1)
# Compute the weighted numerator and denominator, summing along the class axis when sum_over_classes is True
if sum_over_classes:
numer = 2.0 * (intersection * w).sum(dim=1, keepdim=True)
denom = (denominator * w).sum(dim=1, keepdim=True)
y_pred_o = y_pred_o.sum(dim=-1, keepdim=True)
else:
numer = 2.0 * (intersection * w)
denom = denominator * w
y_pred_o = y_pred_o

# Compute the score
generalized_dice_score = numer / denom

# Handle zero division. Where denom == 0 and the prediction volume is 0, score is 1.
# Where denom == 0 but the prediction volume is not 0, score is 0
y_pred_o = y_pred_o.sum(dim=-1)
denom_zeros = denom == 0
generalized_dice_score[denom_zeros] = torch.where(
(y_pred_o == 0)[denom_zeros],
Expand Down
Loading
Loading