diff --git a/doctr/models/detection/linknet/base.py b/doctr/models/detection/linknet/base.py
index 17d425c56c..15bbb6dc81 100644
--- a/doctr/models/detection/linknet/base.py
+++ b/doctr/models/detection/linknet/base.py
@@ -107,7 +107,7 @@ def build_target(
self,
target: List[np.ndarray],
output_shape: Tuple[int, int],
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ ) -> Tuple[np.ndarray, np.ndarray]:
if any(t.dtype != np.float32 for t in target):
raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
@@ -119,7 +119,6 @@ def build_target(
if self.assume_straight_pages:
seg_target = np.zeros(target_shape, dtype=bool)
- edge_mask = np.zeros(target_shape, dtype=bool)
else:
seg_target = np.zeros(target_shape, dtype=np.uint8)
@@ -144,7 +143,12 @@ def build_target(
abs_boxes[:, [0, 2]] *= w
abs_boxes[:, [1, 3]] *= h
abs_boxes = abs_boxes.round().astype(np.int32)
- polys = [None] * abs_boxes.shape[0] # Unused
+ polys = np.stack([
+ abs_boxes[:, [0, 1]],
+ abs_boxes[:, [0, 3]],
+ abs_boxes[:, [2, 3]],
+ abs_boxes[:, [2, 1]],
+ ], axis=1)
boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1])
for poly, box, box_size in zip(polys, abs_boxes, boxes_size):
@@ -159,19 +163,10 @@ def build_target(
if box.shape == (4, 2):
box = [np.min(box[:, 0]), np.min(box[:, 1]), np.max(box[:, 0]), np.max(box[:, 1])]
seg_target[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = True
- # top edge
- edge_mask[idx, box[1], box[0]: min(box[2] + 1, w)] = True
- # bot edge
- edge_mask[idx, min(box[3], h - 1), box[0]: min(box[2] + 1, w)] = True
- # left edge
- edge_mask[idx, box[1]: min(box[3] + 1, h), box[0]] = True
- # right edge
- edge_mask[idx, box[1]: min(box[3] + 1, h), min(box[2], w - 1)] = True
# Don't forget to switch back to channel first if PyTorch is used
if not is_tf_available():
seg_target = seg_target.transpose(0, 3, 1, 2)
seg_mask = seg_mask.transpose(0, 3, 1, 2)
- edge_mask = edge_mask.transpose(0, 3, 1, 2)
- return seg_target, seg_mask, edge_mask
+ return seg_target, seg_mask
diff --git a/doctr/models/detection/linknet/pytorch.py b/doctr/models/detection/linknet/pytorch.py
index fc50691bd3..dc55a0862d 100644
--- a/doctr/models/detection/linknet/pytorch.py
+++ b/doctr/models/detection/linknet/pytorch.py
@@ -158,7 +158,6 @@ def compute_loss(
self,
out_map: torch.Tensor,
target: List[np.ndarray],
- edge_factor: float = 2.,
) -> torch.Tensor:
"""Compute linknet loss, BCE with boosted box edges or focal loss. Focal loss implementation based on
`_.
@@ -166,26 +165,26 @@ def compute_loss(
Args:
out_map: output feature map of the model of shape (N, 1, H, W)
target: list of dictionary where each dict has a `boxes` and a `flags` entry
- edge_factor: boost factor for box edges (in case of BCE)
Returns:
A loss tensor
"""
- seg_target, seg_mask, edge_mask = self.build_target(target, out_map.shape[-2:]) # type: ignore[arg-type]
+ seg_target, seg_mask = self.build_target(target, out_map.shape[-2:]) # type: ignore[arg-type]
seg_target, seg_mask = torch.from_numpy(seg_target).to(dtype=out_map.dtype), torch.from_numpy(seg_mask)
seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
- if edge_factor > 0:
- edge_mask = torch.from_numpy(edge_mask).to(dtype=out_map.dtype, device=out_map.device)
- # Get the cross_entropy for each entry
- loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction='none')
+ # BCE loss
+ bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction='none')
+
+ # Dice loss
+ prob_map = torch.nn.functional.sigmoid(out_map)
+ inter = (prob_map[seg_mask] * seg_target[seg_mask]).sum()
+ cardinality = (prob_map[seg_mask] + seg_target[seg_mask]).sum()
+ dice_loss = 1 - 2 * inter / (cardinality + 1e-8)
- # Compute BCE loss with highlighted edges
- if edge_factor > 0:
- loss = ((1 + (edge_factor - 1) * edge_mask) * loss)
# Only consider contributions overlaping the mask
- return loss[seg_mask].mean()
+ return bce_loss[seg_mask].mean() + dice_loss
def _linknet(
diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py
index 792e503c7b..1c1b40295a 100644
--- a/doctr/models/detection/linknet/tensorflow.py
+++ b/doctr/models/detection/linknet/tensorflow.py
@@ -139,7 +139,6 @@ def compute_loss(
self,
out_map: tf.Tensor,
target: List[np.ndarray],
- edge_factor: float = 2.,
) -> tf.Tensor:
"""Compute linknet loss, BCE with boosted box edges or focal loss. Focal loss implementation based on
`_.
@@ -147,29 +146,25 @@ def compute_loss(
Args:
out_map: output feature map of the model of shape N x H x W x 1
target: list of dictionary where each dict has a `boxes` and a `flags` entry
- edge_factor: boost factor for box edges (in case of BCE)
Returns:
A loss tensor
"""
- seg_target, seg_mask, edge_mask = self.build_target(target, out_map.shape[1:3])
+ seg_target, seg_mask = self.build_target(target, out_map.shape[1:3])
seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype)
seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool)
- if edge_factor > 0:
- edge_mask = tf.convert_to_tensor(edge_mask, dtype=tf.bool)
- # Get the cross_entropy for each entry
- loss = tf.keras.losses.binary_crossentropy(seg_target, out_map, from_logits=True)[..., None]
+ # BCE loss
+ bce_loss = tf.keras.losses.binary_crossentropy(seg_target, out_map, from_logits=True)[..., None]
- # Compute BCE loss with highlighted edges
- if edge_factor > 0:
- loss = tf.math.multiply(
- 1 + (edge_factor - 1) * tf.cast(edge_mask, out_map.dtype),
- loss
- )
+ # Dice loss
+ prob_map = tf.math.sigmoid(out_map)
+ inter = tf.math.reduce_sum(prob_map[seg_mask] * seg_target[seg_mask])
+ cardinality = tf.math.reduce_sum(prob_map[seg_mask] + seg_target[seg_mask])
+ dice_loss = 1 - 2 * inter / (cardinality + 1e-8)
- return tf.reduce_mean(loss[seg_mask])
+ return tf.math.reduce_mean(bce_loss[seg_mask]) + dice_loss
def call(
self,