diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 264a6b5e51..39d22cdcb9 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -28,14 +28,6 @@ jobs: with: python-version: ${{ matrix.python }} architecture: x64 - - name: Install poetry - uses: abatilo/actions-poetry@v2.0.0 - with: - poetry-version: 1.1.13 - - name: Lock the requirements - run: | - cd api - make lock - name: Build & run docker run: cd api && docker-compose up -d --build - name: Ping server diff --git a/.github/workflows/scripts.yml b/.github/workflows/scripts.yml index b485b5e665..aaf435c9e7 100644 --- a/.github/workflows/scripts.yml +++ b/.github/workflows/scripts.yml @@ -140,7 +140,9 @@ jobs: python -m pip install --upgrade pip pip install -e .[torch] --upgrade - name: Run evaluation script - run: python scripts/evaluate.py db_resnet50 crnn_vgg16_bn --samples 10 + run: | + python scripts/evaluate.py db_resnet50 crnn_vgg16_bn --samples 10 + python scripts/evaluate_kie.py db_resnet50 crnn_vgg16_bn --samples 10 test-collectenv: runs-on: ${{ matrix.os }} diff --git a/README.md b/README.md index 5be770103c..beadb25e3c 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,31 @@ You can also export them as a nested dict, more appropriate for JSON format: json_output = result.export() ``` +### Use the KIE predictor +The KIE predictor is a more flexible predictor compared to OCR as your detection model can detect multiple classes in a document. For example, you can have a detection model to detect just dates and adresses in a document. + +The KIE predictor makes it possible to use detector with multiple classes with a recognition model and to have the whole pipeline already setup for you. + +```python +from doctr.io import DocumentFile +from doctr.models import kie_predictor + +# Model +model = kie_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True) +# PDF +doc = DocumentFile.from_pdf("path/to/your/doc.pdf") +# Analyze +result = model(doc) + +predictions = result.pages[0].predictions +for class_name in predictions.keys(): + list_predictions = predictions[class_name] + for prediction in list_predictions: + print(f"Prediction for {class_name}: {prediction}") +``` +The KIE predictor results per page are in a dictionary format with each key representing a class name and it's value are the predictions for that class. + + ### If you are looking for support from the Mindee team [![Bad OCR test detection image asking the developer if they need help](https://github.com/mindee/doctr/releases/download/v0.5.1/doctr-need-help.png)](https://mindee.com/product/doctr) @@ -247,7 +272,10 @@ Looking to integrate docTR into your API? Here is a template to get you started #### Deploy your API locally Specific dependencies are required to run the API template, which you can install as follows: ```shell -pip install -r api/requirements.txt +cd api/ +pip install poetry +make lock +pip install -r requirements.txt ``` You can now run your API locally: @@ -262,7 +290,7 @@ PORT=8002 docker-compose up -d --build #### What you have deployed -Your API should now be running locally on your port 8002. Access your automatically-built documentation at [http://localhost:8002/redoc](http://localhost:8002/redoc) and enjoy your three functional routes ("/detection", "/recognition", "/ocr"). Here is an example with Python to send a request to the OCR route: +Your API should now be running locally on your port 8002. Access your automatically-built documentation at [http://localhost:8002/redoc](http://localhost:8002/redoc) and enjoy your three functional routes ("/detection", "/recognition", "/ocr", "/kie"). Here is an example with Python to send a request to the OCR route: ```python import requests diff --git a/api/Dockerfile b/api/Dockerfile index a803589108..148eaabd61 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -7,17 +7,18 @@ ENV PYTHONDONTWRITEBYTECODE 1 ENV PYTHONUNBUFFERED 1 ENV PYTHONPATH "${PYTHONPATH}:/app" -# copy requirements file -COPY requirements.txt /app/requirements.txt +RUN apt-get update \ + && apt-get install --no-install-recommends ffmpeg libsm6 libxext6 make -y \ + && apt-get autoremove -y \ + && rm -rf /var/lib/apt/lists/* + +COPY pyproject.toml /app/pyproject.toml COPY Makefile /app/Makefile -RUN apt-get update \ - && apt-get install --no-install-recommends ffmpeg libsm6 libxext6 -y \ - && pip install --upgrade pip setuptools wheel \ +RUN pip install --upgrade pip setuptools wheel poetry \ + && make lock \ && pip install -r /app/requirements.txt \ && pip cache purge \ - && apt-get autoremove -y \ - && rm -rf /var/lib/apt/lists/* \ && rm -rf /root/.cache/pip # copy project diff --git a/api/Makefile b/api/Makefile index 3472a78172..cd18b1c3ec 100644 --- a/api/Makefile +++ b/api/Makefile @@ -5,7 +5,7 @@ lock: poetry lock poetry export -f requirements.txt --without-hashes --output requirements.txt - poetry export -f requirements.txt --without-hashes --dev --output requirements-dev.txt + poetry export -f requirements.txt --without-hashes --with dev --output requirements-dev.txt # Run the docker run: diff --git a/api/app/main.py b/api/app/main.py index 4f081ef6dc..f4fe9d18ad 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -9,7 +9,7 @@ from fastapi.openapi.utils import get_openapi from app import config as cfg -from app.routes import detection, ocr, recognition +from app.routes import detection, kie, ocr, recognition app = FastAPI(title=cfg.PROJECT_NAME, description=cfg.PROJECT_DESCRIPTION, debug=cfg.DEBUG, version=cfg.VERSION) @@ -18,6 +18,7 @@ app.include_router(recognition.router, prefix="/recognition", tags=["recognition"]) app.include_router(detection.router, prefix="/detection", tags=["detection"]) app.include_router(ocr.router, prefix="/ocr", tags=["ocr"]) +app.include_router(kie.router, prefix="/kie", tags=["kie"]) # Middleware diff --git a/api/app/routes/detection.py b/api/app/routes/detection.py index e074530331..f53a9b6c8a 100644 --- a/api/app/routes/detection.py +++ b/api/app/routes/detection.py @@ -9,6 +9,7 @@ from app.schemas import DetectionOut from app.vision import det_predictor +from doctr.file_utils import CLASS_NAME from doctr.io import decode_img_as_tensor router = APIRouter() @@ -19,4 +20,4 @@ async def text_detection(file: UploadFile = File(...)): """Runs docTR text detection model to analyze the input image""" img = decode_img_as_tensor(file.file.read()) boxes = det_predictor([img])[0] - return [DetectionOut(box=box.tolist()) for box in boxes[:, :-1]] + return [DetectionOut(box=box.tolist()) for box in boxes[CLASS_NAME][:, :-1]] diff --git a/api/app/routes/kie.py b/api/app/routes/kie.py new file mode 100644 index 0000000000..cf17d7fcdd --- /dev/null +++ b/api/app/routes/kie.py @@ -0,0 +1,29 @@ +# Copyright (C) 2022, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Dict, List + +from fastapi import APIRouter, File, UploadFile, status + +from app.schemas import OCROut +from app.vision import kie_predictor +from doctr.io import decode_img_as_tensor + +router = APIRouter() + + +@router.post("/", response_model=Dict[str, List[OCROut]], status_code=status.HTTP_200_OK, summary="Perform KIE") +async def perform_kie(file: UploadFile = File(...)): + """Runs docTR KIE model to analyze the input image""" + img = decode_img_as_tensor(file.file.read()) + out = kie_predictor([img]) + + return { + class_name: [ + OCROut(box=(*prediction.geometry[0], *prediction.geometry[1]), value=prediction.value) + for prediction in out.pages[0].predictions[class_name] + ] + for class_name in out.pages[0].predictions.keys() + } diff --git a/api/app/routes/ocr.py b/api/app/routes/ocr.py index c761c3c304..114e0f10c2 100644 --- a/api/app/routes/ocr.py +++ b/api/app/routes/ocr.py @@ -22,5 +22,7 @@ async def perform_ocr(file: UploadFile = File(...)): return [ OCROut(box=(*word.geometry[0], *word.geometry[1]), value=word.value) - for word in out.pages[0].blocks[0].lines[0].words + for block in out.pages[0].blocks + for line in block.lines + for word in line.words ] diff --git a/api/app/vision.py b/api/app/vision.py index 7e2eb57aba..b4cd0772af 100644 --- a/api/app/vision.py +++ b/api/app/vision.py @@ -9,8 +9,9 @@ if any(gpu_devices): tf.config.experimental.set_memory_growth(gpu_devices[0], True) -from doctr.models import ocr_predictor +from doctr.models import kie_predictor, ocr_predictor predictor = ocr_predictor(pretrained=True) det_predictor = predictor.det_predictor reco_predictor = predictor.reco_predictor +kie_predictor = kie_predictor(pretrained=True) diff --git a/api/pyproject.toml b/api/pyproject.toml index c57407db74..4ac0564aa3 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -4,16 +4,16 @@ build-backend = "poetry.masonry.api" [tool.poetry] name = "doctr-api" -version = "0.5.2a0" +version = "0.7.1a0" description = "Backend template for your OCR API with docTR" authors = ["Mindee "] license = "Apache-2.0" [tool.poetry.dependencies] -python = ">=3.8,<3.11" +python = ">=3.8.2,<3.11" # pypdfium2 needs a python version above 3.8.2 tensorflow = ">=2.9.0,<3.0.0" tensorflow-addons = ">=0.17.1" -python-doctr = ">=0.2.0" +python-doctr = { version = ">=0.7.0", extras = ['tf'] } # Fastapi: minimum version required to avoid pydantic error # cf. https://github.com/tiangolo/fastapi/issues/4168 fastapi = ">=0.73.0" diff --git a/api/tests/routes/test_kie.py b/api/tests/routes/test_kie.py new file mode 100644 index 0000000000..659f7f2481 --- /dev/null +++ b/api/tests/routes/test_kie.py @@ -0,0 +1,29 @@ +import numpy as np +import pytest +from scipy.optimize import linear_sum_assignment + +from doctr.utils.metrics import box_iou + + +@pytest.mark.asyncio +async def test_perform_kie(test_app_asyncio, mock_detection_image): + + response = await test_app_asyncio.post("/kie", files={"file": mock_detection_image}) + assert response.status_code == 200 + json_response = response.json() + + gt_boxes = np.array([[1240, 430, 1355, 470], [1360, 430, 1495, 470]], dtype=np.float32) + gt_boxes[:, [0, 2]] = gt_boxes[:, [0, 2]] / 1654 + gt_boxes[:, [1, 3]] = gt_boxes[:, [1, 3]] / 2339 + gt_labels = ["Hello", "world!"] + + # Check that IoU with GT if reasonable + assert isinstance(json_response, dict) and len(list(json_response.values())[0]) == gt_boxes.shape[0] + pred_boxes = np.array([elt["box"] for json_out in json_response.values() for elt in json_out]) + pred_labels = np.array([elt["value"] for json_out in json_response.values() for elt in json_out]) + iou_mat = box_iou(gt_boxes, pred_boxes) + gt_idxs, pred_idxs = linear_sum_assignment(-iou_mat) + is_kept = iou_mat[gt_idxs, pred_idxs] >= 0.8 + gt_idxs, pred_idxs = gt_idxs[is_kept], pred_idxs[is_kept] + assert gt_idxs.shape[0] == gt_boxes.shape[0] + assert all(gt_labels[gt_idx] == pred_labels[pred_idx] for gt_idx, pred_idx in zip(gt_idxs, pred_idxs)) diff --git a/docs/source/modules/models.rst b/docs/source/modules/models.rst index 79154b3c58..e6a6e307ad 100644 --- a/docs/source/modules/models.rst +++ b/docs/source/modules/models.rst @@ -81,6 +81,8 @@ doctr.models.zoo .. autofunction:: doctr.models.ocr_predictor +.. autofunction:: doctr.models.kie_predictor + doctr.models.factory -------------------- diff --git a/doctr/__init__.py b/doctr/__init__.py index 14390c4cd1..1e27fc9197 100644 --- a/doctr/__init__.py +++ b/doctr/__init__.py @@ -1,3 +1,3 @@ -from . import datasets, io, models, transforms, utils +from . import io, datasets, models, transforms, utils from .file_utils import is_tf_available, is_torch_available from .version import __version__ # noqa: F401 diff --git a/doctr/datasets/datasets/base.py b/doctr/datasets/datasets/base.py index 55665e4a26..1b6a63532f 100644 --- a/doctr/datasets/datasets/base.py +++ b/doctr/datasets/datasets/base.py @@ -8,6 +8,9 @@ from pathlib import Path from typing import Any, Callable, List, Optional, Tuple, Union +import numpy as np + +from doctr.file_utils import copy_tensor from doctr.io.image import get_img_shape from doctr.utils.data import download_from_url @@ -55,7 +58,13 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: img = self.img_transforms(img) if self.sample_transforms is not None: - img, target = self.sample_transforms(img, target) + if isinstance(target, dict) and all([isinstance(item, np.ndarray) for item in target.values()]): + img_transformed = copy_tensor(img) + for class_name, bboxes in target.items(): + img_transformed, target[class_name] = self.sample_transforms(img, bboxes) + img = img_transformed + else: + img, target = self.sample_transforms(img, target) return img, target diff --git a/doctr/datasets/datasets/pytorch.py b/doctr/datasets/datasets/pytorch.py index 55c130c749..b21ee64f2a 100644 --- a/doctr/datasets/datasets/pytorch.py +++ b/doctr/datasets/datasets/pytorch.py @@ -25,6 +25,12 @@ def _read_sample(self, index: int) -> Tuple[torch.Tensor, Any]: if isinstance(target, dict): assert "boxes" in target, "Target should contain 'boxes' key" assert "labels" in target, "Target should contain 'labels' key" + elif isinstance(target, tuple): + assert len(target) == 2 + assert isinstance(target[0], str) or isinstance( + target[0], np.ndarray + ), "first element of the tuple should be a string or a numpy array" + assert isinstance(target[1], list), "second element of the tuple should be a list" else: assert isinstance(target, str) or isinstance( target, np.ndarray diff --git a/doctr/datasets/datasets/tensorflow.py b/doctr/datasets/datasets/tensorflow.py index 4d8320eee1..a3b9f214b0 100644 --- a/doctr/datasets/datasets/tensorflow.py +++ b/doctr/datasets/datasets/tensorflow.py @@ -25,6 +25,12 @@ def _read_sample(self, index: int) -> Tuple[tf.Tensor, Any]: if isinstance(target, dict): assert "boxes" in target, "Target should contain 'boxes' key" assert "labels" in target, "Target should contain 'labels' key" + elif isinstance(target, tuple): + assert len(target) == 2 + assert isinstance(target[0], str) or isinstance( + target[0], np.ndarray + ), "first element of the tuple should be a string or a numpy array" + assert isinstance(target[1], list), "second element of the tuple should be a list" else: assert isinstance(target, str) or isinstance( target, np.ndarray diff --git a/doctr/datasets/detection.py b/doctr/datasets/detection.py index 0cc63c7ac0..98d27be616 100644 --- a/doctr/datasets/detection.py +++ b/doctr/datasets/detection.py @@ -5,14 +5,14 @@ import json import os -from typing import Any, List, Tuple +from typing import Any, Dict, List, Tuple, Type, Union import numpy as np -from doctr.io.image import get_img_shape -from doctr.utils.geometry import convert_to_relative_coords +from doctr.file_utils import CLASS_NAME from .datasets import AbstractDataset +from .utils import pre_transform_multiclass __all__ = ["DetectionDataset"] @@ -41,24 +41,55 @@ def __init__( ) -> None: super().__init__( img_folder, - pre_transforms=lambda img, boxes: (img, convert_to_relative_coords(boxes, get_img_shape(img))), + pre_transforms=pre_transform_multiclass, **kwargs, ) # File existence check + self._class_names: List = [] if not os.path.exists(label_path): raise FileNotFoundError(f"unable to locate {label_path}") with open(label_path, "rb") as f: labels = json.load(f) - self.data: List[Tuple[str, np.ndarray]] = [] + self.data: List[Tuple[str, Tuple[np.ndarray, List[str]]]] = [] np_dtype = np.float32 for img_name, label in labels.items(): # File existence check if not os.path.exists(os.path.join(self.root, img_name)): raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}") - polygons: np.ndarray = np.asarray(label["polygons"], dtype=np_dtype) - geoms = polygons if use_polygons else np.concatenate((polygons.min(axis=1), polygons.max(axis=1)), axis=1) + geoms, polygons_classes = self.format_polygons(label["polygons"], use_polygons, np_dtype) - self.data.append((img_name, np.asarray(geoms, dtype=np_dtype))) + self.data.append((img_name, (np.asarray(geoms, dtype=np_dtype), polygons_classes))) + + def format_polygons( + self, polygons: Union[List, Dict], use_polygons: bool, np_dtype: Type + ) -> Tuple[np.ndarray, List[str]]: + """format polygons into an array + + Args: + polygons: the bounding boxes + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + np_dtype: dtype of array + + Returns: + geoms: bounding boxes as np array + polygons_classes: list of classes for each bounding box + """ + if isinstance(polygons, list): + self._class_names += [CLASS_NAME] + polygons_classes = [CLASS_NAME for _ in polygons] + _polygons: np.ndarray = np.asarray(polygons, dtype=np_dtype) + elif isinstance(polygons, dict): + self._class_names += list(polygons.keys()) + polygons_classes = [k for k, v in polygons.items() for _ in v] + _polygons = np.concatenate([np.asarray(poly, dtype=np_dtype) for poly in polygons.values() if poly], axis=0) + else: + raise TypeError(f"polygons should be a dictionary or list, it was {type(polygons)}") + geoms = _polygons if use_polygons else np.concatenate((_polygons.min(axis=1), _polygons.max(axis=1)), axis=1) + return geoms, polygons_classes + + @property + def class_names(self): + return sorted(list(set(self._class_names))) diff --git a/doctr/datasets/utils.py b/doctr/datasets/utils.py index 767790e201..5a0476adc7 100644 --- a/doctr/datasets/utils.py +++ b/doctr/datasets/utils.py @@ -20,7 +20,7 @@ from .vocabs import VOCABS -__all__ = ["translate", "encode_string", "decode_sequence", "encode_sequences"] +__all__ = ["translate", "encode_string", "decode_sequence", "encode_sequences", "pre_transform_multiclass"] ImageTensor = TypeVar("ImageTensor") @@ -183,3 +183,19 @@ def crop_bboxes_from_image(img_path: Union[str, Path], geoms: np.ndarray) -> Lis if geoms.ndim == 2 and geoms.shape[1] == 4: return extract_crops(img, geoms.astype(dtype=int)) raise ValueError("Invalid geometry format") + + +def pre_transform_multiclass(img, target: Tuple[np.ndarray, List]) -> Tuple[np.ndarray, Dict[str, List]]: + """Converts multiclass target to relative coordinates. + + Args: + img: Image + target: tuple of target polygons and their classes names + """ + boxes = convert_to_relative_coords(target[0], get_img_shape(img)) + boxes_classes = target[1] + boxes_dict: Dict = {k: [] for k in sorted(set(boxes_classes))} + for k, poly in zip(boxes_classes, boxes): + boxes_dict[k].append(poly) + boxes_dict = {k: np.stack(v, axis=0) for k, v in boxes_dict.items()} + return img, boxes_dict diff --git a/doctr/file_utils.py b/doctr/file_utils.py index 15a9f02991..e41b1cabde 100644 --- a/doctr/file_utils.py +++ b/doctr/file_utils.py @@ -10,13 +10,16 @@ import os import sys +CLASS_NAME: str = "words" + + if sys.version_info < (3, 8): import importlib_metadata else: import importlib.metadata as importlib_metadata -__all__ = ["is_tf_available", "is_torch_available"] +__all__ = ["is_tf_available", "is_torch_available", "CLASS_NAME", "copy_tensor"] ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) @@ -85,3 +88,12 @@ def is_torch_available(): def is_tf_available(): return _tf_available + + +def copy_tensor(x): + if is_tf_available(): + import tensorflow as tf + + return tf.identity(x) + elif is_torch_available(): + return x.detach().clone() diff --git a/doctr/io/elements.py b/doctr/io/elements.py index 743c976a0c..0ae9e8ed55 100644 --- a/doctr/io/elements.py +++ b/doctr/io/elements.py @@ -19,9 +19,9 @@ from doctr.utils.common_types import BoundingBox from doctr.utils.geometry import resolve_enclosing_bbox, resolve_enclosing_rbbox from doctr.utils.repr import NestedObject -from doctr.utils.visualization import synthesize_page, visualize_page +from doctr.utils.visualization import synthesize_kie_page, synthesize_page, visualize_kie_page, visualize_page -__all__ = ["Element", "Word", "Artefact", "Line", "Block", "Page", "Document"] +__all__ = ["Element", "Word", "Artefact", "Line", "Prediction", "Block", "Page", "KIEPage", "Document"] class Element(NestedObject): @@ -42,7 +42,12 @@ def export(self) -> Dict[str, Any]: export_dict = {k: getattr(self, k) for k in self._exported_keys} for children_name in self._children_names: - export_dict[children_name] = [c.export() for c in getattr(self, children_name)] + if children_name in ["predictions"]: + export_dict[children_name] = { + k: [item.export() for item in c] for k, c in getattr(self, children_name).items() + } + else: + export_dict[children_name] = [c.export() for c in getattr(self, children_name)] return export_dict @@ -161,6 +166,17 @@ def from_dict(cls, save_dict: Dict[str, Any], **kwargs): return cls(**kwargs) +class Prediction(Word): + """Implements a prediction element""" + + def render(self) -> str: + """Renders the full text of the element""" + return self.value + + def extra_repr(self) -> str: + return f"value='{self.value}', confidence={self.confidence:.2}, bounding_box={self.geometry}" + + class Block(Element): """Implements a block element as a collection of lines and artefacts @@ -378,6 +394,135 @@ def from_dict(cls, save_dict: Dict[str, Any], **kwargs): return cls(**kwargs) +class KIEPage(Element): + """Implements a KIE page element as a collection of predictions + + Args: + predictions: Dictionary with list of block elements for each detection class + page_idx: the index of the page in the input raw document + dimensions: the page size in pixels in format (height, width) + orientation: a dictionary with the value of the rotation angle in degress and confidence of the prediction + language: a dictionary with the language value and confidence of the prediction + """ + + _exported_keys: List[str] = ["page_idx", "dimensions", "orientation", "language"] + _children_names: List[str] = ["predictions"] + predictions: Dict[str, List[Prediction]] = {} + + def __init__( + self, + predictions: Dict[str, List[Prediction]], + page_idx: int, + dimensions: Tuple[int, int], + orientation: Optional[Dict[str, Any]] = None, + language: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(predictions=predictions) + self.page_idx = page_idx + self.dimensions = dimensions + self.orientation = orientation if isinstance(orientation, dict) else dict(value=None, confidence=None) + self.language = language if isinstance(language, dict) else dict(value=None, confidence=None) + + def render(self, prediction_break: str = "\n\n") -> str: + """Renders the full text of the element""" + return prediction_break.join( + f"{class_name}: {p.render()}" for class_name, predictions in self.predictions.items() for p in predictions + ) + + def extra_repr(self) -> str: + return f"dimensions={self.dimensions}" + + def show(self, page: np.ndarray, interactive: bool = True, preserve_aspect_ratio: bool = False, **kwargs) -> None: + """Overlay the result on a given image + + Args: + page: image encoded as a numpy array in uint8 + interactive: whether the display should be interactive + preserve_aspect_ratio: pass True if you passed True to the predictor + """ + visualize_kie_page(self.export(), page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio) + plt.show(**kwargs) + + def synthesize(self, **kwargs) -> np.ndarray: + """Synthesize the page from the predictions + + Returns: + synthesized page + """ + + return synthesize_kie_page(self.export(), **kwargs) + + def export_as_xml(self, file_title: str = "docTR - XML export (hOCR)") -> Tuple[bytes, ET.ElementTree]: + """Export the page as XML (hOCR-format) + convention: https://github.com/kba/hocr-spec/blob/master/1.2/spec.md + + Args: + file_title: the title of the XML file + + Returns: + a tuple of the XML byte string, and its ElementTree + """ + p_idx = self.page_idx + prediction_count: int = 1 + height, width = self.dimensions + language = self.language if "language" in self.language.keys() else "en" + # Create the XML root element + page_hocr = ETElement("html", attrib={"xmlns": "http://www.w3.org/1999/xhtml", "xml:lang": str(language)}) + # Create the header / SubElements of the root element + head = SubElement(page_hocr, "head") + SubElement(head, "title").text = file_title + SubElement(head, "meta", attrib={"http-equiv": "Content-Type", "content": "text/html; charset=utf-8"}) + SubElement( + head, + "meta", + attrib={"name": "ocr-system", "content": f"python-doctr {doctr.__version__}"}, # type: ignore[attr-defined] + ) + SubElement( + head, + "meta", + attrib={"name": "ocr-capabilities", "content": "ocr_page ocr_carea ocr_par ocr_line ocrx_word"}, + ) + # Create the body + body = SubElement(page_hocr, "body") + SubElement( + body, + "div", + attrib={ + "class": "ocr_page", + "id": f"page_{p_idx + 1}", + "title": f"image; bbox 0 0 {width} {height}; ppageno 0", + }, + ) + # iterate over the blocks / lines / words and create the XML elements in body line by line with the attributes + for class_name, predictions in self.predictions.items(): + for prediction in predictions: + if len(prediction.geometry) != 2: + raise TypeError("XML export is only available for straight bounding boxes for now.") + (xmin, ymin), (xmax, ymax) = prediction.geometry + prediction_div = SubElement( + body, + "div", + attrib={ + "class": "ocr_carea", + "id": f"{class_name}_prediction_{prediction_count}", + "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \ + {int(round(xmax * width))} {int(round(ymax * height))}", + }, + ) + prediction_div.text = prediction.value + prediction_count += 1 + + return ET.tostring(page_hocr, encoding="utf-8", method="xml"), ET.ElementTree(page_hocr) + + @classmethod + def from_dict(cls, save_dict: Dict[str, Any], **kwargs): + kwargs = {k: save_dict[k] for k in cls._exported_keys} + kwargs.update( + {"predictions": [Prediction.from_dict(predictions_dict) for predictions_dict in save_dict["predictions"]]} + ) + return cls(**kwargs) + + class Document(Element): """Implements a document element as a collection of pages @@ -432,3 +577,20 @@ def from_dict(cls, save_dict: Dict[str, Any], **kwargs): kwargs = {k: save_dict[k] for k in cls._exported_keys} kwargs.update({"pages": [Page.from_dict(page_dict) for page_dict in save_dict["pages"]]}) return cls(**kwargs) + + +class KIEDocument(Document): + """Implements a document element as a collection of pages + + Args: + pages: list of page elements + """ + + _children_names: List[str] = ["pages"] + pages: List[KIEPage] = [] # type: ignore[assignment] + + def __init__( + self, + pages: List[KIEPage], + ) -> None: + super().__init__(pages=pages) # type: ignore[arg-type] diff --git a/doctr/models/_utils.py b/doctr/models/_utils.py index b7cf452c0f..bcde26ca59 100644 --- a/doctr/models/_utils.py +++ b/doctr/models/_utils.py @@ -5,13 +5,13 @@ from math import floor from statistics import median_low -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import cv2 import numpy as np from langdetect import LangDetectException, detect_langs -__all__ = ["estimate_orientation", "get_bitmap_angle", "get_language"] +__all__ = ["estimate_orientation", "get_bitmap_angle", "get_language", "invert_data_structure"] def get_max_width_length_ratio(contour: np.ndarray) -> float: @@ -161,3 +161,26 @@ def get_language(text: str) -> Tuple[str, float]: if len(text) <= 1 or (len(text) <= 5 and lang.prob <= 0.2): return "unknown", 0.0 return lang.lang, lang.prob + + +def invert_data_structure( + x: Union[List[Dict[str, Any]], Dict[str, List[Any]]] +) -> Union[List[Dict[str, Any]], Dict[str, List[Any]]]: + """Invert a List of Dict of elements to a Dict of list of elements and the other way around + + Args: + x: a list of dictionaries with the same keys or a dictionary of lists of the same length + + Returns: + dictionary of list when x is a list of dictionaries or a list of dictionaries when x is dictionary of lists + """ + + if isinstance(x, dict): + assert ( + len(set([len(v) for v in x.values()])) == 1 + ), "All the lists in the dictionnary should have the same length." + return [dict(zip(x, t)) for t in zip(*x.values())] + elif isinstance(x, list): + return {k: [dic[k] for dic in x] for k in x[0]} + else: + raise TypeError(f"Expected input to be either a dict or a list, got {type(input)} instead.") diff --git a/doctr/models/builder.py b/doctr/models/builder.py index b804d3c157..221c9856b2 100644 --- a/doctr/models/builder.py +++ b/doctr/models/builder.py @@ -9,7 +9,7 @@ import numpy as np from scipy.cluster.hierarchy import fclusterdata -from doctr.io.elements import Block, Document, Line, Page, Word +from doctr.io.elements import Block, Document, KIEDocument, KIEPage, Line, Page, Prediction, Word from doctr.utils.geometry import estimate_page_angle, resolve_enclosing_bbox, resolve_enclosing_rbbox, rotate_boxes from doctr.utils.repr import NestedObject @@ -332,3 +332,114 @@ def __call__( ] return Document(_pages) + + +class KIEDocumentBuilder(DocumentBuilder): + """Implements a KIE document builder + + Args: + resolve_lines: whether words should be automatically grouped into lines + resolve_blocks: whether lines should be automatically grouped into blocks + paragraph_break: relative length of the minimum space separating paragraphs + export_as_straight_boxes: if True, force straight boxes in the export (fit a rectangle + box to all rotated boxes). Else, keep the boxes format unchanged, no matter what it is. + """ + + def __call__( # type: ignore[override] + self, + boxes: List[Dict[str, np.ndarray]], + text_preds: List[Dict[str, List[Tuple[str, float]]]], + page_shapes: List[Tuple[int, int]], + orientations: Optional[List[Dict[str, Any]]] = None, + languages: Optional[List[Dict[str, Any]]] = None, + ) -> KIEDocument: + """Re-arrange detected words into structured predictions + + Args: + boxes: list of N dictionaries, where each element represents the localization predictions for a class, + of shape (*, 5) or (*, 6) for all predictions + text_preds: list of N dictionaries, where each element is the list of all word prediction + page_shape: shape of each page, of size N + + Returns: + document object + """ + if len(boxes) != len(text_preds) or len(boxes) != len(page_shapes): + raise ValueError("All arguments are expected to be lists of the same size") + _orientations = ( + orientations if isinstance(orientations, list) else [None] * len(boxes) # type: ignore[list-item] + ) + _languages = languages if isinstance(languages, list) else [None] * len(boxes) # type: ignore[list-item] + if self.export_as_straight_boxes and len(boxes) > 0: + # If boxes are already straight OK, else fit a bounding rect + if next(iter(boxes[0].values())).ndim == 3: + straight_boxes: List[Dict[str, np.ndarray]] = [] + # Iterate over pages + for p_boxes in boxes: + # Iterate over boxes of the pages + straight_boxes_dict = {} + for k, box in p_boxes.items(): + straight_boxes_dict[k] = np.concatenate((box.min(1), box.max(1)), 1) + straight_boxes.append(straight_boxes_dict) + boxes = straight_boxes + + _pages = [ + KIEPage( + { + k: self._build_blocks( + page_boxes[k], + word_preds[k], + ) + for k in page_boxes.keys() + }, + _idx, + shape, + orientation, + language, + ) + for _idx, shape, page_boxes, word_preds, orientation, language in zip( + range(len(boxes)), page_shapes, boxes, text_preds, _orientations, _languages + ) + ] + + return KIEDocument(_pages) + + def _build_blocks( # type: ignore[override] + self, + boxes: np.ndarray, + word_preds: List[Tuple[str, float]], + ) -> List[Prediction]: + """Gather independent words in structured blocks + + Args: + boxes: bounding boxes of all detected words of the page, of shape (N, 5) or (N, 4, 2) + word_preds: list of all detected words of the page, of shape N + + Returns: + list of block elements + """ + + if boxes.shape[0] != len(word_preds): + raise ValueError(f"Incompatible argument lengths: {boxes.shape[0]}, {len(word_preds)}") + + if boxes.shape[0] == 0: + return [] + + # Decide whether we try to form lines + _boxes = boxes + idxs, _ = self._sort_boxes(_boxes if _boxes.ndim == 3 else _boxes[:, :4]) + predictions = [ + Prediction( + value=word_preds[idx][0], + confidence=word_preds[idx][1], + geometry=tuple([tuple(pt) for pt in boxes[idx].tolist()]), # type: ignore[arg-type] + ) + if boxes.ndim == 3 + else Prediction( + value=word_preds[idx][0], + confidence=word_preds[idx][1], + geometry=((boxes[idx, 0], boxes[idx, 1]), (boxes[idx, 2], boxes[idx, 3])), + ) + for idx in idxs + ] + return predictions diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py index 202f739d14..ca0cb23abe 100644 --- a/doctr/models/detection/differentiable_binarization/base.py +++ b/doctr/models/detection/differentiable_binarization/base.py @@ -5,7 +5,7 @@ # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization -from typing import List, Tuple, Union +from typing import Dict, List, Tuple, Union import cv2 import numpy as np @@ -264,79 +264,96 @@ def draw_thresh_map( def build_target( self, - target: List[np.ndarray], - output_shape: Tuple[int, int, int], + target: List[Dict[str, np.ndarray]], + output_shape: Tuple[int, int, int, int], + channels_last: bool = True, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - if any(t.dtype != np.float32 for t in target): + if any(t.dtype != np.float32 for tgt in target for t in tgt.values()): raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.") - if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for t in target): + if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()): raise ValueError("the 'boxes' entry of the target is expected to take values between 0 & 1.") - input_dtype = target[0].dtype if len(target) > 0 else np.float32 - - seg_target: np.ndarray = np.zeros(output_shape, dtype=np.uint8) - seg_mask: np.ndarray = np.ones(output_shape, dtype=bool) - thresh_target: np.ndarray = np.zeros(output_shape, dtype=np.float32) - thresh_mask: np.ndarray = np.ones(output_shape, dtype=np.uint8) - - for idx, _target in enumerate(target): - # Draw each polygon on gt - if _target.shape[0] == 0: - # Empty image, full masked - seg_mask[idx] = False - - # Absolute bounding boxes - abs_boxes = _target.copy() - if abs_boxes.ndim == 3: - abs_boxes[:, :, 0] *= output_shape[-1] - abs_boxes[:, :, 1] *= output_shape[-2] - polys = abs_boxes - boxes_size = np.linalg.norm(abs_boxes[:, 2, :] - abs_boxes[:, 0, :], axis=-1) - abs_boxes = np.concatenate((abs_boxes.min(1), abs_boxes.max(1)), -1).round().astype(np.int32) - else: - abs_boxes[:, [0, 2]] *= output_shape[-1] - abs_boxes[:, [1, 3]] *= output_shape[-2] - abs_boxes = abs_boxes.round().astype(np.int32) - polys = np.stack( - [ - abs_boxes[:, [0, 1]], - abs_boxes[:, [0, 3]], - abs_boxes[:, [2, 3]], - abs_boxes[:, [2, 1]], - ], - axis=1, - ) - boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1]) - - for box, box_size, poly in zip(abs_boxes, boxes_size, polys): - # Mask boxes that are too small - if box_size < self.min_size_box: - seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False - continue - - # Negative shrink for gt, as described in paper - polygon = Polygon(poly) - distance = polygon.area * (1 - np.power(self.shrink_ratio, 2)) / polygon.length - subject = [tuple(coor) for coor in poly] - padding = pyclipper.PyclipperOffset() - padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) - shrinked = padding.Execute(-distance) - - # Draw polygon on gt if it is valid - if len(shrinked) == 0: - seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False - continue - shrinked = np.array(shrinked[0]).reshape(-1, 2) - if shrinked.shape[0] <= 2 or not Polygon(shrinked).is_valid: - seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False - continue - cv2.fillPoly(seg_target[idx], [shrinked.astype(np.int32)], 1) + input_dtype = next(iter(target[0].values())).dtype if len(target) > 0 else np.float32 - # Draw on both thresh map and thresh mask - poly, thresh_target[idx], thresh_mask[idx] = self.draw_thresh_map( - poly, thresh_target[idx], thresh_mask[idx] - ) + if channels_last: + h, w = output_shape[1:-1] + target_shape = (output_shape[0], output_shape[-1], h, w) # (Batch_size, num_classes, h, w) + else: + h, w = output_shape[-2:] + target_shape = output_shape # (Batch_size, num_classes, h, w) + seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8) + seg_mask: np.ndarray = np.ones(target_shape, dtype=bool) + thresh_target: np.ndarray = np.zeros(target_shape, dtype=np.float32) + thresh_mask: np.ndarray = np.ones(target_shape, dtype=np.uint8) + + for idx, tgt in enumerate(target): + for class_idx, _tgt in enumerate(tgt.values()): + # Draw each polygon on gt + if _tgt.shape[0] == 0: + # Empty image, full masked + # seg_mask[idx, :, :, class_idx] = False + seg_mask[idx, class_idx] = False + + # Absolute bounding boxes + abs_boxes = _tgt.copy() + if abs_boxes.ndim == 3: + abs_boxes[:, :, 0] *= w + abs_boxes[:, :, 1] *= h + polys = abs_boxes + boxes_size = np.linalg.norm(abs_boxes[:, 2, :] - abs_boxes[:, 0, :], axis=-1) + abs_boxes = np.concatenate((abs_boxes.min(1), abs_boxes.max(1)), -1).round().astype(np.int32) + else: + abs_boxes[:, [0, 2]] *= w + abs_boxes[:, [1, 3]] *= h + abs_boxes = abs_boxes.round().astype(np.int32) + polys = np.stack( + [ + abs_boxes[:, [0, 1]], + abs_boxes[:, [0, 3]], + abs_boxes[:, [2, 3]], + abs_boxes[:, [2, 1]], + ], + axis=1, + ) + boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1]) + + for box, box_size, poly in zip(abs_boxes, boxes_size, polys): + # Mask boxes that are too small + if box_size < self.min_size_box: + # seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False + continue + + # Negative shrink for gt, as described in paper + polygon = Polygon(poly) + distance = polygon.area * (1 - np.power(self.shrink_ratio, 2)) / polygon.length + subject = [tuple(coor) for coor in poly] + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + shrinked = padding.Execute(-distance) + + # Draw polygon on gt if it is valid + if len(shrinked) == 0: + # seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False + continue + shrinked = np.array(shrinked[0]).reshape(-1, 2) + if shrinked.shape[0] <= 2 or not Polygon(shrinked).is_valid: + # seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1, class_idx] = False + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False + continue + cv2.fillPoly(seg_target[idx, class_idx], [shrinked.astype(np.int32)], 1) + + # Draw on both thresh map and thresh mask + poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] = self.draw_thresh_map( + poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] + ) + if channels_last: + seg_target = seg_target.transpose((0, 2, 3, 1)) + seg_mask = seg_mask.transpose((0, 2, 3, 1)) + thresh_target = thresh_target.transpose((0, 2, 3, 1)) + thresh_mask = thresh_mask.transpose((0, 2, 3, 1)) thresh_target = thresh_target.astype(input_dtype) * (self.thresh_max - self.thresh_min) + self.thresh_min diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index c12c191588..fa1a02e4f2 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -13,6 +13,8 @@ from torchvision.models._utils import IntermediateLayerGetter from torchvision.ops.deform_conv import DeformConv2d +from doctr.file_utils import CLASS_NAME + from ...classification import mobilenet_v3_large from ...utils import load_pretrained_params from .base import DBPostProcessor, _DBNet @@ -108,10 +110,10 @@ class DBNet(_DBNet, nn.Module): feature extractor: the backbone serving as feature extractor head_chans: the number of channels in the head deform_conv: whether to use deformable convolution - num_classes: number of output channels in the segmentation map assume_straight_pages: if True, fit straight bounding boxes only exportable: onnx exportable returns only logits cfg: the configuration dict of the model + class_names: list of class names """ def __init__( @@ -119,14 +121,16 @@ def __init__( feat_extractor: IntermediateLayerGetter, head_chans: int = 256, deform_conv: bool = False, - num_classes: int = 1, bin_thresh: float = 0.3, assume_straight_pages: bool = True, exportable: bool = False, cfg: Optional[Dict[str, Any]] = None, + class_names: List[str] = [CLASS_NAME], ) -> None: super().__init__() + self.class_names = class_names + num_classes: int = len(self.class_names) self.cfg = cfg conv_layer = DeformConv2d if deform_conv else nn.Conv2d @@ -209,7 +213,8 @@ def forward( if target is None or return_preds: # Post-process boxes (keep only text predictions) out["preds"] = [ - preds[0] for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy()) + dict(zip(self.class_names, preds)) + for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy()) ] if target is not None: @@ -232,10 +237,10 @@ def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target: A loss tensor """ - prob_map = torch.sigmoid(out_map.squeeze(1)) - thresh_map = torch.sigmoid(thresh_map.squeeze(1)) + prob_map = torch.sigmoid(out_map) + thresh_map = torch.sigmoid(thresh_map) - targets = self.build_target(target, prob_map.shape) # type: ignore[arg-type] + targets = self.build_target(target, prob_map.shape, False) # type: ignore[arg-type] seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1]) seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device) @@ -248,7 +253,11 @@ def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target: dice_loss = torch.zeros(1, device=out_map.device) l1_loss = torch.zeros(1, device=out_map.device) if torch.any(seg_mask): - bce_loss = F.binary_cross_entropy_with_logits(out_map.squeeze(1), seg_target, reduction="none")[seg_mask] + bce_loss = F.binary_cross_entropy_with_logits( + out_map, + seg_target, + reduction="none", + )[seg_mask] neg_target = 1 - seg_target[seg_mask] positive_count = seg_target[seg_mask].sum() @@ -298,6 +307,10 @@ def _dbnet( {layer_name: str(idx) for idx, layer_name in enumerate(fpn_layers)}, ) + if not kwargs.get("class_names", None): + kwargs["class_names"] = default_cfgs[arch].get("class_names", [CLASS_NAME]) + else: + kwargs["class_names"] = sorted(kwargs["class_names"]) # Build the model model = DBNet(feat_extractor, cfg=default_cfgs[arch], **kwargs) # Load pretrained parameters diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index ee8c3ea01f..d78c1329ac 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -14,6 +14,7 @@ from tensorflow.keras import layers from tensorflow.keras.applications import ResNet50 +from doctr.file_utils import CLASS_NAME from doctr.models.utils import IntermediateLayerGetter, conv_sequence, load_pretrained_params from doctr.utils.repr import NestedObject @@ -109,10 +110,10 @@ class DBNet(_DBNet, keras.Model, NestedObject): Args: feature extractor: the backbone serving as feature extractor fpn_channels: number of channels each extracted feature maps is mapped to - num_classes: number of output channels in the segmentation map assume_straight_pages: if True, fit straight bounding boxes only exportable: onnx exportable returns only logits cfg: the configuration dict of the model + class_names: list of class names """ _children_names: List[str] = ["feat_extractor", "fpn", "probability_head", "threshold_head", "postprocessor"] @@ -121,14 +122,16 @@ def __init__( self, feature_extractor: IntermediateLayerGetter, fpn_channels: int = 128, # to be set to 256 to represent the author's initial idea - num_classes: int = 1, bin_thresh: float = 0.3, assume_straight_pages: bool = True, exportable: bool = False, cfg: Optional[Dict[str, Any]] = None, + class_names: List[str] = [CLASS_NAME], ) -> None: super().__init__() + self.class_names = class_names + num_classes: int = len(self.class_names) self.cfg = cfg self.feat_extractor = feature_extractor @@ -161,7 +164,12 @@ def __init__( self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh) - def compute_loss(self, out_map: tf.Tensor, thresh_map: tf.Tensor, target: List[np.ndarray]) -> tf.Tensor: + def compute_loss( + self, + out_map: tf.Tensor, + thresh_map: tf.Tensor, + target: List[Dict[str, np.ndarray]], + ) -> tf.Tensor: """Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes and a list of masks for each image. From there it computes the loss with the model output @@ -174,10 +182,10 @@ def compute_loss(self, out_map: tf.Tensor, thresh_map: tf.Tensor, target: List[n A loss tensor """ - prob_map = tf.math.sigmoid(tf.squeeze(out_map, axis=[-1])) - thresh_map = tf.math.sigmoid(tf.squeeze(thresh_map, axis=[-1])) + prob_map = tf.math.sigmoid(out_map) + thresh_map = tf.math.sigmoid(thresh_map) - seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape[:3]) + seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape, True) seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype) seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool) thresh_target = tf.convert_to_tensor(thresh_target, dtype=out_map.dtype) @@ -185,7 +193,11 @@ def compute_loss(self, out_map: tf.Tensor, thresh_map: tf.Tensor, target: List[n # Compute balanced BCE loss for proba_map bce_scale = 5.0 - bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map, from_logits=True)[seg_mask] + bce_loss = tf.keras.losses.binary_crossentropy( + seg_target[..., None], + out_map[..., None], + from_logits=True, + )[seg_mask] neg_target = 1 - seg_target[seg_mask] positive_count = tf.math.reduce_sum(seg_target[seg_mask]) @@ -216,7 +228,7 @@ def compute_loss(self, out_map: tf.Tensor, thresh_map: tf.Tensor, target: List[n def call( self, x: tf.Tensor, - target: Optional[List[np.ndarray]] = None, + target: Optional[List[Dict[str, np.ndarray]]] = None, return_model_output: bool = False, return_preds: bool = False, **kwargs: Any, @@ -239,7 +251,7 @@ def call( if target is None or return_preds: # Post-process boxes (keep only text predictions) - out["preds"] = [preds[0] for preds in self.postprocessor(prob_map.numpy())] + out["preds"] = [dict(zip(self.class_names, preds)) for preds in self.postprocessor(prob_map.numpy())] if target is not None: thresh_map = self.threshold_head(feat_concat, **kwargs) @@ -264,6 +276,10 @@ def _db_resnet( # Patch the config _cfg = deepcopy(default_cfgs[arch]) _cfg["input_shape"] = input_shape or _cfg["input_shape"] + if not kwargs.get("class_names", None): + kwargs["class_names"] = _cfg.get("class_names", [CLASS_NAME]) + else: + kwargs["class_names"] = sorted(kwargs["class_names"]) # Feature extractor feat_extractor = IntermediateLayerGetter( diff --git a/doctr/models/detection/linknet/base.py b/doctr/models/detection/linknet/base.py index c48fa4f600..7fdf852345 100644 --- a/doctr/models/detection/linknet/base.py +++ b/doctr/models/detection/linknet/base.py @@ -5,14 +5,13 @@ # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization -from typing import List, Tuple, Union +from typing import Dict, List, Tuple, Union import cv2 import numpy as np import pyclipper from shapely.geometry import Polygon -from doctr.file_utils import is_tf_available from doctr.models.core import BaseModel from ..core import DetectionPostProcessor @@ -156,78 +155,95 @@ class _LinkNet(BaseModel): def build_target( self, - target: List[np.ndarray], - output_shape: Tuple[int, int], + target: List[Dict[str, np.ndarray]], + output_shape: Tuple[int, int, int], + channels_last: bool = True, ) -> Tuple[np.ndarray, np.ndarray]: + """Build the target, and it's mask to be used from loss computation. - if any(t.dtype != np.float32 for t in target): + Args: + target: target coming from dataset + output_shape: shape of the output of the model without batch_size + channels_last: whether channels are last or not + + Returns: + the new formatted target and the mask + """ + + if any(t.dtype != np.float32 for tgt in target for t in tgt.values()): raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.") - if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for t in target): + if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()): raise ValueError("the 'boxes' entry of the target is expected to take values between 0 & 1.") - h, w = output_shape - target_shape = (len(target), h, w, 1) + h: int + w: int + if channels_last: + h, w, num_classes = output_shape + else: + num_classes, h, w = output_shape + target_shape = (len(target), num_classes, h, w) seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8) seg_mask: np.ndarray = np.ones(target_shape, dtype=bool) - for idx, _target in enumerate(target): - # Draw each polygon on gt - if _target.shape[0] == 0: - # Empty image, full masked - seg_mask[idx] = False - - # Absolute bounding boxes - abs_boxes = _target.copy() - - if abs_boxes.ndim == 3: - abs_boxes[:, :, 0] *= w - abs_boxes[:, :, 1] *= h - polys = abs_boxes - boxes_size = np.linalg.norm(abs_boxes[:, 2, :] - abs_boxes[:, 0, :], axis=-1) - abs_boxes = np.concatenate((abs_boxes.min(1), abs_boxes.max(1)), -1).round().astype(np.int32) - else: - abs_boxes[:, [0, 2]] *= w - abs_boxes[:, [1, 3]] *= h - abs_boxes = abs_boxes.round().astype(np.int32) - polys = np.stack( - [ - abs_boxes[:, [0, 1]], - abs_boxes[:, [0, 3]], - abs_boxes[:, [2, 3]], - abs_boxes[:, [2, 1]], - ], - axis=1, - ) - boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1]) - - for poly, box, box_size in zip(polys, abs_boxes, boxes_size): - # Mask boxes that are too small - if box_size < self.min_size_box: - seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False - continue - - # Negative shrink for gt, as described in paper - polygon = Polygon(poly) - distance = polygon.area * (1 - np.power(self.shrink_ratio, 2)) / polygon.length - subject = [tuple(coor) for coor in poly] - padding = pyclipper.PyclipperOffset() - padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) - shrunken = padding.Execute(-distance) - - # Draw polygon on gt if it is valid - if len(shrunken) == 0: - seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False - continue - shrunken = np.array(shrunken[0]).reshape(-1, 2) - if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid: - seg_mask[idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False - continue - cv2.fillPoly(seg_target[idx], [shrunken.astype(np.int32)], 1) - - # Don't forget to switch back to channel first if PyTorch is used - if not is_tf_available(): - seg_target = seg_target.transpose(0, 3, 1, 2) - seg_mask = seg_mask.transpose(0, 3, 1, 2) + for idx, tgt in enumerate(target): + for class_idx, _tgt in enumerate(tgt.values()): + # Draw each polygon on gt + if _tgt.shape[0] == 0: + # Empty image, full masked + seg_mask[idx, class_idx] = False + + # Absolute bounding boxes + abs_boxes = _tgt.copy() + + if abs_boxes.ndim == 3: + abs_boxes[:, :, 0] *= w + abs_boxes[:, :, 1] *= h + polys = abs_boxes + boxes_size = np.linalg.norm(abs_boxes[:, 2, :] - abs_boxes[:, 0, :], axis=-1) + abs_boxes = np.concatenate((abs_boxes.min(1), abs_boxes.max(1)), -1).round().astype(np.int32) + else: + abs_boxes[:, [0, 2]] *= w + abs_boxes[:, [1, 3]] *= h + abs_boxes = abs_boxes.round().astype(np.int32) + polys = np.stack( + [ + abs_boxes[:, [0, 1]], + abs_boxes[:, [0, 3]], + abs_boxes[:, [2, 3]], + abs_boxes[:, [2, 1]], + ], + axis=1, + ) + boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1]) + + for poly, box, box_size in zip(polys, abs_boxes, boxes_size): + # Mask boxes that are too small + if box_size < self.min_size_box: + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False + continue + + # Negative shrink for gt, as described in paper + polygon = Polygon(poly) + distance = polygon.area * (1 - np.power(self.shrink_ratio, 2)) / polygon.length + subject = [tuple(coor) for coor in poly] + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + shrunken = padding.Execute(-distance) + + # Draw polygon on gt if it is valid + if len(shrunken) == 0: + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False + continue + shrunken = np.array(shrunken[0]).reshape(-1, 2) + if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid: + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False + continue + cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1) + + # Don't forget to switch back to channel last if Tensorflow is used + if channels_last: + seg_target = seg_target.transpose((0, 2, 3, 1)) + seg_mask = seg_mask.transpose((0, 2, 3, 1)) return seg_target, seg_mask diff --git a/doctr/models/detection/linknet/pytorch.py b/doctr/models/detection/linknet/pytorch.py index 34c493fe2e..70ce32c77c 100644 --- a/doctr/models/detection/linknet/pytorch.py +++ b/doctr/models/detection/linknet/pytorch.py @@ -11,6 +11,7 @@ from torch.nn import functional as F from torchvision.models._utils import IntermediateLayerGetter +from doctr.file_utils import CLASS_NAME from doctr.models.classification import resnet18, resnet34, resnet50 from ...utils import load_pretrained_params @@ -91,25 +92,27 @@ class LinkNet(nn.Module, _LinkNet): Args: feature extractor: the backbone serving as feature extractor - num_classes: number of output channels in the segmentation map head_chans: number of channels in the head layers assume_straight_pages: if True, fit straight bounding boxes only exportable: onnx exportable returns only logits cfg: the configuration dict of the model + class_names: list of class names """ def __init__( self, feat_extractor: IntermediateLayerGetter, - num_classes: int = 1, bin_thresh: float = 0.1, head_chans: int = 32, assume_straight_pages: bool = True, exportable: bool = False, cfg: Optional[Dict[str, Any]] = None, + class_names: List[str] = [CLASS_NAME], ) -> None: super().__init__() + self.class_names = class_names + num_classes: int = len(self.class_names) self.cfg = cfg self.exportable = exportable self.assume_straight_pages = assume_straight_pages @@ -182,7 +185,8 @@ def forward( if target is None or return_preds: # Post-process boxes out["preds"] = [ - preds[0] for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy()) + dict(zip(self.class_names, preds)) + for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy()) ] if target is not None: @@ -203,21 +207,22 @@ def compute_loss( `_. Args: - out_map: output feature map of the model of shape (N, 1, H, W) + out_map: output feature map of the model of shape (N, num_classes, H, W) target: list of dictionary where each dict has a `boxes` and a `flags` entry gamma: modulating factor in the focal loss formula alpha: balancing factor in the focal loss formula + eps: epsilon factor in dice loss Returns: A loss tensor """ - _target, _mask = self.build_target(target, out_map.shape[-2:]) # type: ignore[arg-type] + _target, _mask = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type] seg_target, seg_mask = torch.from_numpy(_target).to(dtype=out_map.dtype), torch.from_numpy(_mask) seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device) seg_mask = seg_mask.to(dtype=torch.float32) - bce_loss = bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction="none") + bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction="none") proba_map = torch.sigmoid(out_map) # Focal loss @@ -228,15 +233,15 @@ def compute_loss( # Unreduced version focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss # Class reduced - focal_loss = (seg_mask * focal_loss).sum((0, 2, 3)) / seg_mask.sum((0, 2, 3)) + focal_loss = (seg_mask * focal_loss).sum((0, 1, 2, 3)) / seg_mask.sum((0, 1, 2, 3)) # Dice loss - inter = (seg_mask * proba_map * seg_target).sum((0, 2, 3)) - cardinality = (seg_mask * (proba_map + seg_target)).sum((0, 2, 3)) + inter = (seg_mask * proba_map * seg_target).sum((0, 1, 2, 3)) + cardinality = (seg_mask * (proba_map + seg_target)).sum((0, 1, 2, 3)) dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps) # Return the full loss (equal sum of focal loss and dice loss) - return focal_loss.mean() + dice_loss.mean() + return focal_loss + dice_loss def _linknet( @@ -256,6 +261,10 @@ def _linknet( backbone, {layer_name: str(idx) for idx, layer_name in enumerate(fpn_layers)}, ) + if not kwargs.get("class_names", None): + kwargs["class_names"] = default_cfgs[arch].get("class_names", [CLASS_NAME]) + else: + kwargs["class_names"] = sorted(kwargs["class_names"]) # Build the model model = LinkNet(feat_extractor, cfg=default_cfgs[arch], **kwargs) diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py index 1b35ef2b1e..cb1b19fa1c 100644 --- a/doctr/models/detection/linknet/tensorflow.py +++ b/doctr/models/detection/linknet/tensorflow.py @@ -13,6 +13,7 @@ from tensorflow import keras from tensorflow.keras import Model, Sequential, layers +from doctr.file_utils import CLASS_NAME from doctr.models.classification import resnet18, resnet34, resnet50 from doctr.models.utils import IntermediateLayerGetter, conv_sequence, load_pretrained_params from doctr.utils.repr import NestedObject @@ -21,7 +22,6 @@ __all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50", "linknet_resnet18_rotation"] - default_cfgs: Dict[str, Dict[str, Any]] = { "linknet_resnet18": { "mean": (0.798, 0.785, 0.772), @@ -79,7 +79,6 @@ def __init__( out_chans: int, in_shapes: List[Tuple[int, ...]], ) -> None: - super().__init__() self.out_chans = out_chans strides = [2] * (len(in_shapes) - 1) + [1] @@ -107,10 +106,10 @@ class LinkNet(_LinkNet, keras.Model): Args: feature extractor: the backbone serving as feature extractor fpn_channels: number of channels each extracted feature maps is mapped to - num_classes: number of output channels in the segmentation map assume_straight_pages: if True, fit straight bounding boxes only exportable: onnx exportable returns only logits cfg: the configuration dict of the model + class_names: list of class names """ _children_names: List[str] = ["feat_extractor", "fpn", "classifier", "postprocessor"] @@ -119,14 +118,17 @@ def __init__( self, feat_extractor: IntermediateLayerGetter, fpn_channels: int = 64, - num_classes: int = 1, bin_thresh: float = 0.1, assume_straight_pages: bool = True, exportable: bool = False, cfg: Optional[Dict[str, Any]] = None, + class_names: List[str] = [CLASS_NAME], ) -> None: super().__init__(cfg=cfg) + self.class_names = class_names + num_classes: int = len(self.class_names) + self.exportable = exportable self.assume_straight_pages = assume_straight_pages @@ -165,7 +167,7 @@ def __init__( def compute_loss( self, out_map: tf.Tensor, - target: List[np.ndarray], + target: List[Dict[str, np.ndarray]], gamma: float = 2.0, alpha: float = 0.5, eps: float = 1e-8, @@ -178,17 +180,17 @@ def compute_loss( target: list of dictionary where each dict has a `boxes` and a `flags` entry gamma: modulating factor in the focal loss formula alpha: balancing factor in the focal loss formula + eps: epsilon factor in dice loss Returns: A loss tensor """ - seg_target, seg_mask = self.build_target(target, out_map.shape[1:3]) - + seg_target, seg_mask = self.build_target(target, out_map.shape[1:], True) seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype) seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool) seg_mask = tf.cast(seg_mask, tf.float32) - bce_loss = tf.keras.losses.binary_crossentropy(seg_target, out_map, from_logits=True)[..., None] + bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True) proba_map = tf.sigmoid(out_map) # Focal loss @@ -200,19 +202,19 @@ def compute_loss( # Unreduced loss focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss # Class reduced - focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2)) / tf.reduce_sum(seg_mask, (0, 1, 2)) + focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3)) # Dice loss - inter = tf.math.reduce_sum(seg_mask * proba_map * seg_target, (0, 1, 2)) - cardinality = tf.math.reduce_sum(seg_mask * (proba_map + seg_target), (0, 1, 2)) + inter = tf.math.reduce_sum(seg_mask * proba_map * seg_target, (0, 1, 2, 3)) + cardinality = tf.math.reduce_sum((proba_map + seg_target), (0, 1, 2, 3)) dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps) - return tf.reduce_mean(focal_loss) + tf.reduce_mean(dice_loss) + return focal_loss + dice_loss def call( self, x: tf.Tensor, - target: Optional[List[np.ndarray]] = None, + target: Optional[List[Dict[str, np.ndarray]]] = None, return_model_output: bool = False, return_preds: bool = False, **kwargs: Any, @@ -234,7 +236,7 @@ def call( if target is None or return_preds: # Post-process boxes - out["preds"] = [preds[0] for preds in self.postprocessor(prob_map.numpy())] + out["preds"] = [dict(zip(self.class_names, preds)) for preds in self.postprocessor(prob_map.numpy())] if target is not None: loss = self.compute_loss(logits, target) @@ -252,12 +254,15 @@ def _linknet( input_shape: Optional[Tuple[int, int, int]] = None, **kwargs: Any, ) -> LinkNet: - pretrained_backbone = pretrained_backbone and not pretrained # Patch the config _cfg = deepcopy(default_cfgs[arch]) _cfg["input_shape"] = input_shape or default_cfgs[arch]["input_shape"] + if not kwargs.get("class_names", None): + kwargs["class_names"] = _cfg.get("class_names", [CLASS_NAME]) + else: + kwargs["class_names"] = sorted(kwargs["class_names"]) # Feature extractor feat_extractor = IntermediateLayerGetter( diff --git a/doctr/models/detection/predictor/tensorflow.py b/doctr/models/detection/predictor/tensorflow.py index 9a6fd89ca5..c4735dc367 100644 --- a/doctr/models/detection/predictor/tensorflow.py +++ b/doctr/models/detection/predictor/tensorflow.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Any, List, Union +from typing import Any, Dict, List, Union import numpy as np import tensorflow as tf @@ -38,7 +38,7 @@ def __call__( self, pages: List[Union[np.ndarray, tf.Tensor]], **kwargs: Any, - ) -> List[np.ndarray]: + ) -> List[Dict[str, np.ndarray]]: # Dimension check if any(page.ndim != 3 for page in pages): diff --git a/doctr/models/kie_predictor/__init__.py b/doctr/models/kie_predictor/__init__.py new file mode 100644 index 0000000000..ff30c3b2e7 --- /dev/null +++ b/doctr/models/kie_predictor/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available + +if is_tf_available(): + from .tensorflow import * +else: + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/kie_predictor/base.py b/doctr/models/kie_predictor/base.py new file mode 100644 index 0000000000..f92b39f15a --- /dev/null +++ b/doctr/models/kie_predictor/base.py @@ -0,0 +1,42 @@ +# Copyright (C) 2022, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Optional + +from doctr.models.builder import KIEDocumentBuilder + +from ..classification.predictor import CropOrientationPredictor +from ..predictor.base import _OCRPredictor + +__all__ = ["_KIEPredictor"] + + +class _KIEPredictor(_OCRPredictor): + """Implements an object able to localize and identify text elements in a set of documents + + Args: + assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages + without rotated textual elements. + straighten_pages: if True, estimates the page general orientation based on the median line orientation. + Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped + accordingly. Doing so will improve performances for documents with page-uniform rotations. + preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding) + symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically. + kwargs: keyword args of `DocumentBuilder` + """ + + crop_orientation_predictor: Optional[CropOrientationPredictor] + + def __init__( + self, + assume_straight_pages: bool = True, + straighten_pages: bool = False, + preserve_aspect_ratio: bool = True, + symmetric_pad: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs) + + self.doc_builder: KIEDocumentBuilder = KIEDocumentBuilder(**kwargs) diff --git a/doctr/models/kie_predictor/pytorch.py b/doctr/models/kie_predictor/pytorch.py new file mode 100644 index 0000000000..5e4b0cfd44 --- /dev/null +++ b/doctr/models/kie_predictor/pytorch.py @@ -0,0 +1,175 @@ +# Copyright (C) 2022, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Dict, List, Union + +import numpy as np +import torch +from torch import nn + +from doctr.io.elements import Document +from doctr.models._utils import estimate_orientation, get_language, invert_data_structure +from doctr.models.detection.predictor import DetectionPredictor +from doctr.models.recognition.predictor import RecognitionPredictor +from doctr.utils.geometry import rotate_boxes, rotate_image + +from .base import _KIEPredictor + +__all__ = ["KIEPredictor"] + + +class KIEPredictor(nn.Module, _KIEPredictor): + """Implements an object able to localize and identify text elements in a set of documents + + Args: + det_predictor: detection module + reco_predictor: recognition module + assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages + without rotated textual elements. + straighten_pages: if True, estimates the page general orientation based on the median line orientation. + Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped + accordingly. Doing so will improve performances for documents with page-uniform rotations. + detect_orientation: if True, the estimated general page orientation will be added to the predictions for each + page. Doing so will slightly deteriorate the overall latency. + detect_language: if True, the language prediction will be added to the predictions for each + page. Doing so will slightly deteriorate the overall latency. + kwargs: keyword args of `DocumentBuilder` + """ + + def __init__( + self, + det_predictor: DetectionPredictor, + reco_predictor: RecognitionPredictor, + assume_straight_pages: bool = True, + straighten_pages: bool = False, + preserve_aspect_ratio: bool = False, + symmetric_pad: bool = True, + detect_orientation: bool = False, + detect_language: bool = False, + **kwargs: Any, + ) -> None: + + nn.Module.__init__(self) + self.det_predictor = det_predictor.eval() # type: ignore[attr-defined] + self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined] + _KIEPredictor.__init__( + self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs + ) + self.detect_orientation = detect_orientation + self.detect_language = detect_language + + @torch.no_grad() + def forward( + self, + pages: List[Union[np.ndarray, torch.Tensor]], + **kwargs: Any, + ) -> Document: + + # Dimension check + if any(page.ndim != 3 for page in pages): + raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") + + origin_page_shapes = [page.shape[:2] if isinstance(page, np.ndarray) else page.shape[-2:] for page in pages] + + # Detect document rotation and rotate pages + if self.detect_orientation: + origin_page_orientations = [estimate_orientation(page) for page in pages] # type: ignore[arg-type] + orientations = [ + {"value": orientation_page, "confidence": 1.0} for orientation_page in origin_page_orientations + ] + else: + orientations = None + if self.straighten_pages: + origin_page_orientations = ( + origin_page_orientations + if self.detect_orientation + else [estimate_orientation(page) for page in pages] # type: ignore[arg-type] + ) + pages = [ + rotate_image(page, -angle, expand=True) # type: ignore[arg-type] + for page, angle in zip(pages, origin_page_orientations) + ] + + # Localize text elements + loc_preds = self.det_predictor(pages, **kwargs) + dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore[assignment] + # Check whether crop mode should be switched to channels first + channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray) + + # Rectify crops if aspect ratio + dict_loc_preds = { + k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items() # type: ignore[arg-type] + } + + # Crop images + crops = {} + for class_name in dict_loc_preds.keys(): + crops[class_name], dict_loc_preds[class_name] = self._prepare_crops( + pages, # type: ignore[arg-type] + dict_loc_preds[class_name], + channels_last=channels_last, + assume_straight_pages=self.assume_straight_pages, + ) + # Rectify crop orientation + if not self.assume_straight_pages: + for class_name in dict_loc_preds.keys(): + crops[class_name], dict_loc_preds[class_name] = self._rectify_crops( + crops[class_name], dict_loc_preds[class_name] + ) + # Identify character sequences + word_preds = { + k: self.reco_predictor([crop for page_crops in crop_value for crop in page_crops], **kwargs) + for k, crop_value in crops.items() + } + + boxes: Dict = {} + text_preds: Dict = {} + for class_name in dict_loc_preds.keys(): + boxes[class_name], text_preds[class_name] = self._process_predictions( + dict_loc_preds[class_name], word_preds[class_name] + ) + + boxes_per_page: List[Dict] = invert_data_structure(boxes) # type: ignore[assignment] + text_preds_per_page: List[Dict] = invert_data_structure(text_preds) # type: ignore[assignment] + if self.detect_language: + languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page] + languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages] + else: + languages_dict = None + # Rotate back pages and boxes while keeping original image size + if self.straighten_pages: + boxes_per_page = [ + { + k: rotate_boxes( + page_boxes, + angle, + orig_shape=page.shape[:2] + if isinstance(page, np.ndarray) + else page.shape[1:], # type: ignore[arg-type] + target_shape=mask, # type: ignore[arg-type] + ) + for k, page_boxes in page_boxes_dict.items() + } + for page_boxes_dict, page, angle, mask in zip( + boxes_per_page, pages, origin_page_orientations, origin_page_shapes + ) + ] + + out = self.doc_builder( + boxes_per_page, + text_preds_per_page, + [page.shape[:2] if channels_last else page.shape[-2:] for page in pages], # type: ignore[misc] + orientations, + languages_dict, + ) + return out + + @staticmethod + def get_text(text_pred: Dict) -> str: + text = [] + for value in text_pred.values(): + text += [item[0] for item in value] + + return " ".join(text) diff --git a/doctr/models/kie_predictor/tensorflow.py b/doctr/models/kie_predictor/tensorflow.py new file mode 100644 index 0000000000..dc408bef03 --- /dev/null +++ b/doctr/models/kie_predictor/tensorflow.py @@ -0,0 +1,163 @@ +# Copyright (C) 2022, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Dict, List, Union + +import numpy as np +import tensorflow as tf + +from doctr.io.elements import Document +from doctr.models._utils import estimate_orientation, get_language, invert_data_structure +from doctr.models.detection.predictor import DetectionPredictor +from doctr.models.recognition.predictor import RecognitionPredictor +from doctr.utils.geometry import rotate_boxes, rotate_image +from doctr.utils.repr import NestedObject + +from .base import _KIEPredictor + +__all__ = ["KIEPredictor"] + + +class KIEPredictor(NestedObject, _KIEPredictor): + """Implements an object able to localize and identify text elements in a set of documents + + Args: + det_predictor: detection module + reco_predictor: recognition module + assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages + without rotated textual elements. + straighten_pages: if True, estimates the page general orientation based on the median line orientation. + Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped + accordingly. Doing so will improve performances for documents with page-uniform rotations. + detect_orientation: if True, the estimated general page orientation will be added to the predictions for each + page. Doing so will slightly deteriorate the overall latency. + detect_language: if True, the language prediction will be added to the predictions for each + page. Doing so will slightly deteriorate the overall latency. + kwargs: keyword args of `DocumentBuilder` + """ + + _children_names = ["det_predictor", "reco_predictor", "doc_builder"] + + def __init__( + self, + det_predictor: DetectionPredictor, + reco_predictor: RecognitionPredictor, + assume_straight_pages: bool = True, + straighten_pages: bool = False, + preserve_aspect_ratio: bool = False, + symmetric_pad: bool = True, + detect_orientation: bool = False, + detect_language: bool = False, + **kwargs: Any, + ) -> None: + + self.det_predictor = det_predictor + self.reco_predictor = reco_predictor + _KIEPredictor.__init__( + self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs + ) + self.detect_orientation = detect_orientation + self.detect_language = detect_language + + def __call__( + self, + pages: List[Union[np.ndarray, tf.Tensor]], + **kwargs: Any, + ) -> Document: + + # Dimension check + if any(page.ndim != 3 for page in pages): + raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") + + origin_page_shapes = [page.shape[:2] for page in pages] + + # Detect document rotation and rotate pages + if self.detect_orientation: + origin_page_orientations = [estimate_orientation(page) for page in pages] + orientations = [ + {"value": orientation_page, "confidence": 1.0} for orientation_page in origin_page_orientations + ] + else: + orientations = None + if self.straighten_pages: + origin_page_orientations = ( + origin_page_orientations if self.detect_orientation else [estimate_orientation(page) for page in pages] + ) + pages = [rotate_image(page, -angle, expand=True) for page, angle in zip(pages, origin_page_orientations)] + + # Localize text elements + loc_preds = self.det_predictor(pages, **kwargs) + + dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore[assignment] + # Rectify crops if aspect ratio + dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()} + + # Crop images + crops = {} + for class_name in dict_loc_preds.keys(): + crops[class_name], dict_loc_preds[class_name] = self._prepare_crops( + pages, dict_loc_preds[class_name], channels_last=True, assume_straight_pages=self.assume_straight_pages + ) + # Rectify crop orientation + if not self.assume_straight_pages: + for class_name in dict_loc_preds.keys(): + crops[class_name], dict_loc_preds[class_name] = self._rectify_crops( + crops[class_name], dict_loc_preds[class_name] + ) + + # Identify character sequences + word_preds = { + k: self.reco_predictor([crop for page_crops in crop_value for crop in page_crops], **kwargs) + for k, crop_value in crops.items() + } + + boxes: Dict = {} + text_preds: Dict = {} + for class_name in dict_loc_preds.keys(): + boxes[class_name], text_preds[class_name] = self._process_predictions( + dict_loc_preds[class_name], word_preds[class_name] + ) + + boxes_per_page: List[Dict] = invert_data_structure(boxes) # type: ignore[assignment] + text_preds_per_page: List[Dict] = invert_data_structure(text_preds) # type: ignore[assignment] + + if self.detect_language: + languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page] + languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages] + else: + languages_dict = None + # Rotate back pages and boxes while keeping original image size + if self.straighten_pages: + boxes_per_page = [ + { + k: rotate_boxes( + page_boxes, + angle, + orig_shape=page.shape[:2] if isinstance(page, np.ndarray) else page.shape[-2:], + target_shape=mask, # type: ignore[arg-type] + ) + for k, page_boxes in page_boxes_dict.items() + } + for page_boxes_dict, page, angle, mask in zip( + boxes_per_page, pages, origin_page_orientations, origin_page_shapes + ) + ] + + out = self.doc_builder( + boxes_per_page, + text_preds_per_page, + origin_page_shapes, # type: ignore[arg-type] + orientations, + languages_dict, + ) + return out + + @staticmethod + def get_text(text_pred: Dict) -> str: + text = [] + for value in text_pred.values(): + text += [item[0] for item in value] + + return " ".join(text) diff --git a/doctr/models/predictor/pytorch.py b/doctr/models/predictor/pytorch.py index ef81b08807..f7c7710876 100644 --- a/doctr/models/predictor/pytorch.py +++ b/doctr/models/predictor/pytorch.py @@ -94,6 +94,11 @@ def forward( # Localize text elements loc_preds = self.det_predictor(pages, **kwargs) + assert all( + len(loc_pred) == 1 for loc_pred in loc_preds + ), "Detection Model in ocr_predictor should output only one class" + + loc_preds = [list(loc_pred.values())[0] for loc_pred in loc_preds] # Check whether crop mode should be switched to channels first channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray) diff --git a/doctr/models/predictor/tensorflow.py b/doctr/models/predictor/tensorflow.py index c63ae8e335..840fe165b8 100644 --- a/doctr/models/predictor/tensorflow.py +++ b/doctr/models/predictor/tensorflow.py @@ -88,7 +88,12 @@ def __call__( pages = [rotate_image(page, -angle, expand=True) for page, angle in zip(pages, origin_page_orientations)] # Localize text elements - loc_preds = self.det_predictor(pages, **kwargs) + loc_preds_dict = self.det_predictor(pages, **kwargs) + assert all( + len(loc_pred) == 1 for loc_pred in loc_preds_dict + ), "Detection Model in ocr_predictor should output only one class" + + loc_preds: List[np.ndarray] = [list(loc_pred.values())[0] for loc_pred in loc_preds_dict] # Rectify crops if aspect ratio loc_preds = self._remove_padding(pages, loc_preds) diff --git a/doctr/models/zoo.py b/doctr/models/zoo.py index a9621df4e3..e8a7936c95 100644 --- a/doctr/models/zoo.py +++ b/doctr/models/zoo.py @@ -6,10 +6,11 @@ from typing import Any from .detection.zoo import detection_predictor +from .kie_predictor import KIEPredictor from .predictor import OCRPredictor from .recognition.zoo import recognition_predictor -__all__ = ["ocr_predictor"] +__all__ = ["ocr_predictor", "kie_predictor"] def _predictor( @@ -113,3 +114,106 @@ def ocr_predictor( detect_language=detect_language, **kwargs, ) + + +def _kie_predictor( + det_arch: Any, + reco_arch: Any, + pretrained: bool, + pretrained_backbone: bool = True, + assume_straight_pages: bool = True, + preserve_aspect_ratio: bool = False, + symmetric_pad: bool = True, + det_bs: int = 2, + reco_bs: int = 128, + detect_orientation: bool = False, + detect_language: bool = False, + **kwargs, +) -> KIEPredictor: + + # Detection + det_predictor = detection_predictor( + det_arch, + pretrained=pretrained, + pretrained_backbone=pretrained_backbone, + batch_size=det_bs, + assume_straight_pages=assume_straight_pages, + preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, + ) + + # Recognition + reco_predictor = recognition_predictor( + reco_arch, pretrained=pretrained, pretrained_backbone=pretrained_backbone, batch_size=reco_bs + ) + + return KIEPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=assume_straight_pages, + preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, + detect_orientation=detect_orientation, + detect_language=detect_language, + **kwargs, + ) + + +def kie_predictor( + det_arch: Any = "db_resnet50", + reco_arch: Any = "crnn_vgg16_bn", + pretrained: bool = False, + pretrained_backbone: bool = True, + assume_straight_pages: bool = True, + preserve_aspect_ratio: bool = False, + symmetric_pad: bool = True, + export_as_straight_boxes: bool = False, + detect_orientation: bool = False, + detect_language: bool = False, + **kwargs: Any, +) -> KIEPredictor: + """End-to-end KIE architecture using one model for localization, and another for text recognition. + + >>> import numpy as np + >>> from doctr.models import ocr_predictor + >>> model = ocr_predictor('db_resnet50', 'crnn_vgg16_bn', pretrained=True) + >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8) + >>> out = model([input_page]) + + Args: + det_arch: name of the detection architecture or the model itself to use + (e.g. 'db_resnet50', 'db_mobilenet_v3_large') + reco_arch: name of the recognition architecture or the model itself to use + (e.g. 'crnn_vgg16_bn', 'sar_resnet31') + pretrained: If True, returns a model pre-trained on our OCR dataset + pretrained_backbone: If True, returns a model with a pretrained backbone + assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages + without rotated textual elements. + preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before + running the detection model on it. + symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right. + export_as_straight_boxes: when assume_straight_pages is set to False, export final predictions + (potentially rotated) as straight bounding boxes. + detect_orientation: if True, the estimated general page orientation will be added to the predictions for each + page. Doing so will slightly deteriorate the overall latency. + detect_language: if True, the language prediction will be added to the predictions for each + page. Doing so will slightly deteriorate the overall latency. + kwargs: keyword args of `OCRPredictor` + + Returns: + KIE predictor + """ + + return _kie_predictor( + det_arch, + reco_arch, + pretrained, + pretrained_backbone=pretrained_backbone, + assume_straight_pages=assume_straight_pages, + preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, + export_as_straight_boxes=export_as_straight_boxes, + detect_orientation=detect_orientation, + detect_language=detect_language, + **kwargs, + ) diff --git a/doctr/utils/visualization.py b/doctr/utils/visualization.py index f6ee075ae8..dbc8792f45 100644 --- a/doctr/utils/visualization.py +++ b/doctr/utils/visualization.py @@ -2,7 +2,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. - +import colorsys from copy import deepcopy from typing import Any, Dict, List, Optional, Tuple, Union @@ -18,7 +18,7 @@ from .common_types import BoundingBox, Polygon4P from .fonts import get_font -__all__ = ["visualize_page", "synthesize_page", "draw_boxes"] +__all__ = ["visualize_page", "synthesize_page", "visualize_kie_page", "synthesize_kie_page", "draw_boxes"] def rect_patch( @@ -139,6 +139,24 @@ def create_obj_patch( raise ValueError("invalid geometry format") +def get_colors(num_colors: int) -> List[Tuple[float, float, float]]: + """Generate num_colors color for matplotlib + + Args: + num_colors: number of colors to generate + + Returns: + colors: list of generated colors + """ + colors = [] + for i in np.arange(0.0, 360.0, 360.0 / num_colors): + hue = i / 360.0 + lightness = (50 + np.random.rand() * 10) / 100.0 + saturation = (90 + np.random.rand() * 10) / 100.0 + colors.append(colorsys.hls_to_rgb(hue, lightness, saturation)) + return colors + + def visualize_page( page: Dict[str, Any], image: np.ndarray, @@ -313,6 +331,130 @@ def synthesize_page( return response +def visualize_kie_page( + page: Dict[str, Any], + image: np.ndarray, + words_only: bool = False, + display_artefacts: bool = True, + scale: float = 10, + interactive: bool = True, + add_labels: bool = True, + **kwargs: Any, +) -> Figure: + """Visualize a full page with predicted blocks, lines and words + + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from doctr.utils.visualization import visualize_page + >>> from doctr.models import ocr_db_crnn + >>> model = ocr_db_crnn(pretrained=True) + >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8) + >>> out = model([[input_page]]) + >>> visualize_kie_page(out[0].pages[0].export(), input_page) + >>> plt.show() + + Args: + page: the exported Page of a Document + image: np array of the page, needs to have the same shape than page['dimensions'] + words_only: whether only words should be displayed + display_artefacts: whether artefacts should be displayed + scale: figsize of the largest windows side + interactive: whether the plot should be interactive + add_labels: for static plot, adds text labels on top of bounding box + """ + # Get proper scale and aspect ratio + h, w = image.shape[:2] + size = (scale * w / h, scale) if h > w else (scale, h / w * scale) + fig, ax = plt.subplots(figsize=size) + # Display the image + ax.imshow(image) + # hide both axis + ax.axis("off") + + if interactive: + artists: List[patches.Patch] = [] # instantiate an empty list of patches (to be drawn on the page) + + colors = {k: color for color, k in zip(get_colors(len(page["predictions"])), page["predictions"])} + for key, value in page["predictions"].items(): + for prediction in value: + if not words_only: + rect = create_obj_patch( + prediction["geometry"], + page["dimensions"], + label=f"{key} \n {prediction['value']} (confidence: {prediction['confidence']:.2%}", + color=colors[key], + linewidth=1, + **kwargs, + ) + # add patch on figure + ax.add_patch(rect) + if interactive: + # add patch to cursor's artists + artists.append(rect) + + if interactive: + # Create mlp Cursor to hover patches in artists + mplcursors.Cursor(artists, hover=2).connect("add", lambda sel: sel.annotation.set_text(sel.artist.get_label())) + fig.tight_layout(pad=0.0) + + return fig + + +def synthesize_kie_page( + page: Dict[str, Any], + draw_proba: bool = False, + font_family: Optional[str] = None, +) -> np.ndarray: + """Draw a the content of the element page (OCR response) on a blank page. + + Args: + page: exported Page object to represent + draw_proba: if True, draw words in colors to represent confidence. Blue: p=1, red: p=0 + font_size: size of the font, default font = 13 + font_family: family of the font + + Return: + the synthesized page + """ + + # Draw template + h, w = page["dimensions"] + response = 255 * np.ones((h, w, 3), dtype=np.int32) + + # Draw each word + for predictions in page["predictions"].values(): + for prediction in predictions: + # Get aboslute word geometry + (xmin, ymin), (xmax, ymax) = prediction["geometry"] + xmin, xmax = int(round(w * xmin)), int(round(w * xmax)) + ymin, ymax = int(round(h * ymin)), int(round(h * ymax)) + + # White drawing context adapted to font size, 0.75 factor to convert pts --> pix + font = get_font(font_family, int(0.75 * (ymax - ymin))) + img = Image.new("RGB", (xmax - xmin, ymax - ymin), color=(255, 255, 255)) + d = ImageDraw.Draw(img) + # Draw in black the value of the word + try: + d.text((0, 0), prediction["value"], font=font, fill=(0, 0, 0)) + except UnicodeEncodeError: + # When character cannot be encoded, use its unidecode version + d.text((0, 0), unidecode(prediction["value"]), font=font, fill=(0, 0, 0)) + + # Colorize if draw_proba + if draw_proba: + p = int(255 * prediction["confidence"]) + mask = np.where(np.array(img) == 0, 1, 0) + proba: np.ndarray = np.array([255 - p, 0, p]) + color = mask * proba[np.newaxis, np.newaxis, :] + white_mask = 255 * (1 - mask) + img = color + white_mask + + # Write to response page + response[ymin:ymax, xmin:xmax, :] = np.array(img) + + return response + + def draw_boxes(boxes: np.ndarray, image: np.ndarray, color: Optional[Tuple[int, int, int]] = None, **kwargs) -> None: """Draw an array of relative straight boxes on an image diff --git a/references/detection/README.md b/references/detection/README.md index 9ae20005e6..e2ff24dfbc 100644 --- a/references/detection/README.md +++ b/references/detection/README.md @@ -57,7 +57,30 @@ labels.json ... } ``` +If you want to train a model with multiple classes, you can use the following format where polygons is a dictionnary where each key represents one class and has all the polygons representing that class. +labels.json +```shell +{ + "sample_img_01.png": { + 'img_dimensions': (900, 600), + 'img_hash': "theimagedumpmyhash", + 'polygons': { + "class_name_1": [[[x10, y10], [x20, y20], [x30, y30], [x40, y40]], ...], + "class_name_2": [[[x11, y11], [x21, y21], [x31, y31], [x41, y41]], ...] + } + }, + "sample_img_02.png": { + 'img_dimensions': (900, 600), + 'img_hash': "thisisahash", + 'polygons': { + "class_name_1": [[[x12, y12], [x22, y22], [x32, y32], [x42, y42]], ...], + "class_name_2": [[[x13, y13], [x23, y23], [x33, y33], [x43, y43]], ...] + } + }, + ... +} +``` ## Advanced options Feel free to inspect the multiple script option to customize your training to your own needs! diff --git a/references/detection/evaluate_pytorch.py b/references/detection/evaluate_pytorch.py index bba36987b0..8826ad0579 100644 --- a/references/detection/evaluate_pytorch.py +++ b/references/detection/evaluate_pytorch.py @@ -5,6 +5,8 @@ import os +from doctr.file_utils import CLASS_NAME + os.environ["USE_TORCH"] = "1" import logging @@ -36,7 +38,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): if torch.cuda.is_available(): images = images.cuda() images = batch_transforms(images) - targets = [t["boxes"] for t in targets] + targets = [{CLASS_NAME: t["boxes"]} for t in targets] if amp: with torch.cuda.amp.autocast(): out = model(images, targets, return_preds=True) @@ -44,9 +46,10 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): out = model(images, targets, return_preds=True) # Compute metric loc_preds = out["preds"] - for boxes_gt, boxes_pred in zip(targets, loc_preds): - # Remove scores - val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :-1]) + for target, loc_pred in zip(targets, loc_preds): + for boxes_gt, boxes_pred in zip(target.values(), loc_pred.values()): + # Remove scores + val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :-1]) val_loss += out["loss"].item() batch_cnt += 1 diff --git a/references/detection/evaluate_tensorflow.py b/references/detection/evaluate_tensorflow.py index b93457f5ec..0cef7a463d 100644 --- a/references/detection/evaluate_tensorflow.py +++ b/references/detection/evaluate_tensorflow.py @@ -5,6 +5,8 @@ import os +from doctr.file_utils import CLASS_NAME + os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" @@ -35,13 +37,14 @@ def evaluate(model, val_loader, batch_transforms, val_metric): val_loss, batch_cnt = 0, 0 for images, targets in tqdm(val_loader): images = batch_transforms(images) - targets = [t["boxes"] for t in targets] + targets = [{CLASS_NAME: t["boxes"]} for t in targets] out = model(images, targets, training=False, return_preds=True) # Compute metric loc_preds = out["preds"] - for boxes_gt, boxes_pred in zip(targets, loc_preds): - # Remove scores - val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :-1]) + for target, loc_pred in zip(targets, loc_preds): + for boxes_gt, boxes_pred in zip(target.values(), loc_pred.values()): + # Remove scores + val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :-1]) val_loss += out["loss"].numpy() batch_cnt += 1 diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index 9f8c475942..e17aa708d2 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -155,11 +155,12 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): out = model(images, targets, return_preds=True) # Compute metric loc_preds = out["preds"] - for boxes_gt, boxes_pred in zip(targets, loc_preds): - if args.rotation and args.eval_straight: - # Convert pred to boxes [xmin, ymin, xmax, ymax] N, 4, 2 --> N, 4 - boxes_pred = np.concatenate((boxes_pred.min(axis=1), boxes_pred.max(axis=1)), axis=-1) - val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :4]) + for target, loc_pred in zip(targets, loc_preds): + for boxes_gt, boxes_pred in zip(target.values(), loc_pred.values()): + if args.rotation and args.eval_straight: + # Convert pred to boxes [xmin, ymin, xmax, ymax] N, 4, 2 --> N, 4 + boxes_pred = np.concatenate((boxes_pred.min(axis=1), boxes_pred.max(axis=1)), axis=-1) + val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :4]) val_loss += out["loss"].item() batch_cnt += 1 @@ -220,7 +221,11 @@ def main(args): batch_transforms = Normalize(mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287)) # Load doctr model - model = detection.__dict__[args.arch](pretrained=args.pretrained, assume_straight_pages=not args.rotation) + model = detection.__dict__[args.arch]( + pretrained=args.pretrained, + assume_straight_pages=not args.rotation, + class_names=val_set.class_names, + ) # Resume weights if isinstance(args.resume, str): diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index b9e77365e5..03aca4ff39 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -30,7 +30,7 @@ from doctr.datasets import DataLoader, DetectionDataset from doctr.models import detection from doctr.utils.metrics import LocalizationConfusion -from utils import plot_recorder, plot_samples +from utils import load_backbone, plot_recorder, plot_samples def record_lr( @@ -115,11 +115,12 @@ def evaluate(model, val_loader, batch_transforms, val_metric): out = model(images, targets, training=False, return_preds=True) # Compute metric loc_preds = out["preds"] - for boxes_gt, boxes_pred in zip(targets, loc_preds): - if args.rotation and args.eval_straight: - # Convert pred to boxes [xmin, ymin, xmax, ymax] N, 4, 2 --> N, 4 - boxes_pred = np.concatenate((boxes_pred.min(axis=1), boxes_pred.max(axis=1)), axis=-1) - val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :4]) + for target, loc_pred in zip(targets, loc_preds): + for boxes_gt, boxes_pred in zip(target.values(), loc_pred.values()): + if args.rotation and args.eval_straight: + # Convert pred to boxes [xmin, ymin, xmax, ymax] N, 4, 2 --> N, 4 + boxes_pred = np.concatenate((boxes_pred.min(axis=1), boxes_pred.max(axis=1)), axis=-1) + val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :4]) val_loss += out["loss"].numpy() batch_cnt += 1 @@ -192,12 +193,18 @@ def main(args): pretrained=args.pretrained, input_shape=(args.input_size, args.input_size, 3), assume_straight_pages=not args.rotation, + class_names=val_set.class_names, ) # Resume weights if isinstance(args.resume, str): model.load_weights(args.resume) + if isinstance(args.pretrained_backbone, str): + print("Loading backbone weights.") + model = load_backbone(model, args.pretrained_backbone) + print("Done.") + # Metrics val_metric = LocalizationConfusion( use_polygons=args.rotation and not args.eval_straight, @@ -366,6 +373,7 @@ def parse_args(): parser.add_argument("--lr", type=float, default=0.001, help="learning rate for the optimizer (Adam)") parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading") parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint") + parser.add_argument("--pretrained-backbone", type=str, default=None, help="Path to your backbone weights") parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop") parser.add_argument( "--freeze-backbone", dest="freeze_backbone", action="store_true", help="freeze model backbone for fine-tuning" diff --git a/references/detection/utils.py b/references/detection/utils.py index fab8e9408d..470fe81399 100644 --- a/references/detection/utils.py +++ b/references/detection/utils.py @@ -3,6 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. +import pickle from typing import Dict, List import cv2 @@ -20,16 +21,17 @@ def plot_samples(images, targets: List[Dict[str, np.ndarray]]) -> None: img = img.transpose(1, 2, 0) target = np.zeros(img.shape[:2], np.uint8) - boxes = targets[idx].copy() - boxes[:, [0, 2]] = boxes[:, [0, 2]] * img.shape[1] - boxes[:, [1, 3]] = boxes[:, [1, 3]] * img.shape[0] - boxes[:, :4] = boxes[:, :4].round().astype(int) - - for box in boxes: - if boxes.ndim == 3: - cv2.fillPoly(target, [np.int0(box)], 1) - else: - target[int(box[1]) : int(box[3]) + 1, int(box[0]) : int(box[2]) + 1] = 1 + tgts = targets[idx].copy() + for key, boxes in tgts.items(): + boxes[:, [0, 2]] = boxes[:, [0, 2]] * img.shape[1] + boxes[:, [1, 3]] = boxes[:, [1, 3]] * img.shape[0] + boxes[:, :4] = boxes[:, :4].round().astype(int) + + for box in boxes: + if boxes.ndim == 3: + cv2.fillPoly(target, [np.int0(box)], 1) + else: + target[int(box[1]) : int(box[3]) + 1, int(box[0]) : int(box[2]) + 1] = 1 if nb_samples > 1: axes[0][idx].imshow(img) axes[1][idx].imshow(target.astype(bool)) @@ -81,3 +83,11 @@ def plot_recorder(lr_recorder, loss_recorder, beta: float = 0.95, **kwargs) -> N plt.ylim(vals[min_idx] - 0.1 * delta, max_val + 0.2 * delta) plt.grid(True, linestyle="--", axis="x") plt.show(**kwargs) + + +def load_backbone(model, weights_path): + + pretrained_backbone_weights = pickle.load(open(weights_path, "rb")) + model.feat_extractor.set_weights(pretrained_backbone_weights[0]) + model.fpn.set_weights(pretrained_backbone_weights[1]) + return model diff --git a/scripts/evaluate_kie.py b/scripts/evaluate_kie.py new file mode 100644 index 0000000000..8be3fcd92f --- /dev/null +++ b/scripts/evaluate_kie.py @@ -0,0 +1,217 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import os + +from doctr.io.elements import KIEDocument + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + +import numpy as np +from tqdm import tqdm + +from doctr import datasets +from doctr.file_utils import is_tf_available +from doctr.models import kie_predictor +from doctr.utils.geometry import extract_crops, extract_rcrops +from doctr.utils.metrics import LocalizationConfusion, OCRMetric, TextMatch + +# Enable GPU growth if using TF +if is_tf_available(): + import tensorflow as tf + + gpu_devices = tf.config.experimental.list_physical_devices("GPU") + if any(gpu_devices): + tf.config.experimental.set_memory_growth(gpu_devices[0], True) +else: + import torch + + +def _pct(val): + return "N/A" if val is None else f"{val:.2%}" + + +def main(args): + + if not args.rotation: + args.eval_straight = True + + predictor = kie_predictor( + args.detection, + args.recognition, + pretrained=True, + reco_bs=args.batch_size, + assume_straight_pages=not args.rotation, + ) + + if args.img_folder and args.label_file: + testset = datasets.OCRDataset( + img_folder=args.img_folder, + label_file=args.label_file, + ) + sets = [testset] + else: + train_set = datasets.__dict__[args.dataset](train=True, download=True, use_polygons=not args.eval_straight) + val_set = datasets.__dict__[args.dataset](train=False, download=True, use_polygons=not args.eval_straight) + sets = [train_set, val_set] + + reco_metric = TextMatch() + if args.mask_shape: + det_metric = LocalizationConfusion( + iou_thresh=args.iou, use_polygons=not args.eval_straight, mask_shape=(args.mask_shape, args.mask_shape) + ) + e2e_metric = OCRMetric( + iou_thresh=args.iou, use_polygons=not args.eval_straight, mask_shape=(args.mask_shape, args.mask_shape) + ) + else: + det_metric = LocalizationConfusion(iou_thresh=args.iou, use_polygons=not args.eval_straight) + e2e_metric = OCRMetric(iou_thresh=args.iou, use_polygons=not args.eval_straight) + + sample_idx = 0 + extraction_fn = extract_crops if args.eval_straight else extract_rcrops + + for dataset in sets: + for page, target in tqdm(dataset): + # GT + gt_boxes = target["boxes"] + gt_labels = target["labels"] + + if args.img_folder and args.label_file: + x, y, w, h = gt_boxes[:, 0], gt_boxes[:, 1], gt_boxes[:, 2], gt_boxes[:, 3] + xmin, ymin = np.clip(x - w / 2, 0, 1), np.clip(y - h / 2, 0, 1) + xmax, ymax = np.clip(x + w / 2, 0, 1), np.clip(y + h / 2, 0, 1) + gt_boxes = np.stack([xmin, ymin, xmax, ymax], axis=-1) + + # Forward + out: KIEDocument + if is_tf_available(): + out = predictor(page[None, ...]) + crops = extraction_fn(page, gt_boxes) + reco_out = predictor.reco_predictor(crops) + else: + with torch.no_grad(): + out = predictor(page[None, ...]) + # We directly crop on PyTorch tensors, which are in channels_first + crops = extraction_fn(page, gt_boxes, channels_last=False) + reco_out = predictor.reco_predictor(crops) + + if len(reco_out): + reco_words, _ = zip(*reco_out) + else: + reco_words = [] + + # Unpack preds + pred_boxes = [] + pred_labels = [] + for page in out.pages: + height, width = page.dimensions + for predictions in page.predictions.values(): + for prediction in predictions: + if not args.rotation: + (a, b), (c, d) = prediction.geometry + else: + ( + [x1, y1], + [x2, y2], + [x3, y3], + [x4, y4], + ) = prediction.geometry + if gt_boxes.dtype == int: + if not args.rotation: + pred_boxes.append([int(a * width), int(b * height), int(c * width), int(d * height)]) + else: + if args.eval_straight: + pred_boxes.append( + [ + int(width * min(x1, x2, x3, x4)), + int(height * min(y1, y2, y3, y4)), + int(width * max(x1, x2, x3, x4)), + int(height * max(y1, y2, y3, y4)), + ] + ) + else: + pred_boxes.append( + [ + [int(x1 * width), int(y1 * height)], + [int(x2 * width), int(y2 * height)], + [int(x3 * width), int(y3 * height)], + [int(x4 * width), int(y4 * height)], + ] + ) + else: + if not args.rotation: + pred_boxes.append([a, b, c, d]) + else: + if args.eval_straight: + pred_boxes.append( + [ + min(x1, x2, x3, x4), + min(y1, y2, y3, y4), + max(x1, x2, x3, x4), + max(y1, y2, y3, y4), + ] + ) + else: + pred_boxes.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) + pred_labels.append(prediction.value) + + # Update the metric + det_metric.update(gt_boxes, np.asarray(pred_boxes)) + reco_metric.update(gt_labels, reco_words) + e2e_metric.update(gt_boxes, np.asarray(pred_boxes), gt_labels, pred_labels) + + # Loop break + sample_idx += 1 + if isinstance(args.samples, int) and args.samples == sample_idx: + break + if isinstance(args.samples, int) and args.samples == sample_idx: + break + + # Unpack aggregated metrics + print( + f"Model Evaluation (model= {args.detection} + {args.recognition}, " + f"dataset={'OCRDataset' if args.img_folder else args.dataset})" + ) + recall, precision, mean_iou = det_metric.summary() + print(f"Text Detection - Recall: {_pct(recall)}, Precision: {_pct(precision)}, Mean IoU: {_pct(mean_iou)}") + acc = reco_metric.summary() + print(f"Text Recognition - Accuracy: {_pct(acc['raw'])} (unicase: {_pct(acc['unicase'])})") + recall, precision, mean_iou = e2e_metric.summary() + print( + f"KIE OCR - Recall: {_pct(recall['raw'])} (unicase: {_pct(recall['unicase'])}), " + f"Precision: {_pct(precision['raw'])} (unicase: {_pct(precision['unicase'])}), Mean IoU: {_pct(mean_iou)}" + ) + + +def parse_args(): + import argparse + + parser = argparse.ArgumentParser( + description="DocTR end-to-end evaluation", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument("detection", type=str, help="Text detection model to use for analysis") + parser.add_argument("recognition", type=str, help="Text recognition model to use for analysis") + parser.add_argument("--iou", type=float, default=0.5, help="IoU threshold to match a pair of boxes") + parser.add_argument("--dataset", type=str, default="FUNSD", help="choose a dataset: FUNSD, CORD") + parser.add_argument("--img_folder", type=str, default=None, help="Only for local sets, path to images") + parser.add_argument("--label_file", type=str, default=None, help="Only for local sets, path to labels") + parser.add_argument("--rotation", dest="rotation", action="store_true", help="run rotated OCR + postprocessing") + parser.add_argument("-b", "--batch_size", type=int, default=32, help="batch size for recognition") + parser.add_argument("--mask_shape", type=int, default=None, help="mask shape for mask iou (only for rotation)") + parser.add_argument("--samples", type=int, default=None, help="evaluate only on the N first samples") + parser.add_argument( + "--eval-straight", + action="store_true", + help="evaluate on straight pages with straight bbox (to use the quick and light metric)", + ) + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/setup.py b/setup.py index c5ae8b0b32..b06872cdcd 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ from setuptools import setup PKG_NAME = "python-doctr" -VERSION = os.getenv("BUILD_VERSION", "0.6.1a0") +VERSION = os.getenv("BUILD_VERSION", "0.7.1a0") if __name__ == "__main__": diff --git a/tests/common/test_io_elements.py b/tests/common/test_io_elements.py index 3b50b66041..3048a49ae1 100644 --- a/tests/common/test_io_elements.py +++ b/tests/common/test_io_elements.py @@ -3,6 +3,7 @@ import numpy as np import pytest +from doctr.file_utils import CLASS_NAME from doctr.io import elements @@ -41,6 +42,19 @@ def _mock_lines(size=(1, 1), offset=(0, 0)): ] +def _mock_prediction(size=(1.0, 1.0), offset=(0, 0), confidence=0.9): + return [ + elements.Prediction( + "hello", confidence, ((offset[0], offset[1]), (size[0] / 2 + offset[0], size[1] / 2 + offset[1])) + ), + elements.Prediction( + "world", + confidence, + ((size[0] / 2 + offset[0], size[1] / 2 + offset[1]), (size[0] + offset[0], size[1] + offset[1])), + ), + ] + + def _mock_blocks(size=(1, 1), offset=(0, 0)): sub_size = (size[0] / 4, size[1] / 4) return [ @@ -74,6 +88,25 @@ def _mock_pages(block_size=(1, 1), block_offset=(0, 0)): ] +def _mock_kie_pages(prediction_size=(1, 1), prediction_offset=(0, 0)): + return [ + elements.KIEPage( + {CLASS_NAME: _mock_prediction(prediction_size, prediction_offset)}, + 0, + (300, 200), + {"value": 0.0, "confidence": 1.0}, + {"value": "EN", "confidence": 0.8}, + ), + elements.KIEPage( + {CLASS_NAME: _mock_prediction(prediction_size, prediction_offset)}, + 1, + (500, 1000), + {"value": 0.15, "confidence": 0.8}, + {"value": "FR", "confidence": 0.7}, + ), + ] + + def test_element(): with pytest.raises(KeyError): elements.Element(sub_elements=[1]) @@ -158,6 +191,32 @@ def test_artefact(): assert artefact.__repr__() == f"Artefact(type='{artefact_type}', confidence={conf:.2})" +def test_prediction(): + prediction_str = "hello" + conf = 0.8 + geom = ((0, 0), (1, 1)) + prediction = elements.Prediction(prediction_str, conf, geom) + + # Attribute checks + assert prediction.value == prediction_str + assert prediction.confidence == conf + assert prediction.geometry == geom + + # Render + assert prediction.render() == prediction_str + + # Export + assert prediction.export() == {"value": prediction_str, "confidence": conf, "geometry": geom} + + # Repr + assert prediction.__repr__() == f"Prediction(value='hello', confidence={conf:.2}, bounding_box={geom})" + + # Class method + state_dict = {"value": "there", "confidence": 0.1, "geometry": ((0, 0), (0.5, 0.5))} + prediction = elements.Prediction.from_dict(state_dict) + assert prediction.export() == state_dict + + def test_block(): geom = ((0, 0), (1, 1)) sub_size = (geom[1][0] / 2, geom[1][0] / 2) @@ -230,6 +289,53 @@ def test_page(): assert img.shape == (*page_size, 3) +def test_kiepage(): + page_idx = 0 + page_size = (300, 200) + orientation = {"value": 0.0, "confidence": 0.0} + language = {"value": "EN", "confidence": 0.8} + predictions = {CLASS_NAME: _mock_prediction()} + kie_page = elements.KIEPage(predictions, page_idx, page_size, orientation, language) + + # Attribute checks + assert len(kie_page.predictions) == len(predictions) + assert all(isinstance(b, elements.Prediction) for b in kie_page.predictions[CLASS_NAME]) + assert kie_page.page_idx == page_idx + assert kie_page.dimensions == page_size + assert kie_page.orientation == orientation + assert kie_page.language == language + + # Render + assert kie_page.render() == "words: hello\n\nwords: world" + + # Export + assert kie_page.export() == { + "predictions": {CLASS_NAME: [b.export() for b in predictions[CLASS_NAME]]}, + "page_idx": page_idx, + "dimensions": page_size, + "orientation": orientation, + "language": language, + } + + # Export XML + assert ( + isinstance(kie_page.export_as_xml(), tuple) + and isinstance(kie_page.export_as_xml()[0], (bytes, bytearray)) + and isinstance(kie_page.export_as_xml()[1], ElementTree) + ) + + # Repr + assert "\n".join(repr(kie_page).split("\n")[:2]) == f"KIEPage(\n dimensions={repr(page_size)}" + + # Show + kie_page.show(np.zeros((256, 256, 3), dtype=np.uint8), block=False) + + # Synthesize + img = kie_page.synthesize() + assert isinstance(img, np.ndarray) + assert img.shape == (*page_size, 3) + + def test_document(): pages = _mock_pages() doc = elements.Document(pages) @@ -254,3 +360,29 @@ def test_document(): # Synthesize img_list = doc.synthesize() assert isinstance(img_list, list) and len(img_list) == len(pages) + + +def test_kie_document(): + pages = _mock_kie_pages() + doc = elements.KIEDocument(pages) + + # Attribute checks + assert len(doc.pages) == len(pages) + assert all(isinstance(p, elements.KIEPage) for p in doc.pages) + + # Render + page_export = "words: hello\n\nwords: world" + assert doc.render() == f"{page_export}\n\n\n\n{page_export}" + + # Export + assert doc.export() == {"pages": [p.export() for p in pages]} + + # Export XML + assert isinstance(doc.export_as_xml(), list) and len(doc.export_as_xml()) == len(pages) + + # Show + doc.show([np.zeros((256, 256, 3), dtype=np.uint8) for _ in range(len(pages))], block=False) + + # Synthesize + img_list = doc.synthesize() + assert isinstance(img_list, list) and len(img_list) == len(pages) diff --git a/tests/common/test_models.py b/tests/common/test_models.py index bfcb0a7967..f228ba5c62 100644 --- a/tests/common/test_models.py +++ b/tests/common/test_models.py @@ -6,7 +6,7 @@ import requests from doctr.io import reader -from doctr.models._utils import estimate_orientation, get_bitmap_angle, get_language +from doctr.models._utils import estimate_orientation, get_bitmap_angle, get_language, invert_data_structure from doctr.utils import geometry @@ -63,3 +63,14 @@ def test_get_lang(): lang = get_language("a") assert lang[0] == "unknown" assert lang[1] == 0.0 + + +def test_convert_list_dict(): + dic = {"k1": [[0], [0], [0]], "k2": [[1], [1], [1]]} + L = [{"k1": [0], "k2": [1]}, {"k1": [0], "k2": [1]}, {"k1": [0], "k2": [1]}] + + converted_dic = invert_data_structure(dic) + converted_list = invert_data_structure(L) + + assert converted_dic == L + assert converted_list == dic diff --git a/tests/common/test_models_builder.py b/tests/common/test_models_builder.py index 967600b4d3..8c89d370a5 100644 --- a/tests/common/test_models_builder.py +++ b/tests/common/test_models_builder.py @@ -1,20 +1,28 @@ import numpy as np import pytest +from doctr.file_utils import CLASS_NAME from doctr.io import Document +from doctr.io.elements import KIEDocument from doctr.models import builder +words_per_page = 10 + +boxes_1 = {CLASS_NAME: np.random.rand(words_per_page, 6)} # dict format +boxes_1[CLASS_NAME][:2] *= boxes_1[CLASS_NAME][2:4] + +boxes_2 = np.random.rand(words_per_page, 6) # array format +boxes_2[:2] *= boxes_2[2:4] + def test_documentbuilder(): - words_per_page = 10 num_pages = 2 # Don't resolve lines doc_builder = builder.DocumentBuilder(resolve_lines=False, resolve_blocks=False) - boxes = np.random.rand(words_per_page, 6) + boxes = np.random.rand(words_per_page, 6) # array format boxes[:2] *= boxes[2:4] - # Arg consistency check with pytest.raises(ValueError): doc_builder([boxes, boxes], [("hello", 1.0)] * 3, [(100, 200), (100, 200)]) @@ -42,7 +50,7 @@ def test_documentbuilder(): ] ) doc_builder_2 = builder.DocumentBuilder(resolve_blocks=False, resolve_lines=False, export_as_straight_boxes=True) - out = doc_builder_2([boxes], [[("hello", 0.99), ("world", 0.99)]], [(100, 100)]) + out = doc_builder_2([boxes], [[("hello", 0.99), ("word", 0.99)]], [(100, 100)]) assert out.pages[0].blocks[0].lines[0].words[-1].geometry == ((0.45, 0.5), (0.6, 0.65)) # Repr @@ -52,6 +60,62 @@ def test_documentbuilder(): ) +def test_kiedocumentbuilder(): + + num_pages = 2 + + # Don't resolve lines + doc_builder = builder.KIEDocumentBuilder(resolve_lines=False, resolve_blocks=False) + predictions = {CLASS_NAME: np.random.rand(words_per_page, 6)} # dict format + predictions[CLASS_NAME][:2] *= predictions[CLASS_NAME][2:4] + # Arg consistency check + with pytest.raises(ValueError): + doc_builder([predictions, predictions], [{CLASS_NAME: ("hello", 1.0)}] * 3, [(100, 200), (100, 200)]) + out = doc_builder( + [predictions, predictions], + [{CLASS_NAME: [("hello", 1.0)] * words_per_page}] * num_pages, + [(100, 200), (100, 200)], + ) + assert isinstance(out, KIEDocument) + assert len(out.pages) == num_pages + # 1 Block & 1 line per page + assert len(out.pages[0].predictions) == 1 + assert len(out.pages[0].predictions[CLASS_NAME]) == words_per_page + + # Resolve lines + doc_builder = builder.KIEDocumentBuilder(resolve_lines=True, resolve_blocks=True) + out = doc_builder( + [predictions, predictions], + [{CLASS_NAME: [("hello", 1.0)] * words_per_page}] * num_pages, + [(100, 200), (100, 200)], + ) + + # No detection + predictions = {CLASS_NAME: np.zeros((0, 5))} + out = doc_builder([predictions, predictions], [{CLASS_NAME: []}, {CLASS_NAME: []}], [(100, 200), (100, 200)]) + assert len(out.pages[0].predictions[CLASS_NAME]) == 0 + + # Rotated boxes to export as straight boxes + predictions = { + CLASS_NAME: np.array( + [ + [[0.1, 0.1], [0.2, 0.2], [0.15, 0.25], [0.05, 0.15]], + [[0.5, 0.5], [0.6, 0.6], [0.55, 0.65], [0.45, 0.55]], + ] + ) + } + doc_builder_2 = builder.KIEDocumentBuilder(resolve_blocks=False, resolve_lines=False, export_as_straight_boxes=True) + out = doc_builder_2([predictions], [{CLASS_NAME: [("hello", 0.99), ("word", 0.99)]}], [(100, 100)]) + assert out.pages[0].predictions[CLASS_NAME][0].geometry == ((0.05, 0.1), (0.2, 0.25)) + assert out.pages[0].predictions[CLASS_NAME][1].geometry == ((0.45, 0.5), (0.6, 0.65)) + + # Repr + assert ( + repr(doc_builder) == "KIEDocumentBuilder(resolve_lines=True, " + "resolve_blocks=True, paragraph_break=0.035, export_as_straight_boxes=False)" + ) + + @pytest.mark.parametrize( "input_boxes, sorted_idxs", [ @@ -71,7 +135,6 @@ def test_documentbuilder(): ], ) def test_sort_boxes(input_boxes, sorted_idxs): - doc_builder = builder.DocumentBuilder() assert doc_builder._sort_boxes(np.asarray(input_boxes))[0].tolist() == sorted_idxs @@ -95,6 +158,5 @@ def test_sort_boxes(input_boxes, sorted_idxs): ], ) def test_resolve_lines(input_boxes, lines): - doc_builder = builder.DocumentBuilder() assert doc_builder._resolve_lines(np.asarray(input_boxes)) == lines diff --git a/tests/pytorch/test_datasets_pt.py b/tests/pytorch/test_datasets_pt.py index f3d0af3386..bb692d3411 100644 --- a/tests/pytorch/test_datasets_pt.py +++ b/tests/pytorch/test_datasets_pt.py @@ -7,6 +7,7 @@ from torch.utils.data import DataLoader, RandomSampler from doctr import datasets +from doctr.file_utils import CLASS_NAME from doctr.transforms import Resize @@ -92,11 +93,13 @@ def test_detection_dataset(mock_image_folder, mock_detection_label): ) assert len(ds) == 5 - img, target = ds[0] + img, target_dict = ds[0] + target = target_dict[CLASS_NAME] assert isinstance(img, torch.Tensor) assert img.dtype == torch.float32 assert img.shape[-2:] == input_size # Bounding boxes + assert isinstance(target_dict, dict) assert isinstance(target, np.ndarray) and target.dtype == np.float32 assert np.all(np.logical_and(target[:, :4] >= 0, target[:, :4] <= 1)) assert target.shape[1] == 4 @@ -104,8 +107,9 @@ def test_detection_dataset(mock_image_folder, mock_detection_label): loader = DataLoader(ds, batch_size=2, collate_fn=ds.collate_fn) images, targets = next(iter(loader)) assert isinstance(images, torch.Tensor) and images.shape == (2, 3, *input_size) - assert isinstance(targets, list) and all(isinstance(elt, np.ndarray) for elt in targets) - + assert isinstance(targets, list) and all( + isinstance(elt, np.ndarray) for target in targets for elt in target.values() + ) # Rotated DS rotated_ds = datasets.DetectionDataset( img_folder=mock_image_folder, @@ -114,7 +118,7 @@ def test_detection_dataset(mock_image_folder, mock_detection_label): use_polygons=True, ) _, r_target = rotated_ds[0] - assert r_target.shape[1:] == (4, 2) + assert r_target[CLASS_NAME].shape[1:] == (4, 2) # File existence check img_name, _ = ds.data[0] diff --git a/tests/pytorch/test_models_detection_pt.py b/tests/pytorch/test_models_detection_pt.py index f976b09a91..de929a13e8 100644 --- a/tests/pytorch/test_models_detection_pt.py +++ b/tests/pytorch/test_models_detection_pt.py @@ -6,6 +6,7 @@ import pytest import torch +from doctr.file_utils import CLASS_NAME from doctr.models import detection from doctr.models.detection._utils import dilate, erode from doctr.models.detection.predictor import DetectionPredictor @@ -29,8 +30,8 @@ def test_detection_models(arch_name, input_shape, output_size, out_prob): assert isinstance(model, torch.nn.Module) input_tensor = torch.rand((batch_size, *input_shape)) target = [ - np.array([[0.5, 0.5, 1, 1], [0.5, 0.5, 0.8, 0.8]], dtype=np.float32), - np.array([[0.5, 0.5, 1, 1], [0.5, 0.5, 0.8, 0.9]], dtype=np.float32), + {CLASS_NAME: np.array([[0.5, 0.5, 1, 1], [0.5, 0.5, 0.8, 0.8]], dtype=np.float32)}, + {CLASS_NAME: np.array([[0.5, 0.5, 1, 1], [0.5, 0.5, 0.8, 0.9]], dtype=np.float32)}, ] if torch.cuda.is_available(): model.cuda() @@ -44,22 +45,27 @@ def test_detection_models(arch_name, input_shape, output_size, out_prob): if out_prob: assert torch.all((out["out_map"] >= 0) & (out["out_map"] <= 1)) # Check boxes - for boxes in out["preds"]: - assert boxes.shape[1] == 5 - assert np.all(boxes[:, :2] < boxes[:, 2:4]) - assert np.all(boxes[:, :4] >= 0) and np.all(boxes[:, :4] <= 1) + for boxes_dict in out["preds"]: + for boxes in boxes_dict.values(): + assert boxes.shape[1] == 5 + assert np.all(boxes[:, :2] < boxes[:, 2:4]) + assert np.all(boxes[:, :4] >= 0) and np.all(boxes[:, :4] <= 1) # Check loss assert isinstance(out["loss"], torch.Tensor) # Check the rotated case (same targets) target = [ - np.array( - [[[0.5, 0.5], [1, 0.5], [1, 1], [0.5, 1]], [[0.5, 0.5], [0.8, 0.5], [0.8, 0.8], [0.5, 0.8]]], - dtype=np.float32, - ), - np.array( - [[[0.5, 0.5], [1, 0.5], [1, 1], [0.5, 1]], [[0.5, 0.5], [0.8, 0.5], [0.8, 0.9], [0.5, 0.9]]], - dtype=np.float32, - ), + { + CLASS_NAME: np.array( + [[[0.5, 0.5], [1, 0.5], [1, 1], [0.5, 1]], [[0.5, 0.5], [0.8, 0.5], [0.8, 0.8], [0.5, 0.8]]], + dtype=np.float32, + ) + }, + { + CLASS_NAME: np.array( + [[[0.5, 0.5], [1, 0.5], [1, 1], [0.5, 1]], [[0.5, 0.5], [0.8, 0.5], [0.8, 0.9], [0.5, 0.9]]], + dtype=np.float32, + ) + }, ] loss = model(input_tensor, target)["loss"] assert isinstance(loss, torch.Tensor) and ((loss - out["loss"]).abs() / loss).item() < 1e-1 @@ -87,7 +93,8 @@ def test_detection_zoo(arch_name): with torch.no_grad(): out = predictor(input_tensor) - assert all(isinstance(boxes, np.ndarray) and boxes.shape[1] == 5 for boxes in out) + assert all(isinstance(boxes, dict) for boxes in out) + assert all(isinstance(boxes[CLASS_NAME], np.ndarray) and boxes[CLASS_NAME].shape[1] == 5 for boxes in out) def test_erode(): diff --git a/tests/pytorch/test_models_zoo_pt.py b/tests/pytorch/test_models_zoo_pt.py index ed2987afa2..e2b26b5b55 100644 --- a/tests/pytorch/test_models_zoo_pt.py +++ b/tests/pytorch/test_models_zoo_pt.py @@ -4,8 +4,10 @@ from doctr import models from doctr.io import Document, DocumentFile +from doctr.io.elements import KIEDocument from doctr.models import detection, recognition from doctr.models.detection.predictor import DetectionPredictor +from doctr.models.kie_predictor import KIEPredictor from doctr.models.predictor import OCRPredictor from doctr.models.preprocessor import PreProcessor from doctr.models.recognition.predictor import RecognitionPredictor @@ -68,6 +70,63 @@ def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa assert out.pages[0].orientation["value"] == orientation +@pytest.mark.parametrize( + "assume_straight_pages, straighten_pages", + [ + [True, False], + [False, False], + [True, True], + ], +) +def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pages): + det_bsize = 4 + det_predictor = DetectionPredictor( + PreProcessor(output_size=(512, 512), batch_size=det_bsize), + detection.db_mobilenet_v3_large( + pretrained=False, + pretrained_backbone=False, + assume_straight_pages=assume_straight_pages, + ), + ) + + assert not det_predictor.model.training + + reco_bsize = 32 + reco_predictor = RecognitionPredictor( + PreProcessor(output_size=(32, 128), batch_size=reco_bsize, preserve_aspect_ratio=True), + recognition.crnn_vgg16_bn(pretrained=False, pretrained_backbone=False, vocab=mock_vocab), + ) + + assert not reco_predictor.model.training + + doc = DocumentFile.from_pdf(mock_pdf) + + predictor = KIEPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=assume_straight_pages, + straighten_pages=straighten_pages, + detect_orientation=True, + detect_language=True, + ) + + if assume_straight_pages: + assert predictor.crop_orientation_predictor is None + else: + assert isinstance(predictor.crop_orientation_predictor, nn.Module) + + out = predictor(doc) + assert isinstance(out, Document) + assert len(out.pages) == 2 + # Dimension check + with pytest.raises(ValueError): + input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8) + _ = predictor([input_page]) + + orientation = 0 + assert out.pages[0].orientation["value"] == orientation + + def _test_predictor(predictor): # Output checks assert isinstance(predictor, OCRPredictor) @@ -85,6 +144,23 @@ def _test_predictor(predictor): _ = predictor([input_page]) +def _test_kiepredictor(predictor): + # Output checks + assert isinstance(predictor, KIEPredictor) + + doc = [np.zeros((512, 512, 3), dtype=np.uint8)] + out = predictor(doc) + # Document + assert isinstance(out, KIEDocument) + + # The input doc has 1 page + assert len(out.pages) == 1 + # Dimension check + with pytest.raises(ValueError): + input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8) + _ = predictor([input_page]) + + @pytest.mark.parametrize( "det_arch, reco_arch", [ @@ -109,3 +185,21 @@ def test_zoo_models(det_arch, reco_arch): # passing detection model as recognition model with pytest.raises(ValueError): models.ocr_predictor(reco_arch=det_model, pretrained=True) + + # KIE predictor + predictor = models.kie_predictor(det_arch, reco_arch, pretrained=True) + _test_kiepredictor(predictor) + + # passing model instance directly + det_model = detection.__dict__[det_arch](pretrained=True) + reco_model = recognition.__dict__[reco_arch](pretrained=True) + predictor = models.kie_predictor(det_model, reco_model) + _test_kiepredictor(predictor) + + # passing recognition model as detection model + with pytest.raises(ValueError): + models.kie_predictor(det_arch=reco_model, pretrained=True) + + # passing detection model as recognition model + with pytest.raises(ValueError): + models.kie_predictor(reco_arch=det_model, pretrained=True) diff --git a/tests/tensorflow/test_datasets_tf.py b/tests/tensorflow/test_datasets_tf.py index 1783acebe0..06c2f0862e 100644 --- a/tests/tensorflow/test_datasets_tf.py +++ b/tests/tensorflow/test_datasets_tf.py @@ -7,6 +7,7 @@ from doctr import datasets from doctr.datasets import DataLoader +from doctr.file_utils import CLASS_NAME from doctr.transforms import Resize @@ -66,11 +67,13 @@ def test_detection_dataset(mock_image_folder, mock_detection_label): ) assert len(ds) == 5 - img, target = ds[0] + img, target_dict = ds[0] + target = target_dict[CLASS_NAME] assert isinstance(img, tf.Tensor) assert img.shape[:2] == input_size assert img.dtype == tf.float32 # Bounding boxes + assert isinstance(target_dict, dict) assert isinstance(target, np.ndarray) and target.dtype == np.float32 assert np.all(np.logical_and(target[:, :4] >= 0, target[:, :4] <= 1)) assert target.shape[1] == 4 @@ -78,7 +81,9 @@ def test_detection_dataset(mock_image_folder, mock_detection_label): loader = DataLoader(ds, batch_size=2) images, targets = next(iter(loader)) assert isinstance(images, tf.Tensor) and images.shape == (2, *input_size, 3) - assert isinstance(targets, list) and all(isinstance(elt, np.ndarray) for elt in targets) + assert isinstance(targets, list) and all( + isinstance(elt, np.ndarray) for target in targets for elt in target.values() + ) # Rotated DS rotated_ds = datasets.DetectionDataset( @@ -88,7 +93,7 @@ def test_detection_dataset(mock_image_folder, mock_detection_label): use_polygons=True, ) _, r_target = rotated_ds[0] - assert r_target.shape[1:] == (4, 2) + assert r_target[CLASS_NAME].shape[1:] == (4, 2) # File existence check img_name, _ = ds.data[0] diff --git a/tests/tensorflow/test_models_detection_tf.py b/tests/tensorflow/test_models_detection_tf.py index 620e389454..301dd29afb 100644 --- a/tests/tensorflow/test_models_detection_tf.py +++ b/tests/tensorflow/test_models_detection_tf.py @@ -6,6 +6,7 @@ import pytest import tensorflow as tf +from doctr.file_utils import CLASS_NAME from doctr.io import DocumentFile from doctr.models import detection from doctr.models.detection._utils import dilate, erode @@ -31,8 +32,8 @@ def test_detection_models(arch_name, input_shape, output_size, out_prob): assert isinstance(model, tf.keras.Model) input_tensor = tf.random.uniform(shape=[batch_size, *input_shape], minval=0, maxval=1) target = [ - np.array([[0.5, 0.5, 1, 1], [0.5, 0.5, 0.8, 0.8]], dtype=np.float32), - np.array([[0.5, 0.5, 1, 1], [0.5, 0.5, 0.8, 0.9]], dtype=np.float32), + {CLASS_NAME: np.array([[0.5, 0.5, 1, 1], [0.5, 0.5, 0.8, 0.8]], dtype=np.float32)}, + {CLASS_NAME: np.array([[0.5, 0.5, 1, 1], [0.5, 0.5, 0.8, 0.9]], dtype=np.float32)}, ] # test training model out = model(input_tensor, target, return_model_output=True, return_preds=True, training=True) @@ -46,31 +47,32 @@ def test_detection_models(arch_name, input_shape, output_size, out_prob): if out_prob: assert np.all(np.logical_and(seg_map >= 0, seg_map <= 1)) # Check boxes - for boxes in out["preds"]: - assert boxes.shape[1] == 5 - assert np.all(boxes[:, :2] < boxes[:, 2:4]) - assert np.all(boxes[:, :4] >= 0) and np.all(boxes[:, :4] <= 1) + for boxes_dict in out["preds"]: + for boxes in boxes_dict.values(): + assert boxes.shape[1] == 5 + assert np.all(boxes[:, :2] < boxes[:, 2:4]) + assert np.all(boxes[:, :4] >= 0) and np.all(boxes[:, :4] <= 1) # Check loss assert isinstance(out["loss"], tf.Tensor) # Target checks target = [ - np.array([[0, 0, 1, 1]], dtype=np.uint8), - np.array([[0, 0, 1, 1]], dtype=np.uint8), + {CLASS_NAME: np.array([[0, 0, 1, 1]], dtype=np.uint8)}, + {CLASS_NAME: np.array([[0, 0, 1, 1]], dtype=np.uint8)}, ] with pytest.raises(AssertionError): out = model(input_tensor, target, training=True) target = [ - np.array([[0, 0, 1.5, 1.5]], dtype=np.float32), - np.array([[-0.2, -0.3, 1, 1]], dtype=np.float32), + {CLASS_NAME: np.array([[0, 0, 1.5, 1.5]], dtype=np.float32)}, + {CLASS_NAME: np.array([[-0.2, -0.3, 1, 1]], dtype=np.float32)}, ] with pytest.raises(ValueError): out = model(input_tensor, target, training=True) # Check the rotated case target = [ - np.array([[0.75, 0.75, 0.5, 0.5, 0], [0.65, 0.65, 0.3, 0.3, 0]], dtype=np.float32), - np.array([[0.75, 0.75, 0.5, 0.5, 0], [0.65, 0.7, 0.3, 0.4, 0]], dtype=np.float32), + {CLASS_NAME: np.array([[0.75, 0.75, 0.5, 0.5, 0], [0.65, 0.65, 0.3, 0.3, 0]], dtype=np.float32)}, + {CLASS_NAME: np.array([[0.75, 0.75, 0.5, 0.5, 0], [0.65, 0.7, 0.3, 0.4, 0]], dtype=np.float32)}, ] loss = model(input_tensor, target, training=True)["loss"] assert isinstance(loss, tf.Tensor) and ((loss - out["loss"]) / loss).numpy() < 25e-2 @@ -136,7 +138,8 @@ def test_detection_zoo(arch_name): assert isinstance(predictor, DetectionPredictor) input_tensor = tf.random.uniform(shape=[2, 1024, 1024, 3], minval=0, maxval=1) out = predictor(input_tensor) - assert all(isinstance(boxes, np.ndarray) and boxes.shape[1] == 5 for boxes in out) + assert all(isinstance(boxes, dict) for boxes in out) + assert all(isinstance(boxes[CLASS_NAME], np.ndarray) and boxes[CLASS_NAME].shape[1] == 5 for boxes in out) def test_detection_zoo_error(): diff --git a/tests/tensorflow/test_models_zoo_tf.py b/tests/tensorflow/test_models_zoo_tf.py index 2899aec558..5cd4d0be5d 100644 --- a/tests/tensorflow/test_models_zoo_tf.py +++ b/tests/tensorflow/test_models_zoo_tf.py @@ -2,10 +2,13 @@ import pytest from doctr import models +from doctr.file_utils import CLASS_NAME from doctr.io import Document, DocumentFile +from doctr.io.elements import KIEDocument from doctr.models import detection, recognition from doctr.models.detection.predictor import DetectionPredictor from doctr.models.detection.zoo import detection_predictor +from doctr.models.kie_predictor import KIEPredictor from doctr.models.predictor import OCRPredictor from doctr.models.preprocessor import PreProcessor from doctr.models.recognition.predictor import RecognitionPredictor @@ -119,6 +122,114 @@ def test_trained_ocr_predictor(mock_tilted_payslip): assert out.pages[0].blocks[0].lines[0].words[0].value == "Mr." +@pytest.mark.parametrize( + "assume_straight_pages, straighten_pages", + [ + [True, False], + [False, False], + [True, True], + ], +) +def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pages): + det_bsize = 4 + det_predictor = DetectionPredictor( + PreProcessor(output_size=(512, 512), batch_size=det_bsize), + detection.db_mobilenet_v3_large( + pretrained=True, + pretrained_backbone=False, + input_shape=(512, 512, 3), + assume_straight_pages=assume_straight_pages, + ), + ) + + reco_bsize = 16 + reco_predictor = RecognitionPredictor( + PreProcessor(output_size=(32, 128), batch_size=reco_bsize, preserve_aspect_ratio=True), + recognition.crnn_vgg16_bn(pretrained=False, pretrained_backbone=False, vocab=mock_vocab), + ) + + doc = DocumentFile.from_pdf(mock_pdf) + + predictor = KIEPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=assume_straight_pages, + straighten_pages=straighten_pages, + detect_orientation=True, + detect_language=True, + ) + + if assume_straight_pages: + assert predictor.crop_orientation_predictor is None + else: + assert isinstance(predictor.crop_orientation_predictor, NestedObject) + + out = predictor(doc) + assert isinstance(out, KIEDocument) + assert len(out.pages) == 2 + # Dimension check + with pytest.raises(ValueError): + input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8) + _ = predictor([input_page]) + + orientation = 0 + assert out.pages[0].orientation["value"] == orientation + language = "unknown" + assert out.pages[0].language["value"] == language + + +def test_trained_kie_predictor(mock_tilted_payslip): + doc = DocumentFile.from_images(mock_tilted_payslip) + + det_predictor = detection_predictor("db_resnet50", pretrained=True, batch_size=2, assume_straight_pages=True) + reco_predictor = recognition_predictor("crnn_vgg16_bn", pretrained=True, batch_size=128) + + predictor = KIEPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=True, + straighten_pages=True, + ) + + out = predictor(doc) + + assert isinstance(out, KIEDocument) + assert out.pages[0].predictions[CLASS_NAME][0].value == "Mr." + geometry_mr = np.array( + [[0.08844472, 0.35763523], [0.11625107, 0.34320644], [0.12588427, 0.35771032], [0.09807791, 0.37213911]] + ) + assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][0].geometry), geometry_mr) + + assert out.pages[0].predictions[CLASS_NAME][-1].value == "Kabir)" + geometry_revised = np.array( + [[0.43725992, 0.67232439], [0.49045468, 0.64472149], [0.50570724, 0.66768597], [0.452512473, 0.69528887]] + ) + assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][-1].geometry), geometry_revised) + + det_predictor = detection_predictor( + "db_resnet50", + pretrained=True, + batch_size=2, + assume_straight_pages=True, + preserve_aspect_ratio=True, + symmetric_pad=True, + ) + + predictor = KIEPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=True, + straighten_pages=True, + preserve_aspect_ratio=True, + symmetric_pad=True, + ) + + out = predictor(doc) + + assert isinstance(out, KIEDocument) + assert out.pages[0].predictions[CLASS_NAME][0].value == "Mr." + + def _test_predictor(predictor): # Output checks assert isinstance(predictor, OCRPredictor) @@ -136,6 +247,23 @@ def _test_predictor(predictor): _ = predictor([input_page]) +def _test_kiepredictor(predictor): + # Output checks + assert isinstance(predictor, KIEPredictor) + + doc = [np.zeros((512, 512, 3), dtype=np.uint8)] + out = predictor(doc) + # Document + assert isinstance(out, KIEDocument) + + # The input doc has 1 page + assert len(out.pages) == 1 + # Dimension check + with pytest.raises(ValueError): + input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8) + _ = predictor([input_page]) + + @pytest.mark.parametrize( "det_arch, reco_arch", [ @@ -160,3 +288,21 @@ def test_zoo_models(det_arch, reco_arch): # passing detection model as recognition model with pytest.raises(ValueError): models.ocr_predictor(reco_arch=det_model, pretrained=True) + + # KIE predictor + predictor = models.kie_predictor(det_arch, reco_arch, pretrained=True) + _test_kiepredictor(predictor) + + # passing model instance directly + det_model = detection.__dict__[det_arch](pretrained=True) + reco_model = recognition.__dict__[reco_arch](pretrained=True) + predictor = models.kie_predictor(det_model, reco_model) + _test_kiepredictor(predictor) + + # passing recognition model as detection model + with pytest.raises(ValueError): + models.kie_predictor(det_arch=reco_model, pretrained=True) + + # passing detection model as recognition model + with pytest.raises(ValueError): + models.kie_predictor(reco_arch=det_model, pretrained=True)