forked from hillbig/binary_net
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathqst.py
68 lines (49 loc) · 1.55 KB
/
qst.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy
import util
from chainer import cuda
from chainer import function
from chainer.utils import type_check
class QST(function.Function):
"""Quantized with Straight Thourgh estimator Unit."""
def __init__(self):
pass
def check_type_forward(self, in_types):
type_check.expect(in_types.size() == 1)
x_type, = in_types
type_check.expect(
x_type.dtype == numpy.float32,
)
def forward_cpu(self, x):
xq = util._log_quant_cpu(x[0] * 64, 64)
y = xq / float(64 ** 2)
return y,
def forward_gpu(self, x):
xq = util._log_quant_gpu(x[0] * 64, 64)
y = xq / float(64 ** 2)
return y,
def backward_cpu(self, x, gy):
gx = gy[0].copy()
zero_indices = numpy.abs(x[0]) > 1
gx[zero_indices] = 0
return gx,
def backward_gpu(self, x, gy):
gx = cuda.elementwise(
'T x, T gy', 'T gx',
'gx = abs(x) > 1 ? 0 : gy', 'bst_bwd')(
x[0], gy[0])
return gx,
def qst(x):
"""Quantized with Straight Thourgh estimator Unit function.
This function is expressed as
.. math::
f(x) = \\left \\{ \\begin{array}{ll}
1 & {\\rm if}~ x \\ge 0 \\\\
-1 & {\\rm if}~ x < 0,
\\end{array} \\right.
See: http://arxiv.org/abs/1511.07289
Args:
x (~chainer.Variable): Input variable.
Returns:
~chainer.Variable: Output variable.
"""
return QST()(x)