Skip to content

Commit

Permalink
Segmentation losses refactored
Browse files Browse the repository at this point in the history
  • Loading branch information
szmazurek committed Nov 16, 2024
1 parent 227a86f commit f7e168b
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 22 deletions.
9 changes: 3 additions & 6 deletions GANDLF/losses/loss_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@ def __init__(self, params: dict):
self.params = params

@abstractmethod
def forward(
self, prediction: torch.Tensor, target: torch.Tensor, *args
) -> torch.Tensor:
def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
pass


Expand Down Expand Up @@ -49,9 +47,7 @@ def _single_class_loss_calculator(
"""Compute loss for a pair of prediction and target tensors. To be implemented by child classes."""
pass

def forward(
self, prediction: torch.Tensor, target: torch.Tensor, *args
) -> torch.Tensor:
def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
accumulated_loss = torch.tensor(0.0, device=prediction.device)

for class_idx in range(self.num_classes):
Expand All @@ -64,6 +60,7 @@ def forward(
current_loss = current_loss * self.penalty_weights[class_idx]
accumulated_loss += current_loss

# TODO shouldn't we always divide by the number of classes?
if self.penalty_weights is None:
accumulated_loss /= self.num_classes

Expand Down
105 changes: 89 additions & 16 deletions GANDLF/losses/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,95 @@ def _single_class_loss_calculator(
return loss


class MulticlassFocalLoss(AbstractSegmentationMultiClassLoss):
"""
This class computes the Focal loss between two tensors.
"""

def __init__(self, params: dict):
super().__init__(params)

self.ce_loss_helper = torch.nn.CrossEntropyLoss(reduction="none")
loss_params = params["loss_function"]
self.alpha = 1.0
self.gamma = 2.0
self.output_aggregation = "sum"
if isinstance(loss_params, dict):
self.alpha = loss_params.get("alpha", self.alpha)
self.gamma = loss_params.get("gamma", self.gamma)
self.output_aggregation = loss_params.get(
"size_average",
self.output_aggregation, # naming mismatch of key due to keeping API consistent with config format
)
assert self.output_aggregation in [
"sum",
"mean",
], f"Invalid output aggregation method defined for Foal Loss: {self.output_aggregation}. Valid options are ['sum', 'mean']"

def _single_class_loss_calculator(
self, prediction: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
"""
Compute focal loss for a single class. It is based on the following formulas:
FocalLoss(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t)
CrossEntropy(pred, target) = -log(pred) if target = 1 else -log(1 - pred)
CrossEntropy(p_t) = CrossEntropy(pred, target) = -log(p_t)
p_t = p if target = 1 else 1 - p
"""
ce_loss = self.ce_loss_helper(prediction, target)
p_t = torch.exp(-ce_loss)
loss = -self.alpha * (1 - p_t) ** self.gamma * ce_loss
return loss.sum() if self.output_aggregation == "sum" else loss.mean()

def _compute_single_class_loss(
self, prediction: torch.Tensor, target: torch.Tensor, class_idx: int
) -> torch.Tensor:
"""Compute loss for a single class."""
loss_value = self._single_class_loss_calculator(
prediction[:, class_idx, ...], target[:, class_idx, ...]
)
return loss_value # no need to subtract from 1 in this case, hence the override


class KullbackLeiblerDivergence(AbstractLossFunction):
def forward(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
"""
Calculates the Kullback-Leibler divergence between two Gaussian distributions.
Args:
mu (torch.Tensor): The mean of the first Gaussian distribution.
logvar (torch.Tensor): The logarithm of the variance of the first Gaussian distribution.
Returns:
torch.Tensor: The computed Kullback-Leibler divergence
"""
loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1)
return loss.mean()


# Dice scores and dice losses
def dice(predicted: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
This function computes a dice score between two tensors.
Args:
predicted (torch.Tensor): Predicted value by the network.
target (torch.Tensor): Required target label to match the predicted with
Returns:
torch.Tensor: The computed dice score.
"""
predicted_flat = predicted.flatten()
label_flat = target.flatten()
intersection = (predicted_flat * label_flat).sum()

dice_score = (2.0 * intersection + sys.float_info.min) / (
predicted_flat.sum() + label_flat.sum() + sys.float_info.min
)

return dice_score


def mcc(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
This function computes the Matthews Correlation Coefficient (MCC) between two tensors. Adapted from https://github.com/kakumarabhishek/MCC-Loss/blob/main/loss.py.
Expand Down Expand Up @@ -212,22 +301,6 @@ def generic_loss_calculator(
return accumulated_loss


class KullbackLeiblerDivergence(AbstractLossFunction):
def forward(self, mu: torch.Tensor, logvar: torch.Tensor, *args) -> torch.Tensor:
"""
Calculates the Kullback-Leibler divergence between two Gaussian distributions.
Args:
mu (torch.Tensor): The mean of the first Gaussian distribution.
logvar (torch.Tensor): The logarithm of the variance of the first Gaussian distribution.
Returns:
torch.Tensor: The computed Kullback-Leibler divergence
"""
loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1)
return loss.mean()


def MCD_loss(
predicted: torch.Tensor, target: torch.Tensor, params: dict
) -> torch.Tensor:
Expand Down

0 comments on commit f7e168b

Please sign in to comment.