Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[models] ViT add checkpoints and some rework to use pretrained ViT backbone in ViTSTR #1072

Merged
merged 15 commits into from
Sep 26, 2022
54 changes: 29 additions & 25 deletions doctr/models/classification/vit/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": None,
"url": "https://github.com/mindee/doctr/releases/download/v0.5.1/vit_b-13bbe405.pt",
},
"vit_s": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": None,
"url": "https://github.com/mindee/doctr/releases/download/v0.5.1/vit_s-ff3c4666.pt",
},
}

Expand Down Expand Up @@ -69,7 +69,6 @@ class VisionTransformer(nn.Sequential):
num_heads: number of attention heads
ffd_ratio: multiplier for the hidden dimension of the feedforward layer
input_shape: size of the input image
patch_size: size of the patches to be extracted from the input
dropout: dropout rate
num_classes: number of output classes
include_top: whether the classifier head should be instantiated
Expand All @@ -82,15 +81,14 @@ def __init__(
num_heads: int,
ffd_ratio: int,
input_shape: Tuple[int, int, int] = (3, 32, 32),
patch_size: Tuple[int, int] = (4, 4),
dropout: float = 0.0,
num_classes: int = 1000,
include_top: bool = True,
cfg: Optional[Dict[str, Any]] = None,
) -> None:

_layers: List[nn.Module] = [
PatchEmbedding(input_shape, patch_size, d_model),
PatchEmbedding(input_shape, d_model),
EncoderBlock(num_layers, num_heads, d_model, d_model * ffd_ratio, dropout, nn.GELU()),
]
if include_top:
Expand All @@ -103,7 +101,7 @@ def __init__(
def _vit(
arch: str,
pretrained: bool,
ignore_keys: Optional[List[str]] = None,
ignore_keys: Optional[List[str]] = [],
**kwargs: Any,
) -> VisionTransformer:

Expand All @@ -123,20 +121,28 @@ def _vit(
if pretrained:
# The number of classes is not the same as the number of classes in the pretrained model =>
# remove the last layer weights
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else []
# The model is used as a feature extractor => remove the patch embedding and position weights
_ignore_keys = (
_ignore_keys + ["0.positions", "0.proj.weight"] # type: ignore
if kwargs["input_shape"] != default_cfgs[arch]["input_shape"]
else _ignore_keys
)
load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)

return model


def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
"""VisionTransformer-B architecture as described in
def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
"""VisionTransformer-S architecture
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
<https://arxiv.org/pdf/2010.11929.pdf>`_.

NOTE: unofficial config used in ViTSTR and ParSeq

>>> import torch
>>> from doctr.models import vit_b
>>> model = vit_b(pretrained=False)
>>> from doctr.models import vit_s
>>> model = vit_s(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 32, 32), dtype=tf.float32)
>>> out = model(input_tensor)

Expand All @@ -148,27 +154,25 @@ def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
"""

return _vit(
"vit_b",
"vit_s",
pretrained,
d_model=768,
d_model=384,
num_layers=12,
num_heads=12,
num_heads=6,
ffd_ratio=4,
ignore_keys=["head.weight", "head.bias"],
ignore_keys=["2.head.weight", "2.head.bias"],
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
)


def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
"""VisionTransformer-S architecture
def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
"""VisionTransformer-B architecture as described in
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
<https://arxiv.org/pdf/2010.11929.pdf>`_.

NOTE: unofficial config used in ViTSTR and ParSeq

>>> import torch
>>> from doctr.models import vit_s
>>> model = vit_s(pretrained=False)
>>> from doctr.models import vit_b
>>> model = vit_b(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 32, 32), dtype=tf.float32)
>>> out = model(input_tensor)

Expand All @@ -180,12 +184,12 @@ def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
"""

return _vit(
"vit_s",
"vit_b",
pretrained,
d_model=384,
d_model=768,
num_layers=12,
num_heads=6,
num_heads=12,
ffd_ratio=4,
ignore_keys=["head.weight", "head.bias"],
ignore_keys=["2.head.weight", "2.head.bias"],
**kwargs,
)
10 changes: 4 additions & 6 deletions doctr/models/classification/vit/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": None,
"url": "https://github.com/mindee/doctr/releases/download/v0.5.1/vit_s-f87ad69c.zip",
},
"vit_b": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": None,
"url": "https://github.com/mindee/doctr/releases/download/v0.5.1/vit_b-71da99f5.zip",
},
}

Expand Down Expand Up @@ -66,7 +66,6 @@ class VisionTransformer(Sequential):
num_heads: number of attention heads
ffd_ratio: multiplier for the hidden dimension of the feedforward layer
input_shape: size of the input image
patch_size: size of the patches to be extracted from the input
dropout: dropout rate
num_classes: number of output classes
include_top: whether the classifier head should be instantiated
Expand All @@ -79,15 +78,14 @@ def __init__(
num_heads: int,
ffd_ratio: int,
input_shape: Tuple[int, int, int] = (32, 32, 3),
patch_size: Tuple[int, int] = (4, 4),
dropout: float = 0.0,
num_classes: int = 1000,
include_top: bool = True,
cfg: Optional[Dict[str, Any]] = None,
) -> None:

_layers = [
PatchEmbedding(input_shape, patch_size, d_model),
PatchEmbedding(input_shape, d_model),
EncoderBlock(num_layers, num_heads, d_model, d_model * ffd_ratio, dropout, activation_fct=GELU()),
]
if include_top:
Expand Down Expand Up @@ -129,7 +127,7 @@ def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:

NOTE: unofficial config used in ViTSTR and ParSeq

>>> import tf
>>> import tensorflow as tf
>>> from doctr.models import vit_s
>>> model = vit_s(pretrained=False)
>>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
Expand Down
16 changes: 6 additions & 10 deletions doctr/models/modules/vision_transformer/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,15 @@
class PatchEmbedding(nn.Module):
"""Compute 2D patch embeddings with cls token and positional encoding"""

def __init__(
self,
input_shape: Tuple[int, int, int],
patch_size: Tuple[int, int],
embed_dim: int,
) -> None:
def __init__(self, input_shape: Tuple[int, int, int], embed_dim: int) -> None:

super().__init__()
channels, height, width = input_shape
# fix patch size if recognition task with 32x128 input
self.patch_size = (4, 8) if height != width else patch_size
self.grid_size = (height // patch_size[0], width // patch_size[1])
self.num_patches = (height // patch_size[0]) * (width // patch_size[1])
# calculate patch size 32x32 -> (4, 4) 32x128 -> (4, 16)
# NOTE: this is different from the original implementation
self.patch_size = (height // 8, width // 8)
self.grid_size = (height // self.patch_size[0], width // self.patch_size[1])
self.num_patches = (height // self.patch_size[0]) * (width // self.patch_size[1])

self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) # type: ignore[attr-defined]
self.positions = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim)) # type: ignore[attr-defined]
Expand Down
16 changes: 6 additions & 10 deletions doctr/models/modules/vision_transformer/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,15 @@
class PatchEmbedding(layers.Layer, NestedObject):
"""Compute 2D patch embeddings with cls token and positional encoding"""

def __init__(
self,
input_shape: Tuple[int, int, int],
patch_size: Tuple[int, int],
embed_dim: int,
) -> None:
def __init__(self, input_shape: Tuple[int, int, int], embed_dim: int) -> None:

super().__init__()
height, width, _ = input_shape
# fix patch size if recognition task with 32x128 input
self.patch_size = (4, 8) if height != width else patch_size
self.grid_size = (height // patch_size[0], width // patch_size[1])
self.num_patches = (height // patch_size[0]) * (width // patch_size[1])
# calculate patch size 32x32 -> (4, 4) 32x128 -> (4, 16)
# NOTE: this is different from the original implementation
self.patch_size = (height // 8, width // 8)
self.grid_size = (height // self.patch_size[0], width // self.patch_size[1])
self.num_patches = (height // self.patch_size[0]) * (width // self.patch_size[1])

self.cls_token = self.add_weight(shape=(1, 1, embed_dim), initializer="zeros", trainable=True, name="cls_token")
self.positions = self.add_weight(
Expand Down
2 changes: 0 additions & 2 deletions doctr/models/recognition/vitstr/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class ViTSTR(_ViTSTR, nn.Module):
max_length: maximum word length handled by the model
dropout_prob: dropout probability of the encoder LSTM
input_shape: input shape of the image
patch_size: size of the patches
exportable: onnx exportable returns only logits
cfg: dictionary containing information about the model
"""
Expand All @@ -60,7 +59,6 @@ def __init__(
embedding_units: int,
max_length: int = 25,
input_shape: Tuple[int, int, int] = (3, 32, 128), # different from paper
patch_size: Tuple[int, int] = (4, 8), # different from paper to match our size
exportable: bool = False,
cfg: Optional[Dict[str, Any]] = None,
) -> None:
Expand Down
7 changes: 2 additions & 5 deletions doctr/models/recognition/vitstr/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ class ViTSTR(_ViTSTR, Model):
max_length: maximum word length handled by the model
dropout_prob: dropout probability for the encoder and decoder
input_shape: input shape of the image
patch_size: size of the patches
exportable: onnx exportable returns only logits
cfg: dictionary containing information about the model
"""
Expand All @@ -61,7 +60,6 @@ def __init__(
max_length: int = 25,
dropout_prob: float = 0.0,
input_shape: Tuple[int, int, int] = (32, 128, 3), # different from paper
patch_size: Tuple[int, int] = (4, 8), # different from paper to match our size
exportable: bool = False,
cfg: Optional[Dict[str, Any]] = None,
) -> None:
Expand Down Expand Up @@ -202,10 +200,9 @@ def _vitstr(

kwargs["vocab"] = _cfg["vocab"]

# Feature extractor
# NOTE: switch to IntermediateLayerGetter if pretrained vit models are available
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
# feature extractor
feat_extractor = backbone_fn(
pretrained=pretrained_backbone,
pretrained=False, # TODO: pretrained_backbone, solve weights shape mismatch
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
input_shape=_cfg["input_shape"],
include_top=False,
)
Expand Down