Skip to content

Commit

Permalink
feat: add multiclass to pytorch and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aminemindee committed Sep 22, 2022
1 parent 771228a commit 2e2a281
Show file tree
Hide file tree
Showing 18 changed files with 319 additions and 259 deletions.
4 changes: 4 additions & 0 deletions doctr/datasets/datasets/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def _read_sample(self, index: int) -> Tuple[torch.Tensor, Any]:
if isinstance(target, dict):
assert "boxes" in target, "Target should contain 'boxes' key"
assert "labels" in target, "Target should contain 'labels' key"
elif isinstance(target, tuple):
assert isinstance(target[0], str) or isinstance(
target[0], np.ndarray
), "Target should be a string or a numpy array"
else:
assert isinstance(target, str) or isinstance(
target, np.ndarray
Expand Down
103 changes: 52 additions & 51 deletions doctr/io/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class Page(Element):

_exported_keys: List[str] = ["page_idx", "dimensions", "orientation", "language"]
_children_names: List[str] = ["blocks"]
blocks: List[Block] = []
blocks: Dict[str, List[Block]] = {}

def __init__(
self,
Expand All @@ -247,7 +247,7 @@ def __init__(

def render(self, block_break: str = "\n\n") -> str:
"""Renders the full text of the element"""
return block_break.join(b.render() for b in self.blocks)
return block_break.join(b.render() for blocks in self.blocks.values() for b in blocks)

def extra_repr(self) -> str:
return f"dimensions={self.dimensions}"
Expand Down Expand Up @@ -316,65 +316,66 @@ def export_as_xml(self, file_title: str = "docTR - XML export (hOCR)") -> Tuple[
},
)
# iterate over the blocks / lines / words and create the XML elements in body line by line with the attributes
for block in self.blocks:
if len(block.geometry) != 2:
raise TypeError("XML export is only available for straight bounding boxes for now.")
(xmin, ymin), (xmax, ymax) = block.geometry
block_div = SubElement(
body,
"div",
attrib={
"class": "ocr_carea",
"id": f"block_{block_count}",
"title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
{int(round(xmax * width))} {int(round(ymax * height))}",
},
)
paragraph = SubElement(
block_div,
"p",
attrib={
"class": "ocr_par",
"id": f"par_{block_count}",
"title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
{int(round(xmax * width))} {int(round(ymax * height))}",
},
)
block_count += 1
for line in block.lines:
(xmin, ymin), (xmax, ymax) = line.geometry
# NOTE: baseline, x_size, x_descenders, x_ascenders is currently initalized to 0
line_span = SubElement(
paragraph,
"span",
for class_name, blocks in self.blocks.items():
for block in blocks:
if len(block.geometry) != 2:
raise TypeError("XML export is only available for straight bounding boxes for now.")
(xmin, ymin), (xmax, ymax) = block.geometry
block_div = SubElement(
body,
"div",
attrib={
"class": "ocr_line",
"id": f"line_{line_count}",
"class": "ocr_carea",
"id": f"{class_name}_block_{block_count}",
"title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
{int(round(xmax * width))} {int(round(ymax * height))}; \
baseline 0 0; x_size 0; x_descenders 0; x_ascenders 0",
{int(round(xmax * width))} {int(round(ymax * height))}",
},
)
line_count += 1
for word in line.words:
(xmin, ymin), (xmax, ymax) = word.geometry
conf = word.confidence
word_div = SubElement(
line_span,
paragraph = SubElement(
block_div,
"p",
attrib={
"class": "ocr_par",
"id": f"par_{block_count}",
"title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
{int(round(xmax * width))} {int(round(ymax * height))}",
},
)
block_count += 1
for line in block.lines:
(xmin, ymin), (xmax, ymax) = line.geometry
# NOTE: baseline, x_size, x_descenders, x_ascenders is currently initalized to 0
line_span = SubElement(
paragraph,
"span",
attrib={
"class": "ocrx_word",
"id": f"word_{word_count}",
"class": "ocr_line",
"id": f"line_{line_count}",
"title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
{int(round(xmax * width))} {int(round(ymax * height))}; \
x_wconf {int(round(conf * 100))}",
baseline 0 0; x_size 0; x_descenders 0; x_ascenders 0",
},
)
# set the text
word_div.text = word.value
word_count += 1

return (ET.tostring(page_hocr, encoding="utf-8", method="xml"), ET.ElementTree(page_hocr))
line_count += 1
for word in line.words:
(xmin, ymin), (xmax, ymax) = word.geometry
conf = word.confidence
word_div = SubElement(
line_span,
"span",
attrib={
"class": "ocrx_word",
"id": f"word_{word_count}",
"title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
{int(round(xmax * width))} {int(round(ymax * height))}; \
x_wconf {int(round(conf * 100))}",
},
)
# set the text
word_div.text = word.value
word_count += 1

return ET.tostring(page_hocr, encoding="utf-8", method="xml"), ET.ElementTree(page_hocr)

@classmethod
def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
Expand Down
35 changes: 34 additions & 1 deletion doctr/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from math import floor
from statistics import median_low
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import cv2
import numpy as np
Expand Down Expand Up @@ -161,3 +161,36 @@ def get_language(text: str) -> Tuple[str, float]:
if len(text) <= 1 or (len(text) <= 5 and lang.prob <= 0.2):
return "unknown", 0.0
return lang.lang, lang.prob


def invert_list_dict_to_dict_list(list_dict: List[Dict[str, Any]]) -> Dict[str, List[Any]]:
"""Convert a List of Dict of elements to a Dict of list of elements
Args:
list_dict (List): the list of dictionaries
Returns:
dict_list (Dict): the dictionary of lists
"""
dict_list: Dict[str, List[Any]] = {k: [] for k in list_dict[0].keys()}
for dic in list_dict:
for k, v in dic.items():
dict_list[k].append(v)
return dict_list


def invert_dict_list_to_list_dict(dict_list: Dict[str, List[Any]]) -> List[Dict[str, Any]]:
"""Convert a Dict of list of elements to a List of Dict of elements
Args:
dict_list (Dict): the dictionary of lists
Returns:
list_dict (List): the list of dictionaries
"""
n = len(list(dict_list.values())[0])
list_dict: List[Dict[str, Any]] = [{k: None for k in dict_list.keys()} for _ in range(n)]
for k, value in dict_list.items():
for i, v in enumerate(value):
list_dict[i][k] = v
return list_dict
63 changes: 1 addition & 62 deletions doctr/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import numpy as np
from scipy.cluster.hierarchy import fclusterdata

from doctr.file_utils import is_tf_available
from doctr.io.elements import Block, Document, Line, Page, Word
from doctr.utils.geometry import estimate_page_angle, resolve_enclosing_bbox, resolve_enclosing_rbbox, rotate_boxes
from doctr.utils.repr import NestedObject
Expand Down Expand Up @@ -40,10 +39,6 @@ def __init__(
self.resolve_blocks = resolve_blocks
self.paragraph_break = paragraph_break
self.export_as_straight_boxes = export_as_straight_boxes
if is_tf_available():
self.__call__ = self.tf_call
else:
self.__call__ = self.torch_call # type: ignore[assignment]

@staticmethod
def _sort_boxes(boxes: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -284,63 +279,7 @@ def extra_repr(self) -> str:
f"export_as_straight_boxes={self.export_as_straight_boxes}"
)

def torch_call(
self,
boxes: List[np.ndarray],
text_preds: List[List[Tuple[str, float]]],
page_shapes: List[Tuple[int, int]],
orientations: Optional[List[Dict[str, Any]]] = None,
languages: Optional[List[Dict[str, Any]]] = None,
) -> Document:
"""Re-arrange detected words into structured blocks
Args:
boxes: list of N elements, where each element represents the localization predictions, of shape (*, 5)
or (*, 6) for all words for a given page
text_preds: list of N elements, where each element is the list of all word prediction (text + confidence)
page_shape: shape of each page, of size N
Returns:
document object
"""
if len(boxes) != len(text_preds) or len(boxes) != len(page_shapes):
raise ValueError("All arguments are expected to be lists of the same size")

_orientations = (
orientations if isinstance(orientations, list) else [None] * len(boxes) # type: ignore[list-item]
)
_languages = languages if isinstance(languages, list) else [None] * len(boxes) # type: ignore[list-item]
if self.export_as_straight_boxes and len(boxes) > 0:
# If boxes are already straight OK, else fit a bounding rect
if boxes[0].ndim == 3:
straight_boxes: List[np.ndarray] = []
# Iterate over pages
for p_boxes in boxes:
# Iterate over boxes of the pages
straight_boxes.append(np.concatenate((p_boxes.min(1), p_boxes.max(1)), 1))
boxes = straight_boxes

_pages = [
Page(
{
"words": self._build_blocks(
page_boxes,
word_preds,
)
},
_idx,
shape,
orientation,
language,
)
for _idx, shape, page_boxes, word_preds, orientation, language in zip(
range(len(boxes)), page_shapes, boxes, text_preds, _orientations, _languages
)
]

return Document(_pages)

def tf_call(
def __call__(
self,
boxes: List[Dict[str, np.ndarray]],
text_preds: List[Dict[str, List[Tuple[str, float]]]],
Expand Down
21 changes: 15 additions & 6 deletions doctr/models/detection/differentiable_binarization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import pyclipper
from shapely.geometry import Polygon

from doctr.file_utils import is_tf_available

from ..core import DetectionPostProcessor

__all__ = ["DBPostProcessor"]
Expand Down Expand Up @@ -272,6 +274,8 @@ def build_target(
for tgt in target:
if isinstance(tgt, np.ndarray):
new_target.append({"words": tgt})
else:
new_target.append(tgt)
target = new_target.copy()
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
Expand All @@ -280,8 +284,12 @@ def build_target(

input_dtype = list(target[0].values())[0].dtype if len(target) > 0 else np.float32

h, w = output_shape[1:-1]
target_shape = (output_shape[0], output_shape[-1], h, w)
if is_tf_available():
h, w = output_shape[1:-1]
target_shape = (output_shape[0], output_shape[-1], h, w)
else:
h, w = output_shape[-2:]
target_shape = output_shape
seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
seg_mask: np.ndarray = np.ones(target_shape, dtype=bool)
thresh_target: np.ndarray = np.zeros(target_shape, dtype=np.float32)
Expand Down Expand Up @@ -349,10 +357,11 @@ def build_target(
poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] = self.draw_thresh_map(
poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx]
)
seg_target = seg_target.transpose((0, 2, 3, 1))
seg_mask = seg_mask.transpose((0, 2, 3, 1))
thresh_target = thresh_target.transpose((0, 2, 3, 1))
thresh_mask = thresh_mask.transpose((0, 2, 3, 1))
if is_tf_available():
seg_target = seg_target.transpose((0, 2, 3, 1))
seg_mask = seg_mask.transpose((0, 2, 3, 1))
thresh_target = thresh_target.transpose((0, 2, 3, 1))
thresh_mask = thresh_mask.transpose((0, 2, 3, 1))

thresh_target = thresh_target.astype(input_dtype) * (self.thresh_max - self.thresh_min) + self.thresh_min

Expand Down
20 changes: 16 additions & 4 deletions doctr/models/detection/differentiable_binarization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,13 @@ def __init__(
assume_straight_pages: bool = True,
exportable: bool = False,
cfg: Optional[Dict[str, Any]] = None,
class_names: List[str] = ["words"],
) -> None:

super().__init__()
self.class_names = class_names
if cfg and cfg.get("class_names"):
self.class_names = cfg["class_names"]
self.cfg = cfg

conv_layer = DeformConv2d if deform_conv else nn.Conv2d
Expand Down Expand Up @@ -207,8 +211,12 @@ def forward(

if target is None or return_preds:
# Post-process boxes (keep only text predictions)
# out["preds"] = [
# preds[0] for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
# ]
out["preds"] = [
preds[0] for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
{class_name: p for class_name, p in zip(self.class_names, preds)}
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
]

if target is not None:
Expand All @@ -231,8 +239,8 @@ def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target:
A loss tensor
"""

prob_map = torch.sigmoid(out_map.squeeze(1))
thresh_map = torch.sigmoid(thresh_map.squeeze(1))
prob_map = torch.sigmoid(out_map)
thresh_map = torch.sigmoid(thresh_map)

targets = self.build_target(target, prob_map.shape) # type: ignore[arg-type]

Expand All @@ -247,7 +255,11 @@ def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target:
dice_loss = torch.zeros(1, device=out_map.device)
l1_loss = torch.zeros(1, device=out_map.device)
if torch.any(seg_mask):
bce_loss = F.binary_cross_entropy_with_logits(out_map.squeeze(1), seg_target, reduction="none")[seg_mask]
bce_loss = F.binary_cross_entropy_with_logits(
out_map,
seg_target,
reduction="none",
)[seg_mask]

neg_target = 1 - seg_target[seg_mask]
positive_count = seg_target[seg_mask].sum()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def compute_loss(
out_map: tf.Tensor,
thresh_map: tf.Tensor,
target: List[np.ndarray],
eps: float = 1e-8,
) -> tf.Tensor:
"""Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes
and a list of masks for each image. From there it computes the loss with the model output
Expand All @@ -178,7 +177,6 @@ def compute_loss(
out_map: output feature map of the model of shape (N, H, W, C)
thresh_map: threshold map of shape (N, H, W, C)
target: list of dictionary where each dict has a `boxes` and a `flags` entry
eps: epsilon factor in dice loss
Returns:
A loss tensor
Expand Down
Loading

0 comments on commit 2e2a281

Please sign in to comment.