diff --git a/demo/app.py b/demo/app.py index 4f156d87e3..bd624ea382 100644 --- a/demo/app.py +++ b/demo/app.py @@ -80,6 +80,9 @@ def main(det_archs, reco_archs): # Binarization threshold bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1) st.sidebar.write("\n") + # Box threshold + box_thresh = st.sidebar.slider("Box threshold", min_value=0.1, max_value=0.9, value=0.1, step=0.1) + st.sidebar.write("\n") if st.sidebar.button("Analyze page"): if uploaded_file is None: @@ -88,7 +91,7 @@ def main(det_archs, reco_archs): else: with st.spinner("Loading model..."): predictor = load_predictor( - det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, forward_device + det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, box_thresh, forward_device ) with st.spinner("Analyzing..."): diff --git a/demo/backend/pytorch.py b/demo/backend/pytorch.py index f15f1e6fcf..f9b8f443f6 100644 --- a/demo/backend/pytorch.py +++ b/demo/backend/pytorch.py @@ -35,6 +35,7 @@ def load_predictor( assume_straight_pages: bool, straighten_pages: bool, bin_thresh: float, + box_thresh: float, device: torch.device, ) -> OCRPredictor: """Load a predictor from doctr.models @@ -46,6 +47,7 @@ def load_predictor( assume_straight_pages: whether to assume straight pages or not straighten_pages: whether to straighten rotated pages or not bin_thresh: binarization threshold for the segmentation map + box_thresh: minimal objectness score to consider a box device: torch.device, the device to load the predictor on Returns: @@ -62,6 +64,7 @@ def load_predictor( detect_orientation=not assume_straight_pages, ).to(device) predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh + predictor.det_predictor.model.postprocessor.box_thresh = box_thresh return predictor diff --git a/demo/backend/tensorflow.py b/demo/backend/tensorflow.py index 08894b688e..3676af83cf 100644 --- a/demo/backend/tensorflow.py +++ b/demo/backend/tensorflow.py @@ -34,6 +34,7 @@ def load_predictor( assume_straight_pages: bool, straighten_pages: bool, bin_thresh: float, + box_thresh: float, device: tf.device, ) -> OCRPredictor: """Load a predictor from doctr.models @@ -45,6 +46,7 @@ def load_predictor( assume_straight_pages: whether to assume straight pages or not straighten_pages: whether to straighten rotated pages or not bin_thresh: binarization threshold for the segmentation map + box_thresh: threshold for the detection boxes device: tf.device, the device to load the predictor on Returns: @@ -62,6 +64,7 @@ def load_predictor( detect_orientation=not assume_straight_pages, ) predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh + predictor.det_predictor.model.postprocessor.box_thresh = box_thresh return predictor diff --git a/docs/source/using_doctr/using_models.rst b/docs/source/using_doctr/using_models.rst index dba1965ba2..e906338f56 100644 --- a/docs/source/using_doctr/using_models.rst +++ b/docs/source/using_doctr/using_models.rst @@ -398,3 +398,50 @@ For reference, here is a sample XML byte string output: + + +Advanced options +^^^^^^^^^^^^^^^^ +We provide a few advanced options to customize the behavior of the predictor to your needs: + +* Modify the binarization threshold for the detection model. +* Modify the box threshold for the detection model. + +This is useful to detect (possible less) text regions more accurately with a higher threshold, or to detect more text regions with a lower threshold. + + +.. code:: python3 + + import numpy as np + from doctr.models import ocr_predictor + predictor = ocr_predictor('db_resnet50', 'crnn_vgg16_bn', pretrained=True) + + # Modify the binarization threshold and the box threshold + predictor.det_predictor.model.postprocessor.bin_thresh = 0.5 + predictor.det_predictor.model.postprocessor.box_thresh = 0.2 + + input_page = (255 * np.random.rand(800, 600, 3)).astype(np.uint8) + out = predictor([input_page]) + + +* Add a hook to the `ocr_predictor` to manipulate the location predictions before the crops are passed to the recognition model. + +.. code:: python3 + + from doctr.model import ocr_predictor + + class CustomHook: + def __call__(self, loc_preds): + # Manipulate the location predictions here + # 1. The outpout structure needs to be the same as the input location predictions + # 2. Be aware that the coordinates are relative and needs to be between 0 and 1 + return loc_preds + + my_hook = CustomHook() + + predictor = ocr_predictor(pretrained=True) + # Add a hook in the middle of the pipeline + predictor.add_hook(my_hook) + # You can also add multiple hooks which will be executed sequentially + for hook in [my_hook, my_hook, my_hook]: + predictor.add_hook(hook) diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index f11408bd3d..cfc8267a3d 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -100,6 +100,8 @@ class DBNet(_DBNet, nn.Module): feature extractor: the backbone serving as feature extractor head_chans: the number of channels in the head deform_conv: whether to use deformable convolution + bin_thresh: threshold for binarization + box_thresh: minimal objectness score to consider a box assume_straight_pages: if True, fit straight bounding boxes only exportable: onnx exportable returns only logits cfg: the configuration dict of the model @@ -112,6 +114,7 @@ def __init__( head_chans: int = 256, deform_conv: bool = False, bin_thresh: float = 0.3, + box_thresh: float = 0.1, assume_straight_pages: bool = True, exportable: bool = False, cfg: Optional[Dict[str, Any]] = None, @@ -160,7 +163,9 @@ def __init__( nn.ConvTranspose2d(head_chans // 4, num_classes, 2, stride=2), ) - self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh) + self.postprocessor = DBPostProcessor( + assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh + ) for n, m in self.named_modules(): # Don't override the initialization of the backbone diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index 38835ead79..64bda8bcdc 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -112,6 +112,8 @@ class DBNet(_DBNet, keras.Model, NestedObject): ---- feature extractor: the backbone serving as feature extractor fpn_channels: number of channels each extracted feature maps is mapped to + bin_thresh: threshold for binarization + box_thresh: minimal objectness score to consider a box assume_straight_pages: if True, fit straight bounding boxes only exportable: onnx exportable returns only logits cfg: the configuration dict of the model @@ -125,6 +127,7 @@ def __init__( feature_extractor: IntermediateLayerGetter, fpn_channels: int = 128, # to be set to 256 to represent the author's initial idea bin_thresh: float = 0.3, + box_thresh: float = 0.1, assume_straight_pages: bool = True, exportable: bool = False, cfg: Optional[Dict[str, Any]] = None, @@ -159,7 +162,9 @@ def __init__( layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"), ]) - self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh) + self.postprocessor = DBPostProcessor( + assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh + ) def compute_loss( self, diff --git a/doctr/models/detection/linknet/pytorch.py b/doctr/models/detection/linknet/pytorch.py index ecf72beda6..537fd57256 100644 --- a/doctr/models/detection/linknet/pytorch.py +++ b/doctr/models/detection/linknet/pytorch.py @@ -91,6 +91,8 @@ class LinkNet(nn.Module, _LinkNet): Args: ---- feature extractor: the backbone serving as feature extractor + bin_thresh: threshold for binarization of the output feature map + box_thresh: minimal objectness score to consider a box head_chans: number of channels in the head layers assume_straight_pages: if True, fit straight bounding boxes only exportable: onnx exportable returns only logits @@ -102,6 +104,7 @@ def __init__( self, feat_extractor: IntermediateLayerGetter, bin_thresh: float = 0.1, + box_thresh: float = 0.1, head_chans: int = 32, assume_straight_pages: bool = True, exportable: bool = False, @@ -142,7 +145,7 @@ def __init__( ) self.postprocessor = LinkNetPostProcessor( - assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh + assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh ) for n, m in self.named_modules(): diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py index d9858b559f..ff11dbe477 100644 --- a/doctr/models/detection/linknet/tensorflow.py +++ b/doctr/models/detection/linknet/tensorflow.py @@ -98,6 +98,8 @@ class LinkNet(_LinkNet, keras.Model): ---- feature extractor: the backbone serving as feature extractor fpn_channels: number of channels each extracted feature maps is mapped to + bin_thresh: threshold for binarization of the output feature map + box_thresh: minimal objectness score to consider a box assume_straight_pages: if True, fit straight bounding boxes only exportable: onnx exportable returns only logits cfg: the configuration dict of the model @@ -111,6 +113,7 @@ def __init__( feat_extractor: IntermediateLayerGetter, fpn_channels: int = 64, bin_thresh: float = 0.1, + box_thresh: float = 0.1, assume_straight_pages: bool = True, exportable: bool = False, cfg: Optional[Dict[str, Any]] = None, @@ -152,7 +155,9 @@ def __init__( ), ]) - self.postprocessor = LinkNetPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh) + self.postprocessor = LinkNetPostProcessor( + assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh + ) def compute_loss( self, diff --git a/doctr/models/kie_predictor/pytorch.py b/doctr/models/kie_predictor/pytorch.py index 152b2acaca..81f342931d 100644 --- a/doctr/models/kie_predictor/pytorch.py +++ b/doctr/models/kie_predictor/pytorch.py @@ -106,6 +106,10 @@ def forward( # Rectify crops if aspect ratio dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()} + # Apply hooks to loc_preds if any + for hook in self.hooks: + dict_loc_preds = hook(dict_loc_preds) + # Crop images crops = {} for class_name in dict_loc_preds.keys(): diff --git a/doctr/models/kie_predictor/tensorflow.py b/doctr/models/kie_predictor/tensorflow.py index 21ce9953ee..e266c0c0c5 100644 --- a/doctr/models/kie_predictor/tensorflow.py +++ b/doctr/models/kie_predictor/tensorflow.py @@ -103,6 +103,10 @@ def __call__( # Rectify crops if aspect ratio dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()} + # Apply hooks to loc_preds if any + for hook in self.hooks: + dict_loc_preds = hook(dict_loc_preds) + # Crop images crops = {} for class_name in dict_loc_preds.keys(): diff --git a/doctr/models/predictor/base.py b/doctr/models/predictor/base.py index cf844f2053..49d00893f7 100644 --- a/doctr/models/predictor/base.py +++ b/doctr/models/predictor/base.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Any, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple import numpy as np @@ -48,6 +48,7 @@ def __init__( self.doc_builder = DocumentBuilder(**kwargs) self.preserve_aspect_ratio = preserve_aspect_ratio self.symmetric_pad = symmetric_pad + self.hooks: List[Callable] = [] @staticmethod def _generate_crops( @@ -149,3 +150,12 @@ def _process_predictions( _idx += page_boxes.shape[0] return loc_preds, text_preds + + def add_hook(self, hook: Callable) -> None: + """Add a hook to the predictor + + Args: + ---- + hook: a callable that takes as input the `loc_preds` and returns the modified `loc_preds` + """ + self.hooks.append(hook) diff --git a/doctr/models/predictor/pytorch.py b/doctr/models/predictor/pytorch.py index a1773d1e31..be22024116 100644 --- a/doctr/models/predictor/pytorch.py +++ b/doctr/models/predictor/pytorch.py @@ -108,6 +108,10 @@ def forward( # Rectify crops if aspect ratio loc_preds = self._remove_padding(pages, loc_preds) + # Apply hooks to loc_preds if any + for hook in self.hooks: + loc_preds = hook(loc_preds) + # Crop images crops, loc_preds = self._prepare_crops( pages, diff --git a/doctr/models/predictor/tensorflow.py b/doctr/models/predictor/tensorflow.py index 5461d3020b..424615656a 100644 --- a/doctr/models/predictor/tensorflow.py +++ b/doctr/models/predictor/tensorflow.py @@ -105,6 +105,10 @@ def __call__( # Rectify crops if aspect ratio loc_preds = self._remove_padding(pages, loc_preds) + # Apply hooks to loc_preds if any + for hook in self.hooks: + loc_preds = hook(loc_preds) + # Crop images crops, loc_preds = self._prepare_crops( pages, loc_preds, channels_last=True, assume_straight_pages=self.assume_straight_pages diff --git a/doctr/models/preprocessor/pytorch.py b/doctr/models/preprocessor/pytorch.py index 294a06a994..b155425f49 100644 --- a/doctr/models/preprocessor/pytorch.py +++ b/doctr/models/preprocessor/pytorch.py @@ -35,7 +35,6 @@ def __init__( batch_size: int, mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), std: Tuple[float, float, float] = (1.0, 1.0, 1.0), - fp16: bool = False, **kwargs: Any, ) -> None: super().__init__() diff --git a/doctr/models/preprocessor/tensorflow.py b/doctr/models/preprocessor/tensorflow.py index f876ccce3d..431f95b11f 100644 --- a/doctr/models/preprocessor/tensorflow.py +++ b/doctr/models/preprocessor/tensorflow.py @@ -35,7 +35,6 @@ def __init__( batch_size: int, mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), std: Tuple[float, float, float] = (1.0, 1.0, 1.0), - fp16: bool = False, **kwargs: Any, ) -> None: self.batch_size = batch_size diff --git a/scripts/detect_text.py b/scripts/detect_text.py index 3a5b18a903..573080c59b 100644 --- a/scripts/detect_text.py +++ b/scripts/detect_text.py @@ -62,6 +62,7 @@ def main(args): detection_model = detection.__dict__[args.detection]( pretrained=True, bin_thresh=args.bin_thresh, + box_thresh=args.box_thresh, ) model = ocr_predictor(detection_model, args.recognition, pretrained=True) path = Path(args.path) @@ -86,6 +87,7 @@ def parse_args(): parser.add_argument("path", type=str, help="Path to process: PDF, image, directory") parser.add_argument("--detection", type=str, default="db_resnet50", help="Text detection model to use for analysis") parser.add_argument("--bin-thresh", type=float, default=0.3, help="Binarization threshold for the detection model.") + parser.add_argument("--box-thresh", type=float, default=0.1, help="Threshold for the detection boxes.") parser.add_argument( "--recognition", type=str, default="crnn_vgg16_bn", help="Text recognition model to use for analysis" ) diff --git a/tests/pytorch/test_models_zoo_pt.py b/tests/pytorch/test_models_zoo_pt.py index fa3f23b9d1..3c6267ab7c 100644 --- a/tests/pytorch/test_models_zoo_pt.py +++ b/tests/pytorch/test_models_zoo_pt.py @@ -16,6 +16,12 @@ from doctr.models.recognition.zoo import recognition_predictor +# Create a dummy callback +class _DummyCallback: + def __call__(self, loc_preds): + return loc_preds + + @pytest.mark.parametrize( "assume_straight_pages, straighten_pages", [ @@ -121,6 +127,8 @@ def test_trained_ocr_predictor(mock_payslip): preserve_aspect_ratio=True, symmetric_pad=True, ) + # test hooks + predictor.add_hook(_DummyCallback()) out = predictor(doc) @@ -204,6 +212,8 @@ def test_trained_kie_predictor(mock_payslip): straighten_pages=True, preserve_aspect_ratio=False, ) + # test hooks + predictor.add_hook(_DummyCallback()) out = predictor(doc) diff --git a/tests/tensorflow/test_models_zoo_tf.py b/tests/tensorflow/test_models_zoo_tf.py index 32e7988560..906d6d0f5d 100644 --- a/tests/tensorflow/test_models_zoo_tf.py +++ b/tests/tensorflow/test_models_zoo_tf.py @@ -16,6 +16,12 @@ from doctr.utils.repr import NestedObject +# Create a dummy callback +class _DummyCallback: + def __call__(self, loc_preds): + return loc_preds + + @pytest.mark.parametrize( "assume_straight_pages, straighten_pages", [ @@ -92,6 +98,8 @@ def test_trained_ocr_predictor(mock_payslip): straighten_pages=True, preserve_aspect_ratio=False, ) + # test hooks + predictor.add_hook(_DummyCallback()) out = predictor(doc) @@ -202,6 +210,8 @@ def test_trained_kie_predictor(mock_payslip): straighten_pages=True, preserve_aspect_ratio=False, ) + # test hooks + predictor.add_hook(_DummyCallback()) out = predictor(doc)