diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 2a4fe9f7a8..57df016140 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -431,6 +431,63 @@ def get_data( metadata[key] = [m[key] for m in metadata_list] return _stack_images(patch_list, metadata), metadata + def _compute_mpp_target_res(self, closest_lvl, closest_lvl_dim, mpp_list, mpp: tuple): + """ + Computes the target dimensions for resizing a whole slide image + to match a user-specified resolution in microns per pixel (MPP). + + Args: + closest_lvl: Whole slide image level closest to user-provided MPP resolution. + closest_lvl_dim: Dimensions (height, width) of the image at the closest level. + mpp_list: List of MPP values for all levels of the whole slide image. + mpp: The MPP resolution at which the whole slide image representation should be extracted. + + """ + mpp_closest_lvl = mpp_list[closest_lvl] + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / mpp[0] + ds_factor_y = mpp_closest_lvl_y / mpp[1] + + target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) + + return target_res_x, target_res_y + + def _compute_mpp_tolerances(self, closest_lvl, mpp_list, mpp, atol, rtol) -> bool: + """ + Determines if user-provided MPP values are within a specified tolerance of the closest + level's MPP and checks if the closest level has higher resolution than desired MPP. + + Args: + closest_lvl: Whole slide image level closest to user-provided MPP resolution. + mpp_list: List of MPP values for all levels of the whole slide image. + mpp: The MPP resolution at which the whole slide image representation should be extracted. + atol: Absolute tolerance for MPP comparison. + rtol: Relative tolerance for MPP comparison. + + """ + user_mpp_x, user_mpp_y = mpp + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] + + # Define tolerance intervals for x and y of closest level + lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol + upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol + lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol + upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol + + # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level + is_within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) + is_within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) + is_within_tolerance = is_within_tolerance_x & is_within_tolerance_y + + # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp + closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x + closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y + closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y + + return is_within_tolerance, closest_level_is_bigger + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ Verify whether the specified file or files format is supported by WSI reader. @@ -603,6 +660,25 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: """ return self.reader.get_mpp(wsi, level) + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.ndarray: + """ + Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. + The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. + If the user-provided mpp is larger than the mpp of the closest level, + the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, + the next lower level (which has a higher resolution) is chosen. + The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. + + Args: + wsi: whole slide image object from WSIReader + mpp: the resolution in micron per pixel at which the representation of the whole slide image should be extracted. + atol: the acceptable absolute tolerance for resolution in micro per pixel. + rtol: the acceptable relative tolerance for resolution in micro per pixel. + + """ + return self.reader.get_wsi_at_mpp(wsi, mpp, atol, rtol) + def get_power(self, wsi, level: int) -> float: """ Returns the micro-per-pixel resolution of the whole slide image at a given level. @@ -744,6 +820,48 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> Any: + """ + Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. + The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. + If the user-provided mpp is larger than the mpp of the closest level, + the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, + the next lower level (which has a higher resolution) is chosen. + The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. + + Args: + wsi: whole slide image object from WSIReader + mpp: the resolution in micron per pixel at which the representation of the whole slide image should be extracted. + atol: the acceptable absolute tolerance for resolution in micro per pixel. + rtol: the acceptable relative tolerance for resolution in micro per pixel. + + """ + cp, _ = optional_import("cupy") + + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.resolutions["level_count"])] + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) + + within_tolerance, closest_level_is_bigger = super()._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) + + if within_tolerance: + # If the image at the desired mpp resolution is within tolerances, return the image at closest_level. + closest_lvl_wsi = wsi.read_region( + (0, 0), level=closest_lvl, size=wsi.resolutions["level_dimensions"][closest_lvl], num_workers=self.num_workers + ) + + elif closest_level_is_bigger: + # Otherwise, select the level closest to the desired mpp with a higher resolution and downsample it. + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + + else: + # If both checks fail, increase resolution (i.e., decrement level) and then downsample it. + closest_lvl = closest_lvl - 1 + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + + wsi_arr = cp.asnumpy(closest_lvl_wsi) + return wsi_arr + def get_power(self, wsi, level: int) -> float: """ Returns the objective power of the whole slide image at a given level. @@ -828,6 +946,29 @@ def _get_patch( return patch + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): + """ + Resizes the whole slide image to the specified resolution in microns per pixel (mpp). + + Args: + wsi: whole slide image object from WSIReader + user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. + closest_lvl: the wsi level that is closest to the user-provided mpp resolution. + mpp_list: list of mpp values for all levels of a whole slide image. + + """ + cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") + cp, _ = optional_import("cupy") + + closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] + + target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + + wsi_arr = cp.array(wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers)) + closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=1) + + return closest_lvl_wsi + @require_pkg(pkg_name="openslide") class OpenSlideWSIReader(BaseWSIReader): @@ -940,6 +1081,47 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.ndarray: + """ + Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. + The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. + If the user-provided mpp is larger than the mpp of the closest level, + the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, + the next lower level (which has a higher resolution) is chosen. + The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. + + Args: + wsi: whole slide image object from WSIReader + mpp: the resolution in micron per pixel at which the representation of the whole slide image should be extracted. + atol: the acceptable absolute tolerance for resolution in micro per pixel. + rtol: the acceptable relative tolerance for resolution in micro per pixel. + + """ + + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.level_count)] + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) + + within_tolerance, closest_level_is_bigger = super()._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) + + if within_tolerance: + # If the image at the desired mpp resolution is within tolerances, return the image at closest_level. + closest_lvl_wsi = wsi.read_region( + (0, 0), level=closest_lvl, size=wsi.level_dimensions[closest_lvl] + ) + + elif closest_level_is_bigger: + # Otherwise, select the level closest to the desired mpp with a higher resolution and downsample it. + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + + else: + # If both checks fail, increase resolution (i.e., decrement level) and then downsample it. + closest_lvl = closest_lvl - 1 + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + + wsi_arr = np.array(closest_lvl_wsi) + return wsi_arr + def get_power(self, wsi, level: int) -> float: """ Returns the objective power of the whole slide image at a given level. @@ -1010,6 +1192,28 @@ def _get_patch( return patch + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): + """ + Resizes the whole slide image to the specified resolution in microns per pixel (mpp). + + Args: + wsi: whole slide image object from WSIReader + user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. + closest_lvl: the wsi level that is closest to the user-provided mpp resolution. + mpp_list: list of mpp values for all levels of a whole slide image. + + """ + pil_image, _ = optional_import("PIL", name="Image") + + closest_lvl_dim = wsi.level_dimensions[closest_lvl] + + target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + + return closest_lvl_wsi + @require_pkg(pkg_name="tifffile") class TiffFileWSIReader(BaseWSIReader): @@ -1103,12 +1307,55 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: unit = "micrometer" convert_to_micron = ConvertUnits(unit, "micrometer") - # Here x and y resolutions are rational numbers so each of them is represented by a tuple. + + # Here, x and y resolutions are rational numbers so each of them is represented by a tuple. yres = wsi.pages[level].tags["YResolution"].value xres = wsi.pages[level].tags["XResolution"].value - return convert_to_micron(yres[1] / yres[0]), convert_to_micron(xres[1] / xres[0]) + if xres[0] & yres[0]: + return convert_to_micron(yres[1] / yres[0]), convert_to_micron(xres[1] / xres[0]) + else: + raise ValueError("The `XResolution` and/or `YResolution` property of the image is zero, " + "which is needed to obtain `mpp` for this file. Please use `level` instead.") + raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") + + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.ndarray: + """ + Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. + The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. + If the user-provided mpp is larger than the mpp of the closest level, + the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, + the next lower level (which has a higher resolution) is chosen. + The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. + + Args: + wsi: whole slide image object from WSIReader + mpp: the resolution in micron per pixel at which the representation of the whole slide image should be extracted. + atol: the acceptable absolute tolerance for resolution in micro per pixel. + rtol: the acceptable relative tolerance for resolution in micro per pixel. + + """ + + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] # Fails for some Tifffiles + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) + + within_tolerance, closest_level_is_bigger = super()._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) + + if within_tolerance: + # If the image at the desired mpp resolution is within tolerances, return the image at closest_level. + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=self.get_size(wsi, closest_lvl)) + + elif closest_level_is_bigger: + # Otherwise, select the level closest to the desired mpp with a higher resolution and downsample it. + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + + else: + # If both checks fail, increase resolution (i.e., decrement level) and then downsample it. + closest_lvl = closest_lvl - 1 + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) - raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") + wsi_arr = np.array(closest_lvl_wsi) + return wsi_arr def get_power(self, wsi, level: int) -> float: """ @@ -1154,7 +1401,7 @@ def _get_patch( Extracts and returns a patch image form the whole slide image. Args: - wsi: a whole slide image object loaded from a file or a lis of such objects + wsi: a whole slide image object loaded from a file or a list of such objects location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). size: (height, width) tuple giving the patch size at the given level (`level`). If None, it is set to the full image size at the given level. @@ -1186,3 +1433,25 @@ def _get_patch( patch = np.take(patch, [0, 1, 2], self.channel_dim) return patch + + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): + """ + Resizes the whole slide image to the specified resolution in microns per pixel (mpp). + + Args: + wsi: whole slide image object from WSIReader + user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. + closest_lvl: the wsi level that is closest to the user-provided mpp resolution. + mpp_list: list of mpp values for all levels of a whole slide image. + + """ + pil_image, _ = optional_import("PIL", name="Image") + + closest_lvl_dim = self.get_size(wsi, closest_lvl) + + target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + + closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + + return closest_lvl_wsi diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py index 66a5116c1a..768e3ed2bb 100644 --- a/monai/transforms/regularization/array.py +++ b/monai/transforms/regularization/array.py @@ -25,7 +25,6 @@ class Mixer(RandomizableTransform): - def __init__(self, batch_size: int, alpha: float = 1.0) -> None: """ Mixer is a base class providing the basic logic for the mixup-class of