Skip to content

Commit

Permalink
feat: add loading backbone pretrained for multiclass detection, new e…
Browse files Browse the repository at this point in the history
…lements for kie predictor (#6)

* feat: ✨ add load backbone

* feat: change kie predictor out

* fix new elements for kie, dataset when class is empty and fix and add tests

* fix api kie route

* fix evaluate kie script

* fix black

* remove commented code

* update README
  • Loading branch information
aminemindee authored Dec 5, 2022
1 parent 07b51a3 commit aa15f43
Show file tree
Hide file tree
Showing 12 changed files with 257 additions and 229 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,9 @@ result = model(doc)

predictions = result.pages[0].predictions
for class_name in predictions.keys():
list_blocks = predictions[class_name]
print(f"Prediction for {class_name}: {list_blocks}")
list_predictions = predictions[class_name]
for prediction in list_predictions:
print(f"Prediction for {class_name}: {prediction}")
```
The KIE predictor results per page are in a dictionary format with each key representing a class name and it's value are the predictions for that class.

Expand Down
6 changes: 2 additions & 4 deletions api/app/routes/kie.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ async def perform_kie(file: UploadFile = File(...)):

return {
class_name: [
OCROut(box=(*word.geometry[0], *word.geometry[1]), value=word.value)
for block in out.pages[0].predictions[class_name]
for line in block.lines
for word in line.words
OCROut(box=(*prediction.geometry[0], *prediction.geometry[1]), value=prediction.value)
for prediction in out.pages[0].predictions[class_name]
]
for class_name in out.pages[0].predictions.keys()
}
2 changes: 1 addition & 1 deletion doctr/datasets/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def format_polygons(
elif isinstance(polygons, dict):
self._class_names += list(polygons.keys())
polygons_classes = [k for k, v in polygons.items() for _ in v]
_polygons = np.concatenate([np.asarray(poly, dtype=np_dtype) for poly in polygons.values()], axis=0)
_polygons = np.concatenate([np.asarray(poly, dtype=np_dtype) for poly in polygons.values() if poly], axis=0)
else:
raise TypeError(f"polygons should be a dictionary or list, it was {type(polygons)}")
geoms = _polygons if use_polygons else np.concatenate((_polygons.min(axis=1), _polygons.max(axis=1)), axis=1)
Expand Down
86 changes: 28 additions & 58 deletions doctr/io/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from doctr.utils.repr import NestedObject
from doctr.utils.visualization import synthesize_kie_page, synthesize_page, visualize_kie_page, visualize_page

__all__ = ["Element", "Word", "Artefact", "Line", "Block", "Page", "KIEPage", "Document"]
__all__ = ["Element", "Word", "Artefact", "Line", "Prediction", "Block", "Page", "KIEPage", "Document"]


class Element(NestedObject):
Expand Down Expand Up @@ -166,6 +166,17 @@ def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
return cls(**kwargs)


class Prediction(Word):
"""Implements a prediction element"""

def render(self) -> str:
"""Renders the full text of the element"""
return self.value

def extra_repr(self) -> str:
return f"value='{self.value}', confidence={self.confidence:.2}, bounding_box={self.geometry}"


class Block(Element):
"""Implements a block element as a collection of lines and artefacts
Expand Down Expand Up @@ -396,11 +407,11 @@ class KIEPage(Element):

_exported_keys: List[str] = ["page_idx", "dimensions", "orientation", "language"]
_children_names: List[str] = ["predictions"]
predictions: Dict[str, List[Block]] = {}
predictions: Dict[str, List[Prediction]] = {}

def __init__(
self,
predictions: Dict[str, List[Block]],
predictions: Dict[str, List[Prediction]],
page_idx: int,
dimensions: Tuple[int, int],
orientation: Optional[Dict[str, Any]] = None,
Expand All @@ -412,9 +423,11 @@ def __init__(
self.orientation = orientation if isinstance(orientation, dict) else dict(value=None, confidence=None)
self.language = language if isinstance(language, dict) else dict(value=None, confidence=None)

def render(self, block_break: str = "\n\n") -> str:
def render(self, prediction_break: str = "\n\n") -> str:
"""Renders the full text of the element"""
return block_break.join(b.render() for blocks in self.predictions.values() for b in blocks)
return prediction_break.join(
f"{class_name}: {p.render()}" for class_name, predictions in self.predictions.items() for p in predictions
)

def extra_repr(self) -> str:
return f"dimensions={self.dimensions}"
Expand Down Expand Up @@ -450,9 +463,7 @@ def export_as_xml(self, file_title: str = "docTR - XML export (hOCR)") -> Tuple[
a tuple of the XML byte string, and its ElementTree
"""
p_idx = self.page_idx
block_count: int = 1
line_count: int = 1
word_count: int = 1
prediction_count: int = 1
height, width = self.dimensions
language = self.language if "language" in self.language.keys() else "en"
# Create the XML root element
Expand Down Expand Up @@ -483,72 +494,31 @@ def export_as_xml(self, file_title: str = "docTR - XML export (hOCR)") -> Tuple[
},
)
# iterate over the blocks / lines / words and create the XML elements in body line by line with the attributes
for class_name, blocks in self.predictions.items():
for block in blocks:
if len(block.geometry) != 2:
for class_name, predictions in self.predictions.items():
for prediction in predictions:
if len(prediction.geometry) != 2:
raise TypeError("XML export is only available for straight bounding boxes for now.")
(xmin, ymin), (xmax, ymax) = block.geometry
block_div = SubElement(
(xmin, ymin), (xmax, ymax) = prediction.geometry
prediction_div = SubElement(
body,
"div",
attrib={
"class": "ocr_carea",
"id": f"{class_name}_block_{block_count}",
"title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
{int(round(xmax * width))} {int(round(ymax * height))}",
},
)
paragraph = SubElement(
block_div,
"p",
attrib={
"class": "ocr_par",
"id": f"par_{block_count}",
"id": f"{class_name}_prediction_{prediction_count}",
"title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
{int(round(xmax * width))} {int(round(ymax * height))}",
},
)
block_count += 1
for line in block.lines:
(xmin, ymin), (xmax, ymax) = line.geometry
# NOTE: baseline, x_size, x_descenders, x_ascenders is currently initalized to 0
line_span = SubElement(
paragraph,
"span",
attrib={
"class": "ocr_line",
"id": f"line_{line_count}",
"title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
{int(round(xmax * width))} {int(round(ymax * height))}; \
baseline 0 0; x_size 0; x_descenders 0; x_ascenders 0",
},
)
line_count += 1
for word in line.words:
(xmin, ymin), (xmax, ymax) = word.geometry
conf = word.confidence
word_div = SubElement(
line_span,
"span",
attrib={
"class": "ocrx_word",
"id": f"word_{word_count}",
"title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
{int(round(xmax * width))} {int(round(ymax * height))}; \
x_wconf {int(round(conf * 100))}",
},
)
# set the text
word_div.text = word.value
word_count += 1
prediction_div.text = prediction.value
prediction_count += 1

return ET.tostring(page_hocr, encoding="utf-8", method="xml"), ET.ElementTree(page_hocr)

@classmethod
def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
kwargs = {k: save_dict[k] for k in cls._exported_keys}
kwargs.update(
{"predictions": [Block.from_dict(predictions_dict) for predictions_dict in save_dict["predictions"]]}
{"predictions": [Prediction.from_dict(predictions_dict) for predictions_dict in save_dict["predictions"]]}
)
return cls(**kwargs)

Expand Down
46 changes: 43 additions & 3 deletions doctr/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
from scipy.cluster.hierarchy import fclusterdata

from doctr.io.elements import Block, Document, KIEDocument, KIEPage, Line, Page, Word
from doctr.io.elements import Block, Document, KIEDocument, KIEPage, Line, Page, Prediction, Word
from doctr.utils.geometry import estimate_page_angle, resolve_enclosing_bbox, resolve_enclosing_rbbox, rotate_boxes
from doctr.utils.repr import NestedObject

Expand Down Expand Up @@ -335,7 +335,7 @@ def __call__(


class KIEDocumentBuilder(DocumentBuilder):
"""Implements a document builder
"""Implements a KIE document builder
Args:
resolve_lines: whether words should be automatically grouped into lines
Expand All @@ -353,7 +353,7 @@ def __call__( # type: ignore[override]
orientations: Optional[List[Dict[str, Any]]] = None,
languages: Optional[List[Dict[str, Any]]] = None,
) -> KIEDocument:
"""Re-arrange detected words into structured blocks
"""Re-arrange detected words into structured predictions
Args:
boxes: list of N dictionaries, where each element represents the localization predictions for a class,
Expand Down Expand Up @@ -403,3 +403,43 @@ def __call__( # type: ignore[override]
]

return KIEDocument(_pages)

def _build_blocks( # type: ignore[override]
self,
boxes: np.ndarray,
word_preds: List[Tuple[str, float]],
) -> List[Prediction]:
"""Gather independent words in structured blocks
Args:
boxes: bounding boxes of all detected words of the page, of shape (N, 5) or (N, 4, 2)
word_preds: list of all detected words of the page, of shape N
Returns:
list of block elements
"""

if boxes.shape[0] != len(word_preds):
raise ValueError(f"Incompatible argument lengths: {boxes.shape[0]}, {len(word_preds)}")

if boxes.shape[0] == 0:
return []

# Decide whether we try to form lines
_boxes = boxes
idxs, _ = self._sort_boxes(_boxes if _boxes.ndim == 3 else _boxes[:, :4])
predictions = [
Prediction(
value=word_preds[idx][0],
confidence=word_preds[idx][1],
geometry=tuple([tuple(pt) for pt in boxes[idx].tolist()]), # type: ignore[arg-type]
)
if boxes.ndim == 3
else Prediction(
value=word_preds[idx][0],
confidence=word_preds[idx][1],
geometry=((boxes[idx, 0], boxes[idx, 1]), (boxes[idx, 2], boxes[idx, 3])),
)
for idx in idxs
]
return predictions
Loading

0 comments on commit aa15f43

Please sign in to comment.