From 6dd1a0f830559074d8f5915b7079e8c04e1e7b11 Mon Sep 17 00:00:00 2001 From: Amine Date: Tue, 13 Sep 2022 10:16:33 +0200 Subject: [PATCH] feat: make loss computation vectorized and change target building to handle better class ids --- doctr/datasets/datasets/tensorflow.py | 4 + .../differentiable_binarization/base.py | 39 +++-- .../differentiable_binarization/tensorflow.py | 139 ++++++++++++------ doctr/models/detection/linknet/base.py | 28 ++-- doctr/models/detection/linknet/tensorflow.py | 58 ++++---- 5 files changed, 160 insertions(+), 108 deletions(-) diff --git a/doctr/datasets/datasets/tensorflow.py b/doctr/datasets/datasets/tensorflow.py index 4d8320eee1..96c53b001b 100644 --- a/doctr/datasets/datasets/tensorflow.py +++ b/doctr/datasets/datasets/tensorflow.py @@ -25,6 +25,10 @@ def _read_sample(self, index: int) -> Tuple[tf.Tensor, Any]: if isinstance(target, dict): assert "boxes" in target, "Target should contain 'boxes' key" assert "labels" in target, "Target should contain 'labels' key" + elif isinstance(target, tuple): + assert isinstance(target[0], str) or isinstance( + target[0], np.ndarray + ), "Target should be a string or a numpy array" else: assert isinstance(target, str) or isinstance( target, np.ndarray diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py index b5cebc291c..edd0ee648a 100644 --- a/doctr/models/detection/differentiable_binarization/base.py +++ b/doctr/models/detection/differentiable_binarization/base.py @@ -5,7 +5,7 @@ # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization -from typing import List, Tuple, Union, Dict +from typing import Dict, List, Tuple, Union import cv2 import numpy as np @@ -268,9 +268,11 @@ def build_target( output_shape: Tuple[int, int, int, int], ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - for i, tgt in enumerate(target): + new_target = [] + for tgt in target: if isinstance(tgt, np.ndarray): - target[i] = {"words": tgt} + new_target.append({"words": tgt}) + target = new_target.copy() if any(t.dtype != np.float32 for tgt in target for t in tgt.values()): raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.") if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()): @@ -279,17 +281,19 @@ def build_target( input_dtype = list(target[0].values())[0].dtype if len(target) > 0 else np.float32 h, w = output_shape[1:-1] - seg_target: np.ndarray = np.zeros(output_shape, dtype=np.uint8) - seg_mask: np.ndarray = np.ones(output_shape, dtype=bool) - thresh_target: np.ndarray = np.zeros(output_shape, dtype=np.float32) - thresh_mask: np.ndarray = np.ones(output_shape, dtype=np.uint8) + target_shape = (output_shape[0], output_shape[-1], h, w) + seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8) + seg_mask: np.ndarray = np.ones(target_shape, dtype=bool) + thresh_target: np.ndarray = np.zeros(target_shape, dtype=np.float32) + thresh_mask: np.ndarray = np.ones(target_shape, dtype=np.uint8) for idx, tgt in enumerate(target): for class_idx, _target in enumerate(tgt.values()): # Draw each polygon on gt if _target.shape[0] == 0: # Empty image, full masked - seg_mask[idx, :, :, class_idx] = False + # seg_mask[idx, :, :, class_idx] = False + seg_mask[idx, class_idx] = False # Absolute bounding boxes abs_boxes = _target.copy() @@ -317,7 +321,8 @@ def build_target( for box, box_size, poly in zip(abs_boxes, boxes_size, polys): # Mask boxes that are too small if box_size < self.min_size_box: - seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False + # seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False continue # Negative shrink for gt, as described in paper @@ -330,18 +335,24 @@ def build_target( # Draw polygon on gt if it is valid if len(shrinked) == 0: - seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False + # seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False continue shrinked = np.array(shrinked[0]).reshape(-1, 2) if shrinked.shape[0] <= 2 or not Polygon(shrinked).is_valid: - seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False + # seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False continue - cv2.fillPoly(seg_target[idx, :, :, class_idx][..., None], [shrinked.astype(np.int32)], 1) + cv2.fillPoly(seg_target[idx, class_idx], [shrinked.astype(np.int32)], 1) # Draw on both thresh map and thresh mask - poly, thresh_target[idx, :, :, class_idx], thresh_mask[idx, :, :, class_idx] = self.draw_thresh_map( - poly, thresh_target[idx, :, :, class_idx], thresh_mask[idx, :, :, class_idx] + poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] = self.draw_thresh_map( + poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] ) + seg_target = seg_target.transpose((0, 2, 3, 1)) + seg_mask = seg_mask.transpose((0, 2, 3, 1)) + thresh_target = thresh_target.transpose((0, 2, 3, 1)) + thresh_mask = thresh_mask.transpose((0, 2, 3, 1)) thresh_target = thresh_target.astype(input_dtype) * (self.thresh_max - self.thresh_min) + self.thresh_min diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index 64f18aaea2..63df13574d 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -180,59 +180,102 @@ def compute_loss( A loss tensor """ + # prob_map = tf.math.sigmoid(out_map) + # thresh_map = tf.math.sigmoid(thresh_map) + # + # seg_target_all, seg_mask_all, thresh_target_all, thresh_mask_all = self.build_target(target, out_map.shape) + # seg_target_all = tf.convert_to_tensor(seg_target_all, dtype=out_map.dtype) + # seg_mask_all = tf.convert_to_tensor(seg_mask_all, dtype=tf.bool) + # thresh_target_all = tf.convert_to_tensor(thresh_target_all, dtype=out_map.dtype) + # thresh_mask_all = tf.convert_to_tensor(thresh_mask_all, dtype=tf.bool) + + # final_loss = tf.convert_to_tensor(0, dtype=float) + # for idx in range(out_map.shape[-1]): + # seg_target = seg_target_all[..., idx] + # seg_mask = seg_mask_all[..., idx] + # thresh_target = thresh_target_all[..., idx] + # thresh_mask = thresh_mask_all[..., idx] + # _out_map = out_map[..., idx] + # _thresh_map = thresh_map[..., idx] + # _prob_map = prob_map[..., idx] + # # Compute balanced BCE loss for proba_map + # bce_scale = 5.0 + # + # bce_loss = tf.keras.losses.binary_crossentropy( + # seg_target[..., None], _out_map[..., None], from_logits=True + # )[seg_mask] + # + # neg_target = 1 - seg_target[seg_mask] + # positive_count = tf.math.reduce_sum(seg_target[seg_mask]) + # negative_count = tf.math.reduce_min([tf.math.reduce_sum(neg_target), 3.0 * positive_count]) + # negative_loss = bce_loss * neg_target + # negative_loss, _ = tf.nn.top_k(negative_loss, tf.cast(negative_count, tf.int32)) + # sum_losses = tf.math.reduce_sum(bce_loss * seg_target[seg_mask]) + tf.math.reduce_sum(negative_loss) + # balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6) + # + # # Compute dice loss for approxbin_map + # bin_map = 1 / (1 + tf.exp(-50.0 * (_prob_map[seg_mask] - _thresh_map[seg_mask]))) + # + # bce_min = tf.math.reduce_min(bce_loss) + # weights = (bce_loss - bce_min) / (tf.math.reduce_max(bce_loss) - bce_min) + 1.0 + # inter = tf.math.reduce_sum(bin_map * seg_target[seg_mask] * weights) + # union = tf.math.reduce_sum(bin_map) + tf.math.reduce_sum(seg_target[seg_mask]) + 1e-8 + # dice_loss = 1 - 2.0 * (inter + eps) / (union + eps) + # + # # Compute l1 loss for thresh_map + # l1_scale = 10.0 + # if tf.reduce_any(thresh_mask): + # l1_loss = tf.math.reduce_mean(tf.math.abs(_thresh_map[thresh_mask] - thresh_target[thresh_mask])) + # else: + # l1_loss = tf.constant(0.0) + # + # final_loss += l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss + # return final_loss + # prob_map = tf.math.sigmoid(tf.squeeze(out_map, axis=[-1])) # thresh_map = tf.math.sigmoid(tf.squeeze(thresh_map, axis=[-1])) prob_map = tf.math.sigmoid(out_map) thresh_map = tf.math.sigmoid(thresh_map) - seg_target_all, seg_mask_all, thresh_target_all, thresh_mask_all = self.build_target(target, out_map.shape) - seg_target_all = tf.convert_to_tensor(seg_target_all, dtype=out_map.dtype) - seg_mask_all = tf.convert_to_tensor(seg_mask_all, dtype=tf.bool) - thresh_target_all = tf.convert_to_tensor(thresh_target_all, dtype=out_map.dtype) - thresh_mask_all = tf.convert_to_tensor(thresh_mask_all, dtype=tf.bool) - - final_loss = tf.convert_to_tensor(0, dtype=float) - for idx in range(out_map.shape[-1]): - seg_target = seg_target_all[..., idx] - seg_mask = seg_mask_all[..., idx] - thresh_target = thresh_target_all[..., idx] - thresh_mask = thresh_mask_all[..., idx] - _out_map = out_map[..., idx] - _thresh_map = thresh_map[..., idx] - _prob_map = prob_map[..., idx] - # Compute balanced BCE loss for proba_map - bce_scale = 5.0 - - bce_loss = tf.keras.losses.binary_crossentropy( - seg_target[..., None], _out_map[..., None], from_logits=True - )[seg_mask] - - neg_target = 1 - seg_target[seg_mask] - positive_count = tf.math.reduce_sum(seg_target[seg_mask]) - negative_count = tf.math.reduce_min([tf.math.reduce_sum(neg_target), 3.0 * positive_count]) - negative_loss = bce_loss * neg_target - negative_loss, _ = tf.nn.top_k(negative_loss, tf.cast(negative_count, tf.int32)) - sum_losses = tf.math.reduce_sum(bce_loss * seg_target[seg_mask]) + tf.math.reduce_sum(negative_loss) - balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6) - - # Compute dice loss for approxbin_map - bin_map = 1 / (1 + tf.exp(-50.0 * (_prob_map[seg_mask] - _thresh_map[seg_mask]))) - - bce_min = tf.math.reduce_min(bce_loss) - weights = (bce_loss - bce_min) / (tf.math.reduce_max(bce_loss) - bce_min) + 1.0 - inter = tf.math.reduce_sum(bin_map * seg_target[seg_mask] * weights) - union = tf.math.reduce_sum(bin_map) + tf.math.reduce_sum(seg_target[seg_mask]) + 1e-8 - dice_loss = 1 - 2.0 * (inter + eps) / (union + eps) - - # Compute l1 loss for thresh_map - l1_scale = 10.0 - if tf.reduce_any(thresh_mask): - l1_loss = tf.math.reduce_mean(tf.math.abs(_thresh_map[thresh_mask] - thresh_target[thresh_mask])) - else: - l1_loss = tf.constant(0.0) - - final_loss += l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss - return final_loss + seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape) + seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype) + seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool) + thresh_target = tf.convert_to_tensor(thresh_target, dtype=out_map.dtype) + thresh_mask = tf.convert_to_tensor(thresh_mask, dtype=tf.bool) + + # Compute balanced BCE loss for proba_map + bce_scale = 5.0 + bce_loss = tf.keras.losses.binary_crossentropy( + seg_target[..., None], + out_map[..., None], + from_logits=True, + )[seg_mask] + + neg_target = 1 - seg_target[seg_mask] + positive_count = tf.math.reduce_sum(seg_target[seg_mask]) + negative_count = tf.math.reduce_min([tf.math.reduce_sum(neg_target), 3.0 * positive_count]) + negative_loss = bce_loss * neg_target + negative_loss, _ = tf.nn.top_k(negative_loss, tf.cast(negative_count, tf.int32)) + sum_losses = tf.math.reduce_sum(bce_loss * seg_target[seg_mask]) + tf.math.reduce_sum(negative_loss) + balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6) + + # Compute dice loss for approxbin_map + bin_map = 1 / (1 + tf.exp(-50.0 * (prob_map[seg_mask] - thresh_map[seg_mask]))) + + bce_min = tf.math.reduce_min(bce_loss) + weights = (bce_loss - bce_min) / (tf.math.reduce_max(bce_loss) - bce_min) + 1.0 + inter = tf.math.reduce_sum(bin_map * seg_target[seg_mask] * weights) + union = tf.math.reduce_sum(bin_map) + tf.math.reduce_sum(seg_target[seg_mask]) + 1e-8 + dice_loss = 1 - 2.0 * inter / union + + # Compute l1 loss for thresh_map + l1_scale = 10.0 + if tf.reduce_any(thresh_mask): + l1_loss = tf.math.reduce_mean(tf.math.abs(thresh_map[thresh_mask] - thresh_target[thresh_mask])) + else: + l1_loss = tf.constant(0.0) + + return l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss def call( self, diff --git a/doctr/models/detection/linknet/base.py b/doctr/models/detection/linknet/base.py index 6828c6f816..b5aaf0ee1b 100644 --- a/doctr/models/detection/linknet/base.py +++ b/doctr/models/detection/linknet/base.py @@ -5,7 +5,7 @@ # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization -from typing import List, Tuple, Union, Dict +from typing import Dict, List, Tuple, Union import cv2 import numpy as np @@ -160,16 +160,18 @@ def build_target( output_shape: Tuple[int, int, int], ) -> Tuple[np.ndarray, np.ndarray]: - for i, tgt in enumerate(target): + new_target = [] + for tgt in target: if isinstance(tgt, np.ndarray): - target[i] = {"words": tgt} + new_target.append({"words": tgt}) + target = new_target.copy() if any(t.dtype != np.float32 for tgt in target for t in tgt.values()): raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.") if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()): raise ValueError("the 'boxes' entry of the target is expected to take values between 0 & 1.") h, w, num_classes = output_shape - target_shape = (len(target), h, w, num_classes) + target_shape = (len(target), num_classes, h, w) seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8) seg_mask: np.ndarray = np.ones(target_shape, dtype=bool) @@ -179,7 +181,7 @@ def build_target( # Draw each polygon on gt if _target.shape[0] == 0: # Empty image, full masked - seg_mask[idx, :, :, class_idx] = False + seg_mask[idx, class_idx] = False # Absolute bounding boxes abs_boxes = _target.copy() @@ -208,7 +210,7 @@ def build_target( for poly, box, box_size in zip(polys, abs_boxes, boxes_size): # Mask boxes that are too small if box_size < self.min_size_box: - seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False continue # Negative shrink for gt, as described in paper @@ -221,17 +223,17 @@ def build_target( # Draw polygon on gt if it is valid if len(shrunken) == 0: - seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False continue shrunken = np.array(shrunken[0]).reshape(-1, 2) if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid: - seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False continue - cv2.fillPoly(seg_target[idx, :, :, class_idx][..., None], [shrunken.astype(np.int32)], 1) + cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1) - # Don't forget to switch back to channel first if PyTorch is used - if not is_tf_available(): - seg_target = seg_target.transpose(0, 3, 1, 2) - seg_mask = seg_mask.transpose(0, 3, 1, 2) + # Don't forget to switch back to channel last if Tensorflow is used + if is_tf_available(): + seg_target = seg_target.transpose((0, 2, 3, 1)) + seg_mask = seg_mask.transpose((0, 2, 3, 1)) return seg_target, seg_mask diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py index 60215990a0..b650c8174d 100644 --- a/doctr/models/detection/linknet/tensorflow.py +++ b/doctr/models/detection/linknet/tensorflow.py @@ -180,39 +180,31 @@ def compute_loss( Returns: A loss tensor """ - seg_target_all, seg_mask_all = self.build_target(target, out_map.shape[1:]) - dice_loss_all = tf.convert_to_tensor(0, dtype=float) - focal_loss_all = tf.convert_to_tensor(0, dtype=float) - for idx in range(seg_target_all.shape[-1]): - seg_target = seg_target_all[..., idx] # [..., None] - seg_mask = seg_mask_all[..., idx] # [..., None] - _out_map = out_map[..., idx] # [..., None] - seg_target = tf.convert_to_tensor(seg_target, dtype=_out_map.dtype) - seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool) - seg_mask = tf.cast(seg_mask, tf.float32) - - bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], _out_map[..., None], from_logits=True) - proba_map = tf.sigmoid(_out_map) - - # Focal loss - if gamma < 0: - raise ValueError("Value of gamma should be greater than or equal to zero.") - # Convert logits to prob, compute gamma factor - p_t = (seg_target * proba_map) + ((1 - seg_target) * (1 - proba_map)) - alpha_t = seg_target * alpha + (1 - seg_target) * (1 - alpha) - # Unreduced loss - focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss - # Class reduced - focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2)) / tf.reduce_sum(seg_mask, (0, 1, 2)) - - # Dice loss - inter = tf.math.reduce_sum(seg_mask * proba_map * seg_target, (0, 1, 2)) - cardinality = tf.math.reduce_sum((proba_map + seg_target), (0, 1, 2)) - dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps) - - focal_loss_all += tf.reduce_mean(focal_loss) - dice_loss_all += tf.reduce_mean(dice_loss) - return focal_loss_all + dice_loss_all + seg_target, seg_mask = self.build_target(target, out_map.shape[1:]) + seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype) + seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool) + seg_mask = tf.cast(seg_mask, tf.float32) + + bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True) + proba_map = tf.sigmoid(out_map) + + # Focal loss + if gamma < 0: + raise ValueError("Value of gamma should be greater than or equal to zero.") + # Convert logits to prob, compute gamma factor + p_t = (seg_target * proba_map) + ((1 - seg_target) * (1 - proba_map)) + alpha_t = seg_target * alpha + (1 - seg_target) * (1 - alpha) + # Unreduced loss + focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss + # Class reduced + focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3)) + + # Dice loss + inter = tf.math.reduce_sum(seg_mask * proba_map * seg_target, (0, 1, 2, 3)) + cardinality = tf.math.reduce_sum((proba_map + seg_target), (0, 1, 2, 3)) + dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps) + + return focal_loss + dice_loss def call( self,