diff --git a/doctr/datasets/datasets/__init__.py b/doctr/datasets/datasets/__init__.py index 059f261e82..c7110f5669 100644 --- a/doctr/datasets/datasets/__init__.py +++ b/doctr/datasets/datasets/__init__.py @@ -3,4 +3,4 @@ if is_tf_available(): from .tensorflow import * elif is_torch_available(): - from .pytorch import * # type: ignore[misc] + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/datasets/generator/__init__.py b/doctr/datasets/generator/__init__.py index 059f261e82..c7110f5669 100644 --- a/doctr/datasets/generator/__init__.py +++ b/doctr/datasets/generator/__init__.py @@ -3,4 +3,4 @@ if is_tf_available(): from .tensorflow import * elif is_torch_available(): - from .pytorch import * # type: ignore[misc] + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/classification/magc_resnet/__init__.py b/doctr/models/classification/magc_resnet/__init__.py index 059f261e82..c7110f5669 100644 --- a/doctr/models/classification/magc_resnet/__init__.py +++ b/doctr/models/classification/magc_resnet/__init__.py @@ -3,4 +3,4 @@ if is_tf_available(): from .tensorflow import * elif is_torch_available(): - from .pytorch import * # type: ignore[misc] + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/classification/predictor/__init__.py b/doctr/models/classification/predictor/__init__.py index 059f261e82..c7110f5669 100644 --- a/doctr/models/classification/predictor/__init__.py +++ b/doctr/models/classification/predictor/__init__.py @@ -3,4 +3,4 @@ if is_tf_available(): from .tensorflow import * elif is_torch_available(): - from .pytorch import * # type: ignore[misc] + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/classification/resnet/__init__.py b/doctr/models/classification/resnet/__init__.py index 059f261e82..c7110f5669 100644 --- a/doctr/models/classification/resnet/__init__.py +++ b/doctr/models/classification/resnet/__init__.py @@ -3,4 +3,4 @@ if is_tf_available(): from .tensorflow import * elif is_torch_available(): - from .pytorch import * # type: ignore[misc] + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/classification/vit/__init__.py b/doctr/models/classification/vit/__init__.py index 059f261e82..c7110f5669 100644 --- a/doctr/models/classification/vit/__init__.py +++ b/doctr/models/classification/vit/__init__.py @@ -3,4 +3,4 @@ if is_tf_available(): from .tensorflow import * elif is_torch_available(): - from .pytorch import * # type: ignore[misc] + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/detection/differentiable_binarization/__init__.py b/doctr/models/detection/differentiable_binarization/__init__.py index 059f261e82..c7110f5669 100644 --- a/doctr/models/detection/differentiable_binarization/__init__.py +++ b/doctr/models/detection/differentiable_binarization/__init__.py @@ -3,4 +3,4 @@ if is_tf_available(): from .tensorflow import * elif is_torch_available(): - from .pytorch import * # type: ignore[misc] + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index 611ec6ec07..c12c191588 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -120,6 +120,7 @@ def __init__( head_chans: int = 256, deform_conv: bool = False, num_classes: int = 1, + bin_thresh: float = 0.3, assume_straight_pages: bool = True, exportable: bool = False, cfg: Optional[Dict[str, Any]] = None, @@ -166,7 +167,7 @@ def __init__( nn.ConvTranspose2d(head_chans // 4, num_classes, 2, stride=2), ) - self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages) + self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh) for n, m in self.named_modules(): # Don't override the initialization of the backbone diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index 264bf946f8..ee8c3ea01f 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -122,6 +122,7 @@ def __init__( feature_extractor: IntermediateLayerGetter, fpn_channels: int = 128, # to be set to 256 to represent the author's initial idea num_classes: int = 1, + bin_thresh: float = 0.3, assume_straight_pages: bool = True, exportable: bool = False, cfg: Optional[Dict[str, Any]] = None, @@ -158,7 +159,7 @@ def __init__( ] ) - self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages) + self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh) def compute_loss(self, out_map: tf.Tensor, thresh_map: tf.Tensor, target: List[np.ndarray]) -> tf.Tensor: """Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes diff --git a/doctr/models/detection/linknet/__init__.py b/doctr/models/detection/linknet/__init__.py index 059f261e82..c7110f5669 100644 --- a/doctr/models/detection/linknet/__init__.py +++ b/doctr/models/detection/linknet/__init__.py @@ -3,4 +3,4 @@ if is_tf_available(): from .tensorflow import * elif is_torch_available(): - from .pytorch import * # type: ignore[misc] + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/detection/linknet/pytorch.py b/doctr/models/detection/linknet/pytorch.py index d430a98224..34c493fe2e 100644 --- a/doctr/models/detection/linknet/pytorch.py +++ b/doctr/models/detection/linknet/pytorch.py @@ -102,6 +102,7 @@ def __init__( self, feat_extractor: IntermediateLayerGetter, num_classes: int = 1, + bin_thresh: float = 0.1, head_chans: int = 32, assume_straight_pages: bool = True, exportable: bool = False, @@ -139,7 +140,9 @@ def __init__( nn.ConvTranspose2d(head_chans, num_classes, kernel_size=2, stride=2), ) - self.postprocessor = LinkNetPostProcessor(assume_straight_pages=self.assume_straight_pages) + self.postprocessor = LinkNetPostProcessor( + assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh + ) for n, m in self.named_modules(): # Don't override the initialization of the backbone diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py index f68a4a72d9..1b35ef2b1e 100644 --- a/doctr/models/detection/linknet/tensorflow.py +++ b/doctr/models/detection/linknet/tensorflow.py @@ -120,6 +120,7 @@ def __init__( feat_extractor: IntermediateLayerGetter, fpn_channels: int = 64, num_classes: int = 1, + bin_thresh: float = 0.1, assume_straight_pages: bool = True, exportable: bool = False, cfg: Optional[Dict[str, Any]] = None, @@ -159,7 +160,7 @@ def __init__( ] ) - self.postprocessor = LinkNetPostProcessor(assume_straight_pages=assume_straight_pages) + self.postprocessor = LinkNetPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh) def compute_loss( self, diff --git a/doctr/models/detection/predictor/__init__.py b/doctr/models/detection/predictor/__init__.py index 6a3fee30ac..ff30c3b2e7 100644 --- a/doctr/models/detection/predictor/__init__.py +++ b/doctr/models/detection/predictor/__init__.py @@ -3,4 +3,4 @@ if is_tf_available(): from .tensorflow import * else: - from .pytorch import * # type: ignore[misc] + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/factory/hub.py b/doctr/models/factory/hub.py index 1796f91963..7b6d2687c0 100644 --- a/doctr/models/factory/hub.py +++ b/doctr/models/factory/hub.py @@ -13,7 +13,13 @@ from pathlib import Path from typing import Any -from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, snapshot_download +from huggingface_hub import ( # type: ignore[attr-defined] + HfApi, + HfFolder, + Repository, + hf_hub_download, + snapshot_download, +) from doctr import models from doctr.file_utils import is_tf_available, is_torch_available diff --git a/doctr/models/modules/transformer/__init__.py b/doctr/models/modules/transformer/__init__.py index 059f261e82..c7110f5669 100644 --- a/doctr/models/modules/transformer/__init__.py +++ b/doctr/models/modules/transformer/__init__.py @@ -3,4 +3,4 @@ if is_tf_available(): from .tensorflow import * elif is_torch_available(): - from .pytorch import * # type: ignore[misc] + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/modules/vision_transformer/__init__.py b/doctr/models/modules/vision_transformer/__init__.py index 059f261e82..c7110f5669 100644 --- a/doctr/models/modules/vision_transformer/__init__.py +++ b/doctr/models/modules/vision_transformer/__init__.py @@ -3,4 +3,4 @@ if is_tf_available(): from .tensorflow import * elif is_torch_available(): - from .pytorch import * # type: ignore[misc] + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/modules/vision_transformer/pytorch.py b/doctr/models/modules/vision_transformer/pytorch.py index 28090cec6e..273ca5e3cc 100644 --- a/doctr/models/modules/vision_transformer/pytorch.py +++ b/doctr/models/modules/vision_transformer/pytorch.py @@ -27,8 +27,8 @@ def __init__(self, input_shape: Tuple[int, int, int], embed_dim: int) -> None: self.grid_size = (self.patch_size[0], self.patch_size[1]) self.num_patches = self.patch_size[0] * self.patch_size[1] - self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) # type: ignore[attr-defined] - self.positions = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim)) # type: ignore[attr-defined] + self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) + self.positions = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim)) self.proj = nn.Linear((channels * self.patch_size[0] * self.patch_size[1]), embed_dim) def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: diff --git a/doctr/models/predictor/__init__.py b/doctr/models/predictor/__init__.py index 6a3fee30ac..ff30c3b2e7 100644 --- a/doctr/models/predictor/__init__.py +++ b/doctr/models/predictor/__init__.py @@ -3,4 +3,4 @@ if is_tf_available(): from .tensorflow import * else: - from .pytorch import * # type: ignore[misc] + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/preprocessor/__init__.py b/doctr/models/preprocessor/__init__.py index 059f261e82..c7110f5669 100644 --- a/doctr/models/preprocessor/__init__.py +++ b/doctr/models/preprocessor/__init__.py @@ -3,4 +3,4 @@ if is_tf_available(): from .tensorflow import * elif is_torch_available(): - from .pytorch import * # type: ignore[misc] + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/recognition/crnn/__init__.py b/doctr/models/recognition/crnn/__init__.py index 059f261e82..c7110f5669 100644 --- a/doctr/models/recognition/crnn/__init__.py +++ b/doctr/models/recognition/crnn/__init__.py @@ -3,4 +3,4 @@ if is_tf_available(): from .tensorflow import * elif is_torch_available(): - from .pytorch import * # type: ignore[misc] + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/recognition/master/__init__.py b/doctr/models/recognition/master/__init__.py index 059f261e82..c7110f5669 100644 --- a/doctr/models/recognition/master/__init__.py +++ b/doctr/models/recognition/master/__init__.py @@ -3,4 +3,4 @@ if is_tf_available(): from .tensorflow import * elif is_torch_available(): - from .pytorch import * # type: ignore[misc] + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/recognition/predictor/__init__.py b/doctr/models/recognition/predictor/__init__.py index 6a3fee30ac..ff30c3b2e7 100644 --- a/doctr/models/recognition/predictor/__init__.py +++ b/doctr/models/recognition/predictor/__init__.py @@ -3,4 +3,4 @@ if is_tf_available(): from .tensorflow import * else: - from .pytorch import * # type: ignore[misc] + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/recognition/sar/__init__.py b/doctr/models/recognition/sar/__init__.py index 059f261e82..c7110f5669 100644 --- a/doctr/models/recognition/sar/__init__.py +++ b/doctr/models/recognition/sar/__init__.py @@ -3,4 +3,4 @@ if is_tf_available(): from .tensorflow import * elif is_torch_available(): - from .pytorch import * # type: ignore[misc] + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/recognition/vitstr/__init__.py b/doctr/models/recognition/vitstr/__init__.py index 059f261e82..c7110f5669 100644 --- a/doctr/models/recognition/vitstr/__init__.py +++ b/doctr/models/recognition/vitstr/__init__.py @@ -3,4 +3,4 @@ if is_tf_available(): from .tensorflow import * elif is_torch_available(): - from .pytorch import * # type: ignore[misc] + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/utils/__init__.py b/doctr/models/utils/__init__.py index 059f261e82..c7110f5669 100644 --- a/doctr/models/utils/__init__.py +++ b/doctr/models/utils/__init__.py @@ -3,4 +3,4 @@ if is_tf_available(): from .tensorflow import * elif is_torch_available(): - from .pytorch import * # type: ignore[misc] + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/transforms/modules/__init__.py b/doctr/transforms/modules/__init__.py index 1950176a6d..4053ff5520 100644 --- a/doctr/transforms/modules/__init__.py +++ b/doctr/transforms/modules/__init__.py @@ -5,4 +5,4 @@ if is_tf_available(): from .tensorflow import * elif is_torch_available(): - from .pytorch import * # type: ignore[misc] + from .pytorch import * # type: ignore[assignment] diff --git a/scripts/detect_text.py b/scripts/detect_text.py index f634fff82a..7aa8b834fd 100644 --- a/scripts/detect_text.py +++ b/scripts/detect_text.py @@ -14,7 +14,7 @@ from doctr.file_utils import is_tf_available from doctr.io import DocumentFile -from doctr.models import ocr_predictor +from doctr.models import detection, ocr_predictor # Enable GPU growth if using TF if is_tf_available(): @@ -60,7 +60,12 @@ def _process_file(model, file_path: Path, out_format: str) -> None: def main(args): - model = ocr_predictor(args.detection, args.recognition, pretrained=True) + + detection_model = detection.__dict__[args.detection]( + pretrained=True, + bin_thresh=args.bin_thresh, + ) + model = ocr_predictor(detection_model, args.recognition, pretrained=True) path = Path(args.path) os.makedirs(name="output", exist_ok=True) @@ -82,6 +87,7 @@ def parse_args(): ) parser.add_argument("path", type=str, help="Path to process: PDF, image, directory") parser.add_argument("--detection", type=str, default="db_resnet50", help="Text detection model to use for analysis") + parser.add_argument("--bin-thresh", type=float, default=0.3, help="Binarization threshold for the detection model.") parser.add_argument( "--recognition", type=str, default="crnn_vgg16_bn", help="Text recognition model to use for analysis" )