diff --git a/GANDLF/losses/loss_interface.py b/GANDLF/losses/loss_interface.py index 49d5b5031..53c5a9325 100644 --- a/GANDLF/losses/loss_interface.py +++ b/GANDLF/losses/loss_interface.py @@ -9,9 +9,7 @@ def __init__(self, params: dict): self.params = params @abstractmethod - def forward( - self, prediction: torch.Tensor, target: torch.Tensor, *args - ) -> torch.Tensor: + def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: pass @@ -49,9 +47,7 @@ def _single_class_loss_calculator( """Compute loss for a pair of prediction and target tensors. To be implemented by child classes.""" pass - def forward( - self, prediction: torch.Tensor, target: torch.Tensor, *args - ) -> torch.Tensor: + def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: accumulated_loss = torch.tensor(0.0, device=prediction.device) for class_idx in range(self.num_classes): @@ -64,6 +60,7 @@ def forward( current_loss = current_loss * self.penalty_weights[class_idx] accumulated_loss += current_loss + # TODO shouldn't we always divide by the number of classes? if self.penalty_weights is None: accumulated_loss /= self.num_classes diff --git a/GANDLF/losses/segmentation.py b/GANDLF/losses/segmentation.py index 35feb3c25..675dab74c 100644 --- a/GANDLF/losses/segmentation.py +++ b/GANDLF/losses/segmentation.py @@ -124,6 +124,95 @@ def _single_class_loss_calculator( return loss +class MulticlassFocalLoss(AbstractSegmentationMultiClassLoss): + """ + This class computes the Focal loss between two tensors. + """ + + def __init__(self, params: dict): + super().__init__(params) + + self.ce_loss_helper = torch.nn.CrossEntropyLoss(reduction="none") + loss_params = params["loss_function"] + self.alpha = 1.0 + self.gamma = 2.0 + self.output_aggregation = "sum" + if isinstance(loss_params, dict): + self.alpha = loss_params.get("alpha", self.alpha) + self.gamma = loss_params.get("gamma", self.gamma) + self.output_aggregation = loss_params.get( + "size_average", + self.output_aggregation, # naming mismatch of key due to keeping API consistent with config format + ) + assert self.output_aggregation in [ + "sum", + "mean", + ], f"Invalid output aggregation method defined for Foal Loss: {self.output_aggregation}. Valid options are ['sum', 'mean']" + + def _single_class_loss_calculator( + self, prediction: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """ + Compute focal loss for a single class. It is based on the following formulas: + FocalLoss(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t) + CrossEntropy(pred, target) = -log(pred) if target = 1 else -log(1 - pred) + CrossEntropy(p_t) = CrossEntropy(pred, target) = -log(p_t) + p_t = p if target = 1 else 1 - p + """ + ce_loss = self.ce_loss_helper(prediction, target) + p_t = torch.exp(-ce_loss) + loss = -self.alpha * (1 - p_t) ** self.gamma * ce_loss + return loss.sum() if self.output_aggregation == "sum" else loss.mean() + + def _compute_single_class_loss( + self, prediction: torch.Tensor, target: torch.Tensor, class_idx: int + ) -> torch.Tensor: + """Compute loss for a single class.""" + loss_value = self._single_class_loss_calculator( + prediction[:, class_idx, ...], target[:, class_idx, ...] + ) + return loss_value # no need to subtract from 1 in this case, hence the override + + +class KullbackLeiblerDivergence(AbstractLossFunction): + def forward(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: + """ + Calculates the Kullback-Leibler divergence between two Gaussian distributions. + + Args: + mu (torch.Tensor): The mean of the first Gaussian distribution. + logvar (torch.Tensor): The logarithm of the variance of the first Gaussian distribution. + + Returns: + torch.Tensor: The computed Kullback-Leibler divergence + """ + loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1) + return loss.mean() + + +# Dice scores and dice losses +def dice(predicted: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + This function computes a dice score between two tensors. + + Args: + predicted (torch.Tensor): Predicted value by the network. + target (torch.Tensor): Required target label to match the predicted with + + Returns: + torch.Tensor: The computed dice score. + """ + predicted_flat = predicted.flatten() + label_flat = target.flatten() + intersection = (predicted_flat * label_flat).sum() + + dice_score = (2.0 * intersection + sys.float_info.min) / ( + predicted_flat.sum() + label_flat.sum() + sys.float_info.min + ) + + return dice_score + + def mcc(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ This function computes the Matthews Correlation Coefficient (MCC) between two tensors. Adapted from https://github.com/kakumarabhishek/MCC-Loss/blob/main/loss.py. @@ -212,22 +301,6 @@ def generic_loss_calculator( return accumulated_loss -class KullbackLeiblerDivergence(AbstractLossFunction): - def forward(self, mu: torch.Tensor, logvar: torch.Tensor, *args) -> torch.Tensor: - """ - Calculates the Kullback-Leibler divergence between two Gaussian distributions. - - Args: - mu (torch.Tensor): The mean of the first Gaussian distribution. - logvar (torch.Tensor): The logarithm of the variance of the first Gaussian distribution. - - Returns: - torch.Tensor: The computed Kullback-Leibler divergence - """ - loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1) - return loss.mean() - - def MCD_loss( predicted: torch.Tensor, target: torch.Tensor, params: dict ) -> torch.Tensor: