-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathquan_ops.py
159 lines (121 loc) · 5.11 KB
/
quan_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Adapted from
# https://github.com/zzzxxxttt/pytorch_DoReFaNet/blob/master/utils/quant_dorefa.py and
# https://github.com/tensorpack/tensorpack/blob/master/examples/DoReFa-Net/dorefa.py#L25
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class SwitchBatchNorm2d(nn.Module):
"""Adapted from https://github.com/JiahuiYu/slimmable_networks
"""
def __init__(self, num_features, bit_list):
super(SwitchBatchNorm2d, self).__init__()
self.bit_list = bit_list
self.bn_dict = nn.ModuleDict()
for i in self.bit_list:
self.bn_dict[str(i)] = nn.BatchNorm2d(num_features)
self.abit = self.bit_list[-1]
self.wbit = self.bit_list[-1]
if self.abit != self.wbit:
raise ValueError('Currenty only support same activation and weight bit width!')
def forward(self, x):
x = self.bn_dict[str(self.abit)](x)
return x
def batchnorm2d_fn(bit_list):
class SwitchBatchNorm2d_(SwitchBatchNorm2d):
def __init__(self, num_features, bit_list=bit_list):
super(SwitchBatchNorm2d_, self).__init__(num_features=num_features, bit_list=bit_list)
return SwitchBatchNorm2d_
class SwitchBatchNorm1d(nn.Module):
"""Adapted from https://github.com/JiahuiYu/slimmable_networks
"""
def __init__(self, num_features, bit_list):
super(SwitchBatchNorm1d, self).__init__()
self.bit_list = bit_list
self.bn_dict = nn.ModuleDict()
for i in self.bit_list:
self.bn_dict[str(i)] = nn.BatchNorm1d(num_features)
self.abit = self.bit_list[-1]
self.wbit = self.bit_list[-1]
if self.abit != self.wbit:
raise ValueError('Currenty only support same activation and weight bit width!')
def forward(self, x):
x = self.bn_dict[str(self.abit)](x)
return x
def batchnorm1d_fn(bit_list):
class SwitchBatchNorm1d_(SwitchBatchNorm1d):
def __init__(self, num_features, bit_list=bit_list):
super(SwitchBatchNorm1d_, self).__init__(num_features=num_features, bit_list=bit_list)
return SwitchBatchNorm1d_
class qfn(torch.autograd.Function):
@staticmethod
def forward(ctx, input, k):
n = float(2**k - 1)
out = torch.round(input * n) / n
return out
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
return grad_input, None
class weight_quantize_fn(nn.Module):
def __init__(self, bit_list):
super(weight_quantize_fn, self).__init__()
self.bit_list = bit_list
self.wbit = self.bit_list[-1]
assert self.wbit <= 8 or self.wbit == 32
def forward(self, x):
if self.wbit == 32:
E = torch.mean(torch.abs(x)).detach()
weight = torch.tanh(x)
weight = weight / torch.max(torch.abs(weight))
weight_q = weight * E
else:
E = torch.mean(torch.abs(x)).detach()
weight = torch.tanh(x)
weight = weight / 2 / torch.max(torch.abs(weight)) + 0.5
weight_q = 2 * qfn.apply(weight, self.wbit) - 1
weight_q = weight_q * E
return weight_q
class activation_quantize_fn(nn.Module):
def __init__(self, bit_list):
super(activation_quantize_fn, self).__init__()
self.bit_list = bit_list
self.abit = self.bit_list[-1]
assert self.abit <= 8 or self.abit == 32
def forward(self, x):
if self.abit == 32:
activation_q = x
else:
activation_q = qfn.apply(x, self.abit)
return activation_q
class Conv2d_Q(nn.Conv2d):
def __init__(self, *kargs, **kwargs):
super(Conv2d_Q, self).__init__(*kargs, **kwargs)
def conv2d_quantize_fn(bit_list):
class Conv2d_Q_(Conv2d_Q):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,
bias=True):
super(Conv2d_Q_, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups,
bias)
self.bit_list = bit_list
self.w_bit = self.bit_list[-1]
self.quantize_fn = weight_quantize_fn(self.bit_list)
def forward(self, input, order=None):
weight_q = self.quantize_fn(self.weight)
return F.conv2d(input, weight_q, self.bias, self.stride, self.padding, self.dilation, self.groups)
return Conv2d_Q_
class Linear_Q(nn.Linear):
def __init__(self, *kargs, **kwargs):
super(Linear_Q, self).__init__(*kargs, **kwargs)
def linear_quantize_fn(bit_list):
class Linear_Q_(Linear_Q):
def __init__(self, in_features, out_features, bias=True):
super(Linear_Q_, self).__init__(in_features, out_features, bias)
self.bit_list = bit_list
self.w_bit = self.bit_list[-1]
self.quantize_fn = weight_quantize_fn(self.bit_list)
def forward(self, input):
weight_q = self.quantize_fn(self.weight)
return F.linear(input, weight_q, self.bias)
return Linear_Q_
batchnorm_fn = batchnorm2d_fn