-
-
Notifications
You must be signed in to change notification settings - Fork 4.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
387 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,386 @@ | ||
""" RepViT | ||
Paper: `RepViT: Revisiting Mobile CNN From ViT Perspective` | ||
- https://arxiv.org/abs/2307.09283 | ||
@misc{wang2023repvit, | ||
title={RepViT: Revisiting Mobile CNN From ViT Perspective}, | ||
author={Ao Wang and Hui Chen and Zijia Lin and Hengjun Pu and Guiguang Ding}, | ||
year={2023}, | ||
eprint={2307.09283}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.CV} | ||
} | ||
Adapted from official impl at https://github.com/jameslahm/RepViT | ||
""" | ||
|
||
__all__ = ['RepViT'] | ||
|
||
import torch.nn as nn | ||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | ||
from ._registry import register_model, generate_default_cfgs | ||
from ._builder import build_model_with_cfg | ||
from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple | ||
from ._manipulate import checkpoint_seq | ||
|
||
import torch | ||
|
||
|
||
class ConvNorm(nn.Sequential): | ||
def __init__(self, in_dim, out_dim, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): | ||
super().__init__() | ||
self.add_module('c', nn.Conv2d(in_dim, out_dim, ks, stride, pad, dilation, groups, bias=False)) | ||
self.add_module('bn', nn.BatchNorm2d(out_dim)) | ||
nn.init.constant_(self.bn.weight, bn_weight_init) | ||
nn.init.constant_(self.bn.bias, 0) | ||
|
||
@torch.no_grad() | ||
def fuse(self): | ||
c, bn = self._modules.values() | ||
w = bn.weight / (bn.running_var + bn.eps) ** 0.5 | ||
w = c.weight * w[:, None, None, None] | ||
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 | ||
m = nn.Conv2d( | ||
w.size(1) * self.c.groups, | ||
w.size(0), | ||
w.shape[2:], | ||
stride=self.c.stride, | ||
padding=self.c.padding, | ||
dilation=self.c.dilation, | ||
groups=self.c.groups, | ||
device=c.weight.device, | ||
) | ||
m.weight.data.copy_(w) | ||
m.bias.data.copy_(b) | ||
return m | ||
|
||
|
||
class NormLinear(nn.Sequential): | ||
def __init__(self, in_dim, out_dim, bias=True, std=0.02): | ||
super().__init__() | ||
self.add_module('bn', nn.BatchNorm1d(in_dim)) | ||
self.add_module('l', nn.Linear(in_dim, out_dim, bias=bias)) | ||
trunc_normal_(self.l.weight, std=std) | ||
if bias: | ||
nn.init.constant_(self.l.bias, 0) | ||
|
||
@torch.no_grad() | ||
def fuse(self): | ||
bn, l = self._modules.values() | ||
w = bn.weight / (bn.running_var + bn.eps) ** 0.5 | ||
b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5 | ||
w = l.weight * w[None, :] | ||
if l.bias is None: | ||
b = b @ self.l.weight.T | ||
else: | ||
b = (l.weight @ b[:, None]).view(-1) + self.l.bias | ||
m = nn.Linear(w.size(1), w.size(0), device=l.weight.device) | ||
m.weight.data.copy_(w) | ||
m.bias.data.copy_(b) | ||
return m | ||
|
||
|
||
class RepVGGDW(nn.Module): | ||
def __init__(self, ed, kernel_size): | ||
super().__init__() | ||
self.conv = ConvNorm(ed, ed, kernel_size, 1, (kernel_size - 1) // 2, groups=ed) | ||
self.conv1 = ConvNorm(ed, ed, 1, 1, 0, groups=ed) | ||
self.dim = ed | ||
|
||
def forward(self, x): | ||
return self.conv(x) + self.conv1(x) + x | ||
|
||
@torch.no_grad() | ||
def fuse(self): | ||
conv = self.conv.fuse() | ||
conv1 = self.conv1.fuse() | ||
|
||
conv_w = conv.weight | ||
conv_b = conv.bias | ||
conv1_w = conv1.weight | ||
conv1_b = conv1.bias | ||
|
||
conv1_w = nn.functional.pad(conv1_w, [1, 1, 1, 1]) | ||
|
||
identity = nn.functional.pad( | ||
torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), [1, 1, 1, 1] | ||
) | ||
|
||
final_conv_w = conv_w + conv1_w + identity | ||
final_conv_b = conv_b + conv1_b | ||
|
||
conv.weight.data.copy_(final_conv_w) | ||
conv.bias.data.copy_(final_conv_b) | ||
return conv | ||
|
||
|
||
class RepViTMlp(nn.Module): | ||
def __init__(self, in_dim, hidden_dim, act_layer): | ||
super().__init__() | ||
self.conv1 = ConvNorm(in_dim, hidden_dim, 1, 1, 0) | ||
self.act = act_layer() | ||
self.conv2 = ConvNorm(hidden_dim, in_dim, 1, 1, 0, bn_weight_init=0) | ||
|
||
def forward(self, x): | ||
return self.conv2(self.act(self.conv1(x))) | ||
|
||
|
||
class RepViTBlock(nn.Module): | ||
def __init__(self, in_dim, mlp_ratio, kernel_size, use_se, act_layer): | ||
super(RepViTBlock, self).__init__() | ||
|
||
self.token_mixer = RepVGGDW(in_dim, kernel_size) | ||
self.se = SqueezeExcite(in_dim, 0.25) if use_se else nn.Identity() | ||
self.channel_mixer = RepViTMlp(in_dim, in_dim * mlp_ratio, act_layer) | ||
|
||
def forward(self, x): | ||
x = self.token_mixer(x) | ||
x = self.se(x) | ||
identity = x | ||
x = self.channel_mixer(x) | ||
return identity + x | ||
|
||
|
||
class RepViTStem(nn.Module): | ||
def __init__(self, in_chs, out_chs, act_layer): | ||
super().__init__() | ||
self.conv1 = ConvNorm(in_chs, out_chs // 2, 3, 2, 1) | ||
self.act1 = act_layer() | ||
self.conv2 = ConvNorm(out_chs // 2, out_chs, 3, 2, 1) | ||
self.stride = 4 | ||
|
||
def forward(self, x): | ||
return self.conv2(self.act1(self.conv1(x))) | ||
|
||
|
||
class RepViTDownsample(nn.Module): | ||
def __init__(self, in_dim, mlp_ratio, out_dim, kernel_size, act_layer): | ||
super().__init__() | ||
self.pre_block = RepViTBlock(in_dim, mlp_ratio, kernel_size, use_se=False, act_layer=act_layer) | ||
self.spatial_downsample = ConvNorm(in_dim, in_dim, kernel_size, 2, (kernel_size - 1) // 2, groups=in_dim) | ||
self.channel_downsample = ConvNorm(in_dim, out_dim, 1, 1) | ||
self.ffn = RepViTMlp(out_dim, out_dim * mlp_ratio, act_layer) | ||
|
||
def forward(self, x): | ||
x = self.pre_block(x) | ||
x = self.spatial_downsample(x) | ||
x = self.channel_downsample(x) | ||
identity = x | ||
x = self.ffn(x) | ||
return x + identity | ||
|
||
|
||
class RepViTClassifier(nn.Module): | ||
def __init__(self, dim, num_classes, distillation=False): | ||
super().__init__() | ||
assert num_classes > 0 | ||
self.head = NormLinear(dim, num_classes) | ||
self.distillation = distillation | ||
if distillation: | ||
self.head_dist = NormLinear(dim, num_classes) | ||
|
||
def forward(self, x): | ||
if self.distillation: | ||
x = self.head(x), self.head_dist(x) | ||
if not self.training: | ||
x = (x[0] + x[1]) / 2 | ||
else: | ||
x = self.head(x) | ||
return x | ||
|
||
@torch.no_grad() | ||
def fuse(self): | ||
head = self.head.fuse() | ||
if self.distillation: | ||
head_dist = self.head_dist.fuse() | ||
head.weight += head_dist.weight | ||
head.bias += head_dist.bias | ||
head.weight /= 2 | ||
head.bias /= 2 | ||
return head | ||
else: | ||
return head | ||
|
||
|
||
class RepViTStage(nn.Module): | ||
def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, downsample=True): | ||
super().__init__() | ||
if downsample: | ||
self.downsample = RepViTDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer) | ||
else: | ||
assert in_dim == out_dim | ||
self.downsample = nn.Identity() | ||
|
||
blocks = [] | ||
use_se = True | ||
for _ in range(depth): | ||
blocks.append(RepViTBlock(out_dim, mlp_ratio, kernel_size, use_se, act_layer)) | ||
use_se = not use_se | ||
|
||
self.blocks = nn.Sequential(*blocks) | ||
|
||
def forward(self, x): | ||
x = self.downsample(x) | ||
x = self.blocks(x) | ||
return x | ||
|
||
|
||
class RepViT(nn.Module): | ||
def __init__( | ||
self, | ||
in_chans=3, | ||
img_size=224, | ||
embed_dim=(48,), | ||
depth=(2,), | ||
mlp_ratio=2, | ||
global_pool='avg', | ||
kernel_size=3, | ||
num_classes=1000, | ||
act_layer=nn.GELU, | ||
distillation=True, | ||
): | ||
super(RepViT, self).__init__() | ||
self.grad_checkpointing = False | ||
self.global_pool = global_pool | ||
self.embed_dim = embed_dim | ||
|
||
in_dim = embed_dim[0] | ||
self.stem = RepViTStem(in_chans, in_dim, act_layer) | ||
stride = self.stem.stride | ||
resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))]) | ||
|
||
num_stages = len(embed_dim) | ||
mlp_ratios = to_ntuple(num_stages)(mlp_ratio) | ||
|
||
self.feature_info = [] | ||
stages = [] | ||
for i in range(num_stages): | ||
downsample = True if i != 0 else False | ||
stages.append( | ||
RepViTStage( | ||
in_dim, | ||
embed_dim[i], | ||
depth[i], | ||
mlp_ratio=mlp_ratios[i], | ||
act_layer=act_layer, | ||
kernel_size=kernel_size, | ||
downsample=downsample, | ||
) | ||
) | ||
stage_stride = 2 if downsample else 1 | ||
stride *= stage_stride | ||
resolution = tuple([(r - 1) // stage_stride + 1 for r in resolution]) | ||
self.feature_info += [dict(num_chs=embed_dim[i], reduction=stride, module=f'stages.{i}')] | ||
in_dim = embed_dim[i] | ||
self.stages = nn.Sequential(*stages) | ||
|
||
self.head = RepViTClassifier(embed_dim[-1], num_classes, distillation) | ||
|
||
@torch.jit.ignore | ||
def set_grad_checkpointing(self, enable=True): | ||
self.grad_checkpointing = enable | ||
|
||
@torch.jit.ignore | ||
def get_classifier(self): | ||
return self.head | ||
|
||
def reset_classifier(self, num_classes, global_pool=None, distillation=False): | ||
self.num_classes = num_classes | ||
if global_pool is not None: | ||
self.global_pool = global_pool | ||
self.head = RepViTClassifier(self.embed_dim[-1], num_classes, distillation) | ||
|
||
def forward_features(self, x): | ||
x = self.stem(x) | ||
if self.grad_checkpointing and not torch.jit.is_scripting(): | ||
x = checkpoint_seq(self.stages, x) | ||
else: | ||
x = self.stages(x) | ||
return x | ||
|
||
def forward_head(self, x, pre_logits: bool = False): | ||
if self.global_pool == 'avg': | ||
x = nn.functional.adaptive_avg_pool2d(x, 1).flatten(1) | ||
return x if pre_logits else self.head(x) | ||
|
||
def forward(self, x): | ||
x = self.forward_features(x) | ||
x = self.forward_head(x) | ||
return x | ||
|
||
@torch.no_grad() | ||
def fuse(self): | ||
def fuse_children(net): | ||
for child_name, child in net.named_children(): | ||
if hasattr(child, 'fuse'): | ||
fused = child.fuse() | ||
setattr(net, child_name, fused) | ||
fuse_children(fused) | ||
else: | ||
fuse_children(child) | ||
|
||
fuse_children(self) | ||
|
||
|
||
def _cfg(url='', **kwargs): | ||
return { | ||
'url': url, | ||
'num_classes': 1000, | ||
'input_size': (3, 224, 224), | ||
'pool_size': None, | ||
'crop_pct': 0.95, | ||
'interpolation': 'bicubic', | ||
'mean': IMAGENET_DEFAULT_MEAN, | ||
'std': IMAGENET_DEFAULT_STD, | ||
'first_conv': 'stem.conv1', | ||
'classifier': ('head'), | ||
**kwargs, | ||
} | ||
|
||
|
||
default_cfgs = generate_default_cfgs( | ||
{ | ||
'repvit_m1': _cfg( | ||
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_distill_300_timm.pth' | ||
), | ||
'repvit_m2': _cfg( | ||
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_distill_300_timm.pth' | ||
), | ||
'repvit_m3': _cfg( | ||
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m3_distill_300_timm.pth' | ||
), | ||
} | ||
) | ||
|
||
|
||
def _create_repvit(variant, pretrained=False, **kwargs): | ||
model = build_model_with_cfg(RepViT, variant, pretrained, **kwargs) | ||
return model | ||
|
||
|
||
@register_model | ||
def repvit_m1(pretrained=False, **kwargs): | ||
""" | ||
Constructs a RepViT-M1 model | ||
""" | ||
model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2)) | ||
return _create_repvit('repvit_m1', pretrained=pretrained, **dict(model_args, **kwargs)) | ||
|
||
|
||
@register_model | ||
def repvit_m2(pretrained=False, **kwargs): | ||
""" | ||
Constructs a RepViT-M2 model | ||
""" | ||
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2)) | ||
return _create_repvit('repvit_m2', pretrained=pretrained, **dict(model_args, **kwargs)) | ||
|
||
|
||
@register_model | ||
def repvit_m3(pretrained=False, **kwargs): | ||
""" | ||
Constructs a RepViT-M3 model | ||
""" | ||
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 18, 2)) | ||
return _create_repvit('repvit_m3', pretrained=pretrained, **dict(model_args, **kwargs)) |