Skip to content

Commit

Permalink
Porting losses to new interface WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
szmazurek committed Nov 16, 2024
1 parent a9729e5 commit 227a86f
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 38 deletions.
75 changes: 52 additions & 23 deletions GANDLF/losses/loss_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,68 @@
from abc import ABC, abstractmethod


class AbstractLossFunction(ABC, nn.Module):
class AbstractLossFunction(nn.Module, ABC):
def __init__(self, params: dict):
super().__init__()
nn.Module.__init__(self)
self.params = params

@abstractmethod
def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
def forward(
self, prediction: torch.Tensor, target: torch.Tensor, *args
) -> torch.Tensor:
pass


class WeightedCE(AbstractLossFunction):
class AbstractSegmentationMultiClassLoss(AbstractLossFunction):
"""
Base class for loss funcions that are used for multi-class segmentation tasks.
"""

def __init__(self, params: dict):
super().__init__(params)
self.num_classes = len(params["model"]["class_list"])
self.penalty_weights = params["penalty_weights"]

def _compute_single_class_loss(
self, prediction: torch.Tensor, target: torch.Tensor, class_idx: int
) -> torch.Tensor:
"""Compute loss for a single class."""
loss_value = self._single_class_loss_calculator(
prediction[:, class_idx, ...], target[:, class_idx, ...]
)
return 1 - loss_value

def _optional_loss_operations(self, loss: torch.Tensor) -> torch.Tensor:
"""
Cross entropy loss using class weights if provided.
Perform addtional operations of the loss value. Defaults to identity operation.
If needed, child classes can override this method. Useful in the cases where
for example, the loss value needs to log-transformed or clipped.
"""
super().__init__(params)
return loss

def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if len(target.shape) > 1 and target.shape[-1] == 1:
target = torch.squeeze(target, -1)

weights = None
if self.params.get("penalty_weights") is not None:
num_classes = len(self.params["penalty_weights"])
assert (
prediction.shape[-1] == num_classes
), f"Number of classes {num_classes} does not match prediction shape {prediction.shape[-1]}"

weights = torch.tensor(
list(self.params["penalty_weights"].values()),
dtype=torch.float32,
device=target.device,
@abstractmethod
def _single_class_loss_calculator(
self, prediction: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
"""Compute loss for a pair of prediction and target tensors. To be implemented by child classes."""
pass

def forward(
self, prediction: torch.Tensor, target: torch.Tensor, *args
) -> torch.Tensor:
accumulated_loss = torch.tensor(0.0, device=prediction.device)

for class_idx in range(self.num_classes):
current_loss = self._compute_single_class_loss(
prediction, target, class_idx
)
current_loss = self._optional_loss_operations(current_loss)

if self.penalty_weights is not None:
current_loss = current_loss * self.penalty_weights[class_idx]
accumulated_loss += current_loss

if self.penalty_weights is None:
accumulated_loss /= self.num_classes

cel = nn.CrossEntropyLoss(weight=weights)
return cel(prediction, target)
return accumulated_loss
144 changes: 129 additions & 15 deletions GANDLF/losses/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,127 @@
import sys
from typing import List, Optional
import torch
from .loss_interface import AbstractSegmentationMultiClassLoss, AbstractLossFunction


# Dice scores and dice losses
def dice(predicted: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
class MulticlassDiceLoss(AbstractSegmentationMultiClassLoss):
"""
This class computes the Dice loss between two tensors.
"""
This function computes a dice score between two tensors.

Args:
predicted (torch.Tensor): Predicted value by the network.
target (torch.Tensor): Required target label to match the predicted with
def _single_class_loss_calculator(
self, prediction: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
"""
Compute Dice score for a single class.
Returns:
torch.Tensor: The computed dice score.
Args:
prediction (torch.Tensor): Network's predicted segmentation mask
target (torch.Tensor): Target segmentation mask
Returns:
torch.Tensor: The computed dice score.
"""
predicted_flat = prediction.flatten()
label_flat = target.flatten()
intersection = (predicted_flat * label_flat).sum()

dice_score = (2.0 * intersection + sys.float_info.min) / (
predicted_flat.sum() + label_flat.sum() + sys.float_info.min
)

return dice_score


class MulticlassDiceLogLoss(MulticlassDiceLoss):
def _optional_loss_operations(self, loss):
return -torch.log(
loss + torch.finfo(torch.float32).eps
) # epsilon for numerical stability


class MulticlassMCCLoss(AbstractSegmentationMultiClassLoss):
"""
This class computes the Matthews Correlation Coefficient (MCC) loss between two tensors.
"""
predicted_flat = predicted.flatten()
label_flat = target.flatten()
intersection = (predicted_flat * label_flat).sum()

dice_score = (2.0 * intersection + sys.float_info.min) / (
predicted_flat.sum() + label_flat.sum() + sys.float_info.min
)
def _single_class_loss_calculator(
self, prediction: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
"""
Compute MCC score for a single class.
Args:
prediction (torch.Tensor): Network's predicted segmentation mask
target (torch.Tensor): Target segmentation mask
Returns:
torch.Tensor: The computed MCC score.
"""
tp = torch.sum(torch.mul(prediction, target))
tn = torch.sum(torch.mul((1 - prediction), (1 - target)))
fp = torch.sum(torch.mul(prediction, (1 - target)))
fn = torch.sum(torch.mul((1 - prediction), target))

numerator = torch.mul(tp, tn) - torch.mul(fp, fn)
# Adding epsilon to the denominator to avoid divide-by-zero errors.
denominator = (
torch.sqrt(
torch.add(tp, 1, fp)
* torch.add(tp, 1, fn)
* torch.add(tn, 1, fp)
* torch.add(tn, 1, fn)
)
+ torch.finfo(torch.float32).eps
)

return dice_score
return torch.div(numerator.sum(), denominator.sum())


class MulticlassMCLLogLoss(MulticlassMCCLoss):
def _optional_loss_operations(self, loss):
return -torch.log(
loss + torch.finfo(torch.float32).eps
) # epsilon for numerical stability


class MulticlassTverskyLoss(AbstractSegmentationMultiClassLoss):
"""
This class computes the Tversky loss between two tensors.
"""

def __init__(self, params: dict):
super().__init__(params)
self.alpha = params.get("alpha", 0.5)
self.beta = params.get("beta", 0.5)

def _single_class_loss_calculator(
self, prediction: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
"""
Compute Tversky score for a single class.
Args:
prediction (torch.Tensor): Network's predicted segmentation mask
target (torch.Tensor): Target segmentation mask
Returns:
torch.Tensor: The computed Tversky score.
"""
predicted_flat = prediction.contiguous().view(-1)
target_flat = target.contiguous().view(-1)

true_positives = (predicted_flat * target_flat).sum()
false_positives = ((1 - target_flat) * predicted_flat).sum()
false_negatives = (target_flat * (1 - predicted_flat)).sum()

numerator = true_positives
denominator = (
true_positives + self.alpha * false_positives + self.beta * false_negatives
)
loss = (numerator + sys.float_info.min) / (denominator + sys.float_info.min)

return loss


def mcc(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -114,6 +212,22 @@ def generic_loss_calculator(
return accumulated_loss


class KullbackLeiblerDivergence(AbstractLossFunction):
def forward(self, mu: torch.Tensor, logvar: torch.Tensor, *args) -> torch.Tensor:
"""
Calculates the Kullback-Leibler divergence between two Gaussian distributions.
Args:
mu (torch.Tensor): The mean of the first Gaussian distribution.
logvar (torch.Tensor): The logarithm of the variance of the first Gaussian distribution.
Returns:
torch.Tensor: The computed Kullback-Leibler divergence
"""
loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1)
return loss.mean()


def MCD_loss(
predicted: torch.Tensor, target: torch.Tensor, params: dict
) -> torch.Tensor:
Expand Down

0 comments on commit 227a86f

Please sign in to comment.