forked from hillbig/binary_net
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathbst.py
69 lines (50 loc) · 1.51 KB
/
bst.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy
from chainer import cuda
from chainer import function
from chainer.utils import type_check
class BST(function.Function):
"""Binary with Straight Thourgh estimator Unit."""
def __init__(self):
pass
def check_type_forward(self, in_types):
type_check.expect(in_types.size() == 1)
x_type, = in_types
type_check.expect(
x_type.dtype == numpy.float32,
)
def forward_cpu(self, x):
y = x[0]
y = numpy.where(y>=0, 1, -1).astype(numpy.float32, copy=False)
return y,
def forward_gpu(self, x):
y = cuda.elementwise(
'T x', 'T y',
'y = x >= 0 ? 1 : -1', 'bst_fwd')(
x[0])
return y,
def backward_cpu(self, x, gy):
gx = gy[0].copy()
zero_indices = numpy.abs(x[0]) > 1
gx[zero_indices] = 0
return gx,
def backward_gpu(self, x, gy):
gx = cuda.elementwise(
'T x, T gy', 'T gx',
'gx = abs(x) > 1 ? 0 : gy', 'bst_bwd')(
x[0], gy[0])
return gx,
def bst(x):
"""Binary with Straight Thourgh estimator Unit function.
This function is expressed as
.. math::
f(x) = \\left \\{ \\begin{array}{ll}
1 & {\\rm if}~ x \\ge 0 \\\\
-1 & {\\rm if}~ x < 0,
\\end{array} \\right.
See: http://arxiv.org/abs/1511.07289
Args:
x (~chainer.Variable): Input variable.
Returns:
~chainer.Variable: Output variable.
"""
return BST()(x)