diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index b31d4d9c3a..f3f099160f 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -19,6 +19,7 @@ import numpy as np import torch +import cv2 from monai.config import DtypeLike, NdarrayOrTensor, PathLike from monai.data.image_reader import ImageReader, _stack_images @@ -940,6 +941,87 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") + def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + """ + Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. + The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. + If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. + + Args: + wsi: whole slide image object from WSIReader + mpp: the resolution in micron per pixel at which the representation of the whole slide image should be extracted. + atol: the acceptable absolute tolerance for resolution in micro per pixel. + rtol: the acceptable relative tolerance for resolution in micro per pixel. + + """ + + user_mpp_x, user_mpp_y = mpp + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.level_count)] + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # Should not throw ValueError, instead just return the closest value; + mpp_closest_lvl = mpp_list[closest_lvl] + closest_lvl_dim = wsi.level_dimensions[closest_lvl] + + print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + # Define tolerance intervals for x and y of closest level + lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol + upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol + lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol + upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol + + # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level + within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) + within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) + within_tolerance = within_tolerance_x & within_tolerance_y + + if within_tolerance: + # Take closest_level and continue with returning img at level + print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') + closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + + return closest_lvl_wsi + else: + # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp + closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x + closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y + closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y + + if closest_level_is_bigger: + ds_factor_x = mpp_closest_lvl_x / user_mpp_x + ds_factor_y = mpp_closest_lvl_y / user_mpp_y + + closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) + + print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') + return closest_lvl_wsi + else: + # Else: increase resolution (ie, decrement level) and then downsample + closest_lvl = closest_lvl - 1 + mpp_closest_lvl = mpp_list[closest_lvl] # Update MPP + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp_x + ds_factor_y = mpp_closest_lvl_y / user_mpp_y + + closest_lvl_dim = wsi.level_dimensions[closest_lvl] + closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) + + print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') + return closest_lvl_wsi + def get_power(self, wsi, level: int) -> float: """ Returns the objective power of the whole slide image at a given level.