From f6b8d8013534735cbd44b044085fdd4a7265b9f1 Mon Sep 17 00:00:00 2001 From: Ryan William Conrad Date: Tue, 1 Nov 2022 15:56:22 -0400 Subject: [PATCH] added stitching of 3d chunks --- empanada/array_utils.py | 27 +-- empanada/data/bc_dataset.py | 12 +- empanada/inference/engines.py | 5 +- empanada/inference/matcher.py | 186 +++++++++++++++--- empanada/inference/stitch.py | 175 ++++++++++++++++ empanada/inference/tile.py | 81 +++++++- empanada/inference/watershed.py | 13 +- empanada/losses.py | 149 ++++++++++++++ empanada/models/point_rend.py | 27 ++- empanada/models/quantization/__init__.py | 4 +- .../models/quantization/panoptic_bifpn.py | 3 +- .../models/quantization/panoptic_deeplab.py | 144 +++++++++++++- .../mmm_median_inference_celegans.yaml | 4 +- .../mmm_median_inference_fly_brain.yaml | 4 +- .../configs/mmm_median_inference_hela.yaml | 2 +- .../configs/mmm_panoptic_deeplab_bc.yaml | 6 +- .../configs/mmm_ws_inference_fly_brain.yaml | 12 +- projects/mitonet/scripts/evaluate3d.py | 9 +- projects/mitonet/scripts/evaluate3d_bc.py | 2 + scripts/export_model.py | 12 +- scripts/train.py | 4 +- 21 files changed, 784 insertions(+), 97 deletions(-) create mode 100644 empanada/inference/stitch.py diff --git a/empanada/array_utils.py b/empanada/array_utils.py index be4b72f..fe098e0 100644 --- a/empanada/array_utils.py +++ b/empanada/array_utils.py @@ -137,23 +137,6 @@ def merge_boxes(box1, box2): return tuple(merged_box) -def box_iou(boxes1, boxes2=None, return_intersection=False): - # do pairwise box iou if no boxes2 - if boxes2 is None: - boxes2 = boxes1 - - intersect = box_intersection(boxes1, boxes2) - area1 = box_area(boxes1) - area2 = box_area(boxes2) - - # union is a matrix of same size as intersect - union = area1[:, None] + area2[None, :] - intersect - iou = intersect / union - if return_intersection: - return iou, intersect - else: - return iou - @numba.jit(nopython=True) def _box_iou(boxes1, boxes2): ndim = boxes1.shape[1] // 2 @@ -623,14 +606,8 @@ def vote_by_ranges(list_of_ranges, vote_thr=2): if len(list_of_ranges) >= vote_thr: # get all the starts and ends of the ranges - starts = sorted([r[0][0] for r in list_of_ranges]) - ends = sorted([r[-1][1] for r in list_of_ranges]) - - init_index = starts[vote_thr - 1] - term_index = ends[-vote_thr] + 1 - ranges = concat_sort_ranges(list_of_ranges) - return np.array(rle_voting(ranges, vote_thr, init_index, term_index)) + return np.array(rle_voting(ranges, vote_thr)) else: return np.array([]) @@ -754,4 +731,4 @@ def numpy_fill_instances(volume, instances): for s,e in zip(starts, ends): volume[s:e] = instance_id - return volume.reshape(shape) \ No newline at end of file + return volume.reshape(shape) diff --git a/empanada/data/bc_dataset.py b/empanada/data/bc_dataset.py index c551700..c3bdab7 100644 --- a/empanada/data/bc_dataset.py +++ b/empanada/data/bc_dataset.py @@ -78,6 +78,8 @@ def __init__( self, data_dir, transforms=None, + sem_thr=None, + cnt_thr=None, weight_gamma=0.3, norms=None, ): @@ -99,13 +101,15 @@ class only. data_dir, transforms, weight_gamma ) + self.sem_thr = sem_thr + self.cnt_thr = cnt_thr self.norms = norms def __getitem__(self, idx): # transformed and paste example f = self.impaths[idx] - image = np.load(f) - mask = np.load(self.mskpaths[idx]) + image = np.load(f, allow_pickle=True) + mask = np.load(self.mskpaths[idx], allow_pickle=True) assert image.ndim == 3 assert mask.ndim == 4 # has 2 channels at first dim @@ -124,6 +128,10 @@ def __getitem__(self, idx): if k == 'image' and self.norms is not None: v = (v - self.norms[0]) / self.norms[1] v = v[None] + elif k == 'sem' and self.sem_thr is not None: + v = v >= self.sem_thr + elif k == 'cnt' and self.cnt_thr is not None: + v = v >= self.cnt_thr # move to torch output[k] = torch.from_numpy(v).float() diff --git a/empanada/inference/engines.py b/empanada/inference/engines.py index af157ff..9e6a281 100644 --- a/empanada/inference/engines.py +++ b/empanada/inference/engines.py @@ -10,7 +10,6 @@ ) from collections import deque - __all__ = [ 'PanopticDeepLabEngine', 'PanopticDeepLabEngine3d', @@ -246,8 +245,10 @@ def __init__( self.coarse_boundaries = coarse_boundaries @torch.no_grad() - def infer(self, image, render_steps=2): + def infer(self, image, render_steps=2): + #render_steps = 0 model_out = self.model(image, render_steps, interpolate_ins=not self.coarse_boundaries) + #model_out['sem_logits'] = F.interpolate(model_out['sem_logits'], scale_factor=4, mode='bilinear', align_corners=True) # notice that sem is NOT sem_logits model_out['sem'] = logits_to_prob(model_out['sem_logits']) diff --git a/empanada/inference/matcher.py b/empanada/inference/matcher.py index 85ce4aa..fd90655 100644 --- a/empanada/inference/matcher.py +++ b/empanada/inference/matcher.py @@ -7,6 +7,8 @@ __all__ = [ 'fast_matcher', + 'rle_iou_matrix', + 'connect_chunk_boundaries', 'rle_matcher', 'RLEMatcher' ] @@ -79,8 +81,10 @@ def fast_matcher( if len(labels1) == 0 or len(labels2) == 0: empty = np.array([]) - if return_ioa: + if return_ioa and return_iou: # no matches, only labels, no matrices + return (empty, empty), (labels1, labels2), empty, empty, empty + elif return_iou or return_ioa: return (empty, empty), (labels1, labels2), empty, empty else: return (empty, empty), (labels1, labels2), empty @@ -133,6 +137,150 @@ def fast_matcher( return output +def rle_iou_matrix( + target_instance_rles, + match_instance_rles, + return_intersection=False, + return_ioa=False +): + r"""Computes the IoU (i.e., cost) matrix for Hungarian matching + on run length encodings. + + Args: + target_instance_rles: Dictionary of instances to match against. Keys are + instance 'labels' and values are a dictionary of ('box', 'starts', 'runs'). + + match_instance_rles: Dictionary of instances to match. Keys are + instance 'labels' and values are a dictionary of ('box', 'starts', 'runs') + + return_intersection: Whether to return total intersection area between + instances in target and match. + + return_ioa: Whether to return intersection-over-area (IoA) scores between + instances in target and match. + + Returns: + iou_matrix: Array of (n, m) pairwise IoU scores between instances + in target and match. + + intersection_matrix: Array of (n, m) pairwise intersection areas between instances + in target and match. Only returned in return_intersection is True. + + ioa_matrix: Array of (n, m) pairwise IoA scores between instances + in target and match. Only returned in return_ioa is True. + """ + # extract bounding boxes and labels for + # all objects in each instance segmentation + target_labels, target_boxes, target_starts, target_runs =\ + unpack_rle_attrs(target_instance_rles) + + match_labels, match_boxes, match_starts, match_runs =\ + unpack_rle_attrs(match_instance_rles) + + if len(target_labels) == 0 or len(match_labels) == 0: + empty = np.array([]) + if return_ioa and return_intersection: + return empty, empty, empty + elif return_ioa or return_intersection: + return empty, empty + else: + return empty + + # compute mask IoUs of all possible matches + iou_matrix = np.zeros((len(target_boxes), len(match_boxes)), dtype='float') + + if return_intersection: + inter_matrix = np.zeros((len(target_boxes), len(match_boxes)), dtype='float') + + if return_ioa: + ioa_matrix = np.zeros((len(target_boxes), len(match_boxes)), dtype='float') + + # match the boxes + box_matches = np.array(box_iou(target_boxes, match_boxes).nonzero()).T + + # compute rle overlap scores + for r1, r2 in box_matches: + iou_out = rle_iou( + target_starts[r1], target_runs[r1], + match_starts[r2], match_runs[r2], + return_intersection=return_intersection + ) + if return_intersection: + iou_matrix[r1, r2] = iou_out[0] + inter_matrix[r1, r2] = iou_out[1] + else: + iou_matrix[r1, r2] = iou_out + + if return_ioa: + ioa_matrix[r1, r2] = rle_ioa( + target_starts[r1], target_runs[r1], + match_starts[r2], match_runs[r2], + ) + + if return_intersection and return_ioa: + return iou_matrix, inter_matrix, ioa_matrix + elif return_intersection: + return iou_matrix, inter_matrix + elif return_ioa: + return iou_matrix, ioa_matrix + else: + return iou_matrix + +def connect_chunk_boundaries( + bound1, + bound2, + iou_thr=0.1, + area_thr=100 +): + r"""Finds all connections between labels in neighboring chunks of + run length encoded segmentations. + + Args: + bound1: Dictionary of instances to match against. Keys are + instance 'labels' and values are a dictionary of ('box', 'starts', 'runs'). + + bound2: Dictionary of instances to match. Keys are + instance 'labels' and values are a dictionary of ('box', 'starts', 'runs') + + iou_thr (float): Minimum iou score between instances in bound1 and bound2 + to add a connection. + + area_thr (float): Minimum overlap area between instances in bound1 and bound2 + to add a connection. + + Returns: + edges: List of connections as tuple. First item is instance id in + bound1 and second item is instance id in bound2. + + """ + if iou_thr is None and area_thr is None: + raise Exception(f'iou_thr and area_thr cannot both be None!') + + # extract the label ids + b1_labels = unpack_rle_attrs(bound1)[0] + b2_labels = unpack_rle_attrs(bound2)[0] + + # compute iou and intersection areas + iou_matrix, inter_matrix = rle_iou_matrix( + bound1, bound2, return_intersection=True + ) + + # no matches + if not np.any(iou_matrix): + return [] + + # mask for where labels are matched + mask = np.zeros(iou_matrix.shape, dtype='bool') + if iou_thr is not None: + mask = np.logical_or(mask, iou_matrix >= iou_thr) + if area_thr is not None: + mask = np.logical_or(mask, inter_matrix >= area_thr) + + n, m = np.where(mask) + edges = np.stack([b1_labels[n], b2_labels[m]], axis=1).tolist() + + return edges + def rle_matcher( target_instance_rles, match_instance_rles, @@ -173,43 +321,37 @@ def rle_matcher( ioa_matrix: Array of (n, m) pairwise IoA scores between instances in target and match. Only returned in return_ioa is True. """ - # screen matches by bounding box iou # extract bounding boxes and labels for # all objects in each instance segmentation target_labels, target_boxes, target_starts, target_runs =\ unpack_rle_attrs(target_instance_rles) - + match_labels, match_boxes, match_starts, match_runs =\ unpack_rle_attrs(match_instance_rles) if len(target_labels) == 0 or len(match_labels) == 0: empty = np.array([]) - if return_ioa: + if return_ioa and return_iou: # no matches, only labels, no matrices + return (empty, empty), (target_labels, match_labels), empty, empty, empty + elif return_ioa or return_iou: return (empty, empty), (target_labels, match_labels), empty, empty else: return (empty, empty), (target_labels, match_labels), empty + + cost_matrix = rle_iou_matrix( + target_instance_rles, + match_instance_rles, + return_ioa=return_ioa + ) - # compute mask IoUs of all possible matches - iou_matrix = np.zeros((len(target_boxes), len(match_boxes)), dtype='float') if return_ioa: - ioa_matrix = np.zeros((len(target_boxes), len(match_boxes)), dtype=np.float32) + iou_matrix = cost_matrix[0] + ioa_matrix = cost_matrix[1] + else: + iou_matrix = cost_matrix - # match the boxes - box_matches = np.array(box_iou(target_boxes, match_boxes).nonzero()).T - for r1, r2 in box_matches: - iou_matrix[r1, r2] = rle_iou( - target_starts[r1], target_runs[r1], - match_starts[r2], match_runs[r2], - ) - - if return_ioa: - ioa_matrix[r1, r2] = rle_ioa( - target_starts[r1], target_runs[r1], - match_starts[r2], match_runs[r2], - ) - - # returns tuple of indices and ious/ioas of instances + # maximize matched ious match_rows, match_cols = linear_sum_assignment(iou_matrix, maximize=True) # filter out matches with iou less than thr diff --git a/empanada/inference/stitch.py b/empanada/inference/stitch.py new file mode 100644 index 0000000..9816ddc --- /dev/null +++ b/empanada/inference/stitch.py @@ -0,0 +1,175 @@ +import numpy as np +import networkx as nx +from empanada.array_utils import merge_boxes, merge_rles +from empanada.inference.matcher import connect_chunk_boundaries + +__all__ = [ + 'global_instance_graph', + 'add_instance_edges', + 'merge_graph_instances', + 'remove_small_objects', + 'create_forward_map', + 'relabel_chunk_rles' +] + +def extend_dict(dict1, dict2): + for k,v in dict2.items(): + if k in dict1: + dict1[k].extend(v) + else: + dict1[k] = v + +def calculate_global_box( + local_box, + chunk_index, + chunk_shape, + chunk_dims +): + offset = np.unravel_index(chunk_index, chunk_dims) + + # project the box into global space + global_box = [ + s + c * chunk_shape[i % 3] + for i, (s,c) in enumerate(zip(local_box, 2 * offset)) + ] + + return global_box + +def global_instance_graph( + chunks, + cuber, + rle_class, + initial_label=1 +): + graph = nx.Graph() + chunk_instance_map = {} + + for chunk_idx, chunk_attrs in chunks.items(): + for label_id, label_attrs in chunk_attrs['rle'][rle_class].items(): + # add a global label node with info about + # the source chunk and label + area = label_attrs['runs'].sum() + box = calculate_global_box( + label_attrs['box'], chunk_idx, + cuber.cube_shape, cuber.chunk_dims + ) + + graph.add_node( + initial_label, area=area, box=box, + chunk_lookup={chunk_idx: [label_id]} + ) + + chunk_instance_map[chunk_idx] = \ + {label_id: initial_label} | chunk_instance_map.get(chunk_idx, {}) + + initial_label += 1 + + return graph, chunk_instance_map + +def add_instance_edges( + graph, + chunks, + chunk_instance_map, + cuber, + class_id, + iou_thr=0.1, + area_thr=100 +): + pairs = [ + ('right', 'left'), + ('bottom', 'top'), + ('back', 'front') + ] + + for chunk_index in chunks.keys(): + # add connections from each of 3 neighbor chunks + for pair, nix in zip(pairs, cuber.find_neighbors(chunk_index)): + if nix is None: + continue + + a, b = pair + edges = connect_chunk_boundaries( + chunks[chunk_index]['boundaries'][a][class_id], + chunks[nix]['boundaries'][b][class_id], + iou_thr, area_thr + ) + # convert from local to global labels + # and add edges to the graph + for edge in edges: + cl, nl = edge + cnode = chunk_instance_map[chunk_index][cl] + nnode = chunk_instance_map[nix][nl] + graph.add_edge(cnode, nnode) + +def merge_nodes( + graph, + root, + node, + remove_node=True +): + # merge the chunk lookup + extend_dict( + graph.nodes[root]['chunk_lookup'], + graph.nodes[node]['chunk_lookup'] + ) + + # merge the bounding boxes + graph.nodes[root]['box'] = merge_boxes( + graph.nodes[root]['box'], graph.nodes[node]['box'] + ) + + # merge the label areas + graph.nodes[root]['area'] += graph.nodes[node]['area'] + + if remove_node: + graph.remove_node(node) + +def merge_graph_instances(graph): + # merge instances in each connected component + # of the graph + for comp in list(nx.connected_components(graph)): + # sort so that the + comp = sorted(list(comp)) + root = comp[0] + for node in comp[1:]: + merge_nodes(graph, root, node, True) + +def remove_small_objects(graph, min_size=1000): + filtered = [] + for node in graph.nodes: + area = graph.nodes[node]['area'] + if area < min_size: + filtered.append(node) + + graph.remove_nodes_from(filtered) + +def create_forward_map(graph): + # merge connected components to the lowest label value + # and store the correct mapping of labels for each chunk + forward_map = {} + for node in graph.nodes: + for chunk_index, chunk_labels in graph.nodes[node]['chunk_lookup'].items(): + chunk_map = {cl: node for cl in chunk_labels} + forward_map[chunk_index] = chunk_map | forward_map.get(chunk_index, {}) + + return forward_map + +def relabel_chunk_rles(chunks, class_id, forward_map): + for chunk_index, chunk_attrs in chunks.items(): + relabeled = {} + lookup_table = forward_map.get(chunk_index, {}) + + for old, new in lookup_table.items(): + old_rle = chunk_attrs['rle'][class_id][old] + if new in relabeled: + merged_s, merged_r = merge_rles( + relabeled[new]['starts'], relabeled[new]['runs'], + old_rle['starts'], old_rle['runs'] + ) + + relabeled[new]['starts'] = merged_s + relabeled[new]['runs'] = merged_r + else: + relabeled[new] = old_rle + + chunk_attrs['rle'][class_id] = relabeled \ No newline at end of file diff --git a/empanada/inference/tile.py b/empanada/inference/tile.py index 1ec4e9f..519a38c 100644 --- a/empanada/inference/tile.py +++ b/empanada/inference/tile.py @@ -1,9 +1,10 @@ +import math import numpy as np from cztile.fixed_total_area_strategy import AlmostEqualBorderFixedTotalAreaStrategy2D from cztile.tiling_strategy import Rectangle as czrect from empanada.array_utils import rle_voting, merge_rles -__all__ = ['Tiler'] +__all__ = ['Tiler', 'Cuber'] def calculate_overlap_rle(yranges, xranges, image_shape): r"""Creates a run length encoding of the overlap between tiles. @@ -198,4 +199,80 @@ def __call__(self, image, tile_index): yslice = slice(*self.yranges[tile_index]) xslice = slice(*self.xranges[tile_index]) - return image[yslice, xslice] \ No newline at end of file + return image[yslice, xslice] + +class Cuber: + def __init__(self, array_shape, cube_shape): + assert len(array_shape) == len(cube_shape) == 3 + self.array_shape = array_shape + self.cube_shape = cube_shape + + self.chunk_dims = tuple( + [math.ceil(s / cs) for s,cs in zip(array_shape, cube_shape)] + ) + + self.cubes = self._get_cubes() + + def _get_cubes(self): + r"""Create slicable data cubes for the given array. + + Returns: + cubes (Dict[int, Tuple(slice)]): cube ROIs indexed + by the raveled cube index + + """ + cubes = {} + + d, h, w = self.array_shape + zs, ys, xs = self.cube_shape + cd, ch, cw = self.chunk_dims + + for zc, z in enumerate(range(0, d, zs)): + for yc, y in enumerate(range(0, h, ys)): + for xc, x in enumerate(range(0, w, xs)): + slices = ( + slice(z, min(z + zs, d)), + slice(y, min(y + ys, h)), + slice(x, min(x + xs, w)) + ) + + # compute the raveled cube index + cube_index = (zc * ch * cw) + (yc * cw) + xc + cubes[cube_index] = slices + + return cubes + + def find_neighbors(self, cube_index): + r"""Finds the indices of cubes to the right, bottom and back + of the given cube. + + Args: + cube_index (int): Index of a cube in the Cuber. + + Returns: + neighbors (Tuple[int]): Index of cube neighbors + to the right, bottom and back respectively. Any or + all of the neighbors may be None. + + """ + cd, ch, cw = self.chunk_dims + + # get the raveled cube indices + if (cube_index + 1) % cw != 0: + right = cube_index + 1 + else: + right = None + + if (cube_index // cw + 1) % ch != 0: + bottom = cube_index + cw + else: + bottom = None + + if (cube_index // (ch * cw) + 1) % cd != 0: + back = cube_index + (ch * cw) + else: + back = None + + return (right, bottom, back) + + \ No newline at end of file diff --git a/empanada/inference/watershed.py b/empanada/inference/watershed.py index 7ba4bfb..4ad4161 100644 --- a/empanada/inference/watershed.py +++ b/empanada/inference/watershed.py @@ -133,7 +133,7 @@ def bc_watershed( thres3=0.85, seed_thres=32, min_size=128, - label_divisor=1000, + label_divisor=None, use_mask_wts=False ): r"""Convert binary foreground probability maps and instance contours to @@ -163,8 +163,11 @@ def bc_watershed( segm = mask_watershed(foreground, seed) else: segm = watershed(-semantic.astype(np.float64), seed, mask=foreground) + + if min_size is not None and min_size > 0: + segm = size_threshold(segm, min_size) + if label_divisor is not None: + segm[segm > 0] += label_divisor - segm = size_threshold(segm, min_size) - segm[segm > 0] += label_divisor - - return cast2dtype(segm) + #return cast2dtype(segm) + return segm \ No newline at end of file diff --git a/empanada/losses.py b/empanada/losses.py index e9c0576..4a0bb35 100644 --- a/empanada/losses.py +++ b/empanada/losses.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from empanada.models.point_rend import point_sample __all__ = [ @@ -98,6 +99,110 @@ def forward(self, point_logits, point_coords, labels): return point_losses +class BootstrapPointRendLoss(nn.Module): + r"""Standard (binary) cross-entropy between logits at + points sampled by the point rend module. + """ + def __init__(self, beta=0.8, mode='hard'): + super(BootstrapPointRendLoss, self).__init__() + self.bce = nn.BCEWithLogitsLoss(reduction='mean') + self.beta = beta + self.mode = mode + + def forward(self, point_logits, point_coords, labels): + # sample the labels at the given coordinates + point_labels = point_sample( + labels.unsqueeze(1).float(), point_coords, + mode="nearest", align_corners=False + ) + + point_probas = torch.sigmoid(point_logits) + if self.mode == 'soft': + boot_labels = (self.beta * point_labels) + (1.0 - self.beta) * point_probas + else: + boot_labels = (self.beta * point_labels) + (1.0 - self.beta) * (point_probas > 0.5).float() + + point_losses = self.bce(point_logits, boot_labels) + + return point_losses + +class BootstrapDiceLoss(nn.Module): + """ + Calculates the bootstrapped dice loss between model output logits and + a noisy ground truth labelmap. The loss targets are modified to be + a linear combination of the noisy ground truth the model's own prediction + confidence. They are calculated as: + + boot_target = (beta * noisy_ground_truth) + (1.0 - beta) * model_predictions [1] + + References: + [1] https://arxiv.org/abs/1412.6596 + + Arguments: + ---------- + beta: Float, in the range (0, 1). Beta = 1 is equivalent to regular dice loss. + Controls the level of mixing between the noisy ground truth and the model's own + predictions. Default 0.8. + + eps: Float. A small float value used to prevent division by zero. Default, 1e-7. + + mode: Choice of ['hard', 'soft']. In the soft mode, model predictions are + probabilities in the range [0, 1]. In the hard mode, model predictions are + "hardened" such that: + + model_predictions = 1 when probability > 0.5 + model_predictions = 0 when probability <= 0.5 + + Default is 'hard'. + + Example Usage: + -------------- + + model = Model() + criterion = BootstrapDiceLoss(beta=0.8, mode='hard') + output = model(input) + loss = criterion(output, noisy_ground_truth) + loss.backward() + + """ + def __init__(self, beta=0.8, eps=1e-7, mode='hard'): + super(BootstrapDiceLoss, self).__init__() + self.beta = beta + self.eps = eps + self.mode = mode + + def forward(self, output, target): + if target.ndim == output.ndim - 1: + target = target.unsqueeze(1) + + n_classes = output.shape[1] + n_classes = 2 if n_classes == 1 else n_classes + empty_dims = (1,) * (target.ndim - 2) + + k = torch.arange(0, n_classes).view(1, n_classes, *empty_dims).to(target.device) + target = (target == k) + + if n_classes == 2: + pos_prob = torch.sigmoid(output) + neg_prob = 1 - pos_prob + probas = torch.cat([neg_prob, pos_prob], dim=1) + else: + probas = F.softmax(output, dim=1) + + target = target.type(output.dtype) + + if self.mode == 'soft': + boot_target = (self.beta * target) + (1.0 - self.beta) * probas + else: + boot_target = (self.beta * target) + (1.0 - self.beta) * (probas > 0.5).float() + + dims = (0,) + tuple(range(2, boot_target.ndimension())) + intersection = torch.sum(probas * boot_target, dims) + cardinality = torch.sum(probas + boot_target, dims) + + dice_loss = ((2. * intersection) / (cardinality + self.eps)).mean() + return 1 - dice_loss + class PanopticLoss(nn.Module): r"""Defines the overall panoptic loss function which combines semantic segmentation, instance centers and offsets. @@ -195,3 +300,47 @@ def forward(self, output, target): aux_loss['total_loss'] = total_loss.item() return total_loss, aux_loss + +class BootstrapBCLoss(nn.Module): + r"""Defines the overall loss for a boundary contour prediction + model. + + Args: + pr_weight: Float, weight to apply to the point rend semantic + segmentation loss. Only applies if using a Point Rend enabled model. + + top_k_percent: Float, fraction of largest semantic segmentation + loss values to consider in BootstrapCE. + + """ + def __init__( + self, + pr_weight=1, + beta=0.8, + eps=1e-7, + mode='hard' + ): + super(BootstrapBCLoss, self).__init__() + self.dice_loss = BootstrapDiceLoss(beta=beta, eps=eps, mode=mode) + self.pr_loss = BootstrapPointRendLoss(beta=beta, mode=mode) + self.pr_weight = pr_weight + + def forward(self, output, target): + # mask losses + sem_dice = self.dice_loss(output['sem_logits'], target['sem']) + cnt_dice = self.dice_loss(output['cnt_logits'], target['cnt']) + + aux_loss = {'sem_dice': sem_dice.item(), 'cnt_dice': cnt_dice.item()} + total_loss = sem_dice + cnt_dice + + # add the point rend losses from both + if 'sem_points' in output: + sem_pr_ce = self.pr_loss(output['sem_points'], output['sem_point_coords'], target['sem']) + cnt_pr_ce = self.pr_loss(output['cnt_points'], output['cnt_point_coords'], target['cnt']) + + aux_loss['sem_pr_ce'] = sem_pr_ce.item() + aux_loss['cnt_pr_ce'] = cnt_pr_ce.item() + total_loss += self.pr_weight * (sem_pr_ce + cnt_pr_ce) + + aux_loss['total_loss'] = total_loss.item() + return total_loss, aux_loss \ No newline at end of file diff --git a/empanada/models/point_rend.py b/empanada/models/point_rend.py index 1e775cb..91c91e4 100644 --- a/empanada/models/point_rend.py +++ b/empanada/models/point_rend.py @@ -279,16 +279,25 @@ def forward(self, coarse_sem_seg_logits, features): N, C = sem_seg_logits.size()[:2] point_indices = point_indices.unsqueeze(1).expand(-1, C, -1) - spatial_dims = sem_seg_logits.size()[-dim:] - dsize = 1 - for s in spatial_dims: - dsize *= s + # clunky if statements to appease TorchScript... + if dim == 3: + D, H, W = sem_seg_logits.size()[-3:] + dsize = D * H * W + + sem_seg_logits = ( + sem_seg_logits.reshape(N, C, dsize) + .scatter_(2, point_indices, point_logits) + .view(N, C, D, H, W) + ) + else: + H, W = sem_seg_logits.size()[-2:] + dsize = H * W - sem_seg_logits = ( - sem_seg_logits.reshape(N, C, dsize) - .scatter_(2, point_indices, point_logits) - .view(N, C, *spatial_dims) - ) + sem_seg_logits = ( + sem_seg_logits.reshape(N, C, dsize) + .scatter_(2, point_indices, point_logits) + .view(N, C, H, W) + ) pr_out['sem_seg_logits'] = sem_seg_logits diff --git a/empanada/models/quantization/__init__.py b/empanada/models/quantization/__init__.py index 2ebd86d..a1a0662 100644 --- a/empanada/models/quantization/__init__.py +++ b/empanada/models/quantization/__init__.py @@ -1,2 +1,2 @@ -from empanada.models.quantization.panoptic_deeplab import QuantizablePanopticDeepLab, QuantizablePanopticDeepLabPR -from empanada.models.quantization.panoptic_bifpn import QuantizablePanopticBiFPN, QuantizablePanopticBiFPNPR \ No newline at end of file +from empanada.models.quantization.panoptic_deeplab import * +from empanada.models.quantization.panoptic_bifpn import * \ No newline at end of file diff --git a/empanada/models/quantization/panoptic_bifpn.py b/empanada/models/quantization/panoptic_bifpn.py index 41ba833..d67504f 100644 --- a/empanada/models/quantization/panoptic_bifpn.py +++ b/empanada/models/quantization/panoptic_bifpn.py @@ -8,7 +8,8 @@ from typing import List, Dict __all__ = [ - 'QuantizablePanopticBiFPN' + 'QuantizablePanopticBiFPN', + 'QuantizablePanopticBiFPNPR' ] def _replace_relu(module): diff --git a/empanada/models/quantization/panoptic_deeplab.py b/empanada/models/quantization/panoptic_deeplab.py index 70097ed..6b4ef23 100644 --- a/empanada/models/quantization/panoptic_deeplab.py +++ b/empanada/models/quantization/panoptic_deeplab.py @@ -6,6 +6,7 @@ from empanada.models.quantization.point_rend import QuantizablePointRendSemSegHead from empanada.models.quantization.decoders import QuantizablePanopticDeepLabDecoder from empanada.models.heads import PanopticDeepLabHead +from empanada.models.panoptic_deeplab import _make3d from empanada.models.blocks import * from typing import List, Dict @@ -16,7 +17,8 @@ __all__ = [ 'QuantizablePanopticDeepLab', - 'QuantizablePanopticDeepLabPR' + 'QuantizablePanopticDeepLabPR', + 'QuantizablePanopticDeepLabBC' ] def _replace_relu(module): @@ -53,7 +55,7 @@ def __init__( assert (encoder in backbones), \ f'Invalid encoder name {encoder}, choices are {backbones}' assert stage4_stride in [16, 32] - assert min(low_level_stages) > 0 + #assert min(low_level_stages) > 0 self.decoder_channels = decoder_channels self.num_classes = num_classes @@ -262,4 +264,140 @@ def fuse_model(self): self.semantic_decoder.fuse_model() self.semantic_pr.fuse_model() if self.instance_decoder is not None: - self.instance_decoder.fuse_model() \ No newline at end of file + self.instance_decoder.fuse_model() + +class QuantizablePanopticDeepLabBC(QuantizablePanopticDeepLab): + def __init__( + self, + num_fc=3, + train_num_points=1024, + oversample_ratio=3, + importance_sample_ratio=0.75, + subdivision_steps=2, + subdivision_num_points=8192, + **kwargs + ): + super(QuantizablePanopticDeepLabBC, self).__init__(**kwargs) + + # remove instance center and regression layers + del self.ins_center + del self.ins_xy + + # create the boundary head + self.boundary_head = PanopticDeepLabHead(self.decoder_channels, 1) + + self.semantic_pr = QuantizablePointRendSemSegHead( + self.decoder_channels, self.num_classes, num_fc, + train_num_points, oversample_ratio, + importance_sample_ratio, subdivision_steps, + subdivision_num_points, quantize=kwargs['quantize'] + ) + + self.boundary_pr = QuantizablePointRendSemSegHead( + self.decoder_channels, self.num_classes, num_fc, + train_num_points, oversample_ratio, + importance_sample_ratio, subdivision_steps, + subdivision_num_points, quantize=kwargs['quantize'] + ) + + self.dimension = kwargs['dimension'] + if self.dimension == 3: + _make3d(self) + + def fix_qconfig(self, observer='fbgemm'): + self.encoder.qconfig = torch.quantization.get_default_qconfig(observer) + self.semantic_decoder.qconfig = torch.quantization.get_default_qconfig(observer) + if self.instance_decoder is not None: + self.instance_decoder.qconfig = torch.quantization.get_default_qconfig(observer) + + self.semantic_head.qconfig = torch.quantization.get_default_qconfig(observer) + self.boundary_head.qconfig = torch.quantization.get_default_qconfig(observer) + + self.quant.qconfig = torch.quantization.get_default_qconfig(observer) + self.dequant.qconfig = torch.quantization.get_default_qconfig(observer) + + def prepare_quantization(self): + torch.quantization.prepare(self.encoder, inplace=True) + torch.quantization.prepare(self.semantic_decoder, inplace=True) + if self.instance_decoder is not None: + torch.quantization.prepare(self.instance_decoder, inplace=True) + + torch.quantization.prepare(self.semantic_head, inplace=True) + torch.quantization.prepare(self.boundary_head, inplace=True) + + torch.quantization.prepare(self.quant, inplace=True) + torch.quantization.prepare(self.dequant, inplace=True) + + def _apply_heads( + self, + semantic_x, + instance_x, + render_steps: int, + interpolate_ins: bool + ): + heads_out = {} + + sem = self.semantic_head(semantic_x) + cnt = self.boundary_head(instance_x) + + if self.training: + sem_pr_out: Dict[str, torch.Tensor] = self.semantic_pr(sem, semantic_x) + cnt_pr_out: Dict[str, torch.Tensor] = self.boundary_pr(cnt, instance_x) + + # interpolate to original resolution (4x) + heads_out['sem_logits'] = self.interpolate(sem_pr_out['sem_seg_logits']) + heads_out['sem_points'] = sem_pr_out['point_logits'] + heads_out['sem_point_coords'] = sem_pr_out['point_coords'] + + heads_out['cnt_logits'] = self.interpolate(cnt_pr_out['sem_seg_logits']) + heads_out['cnt_points'] = cnt_pr_out['point_logits'] + heads_out['cnt_point_coords'] = cnt_pr_out['point_coords'] + + # dequant all outputs + heads_out = {k: self.dequant(v) for k,v in heads_out.items()} + else: + # update the number of subdivisions + self.semantic_pr.subdivision_steps = render_steps + self.boundary_pr.subdivision_steps = render_steps + + sem = self.dequant(sem) + semantic_x = self.dequant(semantic_x) + sem_pr_out: Dict[str, torch.Tensor] = self.semantic_pr( + sem, semantic_x + ) + heads_out['sem_logits'] = sem_pr_out['sem_seg_logits'] + + cnt = self.dequant(cnt) + instance_x = self.dequant(instance_x) + cnt_pr_out: Dict[str, torch.Tensor] = self.boundary_pr( + cnt, instance_x + ) + heads_out['cnt_logits'] = cnt_pr_out['sem_seg_logits'] + + return heads_out + + def forward(self, x, render_steps: int=2, interpolate_ins: bool=True): + if self.training: + assert isinstance(self.quant, nn.Identity), \ + "Quantized training not supported!" + + x = self.quant(x) + + pyramid_features, semantic_x, instance_x = self._encode_decode(x) + output: Dict[str, torch.Tensor] = self._apply_heads( + semantic_x, instance_x, render_steps, interpolate_ins + ) + + output = torch.cat( + [output['sem_logits'], output['cnt_logits']], dim=1 + ) + + return output + + def fuse_model(self): + self.encoder.fuse_model() + self.semantic_decoder.fuse_model() + self.semantic_pr.fuse_model() + self.boundary_pr.fuse_model() + if self.instance_decoder is not None: + self.instance_decoder.fuse_model() diff --git a/projects/mitonet/configs/mmm_median_inference_celegans.yaml b/projects/mitonet/configs/mmm_median_inference_celegans.yaml index 23dad3a..f147542 100644 --- a/projects/mitonet/configs/mmm_median_inference_celegans.yaml +++ b/projects/mitonet/configs/mmm_median_inference_celegans.yaml @@ -3,7 +3,7 @@ BASE: "./mmm_median_inference.yaml" # parameters for the inference engine engine_params: median_kernel_size: 3 - confidence_thr: 0.3 + confidence_thr: 0.5 consensus_params: pixel_vote_thr: 1 @@ -11,4 +11,4 @@ consensus_params: # simple object size/shape filters filters: - { name: "remove_small_objects", min_size: 500 } - - { name: "remove_pancakes", min_span: 8 } \ No newline at end of file + - { name: "remove_pancakes", min_span: 8 } diff --git a/projects/mitonet/configs/mmm_median_inference_fly_brain.yaml b/projects/mitonet/configs/mmm_median_inference_fly_brain.yaml index c3aa885..26b5528 100644 --- a/projects/mitonet/configs/mmm_median_inference_fly_brain.yaml +++ b/projects/mitonet/configs/mmm_median_inference_fly_brain.yaml @@ -1,7 +1,5 @@ BASE: "./mmm_median_inference.yaml" -axes: [ 'xy', 'xz', 'yz' ] - # parameters for the inference engine engine_params: median_kernel_size: 3 @@ -9,4 +7,4 @@ engine_params: # simple object size/shape filters filters: - { name: "remove_small_objects", min_size: 500 } - - { name: "remove_pancakes", min_span: 8 } \ No newline at end of file + - { name: "remove_pancakes", min_span: 8 } diff --git a/projects/mitonet/configs/mmm_median_inference_hela.yaml b/projects/mitonet/configs/mmm_median_inference_hela.yaml index defbda0..6810fdd 100644 --- a/projects/mitonet/configs/mmm_median_inference_hela.yaml +++ b/projects/mitonet/configs/mmm_median_inference_hela.yaml @@ -7,4 +7,4 @@ engine_params: # simple object size/shape filters filters: - { name: "remove_small_objects", min_size: 800 } - - { name: "remove_pancakes", min_span: 8 } \ No newline at end of file + - { name: "remove_pancakes", min_span: 4 } diff --git a/projects/mitonet/configs/mmm_panoptic_deeplab_bc.yaml b/projects/mitonet/configs/mmm_panoptic_deeplab_bc.yaml index a8dacae..97d0a8b 100644 --- a/projects/mitonet/configs/mmm_panoptic_deeplab_bc.yaml +++ b/projects/mitonet/configs/mmm_panoptic_deeplab_bc.yaml @@ -83,9 +83,10 @@ TRAIN: # dataset parameters batch_size: 64 - dataset_class: "BCDataset" - weight_gamma: 0.3 workers: 8 + dataset_class: "BCDataset" + dataset_params: + weight_gamma: 0.3 augmentations: - { aug: "RandomScale", scale_limit: [ -0.9, 1 ]} @@ -116,6 +117,7 @@ EVAL: - { metric: "IoU", name: "contour_iou", labels: [ 1 ], output_key: "cnt", target_key: "cnt"} # parameters needed for inference + engine: "BCEngine" engine_params: thing_list: [ 1 ] label_divisor: 1000 diff --git a/projects/mitonet/configs/mmm_ws_inference_fly_brain.yaml b/projects/mitonet/configs/mmm_ws_inference_fly_brain.yaml index 0690249..3901e82 100644 --- a/projects/mitonet/configs/mmm_ws_inference_fly_brain.yaml +++ b/projects/mitonet/configs/mmm_ws_inference_fly_brain.yaml @@ -7,14 +7,14 @@ labels: [ 1 ] # parameters for the inference engine engine: "BCEngine3d" engine_params: - median_kernel_size: 5 + median_kernel_size: 3 watershed_params: - thres1: 0.5 - thres2: 0.5 - thres3: 0.25 - seed_thres: 300 + thres1: 0.3 + thres2: 0.2 + thres3: 0.1 + seed_thres: 500 min_size: 500 - label_divisor: 1000 + label_divisor: 10000 use_mask_wts: True diff --git a/projects/mitonet/scripts/evaluate3d.py b/projects/mitonet/scripts/evaluate3d.py index 397e015..d8aa1fa 100644 --- a/projects/mitonet/scripts/evaluate3d.py +++ b/projects/mitonet/scripts/evaluate3d.py @@ -95,7 +95,6 @@ def parse_args(): # create a separate tracker for # each prediction axis and each segmentation class trackers = create_axis_trackers(axes, class_labels, label_divisor, shape) - for axis_name, axis in axes.items(): print(f'Predicting {axis_name} stack') @@ -169,7 +168,7 @@ def parse_args(): finish_tracking(trackers[axis_name]) for tracker in trackers[axis_name]: apply_filters(tracker, filters_dict) - + # create the final instance segmentations for class_id in config['INFERENCE']['labels']: class_name = config['DATASET']['class_names'][class_id] @@ -195,7 +194,7 @@ def parse_args(): overwrite=True, chunks=(1, None, None) ) fill_volume(consensus_vol, consensus_tracker.instances, processes=4) - consensus_tracker.write_to_json(os.path.join(volume_path, f'{config_name}_{class_name}_pred.json')) + consensus_tracker.write_to_json(os.path.join(volume_path, f'{config_name}_{class_name}_pred_test.json')) # run evaluation semantic_metrics = {'IoU': iou} @@ -206,13 +205,13 @@ def parse_args(): for class_id, class_name in config['DATASET']['class_names'].items(): gt_json = os.path.join(volume_path, f'{class_name}_gt.json') - pred_json = os.path.join(volume_path, f'{config_name}_{class_name}_pred.json') + pred_json = os.path.join(volume_path, f'{config_name}_{class_name}_pred_test.json') results = evaluator(gt_json, pred_json) results = {f'{class_name}_{k}': v for k,v in results.items()} for k, v in results.items(): print(k, v) - + try: run_id = state.get('run_id') if run_id is not None: diff --git a/projects/mitonet/scripts/evaluate3d_bc.py b/projects/mitonet/scripts/evaluate3d_bc.py index 3db0017..083190b 100644 --- a/projects/mitonet/scripts/evaluate3d_bc.py +++ b/projects/mitonet/scripts/evaluate3d_bc.py @@ -170,6 +170,8 @@ def parse_args(): for index2d,seg2d in tqdm(enumerate(instance_seg), total=len(instance_seg)): rle_seg = pan_seg_to_rle_seg(seg2d, [1], label_divisor, [1], force_connected=False) pred_tracker.update(rle_seg[1], index2d) + + print('Number of instances', len(pred_tracker.instances)) pred_tracker.finish() pred_tracker.write_to_json(os.path.join(volume_path, f'{config_name}_{class_name}_pred.json')) diff --git a/scripts/export_model.py b/scripts/export_model.py index 824dcee..2e7e69c 100644 --- a/scripts/export_model.py +++ b/scripts/export_model.py @@ -125,9 +125,15 @@ def main(): torch.jit.save(model, model_out) print('Exported model successfully.') + dimension = config['MODEL'].get('dimension') + if dimension == 3: + tensor = torch.randn((1, 1, 128, 128, 128)) + else: + tensor = torch.randn((1, 1, 256, 256)) + # NOTE: Do this after saving or model performance is degraded with torch.no_grad(): - x = torch.randn((1, 1, 256, 256)).cuda() + x = tensor.cuda() output = model(x) print('Validated forward pass.') @@ -176,8 +182,8 @@ def main(): 'dataset_params': config['TRAIN']['dataset_params'], 'criterion': config['TRAIN']['criterion'], 'criterion_params': config['TRAIN']['criterion_params'], - 'engine': config['EVAL']['engine'], - 'engine_params': config['EVAL']['engine_params'], + 'engine': config['EVAL'].get('engine'), + 'engine_params': config['EVAL'].get('engine_params'), } desc = { 'model': model_out, diff --git a/scripts/train.py b/scripts/train.py index ec01c45..18180a5 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -564,7 +564,7 @@ def train( # measure elapsed time batch_time.update(time.time() - end) end = time.time() - + if i % config['TRAIN']['print_freq'] == 0: progress.display(i) @@ -715,4 +715,4 @@ def _get_batch_fmtstr(self, num_batches): return '[' + fmt + '/' + fmt.format(num_batches) + ']' if __name__ == "__main__": - main() \ No newline at end of file + main()