-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathcompact_bilinear_pooling.py
157 lines (132 loc) · 6.8 KB
/
compact_bilinear_pooling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from __future__ import absolute_import, division, print_function
import numpy as np
import tensorflow as tf
from sequential_fft import sequential_batch_fft, sequential_batch_ifft
def _fft(bottom, sequential, compute_size):
if sequential:
return sequential_batch_fft(bottom, compute_size)
else:
return tf.fft(bottom)
def _ifft(bottom, sequential, compute_size):
if sequential:
return sequential_batch_ifft(bottom, compute_size)
else:
return tf.ifft(bottom)
def _generate_sketch_matrix(rand_h, rand_s, output_dim):
"""
Return a sparse matrix used for tensor sketch operation in compact bilinear
pooling
Args:
rand_h: an 1D numpy array containing indices in interval `[0, output_dim)`.
rand_s: an 1D numpy array of 1 and -1, having the same shape as `rand_h`.
output_dim: the output dimensions of compact bilinear pooling.
Returns:
a sparse matrix of shape [input_dim, output_dim] for tensor sketch.
"""
# Generate a sparse matrix for tensor count sketch
rand_h = rand_h.astype(np.int64)
rand_s = rand_s.astype(np.float32)
assert(rand_h.ndim==1 and rand_s.ndim==1 and len(rand_h)==len(rand_s))
assert(np.all(rand_h >= 0) and np.all(rand_h < output_dim))
input_dim = len(rand_h)
indices = np.concatenate((np.arange(input_dim)[..., np.newaxis],
rand_h[..., np.newaxis]), axis=1)
sparse_sketch_matrix = tf.sparse_reorder(
tf.SparseTensor(indices, rand_s, [input_dim, output_dim]))
return sparse_sketch_matrix
def compact_bilinear_pooling_layer(bottom1, bottom2, output_dim, sum_pool=True,
rand_h_1=None, rand_s_1=None, rand_h_2=None, rand_s_2=None,
seed_h_1=1, seed_s_1=3, seed_h_2=5, seed_s_2=7, sequential=True,
compute_size=128):
"""
Compute compact bilinear pooling over two bottom inputs. Reference:
Yang Gao, et al. "Compact Bilinear Pooling." in Proceedings of IEEE
Conference on Computer Vision and Pattern Recognition (2016).
Akira Fukui, et al. "Multimodal Compact Bilinear Pooling for Visual Question
Answering and Visual Grounding." arXiv preprint arXiv:1606.01847 (2016).
Args:
bottom1: 1st input, 4D Tensor of shape [batch_size, height, width, input_dim1].
bottom2: 2nd input, 4D Tensor of shape [batch_size, height, width, input_dim2].
output_dim: output dimension for compact bilinear pooling.
sum_pool: (Optional) If True, sum the output along height and width
dimensions and return output shape [batch_size, output_dim].
Otherwise return [batch_size, height, width, output_dim].
Default: True.
rand_h_1: (Optional) an 1D numpy array containing indices in interval
`[0, output_dim)`. Automatically generated from `seed_h_1`
if is None.
rand_s_1: (Optional) an 1D numpy array of 1 and -1, having the same shape
as `rand_h_1`. Automatically generated from `seed_s_1` if is
None.
rand_h_2: (Optional) an 1D numpy array containing indices in interval
`[0, output_dim)`. Automatically generated from `seed_h_2`
if is None.
rand_s_2: (Optional) an 1D numpy array of 1 and -1, having the same shape
as `rand_h_2`. Automatically generated from `seed_s_2` if is
None.
sequential: (Optional) if True, use the sequential FFT and IFFT
instead of tf.batch_fft or tf.batch_ifft to avoid
out-of-memory (OOM) error.
Note: sequential FFT and IFFT are only available on GPU
Default: True.
compute_size: (Optional) The maximum size of sub-batch to be forwarded
through FFT or IFFT in one time. Large compute_size may
be faster but can cause OOM and FFT failure. This
parameter is only effective when sequential == True.
Default: 128.
Returns:
Compact bilinear pooled results of shape [batch_size, output_dim] or
[batch_size, height, width, output_dim], depending on `sum_pool`.
"""
# Static shapes are needed to construction count sketch matrix
input_dim1 = bottom1.get_shape().as_list()[-1]
input_dim2 = bottom2.get_shape().as_list()[-1]
# Step 0: Generate vectors and sketch matrix for tensor count sketch
# This is only done once during graph construction, and fixed during each
# operation
if rand_h_1 is None:
np.random.seed(seed_h_1)
rand_h_1 = np.random.randint(output_dim, size=input_dim1)
if rand_s_1 is None:
np.random.seed(seed_s_1)
rand_s_1 = 2*np.random.randint(2, size=input_dim1) - 1
sparse_sketch_matrix1 = _generate_sketch_matrix(rand_h_1, rand_s_1, output_dim)
if rand_h_2 is None:
np.random.seed(seed_h_2)
rand_h_2 = np.random.randint(output_dim, size=input_dim2)
if rand_s_2 is None:
np.random.seed(seed_s_2)
rand_s_2 = 2*np.random.randint(2, size=input_dim2) - 1
sparse_sketch_matrix2 = _generate_sketch_matrix(rand_h_2, rand_s_2, output_dim)
# Step 1: Flatten the input tensors and count sketch
bottom1_flat = tf.reshape(bottom1, [-1, input_dim1])
bottom2_flat = tf.reshape(bottom2, [-1, input_dim2])
# Essentially:
# sketch1 = bottom1 * sparse_sketch_matrix
# sketch2 = bottom2 * sparse_sketch_matrix
# But tensorflow only supports left multiplying a sparse matrix, so:
# sketch1 = (sparse_sketch_matrix.T * bottom1.T).T
# sketch2 = (sparse_sketch_matrix.T * bottom2.T).T
sketch1 = tf.transpose(tf.sparse_tensor_dense_matmul(sparse_sketch_matrix1,
bottom1_flat, adjoint_a=True, adjoint_b=True))
sketch2 = tf.transpose(tf.sparse_tensor_dense_matmul(sparse_sketch_matrix2,
bottom2_flat, adjoint_a=True, adjoint_b=True))
# Step 2: FFT
fft1 = _fft(tf.complex(real=sketch1, imag=tf.zeros_like(sketch1)),
sequential, compute_size)
fft2 = _fft(tf.complex(real=sketch2, imag=tf.zeros_like(sketch2)),
sequential, compute_size)
# Step 3: Elementwise product
fft_product = tf.multiply(fft1, fft2)
# Step 4: Inverse FFT and reshape back
# Compute output shape dynamically: [batch_size, height, width, output_dim]
cbp_flat = tf.real(_ifft(fft_product, sequential, compute_size))
output_shape = tf.add(tf.multiply(tf.shape(bottom1), [1, 1, 1, 0]),
[0, 0, 0, output_dim])
cbp = tf.reshape(cbp_flat, output_shape)
# set static shape for the output
cbp.set_shape(bottom1.get_shape().as_list()[:-1] + [output_dim])
# Step 5: Sum pool over spatial dimensions, if specified
if sum_pool:
cbp = tf.reduce_sum(cbp, reduction_indices=[1, 2])
return cbp