Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add yolov8 to SAHI #833

Merged
merged 6 commits into from
Feb 23, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ jobs:
run: >
pip install pycocotools==2.0.6

- name: Install ultralytics
run: >
pip install ultralytics
NguyenTheAn marked this conversation as resolved.
Show resolved Hide resolved

- name: Unittest for SAHI+YOLOV5/MMDET/Detectron2 on all platforms
run: |
python -m unittest
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/ci_torch1.10.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ jobs:
run: >
pip install pycocotools==2.0.6

- name: Install ultralytics
run: >
pip install ultralytics

- name: Unittest for SAHI+YOLOV5/MMDET/Detectron2 on all platforms
run: |
python -m unittest
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/package_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ jobs:
run: >
pip install pycocotools==2.0.6

- name: Install ultralytics
run: >
pip install ultralytics

- name: Install latest SAHI package
run: >
pip install --upgrade --force-reinstall sahi
Expand Down
1 change: 1 addition & 0 deletions sahi/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sahi.utils.file import import_model_class

MODEL_TYPE_TO_MODEL_CLASS_NAME = {
"yolov8": "Yolov8DetectionModel",
"mmdet": "MmdetDetectionModel",
"yolov5": "Yolov5DetectionModel",
"detectron2": "Detectron2DetectionModel",
Expand Down
160 changes: 160 additions & 0 deletions sahi/models/yolov8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# OBSS SAHI Tool
# Code written by AnNT, 2023.

import logging
from typing import Any, Dict, List, Optional

import numpy as np

logger = logging.getLogger(__name__)

from sahi.models.base import DetectionModel
from sahi.prediction import ObjectPrediction
from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list
from sahi.utils.import_utils import check_requirements


class Yolov8DetectionModel(DetectionModel):
def check_dependencies(self) -> None:
check_requirements(["ultralytics"])

def load_model(self):
"""
Detection model is initialized and set to self.model.
"""

from ultralytics import YOLO

try:
model = YOLO(self.model_path)
model.to(self.device)
self.set_model(model)
except Exception as e:
raise TypeError("model_path is not a valid yolov8 model path: ", e)

def set_model(self, model: Any):
"""
Sets the underlying YOLOv8 model.
Args:
model: Any
A YOLOv8 model
"""

# if model.__class__.__module__ not in ["yolov5.models.common", "models.common"]:
# raise Exception(f"Not a yolov5 model: {type(model)}")

# model.conf = self.confidence_threshold
self.model = model

# set category_mapping
if not self.category_mapping:
category_mapping = {str(ind): category_name for ind, category_name in enumerate(self.category_names)}
self.category_mapping = category_mapping

def perform_inference(self, image: np.ndarray):
"""
Prediction is performed using self.model and the prediction result is set to self._original_predictions.
Args:
image: np.ndarray
A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
"""

# Confirm model is loaded
if self.model is None:
raise ValueError("Model is not loaded, load it by calling .load_model()")
prediction_result = self.model(image, verbose=False)
prediction_result = [
result.boxes.boxes[result.boxes.boxes[:, 4] >= self.confidence_threshold] for result in prediction_result
]

self._original_predictions = prediction_result

@property
def category_names(self):
return self.model.names.values()

@property
def num_categories(self):
"""
Returns number of categories
"""
return len(self.model.names)

@property
def has_mask(self):
"""
Returns if model output contains segmentation mask
"""
return False # fix when yolov5 supports segmentation models

def _create_object_prediction_list_from_original_predictions(
self,
shift_amount_list: Optional[List[List[int]]] = [[0, 0]],
full_shape_list: Optional[List[List[int]]] = None,
):
"""
self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
self._object_prediction_list_per_image.
Args:
shift_amount_list: list of list
To shift the box and mask predictions from sliced image to full sized image, should
be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
full_shape_list: list of list
Size of the full image after shifting, should be in the form of
List[[height, width],[height, width],...]
"""
original_predictions = self._original_predictions

# compatilibty for sahi v0.8.15
shift_amount_list = fix_shift_amount_list(shift_amount_list)
full_shape_list = fix_full_shape_list(full_shape_list)

# handle all predictions
object_prediction_list_per_image = []
for image_ind, image_predictions_in_xyxy_format in enumerate(original_predictions):
shift_amount = shift_amount_list[image_ind]
full_shape = None if full_shape_list is None else full_shape_list[image_ind]
object_prediction_list = []

# process predictions
for prediction in image_predictions_in_xyxy_format.cpu().detach().numpy():
x1 = prediction[0]
y1 = prediction[1]
x2 = prediction[2]
y2 = prediction[3]
bbox = [x1, y1, x2, y2]
score = prediction[4]
category_id = int(prediction[5])
category_name = self.category_mapping[str(category_id)]

# fix negative box coords
bbox[0] = max(0, bbox[0])
bbox[1] = max(0, bbox[1])
bbox[2] = max(0, bbox[2])
bbox[3] = max(0, bbox[3])

# fix out of image box coords
if full_shape is not None:
bbox[0] = min(full_shape[1], bbox[0])
bbox[1] = min(full_shape[0], bbox[1])
bbox[2] = min(full_shape[1], bbox[2])
bbox[3] = min(full_shape[0], bbox[3])

# ignore invalid predictions
if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]):
logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
continue

object_prediction = ObjectPrediction(
bbox=bbox,
category_id=category_id,
score=score,
bool_mask=None,
category_name=category_name,
shift_amount=shift_amount,
full_shape=full_shape,
)
object_prediction_list.append(object_prediction)
object_prediction_list_per_image.append(object_prediction_list)

self._object_prediction_list_per_image = object_prediction_list_per_image
43 changes: 43 additions & 0 deletions sahi/utils/yolov8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import urllib.request
from os import path
from pathlib import Path
from typing import Optional


class Yolov8TestConstants:
YOLOV8N_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt"
YOLOV8N_MODEL_PATH = "tests/data/models/yolov8/yolov8n.pt"

YOLOV8S_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s.pt"
YOLOV8S_MODEL_PATH = "tests/data/models/yolov8/yolov8s.pt"

YOLOV8M_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m.pt"
YOLOV8M_MODEL_PATH = "tests/data/models/yolov8/yolov8m.pt"


def download_yolov8n_model(destination_path: Optional[str] = None):

if destination_path is None:
destination_path = Yolov8TestConstants.YOLOV8N_MODEL_PATH

Path(destination_path).parent.mkdir(parents=True, exist_ok=True)

if not path.exists(destination_path):
urllib.request.urlretrieve(
Yolov8TestConstants.YOLOV8N_MODEL_URL,
destination_path,
)


def download_yolov8s_model(destination_path: Optional[str] = None):

if destination_path is None:
destination_path = Yolov8TestConstants.YOLOV8S_MODEL_PATH

Path(destination_path).parent.mkdir(parents=True, exist_ok=True)

if not path.exists(destination_path):
urllib.request.urlretrieve(
Yolov8TestConstants.YOLOV8S_MODEL_URL,
destination_path,
)
140 changes: 140 additions & 0 deletions tests/test_yolov8model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# OBSS SAHI Tool
# Code written by Fatih C Akyon, 2020.

import unittest

import numpy as np

from sahi.utils.cv import read_image
from sahi.utils.yolov8 import Yolov8TestConstants, download_yolov8n_model, download_yolov8s_model

MODEL_DEVICE = "cpu"
CONFIDENCE_THRESHOLD = 0.3
IMAGE_SIZE = 320


class TestYolov8DetectionModel(unittest.TestCase):
def test_load_model(self):
from sahi.models.yolov8 import Yolov8DetectionModel

download_yolov8n_model()

yolov8_detection_model = Yolov8DetectionModel(
model_path=Yolov8TestConstants.YOLOV8N_MODEL_PATH,
confidence_threshold=CONFIDENCE_THRESHOLD,
device=MODEL_DEVICE,
category_remapping=None,
load_at_init=True,
)

self.assertNotEqual(yolov8_detection_model.model, None)

def test_set_model(self):

from ultralytics import YOLO

from sahi.models.yolov8 import Yolov8DetectionModel

download_yolov8n_model()

yolo_model = YOLO(Yolov8TestConstants.YOLOV8N_MODEL_PATH)

yolov8_detection_model = Yolov8DetectionModel(
model=yolo_model,
confidence_threshold=CONFIDENCE_THRESHOLD,
device=MODEL_DEVICE,
category_remapping=None,
load_at_init=True,
)

self.assertNotEqual(yolov8_detection_model.model, None)

def test_perform_inference(self):
from sahi.models.yolov8 import Yolov8DetectionModel

# init model
download_yolov8n_model()

yolov8_detection_model = Yolov8DetectionModel(
model_path=Yolov8TestConstants.YOLOV8N_MODEL_PATH,
confidence_threshold=CONFIDENCE_THRESHOLD,
device=MODEL_DEVICE,
category_remapping=None,
load_at_init=True,
image_size=IMAGE_SIZE,
)

# prepare image
image_path = "tests/data/small-vehicles1.jpeg"
image = read_image(image_path)

# perform inference
yolov8_detection_model.perform_inference(image)
original_predictions = yolov8_detection_model.original_predictions

boxes = original_predictions

# find box of first car detection with conf greater than 0.5
for box in boxes[0]:
if box[5].item() == 2: # if category car
if box[4].item() > 0.5:
break

# compare
desired_bbox = [448, 309, 497, 342]
predicted_bbox = list(map(int, box[:4].tolist()))
margin = 2
for ind, point in enumerate(predicted_bbox):
assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin
self.assertEqual(len(yolov8_detection_model.category_names), 80)
for box in boxes[0]:
self.assertGreaterEqual(box[4].item(), CONFIDENCE_THRESHOLD)

def test_convert_original_predictions(self):
from sahi.models.yolov8 import Yolov8DetectionModel

# init model
download_yolov8n_model()

yolov8_detection_model = Yolov8DetectionModel(
model_path=Yolov8TestConstants.YOLOV8N_MODEL_PATH,
confidence_threshold=CONFIDENCE_THRESHOLD,
device=MODEL_DEVICE,
category_remapping=None,
load_at_init=True,
image_size=IMAGE_SIZE,
)

# prepare image
image_path = "tests/data/small-vehicles1.jpeg"
image = read_image(image_path)

# perform inference
yolov8_detection_model.perform_inference(image)

# convert predictions to ObjectPrediction list
yolov8_detection_model.convert_original_predictions()
object_prediction_list = yolov8_detection_model.object_prediction_list

# compare
self.assertEqual(len(object_prediction_list), 11)
self.assertEqual(object_prediction_list[0].category.id, 2)
self.assertEqual(object_prediction_list[0].category.name, "car")
desired_bbox = [448, 309, 49, 33]
predicted_bbox = object_prediction_list[0].bbox.to_xywh()
margin = 2
for ind, point in enumerate(predicted_bbox):
assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin
self.assertEqual(object_prediction_list[2].category.id, 2)
self.assertEqual(object_prediction_list[2].category.name, "car")
desired_bbox = [835, 307, 37, 37]
predicted_bbox = object_prediction_list[2].bbox.to_xywh()
for ind, point in enumerate(predicted_bbox):
assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin

for object_prediction in object_prediction_list:
self.assertGreaterEqual(object_prediction.score.value, CONFIDENCE_THRESHOLD)


if __name__ == "__main__":
unittest.main()