-
Notifications
You must be signed in to change notification settings - Fork 65
/
Copy pathmodel_v2.py
49 lines (39 loc) · 2.19 KB
/
model_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import tensorflow as tf
class CopyModel(tf.keras.Model):
def __init__(self, batch_size, vector_dim, model_type, cell_params):
super().__init__()
self.batch_size = batch_size
self.vector_dim = vector_dim
self.eof = tf.one_hot([self.vector_dim] * batch_size, depth=self.vector_dim+1)
self.zero = tf.zeros([batch_size, vector_dim + 1], dtype=tf.float32)
self.model_type = model_type
self.cell_params = cell_params
if self.model_type == 'LSTM':
self.cell = tf.keras.layers.StackedRNNCells(
[tf.keras.layers.LSTMCell(units=self.cell_params['rnn_size']) for _ in range(self.cell_params['rnn_num_layers'])])
elif self.model_type == 'NTM':
from ntm.ntm_cell_v2 import NTMCell
self.cell = NTMCell(rnn_size=self.cell_params['rnn_size'],
memory_size=self.cell_params['memory_size'],
memory_vector_dim=self.cell_params['memory_vector_dim'],
read_head_num=self.cell_params['read_head_num'],
write_head_num=self.cell_params['write_head_num'],
addressing_mode='content_and_location',
output_dim=self.vector_dim)
else:
raise ValueError('Model type not supported')
@tf.function
def call(self, inputs):
x, seq_length = inputs
x_list = tf.TensorArray(dtype=tf.float32, size=seq_length)
x_list = x_list.unstack(tf.transpose(x, perm=[1, 0, 2]))
state = self.cell.get_initial_state(batch_size=self.batch_size, dtype=tf.float32)
for t in range(seq_length):
output, state = self.cell(tf.concat([x_list.read(t), tf.zeros([self.batch_size, 1])], axis=1), state)
output, state = self.cell(self.eof, state)
output_list = tf.TensorArray(dtype=tf.float32, size=seq_length)
for t in range(seq_length):
output, state = self.cell(self.zero, state)
output_list = output_list.write(t, output[:, 0:self.vector_dim])
y_pred = tf.sigmoid(tf.transpose(output_list.stack(), perm=[1, 0, 2]))
return y_pred