From 5f3dc7ef397e74c853bb34882b78b6e37001a9c1 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 27 Mar 2023 23:21:26 +0100 Subject: [PATCH 01/11] Add new version of the SSIM metric > > Co-authored-by: PedroFerreiradaCosta Signed-off-by: Walter Hugo Lopez Pinaya --- monai/metrics/regression.py | 277 +++++++++++++++++++++--------------- 1 file changed, 163 insertions(+), 114 deletions(-) diff --git a/monai/metrics/regression.py b/monai/metrics/regression.py index 92c36da715..0b7b7c2a20 100644 --- a/monai/metrics/regression.py +++ b/monai/metrics/regression.py @@ -13,7 +13,7 @@ import math from abc import abstractmethod -from collections.abc import Callable +from collections.abc import Callable, Sequence from functools import partial from typing import Any @@ -21,7 +21,7 @@ import torch.nn.functional as F from monai.metrics.utils import do_metric_reduction -from monai.utils import MetricReduction +from monai.utils import MetricReduction, StrEnum, convert_data_type, ensure_tuple_rep from monai.utils.type_conversion import convert_to_dst_type from .metric import CumulativeIterationMetric @@ -232,9 +232,14 @@ def compute_mean_error_metrics(y_pred: torch.Tensor, y: torch.Tensor, func: Call return torch.mean(flt(func(y - y_pred)), dim=-1, keepdim=True) +class KernelType(StrEnum): + GAUSSIAN = "gaussian" + UNIFORM = "uniform" + + class SSIMMetric(RegressionMetric): r""" - Build a Pytorch version of the SSIM metric based on the original formula of SSIM + Computes the Structural Similarity Index Measure (SSIM). .. math:: \operatorname {SSIM}(x,y) =\frac {(2 \mu_x \mu_y + c_1)(2 \sigma_{xy} + c_2)}{((\mu_x^2 + \ @@ -243,19 +248,18 @@ class SSIMMetric(RegressionMetric): For more info, visit https://vicuesoft.com/glossary/term/ssim-ms-ssim/ - Modified and adopted from: - https://github.com/facebookresearch/fastMRI/blob/main/banding_removal/fastmri/ssim_loss_mixin.py - SSIM reference paper: Wang, Zhou, et al. "Image quality assessment: from error visibility to structural similarity." IEEE transactions on image processing 13.4 (2004): 600-612. Args: - data_range: dynamic range of the data - win_size: gaussian weighting window size + spatial_dims: number of spatial dimensions of the input images. + data_range: value range of input images. (usually 1.0 or 255) + kernel_type: type of kernel, can be "gaussian" or "uniform". + kernel_size: size of kernel + kernel_sigma: standard deviation for Gaussian kernel. k1: stability constant used in the luminance denominator k2: stability constant used in the contrast denominator - spatial_dims: if 2, input shape is expected to be (B,C,W,H). if 3, it is expected to be (B,C,W,H,D) reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction @@ -264,126 +268,171 @@ class SSIMMetric(RegressionMetric): def __init__( self, - data_range: torch.Tensor, - win_size: int = 7, + spatial_dims: int, + data_range: float = 1.0, + kernel_type: KernelType | str = KernelType.GAUSSIAN, + kernel_size: int | Sequence[int, ...] = 11, + kernel_sigma: int | Sequence[int, ...] = 1.5, k1: float = 0.01, k2: float = 0.03, - spatial_dims: int = 2, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, ): super().__init__(reduction=reduction, get_not_nans=get_not_nans) - self.data_range = data_range - self.win_size = win_size - self.k1, self.k2 = k1, k2 + self.spatial_dims = spatial_dims - self.cov_norm = (win_size**2) / (win_size**2 - 1) - self.w = torch.ones([1, 1] + [win_size for _ in range(spatial_dims)]) / win_size**spatial_dims - - def _compute_intermediate_statistics(self, x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, ...]: - data_range = self.data_range[(None,) * (self.spatial_dims + 2)] - # determine whether to work with 2D convolution or 3D - conv = getattr(F, f"conv{self.spatial_dims}d") - w = convert_to_dst_type(src=self.w, dst=x)[0] - - c1 = (self.k1 * data_range) ** 2 # stability constant for luminance - c2 = (self.k2 * data_range) ** 2 # stability constant for contrast - ux = conv(x, w) # mu_x - uy = conv(y, w) # mu_y - uxx = conv(x * x, w) # mu_x^2 - uyy = conv(y * y, w) # mu_y^2 - uxy = conv(x * y, w) # mu_xy - vx = self.cov_norm * (uxx - ux * ux) # sigma_x - vy = self.cov_norm * (uyy - uy * uy) # sigma_y - vxy = self.cov_norm * (uxy - ux * uy) # sigma_xy - - return c1, c2, ux, uy, vx, vy, vxy - - def _compute_metric(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + self.data_range = data_range + self.kernel_type = kernel_type + + if not isinstance(kernel_size, Sequence): + kernel_size = ensure_tuple_rep(kernel_size, spatial_dims) + self.kernel_size = kernel_size + + if not isinstance(kernel_sigma, Sequence): + kernel_sigma = ensure_tuple_rep(kernel_sigma, spatial_dims) + self.kernel_sigma = kernel_sigma + + self.k1 = k1 + self.k2 = k2 + + def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ Args: - x: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D data and (B,C,W,H,D) for 3D. - A fastMRI sample should use the 2D format with C being the number of slices. - y: second sample (e.g., the reconstructed image). It has similar shape as x - - Returns: - ssim_value - - Example: - .. code-block:: python - - import torch - x = torch.ones([1,1,10,10])/2 # ground truth - y = torch.ones([1,1,10,10])/2 # prediction - data_range = x.max().unsqueeze(0) - # the following line should print 1.0 (or 0.9999) - print(SSIMMetric(data_range=data_range,spatial_dims=2)._compute_metric(x,y)) + y_pred: Predicted image. + It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D]. + y: Reference image. + It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D]. + + Raises: + ValueError: when `y_pred` is not a 2D or 3D image. """ - if x.shape[1] > 1: # handling multiple channels (C>1) - if x.shape[1] != y.shape[1]: - raise ValueError( - f"x and y should have the same number of channels, " - f"but x has {x.shape[1]} channels and y has {y.shape[1]} channels." - ) - - ssim = torch.stack( - [ - SSIMMetric(self.data_range, self.win_size, self.k1, self.k2, self.spatial_dims)( # type: ignore[misc] - x[:, i, ...].unsqueeze(1), y[:, i, ...].unsqueeze(1) - ) - for i in range(x.shape[1]) - ] + dims = y_pred.ndimension() + if self.spatial_dims == 2 and dims != 4: + raise ValueError( + f"y_pred should have 4 dimensions (batch, channel, height, width) when using {self.spatial_dims} " + f"spatial dimensions, got {dims}." ) - channel_wise_ssim = ssim.mean(1).view(-1, 1) - return channel_wise_ssim - c1, c2, ux, uy, vx, vy, vxy = self._compute_intermediate_statistics(x, y) + if self.spatial_dims == 3 and dims != 5: + raise ValueError( + f"y_pred should have 4 dimensions (batch, channel, height, width, depth) when using {self.spatial_dims}" + f" spatial dimensions, got {dims}." + ) - numerator = (2 * ux * uy + c1) * (2 * vxy + c2) - denom = (ux**2 + uy**2 + c1) * (vx + vy + c2) - ssim_value = numerator / denom - # [B, 1] - ssim_per_batch: torch.Tensor = ssim_value.view(ssim_value.shape[1], -1).mean(1, keepdim=True) + ssim_value_full_image, _ = compute_ssim_and_cs( + y_pred=y_pred, + y=y, + spatial_dims=self.spatial_dims, + data_range=self.data_range, + kernel_type=self.kernel_type, + kernel_size=self.kernel_size, + kernel_sigma=self.kernel_sigma, + k1=self.k1, + k2=self.k2, + ) + + ssim_per_batch: torch.Tensor = ssim_value_full_image.view(ssim_value_full_image.shape[0], -1).mean( + 1, keepdim=True + ) return ssim_per_batch - def _compute_metric_and_contrast(self, x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D data and (B,C,W,H,D) for 3D. - A fastMRI sample should use the 2D format with C being the number of slices. - y: second sample (e.g., the reconstructed image). It has similar shape as x - Returns: - ssim_value, cs_value +def _gaussian_kernel(spatial_dims, channel: int, kernel_size, kernel_sigma) -> torch.Tensor: + """Computes 2D or 3D gaussian kernel. + + Args: + channel: number of channels in the image + """ + + def gaussian_1d(kernel_size: int, sigma: float) -> torch.Tensor: + """Computes 1D gaussian kernel. + + Args: + kernel_size: size of the gaussian kernel + sigma: Standard deviation of the gaussian kernel """ - if x.shape[1] > 1: # handling multiple channels (C>1) - if x.shape[1] != y.shape[1]: - raise ValueError( - f"x and y should have the same number of channels, " - f"but x has {x.shape[1]} channels and y has {y.shape[1]} channels." - ) - - ssim_ls = [] - cs_ls = [] - for i in range(x.shape[1]): - ssim_val, cs_val = SSIMMetric( - self.data_range, self.win_size, self.k1, self.k2, self.spatial_dims - )._compute_metric_and_contrast(x[:, i, ...].unsqueeze(1), y[:, i, ...].unsqueeze(1)) - ssim_ls.append(ssim_val) - cs_ls.append(cs_val) - channel_wise_ssim: torch.Tensor = torch.stack(ssim_ls).mean(1).view(-1, 1) - channel_wise_cs: torch.Tensor = torch.stack(cs_ls).mean(1).view(-1, 1) - return channel_wise_ssim, channel_wise_cs - - c1, c2, ux, uy, vx, vy, vxy = self._compute_intermediate_statistics(x, y) - - numerator = (2 * ux * uy + c1) * (2 * vxy + c2) - denom = (ux**2 + uy**2 + c1) * (vx + vy + c2) - ssim_value = numerator / denom - # [B, 1] - ssim_per_batch: torch.Tensor = ssim_value.view(ssim_value.shape[1], -1).mean(1, keepdim=True) - - cs_per_batch: torch.Tensor = (2 * vxy + c2) / (vx + vy + c2) # contrast sensitivity function - cs_per_batch = cs_per_batch.view(cs_per_batch.shape[0], -1).mean(1, keepdim=True) # [B, 1] - return ssim_per_batch, cs_per_batch + dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1) + gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2) + return (gauss / gauss.sum()).unsqueeze(dim=0) + + gaussian_kernel_x = gaussian_1d(kernel_size[0], kernel_sigma[0]) + gaussian_kernel_y = gaussian_1d(kernel_size[1], kernel_sigma[1]) + kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size) + + kernel_dimensions = (channel, 1, kernel_size[0], kernel_size[1]) + + if spatial_dims == 3: + gaussian_kernel_z = gaussian_1d(kernel_size[2], kernel_sigma[2])[None,] + kernel = torch.mul( + kernel.unsqueeze(-1).repeat(1, 1, kernel_size[2]), + gaussian_kernel_z.expand(kernel_size[0], kernel_size[1], kernel_size[2]), + ) + kernel_dimensions = (channel, 1, kernel_size[0], kernel_size[1], kernel_size[2]) + + return kernel.expand(kernel_dimensions) + + +def compute_ssim_and_cs( + y_pred: torch.Tensor, + y: torch.Tensor, + spatial_dims: int, + data_range: float = 1.0, + kernel_type: KernelType | str = KernelType.GAUSSIAN, + kernel_size: Sequence[int, ...] = 11, + kernel_sigma: Sequence[int, ...] = 1.5, + k1: float = 0.01, + k2: float = 0.03, +): + """ + Function to compute the Structural Similarity Index Measure (SSIM) and Contrast Sensitivity (CS) for a batch + of images. + + Args: + y_pred: batch of predicted images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3]) + y: batch of target images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3]) + spatial_dims: number of spatial dimensions of the images (2, 3) + data_range: the data range of the images. + kernel_type: the type of kernel to use for the SSIM computation. Can be either "gaussian" or "uniform". + kernel_size: the size of the kernel to use for the SSIM computation. + kernel_sigma: the standard deviation of the kernel to use for the SSIM computation. + k1: the first stability constant. + k2: the second stability constant. + + Returns: + ssim: the Structural Similarity Index Measure score for the batch of images. + cs: the Contrast Sensitivity for the batch of images. + """ + if y.shape != y_pred.shape: + raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") + + y_pred = convert_data_type(y_pred, output_type=torch.Tensor, dtype=torch.float)[0] + y = convert_data_type(y, output_type=torch.Tensor, dtype=torch.float)[0] + + num_channels = y_pred.size(1) + + if kernel_type == KernelType.GAUSSIAN: + kernel = _gaussian_kernel(spatial_dims, num_channels, kernel_size, kernel_sigma) + elif kernel_type == KernelType.UNIFORM: + kernel = torch.ones((num_channels, 1, *kernel_size)) / torch.prod(torch.tensor(kernel_size)) + + kernel = convert_to_dst_type(src=kernel, dst=y_pred)[0] + + c1 = (k1 * data_range) ** 2 # stability constant for luminance + c2 = (k2 * data_range) ** 2 # stability constant for contrast + + conv_fn = getattr(F, f"conv{spatial_dims}d") + mu_x = conv_fn(y_pred, kernel, groups=num_channels) + mu_y = conv_fn(y, kernel, groups=num_channels) + mu_xx = conv_fn(y_pred * y_pred, kernel, groups=num_channels) + mu_yy = conv_fn(y * y, kernel, groups=num_channels) + mu_xy = conv_fn(y_pred * y, kernel, groups=num_channels) + + sigma_x = mu_xx - mu_x * mu_x + sigma_y = mu_yy - mu_y * mu_y + sigma_xy = mu_xy - mu_x * mu_y + + contrast_sensitivity = (2 * sigma_xy + c2) / (sigma_x + sigma_y + c2) + ssim_value_full_image = ((2 * mu_x * mu_y + c1) / (mu_x**2 + mu_y**2 + c1)) * contrast_sensitivity + + return ssim_value_full_image, contrast_sensitivity From 47b8a9b2cf8d992a66b184c6f85b5cacd33f11e5 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 27 Mar 2023 23:29:18 +0100 Subject: [PATCH 02/11] Fix typing > > Signed-off-by: Walter Hugo Lopez Pinaya Co-authored-by: Pedro F. da Costa Signed-off-by: Walter Hugo Lopez Pinaya --- monai/metrics/regression.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/metrics/regression.py b/monai/metrics/regression.py index 0b7b7c2a20..d38a296543 100644 --- a/monai/metrics/regression.py +++ b/monai/metrics/regression.py @@ -277,7 +277,7 @@ def __init__( k2: float = 0.03, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, - ): + ) -> None: super().__init__(reduction=reduction, get_not_nans=get_not_nans) self.spatial_dims = spatial_dims @@ -383,7 +383,7 @@ def compute_ssim_and_cs( kernel_sigma: Sequence[int, ...] = 1.5, k1: float = 0.01, k2: float = 0.03, -): +) -> tuple[torch.Tensor, torch.Tensor]: """ Function to compute the Structural Similarity Index Measure (SSIM) and Contrast Sensitivity (CS) for a batch of images. From 4106eef926c111d00f8d6e34f6d3d31eebe835c6 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 27 Mar 2023 23:49:52 +0100 Subject: [PATCH 03/11] Add tests Signed-off-by: Walter Hugo Lopez Pinaya Co-authored-by: Pedro F. da Costa --- tests/test_ssim_metric.py | 99 ++++++++++++++++++++------------------- 1 file changed, 51 insertions(+), 48 deletions(-) diff --git a/tests/test_ssim_metric.py b/tests/test_ssim_metric.py index 5505e5b750..467e478937 100644 --- a/tests/test_ssim_metric.py +++ b/tests/test_ssim_metric.py @@ -14,60 +14,63 @@ import unittest import torch -from parameterized import parameterized -from monai.metrics.regression import SSIMMetric +from monai.metrics.regression import SSIMMetric, compute_ssim_and_cs +from monai.utils import set_determinism -x = torch.ones([1, 1, 10, 10]) / 2 -y1 = torch.ones([1, 1, 10, 10]) / 2 -y2 = torch.zeros([1, 1, 10, 10]) -data_range = x.max().unsqueeze(0) -TESTS2D = [(x, y1, data_range, torch.tensor(1.0).unsqueeze(0)), (x, y2, data_range, torch.tensor(0.0).unsqueeze(0))] -x = torch.ones([1, 1, 10, 10, 10]) / 2 -y1 = torch.ones([1, 1, 10, 10, 10]) / 2 -y2 = torch.zeros([1, 1, 10, 10, 10]) -data_range = x.max().unsqueeze(0) -TESTS3D = [(x, y1, data_range, torch.tensor(1.0).unsqueeze(0)), (x, y2, data_range, torch.tensor(0.0).unsqueeze(0))] +class TestSSIMMetric(unittest.TestCase): + def test2d_gaussian(self): + set_determinism(0) + preds = torch.abs(torch.randn(2, 3, 16, 16)) + target = torch.abs(torch.randn(2, 3, 16, 16)) + preds = preds / preds.max() + target = target / target.max() -x = torch.ones([3, 3, 10, 10, 10]) / 2 -y1 = torch.ones([3, 3, 10, 10, 10]) / 2 -data_range = x.max().unsqueeze(0) -res = torch.tensor(1.0).unsqueeze(0) * torch.ones((3, 1)) -TESTSCS = [(x, y1, data_range, res)] + metric = SSIMMetric(spatial_dims=2, data_range=1.0, kernel_type="gaussian") + metric(preds, target) + result = metric.aggregate() + expected_value = 0.045415 + self.assertTrue(expected_value - result.item() < 0.000001) + def test2d_uniform(self): + set_determinism(0) + preds = torch.abs(torch.randn(2, 3, 16, 16)) + target = torch.abs(torch.randn(2, 3, 16, 16)) + preds = preds / preds.max() + target = target / target.max() -class TestSSIMMetric(unittest.TestCase): - @parameterized.expand(TESTS2D) - def test2d(self, x, y, drange, res): - result = SSIMMetric(data_range=drange, spatial_dims=2)._compute_metric(x, y) - self.assertTrue(isinstance(result, torch.Tensor)) - self.assertTrue(torch.abs(res - result).item() < 0.001) - - ssim = SSIMMetric(data_range=drange, spatial_dims=2) - ssim(x, y) - result2 = ssim.aggregate() - self.assertTrue(isinstance(result2, torch.Tensor)) - self.assertTrue(torch.abs(result2 - result).item() < 0.001) - - @parameterized.expand(TESTS3D) - def test3d(self, x, y, drange, res): - result = SSIMMetric(data_range=drange, spatial_dims=3)._compute_metric(x, y) - self.assertTrue(isinstance(result, torch.Tensor)) - self.assertTrue(torch.abs(res - result).item() < 0.001) - - ssim = SSIMMetric(data_range=drange, spatial_dims=3) - ssim(x, y) - result2 = ssim.aggregate() - self.assertTrue(isinstance(result2, torch.Tensor)) - self.assertTrue(torch.abs(result2 - result).item() < 0.001) - - @parameterized.expand(TESTSCS) - def testfull(self, x, y, drange, res): - result, cs = SSIMMetric(data_range=drange, spatial_dims=3)._compute_metric_and_contrast(x, y) - self.assertTrue(isinstance(result, torch.Tensor)) - self.assertTrue(isinstance(cs, torch.Tensor)) - self.assertTrue((torch.abs(res - cs) < 0.001).all().item()) + metric = SSIMMetric(spatial_dims=2, data_range=1.0, kernel_type="uniform") + metric(preds, target) + result = metric.aggregate() + expected_value = 0.050103 + self.assertTrue(expected_value - result.item() < 0.000001) + + def test3d_gaussian(self): + set_determinism(0) + preds = torch.abs(torch.randn(2, 3, 16, 16, 16)) + target = torch.abs(torch.randn(2, 3, 16, 16, 16)) + preds = preds / preds.max() + target = target / target.max() + + metric = SSIMMetric(spatial_dims=3, data_range=1.0, kernel_type="gaussian") + metric(preds, target) + result = metric.aggregate() + expected_value = 0.017644 + self.assertTrue(expected_value - result.item() < 0.000001) + + def input_ill_input_shape(self): + with self.assertRaises(ValueError): + metric = SSIMMetric(spatial_dims=3) + metric(torch.randn(1, 1, 16, 16), torch.randn(1, 1, 16, 16)) + + with self.assertRaises(ValueError): + metric = SSIMMetric(spatial_dims=2) + metric(torch.randn(1, 1, 16, 16, 16), torch.randn(1, 1, 16, 16, 16)) + + def mismatch_y_pred_and_y(self): + with self.assertRaises(ValueError): + compute_ssim_and_cs(y_pred=torch.randn(1, 1, 16, 8), y=torch.randn(1, 1, 16, 16), spatial_dims=2) if __name__ == "__main__": From 6a8bd772681f0dcd9bd90fc6a93554e7737185c9 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Wed, 29 Mar 2023 10:24:11 +0100 Subject: [PATCH 04/11] Fix typing Signed-off-by: Walter Hugo Lopez Pinaya Co-authored-by: Mark Graham --- monai/metrics/regression.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/monai/metrics/regression.py b/monai/metrics/regression.py index d38a296543..ba8b238222 100644 --- a/monai/metrics/regression.py +++ b/monai/metrics/regression.py @@ -272,7 +272,7 @@ def __init__( data_range: float = 1.0, kernel_type: KernelType | str = KernelType.GAUSSIAN, kernel_size: int | Sequence[int, ...] = 11, - kernel_sigma: int | Sequence[int, ...] = 1.5, + kernel_sigma: float | Sequence[float, ...] = 1.5, k1: float = 0.01, k2: float = 0.03, reduction: MetricReduction | str = MetricReduction.MEAN, @@ -338,7 +338,9 @@ def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor return ssim_per_batch -def _gaussian_kernel(spatial_dims, channel: int, kernel_size, kernel_sigma) -> torch.Tensor: +def _gaussian_kernel( + spatial_dims: int, channel: int, kernel_size: Sequence[int, ...], kernel_sigma: Sequence[float, ...] +) -> torch.Tensor: """Computes 2D or 3D gaussian kernel. Args: @@ -380,7 +382,7 @@ def compute_ssim_and_cs( data_range: float = 1.0, kernel_type: KernelType | str = KernelType.GAUSSIAN, kernel_size: Sequence[int, ...] = 11, - kernel_sigma: Sequence[int, ...] = 1.5, + kernel_sigma: Sequence[float, ...] = 1.5, k1: float = 0.01, k2: float = 0.03, ) -> tuple[torch.Tensor, torch.Tensor]: From fa6677e83a4c1a5eb29bb820e6f9f6089553ecd0 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Wed, 29 Mar 2023 11:07:54 +0100 Subject: [PATCH 05/11] [WIP] Add loss --- monai/losses/ssim_loss.py | 65 +++++++++++++++++++++++++++---------- monai/metrics/regression.py | 11 ++++--- 2 files changed, 55 insertions(+), 21 deletions(-) diff --git a/monai/losses/ssim_loss.py b/monai/losses/ssim_loss.py index e8e5d0c2ba..90fadc8297 100644 --- a/monai/losses/ssim_loss.py +++ b/monai/losses/ssim_loss.py @@ -11,18 +11,18 @@ from __future__ import annotations +from collections.abc import Sequence + import torch from torch.nn.modules.loss import _Loss -from monai.metrics.regression import SSIMMetric +from monai.metrics.regression import KernelType, SSIMMetric +from monai.utils import ensure_tuple_rep class SSIMLoss(_Loss): """ - Build a Pytorch version of the SSIM loss function based on the original formula of SSIM - - Modified and adopted from: - https://github.com/facebookresearch/fastMRI/blob/main/banding_removal/fastmri/ssim_loss_mixin.py + Compute the loss function based on the Structural Similarity Index Measure (SSIM) Metric. For more info, visit https://vicuesoft.com/glossary/term/ssim-ms-ssim/ @@ -32,29 +32,60 @@ class SSIMLoss(_Loss): similarity." IEEE transactions on image processing 13.4 (2004): 600-612. """ - def __init__(self, win_size: int = 7, k1: float = 0.01, k2: float = 0.03, spatial_dims: int = 2): + def __init__( + self, + spatial_dims: int, + data_range: float = 1.0, + kernel_type: KernelType | str = KernelType.GAUSSIAN, + kernel_size: int | Sequence[int, ...] = 11, + kernel_sigma: float | Sequence[float, ...] = 1.5, + k1: float = 0.01, + k2: float = 0.03, + ): """ Args: - win_size: gaussian weighting window size + spatial_dims: number of spatial dimensions of the input images. + data_range: value range of input images. (usually 1.0 or 255) + kernel_type: type of kernel, can be "gaussian" or "uniform". + kernel_size: size of kernel + kernel_sigma: standard deviation for Gaussian kernel. k1: stability constant used in the luminance denominator k2: stability constant used in the contrast denominator - spatial_dims: if 2, input shape is expected to be (B,C,H,W). if 3, it is expected to be (B,C,H,W,D) """ super().__init__() - self.win_size = win_size - self.k1, self.k2 = k1, k2 self.spatial_dims = spatial_dims - - def forward(self, x: torch.Tensor, y: torch.Tensor, data_range: torch.Tensor) -> torch.Tensor: + self.data_range = data_range + self.kernel_type = kernel_type + + if not isinstance(kernel_size, Sequence): + kernel_size = ensure_tuple_rep(kernel_size, spatial_dims) + self.kernel_size = kernel_size + + if not isinstance(kernel_sigma, Sequence): + kernel_sigma = ensure_tuple_rep(kernel_sigma, spatial_dims) + self.kernel_sigma = kernel_sigma + + self.k1 = k1 + self.k2 = k2 + + self.ssim_metric = SSIMMetric( + spatial_dims=self.spatial_dims, + data_range=self.data_range, + kernel_type=self.kernel_type, + kernel_size=self.kernel_size, + kernel_sigma=self.kernel_sigma, + k1=self.k1, + k2=self.k2, + ) + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ Args: - x: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D and pseudo-3D data, - and (B,C,W,H,D) for 3D data, - y: second sample (e.g., the reconstructed image). It has similar shape as x. - data_range: dynamic range of the data + x: batch of predicted images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3]) + y: batch of target images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3]) Returns: - 1-ssim_value (recall this is meant to be a loss function) + 1 minus the Structural Similarity Index Measure (recall this is meant to be a loss function) Example: .. code-block:: python diff --git a/monai/metrics/regression.py b/monai/metrics/regression.py index ba8b238222..8ec74aef25 100644 --- a/monai/metrics/regression.py +++ b/monai/metrics/regression.py @@ -339,12 +339,15 @@ def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor def _gaussian_kernel( - spatial_dims: int, channel: int, kernel_size: Sequence[int, ...], kernel_sigma: Sequence[float, ...] + spatial_dims: int, num_channels: int, kernel_size: Sequence[int, ...], kernel_sigma: Sequence[float, ...] ) -> torch.Tensor: """Computes 2D or 3D gaussian kernel. Args: - channel: number of channels in the image + spatial_dims: number of spatial dimensions of the input images. + num_channels: number of channels in the image + kernel_size: size of kernel + kernel_sigma: standard deviation for Gaussian kernel. """ def gaussian_1d(kernel_size: int, sigma: float) -> torch.Tensor: @@ -362,7 +365,7 @@ def gaussian_1d(kernel_size: int, sigma: float) -> torch.Tensor: gaussian_kernel_y = gaussian_1d(kernel_size[1], kernel_sigma[1]) kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size) - kernel_dimensions = (channel, 1, kernel_size[0], kernel_size[1]) + kernel_dimensions = (num_channels, 1, kernel_size[0], kernel_size[1]) if spatial_dims == 3: gaussian_kernel_z = gaussian_1d(kernel_size[2], kernel_sigma[2])[None,] @@ -370,7 +373,7 @@ def gaussian_1d(kernel_size: int, sigma: float) -> torch.Tensor: kernel.unsqueeze(-1).repeat(1, 1, kernel_size[2]), gaussian_kernel_z.expand(kernel_size[0], kernel_size[1], kernel_size[2]), ) - kernel_dimensions = (channel, 1, kernel_size[0], kernel_size[1], kernel_size[2]) + kernel_dimensions = (num_channels, 1, kernel_size[0], kernel_size[1], kernel_size[2]) return kernel.expand(kernel_dimensions) From e3545ae0ebddc26134551ef0f636832f29e4670e Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Wed, 29 Mar 2023 15:25:56 +0100 Subject: [PATCH 06/11] Add loss --- monai/losses/ssim_loss.py | 58 ++++++++++++----------------- tests/test_ssim_loss.py | 78 ++++++++++++++++----------------------- 2 files changed, 56 insertions(+), 80 deletions(-) diff --git a/monai/losses/ssim_loss.py b/monai/losses/ssim_loss.py index 90fadc8297..1a84d23b5d 100644 --- a/monai/losses/ssim_loss.py +++ b/monai/losses/ssim_loss.py @@ -17,7 +17,7 @@ from torch.nn.modules.loss import _Loss from monai.metrics.regression import KernelType, SSIMMetric -from monai.utils import ensure_tuple_rep +from monai.utils import LossReduction, ensure_tuple_rep class SSIMLoss(_Loss): @@ -41,6 +41,7 @@ def __init__( kernel_sigma: float | Sequence[float, ...] = 1.5, k1: float = 0.01, k2: float = 0.03, + reduction: LossReduction | str = LossReduction.MEAN, ): """ Args: @@ -51,8 +52,14 @@ def __init__( kernel_sigma: standard deviation for Gaussian kernel. k1: stability constant used in the luminance denominator k2: stability constant used in the contrast denominator + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + """ - super().__init__() + super().__init__(reduction=LossReduction(reduction).value) self.spatial_dims = spatial_dims self.data_range = data_range self.kernel_type = kernel_type @@ -78,14 +85,14 @@ def __init__( k2=self.k2, ) - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: - x: batch of predicted images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3]) - y: batch of target images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3]) + input: batch of predicted images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3]) + target: batch of target images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3]) Returns: - 1 minus the Structural Similarity Index Measure (recall this is meant to be a loss function) + 1 minus the ssim index (recall this is meant to be a loss function) Example: .. code-block:: python @@ -95,41 +102,24 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # 2D data x = torch.ones([1,1,10,10])/2 y = torch.ones([1,1,10,10])/2 - data_range = x.max().unsqueeze(0) - # the following line should print 1.0 (or 0.9999) - print(1-SSIMLoss(spatial_dims=2)(x,y,data_range)) + print(1-SSIMLoss(spatial_dims=2)(x,y)) # pseudo-3D data x = torch.ones([1,5,10,10])/2 # 5 could represent number of slices y = torch.ones([1,5,10,10])/2 - data_range = x.max().unsqueeze(0) - # the following line should print 1.0 (or 0.9999) - print(1-SSIMLoss(spatial_dims=2)(x,y,data_range)) + print(1-SSIMLoss(spatial_dims=2)(x,y)) # 3D data x = torch.ones([1,1,10,10,10])/2 y = torch.ones([1,1,10,10,10])/2 - data_range = x.max().unsqueeze(0) - # the following line should print 1.0 (or 0.9999) - print(1-SSIMLoss(spatial_dims=3)(x,y,data_range)) + print(1-SSIMLoss(spatial_dims=3)(x,y)) """ - if x.shape[0] == 1: - ssim_value: torch.Tensor = SSIMMetric( - data_range, self.win_size, self.k1, self.k2, self.spatial_dims - )._compute_tensor(x, y) - elif x.shape[0] > 1: - for i in range(x.shape[0]): - ssim_val: torch.Tensor = SSIMMetric( - data_range, self.win_size, self.k1, self.k2, self.spatial_dims - )._compute_tensor(x[i : i + 1], y[i : i + 1]) - if i == 0: - ssim_value = ssim_val - else: - ssim_value = torch.cat((ssim_value.view(i), ssim_val.view(1)), dim=0) - - else: - raise ValueError("Batch size is not nonnegative integer value") - # 1- dimensional tensor is only allowed - ssim_value = ssim_value.view(-1, 1) - loss: torch.Tensor = 1 - ssim_value.mean() + ssim_value = self.ssim_metric._compute_tensor(input, target).view(-1, 1) + loss: torch.Tensor = 1 - ssim_value + + if self.reduction == LossReduction.MEAN.value: + loss = torch.mean(loss) # the batch average + elif self.reduction == LossReduction.SUM.value: + loss = torch.sum(loss) # sum over the batch + return loss diff --git a/tests/test_ssim_loss.py b/tests/test_ssim_loss.py index a4ba66300b..d3b95b950a 100644 --- a/tests/test_ssim_loss.py +++ b/tests/test_ssim_loss.py @@ -13,58 +13,44 @@ import unittest +import numpy as np import torch -from parameterized import parameterized from monai.losses.ssim_loss import SSIMLoss - -TESTS2D = [] -for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: - for batch_size in [1, 2, 16]: - x = torch.ones([batch_size, 1, 10, 10]) / 2 - y1 = torch.ones([batch_size, 1, 10, 10]) / 2 - y2 = torch.zeros([batch_size, 1, 10, 10]) - data_range = x.max().unsqueeze(0) - TESTS2D.append((x.to(device), y1.to(device), data_range.to(device), torch.tensor(1.0).unsqueeze(0).to(device))) - TESTS2D.append((x.to(device), y2.to(device), data_range.to(device), torch.tensor(0.0).unsqueeze(0).to(device))) - -TESTS3D = [] -for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: - for batch_size in [1, 2, 16]: - x = torch.ones([batch_size, 1, 10, 10, 10]) / 2 - y1 = torch.ones([batch_size, 1, 10, 10, 10]) / 2 - y2 = torch.zeros([batch_size, 1, 10, 10, 10]) - data_range = x.max().unsqueeze(0) - TESTS3D.append((x.to(device), y1.to(device), data_range.to(device), torch.tensor(1.0).unsqueeze(0).to(device))) - TESTS3D.append((x.to(device), y2.to(device), data_range.to(device), torch.tensor(0.0).unsqueeze(0).to(device))) - -TESTS2D_GRAD = [] -for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: - for batch_size in [1, 2, 16]: - x = torch.ones([batch_size, 1, 10, 10]) / 2 - y = torch.ones([batch_size, 1, 10, 10]) / 2 - y.requires_grad_(True) - data_range = x.max().unsqueeze(0) - TESTS2D_GRAD.append([x.to(device), y.to(device), data_range.to(device)]) +from monai.utils import set_determinism +from tests.utils import test_script_save class TestSSIMLoss(unittest.TestCase): - @parameterized.expand(TESTS2D) - def test2d(self, x, y, drange, res): - result = 1 - SSIMLoss(spatial_dims=2)(x, y, drange) - self.assertTrue(isinstance(result, torch.Tensor)) - self.assertTrue(torch.abs(res - result).item() < 0.001) - - @parameterized.expand(TESTS2D_GRAD) - def test_grad(self, x, y, drange): - result = 1 - SSIMLoss(spatial_dims=2)(x, y, drange) - self.assertTrue(result.requires_grad) - - @parameterized.expand(TESTS3D) - def test3d(self, x, y, drange, res): - result = 1 - SSIMLoss(spatial_dims=3)(x, y, drange) - self.assertTrue(isinstance(result, torch.Tensor)) - self.assertTrue(torch.abs(res - result).item() < 0.001) + def test_shape(self): + set_determinism(0) + preds = torch.abs(torch.randn(2, 3, 16, 16)) + target = torch.abs(torch.randn(2, 3, 16, 16)) + preds = preds / preds.max() + target = target / target.max() + + result = SSIMLoss(spatial_dims=2, data_range=1.0, kernel_type="gaussian", reduction="mean").forward( + preds, target + ) + expected_val = 0.9546 + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4) + + result = SSIMLoss(spatial_dims=2, data_range=1.0, kernel_type="gaussian", reduction="sum").forward( + preds, target + ) + expected_val = 1.9092 + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4) + + result = SSIMLoss(spatial_dims=2, data_range=1.0, kernel_type="gaussian", reduction="none").forward( + preds, target + ) + expected_val = [[0.9121], [0.9971]] + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4) + + def test_script(self): + loss = SSIMLoss(spatial_dims=2) + test_input = torch.ones(2, 2, 16, 16) + test_script_save(loss, test_input, test_input) if __name__ == "__main__": From 1b8ed97eb7538884dc0c5209826e4b0251f6750e Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Wed, 29 Mar 2023 19:37:50 +0100 Subject: [PATCH 07/11] fix typing --- monai/losses/ssim_loss.py | 4 ++-- monai/metrics/regression.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/monai/losses/ssim_loss.py b/monai/losses/ssim_loss.py index 1a84d23b5d..5ab87a5f58 100644 --- a/monai/losses/ssim_loss.py +++ b/monai/losses/ssim_loss.py @@ -37,8 +37,8 @@ def __init__( spatial_dims: int, data_range: float = 1.0, kernel_type: KernelType | str = KernelType.GAUSSIAN, - kernel_size: int | Sequence[int, ...] = 11, - kernel_sigma: float | Sequence[float, ...] = 1.5, + kernel_size: int | Sequence[int] = 11, + kernel_sigma: float | Sequence[float] = 1.5, k1: float = 0.01, k2: float = 0.03, reduction: LossReduction | str = LossReduction.MEAN, diff --git a/monai/metrics/regression.py b/monai/metrics/regression.py index 8ec74aef25..b54daff88f 100644 --- a/monai/metrics/regression.py +++ b/monai/metrics/regression.py @@ -271,8 +271,8 @@ def __init__( spatial_dims: int, data_range: float = 1.0, kernel_type: KernelType | str = KernelType.GAUSSIAN, - kernel_size: int | Sequence[int, ...] = 11, - kernel_sigma: float | Sequence[float, ...] = 1.5, + kernel_size: int | Sequence[int] = 11, + kernel_sigma: float | Sequence[float] = 1.5, k1: float = 0.01, k2: float = 0.03, reduction: MetricReduction | str = MetricReduction.MEAN, @@ -339,7 +339,7 @@ def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor def _gaussian_kernel( - spatial_dims: int, num_channels: int, kernel_size: Sequence[int, ...], kernel_sigma: Sequence[float, ...] + spatial_dims: int, num_channels: int, kernel_size: Sequence[int], kernel_sigma: Sequence[float] ) -> torch.Tensor: """Computes 2D or 3D gaussian kernel. @@ -384,8 +384,8 @@ def compute_ssim_and_cs( spatial_dims: int, data_range: float = 1.0, kernel_type: KernelType | str = KernelType.GAUSSIAN, - kernel_size: Sequence[int, ...] = 11, - kernel_sigma: Sequence[float, ...] = 1.5, + kernel_size: Sequence[int] = 11, + kernel_sigma: Sequence[float] = 1.5, k1: float = 0.01, k2: float = 0.03, ) -> tuple[torch.Tensor, torch.Tensor]: From b5bebb57fb4fbc5fe410c96065c2d36f68905221 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Wed, 29 Mar 2023 20:51:00 +0100 Subject: [PATCH 08/11] Fix arguments --- monai/metrics/regression.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/metrics/regression.py b/monai/metrics/regression.py index b54daff88f..b6fb4a9cd3 100644 --- a/monai/metrics/regression.py +++ b/monai/metrics/regression.py @@ -382,10 +382,10 @@ def compute_ssim_and_cs( y_pred: torch.Tensor, y: torch.Tensor, spatial_dims: int, + kernel_size: Sequence[int], + kernel_sigma: Sequence[float], data_range: float = 1.0, kernel_type: KernelType | str = KernelType.GAUSSIAN, - kernel_size: Sequence[int] = 11, - kernel_sigma: Sequence[float] = 1.5, k1: float = 0.01, k2: float = 0.03, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -396,11 +396,11 @@ def compute_ssim_and_cs( Args: y_pred: batch of predicted images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3]) y: batch of target images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3]) + kernel_size: the size of the kernel to use for the SSIM computation. + kernel_sigma: the standard deviation of the kernel to use for the SSIM computation. spatial_dims: number of spatial dimensions of the images (2, 3) data_range: the data range of the images. kernel_type: the type of kernel to use for the SSIM computation. Can be either "gaussian" or "uniform". - kernel_size: the size of the kernel to use for the SSIM computation. - kernel_sigma: the standard deviation of the kernel to use for the SSIM computation. k1: the first stability constant. k2: the second stability constant. From fa91c3c44898eee27c85ac0b2246de1806e038ae Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Wed, 29 Mar 2023 20:54:39 +0100 Subject: [PATCH 09/11] Fix type --- monai/metrics/regression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/metrics/regression.py b/monai/metrics/regression.py index b6fb4a9cd3..fc68d726ed 100644 --- a/monai/metrics/regression.py +++ b/monai/metrics/regression.py @@ -365,7 +365,7 @@ def gaussian_1d(kernel_size: int, sigma: float) -> torch.Tensor: gaussian_kernel_y = gaussian_1d(kernel_size[1], kernel_sigma[1]) kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size) - kernel_dimensions = (num_channels, 1, kernel_size[0], kernel_size[1]) + kernel_dimensions: tuple[int, ...] = (num_channels, 1, kernel_size[0], kernel_size[1]) if spatial_dims == 3: gaussian_kernel_z = gaussian_1d(kernel_size[2], kernel_sigma[2])[None,] From f1791100c816d74bbc18fc0f1360724efdd3e314 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Wed, 29 Mar 2023 21:09:58 +0100 Subject: [PATCH 10/11] DCO Remediation Commit for Walter Hugo Lopez Pinaya I, Walter Hugo Lopez Pinaya , hereby add my Signed-off-by to this commit: fa6677e83a4c1a5eb29bb820e6f9f6089553ecd0 I, Walter Hugo Lopez Pinaya , hereby add my Signed-off-by to this commit: e3545ae0ebddc26134551ef0f636832f29e4670e I, Walter Hugo Lopez Pinaya , hereby add my Signed-off-by to this commit: 1b8ed97eb7538884dc0c5209826e4b0251f6750e I, Walter Hugo Lopez Pinaya , hereby add my Signed-off-by to this commit: b5bebb57fb4fbc5fe410c96065c2d36f68905221 I, Walter Hugo Lopez Pinaya , hereby add my Signed-off-by to this commit: fa91c3c44898eee27c85ac0b2246de1806e038ae Signed-off-by: Walter Hugo Lopez Pinaya From 57f05bafa09e7e9e6d4e9c66be6b6f1f0e01077f Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 3 Apr 2023 10:16:01 +0100 Subject: [PATCH 11/11] Fix type Signed-off-by: Walter Hugo Lopez Pinaya --- monai/losses/ssim_loss.py | 12 ++++++------ monai/metrics/regression.py | 10 +++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/monai/losses/ssim_loss.py b/monai/losses/ssim_loss.py index 5ab87a5f58..8ea3eb116b 100644 --- a/monai/losses/ssim_loss.py +++ b/monai/losses/ssim_loss.py @@ -37,7 +37,7 @@ def __init__( spatial_dims: int, data_range: float = 1.0, kernel_type: KernelType | str = KernelType.GAUSSIAN, - kernel_size: int | Sequence[int] = 11, + win_size: int | Sequence[int] = 11, kernel_sigma: float | Sequence[float] = 1.5, k1: float = 0.01, k2: float = 0.03, @@ -48,7 +48,7 @@ def __init__( spatial_dims: number of spatial dimensions of the input images. data_range: value range of input images. (usually 1.0 or 255) kernel_type: type of kernel, can be "gaussian" or "uniform". - kernel_size: size of kernel + win_size: window size of kernel kernel_sigma: standard deviation for Gaussian kernel. k1: stability constant used in the luminance denominator k2: stability constant used in the contrast denominator @@ -64,9 +64,9 @@ def __init__( self.data_range = data_range self.kernel_type = kernel_type - if not isinstance(kernel_size, Sequence): - kernel_size = ensure_tuple_rep(kernel_size, spatial_dims) - self.kernel_size = kernel_size + if not isinstance(win_size, Sequence): + win_size = ensure_tuple_rep(win_size, spatial_dims) + self.kernel_size = win_size if not isinstance(kernel_sigma, Sequence): kernel_sigma = ensure_tuple_rep(kernel_sigma, spatial_dims) @@ -79,7 +79,7 @@ def __init__( spatial_dims=self.spatial_dims, data_range=self.data_range, kernel_type=self.kernel_type, - kernel_size=self.kernel_size, + win_size=self.kernel_size, kernel_sigma=self.kernel_sigma, k1=self.k1, k2=self.k2, diff --git a/monai/metrics/regression.py b/monai/metrics/regression.py index fc68d726ed..c315a2eac0 100644 --- a/monai/metrics/regression.py +++ b/monai/metrics/regression.py @@ -256,7 +256,7 @@ class SSIMMetric(RegressionMetric): spatial_dims: number of spatial dimensions of the input images. data_range: value range of input images. (usually 1.0 or 255) kernel_type: type of kernel, can be "gaussian" or "uniform". - kernel_size: size of kernel + win_size: window size of kernel kernel_sigma: standard deviation for Gaussian kernel. k1: stability constant used in the luminance denominator k2: stability constant used in the contrast denominator @@ -271,7 +271,7 @@ def __init__( spatial_dims: int, data_range: float = 1.0, kernel_type: KernelType | str = KernelType.GAUSSIAN, - kernel_size: int | Sequence[int] = 11, + win_size: int | Sequence[int] = 11, kernel_sigma: float | Sequence[float] = 1.5, k1: float = 0.01, k2: float = 0.03, @@ -284,9 +284,9 @@ def __init__( self.data_range = data_range self.kernel_type = kernel_type - if not isinstance(kernel_size, Sequence): - kernel_size = ensure_tuple_rep(kernel_size, spatial_dims) - self.kernel_size = kernel_size + if not isinstance(win_size, Sequence): + win_size = ensure_tuple_rep(win_size, spatial_dims) + self.kernel_size = win_size if not isinstance(kernel_sigma, Sequence): kernel_sigma = ensure_tuple_rep(kernel_sigma, spatial_dims)