Skip to content

Commit

Permalink
ensure that the brain mask and void image are treated differently
Browse files Browse the repository at this point in the history
  • Loading branch information
sarthakpati committed Dec 19, 2024
1 parent ca49b4c commit 1c4afd7
Showing 1 changed file with 49 additions and 34 deletions.
83 changes: 49 additions & 34 deletions GANDLF/cli/generate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,21 +302,22 @@ def __percentile_clip(
reference_tensor = (
input_tensor if reference_tensor is None else reference_tensor
)
v_min, v_max = np.percentile(
reference_tensor, [p_min, p_max]
) # get p_min percentile and p_max percentile

# get p_min percentile and p_max percentile
v_min, v_max = np.percentile(reference_tensor, [p_min, p_max])
# set lower bound to be 0 if strictlyPositive is enabled
v_min = max(v_min, 0.0) if strictlyPositive else v_min
output_tensor = np.clip(
input_tensor, v_min, v_max
) # clip values to percentiles from reference_tensor
output_tensor = (output_tensor - v_min) / (
v_max - v_min
) # normalizes values to [0;1]
# clip values to percentiles from reference_tensor
output_tensor = np.clip(input_tensor, v_min, v_max)
# normalizes values to [0;1]
output_tensor = (output_tensor - v_min) / (v_max - v_min)
return output_tensor

input_df = __update_header_location_case_insensitive(input_df, "Mask", False)
# these are additional columns that could be present for synthesis tasks
for column_to_make_case_insensitive in ["Mask", "VoidImage"]:
input_df = __update_header_location_case_insensitive(
input_df, column_to_make_case_insensitive, False
)

for _, row in tqdm(input_df.iterrows(), total=input_df.shape[0]):
current_subject_id = row["SubjectID"]
overall_stats_dict[current_subject_id] = {}
Expand All @@ -332,16 +333,26 @@ def __percentile_clip(
)
).byte()

void_image_present = True if "VoidImage" in row else False
void_image = (
__fix_2d_tensor(torchio.ScalarImage(row["VoidImage"]).data)
if "VoidImage" in row
else torch.from_numpy(
np.ones(target_image.numpy().shape, dtype=np.uint8)
)
)

# Get Infill region (we really are only interested in the infill region)
output_infill = (pred_image * mask).float()
gt_image_infill = (target_image * mask).float()

# Normalize to [0;1] based on GT (otherwise MSE will depend on the image intensity range)
normalize = parameters.get("normalize", True)
if normalize:
# use all the tissue that is not masked for normalization
reference_tensor = (
target_image * ~mask
) # use all the tissue that is not masked for normalization
target_image * ~mask if not void_image_present else void_image
)
gt_image_infill = __percentile_clip(
gt_image_infill,
reference_tensor=reference_tensor,
Expand All @@ -357,9 +368,9 @@ def __percentile_clip(
strictlyPositive=True,
)

overall_stats_dict[current_subject_id][
"ssim"
] = structural_similarity_index(output_infill, gt_image_infill, mask).item()
overall_stats_dict[current_subject_id]["ssim"] = (
structural_similarity_index(output_infill, gt_image_infill, mask).item()
)

# ncc metrics
compute_ncc = parameters.get("compute_ncc", True)
Expand All @@ -386,6 +397,10 @@ def __percentile_clip(
output_infill, gt_image_infill
).item()

overall_stats_dict[current_subject_id]["rmse"] = mean_squared_error(

Check warning on line 400 in GANDLF/cli/generate_metrics.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`rmse` is not a recognized word. (unrecognized-spelling)
output_infill, gt_image_infill, squared=False
).item()

overall_stats_dict[current_subject_id]["msle"] = mean_squared_log_error(
output_infill, gt_image_infill
).item()
Expand All @@ -400,30 +415,30 @@ def __percentile_clip(
).item()

# same as above but with epsilon for robustness
overall_stats_dict[current_subject_id][
"psnr_eps"
] = peak_signal_noise_ratio(
output_infill, gt_image_infill, epsilon=sys.float_info.epsilon
).item()
overall_stats_dict[current_subject_id]["psnr_eps"] = (
peak_signal_noise_ratio(
output_infill, gt_image_infill, epsilon=sys.float_info.epsilon
).item()
)

# only use fix data range to [0;1] if the data was normalized before
if normalize:
# torchmetrics PSNR but with fixed data range of 0 to 1
overall_stats_dict[current_subject_id][
"psnr_01"
] = peak_signal_noise_ratio(
output_infill, gt_image_infill, data_range=(0, 1)
).item()
overall_stats_dict[current_subject_id]["psnr_01"] = (
peak_signal_noise_ratio(
output_infill, gt_image_infill, data_range=(0, 1)
).item()
)

# same as above but with epsilon for robustness
overall_stats_dict[current_subject_id][
"psnr_01_eps"
] = peak_signal_noise_ratio(
output_infill,
gt_image_infill,
data_range=(0, 1),
epsilon=sys.float_info.epsilon,
).item()
overall_stats_dict[current_subject_id]["psnr_01_eps"] = (
peak_signal_noise_ratio(
output_infill,
gt_image_infill,
data_range=(0, 1),
epsilon=sys.float_info.epsilon,
).item()
)

pprint(overall_stats_dict)
if outputfile is not None:
Expand Down

0 comments on commit 1c4afd7

Please sign in to comment.