Skip to content

Commit

Permalink
Fix the anchor configuration mechanism (#202)
Browse files Browse the repository at this point in the history
* Add unit-test for #170

* Fixing unit-test

* Fix initializing anchor_grids in load_from_ultralytics
  • Loading branch information
zhiqwang authored Oct 20, 2021
1 parent fab48d7 commit 6d08630
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 2 deletions.
66 changes: 65 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
import cv2
import numpy as np
import pytest
import torch
from torch import Tensor
from yolort import models
from yolort.models import YOLO
from yolort.utils import (
FeatureExtractor,
get_image_from_url,
load_from_ultralytics,
read_image_to_tensor,
)
from yolort.utils.image_utils import box_cxcywh_to_xyxy
from yolort.v5 import letterbox, scale_coords
from yolort.v5 import (
letterbox,
load_yolov5_model,
scale_coords,
non_max_suppression,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -46,6 +53,63 @@ def test_load_from_ultralytics(
assert len(model_info["strides"]) == 4 if use_p6 else 3


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[("yolov5s-VOC", "r4.0", "v5.0", "23818cff")],
)
def test_load_from_ultralytics_voc(
arch: str,
version: str,
upstream_version: str,
hash_prefix: str,
):
img_path = "test/assets/bus.jpg"
checkpoint_path = f"{arch}_{upstream_version}_{hash_prefix}"

base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"

torch.hub.download_url_to_file(
model_url,
checkpoint_path,
hash_prefix=hash_prefix,
)

# Preprocess
img_raw = cv2.imread(img_path)
img = letterbox(img_raw, new_shape=(320, 320))[0]
img = read_image_to_tensor(img)

conf = 0.25
iou = 0.45

# Define YOLOv5 model
model_yolov5 = load_yolov5_model(checkpoint_path)
model_yolov5.conf = conf # confidence threshold (0-1)
model_yolov5.iou = iou # NMS IoU threshold (0-1)
model_yolov5.eval()
with torch.no_grad():
outs = model_yolov5(img[None])[0]
outs = non_max_suppression(outs, conf, iou, agnostic=True)
out_from_yolov5 = outs[0]

# Define yolort model
model_yolort = YOLO.load_from_yolov5(
checkpoint_path,
score_thresh=conf,
version=version,
)
model_yolort.eval()
with torch.no_grad():
out_from_yolort = model_yolort(img[None])

torch.testing.assert_allclose(out_from_yolort[0]["boxes"], out_from_yolov5[:, :4])
torch.testing.assert_allclose(out_from_yolort[0]["scores"], out_from_yolov5[:, 4])
torch.testing.assert_allclose(
out_from_yolort[0]["labels"], out_from_yolov5[:, 5].to(dtype=torch.int64)
)


def test_read_image_to_tensor():
N, H, W = 3, 720, 360
img = np.random.randint(0, 255, (H, W, N), dtype="uint8") # As a dummy image
Expand Down
5 changes: 4 additions & 1 deletion yolort/utils/update_module_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from yolort.models import yolo
from yolort.v5 import load_yolov5_model, get_yolov5_size
from .image_utils import to_numpy


def load_from_ultralytics(checkpoint_path: str, version: str = "r6.0"):
Expand All @@ -30,7 +31,9 @@ def load_from_ultralytics(checkpoint_path: str, version: str = "r6.0"):
strides = checkpoint_yolov5.stride
anchor_grids = checkpoint_yolov5.yaml["anchors"]
if isinstance(anchor_grids, int):
anchor_grids = [list(range(anchor_grids * 2))] * len(strides)
anchor_grids = (
to_numpy(checkpoint_yolov5.model[-1].anchor_grid).reshape(3, -1).tolist()
)

depth_multiple = checkpoint_yolov5.yaml["depth_multiple"]
width_multiple = checkpoint_yolov5.yaml["width_multiple"]
Expand Down

0 comments on commit 6d08630

Please sign in to comment.