diff --git a/python/cucim/src/cucim/skimage/_shared/tests/test_utils.py b/python/cucim/src/cucim/skimage/_shared/tests/test_utils.py index da0824bb4..f7135a0e3 100644 --- a/python/cucim/src/cucim/skimage/_shared/tests/test_utils.py +++ b/python/cucim/src/cucim/skimage/_shared/tests/test_utils.py @@ -157,7 +157,8 @@ def test_validate_interpolation_order(dtype, order): ) def test_supported_float_dtype_real(dtype): float_dtype = _supported_float_type(dtype) - if dtype in [np.float16, np.float32]: + if dtype in [np.float16, np.float32, np.int8, np.uint8, np.int16, + np.uint16, bool]: assert float_dtype == np.float32 else: assert float_dtype == np.float64 @@ -188,7 +189,9 @@ def test_supported_float_dtype_input_kinds(dtype): 'dtypes, expected', [ ((np.float16, np.float64), np.float64), - ([np.float32, np.uint16, np.int8], np.float64), + ([np.float32, np.uint16, np.int8], np.float32), + ([np.float32, bool], np.float32), + ([np.float32, np.uint32, np.int16], np.float64), ({np.float32, np.float16}, np.float32), ] ) diff --git a/python/cucim/src/cucim/skimage/_shared/utils.py b/python/cucim/src/cucim/skimage/_shared/utils.py index 1ee6cd600..6798628c7 100644 --- a/python/cucim/src/cucim/skimage/_shared/utils.py +++ b/python/cucim/src/cucim/skimage/_shared/utils.py @@ -654,9 +654,10 @@ def convert_to_float(image, preserve_range): # Convert image to double only if it is not single or double # precision float if image.dtype.char not in 'df': - image = image.astype(float) + image = image.astype(_supported_float_type(image.dtype)) else: from ..util.dtype import img_as_float + image = img_as_float(image) return image @@ -730,14 +731,21 @@ def _fix_ndimage_mode(mode): new_float_type = { # preserved types - cp.float32().dtype.char: cp.float32, - cp.float64().dtype.char: cp.float64, - cp.complex64().dtype.char: cp.complex64, - cp.complex128().dtype.char: cp.complex128, - # altered types - cp.float16().dtype.char: cp.float32, - 'g': cp.float64, # cp.float128 ; doesn't exist on windows - 'G': cp.complex128, # cp.complex256 ; doesn't exist on windows + 'f': cp.float32, # float32 + 'd': cp.float64, # float64 + 'F': cp.complex64, # complex64 + 'D': cp.complex128, # complex128 + # promoted float types + 'e': cp.float32, # float16 + # truncated float types + 'g': cp.float64, # float128 (doesn't exist on windows) + 'G': cp.complex128, # complex256 (doesn't exist on windows) + # integer types that can be exactly represented in float32 + 'b': cp.float32, # int8 + 'B': cp.float32, # uint8 + 'h': cp.float32, # int16 + 'H': cp.float32, # uint16 + '?': cp.float32, # bool } diff --git a/python/cucim/src/cucim/skimage/color/tests/test_colorconv.py b/python/cucim/src/cucim/skimage/color/tests/test_colorconv.py index 5b59bec25..0d50cf119 100644 --- a/python/cucim/src/cucim/skimage/color/tests/test_colorconv.py +++ b/python/cucim/src/cucim/skimage/color/tests/test_colorconv.py @@ -14,7 +14,8 @@ import cupy as cp import numpy as np import pytest -from cupy.testing import assert_array_almost_equal, assert_array_equal +from cupy.testing import (assert_allclose, assert_array_almost_equal, + assert_array_equal) from numpy.testing import assert_equal from skimage import data @@ -236,7 +237,7 @@ def test_xyz_rgb_roundtrip(self, channel_axis): round_trip = xyz2rgb(rgb2xyz(img_rgb, channel_axis=channel_axis), channel_axis=channel_axis) - assert_array_almost_equal(round_trip, img_rgb) + assert_allclose(round_trip, img_rgb, rtol=1e-5, atol=1e-5) # RGB<->HED roundtrip with ubyte image def test_hed_rgb_roundtrip(self): @@ -497,12 +498,15 @@ def test_rgb2lab_brucelindbloom(self): def test_lab_rgb_roundtrip(self, channel_axis): img_rgb = img_as_float(self.img_rgb) img_rgb = cp.moveaxis(img_rgb, source=-1, destination=channel_axis) - assert_array_almost_equal( + + assert_allclose( lab2rgb( rgb2lab(img_rgb, channel_axis=channel_axis), channel_axis=channel_axis ), img_rgb, + rtol=1e-5, + atol=1e-5, ) def test_rgb2lab_dtype(self): @@ -633,12 +637,14 @@ def test_luv2rgb_dtype(self): def test_luv_rgb_roundtrip(self, channel_axis): img_rgb = img_as_float(self.img_rgb) img_rgb = cp.moveaxis(img_rgb, source=-1, destination=channel_axis) - assert_array_almost_equal( + assert_allclose( luv2rgb( rgb2luv(img_rgb, channel_axis=channel_axis), channel_axis=channel_axis ), img_rgb, + rtol=1e-4, + atol=1e-4, ) def test_lab_rgb_outlier(self): @@ -674,7 +680,7 @@ def test_lab_lch_roundtrip(self, channel_axis): lab2lch(lab, channel_axis=channel_axis), channel_axis=channel_axis, ) - assert_array_almost_equal(lab2, lab) + assert_allclose(lab2, lab, rtol=1e-4, atol=1e-4) def test_rgb_lch_roundtrip(self): rgb = img_as_float(self.img_rgb) @@ -682,7 +688,7 @@ def test_rgb_lch_roundtrip(self): lch = lab2lch(lab) lab2 = lch2lab(lch) rgb2 = lab2rgb(lab2) - assert_array_almost_equal(rgb, rgb2) + assert_allclose(rgb, rgb2, rtol=1e-4, atol=1e-4) def test_lab_lch_0d(self): lab0 = self._get_lab0() @@ -736,26 +742,41 @@ def test_yuv(self): def test_yuv_roundtrip(self, channel_axis): img_rgb = img_as_float(self.img_rgb)[::16, ::16] img_rgb = cp.moveaxis(img_rgb, source=-1, destination=channel_axis) - assert_array_almost_equal( + assert_allclose( yuv2rgb(rgb2yuv(img_rgb, channel_axis=channel_axis), channel_axis=channel_axis), - img_rgb) - assert_array_almost_equal( + img_rgb, + rtol=1e-5, + atol=1e-5, + ) + assert_allclose( yiq2rgb(rgb2yiq(img_rgb, channel_axis=channel_axis), channel_axis=channel_axis), - img_rgb) - assert_array_almost_equal( + img_rgb, + rtol=1e-5, + atol=1e-5, + ) + assert_allclose( ypbpr2rgb(rgb2ypbpr(img_rgb, channel_axis=channel_axis), channel_axis=channel_axis), - img_rgb) - assert_array_almost_equal( + img_rgb, + rtol=1e-5, + atol=1e-5, + ) + assert_allclose( ycbcr2rgb(rgb2ycbcr(img_rgb, channel_axis=channel_axis), channel_axis=channel_axis), - img_rgb) - assert_array_almost_equal( + img_rgb, + rtol=1e-5, + atol=1e-5, + ) + assert_allclose( ydbdr2rgb(rgb2ydbdr(img_rgb, channel_axis=channel_axis), channel_axis=channel_axis), - img_rgb) + img_rgb, + rtol=1e-5, + atol=1e-5, + ) def test_rgb2yuv_dtype(self): img = self.colbars_array.astype('float64') diff --git a/python/cucim/src/cucim/skimage/exposure/exposure.py b/python/cucim/src/cucim/skimage/exposure/exposure.py index 807fbb939..f7c0eeb4f 100644 --- a/python/cucim/src/cucim/skimage/exposure/exposure.py +++ b/python/cucim/src/cucim/skimage/exposure/exposure.py @@ -440,7 +440,7 @@ def intensity_range(image, range_values="image", clip_negative=False): return i_min, i_max -def _output_dtype(dtype_or_range): +def _output_dtype(dtype_or_range, image_dtype): """Determine the output dtype for rescale_intensity. The dtype is determined according to the following rules: @@ -450,13 +450,16 @@ def _output_dtype(dtype_or_range): in which case the data type that can contain it will be used (e.g. uint16 in this case). - if ``dtype_or_range`` is a pair of values, the output data type will be - float. + ``_supported_float_type(image_dtype)``. This preserves float32 output for + float32 inputs. Parameters ---------- dtype_or_range : type, string, or 2-tuple of int/float The desired range for the output, expressed as either a NumPy dtype or as a (min, max) pair of numbers. + image_dtype : np.dtype + The input image dtype. Returns ------- @@ -465,7 +468,7 @@ def _output_dtype(dtype_or_range): """ if type(dtype_or_range) in [list, tuple, np.ndarray]: # pair of values: always return float. - return float + return utils._supported_float_type(image_dtype) if type(dtype_or_range) == type: # already a type: return it return dtype_or_range @@ -577,9 +580,9 @@ def rescale_intensity(image, in_range="image", out_range="dtype"): array([127, 127, 127], dtype=int32) """ if out_range in ['dtype', 'image']: - out_dtype = _output_dtype(image.dtype.type) + out_dtype = _output_dtype(image.dtype.type, image.dtype) else: - out_dtype = _output_dtype(out_range) + out_dtype = _output_dtype(out_range, image.dtype) imin, imax = map(float, intensity_range(image, in_range)) omin, omax = map(float, intensity_range(image, out_range, diff --git a/python/cucim/src/cucim/skimage/exposure/tests/test_exposure.py b/python/cucim/src/cucim/skimage/exposure/tests/test_exposure.py index be917f2f7..f27c8fff7 100644 --- a/python/cucim/src/cucim/skimage/exposure/tests/test_exposure.py +++ b/python/cucim/src/cucim/skimage/exposure/tests/test_exposure.py @@ -387,7 +387,7 @@ def test_rescale_float_output(): image = cp.array([-128, 0, 127], dtype=cp.int8) output_image = exposure.rescale_intensity(image, out_range=(0, 255)) cp.testing.assert_array_equal(output_image, [0, 128, 255]) - assert output_image.dtype == float + assert output_image.dtype == _supported_float_type(image.dtype) def test_rescale_raises_on_incorrect_out_range(): diff --git a/python/cucim/src/cucim/skimage/feature/corner.py b/python/cucim/src/cucim/skimage/feature/corner.py index 48c9b3764..8e174eb2d 100644 --- a/python/cucim/src/cucim/skimage/feature/corner.py +++ b/python/cucim/src/cucim/skimage/feature/corner.py @@ -140,8 +140,13 @@ def structure_tensor(image, sigma=1, mode="constant", cval=0, order=None): if order == "xy": derivatives = reversed(derivatives) + # Autodetection as done internally to Gaussian, but set it here to silence + # a warning. + channel_axis = -1 if (image.ndim == 3 and image.shape[-1] == 3) else None + # structure tensor - A_elems = [gaussian(der0 * der1, sigma, mode=mode, cval=cval) + A_elems = [gaussian(der0 * der1, sigma, mode=mode, cval=cval, + channel_axis=channel_axis) for der0, der1 in combinations_with_replacement(derivatives, 2)] return A_elems @@ -205,7 +210,12 @@ def hessian_matrix(image, sigma=1, mode='constant', cval=0, order='rc'): float_dtype = _supported_float_type(image.dtype) image = image.astype(float_dtype, copy=False) - gaussian_filtered = gaussian(image, sigma=sigma, mode=mode, cval=cval) + # Autodetection as done internally to Gaussian, but set it here to silence + # a warning. + channel_axis = -1 if (image.ndim == 3 and image.shape[-1] == 3) else None + + gaussian_filtered = gaussian(image, sigma=sigma, mode=mode, cval=cval, + channel_axis=channel_axis) gradients = cp.gradient(gaussian_filtered) axes = range(image.ndim) diff --git a/python/cucim/src/cucim/skimage/feature/template.py b/python/cucim/src/cucim/skimage/feature/template.py index cbf0c7811..ce3443c27 100644 --- a/python/cucim/src/cucim/skimage/feature/template.py +++ b/python/cucim/src/cucim/skimage/feature/template.py @@ -134,7 +134,9 @@ def match_template(image, template, pad_input=False, mode='constant', image_shape = image.shape float_dtype = _supported_float_type(image.dtype) - image = image.astype(float_dtype, copy=False) + + # Note: keep image in float64 for accuracy of cumsum operations, etc. + image = image.astype(cp.float64, copy=False) template = template.astype(float_dtype, copy=False) pad_width = tuple((width, width) for width in template.shape) @@ -153,11 +155,12 @@ def match_template(image, template, pad_input=False, mode='constant', image_window_sum = _window_sum_3d(image, template.shape) image_window_sum2 = _window_sum_3d(image * image, template.shape) - template_mean = template.mean() + # perform mean and sum in float64 for accuracy + template_mean = template.mean(dtype=cp.float64) template_volume = _misc.prod(template.shape) template_ssd = template - template_mean template_ssd *= template_ssd - template_ssd = cp.sum(template_ssd) + template_ssd = cp.sum(template_ssd, dtype=cp.float64) if image.ndim == 2: xcorr = signal.fftconvolve(image, template[::-1, ::-1], diff --git a/python/cucim/src/cucim/skimage/feature/tests/test_template.py b/python/cucim/src/cucim/skimage/feature/tests/test_template.py index 61effa978..4b9050927 100644 --- a/python/cucim/src/cucim/skimage/feature/tests/test_template.py +++ b/python/cucim/src/cucim/skimage/feature/tests/test_template.py @@ -26,7 +26,7 @@ def test_template(dtype): target = cp.asarray(target) result = match_template(image, target) - assert result.dtype == dtype + assert result.dtype == result.dtype delta = 5 positions = peak_local_max(result, min_distance=delta) @@ -200,6 +200,5 @@ def test_bounding_values(): template = cp.zeros((3, 3)) template[1, 1] = 1 result = match_template(image, template) - print(result.max()) assert result.max() < 1 + 1e-7 assert result.min() > -1 - 1e-7 diff --git a/python/cucim/src/cucim/skimage/filters/_fft_based.py b/python/cucim/src/cucim/skimage/filters/_fft_based.py index f659f2c55..905edbb72 100644 --- a/python/cucim/src/cucim/skimage/filters/_fft_based.py +++ b/python/cucim/src/cucim/skimage/filters/_fft_based.py @@ -123,6 +123,7 @@ def butterworth( else np.delete(image.shape, channel_axis)) is_real = cp.isrealobj(image) float_dtype = _supported_float_type(image.dtype, allow_complex=True) + image = image.astype(float_dtype, copy=False) wfilt = _get_ND_butterworth_filter( fft_shape, cutoff_frequency_ratio, order, high_pass, is_real, float_dtype diff --git a/python/cucim/src/cucim/skimage/filters/ridges.py b/python/cucim/src/cucim/skimage/filters/ridges.py index b635711ed..1ce7d4a91 100644 --- a/python/cucim/src/cucim/skimage/filters/ridges.py +++ b/python/cucim/src/cucim/skimage/filters/ridges.py @@ -234,13 +234,13 @@ def meijering(image, sigmas=range(1, 10, 2), alpha=None, if alpha is None: alpha = 1.0 / ndim - float_dtype = _supported_float_type(image.dtype) - image = image.astype(float_dtype, copy=False) - # Invert image to detect dark ridges on bright background if black_ridges: image = invert(image) + float_dtype = _supported_float_type(image.dtype) + image = image.astype(float_dtype, copy=False) + # Generate empty (n+1)D arrays for storing auxiliary images filtered at # different (sigma) scales filtered_array = cp.empty(sigmas.shape + image.shape, dtype=float_dtype) diff --git a/python/cucim/src/cucim/skimage/filters/tests/test_fft_based.py b/python/cucim/src/cucim/skimage/filters/tests/test_fft_based.py index 457e54276..9dcaa43dc 100644 --- a/python/cucim/src/cucim/skimage/filters/tests/test_fft_based.py +++ b/python/cucim/src/cucim/skimage/filters/tests/test_fft_based.py @@ -108,7 +108,7 @@ def test_butterworth_4D_channel(chan, dtype): def test_butterworth_correctness_bw(): - small = cp.array(coins()[180:190, 260:270]) + small = cp.array(coins()[180:190, 260:270], dtype=float) filtered = butterworth(small, cutoff_frequency_ratio=0.2) correct = cp.array( @@ -129,7 +129,7 @@ def test_butterworth_correctness_bw(): def test_butterworth_correctness_rgb(): - small = cp.array(astronaut()[135:145, 205:215]) + small = cp.array(astronaut()[135:145, 205:215], dtype=float) filtered = butterworth(small, cutoff_frequency_ratio=0.3, high_pass=True, diff --git a/python/cucim/src/cucim/skimage/filters/tests/test_ridges.py b/python/cucim/src/cucim/skimage/filters/tests/test_ridges.py index 3c328f5d2..13b944f72 100644 --- a/python/cucim/src/cucim/skimage/filters/tests/test_ridges.py +++ b/python/cucim/src/cucim/skimage/filters/tests/test_ridges.py @@ -4,7 +4,7 @@ from cupy.testing import assert_allclose, assert_array_equal, assert_array_less from skimage.data import camera, retina -from cucim.skimage import img_as_float +from cucim.skimage import img_as_float, img_as_float64 from cucim.skimage._shared.utils import _supported_float_type from cucim.skimage.color import rgb2gray from cucim.skimage.filters import frangi, hessian, meijering, sato @@ -178,9 +178,12 @@ def test_3d_linearity(): atol=1e-3) -def test_2d_cropped_camera_image(): - +@pytest.mark.parametrize('dtype', ['float64', 'uint8']) +def test_2d_cropped_camera_image(dtype): a_black = crop(cp.array(camera()), ((200, 212), (100, 312))) + assert a_black.dtype == cp.uint8 + if dtype == 'float64': + a_black = img_as_float64(a_black) a_white = invert(a_black) zeros = cp.zeros((100, 100)) @@ -208,9 +211,13 @@ def test_ridge_output_dtype(func, dtype): assert func(img).dtype == _supported_float_type(img.dtype) -def test_3d_cropped_camera_image(): +@pytest.mark.parametrize('dtype', ['float64', 'uint8']) +def test_3d_cropped_camera_image(dtype): a_black = crop(cp.asarray(camera()), ((200, 212), (100, 312))) + assert a_black.dtype == cp.uint8 + if dtype == 'float64': + a_black = img_as_float64(a_black) a_black = cp.dstack([a_black, a_black, a_black]) a_white = invert(a_black) diff --git a/python/cucim/src/cucim/skimage/measure/tests/test_regionprops.py b/python/cucim/src/cucim/skimage/measure/tests/test_regionprops.py index c9bb3bd1d..09a36cfac 100644 --- a/python/cucim/src/cucim/skimage/measure/tests/test_regionprops.py +++ b/python/cucim/src/cucim/skimage/measure/tests/test_regionprops.py @@ -157,14 +157,14 @@ def test_area_bbox(): def test_moments_central(): mu = regionprops(SAMPLE)[0].moments_central # determined with OpenCV - assert_almost_equal(mu[2, 0], 436.00000000000045) + assert_almost_equal(mu[2, 0], 436.00000000000045, decimal=4) # different from OpenCV results, bug in OpenCV - assert_almost_equal(mu[3, 0], -737.333333333333) - assert_almost_equal(mu[1, 1], -87.33333333333303) - assert_almost_equal(mu[2, 1], -127.5555555555593) - assert_almost_equal(mu[0, 2], 1259.7777777777774) - assert_almost_equal(mu[1, 2], 2000.296296296291) - assert_almost_equal(mu[0, 3], -760.0246913580195) + assert_almost_equal(mu[3, 0], -737.333333333333, decimal=3) + assert_almost_equal(mu[1, 1], -87.33333333333303, decimal=3) + assert_almost_equal(mu[2, 1], -127.5555555555593, decimal=3) + assert_almost_equal(mu[0, 2], 1259.7777777777774, decimal=2) + assert_almost_equal(mu[1, 2], 2000.296296296291, decimal=2) + assert_almost_equal(mu[0, 3], -760.0246913580195, decimal=2) def test_centroid(): @@ -330,7 +330,7 @@ def test_axis_major_length(): length = regionprops(SAMPLE)[0].axis_major_length # MATLAB has different interpretation of ellipse than found in literature, # here implemented as found in literature - assert_almost_equal(length, 16.7924234999) + assert_almost_equal(length, 16.7924234999, decimal=5) def test_intensity_max(): @@ -355,7 +355,7 @@ def test_axis_minor_length(): length = regionprops(SAMPLE)[0].axis_minor_length # MATLAB has different interpretation of ellipse than found in literature, # here implemented as found in literature - assert_almost_equal(length, 9.739302807263) + assert_almost_equal(length, 9.739302807263, decimal=5) def test_moments(): diff --git a/python/cucim/src/cucim/skimage/metrics/tests/test_structural_similarity.py b/python/cucim/src/cucim/skimage/metrics/tests/test_structural_similarity.py index 60281038d..46532c6c4 100644 --- a/python/cucim/src/cucim/skimage/metrics/tests/test_structural_similarity.py +++ b/python/cucim/src/cucim/skimage/metrics/tests/test_structural_similarity.py @@ -30,7 +30,7 @@ def test_structural_similarity_patch_range(): Y = (rstate.rand(N, N) * 255).astype(cp.uint8) assert structural_similarity(X, Y, win_size=N) < 0.1 - assert_equal(structural_similarity(X, X, win_size=N), 1) + assert_almost_equal(structural_similarity(X, X, win_size=N), 1) def test_structural_similarity_image(): @@ -40,7 +40,7 @@ def test_structural_similarity_image(): Y = (rstate.rand(N, N) * 255).astype(cp.uint8) S0 = structural_similarity(X, X, win_size=3) - assert_equal(S0, 1) + assert_almost_equal(S0, 1) S1 = structural_similarity(X, Y, win_size=3) assert S1 < 0.3 @@ -51,10 +51,10 @@ def test_structural_similarity_image(): mssim0, S3 = structural_similarity(X, Y, full=True) assert_equal(S3.shape, X.shape) mssim = structural_similarity(X, Y) - assert_equal(mssim0, mssim) + assert_almost_equal(mssim0, mssim) # structural_similarity of image with itself should be 1.0 - assert_equal(structural_similarity(X, X), 1.0) + assert_almost_equal(structural_similarity(X, X), 1.0) # Because we are forcing a random seed state, it is probably good to test @@ -223,7 +223,15 @@ def test_gaussian_structural_similarity_vs_IPOL(): def test_mssim_vs_legacy(): # check that ssim with default options matches skimage 0.11 result mssim_skimage_0pt17 = 0.3674518327910367 + + # uint8 will be computed in float32 precision mssim = structural_similarity(cam, cam_noisy) + assert_almost_equal(mssim, mssim_skimage_0pt17, decimal=4) + + # also check with double precision and explicit specification of data_range + mssim = structural_similarity(cam.astype(float), + cam_noisy.astype(float), + data_range=255) assert_almost_equal(mssim, mssim_skimage_0pt17) diff --git a/python/cucim/src/cucim/skimage/restoration/tests/test_denoise.py b/python/cucim/src/cucim/skimage/restoration/tests/test_denoise.py index 965ea13e4..6020dac41 100644 --- a/python/cucim/src/cucim/skimage/restoration/tests/test_denoise.py +++ b/python/cucim/src/cucim/skimage/restoration/tests/test_denoise.py @@ -112,7 +112,7 @@ def test_denoise_tv_chambolle_float_result_range(): denoised_int_astro = restoration.denoise_tv_chambolle(int_astro, weight=0.1) # test if the value range of output float data is within [0.0:1.0] - assert denoised_int_astro.dtype == float + assert denoised_int_astro.dtype == _supported_float_type(int_astro.dtype) assert cp.max(denoised_int_astro) <= 1.0 assert cp.min(denoised_int_astro) >= 0.0 @@ -126,8 +126,9 @@ def test_denoise_tv_chambolle_3d(): mask += 20 * cp.random.rand(*mask.shape) mask[mask < 0] = 0 mask[mask > 255] = 255 - res = restoration.denoise_tv_chambolle(mask.astype(np.uint8), weight=0.1) - assert res.dtype == float + mask = mask.astype(np.uint8) + res = restoration.denoise_tv_chambolle(mask, weight=0.1) + assert res.dtype == _supported_float_type(mask.dtype) assert res.std() * 255 < mask.std() @@ -136,16 +137,18 @@ def test_denoise_tv_chambolle_1d(): x = 125 + 100 * cp.sin(cp.linspace(0, 8 * cp.pi, 1000)) x += 20 * cp.random.rand(x.size) x = cp.clip(x, 0, 255) - res = restoration.denoise_tv_chambolle(x.astype(np.uint8), weight=0.1) - assert res.dtype == float + x = x.astype(np.uint8) + res = restoration.denoise_tv_chambolle(x, weight=0.1) + assert res.dtype == _supported_float_type(x.dtype) assert res.std() * 255 < x.std() def test_denoise_tv_chambolle_4d(): """ TV denoising for a 4D input.""" im = 255 * cp.random.rand(8, 8, 8, 8) - res = restoration.denoise_tv_chambolle(im.astype(np.uint8), weight=0.1) - assert res.dtype == float + im = im.astype(np.uint8) + res = restoration.denoise_tv_chambolle(im, weight=0.1) + assert res.dtype == _supported_float_type(im.dtype) assert res.std() * 255 < im.std() diff --git a/python/cucim/src/cucim/skimage/transform/tests/test_warps.py b/python/cucim/src/cucim/skimage/transform/tests/test_warps.py index e8a1f6219..6504c082d 100644 --- a/python/cucim/src/cucim/skimage/transform/tests/test_warps.py +++ b/python/cucim/src/cucim/skimage/transform/tests/test_warps.py @@ -380,8 +380,8 @@ def test_resize_dtype(): assert resize(x, (10, 10), preserve_range=False).dtype == x.dtype assert resize(x, (10, 10), preserve_range=True).dtype == x.dtype - assert resize(x_u8, (10, 10), preserve_range=False).dtype == cp.double - assert resize(x_u8, (10, 10), preserve_range=True).dtype == cp.double + assert resize(x_u8, (10, 10), preserve_range=False).dtype == cp.float32 + assert resize(x_u8, (10, 10), preserve_range=True).dtype == cp.float32 assert resize(x_b, (10, 10), preserve_range=False).dtype == bool assert resize(x_b, (10, 10), preserve_range=True).dtype == bool assert resize(x_f32, (10, 10), preserve_range=False).dtype == x_f32.dtype @@ -917,13 +917,13 @@ def test_resize_local_mean_dtype(): assert resize_local_mean(x, (10, 10), preserve_range=True).dtype == x.dtype assert resize_local_mean(x_u8, (10, 10), - preserve_range=False).dtype == cp.double + preserve_range=False).dtype == cp.float32 assert resize_local_mean(x_u8, (10, 10), - preserve_range=True).dtype == cp.double + preserve_range=True).dtype == cp.float32 assert resize_local_mean(x_b, (10, 10), - preserve_range=False).dtype == cp.double + preserve_range=False).dtype == cp.float32 assert resize_local_mean(x_b, (10, 10), - preserve_range=True).dtype == cp.double + preserve_range=True).dtype == cp.float32 assert resize_local_mean(x_f32, (10, 10), preserve_range=False).dtype == x_f32.dtype assert resize_local_mean(x_f32, (10, 10), diff --git a/python/cucim/src/cucim/skimage/util/dtype.py b/python/cucim/src/cucim/skimage/util/dtype.py index 6390f369e..027a49882 100644 --- a/python/cucim/src/cucim/skimage/util/dtype.py +++ b/python/cucim/src/cucim/skimage/util/dtype.py @@ -3,6 +3,8 @@ import cupy as cp +from .._shared.utils import _supported_float_type + __all__ = ['img_as_float32', 'img_as_float64', 'img_as_float', 'img_as_int', 'img_as_uint', 'img_as_ubyte', 'img_as_bool', 'dtype_limits'] @@ -464,7 +466,9 @@ def img_as_float(image, force_copy=False): and can be outside the ranges [0.0, 1.0] or [-1.0, 1.0]. """ - return _convert(image, cp.floating, force_copy) + # casts float16, float32 and 8 or 16-bit integer types to float32 + float_dtype = _supported_float_type(image.dtype) + return _convert(image, float_dtype, force_copy) def img_as_uint(image, force_copy=False):