Skip to content

Commit

Permalink
feat: make loss computation vectorized and change target building to …
Browse files Browse the repository at this point in the history
…handle better class ids
  • Loading branch information
aminemindee committed Sep 19, 2022
1 parent 82b051b commit 6dd1a0f
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 108 deletions.
4 changes: 4 additions & 0 deletions doctr/datasets/datasets/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def _read_sample(self, index: int) -> Tuple[tf.Tensor, Any]:
if isinstance(target, dict):
assert "boxes" in target, "Target should contain 'boxes' key"
assert "labels" in target, "Target should contain 'labels' key"
elif isinstance(target, tuple):
assert isinstance(target[0], str) or isinstance(
target[0], np.ndarray
), "Target should be a string or a numpy array"
else:
assert isinstance(target, str) or isinstance(
target, np.ndarray
Expand Down
39 changes: 25 additions & 14 deletions doctr/models/detection/differentiable_binarization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization

from typing import List, Tuple, Union, Dict
from typing import Dict, List, Tuple, Union

import cv2
import numpy as np
Expand Down Expand Up @@ -268,9 +268,11 @@ def build_target(
output_shape: Tuple[int, int, int, int],
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:

for i, tgt in enumerate(target):
new_target = []
for tgt in target:
if isinstance(tgt, np.ndarray):
target[i] = {"words": tgt}
new_target.append({"words": tgt})
target = new_target.copy()
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()):
Expand All @@ -279,17 +281,19 @@ def build_target(
input_dtype = list(target[0].values())[0].dtype if len(target) > 0 else np.float32

h, w = output_shape[1:-1]
seg_target: np.ndarray = np.zeros(output_shape, dtype=np.uint8)
seg_mask: np.ndarray = np.ones(output_shape, dtype=bool)
thresh_target: np.ndarray = np.zeros(output_shape, dtype=np.float32)
thresh_mask: np.ndarray = np.ones(output_shape, dtype=np.uint8)
target_shape = (output_shape[0], output_shape[-1], h, w)
seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
seg_mask: np.ndarray = np.ones(target_shape, dtype=bool)
thresh_target: np.ndarray = np.zeros(target_shape, dtype=np.float32)
thresh_mask: np.ndarray = np.ones(target_shape, dtype=np.uint8)

for idx, tgt in enumerate(target):
for class_idx, _target in enumerate(tgt.values()):
# Draw each polygon on gt
if _target.shape[0] == 0:
# Empty image, full masked
seg_mask[idx, :, :, class_idx] = False
# seg_mask[idx, :, :, class_idx] = False
seg_mask[idx, class_idx] = False

# Absolute bounding boxes
abs_boxes = _target.copy()
Expand Down Expand Up @@ -317,7 +321,8 @@ def build_target(
for box, box_size, poly in zip(abs_boxes, boxes_size, polys):
# Mask boxes that are too small
if box_size < self.min_size_box:
seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
# seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
continue

# Negative shrink for gt, as described in paper
Expand All @@ -330,18 +335,24 @@ def build_target(

# Draw polygon on gt if it is valid
if len(shrinked) == 0:
seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
# seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
continue
shrinked = np.array(shrinked[0]).reshape(-1, 2)
if shrinked.shape[0] <= 2 or not Polygon(shrinked).is_valid:
seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
# seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
continue
cv2.fillPoly(seg_target[idx, :, :, class_idx][..., None], [shrinked.astype(np.int32)], 1)
cv2.fillPoly(seg_target[idx, class_idx], [shrinked.astype(np.int32)], 1)

# Draw on both thresh map and thresh mask
poly, thresh_target[idx, :, :, class_idx], thresh_mask[idx, :, :, class_idx] = self.draw_thresh_map(
poly, thresh_target[idx, :, :, class_idx], thresh_mask[idx, :, :, class_idx]
poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] = self.draw_thresh_map(
poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx]
)
seg_target = seg_target.transpose((0, 2, 3, 1))
seg_mask = seg_mask.transpose((0, 2, 3, 1))
thresh_target = thresh_target.transpose((0, 2, 3, 1))
thresh_mask = thresh_mask.transpose((0, 2, 3, 1))

thresh_target = thresh_target.astype(input_dtype) * (self.thresh_max - self.thresh_min) + self.thresh_min

Expand Down
139 changes: 91 additions & 48 deletions doctr/models/detection/differentiable_binarization/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,59 +180,102 @@ def compute_loss(
A loss tensor
"""

# prob_map = tf.math.sigmoid(out_map)
# thresh_map = tf.math.sigmoid(thresh_map)
#
# seg_target_all, seg_mask_all, thresh_target_all, thresh_mask_all = self.build_target(target, out_map.shape)
# seg_target_all = tf.convert_to_tensor(seg_target_all, dtype=out_map.dtype)
# seg_mask_all = tf.convert_to_tensor(seg_mask_all, dtype=tf.bool)
# thresh_target_all = tf.convert_to_tensor(thresh_target_all, dtype=out_map.dtype)
# thresh_mask_all = tf.convert_to_tensor(thresh_mask_all, dtype=tf.bool)

# final_loss = tf.convert_to_tensor(0, dtype=float)
# for idx in range(out_map.shape[-1]):
# seg_target = seg_target_all[..., idx]
# seg_mask = seg_mask_all[..., idx]
# thresh_target = thresh_target_all[..., idx]
# thresh_mask = thresh_mask_all[..., idx]
# _out_map = out_map[..., idx]
# _thresh_map = thresh_map[..., idx]
# _prob_map = prob_map[..., idx]
# # Compute balanced BCE loss for proba_map
# bce_scale = 5.0
#
# bce_loss = tf.keras.losses.binary_crossentropy(
# seg_target[..., None], _out_map[..., None], from_logits=True
# )[seg_mask]
#
# neg_target = 1 - seg_target[seg_mask]
# positive_count = tf.math.reduce_sum(seg_target[seg_mask])
# negative_count = tf.math.reduce_min([tf.math.reduce_sum(neg_target), 3.0 * positive_count])
# negative_loss = bce_loss * neg_target
# negative_loss, _ = tf.nn.top_k(negative_loss, tf.cast(negative_count, tf.int32))
# sum_losses = tf.math.reduce_sum(bce_loss * seg_target[seg_mask]) + tf.math.reduce_sum(negative_loss)
# balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6)
#
# # Compute dice loss for approxbin_map
# bin_map = 1 / (1 + tf.exp(-50.0 * (_prob_map[seg_mask] - _thresh_map[seg_mask])))
#
# bce_min = tf.math.reduce_min(bce_loss)
# weights = (bce_loss - bce_min) / (tf.math.reduce_max(bce_loss) - bce_min) + 1.0
# inter = tf.math.reduce_sum(bin_map * seg_target[seg_mask] * weights)
# union = tf.math.reduce_sum(bin_map) + tf.math.reduce_sum(seg_target[seg_mask]) + 1e-8
# dice_loss = 1 - 2.0 * (inter + eps) / (union + eps)
#
# # Compute l1 loss for thresh_map
# l1_scale = 10.0
# if tf.reduce_any(thresh_mask):
# l1_loss = tf.math.reduce_mean(tf.math.abs(_thresh_map[thresh_mask] - thresh_target[thresh_mask]))
# else:
# l1_loss = tf.constant(0.0)
#
# final_loss += l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss
# return final_loss

# prob_map = tf.math.sigmoid(tf.squeeze(out_map, axis=[-1]))
# thresh_map = tf.math.sigmoid(tf.squeeze(thresh_map, axis=[-1]))
prob_map = tf.math.sigmoid(out_map)
thresh_map = tf.math.sigmoid(thresh_map)

seg_target_all, seg_mask_all, thresh_target_all, thresh_mask_all = self.build_target(target, out_map.shape)
seg_target_all = tf.convert_to_tensor(seg_target_all, dtype=out_map.dtype)
seg_mask_all = tf.convert_to_tensor(seg_mask_all, dtype=tf.bool)
thresh_target_all = tf.convert_to_tensor(thresh_target_all, dtype=out_map.dtype)
thresh_mask_all = tf.convert_to_tensor(thresh_mask_all, dtype=tf.bool)

final_loss = tf.convert_to_tensor(0, dtype=float)
for idx in range(out_map.shape[-1]):
seg_target = seg_target_all[..., idx]
seg_mask = seg_mask_all[..., idx]
thresh_target = thresh_target_all[..., idx]
thresh_mask = thresh_mask_all[..., idx]
_out_map = out_map[..., idx]
_thresh_map = thresh_map[..., idx]
_prob_map = prob_map[..., idx]
# Compute balanced BCE loss for proba_map
bce_scale = 5.0

bce_loss = tf.keras.losses.binary_crossentropy(
seg_target[..., None], _out_map[..., None], from_logits=True
)[seg_mask]

neg_target = 1 - seg_target[seg_mask]
positive_count = tf.math.reduce_sum(seg_target[seg_mask])
negative_count = tf.math.reduce_min([tf.math.reduce_sum(neg_target), 3.0 * positive_count])
negative_loss = bce_loss * neg_target
negative_loss, _ = tf.nn.top_k(negative_loss, tf.cast(negative_count, tf.int32))
sum_losses = tf.math.reduce_sum(bce_loss * seg_target[seg_mask]) + tf.math.reduce_sum(negative_loss)
balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6)

# Compute dice loss for approxbin_map
bin_map = 1 / (1 + tf.exp(-50.0 * (_prob_map[seg_mask] - _thresh_map[seg_mask])))

bce_min = tf.math.reduce_min(bce_loss)
weights = (bce_loss - bce_min) / (tf.math.reduce_max(bce_loss) - bce_min) + 1.0
inter = tf.math.reduce_sum(bin_map * seg_target[seg_mask] * weights)
union = tf.math.reduce_sum(bin_map) + tf.math.reduce_sum(seg_target[seg_mask]) + 1e-8
dice_loss = 1 - 2.0 * (inter + eps) / (union + eps)

# Compute l1 loss for thresh_map
l1_scale = 10.0
if tf.reduce_any(thresh_mask):
l1_loss = tf.math.reduce_mean(tf.math.abs(_thresh_map[thresh_mask] - thresh_target[thresh_mask]))
else:
l1_loss = tf.constant(0.0)

final_loss += l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss
return final_loss
seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape)
seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype)
seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool)
thresh_target = tf.convert_to_tensor(thresh_target, dtype=out_map.dtype)
thresh_mask = tf.convert_to_tensor(thresh_mask, dtype=tf.bool)

# Compute balanced BCE loss for proba_map
bce_scale = 5.0
bce_loss = tf.keras.losses.binary_crossentropy(
seg_target[..., None],
out_map[..., None],
from_logits=True,
)[seg_mask]

neg_target = 1 - seg_target[seg_mask]
positive_count = tf.math.reduce_sum(seg_target[seg_mask])
negative_count = tf.math.reduce_min([tf.math.reduce_sum(neg_target), 3.0 * positive_count])
negative_loss = bce_loss * neg_target
negative_loss, _ = tf.nn.top_k(negative_loss, tf.cast(negative_count, tf.int32))
sum_losses = tf.math.reduce_sum(bce_loss * seg_target[seg_mask]) + tf.math.reduce_sum(negative_loss)
balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6)

# Compute dice loss for approxbin_map
bin_map = 1 / (1 + tf.exp(-50.0 * (prob_map[seg_mask] - thresh_map[seg_mask])))

bce_min = tf.math.reduce_min(bce_loss)
weights = (bce_loss - bce_min) / (tf.math.reduce_max(bce_loss) - bce_min) + 1.0
inter = tf.math.reduce_sum(bin_map * seg_target[seg_mask] * weights)
union = tf.math.reduce_sum(bin_map) + tf.math.reduce_sum(seg_target[seg_mask]) + 1e-8
dice_loss = 1 - 2.0 * inter / union

# Compute l1 loss for thresh_map
l1_scale = 10.0
if tf.reduce_any(thresh_mask):
l1_loss = tf.math.reduce_mean(tf.math.abs(thresh_map[thresh_mask] - thresh_target[thresh_mask]))
else:
l1_loss = tf.constant(0.0)

return l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss

def call(
self,
Expand Down
28 changes: 15 additions & 13 deletions doctr/models/detection/linknet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization

from typing import List, Tuple, Union, Dict
from typing import Dict, List, Tuple, Union

import cv2
import numpy as np
Expand Down Expand Up @@ -160,16 +160,18 @@ def build_target(
output_shape: Tuple[int, int, int],
) -> Tuple[np.ndarray, np.ndarray]:

for i, tgt in enumerate(target):
new_target = []
for tgt in target:
if isinstance(tgt, np.ndarray):
target[i] = {"words": tgt}
new_target.append({"words": tgt})
target = new_target.copy()
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()):
raise ValueError("the 'boxes' entry of the target is expected to take values between 0 & 1.")

h, w, num_classes = output_shape
target_shape = (len(target), h, w, num_classes)
target_shape = (len(target), num_classes, h, w)

seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
seg_mask: np.ndarray = np.ones(target_shape, dtype=bool)
Expand All @@ -179,7 +181,7 @@ def build_target(
# Draw each polygon on gt
if _target.shape[0] == 0:
# Empty image, full masked
seg_mask[idx, :, :, class_idx] = False
seg_mask[idx, class_idx] = False

# Absolute bounding boxes
abs_boxes = _target.copy()
Expand Down Expand Up @@ -208,7 +210,7 @@ def build_target(
for poly, box, box_size in zip(polys, abs_boxes, boxes_size):
# Mask boxes that are too small
if box_size < self.min_size_box:
seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
continue

# Negative shrink for gt, as described in paper
Expand All @@ -221,17 +223,17 @@ def build_target(

# Draw polygon on gt if it is valid
if len(shrunken) == 0:
seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
continue
shrunken = np.array(shrunken[0]).reshape(-1, 2)
if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
continue
cv2.fillPoly(seg_target[idx, :, :, class_idx][..., None], [shrunken.astype(np.int32)], 1)
cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1)

# Don't forget to switch back to channel first if PyTorch is used
if not is_tf_available():
seg_target = seg_target.transpose(0, 3, 1, 2)
seg_mask = seg_mask.transpose(0, 3, 1, 2)
# Don't forget to switch back to channel last if Tensorflow is used
if is_tf_available():
seg_target = seg_target.transpose((0, 2, 3, 1))
seg_mask = seg_mask.transpose((0, 2, 3, 1))

return seg_target, seg_mask
Loading

0 comments on commit 6dd1a0f

Please sign in to comment.