Skip to content

Commit

Permalink
fix: fix loss computation and make training work
Browse files Browse the repository at this point in the history
  • Loading branch information
aminemindee committed Sep 8, 2022
1 parent ecf8dae commit f9996cd
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 269 deletions.
18 changes: 9 additions & 9 deletions doctr/models/detection/differentiable_binarization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,35 +273,35 @@ def build_target(
target[i] = {"words": tgt}
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()
):
if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()):
raise ValueError("the 'boxes' entry of the target is expected to take values between 0 & 1.")

input_dtype = list(target[0].values())[0].dtype if len(target) > 0 else np.float32

h, w = output_shape[1:-1]
seg_target: np.ndarray = np.zeros(output_shape, dtype=np.uint8)
seg_mask: np.ndarray = np.ones(output_shape, dtype=bool)
thresh_target: np.ndarray = np.zeros(output_shape, dtype=np.float32)
thresh_mask: np.ndarray = np.ones(output_shape, dtype=np.uint8)

for idx, tgt in enumerate(target):
for class_idx, _target in enumerate(tgt.values()):
# Draw each polygon on gt
# Draw each polygon on gt
if _target.shape[0] == 0:
# Empty image, full masked
seg_mask[idx] = False
seg_mask[idx, :, :, class_idx] = False

# Absolute bounding boxes
abs_boxes = _target.copy()
if abs_boxes.ndim == 3:
abs_boxes[:, :, 0] *= output_shape[-1]
abs_boxes[:, :, 1] *= output_shape[-2]
abs_boxes[:, :, 0] *= w
abs_boxes[:, :, 1] *= h
polys = abs_boxes
boxes_size = np.linalg.norm(abs_boxes[:, 2, :] - abs_boxes[:, 0, :], axis=-1)
abs_boxes = np.concatenate((abs_boxes.min(1), abs_boxes.max(1)), -1).round().astype(np.int32)
else:
abs_boxes[:, [0, 2]] *= output_shape[-1]
abs_boxes[:, [1, 3]] *= output_shape[-2]
abs_boxes[:, [0, 2]] *= w
abs_boxes[:, [1, 3]] *= h
abs_boxes = abs_boxes.round().astype(np.int32)
polys = np.stack(
[
Expand Down Expand Up @@ -336,7 +336,7 @@ def build_target(
if shrinked.shape[0] <= 2 or not Polygon(shrinked).is_valid:
seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False
continue
cv2.fillPoly(seg_target[idx], [shrinked.astype(np.int32)], 1)
cv2.fillPoly(seg_target[idx, :, :, class_idx][..., None], [shrinked.astype(np.int32)], 1)

# Draw on both thresh map and thresh mask
poly, thresh_target[idx, :, :, class_idx], thresh_mask[idx, :, :, class_idx] = self.draw_thresh_map(
Expand Down
98 changes: 57 additions & 41 deletions doctr/models/detection/differentiable_binarization/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,21 @@ def __init__(

self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages)

def compute_loss(self, out_map: tf.Tensor, thresh_map: tf.Tensor, target: List[np.ndarray]) -> tf.Tensor:
def compute_loss(
self,
out_map: tf.Tensor,
thresh_map: tf.Tensor,
target: List[np.ndarray],
eps: float = 1e-8,
) -> tf.Tensor:
"""Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes
and a list of masks for each image. From there it computes the loss with the model output
Args:
out_map: output feature map of the model of shape (N, H, W, C)
thresh_map: threshold map of shape (N, H, W, C)
target: list of dictionary where each dict has a `boxes` and a `flags` entry
eps: epsilon factor in dice loss
Returns:
A loss tensor
Expand All @@ -178,45 +185,54 @@ def compute_loss(self, out_map: tf.Tensor, thresh_map: tf.Tensor, target: List[n
prob_map = tf.math.sigmoid(out_map)
thresh_map = tf.math.sigmoid(thresh_map)

seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape)
seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype)
seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool)
thresh_target = tf.convert_to_tensor(thresh_target, dtype=out_map.dtype)
thresh_mask = tf.convert_to_tensor(thresh_mask, dtype=tf.bool)

# Compute balanced BCE loss for proba_map
bce_scale = 5.0
bce_loss = tf.concat([
tf.keras.losses.binary_crossentropy(seg_target[:,:,:,idx: idx+1], out_map[:,:,:,idx: idx+1], from_logits=True)[..., None]
for idx in range(out_map.shape[-1])
], axis=-1)[seg_mask]
# bce_loss = tf.keras.losses.binary_crossentropy(seg_target, out_map, from_logits=True)[..., None][seg_mask]

neg_target = 1 - seg_target[seg_mask]
positive_count = tf.math.reduce_sum(seg_target[seg_mask])
negative_count = tf.math.reduce_min([tf.math.reduce_sum(neg_target), 3.0 * positive_count])
negative_loss = bce_loss * neg_target
negative_loss, _ = tf.nn.top_k(negative_loss, tf.cast(negative_count, tf.int32))
sum_losses = tf.math.reduce_sum(bce_loss * seg_target[seg_mask]) + tf.math.reduce_sum(negative_loss)
balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6)

# Compute dice loss for approxbin_map
bin_map = 1 / (1 + tf.exp(-50.0 * (prob_map[seg_mask] - thresh_map[seg_mask])))

bce_min = tf.math.reduce_min(bce_loss)
weights = (bce_loss - bce_min) / (tf.math.reduce_max(bce_loss) - bce_min) + 1.0
inter = tf.math.reduce_sum(bin_map * seg_target[seg_mask] * weights)
union = tf.math.reduce_sum(bin_map) + tf.math.reduce_sum(seg_target[seg_mask]) + 1e-8
dice_loss = 1 - 2.0 * inter / union

# Compute l1 loss for thresh_map
l1_scale = 10.0
if tf.reduce_any(thresh_mask):
l1_loss = tf.math.reduce_mean(tf.math.abs(thresh_map[thresh_mask] - thresh_target[thresh_mask]))
else:
l1_loss = tf.constant(0.0)

return l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss
seg_target_all, seg_mask_all, thresh_target_all, thresh_mask_all = self.build_target(target, out_map.shape)
seg_target_all = tf.convert_to_tensor(seg_target_all, dtype=out_map.dtype)
seg_mask_all = tf.convert_to_tensor(seg_mask_all, dtype=tf.bool)
thresh_target_all = tf.convert_to_tensor(thresh_target_all, dtype=out_map.dtype)
thresh_mask_all = tf.convert_to_tensor(thresh_mask_all, dtype=tf.bool)

final_loss = tf.convert_to_tensor(0, dtype=float)
for idx in range(out_map.shape[-1]):
seg_target = seg_target_all[..., idx]
seg_mask = seg_mask_all[..., idx]
thresh_target = thresh_target_all[..., idx]
thresh_mask = thresh_mask_all[..., idx]
_out_map = out_map[..., idx]
_thresh_map = thresh_map[..., idx]
_prob_map = prob_map[..., idx]
# Compute balanced BCE loss for proba_map
bce_scale = 5.0

bce_loss = tf.keras.losses.binary_crossentropy(
seg_target[..., None], _out_map[..., None], from_logits=True
)[seg_mask]

neg_target = 1 - seg_target[seg_mask]
positive_count = tf.math.reduce_sum(seg_target[seg_mask])
negative_count = tf.math.reduce_min([tf.math.reduce_sum(neg_target), 3.0 * positive_count])
negative_loss = bce_loss * neg_target
negative_loss, _ = tf.nn.top_k(negative_loss, tf.cast(negative_count, tf.int32))
sum_losses = tf.math.reduce_sum(bce_loss * seg_target[seg_mask]) + tf.math.reduce_sum(negative_loss)
balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6)

# Compute dice loss for approxbin_map
bin_map = 1 / (1 + tf.exp(-50.0 * (_prob_map[seg_mask] - _thresh_map[seg_mask])))

bce_min = tf.math.reduce_min(bce_loss)
weights = (bce_loss - bce_min) / (tf.math.reduce_max(bce_loss) - bce_min) + 1.0
inter = tf.math.reduce_sum(bin_map * seg_target[seg_mask] * weights)
union = tf.math.reduce_sum(bin_map) + tf.math.reduce_sum(seg_target[seg_mask]) + 1e-8
dice_loss = 1 - 2.0 * (inter + eps) / (union + eps)

# Compute l1 loss for thresh_map
l1_scale = 10.0
if tf.reduce_any(thresh_mask):
l1_loss = tf.math.reduce_mean(tf.math.abs(_thresh_map[thresh_mask] - thresh_target[thresh_mask]))
else:
l1_loss = tf.constant(0.0)

final_loss += l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss
return final_loss

def call(
self,
Expand Down Expand Up @@ -244,7 +260,7 @@ def call(

if target is None or return_preds:
# Post-process boxes (keep only text predictions)
out['preds'] = self.postprocessor(prob_map.numpy())
out["preds"] = self.postprocessor(prob_map.numpy())
# out["preds"] = [preds[0] for preds in self.postprocessor(prob_map.numpy())]

if target is not None:
Expand Down
Loading

0 comments on commit f9996cd

Please sign in to comment.