Skip to content

Commit

Permalink
feat: ✨ expose thresh bin thresh in DBNet (#1110)
Browse files Browse the repository at this point in the history
* Migrate static data from github to monitoring middleware.

* feat: expose bin thresh

* fix: merging changes

* feat: expose thresh in script

* Update scripts/detect_text.py

* fix: predictor wrapper

* fix: lint

* fix lint

* fix: update mypy

* fix: mypy

* fix: sort

Co-authored-by: Marvin Amuzu <[email protected]>
Co-authored-by: ianaré <[email protected]>
  • Loading branch information
3 people authored Dec 5, 2022
1 parent acb9f64 commit b5ed162
Show file tree
Hide file tree
Showing 27 changed files with 47 additions and 29 deletions.
2 changes: 1 addition & 1 deletion doctr/datasets/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
if is_tf_available():
from .tensorflow import *
elif is_torch_available():
from .pytorch import * # type: ignore[misc]
from .pytorch import * # type: ignore[assignment]
2 changes: 1 addition & 1 deletion doctr/datasets/generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
if is_tf_available():
from .tensorflow import *
elif is_torch_available():
from .pytorch import * # type: ignore[misc]
from .pytorch import * # type: ignore[assignment]
2 changes: 1 addition & 1 deletion doctr/models/classification/magc_resnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
if is_tf_available():
from .tensorflow import *
elif is_torch_available():
from .pytorch import * # type: ignore[misc]
from .pytorch import * # type: ignore[assignment]
2 changes: 1 addition & 1 deletion doctr/models/classification/predictor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
if is_tf_available():
from .tensorflow import *
elif is_torch_available():
from .pytorch import * # type: ignore[misc]
from .pytorch import * # type: ignore[assignment]
2 changes: 1 addition & 1 deletion doctr/models/classification/resnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
if is_tf_available():
from .tensorflow import *
elif is_torch_available():
from .pytorch import * # type: ignore[misc]
from .pytorch import * # type: ignore[assignment]
2 changes: 1 addition & 1 deletion doctr/models/classification/vit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
if is_tf_available():
from .tensorflow import *
elif is_torch_available():
from .pytorch import * # type: ignore[misc]
from .pytorch import * # type: ignore[assignment]
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
if is_tf_available():
from .tensorflow import *
elif is_torch_available():
from .pytorch import * # type: ignore[misc]
from .pytorch import * # type: ignore[assignment]
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
head_chans: int = 256,
deform_conv: bool = False,
num_classes: int = 1,
bin_thresh: float = 0.3,
assume_straight_pages: bool = True,
exportable: bool = False,
cfg: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -166,7 +167,7 @@ def __init__(
nn.ConvTranspose2d(head_chans // 4, num_classes, 2, stride=2),
)

self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages)
self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh)

for n, m in self.named_modules():
# Don't override the initialization of the backbone
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __init__(
feature_extractor: IntermediateLayerGetter,
fpn_channels: int = 128, # to be set to 256 to represent the author's initial idea
num_classes: int = 1,
bin_thresh: float = 0.3,
assume_straight_pages: bool = True,
exportable: bool = False,
cfg: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -158,7 +159,7 @@ def __init__(
]
)

self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages)
self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh)

def compute_loss(self, out_map: tf.Tensor, thresh_map: tf.Tensor, target: List[np.ndarray]) -> tf.Tensor:
"""Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/linknet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
if is_tf_available():
from .tensorflow import *
elif is_torch_available():
from .pytorch import * # type: ignore[misc]
from .pytorch import * # type: ignore[assignment]
5 changes: 4 additions & 1 deletion doctr/models/detection/linknet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(
self,
feat_extractor: IntermediateLayerGetter,
num_classes: int = 1,
bin_thresh: float = 0.1,
head_chans: int = 32,
assume_straight_pages: bool = True,
exportable: bool = False,
Expand Down Expand Up @@ -139,7 +140,9 @@ def __init__(
nn.ConvTranspose2d(head_chans, num_classes, kernel_size=2, stride=2),
)

self.postprocessor = LinkNetPostProcessor(assume_straight_pages=self.assume_straight_pages)
self.postprocessor = LinkNetPostProcessor(
assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh
)

for n, m in self.named_modules():
# Don't override the initialization of the backbone
Expand Down
3 changes: 2 additions & 1 deletion doctr/models/detection/linknet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
feat_extractor: IntermediateLayerGetter,
fpn_channels: int = 64,
num_classes: int = 1,
bin_thresh: float = 0.1,
assume_straight_pages: bool = True,
exportable: bool = False,
cfg: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -159,7 +160,7 @@ def __init__(
]
)

self.postprocessor = LinkNetPostProcessor(assume_straight_pages=assume_straight_pages)
self.postprocessor = LinkNetPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh)

def compute_loss(
self,
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/predictor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
if is_tf_available():
from .tensorflow import *
else:
from .pytorch import * # type: ignore[misc]
from .pytorch import * # type: ignore[assignment]
8 changes: 7 additions & 1 deletion doctr/models/factory/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
from pathlib import Path
from typing import Any

from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, snapshot_download
from huggingface_hub import ( # type: ignore[attr-defined]
HfApi,
HfFolder,
Repository,
hf_hub_download,
snapshot_download,
)

from doctr import models
from doctr.file_utils import is_tf_available, is_torch_available
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/modules/transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
if is_tf_available():
from .tensorflow import *
elif is_torch_available():
from .pytorch import * # type: ignore[misc]
from .pytorch import * # type: ignore[assignment]
2 changes: 1 addition & 1 deletion doctr/models/modules/vision_transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
if is_tf_available():
from .tensorflow import *
elif is_torch_available():
from .pytorch import * # type: ignore[misc]
from .pytorch import * # type: ignore[assignment]
4 changes: 2 additions & 2 deletions doctr/models/modules/vision_transformer/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def __init__(self, input_shape: Tuple[int, int, int], embed_dim: int) -> None:
self.grid_size = (self.patch_size[0], self.patch_size[1])
self.num_patches = self.patch_size[0] * self.patch_size[1]

self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) # type: ignore[attr-defined]
self.positions = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim)) # type: ignore[attr-defined]
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
self.positions = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
self.proj = nn.Linear((channels * self.patch_size[0] * self.patch_size[1]), embed_dim)

def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/predictor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
if is_tf_available():
from .tensorflow import *
else:
from .pytorch import * # type: ignore[misc]
from .pytorch import * # type: ignore[assignment]
2 changes: 1 addition & 1 deletion doctr/models/preprocessor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
if is_tf_available():
from .tensorflow import *
elif is_torch_available():
from .pytorch import * # type: ignore[misc]
from .pytorch import * # type: ignore[assignment]
2 changes: 1 addition & 1 deletion doctr/models/recognition/crnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
if is_tf_available():
from .tensorflow import *
elif is_torch_available():
from .pytorch import * # type: ignore[misc]
from .pytorch import * # type: ignore[assignment]
2 changes: 1 addition & 1 deletion doctr/models/recognition/master/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
if is_tf_available():
from .tensorflow import *
elif is_torch_available():
from .pytorch import * # type: ignore[misc]
from .pytorch import * # type: ignore[assignment]
2 changes: 1 addition & 1 deletion doctr/models/recognition/predictor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
if is_tf_available():
from .tensorflow import *
else:
from .pytorch import * # type: ignore[misc]
from .pytorch import * # type: ignore[assignment]
2 changes: 1 addition & 1 deletion doctr/models/recognition/sar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
if is_tf_available():
from .tensorflow import *
elif is_torch_available():
from .pytorch import * # type: ignore[misc]
from .pytorch import * # type: ignore[assignment]
2 changes: 1 addition & 1 deletion doctr/models/recognition/vitstr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
if is_tf_available():
from .tensorflow import *
elif is_torch_available():
from .pytorch import * # type: ignore[misc]
from .pytorch import * # type: ignore[assignment]
2 changes: 1 addition & 1 deletion doctr/models/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
if is_tf_available():
from .tensorflow import *
elif is_torch_available():
from .pytorch import * # type: ignore[misc]
from .pytorch import * # type: ignore[assignment]
2 changes: 1 addition & 1 deletion doctr/transforms/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
if is_tf_available():
from .tensorflow import *
elif is_torch_available():
from .pytorch import * # type: ignore[misc]
from .pytorch import * # type: ignore[assignment]
10 changes: 8 additions & 2 deletions scripts/detect_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from doctr.file_utils import is_tf_available
from doctr.io import DocumentFile
from doctr.models import ocr_predictor
from doctr.models import detection, ocr_predictor

# Enable GPU growth if using TF
if is_tf_available():
Expand Down Expand Up @@ -60,7 +60,12 @@ def _process_file(model, file_path: Path, out_format: str) -> None:


def main(args):
model = ocr_predictor(args.detection, args.recognition, pretrained=True)

detection_model = detection.__dict__[args.detection](
pretrained=True,
bin_thresh=args.bin_thresh,
)
model = ocr_predictor(detection_model, args.recognition, pretrained=True)
path = Path(args.path)

os.makedirs(name="output", exist_ok=True)
Expand All @@ -82,6 +87,7 @@ def parse_args():
)
parser.add_argument("path", type=str, help="Path to process: PDF, image, directory")
parser.add_argument("--detection", type=str, default="db_resnet50", help="Text detection model to use for analysis")
parser.add_argument("--bin-thresh", type=float, default=0.3, help="Binarization threshold for the detection model.")
parser.add_argument(
"--recognition", type=str, default="crnn_vgg16_bn", help="Text recognition model to use for analysis"
)
Expand Down

0 comments on commit b5ed162

Please sign in to comment.