Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add dice loss in linknet #816

Merged
merged 3 commits into from
Feb 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions doctr/models/detection/linknet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def build_target(
self,
target: List[np.ndarray],
output_shape: Tuple[int, int],
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
) -> Tuple[np.ndarray, np.ndarray]:

if any(t.dtype != np.float32 for t in target):
raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
Expand All @@ -119,7 +119,6 @@ def build_target(

if self.assume_straight_pages:
seg_target = np.zeros(target_shape, dtype=bool)
edge_mask = np.zeros(target_shape, dtype=bool)
else:
seg_target = np.zeros(target_shape, dtype=np.uint8)

Expand All @@ -144,7 +143,12 @@ def build_target(
abs_boxes[:, [0, 2]] *= w
abs_boxes[:, [1, 3]] *= h
abs_boxes = abs_boxes.round().astype(np.int32)
polys = [None] * abs_boxes.shape[0] # Unused
polys = np.stack([
abs_boxes[:, [0, 1]],
abs_boxes[:, [0, 3]],
abs_boxes[:, [2, 3]],
abs_boxes[:, [2, 1]],
], axis=1)
boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1])

for poly, box, box_size in zip(polys, abs_boxes, boxes_size):
Expand All @@ -159,19 +163,10 @@ def build_target(
if box.shape == (4, 2):
box = [np.min(box[:, 0]), np.min(box[:, 1]), np.max(box[:, 0]), np.max(box[:, 1])]
seg_target[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = True
# top edge
edge_mask[idx, box[1], box[0]: min(box[2] + 1, w)] = True
# bot edge
edge_mask[idx, min(box[3], h - 1), box[0]: min(box[2] + 1, w)] = True
# left edge
edge_mask[idx, box[1]: min(box[3] + 1, h), box[0]] = True
# right edge
edge_mask[idx, box[1]: min(box[3] + 1, h), min(box[2], w - 1)] = True

# Don't forget to switch back to channel first if PyTorch is used
if not is_tf_available():
seg_target = seg_target.transpose(0, 3, 1, 2)
seg_mask = seg_mask.transpose(0, 3, 1, 2)
edge_mask = edge_mask.transpose(0, 3, 1, 2)

return seg_target, seg_mask, edge_mask
return seg_target, seg_mask
21 changes: 10 additions & 11 deletions doctr/models/detection/linknet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,34 +158,33 @@ def compute_loss(
self,
out_map: torch.Tensor,
target: List[np.ndarray],
edge_factor: float = 2.,
) -> torch.Tensor:
"""Compute linknet loss, BCE with boosted box edges or focal loss. Focal loss implementation based on
<https://github.com/tensorflow/addons/>`_.

Args:
out_map: output feature map of the model of shape (N, 1, H, W)
target: list of dictionary where each dict has a `boxes` and a `flags` entry
edge_factor: boost factor for box edges (in case of BCE)

Returns:
A loss tensor
"""
seg_target, seg_mask, edge_mask = self.build_target(target, out_map.shape[-2:]) # type: ignore[arg-type]
seg_target, seg_mask = self.build_target(target, out_map.shape[-2:]) # type: ignore[arg-type]

seg_target, seg_mask = torch.from_numpy(seg_target).to(dtype=out_map.dtype), torch.from_numpy(seg_mask)
seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
if edge_factor > 0:
edge_mask = torch.from_numpy(edge_mask).to(dtype=out_map.dtype, device=out_map.device)

# Get the cross_entropy for each entry
loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction='none')
# BCE loss
bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction='none')

# Dice loss
prob_map = torch.nn.functional.sigmoid(out_map)
inter = (prob_map[seg_mask] * seg_target[seg_mask]).sum()
cardinality = (prob_map[seg_mask] + seg_target[seg_mask]).sum()
dice_loss = 1 - 2 * inter / (cardinality + 1e-8)
Comment on lines +182 to +184
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll have to update this in another PR because it highly advantages classes that occupy the biggest area

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will do it in the next PR !

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To curb this, may be FL + Dice Loss could be a better option? Because, even if the area was same for the classes, since total area of combined bboxes / total area of image could be really less, it may tend to learn more negatives.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, I meant that the baseline option would be to average the loss over classes (not mixing the contribution of each class to the loss before the last step). But your suggestion @SiddhantBahuguna is worth exploring as a follow-up and could improve results 👍

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestions, I will add focal loss in the next PR !


# Compute BCE loss with highlighted edges
if edge_factor > 0:
loss = ((1 + (edge_factor - 1) * edge_mask) * loss)
# Only consider contributions overlaping the mask
return loss[seg_mask].mean()
return bce_loss[seg_mask].mean() + dice_loss


def _linknet(
Expand Down
23 changes: 9 additions & 14 deletions doctr/models/detection/linknet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,37 +139,32 @@ def compute_loss(
self,
out_map: tf.Tensor,
target: List[np.ndarray],
edge_factor: float = 2.,
) -> tf.Tensor:
"""Compute linknet loss, BCE with boosted box edges or focal loss. Focal loss implementation based on
<https://github.com/tensorflow/addons/>`_.

Args:
out_map: output feature map of the model of shape N x H x W x 1
target: list of dictionary where each dict has a `boxes` and a `flags` entry
edge_factor: boost factor for box edges (in case of BCE)

Returns:
A loss tensor
"""
seg_target, seg_mask, edge_mask = self.build_target(target, out_map.shape[1:3])
seg_target, seg_mask = self.build_target(target, out_map.shape[1:3])

seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype)
seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool)
if edge_factor > 0:
edge_mask = tf.convert_to_tensor(edge_mask, dtype=tf.bool)

# Get the cross_entropy for each entry
loss = tf.keras.losses.binary_crossentropy(seg_target, out_map, from_logits=True)[..., None]
# BCE loss
bce_loss = tf.keras.losses.binary_crossentropy(seg_target, out_map, from_logits=True)[..., None]

# Compute BCE loss with highlighted edges
if edge_factor > 0:
loss = tf.math.multiply(
1 + (edge_factor - 1) * tf.cast(edge_mask, out_map.dtype),
loss
)
# Dice loss
prob_map = tf.math.sigmoid(out_map)
inter = tf.math.reduce_sum(prob_map[seg_mask] * seg_target[seg_mask])
cardinality = tf.math.reduce_sum(prob_map[seg_mask] + seg_target[seg_mask])
dice_loss = 1 - 2 * inter / (cardinality + 1e-8)

return tf.reduce_mean(loss[seg_mask])
return tf.math.reduce_mean(bce_loss[seg_mask]) + dice_loss

def call(
self,
Expand Down