Skip to content

Commit

Permalink
add more yolov8 download utils (#887)
Browse files Browse the repository at this point in the history
  • Loading branch information
pranavdurai10 authored Jun 7, 2023
1 parent 0077a14 commit 3397858
Show file tree
Hide file tree
Showing 11 changed files with 29 additions and 18 deletions.
2 changes: 0 additions & 2 deletions sahi/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def __init__(
load_at_init: bool = True,
image_size: int = None,
):

self._processor = processor
self._image_shapes = []
super().__init__(
Expand Down Expand Up @@ -66,7 +65,6 @@ def num_categories(self) -> int:
return self.model.config.num_labels

def load_model(self):

from transformers import AutoModelForObjectDetection, AutoProcessor

model = AutoModelForObjectDetection.from_pretrained(self.model_path)
Expand Down
1 change: 0 additions & 1 deletion sahi/models/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def _create_object_prediction_list_from_original_predictions(
full_shape = None if full_shape_list is None else full_shape_list[0]

for ind in range(len(boxes)):

if masks is not None:
mask = np.array(masks[ind])
else:
Expand Down
1 change: 0 additions & 1 deletion sahi/scripts/coco_error_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,6 @@ def _analyse_results(

classname_to_export_path_list = {}
for k, catId in enumerate(present_cat_ids):

nm = cocoGt.loadCats(catId)[0]
print(f'--------------saving {k + 1}-{nm["name"]}---------------')
analyze_result = analyze_results[k]
Expand Down
1 change: 0 additions & 1 deletion sahi/utils/cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ def read_video_frame(video_capture, frame_skip_interval):
cv2.imshow("Prediction of {}".format(str(video_file_name)), cv2.WINDOW_AUTOSIZE)

while video_capture.isOpened:

frame_num = video_capture.get(cv2.CAP_PROP_POS_FRAMES)
video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_num + frame_skip_interval)

Expand Down
1 change: 0 additions & 1 deletion sahi/utils/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ def increment_path(path, exist_ok=True, sep=""):


def download_from_url(from_url: str, to_path: str):

Path(to_path).parent.mkdir(parents=True, exist_ok=True)

if not os.path.exists(to_path):
Expand Down
3 changes: 0 additions & 3 deletions sahi/utils/mmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class MmdetTestConstants:


def download_mmdet_cascade_mask_rcnn_model(destination_path: Optional[str] = None):

if destination_path is None:
destination_path = MmdetTestConstants.MMDET_CASCADEMASKRCNN_MODEL_PATH

Expand All @@ -41,7 +40,6 @@ def download_mmdet_cascade_mask_rcnn_model(destination_path: Optional[str] = Non


def download_mmdet_retinanet_model(destination_path: Optional[str] = None):

if destination_path is None:
destination_path = MmdetTestConstants.MMDET_RETINANET_MODEL_PATH

Expand All @@ -51,7 +49,6 @@ def download_mmdet_retinanet_model(destination_path: Optional[str] = None):


def download_mmdet_yolox_tiny_model(destination_path: Optional[str] = None):

if destination_path is None:
destination_path = MmdetTestConstants.MMDET_YOLOX_TINY_MODEL_PATH

Expand Down
3 changes: 0 additions & 3 deletions sahi/utils/yolonas.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class YoloNasTestConstants:


def download_yolonas_s_model(destination_path: Optional[str] = None):

if destination_path is None:
destination_path = YoloNasTestConstants.YOLONAS_S_MODEL_PATH

Expand All @@ -30,7 +29,6 @@ def download_yolonas_s_model(destination_path: Optional[str] = None):


def download_yolonas_m_model(destination_path: Optional[str] = None):

if destination_path is None:
destination_path = YoloNasTestConstants.YOLONAS_M_MODEL_PATH

Expand All @@ -44,7 +42,6 @@ def download_yolonas_m_model(destination_path: Optional[str] = None):


def download_yolonas_l_model(destination_path: Optional[str] = None):

if destination_path is None:
destination_path = YoloNasTestConstants.YOLONAS_L_MODEL_PATH

Expand Down
2 changes: 0 additions & 2 deletions sahi/utils/yolov5.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class Yolov5TestConstants:


def download_yolov5n_model(destination_path: Optional[str] = None):

if destination_path is None:
destination_path = Yolov5TestConstants.YOLOV5N_MODEL_PATH

Expand All @@ -30,7 +29,6 @@ def download_yolov5n_model(destination_path: Optional[str] = None):


def download_yolov5s6_model(destination_path: Optional[str] = None):

if destination_path is None:
destination_path = Yolov5TestConstants.YOLOV5S6_MODEL_PATH

Expand Down
31 changes: 29 additions & 2 deletions sahi/utils/yolov8.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ class Yolov8TestConstants:
YOLOV8M_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m.pt"
YOLOV8M_MODEL_PATH = "tests/data/models/yolov8/yolov8m.pt"

YOLOV8M_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l.pt"
YOLOV8M_MODEL_PATH = "tests/data/models/yolov8/yolov8l.pt"

def download_yolov8n_model(destination_path: Optional[str] = None):

def download_yolov8n_model(destination_path: Optional[str] = None):
if destination_path is None:
destination_path = Yolov8TestConstants.YOLOV8N_MODEL_PATH

Expand All @@ -30,7 +32,6 @@ def download_yolov8n_model(destination_path: Optional[str] = None):


def download_yolov8s_model(destination_path: Optional[str] = None):

if destination_path is None:
destination_path = Yolov8TestConstants.YOLOV8S_MODEL_PATH

Expand All @@ -41,3 +42,29 @@ def download_yolov8s_model(destination_path: Optional[str] = None):
Yolov8TestConstants.YOLOV8S_MODEL_URL,
destination_path,
)


def download_yolov8m_model(destination_path: Optional[str] = None):
if destination_path is None:
destination_path = Yolov8TestConstants.YOLOV8M_MODEL_PATH

Path(destination_path).parent.mkdir(parents=True, exist_ok=True)

if not path.exists(destination_path):
urllib.request.urlretrieve(
Yolov8TestConstants.YOLOV8M_MODEL_URL,
destination_path,
)


def download_yolov8l_model(destination_path: Optional[str] = None):
if destination_path is None:
destination_path = Yolov8TestConstants.YOLOV8L_MODEL_PATH

Path(destination_path).parent.mkdir(parents=True, exist_ok=True)

if not path.exists(destination_path):
urllib.request.urlretrieve(
Yolov8TestConstants.YOLOV8L_MODEL_URL,
destination_path,
)
1 change: 0 additions & 1 deletion tests/test_cocoutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,6 @@ def test_get_imageid2annotationlist_mapping(self):
self.assertEqual(len(imageid2annotationlist_mapping), 2)

def check_image_id(image_id):

image_ids = [annotationlist["image_id"] for annotationlist in imageid2annotationlist_mapping[image_id]]
self.assertEqual(image_ids, [image_id] * len(image_ids))

Expand Down
1 change: 0 additions & 1 deletion tests/test_mmdetectionmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@


def download_mmdet_yolox_tiny_model():

download_from_url(MMDET_YOLOX_TINY_MODEL_URL, MMDET_YOLOX_TINY_MODEL_PATH)
download_from_url(MMDET_YOLOX_TINY_CONFIG_URL, MMDET_YOLOX_TINY_CONFIG_PATH)

Expand Down

0 comments on commit 3397858

Please sign in to comment.