From 92db3977464cd457a52c0cdb1d9a3c446debe773 Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 2 Feb 2024 09:19:50 +0100 Subject: [PATCH 1/5] prototype --- doctr/models/kie_predictor/pytorch.py | 9 ++++++++- doctr/models/kie_predictor/tensorflow.py | 9 ++++++++- doctr/models/predictor/pytorch.py | 9 ++++++++- doctr/models/predictor/tensorflow.py | 9 ++++++++- doctr/models/zoo.py | 12 +++++++++++- tests/pytorch/test_models_zoo_pt.py | 8 ++++++++ tests/tensorflow/test_models_zoo_tf.py | 8 ++++++++ 7 files changed, 59 insertions(+), 5 deletions(-) diff --git a/doctr/models/kie_predictor/pytorch.py b/doctr/models/kie_predictor/pytorch.py index 152b2acaca..ccacbd4e29 100644 --- a/doctr/models/kie_predictor/pytorch.py +++ b/doctr/models/kie_predictor/pytorch.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Any, Dict, List, Union +from typing import Any, Callable, Dict, List, Union import numpy as np import torch @@ -36,6 +36,7 @@ class KIEPredictor(nn.Module, _KIEPredictor): page. Doing so will slightly deteriorate the overall latency. detect_language: if True, the language prediction will be added to the predictions for each page. Doing so will slightly deteriorate the overall latency. + callbacks: list of callbacks to be applied to the OCR pipelines `loc_preds` **kwargs: keyword args of `DocumentBuilder` """ @@ -49,6 +50,7 @@ def __init__( symmetric_pad: bool = True, detect_orientation: bool = False, detect_language: bool = False, + callbacks: List[Callable] = [], **kwargs: Any, ) -> None: nn.Module.__init__(self) @@ -59,6 +61,7 @@ def __init__( ) self.detect_orientation = detect_orientation self.detect_language = detect_language + self.callbacks = callbacks @torch.inference_mode() def forward( @@ -106,6 +109,10 @@ def forward( # Rectify crops if aspect ratio dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()} + # Apply callbacks to loc_preds if any + for callback in self.callbacks: + dict_loc_preds = callback(dict_loc_preds) + # Crop images crops = {} for class_name in dict_loc_preds.keys(): diff --git a/doctr/models/kie_predictor/tensorflow.py b/doctr/models/kie_predictor/tensorflow.py index 21ce9953ee..161eee52f3 100644 --- a/doctr/models/kie_predictor/tensorflow.py +++ b/doctr/models/kie_predictor/tensorflow.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Any, Dict, List, Union +from typing import Any, Callable, Dict, List, Union import numpy as np import tensorflow as tf @@ -36,6 +36,7 @@ class KIEPredictor(NestedObject, _KIEPredictor): page. Doing so will slightly deteriorate the overall latency. detect_language: if True, the language prediction will be added to the predictions for each page. Doing so will slightly deteriorate the overall latency. + callbacks: list of callbacks to be applied to the OCR pipelines `loc_preds` **kwargs: keyword args of `DocumentBuilder` """ @@ -51,6 +52,7 @@ def __init__( symmetric_pad: bool = True, detect_orientation: bool = False, detect_language: bool = False, + callbacks: List[Callable] = [], **kwargs: Any, ) -> None: self.det_predictor = det_predictor @@ -60,6 +62,7 @@ def __init__( ) self.detect_orientation = detect_orientation self.detect_language = detect_language + self.callbacks = callbacks def __call__( self, @@ -103,6 +106,10 @@ def __call__( # Rectify crops if aspect ratio dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()} + # Apply callbacks to loc_preds if any + for callback in self.callbacks: + dict_loc_preds = callback(dict_loc_preds) + # Crop images crops = {} for class_name in dict_loc_preds.keys(): diff --git a/doctr/models/predictor/pytorch.py b/doctr/models/predictor/pytorch.py index a1773d1e31..81b629b44b 100644 --- a/doctr/models/predictor/pytorch.py +++ b/doctr/models/predictor/pytorch.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Any, List, Union +from typing import Any, Callable, List, Union import numpy as np import torch @@ -36,6 +36,7 @@ class OCRPredictor(nn.Module, _OCRPredictor): page. Doing so will slightly deteriorate the overall latency. detect_language: if True, the language prediction will be added to the predictions for each page. Doing so will slightly deteriorate the overall latency. + callbacks: list of callbacks to be applied to the OCR pipelines `loc_preds` **kwargs: keyword args of `DocumentBuilder` """ @@ -49,6 +50,7 @@ def __init__( symmetric_pad: bool = True, detect_orientation: bool = False, detect_language: bool = False, + callbacks: List[Callable] = [], **kwargs: Any, ) -> None: nn.Module.__init__(self) @@ -59,6 +61,7 @@ def __init__( ) self.detect_orientation = detect_orientation self.detect_language = detect_language + self.callbacks = callbacks @torch.inference_mode() def forward( @@ -108,6 +111,10 @@ def forward( # Rectify crops if aspect ratio loc_preds = self._remove_padding(pages, loc_preds) + # Apply callbacks to loc_preds if any + for callback in self.callbacks: + loc_preds = callback(loc_preds) + # Crop images crops, loc_preds = self._prepare_crops( pages, diff --git a/doctr/models/predictor/tensorflow.py b/doctr/models/predictor/tensorflow.py index 5461d3020b..4d876d5cab 100644 --- a/doctr/models/predictor/tensorflow.py +++ b/doctr/models/predictor/tensorflow.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Any, List, Union +from typing import Any, Callable, List, Union import numpy as np import tensorflow as tf @@ -36,6 +36,7 @@ class OCRPredictor(NestedObject, _OCRPredictor): page. Doing so will slightly deteriorate the overall latency. detect_language: if True, the language prediction will be added to the predictions for each page. Doing so will slightly deteriorate the overall latency. + callbacks: list of callbacks to be applied to the OCR pipelines `loc_preds` **kwargs: keyword args of `DocumentBuilder` """ @@ -51,6 +52,7 @@ def __init__( symmetric_pad: bool = True, detect_orientation: bool = False, detect_language: bool = False, + callbacks: List[Callable] = [], **kwargs: Any, ) -> None: self.det_predictor = det_predictor @@ -60,6 +62,7 @@ def __init__( ) self.detect_orientation = detect_orientation self.detect_language = detect_language + self.callbacks = callbacks def __call__( self, @@ -105,6 +108,10 @@ def __call__( # Rectify crops if aspect ratio loc_preds = self._remove_padding(pages, loc_preds) + # Apply callbacks to loc_preds if any + for callback in self.callbacks: + loc_preds = callback(loc_preds) + # Crop images crops, loc_preds = self._prepare_crops( pages, loc_preds, channels_last=True, assume_straight_pages=self.assume_straight_pages diff --git a/doctr/models/zoo.py b/doctr/models/zoo.py index a351589037..5094d889d4 100644 --- a/doctr/models/zoo.py +++ b/doctr/models/zoo.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Any +from typing import Any, Callable, List from .detection.zoo import detection_predictor from .kie_predictor import KIEPredictor @@ -26,6 +26,7 @@ def _predictor( detect_orientation: bool = False, straighten_pages: bool = False, detect_language: bool = False, + callbacks: List[Callable] = [], **kwargs, ) -> OCRPredictor: # Detection @@ -56,6 +57,7 @@ def _predictor( detect_orientation=detect_orientation, straighten_pages=straighten_pages, detect_language=detect_language, + callbacks=callbacks, **kwargs, ) @@ -72,6 +74,7 @@ def ocr_predictor( detect_orientation: bool = False, straighten_pages: bool = False, detect_language: bool = False, + callbacks: List[Callable] = [], **kwargs: Any, ) -> OCRPredictor: """End-to-end OCR architecture using one model for localization, and another for text recognition. @@ -105,6 +108,7 @@ def ocr_predictor( Doing so will improve performances for documents with page-uniform rotations. detect_language: if True, the language prediction will be added to the predictions for each page. Doing so will slightly deteriorate the overall latency. + callbacks: list of callbacks to be applied to the OCR pipelines `loc_preds` kwargs: keyword args of `OCRPredictor` Returns: @@ -123,6 +127,7 @@ def ocr_predictor( detect_orientation=detect_orientation, straighten_pages=straighten_pages, detect_language=detect_language, + callbacks=callbacks, **kwargs, ) @@ -140,6 +145,7 @@ def _kie_predictor( detect_orientation: bool = False, straighten_pages: bool = False, detect_language: bool = False, + callbacks: List[Callable] = [], **kwargs, ) -> KIEPredictor: # Detection @@ -170,6 +176,7 @@ def _kie_predictor( detect_orientation=detect_orientation, straighten_pages=straighten_pages, detect_language=detect_language, + callbacks=callbacks, **kwargs, ) @@ -186,6 +193,7 @@ def kie_predictor( detect_orientation: bool = False, straighten_pages: bool = False, detect_language: bool = False, + callbacks: List[Callable] = [], **kwargs: Any, ) -> KIEPredictor: """End-to-end KIE architecture using one model for localization, and another for text recognition. @@ -219,6 +227,7 @@ def kie_predictor( Doing so will improve performances for documents with page-uniform rotations. detect_language: if True, the language prediction will be added to the predictions for each page. Doing so will slightly deteriorate the overall latency. + callbacks: list of callbacks to be applied to the OCR pipelines `loc_preds` kwargs: keyword args of `OCRPredictor` Returns: @@ -237,5 +246,6 @@ def kie_predictor( detect_orientation=detect_orientation, straighten_pages=straighten_pages, detect_language=detect_language, + callbacks=callbacks, **kwargs, ) diff --git a/tests/pytorch/test_models_zoo_pt.py b/tests/pytorch/test_models_zoo_pt.py index fa3f23b9d1..88b0b0f4d1 100644 --- a/tests/pytorch/test_models_zoo_pt.py +++ b/tests/pytorch/test_models_zoo_pt.py @@ -16,6 +16,12 @@ from doctr.models.recognition.zoo import recognition_predictor +# Create a dummy callback +class _DummyCallback: + def __call__(self, loc_preds): + return loc_preds + + @pytest.mark.parametrize( "assume_straight_pages, straighten_pages", [ @@ -120,6 +126,7 @@ def test_trained_ocr_predictor(mock_payslip): straighten_pages=True, preserve_aspect_ratio=True, symmetric_pad=True, + callbacks=[_DummyCallback()], ) out = predictor(doc) @@ -203,6 +210,7 @@ def test_trained_kie_predictor(mock_payslip): assume_straight_pages=True, straighten_pages=True, preserve_aspect_ratio=False, + callbacks=[_DummyCallback()], ) out = predictor(doc) diff --git a/tests/tensorflow/test_models_zoo_tf.py b/tests/tensorflow/test_models_zoo_tf.py index 32e7988560..db4295346e 100644 --- a/tests/tensorflow/test_models_zoo_tf.py +++ b/tests/tensorflow/test_models_zoo_tf.py @@ -16,6 +16,12 @@ from doctr.utils.repr import NestedObject +# Create a dummy callback +class _DummyCallback: + def __call__(self, loc_preds): + return loc_preds + + @pytest.mark.parametrize( "assume_straight_pages, straighten_pages", [ @@ -91,6 +97,7 @@ def test_trained_ocr_predictor(mock_payslip): assume_straight_pages=True, straighten_pages=True, preserve_aspect_ratio=False, + callbacks=[_DummyCallback()], ) out = predictor(doc) @@ -201,6 +208,7 @@ def test_trained_kie_predictor(mock_payslip): assume_straight_pages=True, straighten_pages=True, preserve_aspect_ratio=False, + callbacks=[_DummyCallback()], ) out = predictor(doc) From 3c2ba002e882c9beeda9212c12e6d035950e7c29 Mon Sep 17 00:00:00 2001 From: felix Date: Mon, 5 Feb 2024 09:09:55 +0100 Subject: [PATCH 2/5] update --- demo/app.py | 5 ++- demo/backend/pytorch.py | 3 ++ demo/backend/tensorflow.py | 3 ++ docs/source/using_doctr/using_models.rst | 45 +++++++++++++++++++ .../differentiable_binarization/pytorch.py | 7 ++- .../differentiable_binarization/tensorflow.py | 7 ++- doctr/models/detection/linknet/pytorch.py | 5 ++- doctr/models/kie_predictor/pytorch.py | 11 ++--- doctr/models/kie_predictor/tensorflow.py | 11 ++--- doctr/models/predictor/base.py | 6 ++- doctr/models/predictor/pytorch.py | 11 ++--- doctr/models/predictor/tensorflow.py | 11 ++--- doctr/models/preprocessor/pytorch.py | 1 - doctr/models/preprocessor/tensorflow.py | 1 - doctr/models/zoo.py | 12 +---- scripts/detect_text.py | 2 + tests/pytorch/test_models_zoo_pt.py | 6 ++- tests/tensorflow/test_models_zoo_tf.py | 6 ++- 18 files changed, 103 insertions(+), 50 deletions(-) diff --git a/demo/app.py b/demo/app.py index 4f156d87e3..bd624ea382 100644 --- a/demo/app.py +++ b/demo/app.py @@ -80,6 +80,9 @@ def main(det_archs, reco_archs): # Binarization threshold bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1) st.sidebar.write("\n") + # Box threshold + box_thresh = st.sidebar.slider("Box threshold", min_value=0.1, max_value=0.9, value=0.1, step=0.1) + st.sidebar.write("\n") if st.sidebar.button("Analyze page"): if uploaded_file is None: @@ -88,7 +91,7 @@ def main(det_archs, reco_archs): else: with st.spinner("Loading model..."): predictor = load_predictor( - det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, forward_device + det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, box_thresh, forward_device ) with st.spinner("Analyzing..."): diff --git a/demo/backend/pytorch.py b/demo/backend/pytorch.py index f15f1e6fcf..f9b8f443f6 100644 --- a/demo/backend/pytorch.py +++ b/demo/backend/pytorch.py @@ -35,6 +35,7 @@ def load_predictor( assume_straight_pages: bool, straighten_pages: bool, bin_thresh: float, + box_thresh: float, device: torch.device, ) -> OCRPredictor: """Load a predictor from doctr.models @@ -46,6 +47,7 @@ def load_predictor( assume_straight_pages: whether to assume straight pages or not straighten_pages: whether to straighten rotated pages or not bin_thresh: binarization threshold for the segmentation map + box_thresh: minimal objectness score to consider a box device: torch.device, the device to load the predictor on Returns: @@ -62,6 +64,7 @@ def load_predictor( detect_orientation=not assume_straight_pages, ).to(device) predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh + predictor.det_predictor.model.postprocessor.box_thresh = box_thresh return predictor diff --git a/demo/backend/tensorflow.py b/demo/backend/tensorflow.py index 08894b688e..3676af83cf 100644 --- a/demo/backend/tensorflow.py +++ b/demo/backend/tensorflow.py @@ -34,6 +34,7 @@ def load_predictor( assume_straight_pages: bool, straighten_pages: bool, bin_thresh: float, + box_thresh: float, device: tf.device, ) -> OCRPredictor: """Load a predictor from doctr.models @@ -45,6 +46,7 @@ def load_predictor( assume_straight_pages: whether to assume straight pages or not straighten_pages: whether to straighten rotated pages or not bin_thresh: binarization threshold for the segmentation map + box_thresh: threshold for the detection boxes device: tf.device, the device to load the predictor on Returns: @@ -62,6 +64,7 @@ def load_predictor( detect_orientation=not assume_straight_pages, ) predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh + predictor.det_predictor.model.postprocessor.box_thresh = box_thresh return predictor diff --git a/docs/source/using_doctr/using_models.rst b/docs/source/using_doctr/using_models.rst index dba1965ba2..1d50dcdbc1 100644 --- a/docs/source/using_doctr/using_models.rst +++ b/docs/source/using_doctr/using_models.rst @@ -398,3 +398,48 @@ For reference, here is a sample XML byte string output: + + +Advanced options +^^^^^^^^^^^^^^^^ +We provide a few advanced options to customize the behavior of the predictor to your needs: + +* Modify the binarization threshold for the detection model. +* Modify the box threshold for the detection model. + +This is useful to detect (possible less) text regions more accurately with a higher threshold, or to detect more text regions with a lower threshold. + + +.. code:: python3 + + import numpy as np + from doctr.models import ocr_predictor + predictor = ocr_predictor('db_resnet50', 'crnn_vgg16_bn', pretrained=True) + + # Modify the binarization threshold and the box threshold + predictor.det_predictor.model.postprocessor.bin_thresh = 0.5 + predictor.det_predictor.model.postprocessor.box_thresh = 0.2 + + input_page = (255 * np.random.rand(800, 600, 3)).astype(np.uint8) + out = predictor([input_page]) + + +* Add a hook to the `ocr_predictor` to manipulate the location predictions before the crops are passed to the recognition model. + +.. code:: python3 + + from doctr.model import ocr_predictor + + class CustomHook: + def __call__(self, loc_preds): + # Manipulate the location predictions here + # The outpout structure needs to be the same as the input location predictions + return loc_preds + + my_hook = CustomHook() + + predictor = ocr_predictor(pretrained=True) + # Add a hook in the middle of the pipeline + predictor.add_hook(my_hook) + # You can also add multiple hooks which will be executed sequentially + [predictor.add_hook(hook) for hook in [my_hook, my_hook, my_hook]] diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index f11408bd3d..cfc8267a3d 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -100,6 +100,8 @@ class DBNet(_DBNet, nn.Module): feature extractor: the backbone serving as feature extractor head_chans: the number of channels in the head deform_conv: whether to use deformable convolution + bin_thresh: threshold for binarization + box_thresh: minimal objectness score to consider a box assume_straight_pages: if True, fit straight bounding boxes only exportable: onnx exportable returns only logits cfg: the configuration dict of the model @@ -112,6 +114,7 @@ def __init__( head_chans: int = 256, deform_conv: bool = False, bin_thresh: float = 0.3, + box_thresh: float = 0.1, assume_straight_pages: bool = True, exportable: bool = False, cfg: Optional[Dict[str, Any]] = None, @@ -160,7 +163,9 @@ def __init__( nn.ConvTranspose2d(head_chans // 4, num_classes, 2, stride=2), ) - self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh) + self.postprocessor = DBPostProcessor( + assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh + ) for n, m in self.named_modules(): # Don't override the initialization of the backbone diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index 38835ead79..f782bf2078 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -112,6 +112,8 @@ class DBNet(_DBNet, keras.Model, NestedObject): ---- feature extractor: the backbone serving as feature extractor fpn_channels: number of channels each extracted feature maps is mapped to + bin_tresh: threshold for binarization + box_thresh: minimal objectness score to consider a box assume_straight_pages: if True, fit straight bounding boxes only exportable: onnx exportable returns only logits cfg: the configuration dict of the model @@ -125,6 +127,7 @@ def __init__( feature_extractor: IntermediateLayerGetter, fpn_channels: int = 128, # to be set to 256 to represent the author's initial idea bin_thresh: float = 0.3, + box_thresh: float = 0.1, assume_straight_pages: bool = True, exportable: bool = False, cfg: Optional[Dict[str, Any]] = None, @@ -159,7 +162,9 @@ def __init__( layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"), ]) - self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh) + self.postprocessor = DBPostProcessor( + assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh + ) def compute_loss( self, diff --git a/doctr/models/detection/linknet/pytorch.py b/doctr/models/detection/linknet/pytorch.py index ecf72beda6..537fd57256 100644 --- a/doctr/models/detection/linknet/pytorch.py +++ b/doctr/models/detection/linknet/pytorch.py @@ -91,6 +91,8 @@ class LinkNet(nn.Module, _LinkNet): Args: ---- feature extractor: the backbone serving as feature extractor + bin_thresh: threshold for binarization of the output feature map + box_thresh: minimal objectness score to consider a box head_chans: number of channels in the head layers assume_straight_pages: if True, fit straight bounding boxes only exportable: onnx exportable returns only logits @@ -102,6 +104,7 @@ def __init__( self, feat_extractor: IntermediateLayerGetter, bin_thresh: float = 0.1, + box_thresh: float = 0.1, head_chans: int = 32, assume_straight_pages: bool = True, exportable: bool = False, @@ -142,7 +145,7 @@ def __init__( ) self.postprocessor = LinkNetPostProcessor( - assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh + assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh ) for n, m in self.named_modules(): diff --git a/doctr/models/kie_predictor/pytorch.py b/doctr/models/kie_predictor/pytorch.py index ccacbd4e29..81f342931d 100644 --- a/doctr/models/kie_predictor/pytorch.py +++ b/doctr/models/kie_predictor/pytorch.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Any, Callable, Dict, List, Union +from typing import Any, Dict, List, Union import numpy as np import torch @@ -36,7 +36,6 @@ class KIEPredictor(nn.Module, _KIEPredictor): page. Doing so will slightly deteriorate the overall latency. detect_language: if True, the language prediction will be added to the predictions for each page. Doing so will slightly deteriorate the overall latency. - callbacks: list of callbacks to be applied to the OCR pipelines `loc_preds` **kwargs: keyword args of `DocumentBuilder` """ @@ -50,7 +49,6 @@ def __init__( symmetric_pad: bool = True, detect_orientation: bool = False, detect_language: bool = False, - callbacks: List[Callable] = [], **kwargs: Any, ) -> None: nn.Module.__init__(self) @@ -61,7 +59,6 @@ def __init__( ) self.detect_orientation = detect_orientation self.detect_language = detect_language - self.callbacks = callbacks @torch.inference_mode() def forward( @@ -109,9 +106,9 @@ def forward( # Rectify crops if aspect ratio dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()} - # Apply callbacks to loc_preds if any - for callback in self.callbacks: - dict_loc_preds = callback(dict_loc_preds) + # Apply hooks to loc_preds if any + for hook in self.hooks: + dict_loc_preds = hook(dict_loc_preds) # Crop images crops = {} diff --git a/doctr/models/kie_predictor/tensorflow.py b/doctr/models/kie_predictor/tensorflow.py index 161eee52f3..e266c0c0c5 100644 --- a/doctr/models/kie_predictor/tensorflow.py +++ b/doctr/models/kie_predictor/tensorflow.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Any, Callable, Dict, List, Union +from typing import Any, Dict, List, Union import numpy as np import tensorflow as tf @@ -36,7 +36,6 @@ class KIEPredictor(NestedObject, _KIEPredictor): page. Doing so will slightly deteriorate the overall latency. detect_language: if True, the language prediction will be added to the predictions for each page. Doing so will slightly deteriorate the overall latency. - callbacks: list of callbacks to be applied to the OCR pipelines `loc_preds` **kwargs: keyword args of `DocumentBuilder` """ @@ -52,7 +51,6 @@ def __init__( symmetric_pad: bool = True, detect_orientation: bool = False, detect_language: bool = False, - callbacks: List[Callable] = [], **kwargs: Any, ) -> None: self.det_predictor = det_predictor @@ -62,7 +60,6 @@ def __init__( ) self.detect_orientation = detect_orientation self.detect_language = detect_language - self.callbacks = callbacks def __call__( self, @@ -106,9 +103,9 @@ def __call__( # Rectify crops if aspect ratio dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()} - # Apply callbacks to loc_preds if any - for callback in self.callbacks: - dict_loc_preds = callback(dict_loc_preds) + # Apply hooks to loc_preds if any + for hook in self.hooks: + dict_loc_preds = hook(dict_loc_preds) # Crop images crops = {} diff --git a/doctr/models/predictor/base.py b/doctr/models/predictor/base.py index cf844f2053..5e136d0fc6 100644 --- a/doctr/models/predictor/base.py +++ b/doctr/models/predictor/base.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Any, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple import numpy as np @@ -48,6 +48,7 @@ def __init__( self.doc_builder = DocumentBuilder(**kwargs) self.preserve_aspect_ratio = preserve_aspect_ratio self.symmetric_pad = symmetric_pad + self.hooks = [] @staticmethod def _generate_crops( @@ -149,3 +150,6 @@ def _process_predictions( _idx += page_boxes.shape[0] return loc_preds, text_preds + + def add_hook(self, hook: Callable) -> None: + self.hooks.append(hook) diff --git a/doctr/models/predictor/pytorch.py b/doctr/models/predictor/pytorch.py index 81b629b44b..be22024116 100644 --- a/doctr/models/predictor/pytorch.py +++ b/doctr/models/predictor/pytorch.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Any, Callable, List, Union +from typing import Any, List, Union import numpy as np import torch @@ -36,7 +36,6 @@ class OCRPredictor(nn.Module, _OCRPredictor): page. Doing so will slightly deteriorate the overall latency. detect_language: if True, the language prediction will be added to the predictions for each page. Doing so will slightly deteriorate the overall latency. - callbacks: list of callbacks to be applied to the OCR pipelines `loc_preds` **kwargs: keyword args of `DocumentBuilder` """ @@ -50,7 +49,6 @@ def __init__( symmetric_pad: bool = True, detect_orientation: bool = False, detect_language: bool = False, - callbacks: List[Callable] = [], **kwargs: Any, ) -> None: nn.Module.__init__(self) @@ -61,7 +59,6 @@ def __init__( ) self.detect_orientation = detect_orientation self.detect_language = detect_language - self.callbacks = callbacks @torch.inference_mode() def forward( @@ -111,9 +108,9 @@ def forward( # Rectify crops if aspect ratio loc_preds = self._remove_padding(pages, loc_preds) - # Apply callbacks to loc_preds if any - for callback in self.callbacks: - loc_preds = callback(loc_preds) + # Apply hooks to loc_preds if any + for hook in self.hooks: + loc_preds = hook(loc_preds) # Crop images crops, loc_preds = self._prepare_crops( diff --git a/doctr/models/predictor/tensorflow.py b/doctr/models/predictor/tensorflow.py index 4d876d5cab..424615656a 100644 --- a/doctr/models/predictor/tensorflow.py +++ b/doctr/models/predictor/tensorflow.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Any, Callable, List, Union +from typing import Any, List, Union import numpy as np import tensorflow as tf @@ -36,7 +36,6 @@ class OCRPredictor(NestedObject, _OCRPredictor): page. Doing so will slightly deteriorate the overall latency. detect_language: if True, the language prediction will be added to the predictions for each page. Doing so will slightly deteriorate the overall latency. - callbacks: list of callbacks to be applied to the OCR pipelines `loc_preds` **kwargs: keyword args of `DocumentBuilder` """ @@ -52,7 +51,6 @@ def __init__( symmetric_pad: bool = True, detect_orientation: bool = False, detect_language: bool = False, - callbacks: List[Callable] = [], **kwargs: Any, ) -> None: self.det_predictor = det_predictor @@ -62,7 +60,6 @@ def __init__( ) self.detect_orientation = detect_orientation self.detect_language = detect_language - self.callbacks = callbacks def __call__( self, @@ -108,9 +105,9 @@ def __call__( # Rectify crops if aspect ratio loc_preds = self._remove_padding(pages, loc_preds) - # Apply callbacks to loc_preds if any - for callback in self.callbacks: - loc_preds = callback(loc_preds) + # Apply hooks to loc_preds if any + for hook in self.hooks: + loc_preds = hook(loc_preds) # Crop images crops, loc_preds = self._prepare_crops( diff --git a/doctr/models/preprocessor/pytorch.py b/doctr/models/preprocessor/pytorch.py index 294a06a994..b155425f49 100644 --- a/doctr/models/preprocessor/pytorch.py +++ b/doctr/models/preprocessor/pytorch.py @@ -35,7 +35,6 @@ def __init__( batch_size: int, mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), std: Tuple[float, float, float] = (1.0, 1.0, 1.0), - fp16: bool = False, **kwargs: Any, ) -> None: super().__init__() diff --git a/doctr/models/preprocessor/tensorflow.py b/doctr/models/preprocessor/tensorflow.py index f876ccce3d..431f95b11f 100644 --- a/doctr/models/preprocessor/tensorflow.py +++ b/doctr/models/preprocessor/tensorflow.py @@ -35,7 +35,6 @@ def __init__( batch_size: int, mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), std: Tuple[float, float, float] = (1.0, 1.0, 1.0), - fp16: bool = False, **kwargs: Any, ) -> None: self.batch_size = batch_size diff --git a/doctr/models/zoo.py b/doctr/models/zoo.py index 5094d889d4..a351589037 100644 --- a/doctr/models/zoo.py +++ b/doctr/models/zoo.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Any, Callable, List +from typing import Any from .detection.zoo import detection_predictor from .kie_predictor import KIEPredictor @@ -26,7 +26,6 @@ def _predictor( detect_orientation: bool = False, straighten_pages: bool = False, detect_language: bool = False, - callbacks: List[Callable] = [], **kwargs, ) -> OCRPredictor: # Detection @@ -57,7 +56,6 @@ def _predictor( detect_orientation=detect_orientation, straighten_pages=straighten_pages, detect_language=detect_language, - callbacks=callbacks, **kwargs, ) @@ -74,7 +72,6 @@ def ocr_predictor( detect_orientation: bool = False, straighten_pages: bool = False, detect_language: bool = False, - callbacks: List[Callable] = [], **kwargs: Any, ) -> OCRPredictor: """End-to-end OCR architecture using one model for localization, and another for text recognition. @@ -108,7 +105,6 @@ def ocr_predictor( Doing so will improve performances for documents with page-uniform rotations. detect_language: if True, the language prediction will be added to the predictions for each page. Doing so will slightly deteriorate the overall latency. - callbacks: list of callbacks to be applied to the OCR pipelines `loc_preds` kwargs: keyword args of `OCRPredictor` Returns: @@ -127,7 +123,6 @@ def ocr_predictor( detect_orientation=detect_orientation, straighten_pages=straighten_pages, detect_language=detect_language, - callbacks=callbacks, **kwargs, ) @@ -145,7 +140,6 @@ def _kie_predictor( detect_orientation: bool = False, straighten_pages: bool = False, detect_language: bool = False, - callbacks: List[Callable] = [], **kwargs, ) -> KIEPredictor: # Detection @@ -176,7 +170,6 @@ def _kie_predictor( detect_orientation=detect_orientation, straighten_pages=straighten_pages, detect_language=detect_language, - callbacks=callbacks, **kwargs, ) @@ -193,7 +186,6 @@ def kie_predictor( detect_orientation: bool = False, straighten_pages: bool = False, detect_language: bool = False, - callbacks: List[Callable] = [], **kwargs: Any, ) -> KIEPredictor: """End-to-end KIE architecture using one model for localization, and another for text recognition. @@ -227,7 +219,6 @@ def kie_predictor( Doing so will improve performances for documents with page-uniform rotations. detect_language: if True, the language prediction will be added to the predictions for each page. Doing so will slightly deteriorate the overall latency. - callbacks: list of callbacks to be applied to the OCR pipelines `loc_preds` kwargs: keyword args of `OCRPredictor` Returns: @@ -246,6 +237,5 @@ def kie_predictor( detect_orientation=detect_orientation, straighten_pages=straighten_pages, detect_language=detect_language, - callbacks=callbacks, **kwargs, ) diff --git a/scripts/detect_text.py b/scripts/detect_text.py index 3a5b18a903..573080c59b 100644 --- a/scripts/detect_text.py +++ b/scripts/detect_text.py @@ -62,6 +62,7 @@ def main(args): detection_model = detection.__dict__[args.detection]( pretrained=True, bin_thresh=args.bin_thresh, + box_thresh=args.box_thresh, ) model = ocr_predictor(detection_model, args.recognition, pretrained=True) path = Path(args.path) @@ -86,6 +87,7 @@ def parse_args(): parser.add_argument("path", type=str, help="Path to process: PDF, image, directory") parser.add_argument("--detection", type=str, default="db_resnet50", help="Text detection model to use for analysis") parser.add_argument("--bin-thresh", type=float, default=0.3, help="Binarization threshold for the detection model.") + parser.add_argument("--box-thresh", type=float, default=0.1, help="Threshold for the detection boxes.") parser.add_argument( "--recognition", type=str, default="crnn_vgg16_bn", help="Text recognition model to use for analysis" ) diff --git a/tests/pytorch/test_models_zoo_pt.py b/tests/pytorch/test_models_zoo_pt.py index 88b0b0f4d1..3c6267ab7c 100644 --- a/tests/pytorch/test_models_zoo_pt.py +++ b/tests/pytorch/test_models_zoo_pt.py @@ -126,8 +126,9 @@ def test_trained_ocr_predictor(mock_payslip): straighten_pages=True, preserve_aspect_ratio=True, symmetric_pad=True, - callbacks=[_DummyCallback()], ) + # test hooks + predictor.add_hook(_DummyCallback()) out = predictor(doc) @@ -210,8 +211,9 @@ def test_trained_kie_predictor(mock_payslip): assume_straight_pages=True, straighten_pages=True, preserve_aspect_ratio=False, - callbacks=[_DummyCallback()], ) + # test hooks + predictor.add_hook(_DummyCallback()) out = predictor(doc) diff --git a/tests/tensorflow/test_models_zoo_tf.py b/tests/tensorflow/test_models_zoo_tf.py index db4295346e..906d6d0f5d 100644 --- a/tests/tensorflow/test_models_zoo_tf.py +++ b/tests/tensorflow/test_models_zoo_tf.py @@ -97,8 +97,9 @@ def test_trained_ocr_predictor(mock_payslip): assume_straight_pages=True, straighten_pages=True, preserve_aspect_ratio=False, - callbacks=[_DummyCallback()], ) + # test hooks + predictor.add_hook(_DummyCallback()) out = predictor(doc) @@ -208,8 +209,9 @@ def test_trained_kie_predictor(mock_payslip): assume_straight_pages=True, straighten_pages=True, preserve_aspect_ratio=False, - callbacks=[_DummyCallback()], ) + # test hooks + predictor.add_hook(_DummyCallback()) out = predictor(doc) From 775c1d27535011ee535505a9c97bfedf306e67ff Mon Sep 17 00:00:00 2001 From: felix Date: Mon, 5 Feb 2024 09:19:51 +0100 Subject: [PATCH 3/5] mypy --- doctr/models/predictor/base.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/doctr/models/predictor/base.py b/doctr/models/predictor/base.py index 5e136d0fc6..49d00893f7 100644 --- a/doctr/models/predictor/base.py +++ b/doctr/models/predictor/base.py @@ -48,7 +48,7 @@ def __init__( self.doc_builder = DocumentBuilder(**kwargs) self.preserve_aspect_ratio = preserve_aspect_ratio self.symmetric_pad = symmetric_pad - self.hooks = [] + self.hooks: List[Callable] = [] @staticmethod def _generate_crops( @@ -152,4 +152,10 @@ def _process_predictions( return loc_preds, text_preds def add_hook(self, hook: Callable) -> None: + """Add a hook to the predictor + + Args: + ---- + hook: a callable that takes as input the `loc_preds` and returns the modified `loc_preds` + """ self.hooks.append(hook) From 4445321266bf55c92d4574c463490699fbba8300 Mon Sep 17 00:00:00 2001 From: felix Date: Mon, 5 Feb 2024 09:41:51 +0100 Subject: [PATCH 4/5] docstring + missing box thresh --- .../detection/differentiable_binarization/tensorflow.py | 2 +- doctr/models/detection/linknet/tensorflow.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index f782bf2078..64bda8bcdc 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -112,7 +112,7 @@ class DBNet(_DBNet, keras.Model, NestedObject): ---- feature extractor: the backbone serving as feature extractor fpn_channels: number of channels each extracted feature maps is mapped to - bin_tresh: threshold for binarization + bin_thresh: threshold for binarization box_thresh: minimal objectness score to consider a box assume_straight_pages: if True, fit straight bounding boxes only exportable: onnx exportable returns only logits diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py index d9858b559f..ff11dbe477 100644 --- a/doctr/models/detection/linknet/tensorflow.py +++ b/doctr/models/detection/linknet/tensorflow.py @@ -98,6 +98,8 @@ class LinkNet(_LinkNet, keras.Model): ---- feature extractor: the backbone serving as feature extractor fpn_channels: number of channels each extracted feature maps is mapped to + bin_thresh: threshold for binarization of the output feature map + box_thresh: minimal objectness score to consider a box assume_straight_pages: if True, fit straight bounding boxes only exportable: onnx exportable returns only logits cfg: the configuration dict of the model @@ -111,6 +113,7 @@ def __init__( feat_extractor: IntermediateLayerGetter, fpn_channels: int = 64, bin_thresh: float = 0.1, + box_thresh: float = 0.1, assume_straight_pages: bool = True, exportable: bool = False, cfg: Optional[Dict[str, Any]] = None, @@ -152,7 +155,9 @@ def __init__( ), ]) - self.postprocessor = LinkNetPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh) + self.postprocessor = LinkNetPostProcessor( + assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh + ) def compute_loss( self, From 70cac27b304158731be93d99cc8cef32eb13836c Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 7 Feb 2024 15:37:17 +0100 Subject: [PATCH 5/5] update doc --- docs/source/using_doctr/using_models.rst | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/source/using_doctr/using_models.rst b/docs/source/using_doctr/using_models.rst index 1d50dcdbc1..e906338f56 100644 --- a/docs/source/using_doctr/using_models.rst +++ b/docs/source/using_doctr/using_models.rst @@ -433,7 +433,8 @@ This is useful to detect (possible less) text regions more accurately with a hig class CustomHook: def __call__(self, loc_preds): # Manipulate the location predictions here - # The outpout structure needs to be the same as the input location predictions + # 1. The outpout structure needs to be the same as the input location predictions + # 2. Be aware that the coordinates are relative and needs to be between 0 and 1 return loc_preds my_hook = CustomHook() @@ -442,4 +443,5 @@ This is useful to detect (possible less) text regions more accurately with a hig # Add a hook in the middle of the pipeline predictor.add_hook(my_hook) # You can also add multiple hooks which will be executed sequentially - [predictor.add_hook(hook) for hook in [my_hook, my_hook, my_hook]] + for hook in [my_hook, my_hook, my_hook]: + predictor.add_hook(hook)