diff --git a/doctr/models/detection/linknet/base.py b/doctr/models/detection/linknet/base.py index 17d425c56c..15bbb6dc81 100644 --- a/doctr/models/detection/linknet/base.py +++ b/doctr/models/detection/linknet/base.py @@ -107,7 +107,7 @@ def build_target( self, target: List[np.ndarray], output_shape: Tuple[int, int], - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + ) -> Tuple[np.ndarray, np.ndarray]: if any(t.dtype != np.float32 for t in target): raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.") @@ -119,7 +119,6 @@ def build_target( if self.assume_straight_pages: seg_target = np.zeros(target_shape, dtype=bool) - edge_mask = np.zeros(target_shape, dtype=bool) else: seg_target = np.zeros(target_shape, dtype=np.uint8) @@ -144,7 +143,12 @@ def build_target( abs_boxes[:, [0, 2]] *= w abs_boxes[:, [1, 3]] *= h abs_boxes = abs_boxes.round().astype(np.int32) - polys = [None] * abs_boxes.shape[0] # Unused + polys = np.stack([ + abs_boxes[:, [0, 1]], + abs_boxes[:, [0, 3]], + abs_boxes[:, [2, 3]], + abs_boxes[:, [2, 1]], + ], axis=1) boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1]) for poly, box, box_size in zip(polys, abs_boxes, boxes_size): @@ -159,19 +163,10 @@ def build_target( if box.shape == (4, 2): box = [np.min(box[:, 0]), np.min(box[:, 1]), np.max(box[:, 0]), np.max(box[:, 1])] seg_target[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = True - # top edge - edge_mask[idx, box[1], box[0]: min(box[2] + 1, w)] = True - # bot edge - edge_mask[idx, min(box[3], h - 1), box[0]: min(box[2] + 1, w)] = True - # left edge - edge_mask[idx, box[1]: min(box[3] + 1, h), box[0]] = True - # right edge - edge_mask[idx, box[1]: min(box[3] + 1, h), min(box[2], w - 1)] = True # Don't forget to switch back to channel first if PyTorch is used if not is_tf_available(): seg_target = seg_target.transpose(0, 3, 1, 2) seg_mask = seg_mask.transpose(0, 3, 1, 2) - edge_mask = edge_mask.transpose(0, 3, 1, 2) - return seg_target, seg_mask, edge_mask + return seg_target, seg_mask diff --git a/doctr/models/detection/linknet/pytorch.py b/doctr/models/detection/linknet/pytorch.py index fc50691bd3..dc55a0862d 100644 --- a/doctr/models/detection/linknet/pytorch.py +++ b/doctr/models/detection/linknet/pytorch.py @@ -158,7 +158,6 @@ def compute_loss( self, out_map: torch.Tensor, target: List[np.ndarray], - edge_factor: float = 2., ) -> torch.Tensor: """Compute linknet loss, BCE with boosted box edges or focal loss. Focal loss implementation based on `_. @@ -166,26 +165,26 @@ def compute_loss( Args: out_map: output feature map of the model of shape (N, 1, H, W) target: list of dictionary where each dict has a `boxes` and a `flags` entry - edge_factor: boost factor for box edges (in case of BCE) Returns: A loss tensor """ - seg_target, seg_mask, edge_mask = self.build_target(target, out_map.shape[-2:]) # type: ignore[arg-type] + seg_target, seg_mask = self.build_target(target, out_map.shape[-2:]) # type: ignore[arg-type] seg_target, seg_mask = torch.from_numpy(seg_target).to(dtype=out_map.dtype), torch.from_numpy(seg_mask) seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device) - if edge_factor > 0: - edge_mask = torch.from_numpy(edge_mask).to(dtype=out_map.dtype, device=out_map.device) - # Get the cross_entropy for each entry - loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction='none') + # BCE loss + bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction='none') + + # Dice loss + prob_map = torch.nn.functional.sigmoid(out_map) + inter = (prob_map[seg_mask] * seg_target[seg_mask]).sum() + cardinality = (prob_map[seg_mask] + seg_target[seg_mask]).sum() + dice_loss = 1 - 2 * inter / (cardinality + 1e-8) - # Compute BCE loss with highlighted edges - if edge_factor > 0: - loss = ((1 + (edge_factor - 1) * edge_mask) * loss) # Only consider contributions overlaping the mask - return loss[seg_mask].mean() + return bce_loss[seg_mask].mean() + dice_loss def _linknet( diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py index 792e503c7b..1c1b40295a 100644 --- a/doctr/models/detection/linknet/tensorflow.py +++ b/doctr/models/detection/linknet/tensorflow.py @@ -139,7 +139,6 @@ def compute_loss( self, out_map: tf.Tensor, target: List[np.ndarray], - edge_factor: float = 2., ) -> tf.Tensor: """Compute linknet loss, BCE with boosted box edges or focal loss. Focal loss implementation based on `_. @@ -147,29 +146,25 @@ def compute_loss( Args: out_map: output feature map of the model of shape N x H x W x 1 target: list of dictionary where each dict has a `boxes` and a `flags` entry - edge_factor: boost factor for box edges (in case of BCE) Returns: A loss tensor """ - seg_target, seg_mask, edge_mask = self.build_target(target, out_map.shape[1:3]) + seg_target, seg_mask = self.build_target(target, out_map.shape[1:3]) seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype) seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool) - if edge_factor > 0: - edge_mask = tf.convert_to_tensor(edge_mask, dtype=tf.bool) - # Get the cross_entropy for each entry - loss = tf.keras.losses.binary_crossentropy(seg_target, out_map, from_logits=True)[..., None] + # BCE loss + bce_loss = tf.keras.losses.binary_crossentropy(seg_target, out_map, from_logits=True)[..., None] - # Compute BCE loss with highlighted edges - if edge_factor > 0: - loss = tf.math.multiply( - 1 + (edge_factor - 1) * tf.cast(edge_mask, out_map.dtype), - loss - ) + # Dice loss + prob_map = tf.math.sigmoid(out_map) + inter = tf.math.reduce_sum(prob_map[seg_mask] * seg_target[seg_mask]) + cardinality = tf.math.reduce_sum(prob_map[seg_mask] + seg_target[seg_mask]) + dice_loss = 1 - 2 * inter / (cardinality + 1e-8) - return tf.reduce_mean(loss[seg_mask]) + return tf.math.reduce_mean(bce_loss[seg_mask]) + dice_loss def call( self,