Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of the MNASNet family of models #829

Merged
merged 23 commits into from
Jun 24, 2019
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .densenet import *
from .googlenet import *
from .mobilenet import *
from .mnasnet import *
1e100 marked this conversation as resolved.
Show resolved Hide resolved
163 changes: 163 additions & 0 deletions torchvision/models/mnasnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import math

import torch
import torch.nn as nn

__all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3']

# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
# 1.0 - tensorflow.
_BN_MOMENTUM = 1 - 0.9997


class _InvertedResidual(nn.Module):
1e100 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor,
bn_momentum=0.1):
super(_InvertedResidual, self).__init__()
assert stride in [1, 2]
assert kernel_size in [3, 5]
mid_ch = in_ch * expansion_factor
self.apply_residual = (in_ch == out_ch and stride == 1)
self.layers = nn.Sequential(
# Pointwise
nn.Conv2d(in_ch, mid_ch, 1, bias=False),
nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
nn.ReLU(inplace=True),
# Depthwise
nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2,
stride=stride, groups=mid_ch, bias=False),
nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
nn.ReLU(inplace=True),
# Linear pointwise. Note that there's no activation.
nn.Conv2d(mid_ch, out_ch, 1, bias=False),
nn.BatchNorm2d(out_ch, momentum=bn_momentum))

def forward(self, input):
if self.apply_residual:
return self.layers(input) + input
else:
return self.layers(input)


def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats,
bn_momentum):
""" Creates a stack of inverted residuals. """
assert repeats >= 1
# First one has no skip, because feature map size changes.
first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor,
bn_momentum=bn_momentum)
remaining = []
for _ in range(1, repeats):
remaining.append(
_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor,
bn_momentum=bn_momentum))
return nn.Sequential(first, *remaining)


def _round_to_multiple_of(val, divisor, round_up_bias=0.9):
""" Asymmetric rounding to make `val` divisible by `divisor`. With default
bias, will round up, unless the number is no more than 10% greater than the
smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """
assert 0.0 < round_up_bias < 1.0
new_val = max(divisor, int(val + divisor / 2) // divisor * divisor)
return new_val if new_val >= round_up_bias * val else new_val + divisor


def _scale_depths(depths, alpha):
""" Scales tensor depths as in reference MobileNet code, prefers rouding up
rather than down. """
return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]


class MNASNet(torch.nn.Module):
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf.
>>> model = MNASNet(1000, 1.0)
>>> x = torch.rand(1, 3, 224, 224)
>>> y = model.forward(x)
1e100 marked this conversation as resolved.
Show resolved Hide resolved
>>> y.dim()
1
>>> y.nelement()
1000
"""

def __init__(self, num_classes, alpha, dropout=0.2):
super(MNASNet, self).__init__()
self.alpha = alpha
self.num_classes = num_classes
1e100 marked this conversation as resolved.
Show resolved Hide resolved
self.dropout = dropout
depths = _scale_depths([24, 40, 80, 96, 192, 320], alpha)
layers = [
1e100 marked this conversation as resolved.
Show resolved Hide resolved
# First layer: regular conv.
nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
# Depthwise separable, no skip.
nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False),
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
# MNASNet blocks: stacks of inverted residuals.
_stack(16, depths[0], 3, 2, 3, 3, _BN_MOMENTUM),
_stack(depths[0], depths[1], 5, 2, 3, 3, _BN_MOMENTUM),
_stack(depths[1], depths[2], 5, 2, 6, 3, _BN_MOMENTUM),
_stack(depths[2], depths[3], 3, 1, 6, 2, _BN_MOMENTUM),
_stack(depths[3], depths[4], 5, 2, 6, 4, _BN_MOMENTUM),
_stack(depths[4], depths[5], 3, 1, 6, 1, _BN_MOMENTUM),
# Final mapping to classifier input.
nn.Conv2d(depths[5], 1280, 1, padding=0, stride=1, bias=False),
nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
]
self.layers = nn.Sequential(*layers)
self.classifier = nn.Linear(1280, self.num_classes)

self._initialize_weights()

def features(self, x):
1e100 marked this conversation as resolved.
Show resolved Hide resolved
return self.layers(x)

def forward(self, x):
x = self.features(x)
1e100 marked this conversation as resolved.
Show resolved Hide resolved
# Equivalent to global avgpool and removing H and W dimensions.
x = x.mean([2, 3])
if self.dropout > 0.0:
x = nn.functional.dropout(x, p=self.dropout, training=self.training,
1e100 marked this conversation as resolved.
Show resolved Hide resolved
inplace=True)
return self.classifier(x)

def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
1e100 marked this conversation as resolved.
Show resolved Hide resolved
m.weight.data.normal_(0, math.sqrt(2.0 / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()


def mnasnet0_5(num_classes):
1e100 marked this conversation as resolved.
Show resolved Hide resolved
""" MNASNet with depth multiplier of 0.5. """
return MNASNet(num_classes, alpha=0.5)


def mnasnet0_75(num_classes):
""" MNASNet with depth multiplier of 0.75. """
return MNASNet(num_classes, alpha=0.75)


def mnasnet1_0(num_classes):
""" MNASNet with depth multiplier of 1.0. """
return MNASNet(num_classes, alpha=1.0)


def mnasnet1_3(num_classes):
""" MNASNet with depth multiplier of 1.3. """
return MNASNet(num_classes, alpha=1.3)