Skip to content

Commit

Permalink
Added function get_img_at_mpp to class OpenSlideWSIReader of module w…
Browse files Browse the repository at this point in the history
…si_reader.py
  • Loading branch information
Nikolas Schmitz committed Mar 22, 2024
1 parent c649934 commit f3e7d03
Showing 1 changed file with 82 additions and 0 deletions.
82 changes: 82 additions & 0 deletions monai/data/wsi_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import numpy as np
import torch
import cv2

from monai.config import DtypeLike, NdarrayOrTensor, PathLike
from monai.data.image_reader import ImageReader, _stack_images
Expand Down Expand Up @@ -940,6 +941,87 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]:

raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.")

def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array:
"""
Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution.
The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user.
If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp.
Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen.
The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value.
Args:
wsi: whole slide image object from WSIReader
mpp: the resolution in micron per pixel at which the representation of the whole slide image should be extracted.
atol: the acceptable absolute tolerance for resolution in micro per pixel.
rtol: the acceptable relative tolerance for resolution in micro per pixel.
"""

user_mpp_x, user_mpp_y = mpp
mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.level_count)]
closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # Should not throw ValueError, instead just return the closest value;
mpp_closest_lvl = mpp_list[closest_lvl]
closest_lvl_dim = wsi.level_dimensions[closest_lvl]

print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}')
mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl

# Define tolerance intervals for x and y of closest level
lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol
upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol
lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol
upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol

# Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level
within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x)
within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y)
within_tolerance = within_tolerance_x & within_tolerance_y

if within_tolerance:
# Take closest_level and continue with returning img at level
print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.')
closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3]

return closest_lvl_wsi
else:
# If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp
closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x
closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y
closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y

if closest_level_is_bigger:
ds_factor_x = mpp_closest_lvl_x / user_mpp_x
ds_factor_y = mpp_closest_lvl_y / user_mpp_y

closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3]

target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x))
target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y))

closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR)

print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}')
return closest_lvl_wsi
else:
# Else: increase resolution (ie, decrement level) and then downsample
closest_lvl = closest_lvl - 1
mpp_closest_lvl = mpp_list[closest_lvl] # Update MPP
mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl

ds_factor_x = mpp_closest_lvl_x / user_mpp_x
ds_factor_y = mpp_closest_lvl_y / user_mpp_y

closest_lvl_dim = wsi.level_dimensions[closest_lvl]
closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3]

target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x))
target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y))

closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR)

print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}')
return closest_lvl_wsi

def get_power(self, wsi, level: int) -> float:
"""
Returns the objective power of the whole slide image at a given level.
Expand Down

0 comments on commit f3e7d03

Please sign in to comment.