Skip to content

Commit

Permalink
Fix bug in replace_head for YoloX (#1411)
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe authored Aug 24, 2023
1 parent 8c7dc64 commit 8cdb2aa
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,12 @@ def __init__(
self.n_anchors = 1
self.grid = [torch.zeros(1)] * self.detection_layers_num # init grid

self.register_buffer("stride", torch.tensor(stride), persistent=False)
if torch.is_tensor(stride):
stride = stride.clone().detach()
else:
stride = torch.tensor(stride)

self.register_buffer("stride", stride, persistent=False)

self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
Expand Down Expand Up @@ -691,7 +696,7 @@ def replace_head(self, new_num_classes=None, new_head=None):

new_last_layer = DetectX(
num_classes=new_num_classes,
stride=self._head.anchors.stride,
stride=self.strides,
activation_func_type=activation_type,
channels=[width_mult(v) for v in (256, 512, 1024)],
depthwise=isinstance(old_detectx.cls_convs[0][0], GroupedConvBlock),
Expand Down
17 changes: 11 additions & 6 deletions tests/unit_tests/replace_head_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os
import shutil
import unittest

import torch
Expand All @@ -14,6 +12,17 @@ def setUp(self) -> None:
self.device = "cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu"
super_gradients.init_trainer()

def test_yolox_replace_head(self):
input = torch.randn(1, 3, 640, 640).to(self.device)
for model in [Models.YOLOX_S, Models.YOLOX_M, Models.YOLOX_L, Models.YOLOX_T]:
model = models.get(model, pretrained_weights="coco").to(self.device).eval()
num_classes = 100
model.replace_head(new_num_classes=num_classes)
outputs = model.forward(input)
self.assertEqual(outputs[0].size(4), num_classes + 5)
self.assertEqual(outputs[1].size(4), num_classes + 5)
self.assertEqual(outputs[2].size(4), num_classes + 5)

def test_ppyolo_replace_head(self):
input = torch.randn(1, 3, 640, 640).to(self.device)
for model in [Models.PP_YOLOE_S, Models.PP_YOLOE_M, Models.PP_YOLOE_L, Models.PP_YOLOE_X]:
Expand All @@ -37,10 +46,6 @@ def test_dekr_replace_head(self):
self.assertEqual(heatmap.size(1), 20 + 1)
self.assertEqual(offsets.size(1), 20 * 2)

def tearDown(self) -> None:
if os.path.exists("~/.cache/torch/hub/"):
shutil.rmtree("~/.cache/torch/hub/")


if __name__ == "__main__":
unittest.main()

0 comments on commit 8cdb2aa

Please sign in to comment.