Skip to content

Commit

Permalink
add multi class intergration in prediction pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
aminemindee committed Sep 19, 2022
1 parent 6dd1a0f commit e33889b
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 151 deletions.
8 changes: 3 additions & 5 deletions doctr/datasets/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union

import numpy as np

from doctr.io.image import get_img_shape
from doctr.utils.data import download_from_url

Expand Down Expand Up @@ -57,11 +55,11 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:
img = self.img_transforms(img)

if self.sample_transforms is not None:
if isinstance(target, np.ndarray):
img, target = self.sample_transforms(img, target)
elif isinstance(target, dict):
if isinstance(target, dict):
for k, v in target.items():
img, target[k] = self.sample_transforms(img, v)
else:
img, target = self.sample_transforms(img, target)

return img, target

Expand Down
9 changes: 7 additions & 2 deletions doctr/io/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ def export(self) -> Dict[str, Any]:

export_dict = {k: getattr(self, k) for k in self._exported_keys}
for children_name in self._children_names:
export_dict[children_name] = [c.export() for c in getattr(self, children_name)]
if children_name in ["blocks"]:
export_dict[children_name] = {
k: [item.export() for item in c] for k, c in getattr(self, children_name).items()
}
else:
export_dict[children_name] = [c.export() for c in getattr(self, children_name)]

return export_dict

Expand Down Expand Up @@ -228,7 +233,7 @@ class Page(Element):

def __init__(
self,
blocks: List[Block],
blocks: Dict[str, List[Block]],
page_idx: int,
dimensions: Tuple[int, int],
orientation: Optional[Dict[str, Any]] = None,
Expand Down
77 changes: 72 additions & 5 deletions doctr/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
from scipy.cluster.hierarchy import fclusterdata

from doctr.file_utils import is_tf_available
from doctr.io.elements import Block, Document, Line, Page, Word
from doctr.utils.geometry import estimate_page_angle, resolve_enclosing_bbox, resolve_enclosing_rbbox, rotate_boxes
from doctr.utils.repr import NestedObject
Expand Down Expand Up @@ -39,6 +40,10 @@ def __init__(
self.resolve_blocks = resolve_blocks
self.paragraph_break = paragraph_break
self.export_as_straight_boxes = export_as_straight_boxes
if is_tf_available():
self.__call__ = self.tf_call
else:
self.__call__ = self.torch_call # type: ignore[assignment]

@staticmethod
def _sort_boxes(boxes: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -279,7 +284,7 @@ def extra_repr(self) -> str:
f"export_as_straight_boxes={self.export_as_straight_boxes}"
)

def __call__(
def torch_call(
self,
boxes: List[np.ndarray],
text_preds: List[List[Tuple[str, float]]],
Expand Down Expand Up @@ -317,10 +322,72 @@ def __call__(

_pages = [
Page(
self._build_blocks(
page_boxes,
word_preds,
),
{
"words": self._build_blocks(
page_boxes,
word_preds,
)
},
_idx,
shape,
orientation,
language,
)
for _idx, shape, page_boxes, word_preds, orientation, language in zip(
range(len(boxes)), page_shapes, boxes, text_preds, _orientations, _languages
)
]

return Document(_pages)

def tf_call(
self,
boxes: List[Dict[str, np.ndarray]],
text_preds: List[Dict[str, List[Tuple[str, float]]]],
page_shapes: List[Tuple[int, int]],
orientations: Optional[List[Dict[str, Any]]] = None,
languages: Optional[List[Dict[str, Any]]] = None,
) -> Document:
"""Re-arrange detected words into structured blocks
Args:
boxes: list of N elements, where each element represents the localization predictions, of shape (*, 5)
or (*, 6) for all words for a given page
text_preds: list of N elements, where each element is the list of all word prediction (text + confidence)
page_shape: shape of each page, of size N
Returns:
document object
"""
if len(boxes) != len(text_preds) or len(boxes) != len(page_shapes):
raise ValueError("All arguments are expected to be lists of the same size")

_orientations = (
orientations if isinstance(orientations, list) else [None] * len(boxes) # type: ignore[list-item]
)
_languages = languages if isinstance(languages, list) else [None] * len(boxes) # type: ignore[list-item]
if self.export_as_straight_boxes and len(boxes) > 0:
# If boxes are already straight OK, else fit a bounding rect
if list(boxes[0].values())[0].ndim == 3:
straight_boxes: List[Dict[str, np.ndarray]] = []
# Iterate over pages
for p_boxes in boxes:
# Iterate over boxes of the pages
straight_boxes_dict = {}
for k, box in p_boxes.items():
straight_boxes_dict[k] = np.concatenate((box.min(1), box.max(1)), 1)
straight_boxes.append(straight_boxes_dict)
boxes = straight_boxes

_pages = [
Page(
{
k: self._build_blocks(
page_boxes[k],
word_preds[k],
)
for k in page_boxes.keys()
},
_idx,
shape,
orientation,
Expand Down
64 changes: 8 additions & 56 deletions doctr/models/detection/differentiable_binarization/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,13 @@ def __init__(
assume_straight_pages: bool = True,
exportable: bool = False,
cfg: Optional[Dict[str, Any]] = None,
class_names: List[str] = ["words"],
) -> None:

super().__init__()
self.class_names = class_names
if cfg and cfg.get("class_names"):
self.class_names = cfg["class_names"]
self.cfg = cfg

self.feat_extractor = feature_extractor
Expand Down Expand Up @@ -180,60 +184,6 @@ def compute_loss(
A loss tensor
"""

# prob_map = tf.math.sigmoid(out_map)
# thresh_map = tf.math.sigmoid(thresh_map)
#
# seg_target_all, seg_mask_all, thresh_target_all, thresh_mask_all = self.build_target(target, out_map.shape)
# seg_target_all = tf.convert_to_tensor(seg_target_all, dtype=out_map.dtype)
# seg_mask_all = tf.convert_to_tensor(seg_mask_all, dtype=tf.bool)
# thresh_target_all = tf.convert_to_tensor(thresh_target_all, dtype=out_map.dtype)
# thresh_mask_all = tf.convert_to_tensor(thresh_mask_all, dtype=tf.bool)

# final_loss = tf.convert_to_tensor(0, dtype=float)
# for idx in range(out_map.shape[-1]):
# seg_target = seg_target_all[..., idx]
# seg_mask = seg_mask_all[..., idx]
# thresh_target = thresh_target_all[..., idx]
# thresh_mask = thresh_mask_all[..., idx]
# _out_map = out_map[..., idx]
# _thresh_map = thresh_map[..., idx]
# _prob_map = prob_map[..., idx]
# # Compute balanced BCE loss for proba_map
# bce_scale = 5.0
#
# bce_loss = tf.keras.losses.binary_crossentropy(
# seg_target[..., None], _out_map[..., None], from_logits=True
# )[seg_mask]
#
# neg_target = 1 - seg_target[seg_mask]
# positive_count = tf.math.reduce_sum(seg_target[seg_mask])
# negative_count = tf.math.reduce_min([tf.math.reduce_sum(neg_target), 3.0 * positive_count])
# negative_loss = bce_loss * neg_target
# negative_loss, _ = tf.nn.top_k(negative_loss, tf.cast(negative_count, tf.int32))
# sum_losses = tf.math.reduce_sum(bce_loss * seg_target[seg_mask]) + tf.math.reduce_sum(negative_loss)
# balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6)
#
# # Compute dice loss for approxbin_map
# bin_map = 1 / (1 + tf.exp(-50.0 * (_prob_map[seg_mask] - _thresh_map[seg_mask])))
#
# bce_min = tf.math.reduce_min(bce_loss)
# weights = (bce_loss - bce_min) / (tf.math.reduce_max(bce_loss) - bce_min) + 1.0
# inter = tf.math.reduce_sum(bin_map * seg_target[seg_mask] * weights)
# union = tf.math.reduce_sum(bin_map) + tf.math.reduce_sum(seg_target[seg_mask]) + 1e-8
# dice_loss = 1 - 2.0 * (inter + eps) / (union + eps)
#
# # Compute l1 loss for thresh_map
# l1_scale = 10.0
# if tf.reduce_any(thresh_mask):
# l1_loss = tf.math.reduce_mean(tf.math.abs(_thresh_map[thresh_mask] - thresh_target[thresh_mask]))
# else:
# l1_loss = tf.constant(0.0)
#
# final_loss += l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss
# return final_loss

# prob_map = tf.math.sigmoid(tf.squeeze(out_map, axis=[-1]))
# thresh_map = tf.math.sigmoid(tf.squeeze(thresh_map, axis=[-1]))
prob_map = tf.math.sigmoid(out_map)
thresh_map = tf.math.sigmoid(thresh_map)

Expand Down Expand Up @@ -303,8 +253,10 @@ def call(

if target is None or return_preds:
# Post-process boxes (keep only text predictions)
out["preds"] = self.postprocessor(prob_map.numpy())
# out["preds"] = [preds[0] for preds in self.postprocessor(prob_map.numpy())]
out["preds"] = [
{class_name: p for class_name, p in zip(self.class_names, preds)}
for preds in self.postprocessor(prob_map.numpy())
]

if target is not None:
thresh_map = self.threshold_head(feat_concat, **kwargs)
Expand Down
11 changes: 9 additions & 2 deletions doctr/models/detection/linknet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,14 @@ def __init__(
assume_straight_pages: bool = True,
exportable: bool = False,
cfg: Optional[Dict[str, Any]] = None,
class_names: List[str] = ["words"],
) -> None:
super().__init__(cfg=cfg)

self.class_names = class_names
if self.cfg and self.cfg.get("class_names"):
self.class_names = self.cfg["class_names"]

self.exportable = exportable
self.assume_straight_pages = assume_straight_pages

Expand Down Expand Up @@ -231,8 +236,10 @@ def call(

if target is None or return_preds:
# Post-process boxes
out["preds"] = self.postprocessor(prob_map.numpy())
# out["preds"] = [preds[0] for preds in self.postprocessor(prob_map.numpy())]
out["preds"] = [
{class_name: p for class_name, p in zip(self.class_names, preds)}
for preds in self.postprocessor(prob_map.numpy())
]

if target is not None:
loss = self.compute_loss(logits, target)
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/detection/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Any, List, Union
from typing import Any, Dict, List, Union

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -38,7 +38,7 @@ def __call__(
self,
pages: List[Union[np.ndarray, tf.Tensor]],
**kwargs: Any,
) -> List[np.ndarray]:
) -> List[Dict[str, np.ndarray]]:

# Dimension check
if any(page.ndim != 3 for page in pages):
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def forward(
for page_boxes, page, angle, mask in zip(boxes, pages, origin_page_orientations, origin_page_shapes)
]

out = self.doc_builder(
out = self.doc_builder.torch_call(
boxes,
text_preds,
[page.shape[:2] if channels_last else page.shape[-2:] for page in pages], # type: ignore[misc]
Expand Down
Loading

0 comments on commit e33889b

Please sign in to comment.