Skip to content

Commit

Permalink
[relay][frontend] TensorFlow saved model support
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed Feb 12, 2019
1 parent 89deaa6 commit 8671af9
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 15 deletions.
59 changes: 44 additions & 15 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import print_function

import logging
import warnings
# Numpy support
import numpy as np

Expand Down Expand Up @@ -411,7 +412,7 @@ def _impl(inputs, attr, params):
def _decode_image():
def _impl(inputs, attr, params):
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
print("DecodeJpeg: It's a pass through, please handle preprocessing before input")
warnings.warn("DecodeJpeg: It's a pass through, please handle preprocessing before input")
return inputs[0]
return _impl

Expand Down Expand Up @@ -1172,6 +1173,7 @@ class GraphProto(object):
def __init__(self):
self._nodes = {}
self._params = {}
self._input_shapes = {}
self._output_shapes = {}
self._num_param = 0
self._num_rnn_layer = False
Expand Down Expand Up @@ -1223,36 +1225,55 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))

for node in graph.node:
if node.op == 'Placeholder':
if shape and node.name in shape:
self._input_shapes[node.name] = list(shape[node.name])
continue
self._input_shapes[node.name] = \
tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
for idx, dim in enumerate(self._input_shapes[node.name]):
if dim < 0:
self._input_shapes[node.name][idx] = 1
warnings.warn("Use 1 instead of -1 in shape of operator %s."
% node.name)

# Ignore user's input shape for Non placeholder
elif node.op == 'Const':
tensor_value = node.attr['value'].tensor
self._input_shapes[node.name] = \
tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape)
if shape and node.name in shape:
warnings.warn("Ignore the passed shape. Shape in graphdef "
"will be used for operator %s." % node.name)

# Parse the nodes to re-create TF graph using Relay operators.
for node in graph.node:
# Tensorflow doesn't have seperate list for params extraction.
# Tensorflow doesn't have separate list for params extraction.
# Operator name 'Const' is treated as a parameter to build params dict.

input_shapes = {}
attr = self._parse_attr(node.attr)

#Variable converted to Const will not have only value attr
# Variable converted to Const will not have only value attr
if 'value' in attr and node.op == 'Const':
tensor_value = attr['value']
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList( \
tensor_value.tensor_shape)]
self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif shape and node.name in shape:
# Give priority to user argument.
self._output_shapes[node.name] = [shape[node.name]]
elif '_output_shapes' in attr:
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(tshape) \
for tshape in attr['_output_shapes']]
elif shape:
else:
# Keep the list indexable to avoid key error.
# Actual value will be filled after node creation.
self._output_shapes[node.name] = [None]
else:
raise NotImplementedError( \
"Please freeze the graph with add_shapes=True")

if node.op == "Placeholder":
self._output_shapes[node.name] = [shape[node.name]]
self._output_shapes[node.name] = [self._input_shapes[node.name]]
self._nodes[node.name] = [_expr.var(node.name,
shape=self._output_shapes[node.name][0],
shape=self._input_shapes[node.name],
dtype=attr['dtype'].name)]

elif node.op == "Const":
Expand All @@ -1268,7 +1289,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):

else:
# Pass the parsed shapes instead
attr["_output_shapes"] = self._output_shapes[node.name]
attr["_output_shapes"] = output_shapes = self._output_shapes[node.name]

# Pass the node name too in attr
attr["_node_name"] = node.name
Expand All @@ -1295,7 +1316,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):

op = self._convert_operator(node.op, inputs, attr, graph)

# Check is op is converted to param
# Check if op is converted to param
if isinstance(op, np.ndarray):
self._params[node.name] = tvm.nd.array(op)
op = [_expr.var(node.name,
Expand All @@ -1311,6 +1332,14 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):

self._nodes[node.name] = op

# Infer shapes even without specifying "add_shapes=True"
if output_shapes == [None]:
out_type = ir_pass.infer_type(self._nodes[node.name][0])
self._output_shapes[node.name] = [get_const_tuple(out_type.checked_type.shape)]

if self._output_shapes[node.name] and shape and node.name in shape:
assert self._output_shapes[node.name] == list(shape[node.name])

# Infer shapes if passed explicitely
node_output = self._nodes[node.name]
out_type = ir_pass.infer_type(node_output[0])
Expand Down
153 changes: 153 additions & 0 deletions python/tvm/relay/frontend/tensorflow_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""TF: Tensorflow parser"""
from __future__ import absolute_import as _abs
from __future__ import print_function
import os
from tensorflow.core.framework import graph_pb2
from tvm.contrib import util


class TFParser(object):
"""A Wrapper to handle tensorflow models parsing
TensorFlow is needed
```
parser = TfParser(model_dir)
graph = parser.parse()
```
Parameters
----------
model_dir : tensorflow frozen pb file or a directory that contains saved
model or checkpoints.
"""

def __init__(self, model_dir):
self._tmp_dir = util.tempdir()
self._model_dir = model_dir
self._graph = graph_pb2.GraphDef()

def _set_graph(self, graph):
"""Set Graph"""
self._graph = graph

def _get_graph(self):
"""Get Graph"""
return self._graph

def _load_pb_file(self):
"""Load single pb file"""
graph = self._get_graph()
with open(self._model_dir, "rb") as f:
graph.ParseFromString(f.read())
return graph

def _get_tag_set(self):
"""Return the tag set of saved model, multiple metagraphs are not supported"""
try:
from tensorflow.contrib.saved_model.python.saved_model import reader
except ImportError:
raise ImportError(
"InputConfiguration: Unable to import saved_model.reader which is "
"required to get tag set from saved model.")
tag_sets = reader.get_saved_model_tag_sets(self._model_dir)
return tag_sets[0]

def _get_output_names(self):
"""Return the concatenated output names"""
try:
import tensorflow as tf
except ImportError:
raise ImportError(
"InputConfiguration: Unable to import tensorflow which is "
"required to restore from saved model.")
tags = self._get_tag_set()
with tf.Session() as sess:
meta_graph_def = tf.saved_model.loader.load(sess,
tags,
self._model_dir)
output_names = set()
for k in meta_graph_def.signature_def.keys():
outputs_tensor_info = meta_graph_def.signature_def[k].outputs
for output_tensor in outputs_tensor_info.values():
output_names.add(output_tensor.name)
output_names = [i.replace(":0", "") for i in output_names]
return ",".join(output_names)

def _load_saved_model(self):
"""Load the tensorflow saved model."""
try:
from tensorflow.python.tools import freeze_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import graph_util
except ImportError:
raise ImportError(
"InputConfiguration: Unable to import tensorflow which is "
"required to restore from saved model.")

saved_model_dir = self._model_dir
output_graph_filename = self._tmp_dir.relpath("tf_frozen_model.pb")
input_saved_model_dir = saved_model_dir
output_node_names = self._get_output_names()

input_binary = False
input_saver_def_path = False
restore_op_name = None
filename_tensor_name = None
clear_devices = True
input_meta_graph = False
checkpoint_path = None
input_graph_filename = None
saved_model_tags = ",".join(self._get_tag_set())

freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path,
input_binary, checkpoint_path, output_node_names,
restore_op_name, filename_tensor_name,
output_graph_filename, clear_devices, "", "", "",
input_meta_graph, input_saved_model_dir,
saved_model_tags)

with ops.Graph().as_default():
output_graph_def = graph_pb2.GraphDef()
with open(output_graph_filename, "rb") as f:
output_graph_def.ParseFromString(f.read())
output_graph_def = graph_util.remove_training_nodes(output_graph_def)
return output_graph_def

def _load_ckpt(self):
"""TODO: Load checkpoint model."""
raise RuntimeError("InputConfiguration: Loading tf checkpoint model is "
"not supported yet.")

def parse(self):
"""Parse tensorflow models: checkpoints, saved models, and single pb
file.
"""
graph = None

if os.path.isdir(self._model_dir):
ckpt = os.path.join(self._model_dir, "checkpoint")
if not os.path.isfile(ckpt):
if not os.path.isdir(os.path.join(self._model_dir, "variables")):
raise RuntimeError("InputConfiguration: Invalid model path.")
graph = self._load_saved_model()
else:
graph = self._load_ckpt()
elif os.path.isfile(self._model_dir):
# Only .pb or .pbtxt is a valid suffix name.
if self._model_dir.endswith(".pb") or \
self._model_dir.endswith(".pbtxt"):
cur_dir = os.path.dirname(self._model_dir)
else:
raise RuntimeError("InputConfiguration: Invalid model format.")

# It is a saved model if `variables` directory is present at the
# same directory with the pb or pbtxt file.
if os.path.isdir(os.path.join(cur_dir, "variables")):
self._model_dir = cur_dir
graph = self._load_saved_model()
else:
graph = self._load_pb_file()
else:
raise RuntimeError("InputConfiguration: Unrecognized model "
"file or path.")

self._set_graph(graph)
return graph

0 comments on commit 8671af9

Please sign in to comment.