Skip to content

Commit

Permalink
Add new version of the SSIM metric (#6250)
Browse files Browse the repository at this point in the history
fixes #6249

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.

---------

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>
Co-authored-by: Pedro F. da Costa <[email protected]>
Co-authored-by: Mark Graham <[email protected]>
  • Loading branch information
3 people authored Apr 22, 2023
1 parent 9a5b900 commit 43902e3
Show file tree
Hide file tree
Showing 4 changed files with 319 additions and 255 deletions.
113 changes: 67 additions & 46 deletions monai/losses/ssim_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@

from __future__ import annotations

from collections.abc import Sequence

import torch
from torch.nn.modules.loss import _Loss

from monai.metrics.regression import SSIMMetric
from monai.metrics.regression import KernelType, SSIMMetric
from monai.utils import LossReduction, ensure_tuple_rep


class SSIMLoss(_Loss):
"""
Build a Pytorch version of the SSIM loss function based on the original formula of SSIM
Modified and adopted from:
https://github.com/facebookresearch/fastMRI/blob/main/banding_removal/fastmri/ssim_loss_mixin.py
Compute the loss function based on the Structural Similarity Index Measure (SSIM) Metric.
For more info, visit
https://vicuesoft.com/glossary/term/ssim-ms-ssim/
Expand All @@ -32,29 +32,67 @@ class SSIMLoss(_Loss):
similarity." IEEE transactions on image processing 13.4 (2004): 600-612.
"""

def __init__(self, win_size: int = 7, k1: float = 0.01, k2: float = 0.03, spatial_dims: int = 2):
def __init__(
self,
spatial_dims: int,
data_range: float = 1.0,
kernel_type: KernelType | str = KernelType.GAUSSIAN,
win_size: int | Sequence[int] = 11,
kernel_sigma: float | Sequence[float] = 1.5,
k1: float = 0.01,
k2: float = 0.03,
reduction: LossReduction | str = LossReduction.MEAN,
):
"""
Args:
win_size: gaussian weighting window size
spatial_dims: number of spatial dimensions of the input images.
data_range: value range of input images. (usually 1.0 or 255)
kernel_type: type of kernel, can be "gaussian" or "uniform".
win_size: window size of kernel
kernel_sigma: standard deviation for Gaussian kernel.
k1: stability constant used in the luminance denominator
k2: stability constant used in the contrast denominator
spatial_dims: if 2, input shape is expected to be (B,C,H,W). if 3, it is expected to be (B,C,H,W,D)
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.
"""
super().__init__()
self.win_size = win_size
self.k1, self.k2 = k1, k2
super().__init__(reduction=LossReduction(reduction).value)
self.spatial_dims = spatial_dims

def forward(self, x: torch.Tensor, y: torch.Tensor, data_range: torch.Tensor) -> torch.Tensor:
self.data_range = data_range
self.kernel_type = kernel_type

if not isinstance(win_size, Sequence):
win_size = ensure_tuple_rep(win_size, spatial_dims)
self.kernel_size = win_size

if not isinstance(kernel_sigma, Sequence):
kernel_sigma = ensure_tuple_rep(kernel_sigma, spatial_dims)
self.kernel_sigma = kernel_sigma

self.k1 = k1
self.k2 = k2

self.ssim_metric = SSIMMetric(
spatial_dims=self.spatial_dims,
data_range=self.data_range,
kernel_type=self.kernel_type,
win_size=self.kernel_size,
kernel_sigma=self.kernel_sigma,
k1=self.k1,
k2=self.k2,
)

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
x: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D and pseudo-3D data,
and (B,C,W,H,D) for 3D data,
y: second sample (e.g., the reconstructed image). It has similar shape as x.
data_range: dynamic range of the data
input: batch of predicted images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3])
target: batch of target images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3])
Returns:
1-ssim_value (recall this is meant to be a loss function)
1 minus the ssim index (recall this is meant to be a loss function)
Example:
.. code-block:: python
Expand All @@ -64,41 +102,24 @@ def forward(self, x: torch.Tensor, y: torch.Tensor, data_range: torch.Tensor) ->
# 2D data
x = torch.ones([1,1,10,10])/2
y = torch.ones([1,1,10,10])/2
data_range = x.max().unsqueeze(0)
# the following line should print 1.0 (or 0.9999)
print(1-SSIMLoss(spatial_dims=2)(x,y,data_range))
print(1-SSIMLoss(spatial_dims=2)(x,y))
# pseudo-3D data
x = torch.ones([1,5,10,10])/2 # 5 could represent number of slices
y = torch.ones([1,5,10,10])/2
data_range = x.max().unsqueeze(0)
# the following line should print 1.0 (or 0.9999)
print(1-SSIMLoss(spatial_dims=2)(x,y,data_range))
print(1-SSIMLoss(spatial_dims=2)(x,y))
# 3D data
x = torch.ones([1,1,10,10,10])/2
y = torch.ones([1,1,10,10,10])/2
data_range = x.max().unsqueeze(0)
# the following line should print 1.0 (or 0.9999)
print(1-SSIMLoss(spatial_dims=3)(x,y,data_range))
print(1-SSIMLoss(spatial_dims=3)(x,y))
"""
if x.shape[0] == 1:
ssim_value: torch.Tensor = SSIMMetric(
data_range, self.win_size, self.k1, self.k2, self.spatial_dims
)._compute_tensor(x, y)
elif x.shape[0] > 1:
for i in range(x.shape[0]):
ssim_val: torch.Tensor = SSIMMetric(
data_range, self.win_size, self.k1, self.k2, self.spatial_dims
)._compute_tensor(x[i : i + 1], y[i : i + 1])
if i == 0:
ssim_value = ssim_val
else:
ssim_value = torch.cat((ssim_value.view(i), ssim_val.view(1)), dim=0)

else:
raise ValueError("Batch size is not nonnegative integer value")
# 1- dimensional tensor is only allowed
ssim_value = ssim_value.view(-1, 1)
loss: torch.Tensor = 1 - ssim_value.mean()
ssim_value = self.ssim_metric._compute_tensor(input, target).view(-1, 1)
loss: torch.Tensor = 1 - ssim_value

if self.reduction == LossReduction.MEAN.value:
loss = torch.mean(loss) # the batch average
elif self.reduction == LossReduction.SUM.value:
loss = torch.sum(loss) # sum over the batch

return loss
Loading

0 comments on commit 43902e3

Please sign in to comment.