Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[prototype] Extend detection result customization #1449

Merged
merged 5 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion demo/app.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to do the same on HF Space

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep this would be good afterwards :)

Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def main(det_archs, reco_archs):
# Binarization threshold
bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1)
st.sidebar.write("\n")
# Box threshold
box_thresh = st.sidebar.slider("Box threshold", min_value=0.1, max_value=0.9, value=0.1, step=0.1)
st.sidebar.write("\n")

if st.sidebar.button("Analyze page"):
if uploaded_file is None:
Expand All @@ -88,7 +91,7 @@ def main(det_archs, reco_archs):
else:
with st.spinner("Loading model..."):
predictor = load_predictor(
det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, forward_device
det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, box_thresh, forward_device
)

with st.spinner("Analyzing..."):
Expand Down
3 changes: 3 additions & 0 deletions demo/backend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def load_predictor(
assume_straight_pages: bool,
straighten_pages: bool,
bin_thresh: float,
box_thresh: float,
device: torch.device,
) -> OCRPredictor:
"""Load a predictor from doctr.models
Expand All @@ -46,6 +47,7 @@ def load_predictor(
assume_straight_pages: whether to assume straight pages or not
straighten_pages: whether to straighten rotated pages or not
bin_thresh: binarization threshold for the segmentation map
box_thresh: minimal objectness score to consider a box
device: torch.device, the device to load the predictor on

Returns:
Expand All @@ -62,6 +64,7 @@ def load_predictor(
detect_orientation=not assume_straight_pages,
).to(device)
predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
predictor.det_predictor.model.postprocessor.box_thresh = box_thresh
return predictor


Expand Down
3 changes: 3 additions & 0 deletions demo/backend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def load_predictor(
assume_straight_pages: bool,
straighten_pages: bool,
bin_thresh: float,
box_thresh: float,
device: tf.device,
) -> OCRPredictor:
"""Load a predictor from doctr.models
Expand All @@ -45,6 +46,7 @@ def load_predictor(
assume_straight_pages: whether to assume straight pages or not
straighten_pages: whether to straighten rotated pages or not
bin_thresh: binarization threshold for the segmentation map
box_thresh: threshold for the detection boxes
device: tf.device, the device to load the predictor on

Returns:
Expand All @@ -62,6 +64,7 @@ def load_predictor(
detect_orientation=not assume_straight_pages,
)
predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
predictor.det_predictor.model.postprocessor.box_thresh = box_thresh
return predictor


Expand Down
45 changes: 45 additions & 0 deletions docs/source/using_doctr/using_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -398,3 +398,48 @@ For reference, here is a sample XML byte string output:
</div>
</body>
</html>


Advanced options
^^^^^^^^^^^^^^^^
We provide a few advanced options to customize the behavior of the predictor to your needs:

* Modify the binarization threshold for the detection model.
* Modify the box threshold for the detection model.

This is useful to detect (possible less) text regions more accurately with a higher threshold, or to detect more text regions with a lower threshold.


.. code:: python3

import numpy as np
from doctr.models import ocr_predictor
predictor = ocr_predictor('db_resnet50', 'crnn_vgg16_bn', pretrained=True)

# Modify the binarization threshold and the box threshold
predictor.det_predictor.model.postprocessor.bin_thresh = 0.5
predictor.det_predictor.model.postprocessor.box_thresh = 0.2

input_page = (255 * np.random.rand(800, 600, 3)).astype(np.uint8)
out = predictor([input_page])


* Add a hook to the `ocr_predictor` to manipulate the location predictions before the crops are passed to the recognition model.

.. code:: python3

from doctr.model import ocr_predictor

class CustomHook:
def __call__(self, loc_preds):
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
# Manipulate the location predictions here
# The outpout structure needs to be the same as the input location predictions
return loc_preds

my_hook = CustomHook()

predictor = ocr_predictor(pretrained=True)
# Add a hook in the middle of the pipeline
predictor.add_hook(my_hook)
# You can also add multiple hooks which will be executed sequentially
[predictor.add_hook(hook) for hook in [my_hook, my_hook, my_hook]]
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class DBNet(_DBNet, nn.Module):
feature extractor: the backbone serving as feature extractor
head_chans: the number of channels in the head
deform_conv: whether to use deformable convolution
bin_thresh: threshold for binarization
box_thresh: minimal objectness score to consider a box
assume_straight_pages: if True, fit straight bounding boxes only
exportable: onnx exportable returns only logits
cfg: the configuration dict of the model
Expand All @@ -112,6 +114,7 @@ def __init__(
head_chans: int = 256,
deform_conv: bool = False,
bin_thresh: float = 0.3,
box_thresh: float = 0.1,
assume_straight_pages: bool = True,
exportable: bool = False,
cfg: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -160,7 +163,9 @@ def __init__(
nn.ConvTranspose2d(head_chans // 4, num_classes, 2, stride=2),
)

self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh)
self.postprocessor = DBPostProcessor(
assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
)

for n, m in self.named_modules():
# Don't override the initialization of the backbone
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ class DBNet(_DBNet, keras.Model, NestedObject):
----
feature extractor: the backbone serving as feature extractor
fpn_channels: number of channels each extracted feature maps is mapped to
bin_thresh: threshold for binarization
box_thresh: minimal objectness score to consider a box
assume_straight_pages: if True, fit straight bounding boxes only
exportable: onnx exportable returns only logits
cfg: the configuration dict of the model
Expand All @@ -125,6 +127,7 @@ def __init__(
feature_extractor: IntermediateLayerGetter,
fpn_channels: int = 128, # to be set to 256 to represent the author's initial idea
bin_thresh: float = 0.3,
box_thresh: float = 0.1,
assume_straight_pages: bool = True,
exportable: bool = False,
cfg: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -159,7 +162,9 @@ def __init__(
layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
])

self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh)
self.postprocessor = DBPostProcessor(
assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
)

def compute_loss(
self,
Expand Down
5 changes: 4 additions & 1 deletion doctr/models/detection/linknet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class LinkNet(nn.Module, _LinkNet):
Args:
----
feature extractor: the backbone serving as feature extractor
bin_thresh: threshold for binarization of the output feature map
box_thresh: minimal objectness score to consider a box
head_chans: number of channels in the head layers
assume_straight_pages: if True, fit straight bounding boxes only
exportable: onnx exportable returns only logits
Expand All @@ -102,6 +104,7 @@ def __init__(
self,
feat_extractor: IntermediateLayerGetter,
bin_thresh: float = 0.1,
box_thresh: float = 0.1,
head_chans: int = 32,
assume_straight_pages: bool = True,
exportable: bool = False,
Expand Down Expand Up @@ -142,7 +145,7 @@ def __init__(
)

self.postprocessor = LinkNetPostProcessor(
assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh
assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
)

for n, m in self.named_modules():
Expand Down
7 changes: 6 additions & 1 deletion doctr/models/detection/linknet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class LinkNet(_LinkNet, keras.Model):
----
feature extractor: the backbone serving as feature extractor
fpn_channels: number of channels each extracted feature maps is mapped to
bin_thresh: threshold for binarization of the output feature map
box_thresh: minimal objectness score to consider a box
assume_straight_pages: if True, fit straight bounding boxes only
exportable: onnx exportable returns only logits
cfg: the configuration dict of the model
Expand All @@ -111,6 +113,7 @@ def __init__(
feat_extractor: IntermediateLayerGetter,
fpn_channels: int = 64,
bin_thresh: float = 0.1,
box_thresh: float = 0.1,
assume_straight_pages: bool = True,
exportable: bool = False,
cfg: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -152,7 +155,9 @@ def __init__(
),
])

self.postprocessor = LinkNetPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh)
self.postprocessor = LinkNetPostProcessor(
assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
)

def compute_loss(
self,
Expand Down
4 changes: 4 additions & 0 deletions doctr/models/kie_predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def forward(
# Rectify crops if aspect ratio
dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()}

# Apply hooks to loc_preds if any
for hook in self.hooks:
dict_loc_preds = hook(dict_loc_preds)

# Crop images
crops = {}
for class_name in dict_loc_preds.keys():
Expand Down
4 changes: 4 additions & 0 deletions doctr/models/kie_predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def __call__(
# Rectify crops if aspect ratio
dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()}

# Apply hooks to loc_preds if any
for hook in self.hooks:
dict_loc_preds = hook(dict_loc_preds)

# Crop images
crops = {}
for class_name in dict_loc_preds.keys():
Expand Down
12 changes: 11 additions & 1 deletion doctr/models/predictor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Any, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple

import numpy as np

Expand Down Expand Up @@ -48,6 +48,7 @@ def __init__(
self.doc_builder = DocumentBuilder(**kwargs)
self.preserve_aspect_ratio = preserve_aspect_ratio
self.symmetric_pad = symmetric_pad
self.hooks: List[Callable] = []

@staticmethod
def _generate_crops(
Expand Down Expand Up @@ -149,3 +150,12 @@ def _process_predictions(
_idx += page_boxes.shape[0]

return loc_preds, text_preds

def add_hook(self, hook: Callable) -> None:
"""Add a hook to the predictor

Args:
----
hook: a callable that takes as input the `loc_preds` and returns the modified `loc_preds`
"""
self.hooks.append(hook)
4 changes: 4 additions & 0 deletions doctr/models/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def forward(
# Rectify crops if aspect ratio
loc_preds = self._remove_padding(pages, loc_preds)

# Apply hooks to loc_preds if any
for hook in self.hooks:
loc_preds = hook(loc_preds)

# Crop images
crops, loc_preds = self._prepare_crops(
pages,
Expand Down
4 changes: 4 additions & 0 deletions doctr/models/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def __call__(
# Rectify crops if aspect ratio
loc_preds = self._remove_padding(pages, loc_preds)

# Apply hooks to loc_preds if any
for hook in self.hooks:
loc_preds = hook(loc_preds)

# Crop images
crops, loc_preds = self._prepare_crops(
pages, loc_preds, channels_last=True, assume_straight_pages=self.assume_straight_pages
Expand Down
1 change: 0 additions & 1 deletion doctr/models/preprocessor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(
batch_size: int,
mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
std: Tuple[float, float, float] = (1.0, 1.0, 1.0),
fp16: bool = False,
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
**kwargs: Any,
) -> None:
super().__init__()
Expand Down
1 change: 0 additions & 1 deletion doctr/models/preprocessor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(
batch_size: int,
mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
std: Tuple[float, float, float] = (1.0, 1.0, 1.0),
fp16: bool = False,
**kwargs: Any,
) -> None:
self.batch_size = batch_size
Expand Down
2 changes: 2 additions & 0 deletions scripts/detect_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def main(args):
detection_model = detection.__dict__[args.detection](
pretrained=True,
bin_thresh=args.bin_thresh,
box_thresh=args.box_thresh,
)
model = ocr_predictor(detection_model, args.recognition, pretrained=True)
path = Path(args.path)
Expand All @@ -86,6 +87,7 @@ def parse_args():
parser.add_argument("path", type=str, help="Path to process: PDF, image, directory")
parser.add_argument("--detection", type=str, default="db_resnet50", help="Text detection model to use for analysis")
parser.add_argument("--bin-thresh", type=float, default=0.3, help="Binarization threshold for the detection model.")
parser.add_argument("--box-thresh", type=float, default=0.1, help="Threshold for the detection boxes.")
parser.add_argument(
"--recognition", type=str, default="crnn_vgg16_bn", help="Text recognition model to use for analysis"
)
Expand Down
10 changes: 10 additions & 0 deletions tests/pytorch/test_models_zoo_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
from doctr.models.recognition.zoo import recognition_predictor


# Create a dummy callback
class _DummyCallback:
def __call__(self, loc_preds):
return loc_preds


@pytest.mark.parametrize(
"assume_straight_pages, straighten_pages",
[
Expand Down Expand Up @@ -121,6 +127,8 @@ def test_trained_ocr_predictor(mock_payslip):
preserve_aspect_ratio=True,
symmetric_pad=True,
)
# test hooks
predictor.add_hook(_DummyCallback())

out = predictor(doc)

Expand Down Expand Up @@ -204,6 +212,8 @@ def test_trained_kie_predictor(mock_payslip):
straighten_pages=True,
preserve_aspect_ratio=False,
)
# test hooks
predictor.add_hook(_DummyCallback())

out = predictor(doc)

Expand Down
10 changes: 10 additions & 0 deletions tests/tensorflow/test_models_zoo_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
from doctr.utils.repr import NestedObject


# Create a dummy callback
class _DummyCallback:
def __call__(self, loc_preds):
return loc_preds


@pytest.mark.parametrize(
"assume_straight_pages, straighten_pages",
[
Expand Down Expand Up @@ -92,6 +98,8 @@ def test_trained_ocr_predictor(mock_payslip):
straighten_pages=True,
preserve_aspect_ratio=False,
)
# test hooks
predictor.add_hook(_DummyCallback())

out = predictor(doc)

Expand Down Expand Up @@ -202,6 +210,8 @@ def test_trained_kie_predictor(mock_payslip):
straighten_pages=True,
preserve_aspect_ratio=False,
)
# test hooks
predictor.add_hook(_DummyCallback())

out = predictor(doc)

Expand Down
Loading