From 1c4afd7306a2b766f8d71a5460bd9edc151412eb Mon Sep 17 00:00:00 2001
From: sarthakpati <sarthak.pati@hotmail.com>
Date: Thu, 19 Dec 2024 09:58:05 -0500
Subject: [PATCH] ensure that the brain mask and void image are treated
 differently

---
 GANDLF/cli/generate_metrics.py | 83 ++++++++++++++++++++--------------
 1 file changed, 49 insertions(+), 34 deletions(-)

diff --git a/GANDLF/cli/generate_metrics.py b/GANDLF/cli/generate_metrics.py
index d484b63a0..1821aa2dd 100644
--- a/GANDLF/cli/generate_metrics.py
+++ b/GANDLF/cli/generate_metrics.py
@@ -302,21 +302,22 @@ def __percentile_clip(
             reference_tensor = (
                 input_tensor if reference_tensor is None else reference_tensor
             )
-            v_min, v_max = np.percentile(
-                reference_tensor, [p_min, p_max]
-            )  # get p_min percentile and p_max percentile
-
+            # get p_min percentile and p_max percentile
+            v_min, v_max = np.percentile(reference_tensor, [p_min, p_max])
             # set lower bound to be 0 if strictlyPositive is enabled
             v_min = max(v_min, 0.0) if strictlyPositive else v_min
-            output_tensor = np.clip(
-                input_tensor, v_min, v_max
-            )  # clip values to percentiles from reference_tensor
-            output_tensor = (output_tensor - v_min) / (
-                v_max - v_min
-            )  # normalizes values to [0;1]
+            # clip values to percentiles from reference_tensor
+            output_tensor = np.clip(input_tensor, v_min, v_max)
+            # normalizes values to [0;1]
+            output_tensor = (output_tensor - v_min) / (v_max - v_min)
             return output_tensor
 
-        input_df = __update_header_location_case_insensitive(input_df, "Mask", False)
+        # these are additional columns that could be present for synthesis tasks
+        for column_to_make_case_insensitive in ["Mask", "VoidImage"]:
+            input_df = __update_header_location_case_insensitive(
+                input_df, column_to_make_case_insensitive, False
+            )
+
         for _, row in tqdm(input_df.iterrows(), total=input_df.shape[0]):
             current_subject_id = row["SubjectID"]
             overall_stats_dict[current_subject_id] = {}
@@ -332,6 +333,15 @@ def __percentile_clip(
                 )
             ).byte()
 
+            void_image_present = True if "VoidImage" in row else False
+            void_image = (
+                __fix_2d_tensor(torchio.ScalarImage(row["VoidImage"]).data)
+                if "VoidImage" in row
+                else torch.from_numpy(
+                    np.ones(target_image.numpy().shape, dtype=np.uint8)
+                )
+            )
+
             # Get Infill region (we really are only interested in the infill region)
             output_infill = (pred_image * mask).float()
             gt_image_infill = (target_image * mask).float()
@@ -339,9 +349,10 @@ def __percentile_clip(
             # Normalize to [0;1] based on GT (otherwise MSE will depend on the image intensity range)
             normalize = parameters.get("normalize", True)
             if normalize:
+                # use all the tissue that is not masked for normalization
                 reference_tensor = (
-                    target_image * ~mask
-                )  # use all the tissue that is not masked for normalization
+                    target_image * ~mask if not void_image_present else void_image
+                )
                 gt_image_infill = __percentile_clip(
                     gt_image_infill,
                     reference_tensor=reference_tensor,
@@ -357,9 +368,9 @@ def __percentile_clip(
                     strictlyPositive=True,
                 )
 
-            overall_stats_dict[current_subject_id][
-                "ssim"
-            ] = structural_similarity_index(output_infill, gt_image_infill, mask).item()
+            overall_stats_dict[current_subject_id]["ssim"] = (
+                structural_similarity_index(output_infill, gt_image_infill, mask).item()
+            )
 
             # ncc metrics
             compute_ncc = parameters.get("compute_ncc", True)
@@ -386,6 +397,10 @@ def __percentile_clip(
                 output_infill, gt_image_infill
             ).item()
 
+            overall_stats_dict[current_subject_id]["rmse"] = mean_squared_error(
+                output_infill, gt_image_infill, squared=False
+            ).item()
+
             overall_stats_dict[current_subject_id]["msle"] = mean_squared_log_error(
                 output_infill, gt_image_infill
             ).item()
@@ -400,30 +415,30 @@ def __percentile_clip(
             ).item()
 
             # same as above but with epsilon for robustness
-            overall_stats_dict[current_subject_id][
-                "psnr_eps"
-            ] = peak_signal_noise_ratio(
-                output_infill, gt_image_infill, epsilon=sys.float_info.epsilon
-            ).item()
+            overall_stats_dict[current_subject_id]["psnr_eps"] = (
+                peak_signal_noise_ratio(
+                    output_infill, gt_image_infill, epsilon=sys.float_info.epsilon
+                ).item()
+            )
 
             # only use fix data range to [0;1] if the data was normalized before
             if normalize:
                 # torchmetrics PSNR but with fixed data range of 0 to 1
-                overall_stats_dict[current_subject_id][
-                    "psnr_01"
-                ] = peak_signal_noise_ratio(
-                    output_infill, gt_image_infill, data_range=(0, 1)
-                ).item()
+                overall_stats_dict[current_subject_id]["psnr_01"] = (
+                    peak_signal_noise_ratio(
+                        output_infill, gt_image_infill, data_range=(0, 1)
+                    ).item()
+                )
 
                 # same as above but with epsilon for robustness
-                overall_stats_dict[current_subject_id][
-                    "psnr_01_eps"
-                ] = peak_signal_noise_ratio(
-                    output_infill,
-                    gt_image_infill,
-                    data_range=(0, 1),
-                    epsilon=sys.float_info.epsilon,
-                ).item()
+                overall_stats_dict[current_subject_id]["psnr_01_eps"] = (
+                    peak_signal_noise_ratio(
+                        output_infill,
+                        gt_image_infill,
+                        data_range=(0, 1),
+                        epsilon=sys.float_info.epsilon,
+                    ).item()
+                )
 
     pprint(overall_stats_dict)
     if outputfile is not None: