Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add handling of multiclass format in detection dataset loading #4

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
24d7009
feat: add handling of multiclass format in detection dataset loading …
aminemindee Aug 31, 2022
b713128
feat: commit to rebase
aminemindee Sep 7, 2022
64dc994
fix: fix loss computation and make training work
aminemindee Sep 8, 2022
95c017e
feat: make loss computation vectorized and change target building to …
aminemindee Sep 13, 2022
f630280
add multi class intergration in prediction pipeline
aminemindee Sep 19, 2022
1563f1e
feat: add multiclass to pytorch and fix tests
aminemindee Sep 22, 2022
80475ac
fix: fix all pr reviews
aminemindee Oct 3, 2022
fa887c2
fix: second review comments
aminemindee Oct 5, 2022
31165b5
fix: pr comments
aminemindee Oct 10, 2022
6cef57c
fix: fix CI evaluate script
aminemindee Oct 11, 2022
79a320a
feat: add doc about Pages changes and multilabel dataset for training
aminemindee Oct 11, 2022
863fd98
feat: fix api dockerfile and make it work with new changes
aminemindee Oct 11, 2022
fcee26a
fix reference tests
aminemindee Oct 11, 2022
cb19734
fix: fix api doc and dockerfile to create requirement txt inside and …
aminemindee Oct 12, 2022
133316f
fix api docker with doctr 0.6.0
aminemindee Oct 13, 2022
75c1cd3
Up major version to 1.0.0
aminemindee Oct 14, 2022
1cc2204
refactor: refactor invert dict list and list dict function into one s…
aminemindee Oct 14, 2022
3b22ebe
fix: style and mypy
aminemindee Oct 14, 2022
3747094
docs: make it more clear for new data format
aminemindee Oct 18, 2022
a96b83c
explain why python version was upped
aminemindee Oct 18, 2022
b48a243
add assert on length of tuple
aminemindee Oct 18, 2022
f5f4dce
add doc and simple name to invert data structure function
aminemindee Oct 18, 2022
47bcdbf
feat: add class names can be obtained from model config
aminemindee Oct 18, 2022
b8c1418
fix: prioritize class_names from dataset over model config
aminemindee Oct 18, 2022
b9810a0
fix: fix show samples in training
aminemindee Oct 18, 2022
829e266
fix: add check when target is dict and all values are numpy arrays
aminemindee Oct 19, 2022
223722d
fix: make detection target always dict and remove unnecessary made co…
aminemindee Oct 19, 2022
f4f4aed
fix: script detection evaluation tests and dataset tests with target …
aminemindee Oct 19, 2022
be40ef6
fix tests also on pytorch
aminemindee Oct 19, 2022
2d24441
feat: Add kie predictor and io elements and visualization that come w…
aminemindee Oct 26, 2022
20bfd41
fix: revert ocr predictor to old format
aminemindee Oct 26, 2022
b07abc0
fix tests and add test for kie predictor
aminemindee Oct 26, 2022
b74d7f2
up project version to 0.7.0
aminemindee Oct 26, 2022
f3c5d65
update api to fix it and add kie route
aminemindee Oct 26, 2022
aac4347
fix api version
aminemindee Oct 26, 2022
d112c2d
feat: sort class names to always have the same order.
aminemindee Oct 26, 2022
d8991d9
sort imports to avoid cyclic imports
aminemindee Oct 26, 2022
2442a4f
fix class_names default, use of tf_is_available avoid and copyright d…
aminemindee Oct 27, 2022
7cb23dd
feat: update readme and doc with kie predictor
aminemindee Nov 10, 2022
b396f6d
feat: add loading backbone pretrained for multiclass detection, new e…
aminemindee Dec 5, 2022
73d10a0
fix mypy
aminemindee Dec 5, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions .github/workflows/docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,6 @@ jobs:
with:
python-version: ${{ matrix.python }}
architecture: x64
- name: Install poetry
uses: abatilo/[email protected]
with:
poetry-version: 1.1.13
- name: Lock the requirements
run: |
cd api
make lock
- name: Build & run docker
run: cd api && docker-compose up -d --build
- name: Ping server
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/scripts.yml
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ jobs:
python -m pip install --upgrade pip
pip install -e .[torch] --upgrade
- name: Run evaluation script
run: python scripts/evaluate.py db_resnet50 crnn_vgg16_bn --samples 10
run: |
python scripts/evaluate.py db_resnet50 crnn_vgg16_bn --samples 10
python scripts/evaluate_kie.py db_resnet50 crnn_vgg16_bn --samples 10

test-collectenv:
runs-on: ${{ matrix.os }}
Expand Down
32 changes: 30 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,31 @@ You can also export them as a nested dict, more appropriate for JSON format:
json_output = result.export()
```

### Use the KIE predictor
The KIE predictor is a more flexible predictor compared to OCR as your detection model can detect multiple classes in a document. For example, you can have a detection model to detect just dates and adresses in a document.

The KIE predictor makes it possible to use detector with multiple classes with a recognition model and to have the whole pipeline already setup for you.

```python
from doctr.io import DocumentFile
from doctr.models import kie_predictor

# Model
model = kie_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
# PDF
doc = DocumentFile.from_pdf("path/to/your/doc.pdf")
# Analyze
result = model(doc)

predictions = result.pages[0].predictions
for class_name in predictions.keys():
list_predictions = predictions[class_name]
for prediction in list_predictions:
print(f"Prediction for {class_name}: {prediction}")
```
The KIE predictor results per page are in a dictionary format with each key representing a class name and it's value are the predictions for that class.


### If you are looking for support from the Mindee team
[![Bad OCR test detection image asking the developer if they need help](https://github.com/mindee/doctr/releases/download/v0.5.1/doctr-need-help.png)](https://mindee.com/product/doctr)

Expand Down Expand Up @@ -247,7 +272,10 @@ Looking to integrate docTR into your API? Here is a template to get you started
#### Deploy your API locally
Specific dependencies are required to run the API template, which you can install as follows:
```shell
pip install -r api/requirements.txt
cd api/
pip install poetry
make lock
pip install -r requirements.txt
```
You can now run your API locally:

Expand All @@ -262,7 +290,7 @@ PORT=8002 docker-compose up -d --build

#### What you have deployed

Your API should now be running locally on your port 8002. Access your automatically-built documentation at [http://localhost:8002/redoc](http://localhost:8002/redoc) and enjoy your three functional routes ("/detection", "/recognition", "/ocr"). Here is an example with Python to send a request to the OCR route:
Your API should now be running locally on your port 8002. Access your automatically-built documentation at [http://localhost:8002/redoc](http://localhost:8002/redoc) and enjoy your three functional routes ("/detection", "/recognition", "/ocr", "/kie"). Here is an example with Python to send a request to the OCR route:

```python
import requests
Expand Down
15 changes: 8 additions & 7 deletions api/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@ ENV PYTHONDONTWRITEBYTECODE 1
ENV PYTHONUNBUFFERED 1
ENV PYTHONPATH "${PYTHONPATH}:/app"

# copy requirements file
COPY requirements.txt /app/requirements.txt
RUN apt-get update \
&& apt-get install --no-install-recommends ffmpeg libsm6 libxext6 make -y \
&& apt-get autoremove -y \
&& rm -rf /var/lib/apt/lists/*

COPY pyproject.toml /app/pyproject.toml
COPY Makefile /app/Makefile

RUN apt-get update \
&& apt-get install --no-install-recommends ffmpeg libsm6 libxext6 -y \
&& pip install --upgrade pip setuptools wheel \
RUN pip install --upgrade pip setuptools wheel poetry \
&& make lock \
&& pip install -r /app/requirements.txt \
&& pip cache purge \
&& apt-get autoremove -y \
&& rm -rf /var/lib/apt/lists/* \
&& rm -rf /root/.cache/pip

# copy project
Expand Down
2 changes: 1 addition & 1 deletion api/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
lock:
poetry lock
poetry export -f requirements.txt --without-hashes --output requirements.txt
poetry export -f requirements.txt --without-hashes --dev --output requirements-dev.txt
poetry export -f requirements.txt --without-hashes --with dev --output requirements-dev.txt

# Run the docker
run:
Expand Down
3 changes: 2 additions & 1 deletion api/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi.openapi.utils import get_openapi

from app import config as cfg
from app.routes import detection, ocr, recognition
from app.routes import detection, kie, ocr, recognition

app = FastAPI(title=cfg.PROJECT_NAME, description=cfg.PROJECT_DESCRIPTION, debug=cfg.DEBUG, version=cfg.VERSION)

Expand All @@ -18,6 +18,7 @@
app.include_router(recognition.router, prefix="/recognition", tags=["recognition"])
app.include_router(detection.router, prefix="/detection", tags=["detection"])
app.include_router(ocr.router, prefix="/ocr", tags=["ocr"])
app.include_router(kie.router, prefix="/kie", tags=["kie"])


# Middleware
Expand Down
3 changes: 2 additions & 1 deletion api/app/routes/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from app.schemas import DetectionOut
from app.vision import det_predictor
from doctr.file_utils import CLASS_NAME
from doctr.io import decode_img_as_tensor

router = APIRouter()
Expand All @@ -19,4 +20,4 @@ async def text_detection(file: UploadFile = File(...)):
"""Runs docTR text detection model to analyze the input image"""
img = decode_img_as_tensor(file.file.read())
boxes = det_predictor([img])[0]
return [DetectionOut(box=box.tolist()) for box in boxes[:, :-1]]
return [DetectionOut(box=box.tolist()) for box in boxes[CLASS_NAME][:, :-1]]
29 changes: 29 additions & 0 deletions api/app/routes/kie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (C) 2022, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Dict, List

from fastapi import APIRouter, File, UploadFile, status

from app.schemas import OCROut
from app.vision import kie_predictor
from doctr.io import decode_img_as_tensor

router = APIRouter()


@router.post("/", response_model=Dict[str, List[OCROut]], status_code=status.HTTP_200_OK, summary="Perform KIE")
async def perform_kie(file: UploadFile = File(...)):
"""Runs docTR KIE model to analyze the input image"""
img = decode_img_as_tensor(file.file.read())
out = kie_predictor([img])

return {
class_name: [
OCROut(box=(*prediction.geometry[0], *prediction.geometry[1]), value=prediction.value)
for prediction in out.pages[0].predictions[class_name]
]
for class_name in out.pages[0].predictions.keys()
}
4 changes: 3 additions & 1 deletion api/app/routes/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,7 @@ async def perform_ocr(file: UploadFile = File(...)):

return [
OCROut(box=(*word.geometry[0], *word.geometry[1]), value=word.value)
for word in out.pages[0].blocks[0].lines[0].words
for block in out.pages[0].blocks
for line in block.lines
for word in line.words
]
3 changes: 2 additions & 1 deletion api/app/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)

from doctr.models import ocr_predictor
from doctr.models import kie_predictor, ocr_predictor

predictor = ocr_predictor(pretrained=True)
det_predictor = predictor.det_predictor
reco_predictor = predictor.reco_predictor
kie_predictor = kie_predictor(pretrained=True)
6 changes: 3 additions & 3 deletions api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ build-backend = "poetry.masonry.api"

[tool.poetry]
name = "doctr-api"
version = "0.5.2a0"
version = "0.7.1a0"
description = "Backend template for your OCR API with docTR"
authors = ["Mindee <[email protected]>"]
license = "Apache-2.0"

[tool.poetry.dependencies]
python = ">=3.8,<3.11"
python = ">=3.8.2,<3.11" # pypdfium2 needs a python version above 3.8.2
tensorflow = ">=2.9.0,<3.0.0"
tensorflow-addons = ">=0.17.1"
python-doctr = ">=0.2.0"
python-doctr = { version = ">=0.7.0", extras = ['tf'] }
# Fastapi: minimum version required to avoid pydantic error
# cf. https://github.com/tiangolo/fastapi/issues/4168
fastapi = ">=0.73.0"
Expand Down
29 changes: 29 additions & 0 deletions api/tests/routes/test_kie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import numpy as np
import pytest
from scipy.optimize import linear_sum_assignment

from doctr.utils.metrics import box_iou


@pytest.mark.asyncio
async def test_perform_kie(test_app_asyncio, mock_detection_image):

response = await test_app_asyncio.post("/kie", files={"file": mock_detection_image})
assert response.status_code == 200
json_response = response.json()

gt_boxes = np.array([[1240, 430, 1355, 470], [1360, 430, 1495, 470]], dtype=np.float32)
gt_boxes[:, [0, 2]] = gt_boxes[:, [0, 2]] / 1654
gt_boxes[:, [1, 3]] = gt_boxes[:, [1, 3]] / 2339
gt_labels = ["Hello", "world!"]

# Check that IoU with GT if reasonable
assert isinstance(json_response, dict) and len(list(json_response.values())[0]) == gt_boxes.shape[0]
pred_boxes = np.array([elt["box"] for json_out in json_response.values() for elt in json_out])
pred_labels = np.array([elt["value"] for json_out in json_response.values() for elt in json_out])
iou_mat = box_iou(gt_boxes, pred_boxes)
gt_idxs, pred_idxs = linear_sum_assignment(-iou_mat)
is_kept = iou_mat[gt_idxs, pred_idxs] >= 0.8
gt_idxs, pred_idxs = gt_idxs[is_kept], pred_idxs[is_kept]
assert gt_idxs.shape[0] == gt_boxes.shape[0]
assert all(gt_labels[gt_idx] == pred_labels[pred_idx] for gt_idx, pred_idx in zip(gt_idxs, pred_idxs))
2 changes: 2 additions & 0 deletions docs/source/modules/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ doctr.models.zoo

.. autofunction:: doctr.models.ocr_predictor

.. autofunction:: doctr.models.kie_predictor


doctr.models.factory
--------------------
Expand Down
2 changes: 1 addition & 1 deletion doctr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import datasets, io, models, transforms, utils
from . import io, datasets, models, transforms, utils
from .file_utils import is_tf_available, is_torch_available
from .version import __version__ # noqa: F401
11 changes: 10 additions & 1 deletion doctr/datasets/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union

import numpy as np

from doctr.file_utils import copy_tensor
from doctr.io.image import get_img_shape
from doctr.utils.data import download_from_url

Expand Down Expand Up @@ -55,7 +58,13 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:
img = self.img_transforms(img)

if self.sample_transforms is not None:
img, target = self.sample_transforms(img, target)
if isinstance(target, dict) and all([isinstance(item, np.ndarray) for item in target.values()]):
img_transformed = copy_tensor(img)
for class_name, bboxes in target.items():
img_transformed, target[class_name] = self.sample_transforms(img, bboxes)
img = img_transformed
else:
img, target = self.sample_transforms(img, target)

return img, target

Expand Down
6 changes: 6 additions & 0 deletions doctr/datasets/datasets/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def _read_sample(self, index: int) -> Tuple[torch.Tensor, Any]:
if isinstance(target, dict):
assert "boxes" in target, "Target should contain 'boxes' key"
assert "labels" in target, "Target should contain 'labels' key"
elif isinstance(target, tuple):
assert len(target) == 2
assert isinstance(target[0], str) or isinstance(
target[0], np.ndarray
), "first element of the tuple should be a string or a numpy array"
assert isinstance(target[1], list), "second element of the tuple should be a list"
else:
assert isinstance(target, str) or isinstance(
target, np.ndarray
Expand Down
6 changes: 6 additions & 0 deletions doctr/datasets/datasets/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def _read_sample(self, index: int) -> Tuple[tf.Tensor, Any]:
if isinstance(target, dict):
assert "boxes" in target, "Target should contain 'boxes' key"
assert "labels" in target, "Target should contain 'labels' key"
elif isinstance(target, tuple):
assert len(target) == 2
assert isinstance(target[0], str) or isinstance(
target[0], np.ndarray
), "first element of the tuple should be a string or a numpy array"
assert isinstance(target[1], list), "second element of the tuple should be a list"
else:
assert isinstance(target, str) or isinstance(
target, np.ndarray
Expand Down
47 changes: 39 additions & 8 deletions doctr/datasets/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

import json
import os
from typing import Any, List, Tuple
from typing import Any, Dict, List, Tuple, Type, Union

import numpy as np

from doctr.io.image import get_img_shape
from doctr.utils.geometry import convert_to_relative_coords
from doctr.file_utils import CLASS_NAME

from .datasets import AbstractDataset
from .utils import pre_transform_multiclass

__all__ = ["DetectionDataset"]

Expand Down Expand Up @@ -41,24 +41,55 @@ def __init__(
) -> None:
super().__init__(
img_folder,
pre_transforms=lambda img, boxes: (img, convert_to_relative_coords(boxes, get_img_shape(img))),
pre_transforms=pre_transform_multiclass,
**kwargs,
)

# File existence check
self._class_names: List = []
if not os.path.exists(label_path):
raise FileNotFoundError(f"unable to locate {label_path}")
with open(label_path, "rb") as f:
labels = json.load(f)

self.data: List[Tuple[str, np.ndarray]] = []
self.data: List[Tuple[str, Tuple[np.ndarray, List[str]]]] = []
np_dtype = np.float32
for img_name, label in labels.items():
# File existence check
if not os.path.exists(os.path.join(self.root, img_name)):
raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")

polygons: np.ndarray = np.asarray(label["polygons"], dtype=np_dtype)
geoms = polygons if use_polygons else np.concatenate((polygons.min(axis=1), polygons.max(axis=1)), axis=1)
geoms, polygons_classes = self.format_polygons(label["polygons"], use_polygons, np_dtype)

self.data.append((img_name, np.asarray(geoms, dtype=np_dtype)))
self.data.append((img_name, (np.asarray(geoms, dtype=np_dtype), polygons_classes)))

def format_polygons(
odulcy-mindee marked this conversation as resolved.
Show resolved Hide resolved
self, polygons: Union[List, Dict], use_polygons: bool, np_dtype: Type
) -> Tuple[np.ndarray, List[str]]:
"""format polygons into an array

Args:
polygons: the bounding boxes
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
np_dtype: dtype of array

Returns:
geoms: bounding boxes as np array
polygons_classes: list of classes for each bounding box
"""
if isinstance(polygons, list):
self._class_names += [CLASS_NAME]
polygons_classes = [CLASS_NAME for _ in polygons]
_polygons: np.ndarray = np.asarray(polygons, dtype=np_dtype)
elif isinstance(polygons, dict):
self._class_names += list(polygons.keys())
polygons_classes = [k for k, v in polygons.items() for _ in v]
_polygons = np.concatenate([np.asarray(poly, dtype=np_dtype) for poly in polygons.values() if poly], axis=0)
else:
raise TypeError(f"polygons should be a dictionary or list, it was {type(polygons)}")
geoms = _polygons if use_polygons else np.concatenate((_polygons.min(axis=1), _polygons.max(axis=1)), axis=1)
return geoms, polygons_classes

@property
def class_names(self):
return sorted(list(set(self._class_names)))
Loading