Skip to content

Commit

Permalink
* review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 committed Jan 25, 2019
1 parent 6779e9c commit 3cbc68a
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 13 deletions.
19 changes: 10 additions & 9 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,6 @@ def __call__(self, inputs, attrs, *args):
self._ignores.append('_node_name')
self._ignores.append('is_training')
self._ignores.append('_target_layout')
# Retain the names
#try:
# attrs['name'] = attrs['_node_name']
#except KeyError:
# pass

# apply custom check
if self._custom_check:
Expand Down Expand Up @@ -513,13 +508,19 @@ def _impl(inputs, attr, params):
# Shape operator is already pruned, hence
# try to infer shape by precompute prune if possible.
if all(in_node in params for in_node in inputs[1].list_input_names()):
graph = _graph.create(_op.Group(inputs[1]))
params_pre = {k: params[k] for k in inputs[1].list_input_names()}
params_new = build_module._run_graph(graph, params_pre)
func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1])
with relay.build_config(opt_level=0):
graph, lib, params = relay.build(func, target="llvm", params=params)
ctx = tvm.context("llvm", 0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
params_new = m.get_output(0)
inputs.pop(1)
return AttrCvt(
op_name="reshape",
extras={'newshape':tuple(params_new[0].asnumpy().flatten())},
extras={'newshape':tuple(params_new.asnumpy().flatten())},
ignores=['Tshape'])(inputs, attr)
else:
raise RuntimeError("Reshape with dynamic shape input not supported yet.")
Expand Down
2 changes: 1 addition & 1 deletion tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm'
shape=shape_dict,
outputs=out_names)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(sym, target, params=params)
graph, lib, params = relay.build(sym, target, params=params)

ctx = tvm.context(target, 0)
from tvm.contrib import graph_runtime
Expand Down
3 changes: 0 additions & 3 deletions tutorials/frontend/from_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@

# Tensorflow imports
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util

# Tensorflow utility functions
import tvm.relay.testing.tf as tf_testing
Expand Down

0 comments on commit 3cbc68a

Please sign in to comment.