diff --git a/examples/add_labels.py b/examples/add_labels.py index a58180e2ca3..25bc395de00 100644 --- a/examples/add_labels.py +++ b/examples/add_labels.py @@ -7,7 +7,6 @@ .. tags:: layers, visualization-basic """ - from skimage import data from skimage.filters import threshold_otsu from skimage.measure import label @@ -26,7 +25,7 @@ cleared = remove_small_objects(clear_border(bw), 20) # label image regions -label_image = label(cleared) +label_image = label(cleared).astype("uint8") # initialise viewer with coins image viewer = napari.view_image(image, name='coins', rgb=False) diff --git a/napari/layers/labels/_tests/test_labels.py b/napari/layers/labels/_tests/test_labels.py index 3d0f75443e2..e6db0a74e7d 100644 --- a/napari/layers/labels/_tests/test_labels.py +++ b/napari/layers/labels/_tests/test_labels.py @@ -23,6 +23,7 @@ from napari.layers.labels._labels_utils import get_contours from napari.utils import Colormap from napari.utils.colormaps import label_colormap +from napari.utils.colormaps.colormap import DirectLabelColormap def test_random_labels(): @@ -1685,6 +1686,35 @@ def test_copy(): assert l1.data is l3.data +@pytest.mark.parametrize( + "colormap,expected", + [ + (label_colormap(49, 0.5), [0, 1]), + ( + DirectLabelColormap( + color_dict={ + 0: np.array([0, 0, 0, 0]), + 1: np.array([1, 0, 0, 1]), + None: np.array([1, 1, 0, 1]), + } + ), + [1, 2], + ), + ], + ids=["auto", "direct"], +) +def test_draw(colormap, expected): + labels = Labels(np.zeros((30, 30), dtype=np.uint32)) + labels.mode = "paint" + labels.colormap = colormap + labels.selected_label = 1 + npt.assert_array_equal(np.unique(labels._slice.image.raw), [0]) + npt.assert_array_equal(np.unique(labels._slice.image.view), expected[:1]) + labels._draw(1, (15, 15), (15, 15)) + npt.assert_array_equal(np.unique(labels._slice.image.raw), [0, 1]) + npt.assert_array_equal(np.unique(labels._slice.image.view), expected) + + class TestLabels: @staticmethod def get_objects(): diff --git a/napari/layers/labels/labels.py b/napari/layers/labels/labels.py index 2df8b8c5d55..c91ca66cf5b 100644 --- a/napari/layers/labels/labels.py +++ b/napari/layers/labels/labels.py @@ -992,9 +992,22 @@ def _setup_cache(self, labels): if self._cached_labels is not None: return + if isinstance(self._colormap, LabelColormap): + mapped_background = _cast_labels_data_to_texture_dtype_auto( + labels.dtype.type(self.colormap.background_value), + self._random_colormap, + ) + else: # direct + mapped_background = _cast_labels_data_to_texture_dtype_direct( + labels.dtype.type(self.colormap.background_value), + self._direct_colormap, + ) + self._cached_labels = np.zeros_like(labels) - self._cached_mapped_labels = np.zeros_like( - labels, dtype=self._get_cache_dtype(labels.dtype) + self._cached_mapped_labels = np.full( + shape=labels.shape, + fill_value=mapped_background, + dtype=self._get_cache_dtype(labels.dtype), ) def _raw_to_displayed(