Skip to content

Commit

Permalink
feat: Added rectangular stride MobileNets (#483)
Browse files Browse the repository at this point in the history
* feat: Added rectangular stride MobileNets

* fix: Fixed CRNN mobilenets

* style: Fixed typing

* chore: Added missing entries in __all__

* fix: Fixed rect stride of MobileNets
  • Loading branch information
fg-mindee authored Sep 20, 2021
1 parent ab7d483 commit a73e24b
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 29 deletions.
78 changes: 73 additions & 5 deletions doctr/models/backbones/mobilenet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

# Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py

from torch import nn
from torchvision.models import mobilenetv3
from typing import Any, Dict
from doctr.datasets import VOCABS
from ...utils import load_pretrained_params


__all__ = ["mobilenet_v3_small", "mobilenet_v3_large"]
__all__ = ["mobilenet_v3_small", "mobilenet_v3_small_r", "mobilenet_v3_large", "mobilenet_v3_large_r"]


default_cfgs: Dict[str, Dict[str, Any]] = {
Expand All @@ -22,13 +23,29 @@
'vocab': VOCABS['legacy_french'],
'url': 'https://github.com/mindee/doctr/releases/download/v0.3.0/mobilenet_v3_large-a0aea820.pt',
},
'mobilenet_v3_large_r': {
'mean': (0.694, 0.695, 0.693),
'std': (0.299, 0.296, 0.301),
'input_shape': (3, 32, 32),
'rect_stride': ['features.4.block.1.0', 'features.7.block.1.0', 'features.13.block.1.0'],
'vocab': VOCABS['french'],
'url': None,
},
'mobilenet_v3_small': {
'mean': (0.694, 0.695, 0.693),
'std': (0.299, 0.296, 0.301),
'input_shape': (3, 32, 32),
'vocab': VOCABS['legacy_french'],
'url': 'https://github.com/mindee/doctr/releases/download/v0.3.0/mobilenet_v3_small-69c7267d.pt',
}
},
'mobilenet_v3_small_r': {
'mean': (0.694, 0.695, 0.693),
'std': (0.299, 0.296, 0.301),
'input_shape': (3, 32, 32),
'rect_stride': ['features.2.block.1.0', 'features.4.block.1.0', 'features.9.block.1.0'],
'vocab': VOCABS['french'],
'url': None,
},
}


Expand All @@ -40,11 +57,19 @@ def _mobilenet_v3(

kwargs['num_classes'] = kwargs.get('num_classes', len(default_cfgs[arch]['vocab']))

if arch == "mobilenet_v3_small":
if arch.startswith("mobilenet_v3_small"):
model = mobilenetv3.mobilenet_v3_small(**kwargs)
else:
model = mobilenetv3.mobilenet_v3_large(**kwargs)

# Rectangular strides
if isinstance(default_cfgs[arch].get('rect_stride'), list):
for layer_name in default_cfgs[arch]['rect_stride']:
m = model
for child in layer_name.split('.'):
m = getattr(m, child)
m.stride = (2, 1)

# Load pretrained parameters
if pretrained:
load_pretrained_params(model, default_cfgs[arch]['url'])
Expand All @@ -68,12 +93,34 @@ def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.M
pretrained: boolean, True if model is pretrained
Returns:
A mobilenetv3_small model
a torch.nn.Module
"""

return _mobilenet_v3('mobilenet_v3_small', pretrained, **kwargs)


def mobilenet_v3_small_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
"""MobileNetV3-Small architecture as described in
`"Searching for MobileNetV3",
<https://arxiv.org/pdf/1905.02244.pdf>`_, with rectangular pooling.
Example::
>>> import torch
>>> from doctr.models import mobilenet_v3_small_r
>>> model = mobilenet_v3_small_r(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 32, 32), dtype=torch.float32)
>>> out = model(input_tensor)
Args:
pretrained: boolean, True if model is pretrained
Returns:
a torch.nn.Module
"""

return _mobilenet_v3('mobilenet_v3_small_r', pretrained, **kwargs)


def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
"""MobileNetV3-Large architecture as described in
`"Searching for MobileNetV3",
Expand All @@ -90,6 +137,27 @@ def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.M
pretrained: boolean, True if model is pretrained
Returns:
A mobilenetv3_large model
a torch.nn.Module
"""
return _mobilenet_v3('mobilenet_v3_large', pretrained, **kwargs)


def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
"""MobileNetV3-Large architecture as described in
`"Searching for MobileNetV3",
<https://arxiv.org/pdf/1905.02244.pdf>`_, with rectangular pooling.
Example::
>>> import torch
>>> from doctr.models import mobilenet_v3_large_r
>>> model = mobilenet_v3_large_r(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 32, 32), dtype=torch.float32)
>>> out = model(input_tensor)
Args:
pretrained: boolean, True if model is pretrained
Returns:
a torch.nn.Module
"""
return _mobilenet_v3('mobilenet_v3_large_r', pretrained, **kwargs)
85 changes: 72 additions & 13 deletions doctr/models/backbones/mobilenet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from typing import Optional, Tuple, Any, Dict, List
from typing import Optional, Tuple, Any, Dict, List, Union
from ...utils import conv_sequence, load_pretrained_params
from ....datasets import VOCABS


__all__ = ["MobileNetV3", "mobilenet_v3_small", "mobilenet_v3_large"]
__all__ = ["MobileNetV3", "mobilenet_v3_small", "mobilenet_v3_small_r", "mobilenet_v3_large",
"mobilenet_v3_large_r"]


default_cfgs: Dict[str, Dict[str, Any]] = {
Expand All @@ -24,12 +25,26 @@
'vocab': VOCABS['legacy_french'],
'url': 'https://github.com/mindee/doctr/releases/download/v0.3.0/mobilenet_v3_large-d27d66f2.zip'
},
'mobilenet_v3_large_r': {
'mean': (0.694, 0.695, 0.693),
'std': (0.299, 0.296, 0.301),
'input_shape': (32, 32, 3),
'vocab': VOCABS['french'],
'url': None,
},
'mobilenet_v3_small': {
'mean': (0.694, 0.695, 0.693),
'std': (0.299, 0.296, 0.301),
'input_shape': (32, 32, 3),
'vocab': VOCABS['legacy_french'],
'url': 'https://github.com/mindee/doctr/releases/download/v0.3.0/mobilenet_v3_small-d624c4de.zip'
},
'mobilenet_v3_small_r': {
'mean': (0.694, 0.695, 0.693),
'std': (0.299, 0.296, 0.301),
'input_shape': (32, 32, 3),
'vocab': VOCABS['french'],
'url': None,
}
}

Expand Down Expand Up @@ -76,7 +91,7 @@ def __init__(
out_channels: int,
use_se: bool,
activation: str,
stride: int,
stride: Union[int, Tuple[int, int]],
width_mult: float = 1,
) -> None:
self.input_channels = self.adjust_channels(input_channels, width_mult)
Expand Down Expand Up @@ -108,7 +123,8 @@ def __init__(

act_fn = hard_swish if conf.use_hs else tf.nn.relu

self.use_res_connect = conf.stride == 1 and conf.input_channels == conf.out_channels
_is_s1 = (isinstance(conf.stride, tuple) and conf.stride == (1, 1)) or conf.stride == 1
self.use_res_connect = _is_s1 and conf.input_channels == conf.out_channels

_layers = []
# expand
Expand Down Expand Up @@ -196,17 +212,17 @@ def _mobilenet_v3(
input_shape = input_shape or default_cfgs[arch]['input_shape']

# cf. Table 1 & 2 of the paper
if arch == "mobilenet_v3_small":
if arch.startswith("mobilenet_v3_small"):
inverted_residual_setting = [
InvertedResidualConfig(16, 3, 16, 16, True, "RE", 2), # C1
InvertedResidualConfig(16, 3, 72, 24, False, "RE", 2), # C2
InvertedResidualConfig(16, 3, 72, 24, False, "RE", (2, 1) if arch.endswith("_r") else 2), # C2
InvertedResidualConfig(24, 3, 88, 24, False, "RE", 1),
InvertedResidualConfig(24, 5, 96, 40, True, "HS", 2), # C3
InvertedResidualConfig(24, 5, 96, 40, True, "HS", (2, 1) if arch.endswith("_r") else 2), # C3
InvertedResidualConfig(40, 5, 240, 40, True, "HS", 1),
InvertedResidualConfig(40, 5, 240, 40, True, "HS", 1),
InvertedResidualConfig(40, 5, 120, 48, True, "HS", 1),
InvertedResidualConfig(48, 5, 144, 48, True, "HS", 1),
InvertedResidualConfig(48, 5, 288, 96, True, "HS", 2), # C4
InvertedResidualConfig(48, 5, 288, 96, True, "HS", (2, 1) if arch.endswith("_r") else 2), # C4
InvertedResidualConfig(96, 5, 576, 96, True, "HS", 1),
InvertedResidualConfig(96, 5, 576, 96, True, "HS", 1),
]
Expand All @@ -216,16 +232,16 @@ def _mobilenet_v3(
InvertedResidualConfig(16, 3, 16, 16, False, "RE", 1),
InvertedResidualConfig(16, 3, 64, 24, False, "RE", 2), # C1
InvertedResidualConfig(24, 3, 72, 24, False, "RE", 1),
InvertedResidualConfig(24, 5, 72, 40, True, "RE", 2), # C2
InvertedResidualConfig(24, 5, 72, 40, True, "RE", (2, 1) if arch.endswith("_r") else 2), # C2
InvertedResidualConfig(40, 5, 120, 40, True, "RE", 1),
InvertedResidualConfig(40, 5, 120, 40, True, "RE", 1),
InvertedResidualConfig(40, 3, 240, 80, False, "HS", 2), # C3
InvertedResidualConfig(40, 3, 240, 80, False, "HS", (2, 1) if arch.endswith("_r") else 2), # C3
InvertedResidualConfig(80, 3, 200, 80, False, "HS", 1),
InvertedResidualConfig(80, 3, 184, 80, False, "HS", 1),
InvertedResidualConfig(80, 3, 184, 80, False, "HS", 1),
InvertedResidualConfig(80, 3, 480, 112, True, "HS", 1),
InvertedResidualConfig(112, 3, 672, 112, True, "HS", 1),
InvertedResidualConfig(112, 5, 672, 160, True, "HS", 2), # C4
InvertedResidualConfig(112, 5, 672, 160, True, "HS", (2, 1) if arch.endswith("_r") else 2), # C4
InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1),
InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1),
]
Expand Down Expand Up @@ -263,12 +279,34 @@ def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
pretrained: boolean, True if model is pretrained
Returns:
A mobilenetv3_small model
a keras.Model
"""

return _mobilenet_v3('mobilenet_v3_small', pretrained, **kwargs)


def mobilenet_v3_small_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
"""MobileNetV3-Small architecture as described in
`"Searching for MobileNetV3",
<https://arxiv.org/pdf/1905.02244.pdf>`_, with rectangular pooling.
Example::
>>> import tensorflow as tf
>>> from doctr.models import mobilenet_v3_small_r
>>> model = mobilenet_v3_small_r(pretrained=False)
>>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
>>> out = model(input_tensor)
Args:
pretrained: boolean, True if model is pretrained
Returns:
a keras.Model
"""

return _mobilenet_v3('mobilenet_v3_small_r', pretrained, **kwargs)


def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
"""MobileNetV3-Large architecture as described in
`"Searching for MobileNetV3",
Expand All @@ -285,6 +323,27 @@ def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
pretrained: boolean, True if model is pretrained
Returns:
A mobilenetv3_large model
a keras.Model
"""
return _mobilenet_v3('mobilenet_v3_large', pretrained, **kwargs)


def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
"""MobileNetV3-Large architecture as described in
`"Searching for MobileNetV3",
<https://arxiv.org/pdf/1905.02244.pdf>`_.
Example::
>>> import tensorflow as tf
>>> from doctr.models import mobilenet_v3_large_r
>>> model = mobilenet_v3_large_r(pretrained=False)
>>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
>>> out = model(input_tensor)
Args:
pretrained: boolean, True if model is pretrained
Returns:
a keras.Model
"""
return _mobilenet_v3('mobilenet_v3_large_r', pretrained, **kwargs)
11 changes: 5 additions & 6 deletions doctr/models/recognition/crnn/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.models import mobilenet_v3_small, mobilenet_v3_large
from typing import Tuple, Dict, Any, Optional, List

from ...backbones import vgg16_bn, resnet31
from ...backbones import vgg16_bn, resnet31, mobilenet_v3_small_r, mobilenet_v3_large_r
from ...utils import load_pretrained_params
from ..core import RecognitionModel, RecognitionPostProcessor
from ....datasets import VOCABS
Expand All @@ -31,17 +30,17 @@
'crnn_mobilenet_v3_small': {
'mean': (.5, .5, .5),
'std': (1., 1., 1.),
'backbone': mobilenet_v3_small, 'rnn_units': 128, 'lstm_features': 576,
'backbone': mobilenet_v3_small_r, 'rnn_units': 128, 'lstm_features': 576,
'input_shape': (3, 32, 128),
'vocab': VOCABS['legacy_french'],
'vocab': VOCABS['french'],
'url': None,
},
'crnn_mobilenet_v3_large': {
'mean': (.5, .5, .5),
'std': (1., 1., 1.),
'backbone': mobilenet_v3_large, 'rnn_units': 128, 'lstm_features': 960,
'backbone': mobilenet_v3_large_r, 'rnn_units': 128, 'lstm_features': 960,
'input_shape': (3, 32, 128),
'vocab': VOCABS['legacy_french'],
'vocab': VOCABS['french'],
'url': None,
},
}
Expand Down
10 changes: 5 additions & 5 deletions doctr/models/recognition/crnn/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tensorflow.keras.models import Sequential, Model
from typing import Tuple, Dict, Any, Optional, List

from ...backbones import vgg16_bn, resnet31, mobilenet_v3_small, mobilenet_v3_large
from ...backbones import vgg16_bn, resnet31, mobilenet_v3_small_r, mobilenet_v3_large_r
from ...utils import load_pretrained_params
from ..core import RecognitionModel, RecognitionPostProcessor
from ....datasets import VOCABS
Expand All @@ -29,17 +29,17 @@
'crnn_mobilenet_v3_small': {
'mean': (0.694, 0.695, 0.693),
'std': (0.299, 0.296, 0.301),
'backbone': mobilenet_v3_small, 'rnn_units': 128,
'backbone': mobilenet_v3_small_r, 'rnn_units': 128,
'input_shape': (32, 128, 3),
'vocab': VOCABS['legacy_french'],
'vocab': VOCABS['french'],
'url': None,
},
'crnn_mobilenet_v3_large': {
'mean': (0.694, 0.695, 0.693),
'std': (0.299, 0.296, 0.301),
'backbone': mobilenet_v3_large, 'rnn_units': 128,
'backbone': mobilenet_v3_large_r, 'rnn_units': 128,
'input_shape': (32, 128, 3),
'vocab': VOCABS['legacy_french'],
'vocab': VOCABS['french'],
'url': None,
},
}
Expand Down

0 comments on commit a73e24b

Please sign in to comment.