-
Notifications
You must be signed in to change notification settings - Fork 463
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Added loading method for PyTorch artefact detection models from…
… HF Hub (#836) * refactor: Refactored FasterRCNN * feat: Added factory method from_hub * chore: Updated requirements * test: Added unittest * chore: Updated mypy config * test: Updated unittests * test: Fixed unittest * test: Fixed unittest * feat: Added cfg to model * test: Fixed unittest
- Loading branch information
Showing
10 changed files
with
83 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from doctr.file_utils import is_torch_available | ||
|
||
if is_torch_available(): | ||
from .pytorch import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Copyright (C) 2022, Mindee. | ||
|
||
# This program is licensed under the Apache License version 2. | ||
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details. | ||
|
||
import json | ||
from typing import Any | ||
|
||
import torch | ||
from huggingface_hub import hf_hub_download | ||
|
||
from doctr.models import obj_detection | ||
|
||
__all__ = ['from_hub'] | ||
|
||
|
||
def from_hub(repo_id: str, **kwargs: Any) -> torch.nn.Module: | ||
"""Instantiate & load a pretrained model from HF hub. | ||
Example:: | ||
>>> from doctr.models.obj_detection import from_hub | ||
>>> model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn").eval() | ||
>>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) | ||
>>> with torch.no_grad(): out = model(input_tensor) | ||
Args: | ||
repo_id: HuggingFace model hub repo | ||
kwargs: kwargs of `hf_hub_download` | ||
Returns: | ||
Model loaded with the checkpoint | ||
""" | ||
|
||
# Get the config | ||
with open(hf_hub_download(repo_id, filename='config.json', **kwargs), 'rb') as f: | ||
cfg = json.load(f) | ||
|
||
model = obj_detection.__dict__[cfg['arch']]( | ||
pretrained=False, | ||
image_mean=cfg['mean'], | ||
image_std=cfg['std'], | ||
max_size=cfg['input_shape'][-1], | ||
num_classes=len(cfg['classes']), | ||
) | ||
|
||
# Load the checkpoint | ||
state_dict = torch.load(hf_hub_download(repo_id, filename='pytorch_model.bin', **kwargs), map_location='cpu') | ||
model.load_state_dict(state_dict) | ||
model.cfg = cfg | ||
|
||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,3 +14,4 @@ torchvision>=0.9.0 | |
Pillow>=8.3.2 | ||
tqdm>=4.30.0 | ||
rapidfuzz>=1.6.0 | ||
huggingface-hub>=0.4.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,3 +15,4 @@ tqdm>=4.30.0 | |
tensorflow-addons>=0.13.0 | ||
rapidfuzz>=1.6.0 | ||
keras<2.7.0 | ||
huggingface-hub>=0.4.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters