Skip to content

Commit

Permalink
Move load_from_ultralytics back to yolort.models._utils
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Oct 15, 2021
1 parent c2cfc5e commit 0761eb2
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 147 deletions.
46 changes: 0 additions & 46 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
from torch import Tensor
from yolort import models
from yolort.models import YOLOv5
from yolort.models.anchor_utils import AnchorGenerator
from yolort.models.backbone_utils import darknet_pan_backbone
from yolort.models.box_head import YOLOHead, PostProcess, SetCriterion
Expand Down Expand Up @@ -353,48 +352,3 @@ def test_torchscript(arch):
torch.testing.assert_close(
out[0]["boxes"], out_script[1][0]["boxes"], rtol=0, atol=0
)


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[("yolov5s", "r4.0", "v4.0", "9ca9a642")],
)
def test_load_from_yolov5(
arch: str,
version: str,
upstream_version: str,
hash_prefix: str,
):
img_path = "test/assets/bus.jpg"
checkpoint_path = f"{arch}_{version}_{hash_prefix}"

base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"

torch.hub.download_url_to_file(
model_url,
checkpoint_path,
hash_prefix=hash_prefix,
)

model_yolov5 = YOLOv5.load_from_yolov5(checkpoint_path, version=version)
model_yolov5.eval()
out_from_yolov5 = model_yolov5.predict(img_path)
assert isinstance(out_from_yolov5[0], dict)
assert isinstance(out_from_yolov5[0]["boxes"], Tensor)
assert isinstance(out_from_yolov5[0]["labels"], Tensor)
assert isinstance(out_from_yolov5[0]["scores"], Tensor)

model = models.__dict__[arch](pretrained=True, score_thresh=0.25)
model.eval()
out = model.predict(img_path)

torch.testing.assert_close(
out_from_yolov5[0]["scores"], out[0]["scores"], rtol=0, atol=0
)
torch.testing.assert_close(
out_from_yolov5[0]["labels"], out[0]["labels"], rtol=0, atol=0
)
torch.testing.assert_close(
out_from_yolov5[0]["boxes"], out[0]["boxes"], rtol=0, atol=0
)
84 changes: 84 additions & 0 deletions test/test_models_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
import pytest
import torch
from torch import Tensor

from yolort import models
from yolort.models import YOLOv5
from yolort.models._utils import load_from_ultralytics


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix, use_p6",
[
("yolov5s", "r4.0", "v4.0", "9ca9a642", False),
("yolov5s", "r4.0", "v6.0", "c3b140f3", False),
],
)
def test_load_from_ultralytics(
arch: str,
version: str,
upstream_version: str,
hash_prefix: str,
use_p6: bool,
):
checkpoint_path = f"{arch}_{version}_{hash_prefix}"
base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"

torch.hub.download_url_to_file(
model_url,
checkpoint_path,
hash_prefix=hash_prefix,
)
model_info = load_from_ultralytics(checkpoint_path, version=version)
assert isinstance(model_info, dict)
assert model_info["num_classes"] == 80
assert model_info["size"] == arch.replace("yolov5", "")
assert model_info["use_p6"] == use_p6
assert len(model_info["strides"]) == 4 if use_p6 else 3


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[("yolov5s", "r4.0", "v4.0", "9ca9a642")],
)
def test_load_from_yolov5(
arch: str,
version: str,
upstream_version: str,
hash_prefix: str,
):
img_path = "test/assets/bus.jpg"
checkpoint_path = f"{arch}_{version}_{hash_prefix}"

base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"

torch.hub.download_url_to_file(
model_url,
checkpoint_path,
hash_prefix=hash_prefix,
)

model_yolov5 = YOLOv5.load_from_yolov5(checkpoint_path, version=version)
model_yolov5.eval()
out_from_yolov5 = model_yolov5.predict(img_path)
assert isinstance(out_from_yolov5[0], dict)
assert isinstance(out_from_yolov5[0]["boxes"], Tensor)
assert isinstance(out_from_yolov5[0]["labels"], Tensor)
assert isinstance(out_from_yolov5[0]["scores"], Tensor)

model = models.__dict__[arch](pretrained=True, score_thresh=0.25)
model.eval()
out = model.predict(img_path)

torch.testing.assert_close(
out_from_yolov5[0]["scores"], out[0]["scores"], rtol=0, atol=0
)
torch.testing.assert_close(
out_from_yolov5[0]["labels"], out[0]["labels"], rtol=0, atol=0
)
torch.testing.assert_close(
out_from_yolov5[0]["boxes"], out[0]["boxes"], rtol=0, atol=0
)
32 changes: 0 additions & 32 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,44 +7,12 @@
from yolort.utils import (
FeatureExtractor,
get_image_from_url,
load_from_ultralytics,
read_image_to_tensor,
)
from yolort.utils.image_utils import box_cxcywh_to_xyxy
from yolort.v5 import letterbox, scale_coords


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix, use_p6",
[
("yolov5s", "r4.0", "v4.0", "9ca9a642", False),
("yolov5s", "r4.0", "v6.0", "c3b140f3", False),
],
)
def test_load_from_ultralytics(
arch: str,
version: str,
upstream_version: str,
hash_prefix: str,
use_p6: bool,
):
checkpoint_path = f"{arch}_{version}_{hash_prefix}"
base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"

torch.hub.download_url_to_file(
model_url,
checkpoint_path,
hash_prefix=hash_prefix,
)
model_info = load_from_ultralytics(checkpoint_path, version=version)
assert isinstance(model_info, dict)
assert model_info["num_classes"] == 80
assert model_info["size"] == arch.replace("yolov5", "")
assert model_info["use_p6"] == use_p6
assert len(model_info["strides"]) == 4 if use_p6 else 3


def test_read_image_to_tensor():
N, H, W = 3, 720, 360
img = np.random.randint(0, 255, (H, W, N), dtype="uint8") # As a dummy image
Expand Down
65 changes: 65 additions & 0 deletions yolort/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,71 @@
from torch import nn, Tensor
from torchvision.ops import box_convert, box_iou

from yolort.utils import ModuleStateUpdate
from yolort.v5 import load_yolov5_model, get_yolov5_size


def load_from_ultralytics(checkpoint_path: str, version: str = "r6.0"):
"""
Allows the user to load model state file from the checkpoint trained from
the ultralytics/yolov5.
Args:
checkpoint_path (str): Path of the YOLOv5 checkpoint model.
version (str): upstream version released by the ultralytics/yolov5, Possible
values are ["r3.1", "r4.0", "r6.0"]. Default: "r6.0".
"""

assert version in [
"r3.1",
"r4.0",
"r6.0",
], "Currently does not support this version."

checkpoint_yolov5 = load_yolov5_model(checkpoint_path)
num_classes = checkpoint_yolov5.yaml["nc"]
strides = checkpoint_yolov5.stride
anchor_grids = checkpoint_yolov5.yaml["anchors"]
depth_multiple = checkpoint_yolov5.yaml["depth_multiple"]
width_multiple = checkpoint_yolov5.yaml["width_multiple"]

use_p6 = False
if len(strides) == 4:
use_p6 = True

if use_p6:
inner_block_maps = {"0": "9", "1": "10", "3": "13", "4": "14"}
layer_block_maps = {"0": "17", "1": "18", "2": "20", "3": "21", "4": "23"}
else:
inner_block_maps = {"0": "9", "1": "10", "3": "13", "4": "14"}
layer_block_maps = {"0": "17", "1": "18", "2": "20", "3": "21", "4": "23"}

module_state_updater = ModuleStateUpdate(
arch=None,
depth_multiple=depth_multiple,
width_multiple=width_multiple,
version=version,
num_classes=num_classes,
inner_block_maps=inner_block_maps,
layer_block_maps=layer_block_maps,
use_p6=use_p6,
)
module_state_updater.updating(checkpoint_yolov5)
state_dict = module_state_updater.model.state_dict()

size = get_yolov5_size(depth_multiple, width_multiple)

return {
"num_classes": num_classes,
"depth_multiple": depth_multiple,
"width_multiple": width_multiple,
"strides": strides,
"anchor_grids": anchor_grids,
"use_p6": use_p6,
"size": size,
"state_dict": state_dict,
}


def _evaluate_iou(target, pred):
"""
Expand Down
2 changes: 1 addition & 1 deletion yolort/models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import nn, Tensor
from torchvision.models.utils import load_state_dict_from_url

from yolort.utils import load_from_ultralytics
from ._utils import load_from_ultralytics
from .anchor_utils import AnchorGenerator
from .backbone_utils import darknet_pan_backbone
from .box_head import YOLOHead, SetCriterion, PostProcess
Expand Down
3 changes: 1 addition & 2 deletions yolort/models/yolo_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
from torchvision.io import read_image

from yolort.data import COCOEvaluator, contains_any_tensor
from yolort.utils import load_from_ultralytics
from . import yolo
from ._utils import _evaluate_iou
from ._utils import _evaluate_iou, load_from_ultralytics
from .transform import YOLOTransform

__all__ = ["YOLOv5"]
Expand Down
6 changes: 3 additions & 3 deletions yolort/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from .hooks import FeatureExtractor
from .image_utils import cv2_imshow, get_image_from_url, read_image_to_tensor
from .update_module_state import load_from_ultralytics
from .update_module_state import ModuleStateUpdate

__all__ = [
"cv2_imshow",
"FeatureExtractor",
"ModuleStateUpdate",
"cv2_imshow",
"get_image_from_url",
"get_callable_dict",
"load_from_ultralytics",
"read_image_to_tensor",
]

Expand Down
64 changes: 1 addition & 63 deletions yolort/utils/update_module_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,69 +5,7 @@
from torch import nn

from yolort.models import yolo
from yolort.v5 import load_yolov5_model, get_yolov5_size


def load_from_ultralytics(checkpoint_path: str, version: str = "r6.0"):
"""
Allows the user to load model state file from the checkpoint trained from
the ultralytics/yolov5.
Args:
checkpoint_path (str): Path of the YOLOv5 checkpoint model.
version (str): upstream version released by the ultralytics/yolov5, Possible
values are ["r3.1", "r4.0", "r6.0"]. Default: "r6.0".
"""

assert version in [
"r3.1",
"r4.0",
"r6.0",
], "Currently does not support this version."

checkpoint_yolov5 = load_yolov5_model(checkpoint_path)
num_classes = checkpoint_yolov5.yaml["nc"]
strides = checkpoint_yolov5.stride
anchor_grids = checkpoint_yolov5.yaml["anchors"]
depth_multiple = checkpoint_yolov5.yaml["depth_multiple"]
width_multiple = checkpoint_yolov5.yaml["width_multiple"]

use_p6 = False
if len(strides) == 4:
use_p6 = True

if use_p6:
inner_block_maps = {"0": "9", "1": "10", "3": "13", "4": "14"}
layer_block_maps = {"0": "17", "1": "18", "2": "20", "3": "21", "4": "23"}
else:
inner_block_maps = {"0": "9", "1": "10", "3": "13", "4": "14"}
layer_block_maps = {"0": "17", "1": "18", "2": "20", "3": "21", "4": "23"}

module_state_updater = ModuleStateUpdate(
arch=None,
depth_multiple=depth_multiple,
width_multiple=width_multiple,
version=version,
num_classes=num_classes,
inner_block_maps=inner_block_maps,
layer_block_maps=layer_block_maps,
use_p6=use_p6,
)
module_state_updater.updating(checkpoint_yolov5)
state_dict = module_state_updater.model.state_dict()

size = get_yolov5_size(depth_multiple, width_multiple)

return {
"num_classes": num_classes,
"depth_multiple": depth_multiple,
"width_multiple": width_multiple,
"strides": strides,
"anchor_grids": anchor_grids,
"use_p6": use_p6,
"size": size,
"state_dict": state_dict,
}
from yolort.v5 import get_yolov5_size


class ModuleStateUpdate:
Expand Down

0 comments on commit 0761eb2

Please sign in to comment.