-
Notifications
You must be signed in to change notification settings - Fork 43
/
Copy pathcapsulelayers.py
158 lines (129 loc) · 6.87 KB
/
capsulelayers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import keras.backend as K
import tensorflow as tf
from keras import initializers, layers
class Length(layers.Layer):
"""
Compute the length of vectors. This is used to compute a Tensor that has the same shape with y_true in margin_loss.
Using this layer as model's output can directly predict labels by using `y_pred = np.argmax(model.predict(x), 1)`
inputs: shape=[None, num_vectors, dim_vector]
output: shape=[None, num_vectors]
"""
def call(self, inputs, **kwargs):
return K.sqrt(K.sum(K.square(inputs), -1))
def compute_output_shape(self, input_shape):
return input_shape[:-1]
def get_config(self):
config = super(Length, self).get_config()
return config
class Mask(layers.Layer):
"""
Mask a Tensor with shape=[None, num_capsule, dim_vector] either by the capsule with max length or by an additional
input mask. Except the max-length capsule (or specified capsule), all vectors are masked to zeros. Then flatten the
masked Tensor.
For example:
```
x = keras.layers.Input(shape=[8, 3, 2]) # batch_size=8, each sample contains 3 capsules with dim_vector=2
y = keras.layers.Input(shape=[8, 3]) # True labels. 8 samples, 3 classes, one-hot coding.
out = Mask()(x) # out.shape=[8, 6]
# or
out2 = Mask()([x, y]) # out2.shape=[8,6]. Masked with true labels y. Of course y can also be manipulated.
```
"""
def call(self, inputs, **kwargs):
if type(inputs) is list:
assert len(inputs) == 2
inputs, mask = inputs
else:
x = K.sqrt(K.sum(K.square(inputs), -1))
mask = K.one_hot(indices=K.argmax(x, 1), num_classes=x.get_shape().as_list()[1])
masked = K.batch_flatten(inputs * K.expand_dims(mask, -1))
return masked
def compute_output_shape(self, input_shape):
if type(input_shape[0]) is tuple:
return tuple([None, input_shape[0][1] * input_shape[0][2]])
else:
return tuple([None, input_shape[1] * input_shape[2]])
def get_config(self):
config = super(Mask, self).get_config()
return config
def squash(vectors, axis=-1):
"""
The non-linear activation used in Capsule. It drives the length of a large vector to near 1 and small vector to 0
:param vectors: some vectors to be squashed, N-dim tensor
:param axis: the axis to squash
:return: a Tensor with same shape as input vectors
"""
s_squared_norm = K.sum(K.square(vectors), axis, keepdims=True)
scale = s_squared_norm / (1 + s_squared_norm) / K.sqrt(s_squared_norm + K.epsilon())
return scale * vectors
class CapsuleLayer(layers.Layer):
"""
The capsule layer. It is similar to Dense layer. Dense layer has `in_num` inputs, each is a scalar, the output of the
neuron from the former layer, and it has `out_num` output neurons. CapsuleLayer just expand the output of the neuron
from scalar to vector. So its input shape = [None, input_num_capsule, input_dim_capsule] and output shape = \
[None, num_capsule, dim_capsule]. For Dense Layer, input_dim_capsule = dim_capsule = 1.
:param num_capsule: number of capsules in this layer
:param dim_capsule: dimension of the output vectors of the capsules in this layer
:param routings: number of iterations for the routing algorithm
"""
def __init__(self, num_capsule, dim_capsule,channels, routings=3,
kernel_initializer='glorot_uniform',
**kwargs):
super(CapsuleLayer, self).__init__(**kwargs)
self.num_capsule = num_capsule
self.dim_capsule = dim_capsule
self.routings = routings
self.channels = channels
self.kernel_initializer = initializers.get(kernel_initializer)
def build(self, input_shape):
assert len(input_shape) >= 3, "The input Tensor should have shape=[None, input_num_capsule, input_dim_capsule]"
self.input_num_capsule = input_shape[1]
self.input_dim_capsule = input_shape[2]
if(self.channels!=0):
assert int(self.input_num_capsule/self.channels)/(self.input_num_capsule/self.channels)==1, "error"
self.W = self.add_weight(shape=[self.num_capsule, self.channels,
self.dim_capsule, self.input_dim_capsule],
initializer=self.kernel_initializer,
name='W')
self.B = self.add_weight(shape=[self.num_capsule,self.dim_capsule],
initializer=self.kernel_initializer,
name='B')
else:
self.W = self.add_weight(shape=[self.num_capsule, self.input_num_capsule,
self.dim_capsule, self.input_dim_capsule],
initializer=self.kernel_initializer,
name='W')
self.B = self.add_weight(shape=[self.num_capsule,self.dim_capsule],
initializer=self.kernel_initializer,
name='B')
self.built = True
def call(self, inputs, training=None):
inputs_expand = K.expand_dims(inputs, 1)
inputs_tiled = K.tile(inputs_expand, [1, self.num_capsule, 1, 1])
if(self.channels!=0):
W2 = K.repeat_elements(self.W,int(self.input_num_capsule/self.channels),1)
else:
W2 = self.W
inputs_hat = K.map_fn(lambda x: K.batch_dot(x, W2, [2, 3]) , elems=inputs_tiled)
b = tf.zeros(shape=[K.shape(inputs_hat)[0], self.num_capsule, self.input_num_capsule])
assert self.routings > 0, 'The routings should be > 0.'
for i in range(self.routings):
c = tf.nn.softmax(b, dim=1)
outputs = squash(K.batch_dot(c, inputs_hat, [2, 2])+ self.B)
if i < self.routings - 1:
b += K.batch_dot(outputs, inputs_hat, [2, 3])
return outputs
def compute_output_shape(self, input_shape):
return tuple([None, self.num_capsule, self.dim_capsule])
def PrimaryCap(inputs, dim_capsule, n_channels, kernel_size, strides, padding):
"""
Apply Conv2D `n_channels` times and concatenate all capsules
:param inputs: 4D tensor, shape=[None, width, height, channels]
:param dim_capsule: the dim of the output vector of capsule
:param n_channels: the number of types of capsules
:return: output tensor, shape=[None, num_capsule, dim_capsule]
"""
output = layers.Conv2D(filters=dim_capsule*n_channels, kernel_size=kernel_size, strides=strides, padding=padding,
name='primarycap_conv2d')(inputs)
outputs = layers.Reshape(target_shape=[-1, dim_capsule], name='primarycap_reshape')(output)
return layers.Lambda(squash, name='primarycap_squash')(outputs)