Skip to content

Commit

Permalink
feat: add rotated linknet_resnet18 tensorflow ckpts (#817)
Browse files Browse the repository at this point in the history
* feat: add dice loss in linknet

* fix: typing

* feat: add rotated linknet_renset18 clpt

* fix: requested changes

* feat: add aspect ratio for ocr predictor

* fix: requested changes

* fix: tests

* fix: default args

* fix: new ckpts

* fix: linknet

* fix: postprocessing

* fix: isort
  • Loading branch information
charlesmindee authored Mar 9, 2022
1 parent a8e1908 commit 436053d
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 4 deletions.
2 changes: 1 addition & 1 deletion doctr/models/detection/linknet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ def bitmap_to_boxes(

if self.assume_straight_pages:
# compute relative polygon to get rid of img shape
x, y, w, h = _box
xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height
boxes.append([xmin, ymin, xmax, ymax, score])
else:
_box = cv2.boxPoints(cv2.minAreaRect(contour))
# compute relative box to get rid of img shape
_box[:, 0] /= width
_box[:, 1] /= height
Expand Down
35 changes: 34 additions & 1 deletion doctr/models/detection/linknet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from .base import LinkNetPostProcessor, _LinkNet

__all__ = ['LinkNet', 'linknet_resnet18', 'linknet_resnet34', 'linknet_resnet50']
__all__ = ['LinkNet', 'linknet_resnet18', 'linknet_resnet34', 'linknet_resnet50', 'linknet_resnet18_rotation']


default_cfgs: Dict[str, Dict[str, Any]] = {
Expand All @@ -29,6 +29,12 @@
'input_shape': (1024, 1024, 3),
'url': None,
},
'linknet_resnet18_rotation': {
'mean': (0.798, 0.785, 0.772),
'std': (0.264, 0.2749, 0.287),
'input_shape': (1024, 1024, 3),
'url': 'https://github.com/mindee/doctr/releases/download/v0.5.0/linknet_resnet18-a48e6ed3.zip',
},
'linknet_resnet34': {
'mean': (0.798, 0.785, 0.772),
'std': (0.264, 0.2749, 0.287),
Expand Down Expand Up @@ -286,6 +292,33 @@ def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
)


def linknet_resnet18_rotation(pretrained: bool = False, **kwargs: Any) -> LinkNet:
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
<https://arxiv.org/pdf/1707.03718.pdf>`_.
Example::
>>> import tensorflow as tf
>>> from doctr.models import linknet_resnet18
>>> model = linknet_resnet18(pretrained=True)
>>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
>>> out = model(input_tensor)
Args:
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
Returns:
text detection architecture
"""

return _linknet(
'linknet_resnet18_rotation',
pretrained,
resnet18,
['resnet_block_1', 'resnet_block_3', 'resnet_block_5', 'resnet_block_7'],
**kwargs,
)


def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
<https://arxiv.org/pdf/1707.03718.pdf>`_.
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/detection/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@


if is_tf_available():
ARCHS = ['db_resnet50', 'db_mobilenet_v3_large', 'linknet_resnet18']
ROT_ARCHS = []
ARCHS = ['db_resnet50', 'db_mobilenet_v3_large', 'linknet_resnet18', 'linknet_resnet18_rotation']
ROT_ARCHS = ['linknet_resnet18_rotation']
elif is_torch_available():
ARCHS = ['db_resnet34', 'db_resnet50', 'db_mobilenet_v3_large', 'linknet_resnet18', 'db_resnet50_rotation']
ROT_ARCHS = ['db_resnet50_rotation']
Expand Down

0 comments on commit 436053d

Please sign in to comment.