Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support upstream yolov5 v6 release #194

Merged
merged 13 commits into from
Oct 10, 2021
153 changes: 55 additions & 98 deletions notebooks/how-to-align-with-ultralytics-yolov5.ipynb

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions yolort/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def yolov5s(
"""
Args:
upstream_version (str): model released by the upstream YOLOv5. Possible values
are 'r3.1' and 'r4.0'. Default: 'r4.0'.
are ["r3.1", "r4.0"]. Default: "r4.0".
export_friendly (bool): Deciding whether to use (ONNX/TVM) export friendly mode.
Default: False.
"""
Expand All @@ -40,7 +40,7 @@ def yolov5m(
"""
Args:
upstream_version (str): model released by the upstream YOLOv5. Possible values
are 'r3.1' and 'r4.0'. Default: 'r4.0'.
are ["r3.1", "r4.0"]. Default: "r4.0".
export_friendly (bool): Deciding whether to use (ONNX/TVM) export friendly mode.
Default: False.
"""
Expand All @@ -63,7 +63,7 @@ def yolov5l(
"""
Args:
upstream_version (str): model released by the upstream YOLOv5. Possible values
are 'r3.1' and 'r4.0'. Default: 'r4.0'.
are ["r3.1", "r4.0"]. Default: "r4.0".
export_friendly (bool): Deciding whether to use (ONNX/TVM) export friendly mode.
Default: False.
"""
Expand All @@ -86,7 +86,7 @@ def yolotr(
"""
Args:
upstream_version (str): model released by the upstream YOLOv5. Possible values
are 'r3.1' and 'r4.0'. Default: 'r4.0'.
are "r4.0". Default: "r4.0".
export_friendly (bool): Deciding whether to use (ONNX/TVM) export friendly mode.
Default: False.
"""
Expand Down
18 changes: 17 additions & 1 deletion yolort/models/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Tuple
from typing import Tuple, Optional

import torch
from torch import nn, Tensor
Expand Down Expand Up @@ -56,6 +56,22 @@ def _evaluate_iou(target, pred):
return box_iou(target["boxes"], pred["boxes"]).diag().mean()


def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v


def encode_single(reference_boxes: Tensor, anchors: Tensor) -> Tensor:
"""
Encode a set of anchors with respect to some
Expand Down
13 changes: 6 additions & 7 deletions yolort/models/backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class BackboneWithPAN(nn.Module):
in_channels_list (List[int]): number of channels for each feature map
that is returned, in the order they are present in the OrderedDict
depth_multiple (float): depth multiplier
version (str): ultralytics release version: r3.1 or r4.0
version (str): ultralytics release version: ["r3.1", "r4.0", "r6.0"]

Attributes:
out_channels (int): the number of channels in the PAN
Expand Down Expand Up @@ -61,7 +61,7 @@ def darknet_pan_backbone(
Constructs a specified DarkNet backbone with PAN on top. Freezes the specified number of
layers in the backbone.

Examples::
Examples:

>>> from models.backbone_utils import darknet_pan_backbone
>>> backbone = darknet_pan_backbone('darknet3_1', pretrained=True, trainable_layers=3)
Expand All @@ -70,10 +70,9 @@ def darknet_pan_backbone(
>>> # compute the output
>>> output = backbone(x)
>>> print([(k, v.shape) for k, v in output.items()])
>>> # returns
>>> [('0', torch.Size([1, 128, 8, 8])),
>>> ('1', torch.Size([1, 256, 4, 4])),
>>> ('2', torch.Size([1, 512, 2, 2]))]
[('0', torch.Size([1, 128, 8, 8])),
('1', torch.Size([1, 256, 4, 4])),
('2', torch.Size([1, 512, 2, 2]))]

Args:
backbone_name (string): darknet architecture. Possible values are 'DarkNet', 'darknet_s_r3_1',
Expand All @@ -83,7 +82,7 @@ def darknet_pan_backbone(
pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_layers (int): number of trainable (not frozen) darknet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
version (str): ultralytics release version: r3.1 or r4.0
version (str): ultralytics release version: ["r3.1", "r4.0", "r6.0"]
"""
backbone = darknet.__dict__[backbone_name](pretrained=pretrained).features

Expand Down
273 changes: 23 additions & 250 deletions yolort/models/darknet.py
Original file line number Diff line number Diff line change
@@ -1,256 +1,29 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from typing import Callable, List, Optional, Any

import torch
from torch import nn, Tensor
from torch.hub import load_state_dict_from_url

from yolort.v5 import Conv, SPP, Focus, BottleneckCSP, C3


__all__ = [
"DarkNet",
from .darknetv5 import (
DarkNetV5,
darknet_s_r3_1,
darknet_m_r3_1,
darknet_l_r3_1,
darknet_s_r4_0,
darknet_m_r4_0,
darknet_l_r4_0,
)
from .darknetv6 import (
DarkNetV6,
darknet_s_r6_0,
darknet_m_r6_0,
darknet_l_r6_0,
)

__all__ = (
"DarkNetV5",
"DarkNetV6",
"darknet_s_r3_1",
"darknet_m_r3_1",
"darknet_l_r3_1",
"darknet_s_r4_0",
"darknet_m_r4_0",
"darknet_l_r4_0",
]

model_urls = {
"darknet_s_r3.1": None,
"darknet_m_r3.1": None,
"darknet_l_r3.1": None,
"darknet_s_r4.0": None,
"darknet_m_r4.0": None,
"darknet_l_r4.0": None,
} # TODO: add checkpoint weights


def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v


class DarkNet(nn.Module):
def __init__(
self,
depth_multiple: float,
width_multiple: float,
version: str,
block: Optional[Callable[..., nn.Module]] = None,
stages_repeats: Optional[List[int]] = None,
stages_out_channels: Optional[List[int]] = None,
num_classes: int = 1000,
round_nearest: int = 8,
) -> None:
"""
DarkNet main class

Args:
num_classes (int): Number of classes
depth_multiple (float): Depth multiplier
width_multiple (float): Width multiplier - adjusts number of channels in each layer by this amount
version (str): ultralytics release version: r3.1 or r4.0
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
Set to 1 to turn off rounding
block: Module specifying inverted residual building block for darknet
"""
super().__init__()

if block is None:
block = _block[version]

input_channel = 64
last_channel = 1024

if stages_repeats is None:
stages_repeats = [3, 9, 9]

if stages_out_channels is None:
stages_out_channels = [128, 256, 512]

# Initial an empty features list
layers: List[nn.Module] = []

# building first layer
out_channel = _make_divisible(input_channel * width_multiple, round_nearest)
layers.append(Focus(3, out_channel, k=3, version=version))
input_channel = out_channel

# building CSP blocks
for depth_gain, out_channel in zip(stages_repeats, stages_out_channels):
depth_gain = max(round(depth_gain * depth_multiple), 1)
out_channel = _make_divisible(out_channel * width_multiple, round_nearest)
layers.append(Conv(input_channel, out_channel, k=3, s=2, version=version))
layers.append(block(out_channel, out_channel, n=depth_gain))
input_channel = out_channel

# building last CSP blocks
last_channel = _make_divisible(last_channel * width_multiple, round_nearest)
layers.append(Conv(input_channel, last_channel, k=3, s=2, version=version))
layers.append(SPP(last_channel, last_channel, k=(5, 9, 13), version=version))

self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Sequential(
nn.Linear(last_channel, last_channel),
nn.Hardswish(inplace=True),
nn.Dropout(p=0.2, inplace=True),
nn.Linear(last_channel, num_classes),
)

for m in self.modules():
if isinstance(m, nn.Conv2d):
pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
m.eps = 1e-3
m.momentum = 0.03
elif isinstance(m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6)):
m.inplace = True

def _forward_impl(self, x: Tensor) -> Tensor:
x = self.features(x)

x = self.avgpool(x)
x = torch.flatten(x, 1)

x = self.classifier(x)

return x

def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)


_block = {
"r3.1": BottleneckCSP,
"r4.0": C3,
}


def _darknet(
arch: str, pretrained: bool, progress: bool, *args: Any, **kwargs: Any
) -> DarkNet:
"""
Constructs a DarkNet architecture from
# TODO

"""
model = DarkNet(*args, **kwargs)

if pretrained:
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError(
"pretrained {} is not supported as of now".format(arch)
)
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)

return model


def darknet_s_r3_1(
pretrained: bool = False, progress: bool = True, **kwargs: Any
) -> DarkNet:
"""
Constructs a DarkNet with small channels, as described in release 3.1
# TODO

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _darknet("darknet_s_r3.1", pretrained, progress, 0.33, 0.5, "r3.1", **kwargs)


def darknet_m_r3_1(
pretrained: bool = False, progress: bool = True, **kwargs: Any
) -> DarkNet:
"""
Constructs a DarkNet with small channels, as described in release 3.1
# TODO

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _darknet(
"darknet_m_r3.1", pretrained, progress, 0.67, 0.75, "r3.1", **kwargs
)


def darknet_l_r3_1(
pretrained: bool = False, progress: bool = True, **kwargs: Any
) -> DarkNet:
"""
Constructs a DarkNet with small channels, as described in release 3.1
# TODO

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _darknet("darknet_l_r3.1", pretrained, progress, 1.0, 1.0, "r3.1", **kwargs)


def darknet_s_r4_0(
pretrained: bool = False, progress: bool = True, **kwargs: Any
) -> DarkNet:
"""
Constructs a DarkNet with small channels, as described in release 3.1
# TODO

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _darknet("darknet_s_r4.0", pretrained, progress, 0.33, 0.5, "r4.0", **kwargs)


def darknet_m_r4_0(
pretrained: bool = False, progress: bool = True, **kwargs: Any
) -> DarkNet:
"""
Constructs a DarkNet with small channels, as described in release 3.1
# TODO

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _darknet(
"darknet_m_r4.0", pretrained, progress, 0.67, 0.75, "r4.0", **kwargs
)


def darknet_l_r4_0(
pretrained: bool = False, progress: bool = True, **kwargs: Any
) -> DarkNet:
"""
Constructs a DarkNet with small channels, as described in release 3.1
# TODO

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _darknet("darknet_l_r4.0", pretrained, progress, 1.0, 1.0, "r4.0", **kwargs)
"darknet_s_r6_0",
"darknet_m_r6_0",
"darknet_l_r6_0",
)
Loading