-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimitation_agent_network.py
196 lines (145 loc) · 6.53 KB
/
imitation_agent_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
from __future__ import print_function
import numpy as np
import tensorflow as tf
def weight_ones(shape, name):
initial = tf.constant(1.0, shape=shape, name=name)
return tf.Variable(initial)
def weight_xavi_init(shape, name):
initial = tf.get_variable(name=name, shape=shape,
initializer=tf.contrib.layers.xavier_initializer())
return initial
def bias_variable(shape, name):
initial = tf.constant(0.1, shape=shape, name=name)
return tf.Variable(initial)
class Network(object):
def __init__(self, dropout, image_shape):
""" We put a few counters to see how many times we called each function """
self._dropout_vec = dropout
self._image_shape = image_shape
self._count_conv = 0
self._count_pool = 0
self._count_bn = 0
self._count_activations = 0
self._count_dropouts = 0
self._count_fc = 0
self._count_lstm = 0
self._count_soft_max = 0
self._conv_kernels = []
self._conv_strides = []
self._weights = {}
self._features = {}
""" Our conv is currently using bias """
def conv(self, x, kernel_size, stride, output_size, padding_in='SAME'):
self._count_conv += 1
filters_in = x.get_shape()[-1]
shape = [kernel_size, kernel_size, filters_in, output_size]
weights = weight_xavi_init(shape, 'W_c_' + str(self._count_conv))
bias = bias_variable([output_size], name='B_c_' + str(self._count_conv))
self._weights['W_conv' + str(self._count_conv)] = weights
self._conv_kernels.append(kernel_size)
self._conv_strides.append(stride)
conv_res = tf.add(tf.nn.conv2d(x, weights, [1, stride, stride, 1], padding=padding_in,
name='conv2d_' + str(self._count_conv)), bias,
name='add_' + str(self._count_conv))
self._features['conv_block' + str(self._count_conv - 1)] = conv_res
return conv_res
def max_pool(self, x, ksize=3, stride=2):
self._count_pool += 1
return tf.nn.max_pool(x, ksize=[1, ksize, ksize, 1], strides=[1, stride, stride, 1],
padding='SAME', name='max_pool' + str(self._count_pool))
def bn(self, x):
self._count_bn += 1
return tf.contrib.layers.batch_norm(x, is_training=False,
updates_collections=None,
scope='bn' + str(self._count_bn))
def activation(self, x):
self._count_activations += 1
return tf.nn.relu(x, name='relu' + str(self._count_activations))
def dropout(self, x):
print("Dropout", self._count_dropouts)
self._count_dropouts += 1
output = tf.nn.dropout(x, self._dropout_vec[self._count_dropouts - 1],
name='dropout' + str(self._count_dropouts))
return output
def fc(self, x, output_size):
self._count_fc += 1
filters_in = x.get_shape()[-1]
shape = [filters_in, output_size]
weights = weight_xavi_init(shape, 'W_f_' + str(self._count_fc))
bias = bias_variable([output_size], name='B_f_' + str(self._count_fc))
return tf.nn.xw_plus_b(x, weights, bias, name='fc_' + str(self._count_fc))
def conv_block(self, x, kernel_size, stride, output_size, padding_in='SAME'):
print(" === Conv", self._count_conv, " : ", kernel_size, stride, output_size)
with tf.name_scope("conv_block" + str(self._count_conv)):
x = self.conv(x, kernel_size, stride, output_size, padding_in=padding_in)
x = self.bn(x)
x = self.dropout(x)
return self.activation(x)
def fc_block(self, x, output_size):
print(" === FC", self._count_fc, " : ", output_size)
with tf.name_scope("fc" + str(self._count_fc + 1)):
x = self.fc(x, output_size)
x = self.dropout(x)
self._features['fc_block' + str(self._count_fc + 1)] = x
return self.activation(x)
def get_weigths_dict(self):
return self._weights
def get_feat_tensors_dict(self):
return self._features
def load_imitation_learning_network(input_image, input_data, input_size, dropout):
branches = []
x = input_image
network_manager = Network(dropout, tf.shape(x))
"""conv1""" # kernel sz, stride, num feature maps
xc = network_manager.conv_block(x, 5, 2, 32, padding_in='VALID')
print(xc)
xc = network_manager.conv_block(xc, 3, 1, 32, padding_in='VALID')
print(xc)
"""conv2"""
xc = network_manager.conv_block(xc, 3, 2, 64, padding_in='VALID')
print(xc)
xc = network_manager.conv_block(xc, 3, 1, 64, padding_in='VALID')
print(xc)
"""conv3"""
xc = network_manager.conv_block(xc, 3, 2, 128, padding_in='VALID')
print(xc)
xc = network_manager.conv_block(xc, 3, 1, 128, padding_in='VALID')
print(xc)
"""conv4"""
xc = network_manager.conv_block(xc, 3, 1, 256, padding_in='VALID')
print(xc)
xc = network_manager.conv_block(xc, 3, 1, 256, padding_in='VALID')
print(xc)
"""mp3 (default values)"""
""" reshape """
x = tf.reshape(xc, [-1, int(np.prod(xc.get_shape()[1:]))], name='reshape')
print(x)
""" fc1 """
x = network_manager.fc_block(x, 512)
print(x)
""" fc2 """
x = network_manager.fc_block(x, 512)
"""Process Control"""
""" Speed (measurements)"""
with tf.name_scope("Speed"):
speed = input_data[1] # get the speed from input data
speed = network_manager.fc_block(speed, 128)
speed = network_manager.fc_block(speed, 128)
""" Joint sensory """
j = tf.concat([x, speed], 1)
j = network_manager.fc_block(j, 512)
"""Start BRANCHING"""
branch_config = [["Steer", "Gas", "Brake"], ["Steer", "Gas", "Brake"], \
["Steer", "Gas", "Brake"], ["Steer", "Gas", "Brake"], ["Speed"]]
for i in range(0, len(branch_config)):
with tf.name_scope("Branch_" + str(i)):
if branch_config[i][0] == "Speed":
# we only use the image as input to speed prediction
branch_output = network_manager.fc_block(x, 256)
branch_output = network_manager.fc_block(branch_output, 256)
else:
branch_output = network_manager.fc_block(j, 256)
branch_output = network_manager.fc_block(branch_output, 256)
branches.append(network_manager.fc(branch_output, len(branch_config[i])))
print(branch_output)
return branches