Skip to content

Commit

Permalink
* wip
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 committed Dec 1, 2018
1 parent d097a66 commit 89284d2
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 30 deletions.
43 changes: 24 additions & 19 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,21 @@ def _required_attr(self, attr, key):

def _get_pad_pair(input1d, kernel1d, stride1d):
if input1d % stride1d == 0:
pad = tvm.select((kernel1d - stride1d) > 0, (kernel1d - stride1d), relay.const(0))
pad = max(kernel1d - stride1d, 0)
else:
pad = tvm.select((kernel1d - (input1d % stride1d)) > 0, (kernel1d - (input1d % stride1d)), relay.const(0))
pad = max(kernel1d - (input1d % stride1d), 0)

pad_before = pad // relay.const(2)
pad_before = pad // 2
pad_after = pad - pad_before

return [pad_before, pad_after]

def _get_name_hint(node):
if hasattr(node, "name_hint"):
return node.name_hint
else:
return ''

def _math_name_picker(surfix):
def _impl(attr):
return 'broadcast_' + surfix
Expand Down Expand Up @@ -318,7 +324,7 @@ def _impl(inputs, attr, params):
attr['data_format'] = "NCHW"
attr['strides'] = [attr['strides'][ii] for ii in (0, 3, 1, 2)]
flip_layout = True
print("W Shape:", weights_shape)

if attr['data_format'] == 'NHWC':
kernel_h, kernel_w, _, depth_mult = weights_shape
attr['kernel_shape'] = (weights_shape[0], weights_shape[1])
Expand Down Expand Up @@ -647,11 +653,11 @@ def _impl(inputs, attr, params):
new_input = []
new_input.append(inputs.pop(0))
new_input.append(inputs.pop(0))
return AttrCvt(
op_name="take",
extras={'axis':axis},
ignores=['Tindices', 'Tparams', 'validate_indices', \
'Taxis', '_class'])(new_input, attr)
return AttrCvt(op_name="take",
extras={'axis': tvm.const(axis)},
ignores=['Tindices', 'Tparams', 'validate_indices', \
'Taxis', '_class'])(new_input, attr)
return out
return _impl

def _infer_out_shapes(inputs, params):
Expand Down Expand Up @@ -785,9 +791,12 @@ def _transpose():
def _impl(inputs, attr, params):
# If perm is not specified, axes is left empty,
# otherwise its value is get from params
param_name = inputs[1].name_hint
axes = params.get(param_name, tvm.nd.array([])).asnumpy()
return _op.transpose(inputs[0], axes=tuple(axes))
param_name = _get_name_hint(inputs[1])
if param_name in params:
axes = tuple(params.get(param_name).asnumpy())
else:
axes = None
return _op.transpose(inputs[0], axes=axes)
return _impl

def _rank():
Expand All @@ -799,7 +808,7 @@ def _impl(inputs, attr, params):
params[name] = tvm.nd.array([len(input_shapes[0])])
return [_expr.var(name,
shape=params[name].shape,
dtype=params[name].dtype)]
dtype='int32')]

return _impl

Expand All @@ -813,7 +822,7 @@ def _impl(inputs, attr, params):
params[name] = tvm.nd.array([start, limit, delta])
return [_expr.var(name,
shape=params[name].shape,
dtype=params[name].dtype)]
dtype='int32')]
return _impl

def _elu():
Expand Down Expand Up @@ -873,7 +882,7 @@ def _impl(inputs, attr, params):
'MatMul' : _matmul(),
'MaxPool' : _pooling('max_pool'),
'Add' : _elemwise('add'),
'Sub' : _elemwise('sub'),
'Sub' : _elemwise('subtract'),
'Mul' : _elemwise('multiply'),
'Maximum' : _elemwise('max'),
'Minimum' : _elemwise('min'),
Expand Down Expand Up @@ -974,7 +983,6 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
final_op = None
# Parse the nodes to re-create TF graph using Symbol API of NNVM
for node in graph.node:
print("Node: ", node.name, "Node Op:", node.op)
# Tensorflow doesn't have seperate list for params extraction.
# Operator name 'Const' is treated as a parameter to build NNVM params dict.

Expand Down Expand Up @@ -1070,9 +1078,6 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
out = op
out = out[0] if len(out) == 1 else _expr.Tuple(out)
func = _expr.Function(ir_pass.free_vars(out), out)
print("OP:", op)
print("Func:", func)
print("Shape:", relay.ir_pass.infer_type(op[0]).checked_type)

return func, self._params

Expand Down
23 changes: 12 additions & 11 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,19 +939,20 @@ def test_forward_l2_normalize():
# transpose
# ---------
def _test_forward_transpose(ishape, axes=None):
input = np.random.uniform(size=ishape).astype(np.float32)
data = np.random.uniform(size=ishape).astype(np.float32)

with tf.Graph().as_default():
in1 = tf.placeholder(shape=input.shape, dtype=input.dtype, name="transpose_data")
in1 = tf.placeholder(shape=data.shape, dtype=data.dtype, name="transpose_data")

if axes is None:
tf.transpose(in1)
else:
tf.transpose(in1, perm=axes)

compare_tf_with_tvm(input, 'transpose_data:0', 'transpose:0')
compare_tf_with_tvm(data, 'transpose_data:0', 'transpose:0')

def test_forward_transpose():
_test_forward_transpose((2, 3, 4), (1, 2, 0))
_test_forward_transpose((2, 3, 4))
_test_forward_transpose((7, 8, 8, 10))
_test_forward_transpose((2, 3, 4), (1, 2, 0))
Expand Down Expand Up @@ -1056,14 +1057,6 @@ def test_forward_rel_ops():
# Main
# ----
if __name__ == '__main__':
# NN
test_forward_convolution()
#test_forward_pooling()
#if tf.__version__ == '1.4.1':
# _test_forward_concat_v2()
#test_forward_lrn()
#test_forward_l2_normalize()
exit(0)
# Transforms
test_forward_transpose()
test_forward_reshape()
Expand Down Expand Up @@ -1108,3 +1101,11 @@ def test_forward_rel_ops():

# Relational ops
test_forward_rel_ops()

# NN
#test_forward_convolution()
#test_forward_pooling()
#if tf.__version__ == '1.4.1':
# _test_forward_concat_v2()
#test_forward_lrn()
#test_forward_l2_normalize()

0 comments on commit 89284d2

Please sign in to comment.