Skip to content

Commit

Permalink
putting comments and added parameter to get rmse
Browse files Browse the repository at this point in the history
  • Loading branch information
sarthakpati committed Dec 19, 2024
1 parent f737db0 commit ca49b4c
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions GANDLF/metrics/synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,21 @@ def structural_similarity_index(
return ssim_idx.mean()


def mean_squared_error(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
def mean_squared_error(
prediction: torch.Tensor, target: torch.Tensor, squared: bool = True
) -> torch.Tensor:
"""
Computes the mean squared error between the target and prediction.
Args:
prediction (torch.Tensor): The prediction tensor.
target (torch.Tensor): The target tensor.
squared (bool, optional): Whether to return squared error. Defaults to True.
Returns:
torch.Tensor: The mean squared error or its square root.
"""
mse = MeanSquaredError()
mse = MeanSquaredError(squared=squared)
return mse(preds=prediction, target=target)


Expand Down Expand Up @@ -78,10 +84,9 @@ def peak_signal_noise_ratio(
return psnr(preds=prediction, target=target)
else: # implementation of PSNR that does not give 'inf'/'nan' when 'mse==0'
mse = mean_squared_error(target, prediction)
if data_range == None: # compute data_range like torchmetrics if not given
min_v = (
0 if torch.min(target) > 0 else torch.min(target)
) # look at this line
if data_range is None: # compute data_range like torchmetrics if not given
# put the min value to 0 if all values are positive
min_v = 0 if torch.min(target) > 0 else torch.min(target)
max_v = torch.max(target)
else:
min_v, max_v = data_range
Expand Down

0 comments on commit ca49b4c

Please sign in to comment.