From e18d2b8f3cb8315874dc35a8b21db9bf3b3ed1a2 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Tue, 19 Nov 2024 21:36:57 +0100 Subject: [PATCH 1/9] Add label stitching functionality --- elf/segmentation/stitching.py | 163 ++++++++++++++++++++++++++++++---- 1 file changed, 148 insertions(+), 15 deletions(-) diff --git a/elf/segmentation/stitching.py b/elf/segmentation/stitching.py index d2fe61c..e8a26b6 100644 --- a/elf/segmentation/stitching.py +++ b/elf/segmentation/stitching.py @@ -1,10 +1,12 @@ import multiprocessing from concurrent import futures +from typing import Tuple, Optional, Callable -import nifty.tools as nt -import numpy as np import vigra -from nifty.ground_truth import overlap +import numpy as np + +import nifty.tools as nt +from nifty.ground_truth import overlap as compute_overlap try: from napari.utils import progress as tqdm @@ -16,11 +18,17 @@ def stitch_segmentation( - input_, segmentation_function, - tile_shape, tile_overlap, beta=0.5, - shape=None, with_background=True, n_threads=None, - return_before_stitching=False, verbose=True, -): + input_: np.ndarray, + segmentation_function: Callable, + tile_shape: Tuple[int, int], + tile_overlap: Tuple[int, int], + beta: float = 0.5, + shape: Optional[Tuple[int, int]] = None, + with_background: bool = True, + n_threads: Optional[int] = None, + return_before_stitching: bool = False, + verbose: bool = True, +) -> np.ndarray: """Run segmentation function tilewise and stitch the results based on overlap. Arguments: @@ -28,7 +36,7 @@ def stitch_segmentation( e.g. XYC for a 2D image with channels. segmentation_function [callable] - the function to perform segmentation for each tile. Needs to be a segmentation that takes the input (for the tile) as well as the id of the tile as input. - I.e. the function needs to have a signature like this: 'def my_seg_func(tile_input_, tile_id)'. + i.e. the function needs to have a signature like this: 'def my_seg_func(tile_input_, tile_id)'. The tile_id is passed in case the segmentation routine differs depending on the tile; it can be ignored in most cases. tile_shape [tuple] - shape of the individual tiles. @@ -101,10 +109,12 @@ def _compute_overlaps(block_id): this_seg, ngb_seg = block_segs[block_id], block_segs[ngb_id] # get the global coordinates of the block face - face = tuple(slice(beg_out, end_out) if d != axis else slice(beg_out, beg_in + tile_overlap[d]) - for d, (beg_out, end_out, beg_in) in enumerate(zip(this_block.outerBlock.begin, - this_block.outerBlock.end, - this_block.innerBlock.begin))) + face = tuple( + slice(beg_out, end_out) if d != axis else slice(beg_out, beg_in + tile_overlap[d]) + for d, (beg_out, end_out, beg_in) in enumerate( + zip(this_block.outerBlock.begin, this_block.outerBlock.end, this_block.innerBlock.begin) + )) + # map to the two local face coordinates this_face_bb = tuple( slice(fa.start - offset, fa.stop - offset) for fa, offset in zip(face, this_block.outerBlock.begin) @@ -116,10 +126,10 @@ def _compute_overlaps(block_id): # load the two segmentations for the face this_face = this_seg[this_face_bb] ngb_face = ngb_seg[ngb_face_bb] - assert this_face.shape == ngb_face.shape + assert this_face.shape == ngb_face.shape, (this_face.shape, ngb_face.shape) # compute the object overlaps - overlap_comp = overlap(this_face, ngb_face) + overlap_comp = compute_overlap(this_face, ngb_face) this_ids = np.unique(this_face) overlaps = {this_id: overlap_comp.overlapArraysNormalized(this_id, sorted=False) for this_id in this_ids} overlap_ids = {this_id: ovlps[0] for this_id, ovlps in overlaps.items()} @@ -175,4 +185,127 @@ def _compute_overlaps(block_id): if return_before_stitching: return seg_stitched, seg + + return seg_stitched + + +def stitch_tiled_segmentation( + segmentation: np.ndarray, + tile_shape: Tuple[int, int], + overlap: int = 1, + n_threads: Optional[int] = None, + verbose: bool = True, +) -> np.ndarray: + """Functionality for stitching segmentations tile-wise based on overlap. + + Args: + segmentation: The input segmentation. + tile_shape: The shape of inidividual tiles. + overlap: The overlap of tiles. + It is responsible to compute the edge nodes for the desired overlap region. + n_threads: The number of threads used for parallelized operations. + Set to the number of cores by default. + verbose: Whether to print the progress bars. + + Returns: + The stitched segmentation with merged labels. + """ + shape = segmentation.shape + ndim = len(shape) + blocking = nt.blocking([0] * ndim, shape, tile_shape) + n_blocks = blocking.numberOfBlocks + + block_segs = [] + + # Get the tiles from the segmentation of shape: 'tile_shape'. + def _fetch_tiles(block_id): + block = blocking.getBlock(block_id) + bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end)) + block_seg = segmentation[bb] + block_segs.append(block_seg) + + n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads + with futures.ThreadPoolExecutor(n_threads) as tp: + list(tqdm(tp.map( + _fetch_tiles, range(n_blocks)), total=n_blocks, desc="Get tiles from the segmentation", disable=not verbose, + )) + + # Conpute the Region Adjacency Graph (RAG) for the tiled segmentation. + # and the edges between block boundaries (stitch edges). + seg_ids = np.unique(segmentation) + rag = compute_rag(segmentation) + + # We initialize the edge disaffinities with a high value (corresponding to a low overlap) + # so that merging things that are not on the edge is very unlikely + # but not completely impossible in case it is needed for a consistent solution. + edge_disaffinities = np.full(rag.numberOfEdges, 0.9, dtype="float32") + + def _compute_overlaps(block_id): + # For each axis, load the face with the lower block neighbor and compute the object overlaps + for axis in range(ndim): + ngb_id = blocking.getNeighborId(block_id, axis, lower=True) + if ngb_id == -1: + continue + + # Load the respective tiles. + this_seg, ngb_seg = block_segs[block_id], block_segs[ngb_id] + + # Get the local face coordinates of the respective tiles. + # We get the face region of the shape defined by 'overlap' + # eg. The default '1' returns a 1d cross-section of the tile interfaces. + face_bb = tuple(slice(None) if d != axis else slice(0, overlap) for d in range(ndim)) + ngb_face_bb = tuple( + slice(None) if d != axis else slice(ngb_seg.shape[d] - overlap, ngb_seg.shape[d]) for d in range(ndim) + ) + + # Load the two segmentations for the face. + this_face = this_seg[face_bb] + ngb_face = ngb_seg[ngb_face_bb] + + # Both the faces from each tile are expected to be of the same shape + assert this_face.shape == ngb_face.shape, (this_face.shape, ngb_face.shape) + + # Compute the object overlaps. + # In this step, we compute the per-instance overlap over both faces + overlap_comp = compute_overlap(this_face, ngb_face) + this_ids = np.unique(this_face).astype("uint32") + overlaps = {this_id: overlap_comp.overlapArraysNormalized(this_id, sorted=False) for this_id in this_ids} + overlap_ids = {this_id: ovlps[0] for this_id, ovlps in overlaps.items()} + overlap_values = {this_id: ovlps[1] for this_id, ovlps in overlaps.items()} + overlap_uv_ids = np.array([ + [this_id, ovlp_id] for this_id, ovlp_ids in overlap_ids.items() for ovlp_id in ovlp_ids + ]) + overlap_values = np.array([ovlp for ovlps in overlap_values.values() for ovlp in ovlps], dtype="float32") + assert len(overlap_uv_ids) == len(overlap_values) + + # Next, we remove the invalid edges. + # We might have ids in the overlaps that are not in the segmentation. We filter them out. + valid_uv_ids = np.isin(overlap_uv_ids, seg_ids).all(axis=1) + if valid_uv_ids.sum() == 0: + continue + overlap_uv_ids, overlap_values = overlap_uv_ids[valid_uv_ids], overlap_values[valid_uv_ids] + assert len(overlap_uv_ids) == len(overlap_values) + + # Get the edge ids. + edge_ids = rag.findEdges(overlap_uv_ids) + valid_edges = edge_ids != -1 + if valid_edges.sum() == 0: + continue + edge_ids, overlap_values = edge_ids[valid_edges], overlap_values[valid_edges] + assert len(edge_ids) == len(overlap_values) + + # And set the global edge disaffinities to (1 - overlap). + edge_disaffinities[edge_ids] = (1.0 - overlap_values) + + with futures.ThreadPoolExecutor(n_threads) as tp: + list(tqdm(tp.map( + _compute_overlaps, range(n_blocks)), total=n_blocks, desc="Compute object overlaps", disable=not verbose, + )) + + costs = compute_edge_costs(edge_disaffinities, beta=0.5) + + # Run multicut to get the segmentation result. + node_labels = multicut_decomposition(rag, costs) + seg_stitched = project_node_labels_to_pixels(rag, node_labels) + return seg_stitched From 4d07cc14ae78fe88a62f78b213b00a6ff9c3d4bd Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Tue, 19 Nov 2024 23:15:32 +0100 Subject: [PATCH 2/9] Restore extracting tiles block wise --- elf/segmentation/stitching.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/elf/segmentation/stitching.py b/elf/segmentation/stitching.py index e8a26b6..554548a 100644 --- a/elf/segmentation/stitching.py +++ b/elf/segmentation/stitching.py @@ -218,18 +218,12 @@ def stitch_tiled_segmentation( block_segs = [] # Get the tiles from the segmentation of shape: 'tile_shape'. - def _fetch_tiles(block_id): + for block_id in tqdm(range(n_blocks), desc="Get tiles from the segmentation", disable=not verbose): block = blocking.getBlock(block_id) bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end)) block_seg = segmentation[bb] block_segs.append(block_seg) - n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads - with futures.ThreadPoolExecutor(n_threads) as tp: - list(tqdm(tp.map( - _fetch_tiles, range(n_blocks)), total=n_blocks, desc="Get tiles from the segmentation", disable=not verbose, - )) - # Conpute the Region Adjacency Graph (RAG) for the tiled segmentation. # and the edges between block boundaries (stitch edges). seg_ids = np.unique(segmentation) From 5dad99c283fb0f48deefd4f36bab0fb204759586 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Wed, 20 Nov 2024 13:48:21 +0100 Subject: [PATCH 3/9] Add test for stitching labels --- test/segmentation/test_stitching.py | 47 ++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/test/segmentation/test_stitching.py b/test/segmentation/test_stitching.py index efa7c76..b70960b 100644 --- a/test/segmentation/test_stitching.py +++ b/test/segmentation/test_stitching.py @@ -1,15 +1,48 @@ import unittest -from elf.evaluation import rand_index +import numpy as np from skimage.data import binary_blobs from skimage.measure import label +from elf.evaluation import rand_index + class TestStitching(unittest.TestCase): def get_data(self, size=1024, ndim=2): data = binary_blobs(size, blob_size_fraction=0.1, volume_fraction=0.2, n_dim=ndim) return data + def get_tiled_data(self, size=1024, ndim=2, tile_shape=(512, 512)): + data = self.get_data(size=size, ndim=ndim) + data = label(data) # Ensure all inputs are instances (the blobs are semantic labels) + + # Create tiles out of the data. + # Ensure offset for objects per tile to get individual ids per object per tile. + import nifty.tools as nt + blocking = nt.blocking([0] * ndim, data.shape, tile_shape) + n_blocks = blocking.numberOfBlocks + + offset = 0 + bb_tiles, tiles = [], [] + for tile_id in range(n_blocks): + block = blocking.getBlock(tile_id) + bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end)) + + tile = data[bb] + tile = label(tile) + tile[tile != 0] += offset + offset = tile.max() + + tiles.append(tile) + bb_tiles.append(bb) + + # Finally, let's stitch back the individual tiles. + labels = np.zeros(data.shape) + for tile, loc in zip(tiles, bb_tiles): + labels[loc] = tile + + return labels, data # returns the stitched labels and original labels + def test_stitch_segmentation(self): from elf.segmentation.stitching import stitch_segmentation @@ -43,6 +76,18 @@ def _segment(input_, block_id=None): are, _ = rand_index(segmentation, expected_segmentation) self.assertTrue(are < 0.05) + def test_stitch_tiled_segmentation(self): + from elf.segmentation.stitching import stitch_tiled_segmentation + + tile_shapes = [(224, 224), (256, 256), (512, 512)] + for tile_shape in tile_shapes: + # Get the tiled segmentation with unmerged instances at tile interfaces. + labels, original_labels = self.get_tiled_data() + stitched_labels = stitch_tiled_segmentation(segmentation=labels, tile_shape=tile_shape) + self.assertEqual(labels.shape, stitched_labels.shape) + # self.assertEqual(len(np.unique(original_labels)), len(np.unique(stitched_labels))) + print(len(np.unique(original_labels)), len(np.unique(stitched_labels))) + if __name__ == "__main__": unittest.main() From 891a297207b5d504fc0004ed39aea6cef993d0e1 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Wed, 20 Nov 2024 19:18:40 +0100 Subject: [PATCH 4/9] Refactor tiling and stitching of test samples into one step --- test/segmentation/test_stitching.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/test/segmentation/test_stitching.py b/test/segmentation/test_stitching.py index b70960b..9a34ffd 100644 --- a/test/segmentation/test_stitching.py +++ b/test/segmentation/test_stitching.py @@ -16,14 +16,15 @@ def get_tiled_data(self, size=1024, ndim=2, tile_shape=(512, 512)): data = self.get_data(size=size, ndim=ndim) data = label(data) # Ensure all inputs are instances (the blobs are semantic labels) - # Create tiles out of the data. + # Create tiles out of the data for testing label stitching. # Ensure offset for objects per tile to get individual ids per object per tile. + # And finally stitch back the tiles. import nifty.tools as nt blocking = nt.blocking([0] * ndim, data.shape, tile_shape) n_blocks = blocking.numberOfBlocks + labels = np.zeros(data.shape) offset = 0 - bb_tiles, tiles = [], [] for tile_id in range(n_blocks): block = blocking.getBlock(tile_id) bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end)) @@ -33,13 +34,7 @@ def get_tiled_data(self, size=1024, ndim=2, tile_shape=(512, 512)): tile[tile != 0] += offset offset = tile.max() - tiles.append(tile) - bb_tiles.append(bb) - - # Finally, let's stitch back the individual tiles. - labels = np.zeros(data.shape) - for tile, loc in zip(tiles, bb_tiles): - labels[loc] = tile + labels[bb] = tile return labels, data # returns the stitched labels and original labels From 18f4a087c94345ea0c1151a6faa4af2f80b9a5d3 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Wed, 20 Nov 2024 22:39:25 +0100 Subject: [PATCH 5/9] Update original data patch shape --- test/segmentation/test_stitching.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/segmentation/test_stitching.py b/test/segmentation/test_stitching.py index 9a34ffd..2d9a85e 100644 --- a/test/segmentation/test_stitching.py +++ b/test/segmentation/test_stitching.py @@ -12,9 +12,9 @@ def get_data(self, size=1024, ndim=2): data = binary_blobs(size, blob_size_fraction=0.1, volume_fraction=0.2, n_dim=ndim) return data - def get_tiled_data(self, size=1024, ndim=2, tile_shape=(512, 512)): + def get_tiled_data(self, tile_shape, size=1024, ndim=2): data = self.get_data(size=size, ndim=ndim) - data = label(data) # Ensure all inputs are instances (the blobs are semantic labels) + original_data = label(data) # Ensure all inputs are instances (the blobs are semantic labels) # Create tiles out of the data for testing label stitching. # Ensure offset for objects per tile to get individual ids per object per tile. @@ -36,7 +36,7 @@ def get_tiled_data(self, size=1024, ndim=2, tile_shape=(512, 512)): labels[bb] = tile - return labels, data # returns the stitched labels and original labels + return labels, original_data # returns the stitched labels and original labels def test_stitch_segmentation(self): from elf.segmentation.stitching import stitch_segmentation @@ -77,7 +77,7 @@ def test_stitch_tiled_segmentation(self): tile_shapes = [(224, 224), (256, 256), (512, 512)] for tile_shape in tile_shapes: # Get the tiled segmentation with unmerged instances at tile interfaces. - labels, original_labels = self.get_tiled_data() + labels, original_labels = self.get_tiled_data(tile_shape=tile_shape, size=1000) stitched_labels = stitch_tiled_segmentation(segmentation=labels, tile_shape=tile_shape) self.assertEqual(labels.shape, stitched_labels.shape) # self.assertEqual(len(np.unique(original_labels)), len(np.unique(stitched_labels))) From 1ed5b84124064602f404465a9b2c32fc889a318e Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 29 Dec 2024 22:50:48 +0100 Subject: [PATCH 6/9] Fix issues with tile offset in segmentation stitching --- elf/segmentation/stitching.py | 92 +++++++++++++++-------------- test/segmentation/test_stitching.py | 25 ++++---- 2 files changed, 64 insertions(+), 53 deletions(-) diff --git a/elf/segmentation/stitching.py b/elf/segmentation/stitching.py index 554548a..ec2b610 100644 --- a/elf/segmentation/stitching.py +++ b/elf/segmentation/stitching.py @@ -1,6 +1,6 @@ import multiprocessing from concurrent import futures -from typing import Tuple, Optional, Callable +from typing import Callable, Tuple, Optional, Union import vigra import numpy as np @@ -28,31 +28,34 @@ def stitch_segmentation( n_threads: Optional[int] = None, return_before_stitching: bool = False, verbose: bool = True, -) -> np.ndarray: +) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """Run segmentation function tilewise and stitch the results based on overlap. - Arguments: - input_ [np.ndarray] - the input data. If the data has channels they need to be passed as last dimension, + Args: + input_: The input data. If the data has channels they need to be passed as last dimension, e.g. XYC for a 2D image with channels. - segmentation_function [callable] - the function to perform segmentation for each tile. - Needs to be a segmentation that takes the input (for the tile) as well as the id of the tile as input. + segmentation_function: the function to perform segmentation for each tile. + It must take the input (for the tile) as well as the id of the tile as input; i.e. the function needs to have a signature like this: 'def my_seg_func(tile_input_, tile_id)'. - The tile_id is passed in case the segmentation routine differs depending on the tile; - it can be ignored in most cases. - tile_shape [tuple] - shape of the individual tiles. - tile_overlap [tuple] - overlap of the tiles. + The tile_id is passed in case the segmentation differs based on the tile and can be ignored otherwise. + tile_shape: Shape of the individual tiles. + tile_overlap: Overlap of the tiles. The input to the segmentation function will have the size tile_shape + 2 * tile_overlap. The tile overlap will be used to compute the overlap between objects, which will be used for stitching. - beta [float] - parameter to bias the stitching results towards more over-segmentation (beta > 0.5) - or more under-segmentation (beta < 0.5). Has to be in the exclusive range 0 to 1. (default: 0.5) - shape [tuple] - the shape of the segmentation. By default this will use the shape of the input, but if the - input has channels it needs to be passed manually. (default: None) - with_background [bool] - whether this is a segmentation problem with background. In this case the - background id (which is hard-coded to 0), will not be stitched. (default: True) - n_threads [int] - number of threads that will be used for parallelized operations. - Set to the number of cores by default. (default: None) - return_before_stitching [bool] - return the result before stitching (for debugging). (default: False) - verbose [bool] - whether to print progress bars. (default: True) + beta: Parameter to bias the stitching results towards more over-segmentation (beta > 0.5) + or more under-segmentation (beta < 0.5). Has to be in the exclusive range 0 to 1. + shape: Shape of the segmentation. By default this will use the shape of the input, but if the + input has channels it needs to be passed. + with_background: Whether this is a segmentation problem with background. In this case the + background id (which is hard-coded to 0), will not be stitched. + n_threads: Number of threads that will be used for parallelized operations. + Set to the number of cores by default. + return_before_stitching: Return the result before stitching for debugging. + verbose: Whether to print progress bars. + + Returns: + The stitched segmentation. + The segmentation before stitching, if return_before_stitching is set to True. """ shape = input_.shape if shape is None else shape @@ -64,19 +67,24 @@ def stitch_segmentation( seg = np.zeros(shape, dtype="uint64") n_blocks = blocking.numberOfBlocks - # TODO enable parallelisation - # run tiled segmentation + + # Run tiled segmentation. for block_id in tqdm(range(n_blocks), total=n_blocks, desc="Run tiled segmentation", disable=not verbose): block = blocking.getBlockWithHalo(block_id, list(tile_overlap)) outer_bb = tuple(slice(beg, end) for beg, end in zip(block.outerBlock.begin, block.outerBlock.end)) block_input = input_[outer_bb] block_seg = segmentation_function(block_input, block_id) + if with_background: - block_seg[block_seg != 0] += id_offset + block_mask = block_seg != 0 + # We need to make sure that empty blocks do not reset the offset. + if block_mask.sum() > 0: + block_seg[block_mask] += id_offset + id_offset = block_seg.max() else: block_seg += id_offset - id_offset = block_seg.max() + id_offset = block_seg.max() inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end)) local_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end)) @@ -84,19 +92,19 @@ def stitch_segmentation( seg[inner_bb] = block_seg[local_bb] block_segs.append(block_seg) - # compute the region adjacency graph for the tiled segmentation - # and the edges between block boundaries (stitch edges) + # Compute the region adjacency graph for the tiled segmentation. + # In order to computhe the the edges between block boundaries (stitch edges). seg_ids = np.unique(seg) rag = compute_rag(seg, n_threads=n_threads) - # we initialize the edge disaffinities with a high value (corresponding to a low overlap) - # so that merging things that are not on the edge is very unlikely - # but not completely impossible in case it is needed for a consistent solution + # We initialize the edge disaffinities with a high value (corresponding to a low overlap), + # so that merging pairs that are not on the edge is very unlikely + # but not completely impossible in case it is needed for a consistent solution. edge_disaffinties = np.full(rag.numberOfEdges, 0.9, dtype="float32") def _compute_overlaps(block_id): - # for each axis, load the face with the lower block neighbor and compute the object overlaps + # For each axis, load the face with the lower block neighbor and compute the object overlaps. for axis in range(ndim): ngb_id = blocking.getNeighborId(block_id, axis, lower=True) if ngb_id == -1: @@ -105,17 +113,18 @@ def _compute_overlaps(block_id): this_block = blocking.getBlockWithHalo(block_id, list(tile_overlap)) ngb_block = blocking.getBlockWithHalo(ngb_id, list(tile_overlap)) - # load the full block segmentations + # Load the full block segmentations. this_seg, ngb_seg = block_segs[block_id], block_segs[ngb_id] - # get the global coordinates of the block face + # Get the global coordinates of the block face. face = tuple( slice(beg_out, end_out) if d != axis else slice(beg_out, beg_in + tile_overlap[d]) for d, (beg_out, end_out, beg_in) in enumerate( zip(this_block.outerBlock.begin, this_block.outerBlock.end, this_block.innerBlock.begin) - )) + ) + ) - # map to the two local face coordinates + # Map to the two local face coordinates. this_face_bb = tuple( slice(fa.start - offset, fa.stop - offset) for fa, offset in zip(face, this_block.outerBlock.begin) ) @@ -123,12 +132,12 @@ def _compute_overlaps(block_id): slice(fa.start - offset, fa.stop - offset) for fa, offset in zip(face, ngb_block.outerBlock.begin) ) - # load the two segmentations for the face + # Load the two segmentations for the face. this_face = this_seg[this_face_bb] ngb_face = ngb_seg[ngb_face_bb] assert this_face.shape == ngb_face.shape, (this_face.shape, ngb_face.shape) - # compute the object overlaps + # Compute the object overlaps. overlap_comp = compute_overlap(this_face, ngb_face) this_ids = np.unique(this_face) overlaps = {this_id: overlap_comp.overlapArraysNormalized(this_id, sorted=False) for this_id in this_ids} @@ -140,11 +149,9 @@ def _compute_overlaps(block_id): overlap_values = np.array([ovlp for ovlps in overlap_values.values() for ovlp in ovlps], dtype="float32") assert len(overlap_uv_ids) == len(overlap_values) - # - get the edge ids - # - exclude invalid edge - # - set the global edge disaffinities to 1 - overlap + # Get the edge ids, then exclude invalid edges and set the edge disaffinities to 1 - overlap. - # we might have ids in the overlaps that are not in the final seg, these need to be filtered + # We might have ids in the overlaps that are not in the final segmentation, these need to be filtered. valid_uv_ids = np.isin(overlap_uv_ids, seg_ids).all(axis=1) if valid_uv_ids.sum() == 0: continue @@ -166,15 +173,14 @@ def _compute_overlaps(block_id): _compute_overlaps, range(n_blocks)), total=n_blocks, desc="Compute object overlaps", disable=not verbose )) - # if we have background set all the edges that are connecting 0 to another element - # to be very unlikely + # If we have background, then set all the edges that are connecting 0 to another element to be very unlikely. if with_background: uv_ids = rag.uvIds() bg_edges = rag.findEdges(uv_ids[(uv_ids == 0).any(axis=1)]) edge_disaffinties[bg_edges] = 0.99 costs = compute_edge_costs(edge_disaffinties, beta=beta) - # run multicut to get the segmentation result + # Run multicut to get the segmentation result. node_labels = multicut_decomposition(rag, costs) seg_stitched = project_node_labels_to_pixels(rag, node_labels, n_threads=n_threads) diff --git a/test/segmentation/test_stitching.py b/test/segmentation/test_stitching.py index 2d9a85e..e5fc1d1 100644 --- a/test/segmentation/test_stitching.py +++ b/test/segmentation/test_stitching.py @@ -1,6 +1,7 @@ import unittest import numpy as np +import nifty.tools as nt from skimage.data import binary_blobs from skimage.measure import label @@ -9,7 +10,7 @@ class TestStitching(unittest.TestCase): def get_data(self, size=1024, ndim=2): - data = binary_blobs(size, blob_size_fraction=0.1, volume_fraction=0.2, n_dim=ndim) + data = binary_blobs(size, blob_size_fraction=0.1, volume_fraction=0.25, n_dim=ndim) return data def get_tiled_data(self, tile_shape, size=1024, ndim=2): @@ -19,7 +20,6 @@ def get_tiled_data(self, tile_shape, size=1024, ndim=2): # Create tiles out of the data for testing label stitching. # Ensure offset for objects per tile to get individual ids per object per tile. # And finally stitch back the tiles. - import nifty.tools as nt blocking = nt.blocking([0] * ndim, data.shape, tile_shape) n_blocks = blocking.numberOfBlocks @@ -31,8 +31,10 @@ def get_tiled_data(self, tile_shape, size=1024, ndim=2): tile = data[bb] tile = label(tile) - tile[tile != 0] += offset - offset = tile.max() + tile_mask = tile != 0 + if tile_mask.sum() > 0: + tile[tile_mask] += offset + offset = tile.max() labels[bb] = tile @@ -46,14 +48,16 @@ def _segment(input_, block_id=None): return segmentation.astype("uint32") tile_overlap = (32, 32) - tile_shapes = [(128, 128), (256, 256), (128, 256)] + tile_shapes = [(128, 128), (256, 256), (128, 256), (224, 224)] for tile_shape in tile_shapes: for _ in range(3): # test 3 times with different data data = self.get_data() expected_segmentation = _segment(data) segmentation = stitch_segmentation(data, _segment, tile_shape, tile_overlap, verbose=False) + are, _ = rand_index(segmentation, expected_segmentation) - self.assertTrue(are < 0.05) + # We allow for some tolerance, because small objects might get stitched incorrectly. + self.assertTrue(np.isclose(are, 0, rtol=1e-3, atol=1e-3)) def test_stitch_segmentation_3d(self): from elf.segmentation.stitching import stitch_segmentation @@ -69,7 +73,7 @@ def _segment(input_, block_id=None): expected_segmentation = _segment(data) segmentation = stitch_segmentation(data, _segment, tile_shape, tile_overlap, verbose=False) are, _ = rand_index(segmentation, expected_segmentation) - self.assertTrue(are < 0.05) + self.assertTrue(np.isclose(are, 0, rtol=1e-3, atol=1e-3)) def test_stitch_tiled_segmentation(self): from elf.segmentation.stitching import stitch_tiled_segmentation @@ -78,10 +82,11 @@ def test_stitch_tiled_segmentation(self): for tile_shape in tile_shapes: # Get the tiled segmentation with unmerged instances at tile interfaces. labels, original_labels = self.get_tiled_data(tile_shape=tile_shape, size=1000) - stitched_labels = stitch_tiled_segmentation(segmentation=labels, tile_shape=tile_shape) + stitched_labels = stitch_tiled_segmentation(segmentation=labels, tile_shape=tile_shape, verbose=False) self.assertEqual(labels.shape, stitched_labels.shape) - # self.assertEqual(len(np.unique(original_labels)), len(np.unique(stitched_labels))) - print(len(np.unique(original_labels)), len(np.unique(stitched_labels))) + + are, _ = rand_index(stitched_labels, original_labels) + self.assertTrue(np.isclose(are, 0, rtol=1e-3, atol=1e-3)) if __name__ == "__main__": From 1b59252107eb078b6bf158d784e88e105ec17353 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 29 Dec 2024 23:00:52 +0100 Subject: [PATCH 7/9] Use imageio.v3 --- elf/io/image_stack_wrapper.py | 4 ++-- elf/io/knossos_wrapper.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/elf/io/image_stack_wrapper.py b/elf/io/image_stack_wrapper.py index 385d65b..446d05a 100644 --- a/elf/io/image_stack_wrapper.py +++ b/elf/io/image_stack_wrapper.py @@ -4,7 +4,7 @@ from glob import glob import numpy as np -import imageio +import imageio.v3 as imageio try: import tifffile @@ -149,7 +149,7 @@ def _read_image(self, index): return imageio.imread(self.files[index]) def _read_volume(self): - return imageio.volread(self.files) + return imageio.imread(self.files) def _load_roi_from_stack(self, roi): return self._volume[roi] diff --git a/elf/io/knossos_wrapper.py b/elf/io/knossos_wrapper.py index a0dec5a..751a37d 100644 --- a/elf/io/knossos_wrapper.py +++ b/elf/io/knossos_wrapper.py @@ -3,7 +3,7 @@ from concurrent import futures import numpy as np -import imageio +import imageio.v3 as imageio from ..util import (normalize_index, squeeze_singletons, map_chunk_to_roi, chunks_overlapping_roi) From 7ea24d5ba3733437ddec8df2190c2e8c5aaee1ae Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 30 Dec 2024 00:48:38 +0100 Subject: [PATCH 8/9] Update segmentation stitching functionality --- elf/segmentation/stitching.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/elf/segmentation/stitching.py b/elf/segmentation/stitching.py index ec2b610..c2320e3 100644 --- a/elf/segmentation/stitching.py +++ b/elf/segmentation/stitching.py @@ -100,7 +100,7 @@ def stitch_segmentation( # We initialize the edge disaffinities with a high value (corresponding to a low overlap), # so that merging pairs that are not on the edge is very unlikely # but not completely impossible in case it is needed for a consistent solution. - edge_disaffinties = np.full(rag.numberOfEdges, 0.9, dtype="float32") + edge_disaffinities = np.full(rag.numberOfEdges, 0.9, dtype="float32") def _compute_overlaps(block_id): @@ -165,7 +165,7 @@ def _compute_overlaps(block_id): edge_ids, overlap_values = edge_ids[valid_edges], overlap_values[valid_edges] assert len(edge_ids) == len(overlap_values) - edge_disaffinties[edge_ids] = (1.0 - overlap_values) + edge_disaffinities[edge_ids] = (1.0 - overlap_values) n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads with futures.ThreadPoolExecutor(n_threads) as tp: @@ -177,8 +177,8 @@ def _compute_overlaps(block_id): if with_background: uv_ids = rag.uvIds() bg_edges = rag.findEdges(uv_ids[(uv_ids == 0).any(axis=1)]) - edge_disaffinties[bg_edges] = 0.99 - costs = compute_edge_costs(edge_disaffinties, beta=beta) + edge_disaffinities[bg_edges] = 0.99 + costs = compute_edge_costs(edge_disaffinities, beta=beta) # Run multicut to get the segmentation result. node_labels = multicut_decomposition(rag, costs) @@ -199,18 +199,22 @@ def stitch_tiled_segmentation( segmentation: np.ndarray, tile_shape: Tuple[int, int], overlap: int = 1, + with_background: bool = True, n_threads: Optional[int] = None, verbose: bool = True, ) -> np.ndarray: - """Functionality for stitching segmentations tile-wise based on overlap. + """Stitch a segmentation that is split into tiles. + + The ids in the tiles of the input segmentation have to be unique, + i.e. the segmentations have to be separate across tiles. Args: segmentation: The input segmentation. - tile_shape: The shape of inidividual tiles. - overlap: The overlap of tiles. - It is responsible to compute the edge nodes for the desired overlap region. - n_threads: The number of threads used for parallelized operations. - Set to the number of cores by default. + tile_shape: The shape of tiles. + overlap: The overlap between adjacent tiles that is used to compute overlap for stitching objects. + with_background: Whether this is a segmentation problem with background. In this case the + background id (which is hard-coded to 0), will not be stitched. + n_threads: The number of threads used for parallelized operations. Set to the number of cores by default. verbose: Whether to print the progress bars. Returns: @@ -302,6 +306,10 @@ def _compute_overlaps(block_id): _compute_overlaps, range(n_blocks)), total=n_blocks, desc="Compute object overlaps", disable=not verbose, )) + uv_ids = rag.uvIds() + if with_background: + bg_edges = rag.findEdges(uv_ids[(uv_ids == 0).any(axis=1)]) + edge_disaffinities[bg_edges] = 0.99 costs = compute_edge_costs(edge_disaffinities, beta=0.5) # Run multicut to get the segmentation result. From 8134fce7e453d9d36e26530d1a6b1bc420c8514d Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 30 Dec 2024 10:36:56 +0100 Subject: [PATCH 9/9] Lower test stringency --- test/segmentation/test_stitching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/segmentation/test_stitching.py b/test/segmentation/test_stitching.py index e5fc1d1..08dbfab 100644 --- a/test/segmentation/test_stitching.py +++ b/test/segmentation/test_stitching.py @@ -73,7 +73,7 @@ def _segment(input_, block_id=None): expected_segmentation = _segment(data) segmentation = stitch_segmentation(data, _segment, tile_shape, tile_overlap, verbose=False) are, _ = rand_index(segmentation, expected_segmentation) - self.assertTrue(np.isclose(are, 0, rtol=1e-3, atol=1e-3)) + self.assertTrue(np.isclose(are, 0, rtol=1e-2, atol=1e-2)) def test_stitch_tiled_segmentation(self): from elf.segmentation.stitching import stitch_tiled_segmentation