Skip to content

Commit

Permalink
[models] add ViTSTR TF and PT and update ViT to work as backbone (#1055)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Sep 21, 2022
1 parent 1a78f4c commit e538cc2
Show file tree
Hide file tree
Showing 20 changed files with 865 additions and 47 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ Credits where it's due: this repository is implementing, among others, architect
- CRNN: [An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition](https://arxiv.org/pdf/1507.05717.pdf).
- SAR: [Show, Attend and Read:A Simple and Strong Baseline for Irregular Text Recognition](https://arxiv.org/pdf/1811.00751.pdf).
- MASTER: [MASTER: Multi-Aspect Non-local Network for Scene Text Recognition](https://arxiv.org/pdf/1910.02562.pdf).
- ViTSTR: [Vision Transformer for Fast and Efficient Scene Text Recognition](https://arxiv.org/pdf/2105.08582.pdf).


## More goodies
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Text recognition models
* SAR from `"Show, Attend and Read: A Simple and Strong Baseline for Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_
* CRNN from `"An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_
* MASTER from `"MASTER: Multi-Aspect Non-local Network for Scene Text Recognition" <https://arxiv.org/pdf/1910.02562.pdf>`_
* ViTSTR from `"Vision Transformer for Fast and Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_


Supported datasets
Expand Down
6 changes: 6 additions & 0 deletions docs/source/modules/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ doctr.models.classification

.. autofunction:: doctr.models.classification.magc_resnet31

.. autofunction:: doctr.models.classification.vit_s

.. autofunction:: doctr.models.classification.vit_b

.. autofunction:: doctr.models.classification.crop_orientation_predictor
Expand Down Expand Up @@ -67,6 +69,10 @@ doctr.models.recognition

.. autofunction:: doctr.models.recognition.master

.. autofunction:: doctr.models.recognition.vitstr_small

.. autofunction:: doctr.models.recognition.vitstr_base

.. autofunction:: doctr.models.recognition.recognition_predictor


Expand Down
65 changes: 54 additions & 11 deletions doctr/models/classification/vit/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,18 @@

from ...utils.pytorch import load_pretrained_params

__all__ = ["vit_b"]
__all__ = ["vit_s", "vit_b"]


default_cfgs: Dict[str, Dict[str, Any]] = {
"vit": {
"vit_b": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": None,
},
"vit_s": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
Expand Down Expand Up @@ -57,25 +64,25 @@ class VisionTransformer(nn.Sequential):
<https://arxiv.org/pdf/2010.11929.pdf>`_.
Args:
input_shape: size of the input image
patch_size: size of the patches to be extracted from the input
d_model: dimension of the transformer layers
num_layers: number of transformer layers
num_heads: number of attention heads
ffd_ratio: multiplier for the hidden dimension of the feedforward layer
input_shape: size of the input image
patch_size: size of the patches to be extracted from the input
dropout: dropout rate
num_classes: number of output classes
include_top: whether the classifier head should be instantiated
"""

def __init__(
self,
d_model: int,
num_layers: int,
num_heads: int,
ffd_ratio: int,
input_shape: Tuple[int, int, int] = (3, 32, 32),
patch_size: Tuple[int, int] = (4, 4),
d_model: int = 768,
num_layers: int = 12,
num_heads: int = 12,
ffd_ratio: int = 4,
dropout: float = 0.0,
num_classes: int = 1000,
include_top: bool = True,
Expand Down Expand Up @@ -128,8 +135,40 @@ def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
<https://arxiv.org/pdf/2010.11929.pdf>`_.
>>> import torch
>>> from doctr.models import vit
>>> model = vit(pretrained=False)
>>> from doctr.models import vit_b
>>> model = vit_b(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 32, 32), dtype=tf.float32)
>>> out = model(input_tensor)
Args:
pretrained: boolean, True if model is pretrained
Returns:
A feature extractor model
"""

return _vit(
"vit_b",
pretrained,
d_model=768,
num_layers=12,
num_heads=12,
ffd_ratio=4,
ignore_keys=["head.weight", "head.bias"],
**kwargs,
)


def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
"""VisionTransformer-S architecture
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
<https://arxiv.org/pdf/2010.11929.pdf>`_.
NOTE: unofficial config used in ViTSTR and ParSeq
>>> import torch
>>> from doctr.models import vit_s
>>> model = vit_s(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 32, 32), dtype=tf.float32)
>>> out = model(input_tensor)
Expand All @@ -141,8 +180,12 @@ def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
"""

return _vit(
"vit",
"vit_s",
pretrained,
d_model=384,
num_layers=12,
num_heads=6,
ffd_ratio=4,
ignore_keys=["head.weight", "head.bias"],
**kwargs,
)
64 changes: 53 additions & 11 deletions doctr/models/classification/vit/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,18 @@

from ...utils import load_pretrained_params

__all__ = ["vit_b"]
__all__ = ["vit_s", "vit_b"]


default_cfgs: Dict[str, Dict[str, Any]] = {
"vit": {
"vit_s": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": None,
},
"vit_b": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
Expand Down Expand Up @@ -54,25 +61,25 @@ class VisionTransformer(Sequential):
<https://arxiv.org/pdf/2010.11929.pdf>`_.
Args:
input_shape: size of the input image
patch_size: size of the patches to be extracted from the input
d_model: dimension of the transformer layers
num_layers: number of transformer layers
num_heads: number of attention heads
ffd_ratio: multiplier for the hidden dimension of the feedforward layer
input_shape: size of the input image
patch_size: size of the patches to be extracted from the input
dropout: dropout rate
num_classes: number of output classes
include_top: whether the classifier head should be instantiated
"""

def __init__(
self,
d_model: int,
num_layers: int,
num_heads: int,
ffd_ratio: int,
input_shape: Tuple[int, int, int] = (32, 32, 3),
patch_size: Tuple[int, int] = (4, 4),
d_model: int = 768,
num_layers: int = 12,
num_heads: int = 12,
ffd_ratio: int = 4,
dropout: float = 0.0,
num_classes: int = 1000,
include_top: bool = True,
Expand Down Expand Up @@ -115,14 +122,45 @@ def _vit(
return model


def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
"""VisionTransformer-S architecture
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
<https://arxiv.org/pdf/2010.11929.pdf>`_.
NOTE: unofficial config used in ViTSTR and ParSeq
>>> import tf
>>> from doctr.models import vit_s
>>> model = vit_s(pretrained=False)
>>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
>>> out = model(input_tensor)
Args:
pretrained: boolean, True if model is pretrained
Returns:
A feature extractor model
"""

return _vit(
"vit_s",
pretrained,
d_model=384,
num_layers=12,
num_heads=6,
ffd_ratio=4,
**kwargs,
)


def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
"""VisionTransformer-B architecture as described in
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
<https://arxiv.org/pdf/2010.11929.pdf>`_.
>>> import tensorflow as tf
>>> from doctr.models import vit
>>> model = vit(pretrained=False)
>>> from doctr.models import vit_b
>>> model = vit_b(pretrained=False)
>>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
>>> out = model(input_tensor)
Expand All @@ -134,7 +172,11 @@ def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
"""

return _vit(
"vit",
"vit_b",
pretrained,
d_model=768,
num_layers=12,
num_heads=12,
ffd_ratio=4,
**kwargs,
)
1 change: 1 addition & 0 deletions doctr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"resnet50",
"resnet34_wide",
"vgg16_bn_r",
"vit_s",
"vit_b",
]
ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_orientation"]
Expand Down
44 changes: 42 additions & 2 deletions doctr/models/modules/vision_transformer/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.


import math
from typing import Tuple

import torch
Expand All @@ -24,14 +25,53 @@ def __init__(

super().__init__()
channels, height, width = input_shape
self.patch_size = patch_size
# fix patch size if recognition task with 32x128 input
self.patch_size = (4, 8) if height != width else patch_size
self.grid_size = (height // patch_size[0], width // patch_size[1])
self.num_patches = (height // patch_size[0]) * (width // patch_size[1])

self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) # type: ignore[attr-defined]
self.positions = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim)) # type: ignore[attr-defined]
self.proj = nn.Linear((channels * self.patch_size[0] * self.patch_size[1]), embed_dim)

def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
100 % borrowed from:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_vit.py
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py
"""

num_patches = embeddings.shape[1] - 1
num_positions = self.positions.shape[1] - 1
if num_patches == num_positions and height == width:
return self.positions
class_pos_embed = self.positions[:, 0]
patch_pos_embed = self.positions[:, 1:]
dim = embeddings.shape[-1]
h0 = float(height // self.patch_size[0])
w0 = float(width // self.patch_size[1])
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
mode="bicubic",
align_corners=False,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
)
assert int(h0) == patch_pos_embed.shape[-2], "height of interpolated patch embedding doesn't match"
assert int(w0) == patch_pos_embed.shape[-1], "width of interpolated patch embedding doesn't match"

patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
assert H % self.patch_size[0] == 0, "Image height must be divisible by patch height"
Expand All @@ -53,6 +93,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# concate cls_tokens to patches
embeddings = torch.cat([cls_tokens, patches], dim=1) # (batch_size, num_patches + 1, d_model)
# add positions to embeddings
embeddings += self.positions # (batch_size, num_patches + 1, d_model)
embeddings += self.interpolate_pos_encoding(embeddings, H, W) # (batch_size, num_patches + 1, d_model)

return embeddings
Loading

0 comments on commit e538cc2

Please sign in to comment.