Skip to content

Commit

Permalink
feat: add clDice loss (#6763)
Browse files Browse the repository at this point in the history
Fixes #5938

### Description

This PR aims to add the `SoftclDiceLoss` and the `SoftDiceclDiceLoss`
from [clDice - a Novel Topology-Preserving Loss Function for Tubular
Structure
Segmentation](https://openaccess.thecvf.com/content/CVPR2021/papers/Shit_clDice_-_A_Novel_Topology-Preserving_Loss_Function_for_Tubular_Structure_CVPR_2021_paper.pdf)

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Saurav Maheshkar <[email protected]>
  • Loading branch information
SauravMaheshkar authored Jul 25, 2023
1 parent 28c9083 commit 2800a76
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 0 deletions.
1 change: 1 addition & 0 deletions monai/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss
from .contrastive import ContrastiveLoss
from .deform import BendingEnergyLoss
from .dice import (
Expand Down
184 changes: 184 additions & 0 deletions monai/losses/cldice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss


def soft_erode(img: torch.Tensor) -> torch.Tensor: # type: ignore
"""
Perform soft erosion on the input image
Args:
img: the shape should be BCH(WD)
Adapted from:
https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L6
"""
if len(img.shape) == 4:
p1 = -(F.max_pool2d(-img, (3, 1), (1, 1), (1, 0)))
p2 = -(F.max_pool2d(-img, (1, 3), (1, 1), (0, 1)))
return torch.min(p1, p2) # type: ignore
elif len(img.shape) == 5:
p1 = -(F.max_pool3d(-img, (3, 1, 1), (1, 1, 1), (1, 0, 0)))
p2 = -(F.max_pool3d(-img, (1, 3, 1), (1, 1, 1), (0, 1, 0)))
p3 = -(F.max_pool3d(-img, (1, 1, 3), (1, 1, 1), (0, 0, 1)))
return torch.min(torch.min(p1, p2), p3) # type: ignore


def soft_dilate(img: torch.Tensor) -> torch.Tensor: # type: ignore
"""
Perform soft dilation on the input image
Args:
img: the shape should be BCH(WD)
Adapted from:
https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L18
"""
if len(img.shape) == 4:
return F.max_pool2d(img, (3, 3), (1, 1), (1, 1)) # type: ignore
elif len(img.shape) == 5:
return F.max_pool3d(img, (3, 3, 3), (1, 1, 1), (1, 1, 1)) # type: ignore


def soft_open(img: torch.Tensor) -> torch.Tensor:
"""
Wrapper function to perform soft opening on the input image
Args:
img: the shape should be BCH(WD)
Adapted from:
https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L25
"""
eroded_image = soft_erode(img)
dilated_image = soft_dilate(eroded_image)
return dilated_image


def soft_skel(img: torch.Tensor, iter_: int) -> torch.Tensor:
"""
Perform soft skeletonization on the input image
Adapted from:
https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L29
Args:
img: the shape should be BCH(WD)
iter_: number of iterations for skeletonization
Returns:
skeletonized image
"""
img1 = soft_open(img)
skel = F.relu(img - img1)
for _ in range(iter_):
img = soft_erode(img)
img1 = soft_open(img)
delta = F.relu(img - img1)
skel = skel + F.relu(delta - skel * delta)
return skel


def soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor, smooth: float = 1.0) -> torch.Tensor:
"""
Function to compute soft dice loss
Adapted from:
https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L22
Args:
y_true: the shape should be BCH(WD)
y_pred: the shape should be BCH(WD)
Returns:
dice loss
"""
intersection = torch.sum((y_true * y_pred)[:, 1:, ...])
coeff = (2.0 * intersection + smooth) / (torch.sum(y_true[:, 1:, ...]) + torch.sum(y_pred[:, 1:, ...]) + smooth)
soft_dice: torch.Tensor = 1.0 - coeff
return soft_dice


class SoftclDiceLoss(_Loss):
"""
Compute the Soft clDice loss defined in:
Shit et al. (2021) clDice -- A Novel Topology-Preserving Loss Function
for Tubular Structure Segmentation. (https://arxiv.org/abs/2003.07311)
Adapted from:
https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L7
"""

def __init__(self, iter_: int = 3, smooth: float = 1.0) -> None:
"""
Args:
iter_: Number of iterations for skeletonization
smooth: Smoothing parameter
"""
super().__init__()
self.iter = iter_
self.smooth = smooth

def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
skel_pred = soft_skel(y_pred, self.iter)
skel_true = soft_skel(y_true, self.iter)
tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / (
torch.sum(skel_pred[:, 1:, ...]) + self.smooth
)
tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / (
torch.sum(skel_true[:, 1:, ...]) + self.smooth
)
cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens)
return cl_dice


class SoftDiceclDiceLoss(_Loss):
"""
Compute the Soft clDice loss defined in:
Shit et al. (2021) clDice -- A Novel Topology-Preserving Loss Function
for Tubular Structure Segmentation. (https://arxiv.org/abs/2003.07311)
Adapted from:
https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L38
"""

def __init__(self, iter_: int = 3, alpha: float = 0.5, smooth: float = 1.0) -> None:
"""
Args:
iter_: Number of iterations for skeletonization
smooth: Smoothing parameter
alpha: Weighing factor for cldice
"""
super().__init__()
self.iter = iter_
self.smooth = smooth
self.alpha = alpha

def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
dice = soft_dice(y_true, y_pred, self.smooth)
skel_pred = soft_skel(y_pred, self.iter)
skel_true = soft_skel(y_true, self.iter)
tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / (
torch.sum(skel_pred[:, 1:, ...]) + self.smooth
)
tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / (
torch.sum(skel_true[:, 1:, ...]) + self.smooth
)
cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens)
total_loss: torch.Tensor = (1.0 - self.alpha) * dice + self.alpha * cl_dice
return total_loss
56 changes: 56 additions & 0 deletions tests/test_cldice_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.losses import SoftclDiceLoss, SoftDiceclDiceLoss

TEST_CASES = [
[ # shape: (1, 4), (1, 4)
{"y_pred": torch.ones((100, 3, 256, 256)), "y_true": torch.ones((100, 3, 256, 256))},
0.0,
],
[ # shape: (1, 5), (1, 5)
{"y_pred": torch.ones((100, 3, 256, 256, 5)), "y_true": torch.ones((100, 3, 256, 256, 5))},
0.0,
],
]


class TestclDiceLoss(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_result(self, y_pred_data, expected_val):
loss = SoftclDiceLoss()
loss_dice = SoftDiceclDiceLoss()
result = loss(**y_pred_data)
result_dice = loss_dice(**y_pred_data)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(result_dice.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)

def test_with_cuda(self):
loss = SoftclDiceLoss()
loss_dice = SoftDiceclDiceLoss()
i = torch.ones((100, 3, 256, 256))
j = torch.ones((100, 3, 256, 256))
if torch.cuda.is_available():
i = i.cuda()
j = j.cuda()
output = loss(i, j)
output_dice = loss_dice(i, j)
np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(output_dice.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4)


if __name__ == "__main__":
unittest.main()

0 comments on commit 2800a76

Please sign in to comment.