From ca49b4ca71988de850791f3691be3ef0e924e9d7 Mon Sep 17 00:00:00 2001 From: sarthakpati Date: Thu, 19 Dec 2024 09:57:43 -0500 Subject: [PATCH] putting comments and added parameter to get rmse --- GANDLF/metrics/synthesis.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/GANDLF/metrics/synthesis.py b/GANDLF/metrics/synthesis.py index ba1b4113e..f29933da7 100644 --- a/GANDLF/metrics/synthesis.py +++ b/GANDLF/metrics/synthesis.py @@ -39,15 +39,21 @@ def structural_similarity_index( return ssim_idx.mean() -def mean_squared_error(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: +def mean_squared_error( + prediction: torch.Tensor, target: torch.Tensor, squared: bool = True +) -> torch.Tensor: """ Computes the mean squared error between the target and prediction. Args: prediction (torch.Tensor): The prediction tensor. target (torch.Tensor): The target tensor. + squared (bool, optional): Whether to return squared error. Defaults to True. + + Returns: + torch.Tensor: The mean squared error or its square root. """ - mse = MeanSquaredError() + mse = MeanSquaredError(squared=squared) return mse(preds=prediction, target=target) @@ -78,10 +84,9 @@ def peak_signal_noise_ratio( return psnr(preds=prediction, target=target) else: # implementation of PSNR that does not give 'inf'/'nan' when 'mse==0' mse = mean_squared_error(target, prediction) - if data_range == None: # compute data_range like torchmetrics if not given - min_v = ( - 0 if torch.min(target) > 0 else torch.min(target) - ) # look at this line + if data_range is None: # compute data_range like torchmetrics if not given + # put the min value to 0 if all values are positive + min_v = 0 if torch.min(target) > 0 else torch.min(target) max_v = torch.max(target) else: min_v, max_v = data_range