Skip to content

Commit

Permalink
[RELAY][FRONTEND] Tensorflow frontend. (apache#2216)
Browse files Browse the repository at this point in the history
* [RELAY][FRONTEND] Tensorflow frontend support.

* 	* LSTM removed for a while.

* 	* basic ops are good.

* 	* nn wip

* 	* wip

* 	* python2.7 corrections.

* * NN ops are good.

* * e2e models working good

* 	* all good except LSTM

* 	* rebase, tutorials and CI trigger.

* 	* CI errors.

* 	* enable opt_level=3

* 	* Docstrings cleanup. testing.tf utils moved to relay from nnvm.

* 	* tutorials update.

* 	* LSTM work good now.

* 	* Rebase

* 	* CI error

* 	* enable PTB.

* 	* rebase.

* 	* tutorials

* Update python/tvm/relay/frontend/tensorflow.py

Co-Authored-By: srkreddy1238 <[email protected]>

* 	* review comments.

* 	CI fix.

* 	* review comments.
  • Loading branch information
srkreddy1238 authored and AWS Neo committed Feb 20, 2019
1 parent 55a48c5 commit b488ade
Show file tree
Hide file tree
Showing 10 changed files with 2,895 additions and 25 deletions.
2 changes: 1 addition & 1 deletion docs/frontend/tensorflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ instructions to generate protobuf from checkpoint.

### Add Shapes:
While freezing of protobuf add additional option ```add_shapes=True``` to embed output shapes of each node into graph.
You may use ```nnvm.testing.tf.AddShapesToGraphDef``` from nnvm for the same.
You may use ```tvm.relay.testing.tf.AddShapesToGraphDef``` from nnvm for the same.
Please refer to [tensorflow tutorial](https://github.com/dmlc/tvm/blob/master/tutorials/nnvm/from_tensorflow.py).

### Explicit Shape:
Expand Down
30 changes: 15 additions & 15 deletions nnvm/tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tensorflow.python.ops import init_ops
from tensorflow.core.framework import graph_pb2

import nnvm.testing.tf
import tvm.relay.testing.tf as tf_testing

#######################################################################
# Generic run functions for TVM & tensorflow
Expand Down Expand Up @@ -784,9 +784,9 @@ def test_forward_pad():
def test_forward_inception_v3():
'''test inception V3 model'''
with tf.Graph().as_default():
graph_def = nnvm.testing.tf.get_workload('InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb')
graph_def = tf_testing.get_workload('InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb')
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
graph_def = tf_testing.ProcessGraphDefParam(graph_def)

data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')

Expand All @@ -801,9 +801,9 @@ def test_forward_inception_v3():
def test_forward_inception_v1():
'''test inception V1 model'''
with tf.Graph().as_default():
graph_def = nnvm.testing.tf.get_workload("InceptionV1/classify_image_graph_def-with_shapes.pb")
graph_def = tf_testing.get_workload("InceptionV1/classify_image_graph_def-with_shapes.pb")
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
graph_def = tf_testing.ProcessGraphDefParam(graph_def)

# Build an image from random data.
from PIL import Image
Expand Down Expand Up @@ -838,18 +838,18 @@ def test_forward_mobilenet():
'''test mobilenet model'''
# MobilenetV2
with tf.Graph().as_default():
graph_def = nnvm.testing.tf.get_workload(
graph_def = tf_testing.get_workload(
"https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz",
"mobilenet_v2_1.4_224_frozen.pb")
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
graph_def = tf_testing.ProcessGraphDefParam(graph_def)

data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
out_node = 'MobilenetV2/Predictions/Reshape_1'

with tf.Session() as sess:
# Add shapes to the graph.
graph_def = nnvm.testing.tf.AddShapesToGraphDef(sess, out_node)
graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0')
tvm_output = run_tvm_graph(graph_def, data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
Expand All @@ -861,9 +861,9 @@ def test_forward_resnetv2():
'''test resnet model'''
if is_gpu_available():
with tf.Graph().as_default():
graph_def = nnvm.testing.tf.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb")
graph_def = tf_testing.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb")
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
graph_def = tf_testing.ProcessGraphDefParam(graph_def)

data = np.random.uniform(size=(128, 224, 224, 3)).astype('float32')
out_node = 'ArgMax'
Expand All @@ -879,7 +879,7 @@ def test_forward_resnetv2():
dir(tf.contrib)
def test_forward_ptb():
'''test ptb model'''
config = nnvm.testing.tf.get_config()
config = tf_testing.get_config()
num_steps = config.num_steps
num_hidden = config.hidden_size
num_layers = config.num_layers
Expand Down Expand Up @@ -936,7 +936,7 @@ def _get_sample(data, state):
"float32")).asnumpy()
state_output = model.get_output(1, tvm.nd.empty(out_state_shape,
"float32")).asnumpy()
sample = nnvm.testing.tf.pick_from_weight(tvm_output[0])
sample = tf_testing.pick_from_weight(tvm_output[0])

return sample, state_output

Expand All @@ -956,10 +956,10 @@ def _get_sample(data, state):
return samples, state

with tf.Graph().as_default():
word_to_id, id_to_word, graph_def = nnvm.testing.tf.get_workload_ptb()
word_to_id, id_to_word, graph_def = tf_testing.get_workload_ptb()
vocab_size = len(word_to_id)
# Call the utility to import the graph definition into default graph.
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
sess = tf.Session()

#TVM graph module creation
Expand All @@ -975,7 +975,7 @@ def _get_sample(data, state):
for word in seed_for_sample],
in_state, params, cnt_sample)
tvm_sample_str = _pretty_print(tvm_samples, False, id_to_word)
tf_samples, tf_state = nnvm.testing.tf.do_tf_sample(sess,
tf_samples, tf_state = tf_testing.do_tf_sample(sess,
[word_to_id[word] for word in seed_for_sample],
in_state, cnt_sample)
tf_sample_str = _pretty_print(tf_samples, False, id_to_word)
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/frontend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
from .tflite import from_tflite
from .coreml import from_coreml
from .caffe2 import from_caffe2
from .tensorflow import from_tensorflow
Loading

0 comments on commit b488ade

Please sign in to comment.