Skip to content

Commit

Permalink
feat: Added loading method for PyTorch artefact detection models from…
Browse files Browse the repository at this point in the history
… HF Hub (#836)

* refactor: Refactored FasterRCNN

* feat: Added factory method from_hub

* chore: Updated requirements

* test: Added unittest

* chore: Updated mypy config

* test: Updated unittests

* test: Fixed unittest

* test: Fixed unittest

* feat: Added cfg to model

* test: Fixed unittest
  • Loading branch information
fg-mindee authored Mar 8, 2022
1 parent c9806fa commit 9b31588
Show file tree
Hide file tree
Showing 10 changed files with 83 additions and 10 deletions.
4 changes: 4 additions & 0 deletions doctr/models/obj_detection/factory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from doctr.file_utils import is_torch_available

if is_torch_available():
from .pytorch import *
50 changes: 50 additions & 0 deletions doctr/models/obj_detection/factory/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (C) 2022, Mindee.

# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import json
from typing import Any

import torch
from huggingface_hub import hf_hub_download

from doctr.models import obj_detection

__all__ = ['from_hub']


def from_hub(repo_id: str, **kwargs: Any) -> torch.nn.Module:
"""Instantiate & load a pretrained model from HF hub.
Example::
>>> from doctr.models.obj_detection import from_hub
>>> model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn").eval()
>>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
>>> with torch.no_grad(): out = model(input_tensor)
Args:
repo_id: HuggingFace model hub repo
kwargs: kwargs of `hf_hub_download`
Returns:
Model loaded with the checkpoint
"""

# Get the config
with open(hf_hub_download(repo_id, filename='config.json', **kwargs), 'rb') as f:
cfg = json.load(f)

model = obj_detection.__dict__[cfg['arch']](
pretrained=False,
image_mean=cfg['mean'],
image_std=cfg['std'],
max_size=cfg['input_shape'][-1],
num_classes=len(cfg['classes']),
)

# Load the checkpoint
state_dict = torch.load(hf_hub_download(repo_id, filename='pytorch_model.bin', **kwargs), map_location='cpu')
model.load_state_dict(state_dict)
model.cfg = cfg

return model
8 changes: 3 additions & 5 deletions doctr/models/obj_detection/faster_rcnn/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
'input_shape': (3, 1024, 1024),
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
'anchor_sizes': [32, 64, 128, 256, 512],
'anchor_aspect_ratios': (0.5, 1., 2.),
'num_classes': 5,
'classes': ["background", "qr_code", "bar_code", "logo", "photo"],
'url': 'https://github.com/mindee/doctr/releases/download/v0.4.1/fasterrcnn_mobilenet_v3_large_fpn-d5b2490d.pt',
},
}
Expand All @@ -31,11 +29,11 @@ def _fasterrcnn(arch: str, pretrained: bool, **kwargs: Any) -> FasterRCNN:
"image_mean": default_cfgs[arch]['mean'],
"image_std": default_cfgs[arch]['std'],
"box_detections_per_img": 150,
"box_score_thresh": 0.15,
"box_score_thresh": 0.5,
"box_positive_fraction": 0.35,
"box_nms_thresh": 0.2,
"rpn_nms_thresh": 0.2,
"num_classes": default_cfgs[arch]['num_classes'],
"num_classes": len(default_cfgs[arch]['classes']),
}

# Build the model
Expand Down
4 changes: 4 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,7 @@ ignore_missing_imports = True
[mypy-h5py.*]

ignore_missing_imports = True

[mypy-huggingface_hub.*]

ignore_missing_imports = True
1 change: 1 addition & 0 deletions requirements-pt.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ torchvision>=0.9.0
Pillow>=8.3.2
tqdm>=4.30.0
rapidfuzz>=1.6.0
huggingface-hub>=0.4.0
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ tqdm>=4.30.0
tensorflow-addons>=0.13.0
rapidfuzz>=1.6.0
keras<2.7.0
huggingface-hub>=0.4.0
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
"tensorflow-addons>=0.13.0",
"rapidfuzz>=1.6.0",
"keras<2.7.0",
"huggingface-hub>=0.4.0",
# Testing
"pytest>=5.3.2",
"coverage>=4.5.4",
Expand Down Expand Up @@ -104,6 +105,7 @@ def deps_list(*pkgs):
deps["Pillow"],
deps["tqdm"],
deps["rapidfuzz"],
deps["huggingface-hub"],
]

extras = {}
Expand Down
14 changes: 10 additions & 4 deletions tests/common/test_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,24 @@ def test_headers():
shebang = ["#!usr/bin/python\n"]
blank_line = "\n"

_copyright_str = f"-{datetime.now().year}" if datetime.now().year > 2021 else ""
copyright_notice = [f"# Copyright (C) 2021{_copyright_str}, Mindee.\n"]
starting_year = 2021
current_year = datetime.now().year
year_str = [current_year] + [f"{starting_year}-{current_year}" for year in range(starting_year, current_year)]
if starting_year == current_year:
year_str = year_str[:1]

copyright_notices = [[f"# Copyright (C) {_str}, Mindee.\n"] for _str in year_str]
license_notice = [
"# This program is licensed under the Apache License version 2.\n",
"# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.\n"
]

# Define all header options
headers = [
headers = [[
shebang + [blank_line] + copyright_notice + [blank_line] + license_notice,
copyright_notice + [blank_line] + license_notice
]
] for copyright_notice in copyright_notices]
headers = [_header for year_header in headers for _header in year_header]

excluded_files = ["version.py", "__init__.py"]
invalid_files = []
Expand Down
7 changes: 7 additions & 0 deletions tests/pytorch/test_models_obj_detection_pt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest
import torch
from torchvision.models.detection import FasterRCNN

from doctr.models import obj_detection
from doctr.models.obj_detection.factory import from_hub


@pytest.mark.parametrize(
Expand Down Expand Up @@ -32,3 +34,8 @@ def test_detection_models(arch_name, input_shape, pretrained):
target = [{k: v.cuda() for k, v in t.items()} for t in target]
out = model(input_tensor, target)
assert isinstance(out, dict) and all(isinstance(v, torch.Tensor) for v in out.values())


def test_obj_det_from_hub():
model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn").eval()
assert isinstance(model, FasterRCNN)
2 changes: 1 addition & 1 deletion tests/tensorflow/test_transforms_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def test_random_shadow(input_dtype, input_shape):
assert transformed.shape == input_shape
assert transformed.dtype == input_dtype
# The shadow will darken the picture
assert tf.math.reduce_mean(input_t) > tf.math.reduce_mean(transformed)
assert tf.math.reduce_mean(input_t) >= tf.math.reduce_mean(transformed)
assert tf.math.reduce_all(transformed >= 0)
if input_dtype == tf.uint8:
assert tf.reduce_all(transformed <= 255)
Expand Down

0 comments on commit 9b31588

Please sign in to comment.