Skip to content

Commit

Permalink
Add loss
Browse files Browse the repository at this point in the history
  • Loading branch information
Warvito committed Mar 29, 2023
1 parent fa6677e commit e3545ae
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 80 deletions.
58 changes: 24 additions & 34 deletions monai/losses/ssim_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch.nn.modules.loss import _Loss

from monai.metrics.regression import KernelType, SSIMMetric
from monai.utils import ensure_tuple_rep
from monai.utils import LossReduction, ensure_tuple_rep


class SSIMLoss(_Loss):
Expand All @@ -41,6 +41,7 @@ def __init__(
kernel_sigma: float | Sequence[float, ...] = 1.5,
k1: float = 0.01,
k2: float = 0.03,
reduction: LossReduction | str = LossReduction.MEAN,
):
"""
Args:
Expand All @@ -51,8 +52,14 @@ def __init__(
kernel_sigma: standard deviation for Gaussian kernel.
k1: stability constant used in the luminance denominator
k2: stability constant used in the contrast denominator
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.
"""
super().__init__()
super().__init__(reduction=LossReduction(reduction).value)
self.spatial_dims = spatial_dims
self.data_range = data_range
self.kernel_type = kernel_type
Expand All @@ -78,14 +85,14 @@ def __init__(
k2=self.k2,
)

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
x: batch of predicted images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3])
y: batch of target images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3])
input: batch of predicted images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3])
target: batch of target images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3])
Returns:
1 minus the Structural Similarity Index Measure (recall this is meant to be a loss function)
1 minus the ssim index (recall this is meant to be a loss function)
Example:
.. code-block:: python
Expand All @@ -95,41 +102,24 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# 2D data
x = torch.ones([1,1,10,10])/2
y = torch.ones([1,1,10,10])/2
data_range = x.max().unsqueeze(0)
# the following line should print 1.0 (or 0.9999)
print(1-SSIMLoss(spatial_dims=2)(x,y,data_range))
print(1-SSIMLoss(spatial_dims=2)(x,y))
# pseudo-3D data
x = torch.ones([1,5,10,10])/2 # 5 could represent number of slices
y = torch.ones([1,5,10,10])/2
data_range = x.max().unsqueeze(0)
# the following line should print 1.0 (or 0.9999)
print(1-SSIMLoss(spatial_dims=2)(x,y,data_range))
print(1-SSIMLoss(spatial_dims=2)(x,y))
# 3D data
x = torch.ones([1,1,10,10,10])/2
y = torch.ones([1,1,10,10,10])/2
data_range = x.max().unsqueeze(0)
# the following line should print 1.0 (or 0.9999)
print(1-SSIMLoss(spatial_dims=3)(x,y,data_range))
print(1-SSIMLoss(spatial_dims=3)(x,y))
"""
if x.shape[0] == 1:
ssim_value: torch.Tensor = SSIMMetric(
data_range, self.win_size, self.k1, self.k2, self.spatial_dims
)._compute_tensor(x, y)
elif x.shape[0] > 1:
for i in range(x.shape[0]):
ssim_val: torch.Tensor = SSIMMetric(
data_range, self.win_size, self.k1, self.k2, self.spatial_dims
)._compute_tensor(x[i : i + 1], y[i : i + 1])
if i == 0:
ssim_value = ssim_val
else:
ssim_value = torch.cat((ssim_value.view(i), ssim_val.view(1)), dim=0)

else:
raise ValueError("Batch size is not nonnegative integer value")
# 1- dimensional tensor is only allowed
ssim_value = ssim_value.view(-1, 1)
loss: torch.Tensor = 1 - ssim_value.mean()
ssim_value = self.ssim_metric._compute_tensor(input, target).view(-1, 1)
loss: torch.Tensor = 1 - ssim_value

if self.reduction == LossReduction.MEAN.value:
loss = torch.mean(loss) # the batch average
elif self.reduction == LossReduction.SUM.value:
loss = torch.sum(loss) # sum over the batch

return loss
78 changes: 32 additions & 46 deletions tests/test_ssim_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,58 +13,44 @@

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.losses.ssim_loss import SSIMLoss

TESTS2D = []
for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]:
for batch_size in [1, 2, 16]:
x = torch.ones([batch_size, 1, 10, 10]) / 2
y1 = torch.ones([batch_size, 1, 10, 10]) / 2
y2 = torch.zeros([batch_size, 1, 10, 10])
data_range = x.max().unsqueeze(0)
TESTS2D.append((x.to(device), y1.to(device), data_range.to(device), torch.tensor(1.0).unsqueeze(0).to(device)))
TESTS2D.append((x.to(device), y2.to(device), data_range.to(device), torch.tensor(0.0).unsqueeze(0).to(device)))

TESTS3D = []
for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]:
for batch_size in [1, 2, 16]:
x = torch.ones([batch_size, 1, 10, 10, 10]) / 2
y1 = torch.ones([batch_size, 1, 10, 10, 10]) / 2
y2 = torch.zeros([batch_size, 1, 10, 10, 10])
data_range = x.max().unsqueeze(0)
TESTS3D.append((x.to(device), y1.to(device), data_range.to(device), torch.tensor(1.0).unsqueeze(0).to(device)))
TESTS3D.append((x.to(device), y2.to(device), data_range.to(device), torch.tensor(0.0).unsqueeze(0).to(device)))

TESTS2D_GRAD = []
for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]:
for batch_size in [1, 2, 16]:
x = torch.ones([batch_size, 1, 10, 10]) / 2
y = torch.ones([batch_size, 1, 10, 10]) / 2
y.requires_grad_(True)
data_range = x.max().unsqueeze(0)
TESTS2D_GRAD.append([x.to(device), y.to(device), data_range.to(device)])
from monai.utils import set_determinism
from tests.utils import test_script_save


class TestSSIMLoss(unittest.TestCase):
@parameterized.expand(TESTS2D)
def test2d(self, x, y, drange, res):
result = 1 - SSIMLoss(spatial_dims=2)(x, y, drange)
self.assertTrue(isinstance(result, torch.Tensor))
self.assertTrue(torch.abs(res - result).item() < 0.001)

@parameterized.expand(TESTS2D_GRAD)
def test_grad(self, x, y, drange):
result = 1 - SSIMLoss(spatial_dims=2)(x, y, drange)
self.assertTrue(result.requires_grad)

@parameterized.expand(TESTS3D)
def test3d(self, x, y, drange, res):
result = 1 - SSIMLoss(spatial_dims=3)(x, y, drange)
self.assertTrue(isinstance(result, torch.Tensor))
self.assertTrue(torch.abs(res - result).item() < 0.001)
def test_shape(self):
set_determinism(0)
preds = torch.abs(torch.randn(2, 3, 16, 16))
target = torch.abs(torch.randn(2, 3, 16, 16))
preds = preds / preds.max()
target = target / target.max()

result = SSIMLoss(spatial_dims=2, data_range=1.0, kernel_type="gaussian", reduction="mean").forward(
preds, target
)
expected_val = 0.9546
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4)

result = SSIMLoss(spatial_dims=2, data_range=1.0, kernel_type="gaussian", reduction="sum").forward(
preds, target
)
expected_val = 1.9092
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4)

result = SSIMLoss(spatial_dims=2, data_range=1.0, kernel_type="gaussian", reduction="none").forward(
preds, target
)
expected_val = [[0.9121], [0.9971]]
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4)

def test_script(self):
loss = SSIMLoss(spatial_dims=2)
test_input = torch.ones(2, 2, 16, 16)
test_script_save(loss, test_input, test_input)


if __name__ == "__main__":
Expand Down

0 comments on commit e3545ae

Please sign in to comment.