Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new version of the SSIM metric #6250

Merged
merged 12 commits into from
Apr 22, 2023
Merged
113 changes: 67 additions & 46 deletions monai/losses/ssim_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@

from __future__ import annotations

from collections.abc import Sequence

import torch
from torch.nn.modules.loss import _Loss

from monai.metrics.regression import SSIMMetric
from monai.metrics.regression import KernelType, SSIMMetric
from monai.utils import LossReduction, ensure_tuple_rep


class SSIMLoss(_Loss):
"""
Build a Pytorch version of the SSIM loss function based on the original formula of SSIM

Modified and adopted from:
https://github.com/facebookresearch/fastMRI/blob/main/banding_removal/fastmri/ssim_loss_mixin.py
Compute the loss function based on the Structural Similarity Index Measure (SSIM) Metric.

For more info, visit
https://vicuesoft.com/glossary/term/ssim-ms-ssim/
Expand All @@ -32,29 +32,67 @@ class SSIMLoss(_Loss):
similarity." IEEE transactions on image processing 13.4 (2004): 600-612.
"""

def __init__(self, win_size: int = 7, k1: float = 0.01, k2: float = 0.03, spatial_dims: int = 2):
def __init__(
self,
spatial_dims: int,
data_range: float = 1.0,
kernel_type: KernelType | str = KernelType.GAUSSIAN,
kernel_size: int | Sequence[int] = 11,
kernel_sigma: float | Sequence[float] = 1.5,
k1: float = 0.01,
k2: float = 0.03,
wyli marked this conversation as resolved.
Show resolved Hide resolved
reduction: LossReduction | str = LossReduction.MEAN,
):
"""
Args:
win_size: gaussian weighting window size
spatial_dims: number of spatial dimensions of the input images.
data_range: value range of input images. (usually 1.0 or 255)
kernel_type: type of kernel, can be "gaussian" or "uniform".
kernel_size: size of kernel
kernel_sigma: standard deviation for Gaussian kernel.
k1: stability constant used in the luminance denominator
k2: stability constant used in the contrast denominator
spatial_dims: if 2, input shape is expected to be (B,C,H,W). if 3, it is expected to be (B,C,H,W,D)
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.

"""
super().__init__()
self.win_size = win_size
self.k1, self.k2 = k1, k2
super().__init__(reduction=LossReduction(reduction).value)
self.spatial_dims = spatial_dims

def forward(self, x: torch.Tensor, y: torch.Tensor, data_range: torch.Tensor) -> torch.Tensor:
self.data_range = data_range
self.kernel_type = kernel_type

if not isinstance(kernel_size, Sequence):
kernel_size = ensure_tuple_rep(kernel_size, spatial_dims)
self.kernel_size = kernel_size

if not isinstance(kernel_sigma, Sequence):
kernel_sigma = ensure_tuple_rep(kernel_sigma, spatial_dims)
self.kernel_sigma = kernel_sigma

self.k1 = k1
self.k2 = k2

self.ssim_metric = SSIMMetric(
spatial_dims=self.spatial_dims,
data_range=self.data_range,
kernel_type=self.kernel_type,
kernel_size=self.kernel_size,
kernel_sigma=self.kernel_sigma,
k1=self.k1,
k2=self.k2,
)

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
x: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D and pseudo-3D data,
and (B,C,W,H,D) for 3D data,
y: second sample (e.g., the reconstructed image). It has similar shape as x.
data_range: dynamic range of the data
input: batch of predicted images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3])
target: batch of target images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3])

Returns:
1-ssim_value (recall this is meant to be a loss function)
1 minus the ssim index (recall this is meant to be a loss function)

Example:
.. code-block:: python
Expand All @@ -64,41 +102,24 @@ def forward(self, x: torch.Tensor, y: torch.Tensor, data_range: torch.Tensor) ->
# 2D data
x = torch.ones([1,1,10,10])/2
y = torch.ones([1,1,10,10])/2
data_range = x.max().unsqueeze(0)
# the following line should print 1.0 (or 0.9999)
print(1-SSIMLoss(spatial_dims=2)(x,y,data_range))
print(1-SSIMLoss(spatial_dims=2)(x,y))

# pseudo-3D data
x = torch.ones([1,5,10,10])/2 # 5 could represent number of slices
y = torch.ones([1,5,10,10])/2
data_range = x.max().unsqueeze(0)
# the following line should print 1.0 (or 0.9999)
print(1-SSIMLoss(spatial_dims=2)(x,y,data_range))
print(1-SSIMLoss(spatial_dims=2)(x,y))

# 3D data
x = torch.ones([1,1,10,10,10])/2
y = torch.ones([1,1,10,10,10])/2
data_range = x.max().unsqueeze(0)
# the following line should print 1.0 (or 0.9999)
print(1-SSIMLoss(spatial_dims=3)(x,y,data_range))
print(1-SSIMLoss(spatial_dims=3)(x,y))
"""
if x.shape[0] == 1:
ssim_value: torch.Tensor = SSIMMetric(
data_range, self.win_size, self.k1, self.k2, self.spatial_dims
)._compute_tensor(x, y)
elif x.shape[0] > 1:
for i in range(x.shape[0]):
ssim_val: torch.Tensor = SSIMMetric(
data_range, self.win_size, self.k1, self.k2, self.spatial_dims
)._compute_tensor(x[i : i + 1], y[i : i + 1])
if i == 0:
ssim_value = ssim_val
else:
ssim_value = torch.cat((ssim_value.view(i), ssim_val.view(1)), dim=0)

else:
raise ValueError("Batch size is not nonnegative integer value")
# 1- dimensional tensor is only allowed
ssim_value = ssim_value.view(-1, 1)
loss: torch.Tensor = 1 - ssim_value.mean()
ssim_value = self.ssim_metric._compute_tensor(input, target).view(-1, 1)
loss: torch.Tensor = 1 - ssim_value

if self.reduction == LossReduction.MEAN.value:
loss = torch.mean(loss) # the batch average
elif self.reduction == LossReduction.SUM.value:
loss = torch.sum(loss) # sum over the batch

return loss
Loading