-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcommon.py
75 lines (55 loc) · 1.87 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import timm
from functools import partial
from torch import Tensor
def l2_norm(input, axis=1):
norm = torch.norm(input, 2, axis, True)
output = torch.div(input, norm)
return output
def norm_feature(feature, p=2, dim=1):
feature_norm = torch.norm(feature, p=p, dim=dim, keepdim=True).clamp(min=1e-12) ** 0.5 * (2) ** 0.5
feature = torch.div(feature, feature_norm)
return feature
class ChannelShuffleCustom(nn.Module):
def __init__(self, groups=16):
super().__init__()
self.groups = groups
def forward(self, x):
if not self.training:
return x
batch, channels, height, width = x.size()
assert (channels % self.groups == 0)
channels_per_group = channels // self.groups
x = x.view(batch, channels_per_group, self.groups, height, width)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(batch, channels, height, width)
return x
class RandomZero(nn.Module):
def __init__(self, p=0.1):
super().__init__()
self.p = p
def forward(self, x):
if not self.training:
return x
mask = torch.ones_like(x, dtype=x.dtype, device=x.device)
channel_size = x.size(1)
zero_index = int(self.p * channel_size)
perm = torch.randperm(channel_size-1)[:zero_index]
mask[:, perm, :, :] = 1e-8
return x * mask
class RandomReplace(nn.Module):
def __init__(self, p=0.1):
super().__init__()
self.p = p
def forward(self, x):
if not self.training:
return x
channel_size = x.size(1)
replaced_index = int(self.p * channel_size)
perm = torch.randperm(channel_size-1)[:replaced_index]
x[:, perm, :, :] = x[:, perm+1, :, :].clone()
return x