Skip to content

Commit

Permalink
[TENSORFLOW] Tensorflow 2.x support
Browse files Browse the repository at this point in the history
    * Frontend supports concreate function along with graphdef.
    * New test cases added to validate TF2.x functions.
    * E2E testcases will use TFHub inputs.
  • Loading branch information
srkreddy1238 committed Feb 22, 2021
1 parent d16f282 commit d821106
Show file tree
Hide file tree
Showing 2 changed files with 3,376 additions and 1 deletion.
73 changes: 72 additions & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .common import infer_channels as _infer_channels
from .common import infer_value as _infer_value


__all__ = ["from_tensorflow"]


Expand Down Expand Up @@ -2462,6 +2463,7 @@ def _impl(inputs, attr, params, mod):
"Round": AttrCvt("round"),
"Rsqrt": _rsqrt(),
"Select": _where(),
"SelectV2": _where(),
"Selu": _selu(),
"Shape": _shape(),
"Sigmoid": AttrCvt("sigmoid"),
Expand Down Expand Up @@ -3656,13 +3658,58 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
return func, self._params


# Ref. taken from Tensorflow JS
def _build_signature_def(frozen_graph, input_nodes, output_nodes):
try:
from tensorflow.core.protobuf import meta_graph_pb2
except ImportError as e:
raise ImportError("Unable to import tensorflow which is required {}".format(e))
signature = meta_graph_pb2.SignatureDef()
for input_tensor in input_nodes:
op_name = input_tensor.name.split(":")[0]
# The graph freezing may turn the original inputs into constants, or remove
# them from the graph, so we need to ignore those.
try:
op = frozen_graph.get_operation_by_name(op_name)
if op.type != "Const":
signature.inputs[input_tensor.name].name = input_tensor.name
signature.inputs[input_tensor.name].dtype = input_tensor.dtype.as_datatype_enum
signature.inputs[input_tensor.name].tensor_shape.CopyFrom(
input_tensor.shape.as_proto()
)
except KeyError:
# The original input was removed when the graph was frozen.
continue
for output_tensor in output_nodes:
if hasattr(output_tensor, "name"):
signature.outputs[output_tensor.name].name = output_tensor.name
signature.outputs[output_tensor.name].dtype = output_tensor.dtype.as_datatype_enum
signature.outputs[output_tensor.name].tensor_shape.CopyFrom(
output_tensor.shape.as_proto()
)
else: # just the tensor name string array
signature.outputs[output_tensor].name = output_tensor
return signature


def _run_grappler(config, graph_def, graph, signature_def):
try:
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.training.saver import export_meta_graph
except ImportError as e:
raise ImportError("Unable to import tensorflow which is required {}".format(e))
meta_graph = export_meta_graph(graph_def=graph_def, graph=graph)
meta_graph.signature_def["not_used_key"].CopyFrom(signature_def)
return tf_optimizer.OptimizeGraph(config, meta_graph)


def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
"""Load tensorflow graph which is a python tensorflow graph object into relay.
The companion parameters will be handled automatically.
Parameters
----------
graph : GraphDef object
graph : GraphDef object or concrete function
Tensorflow GraphDef
layout : target layout to be used (Optional)
Expand All @@ -3682,6 +3729,30 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
params : dict of str to tvm.nd.NDArray
Dict of converted parameters stored in tvm.nd.NDArray format
"""
from tensorflow.python.eager.function import ConcreteFunction

if isinstance(graph, ConcreteFunction):
try:
from tensorflow.python.framework import convert_to_constants
from tensorflow.core.protobuf import config_pb2
except ImportError as e:
raise ImportError("Unable to import tensorflow which is required {}".format(e))
concrete_func = graph
graph = convert_to_constants.convert_variables_to_constants_v2(concrete_func).graph
signature = _build_signature_def(graph, concrete_func.inputs, concrete_func.outputs)
graph_def = graph.as_graph_def()

# Some optimization
config = config_pb2.ConfigProto()
rewriter_config = config.graph_options.rewrite_options
rewriter_config.optimizers[:] = [
"debug_stripper",
"arithmetic",
"dependency",
"arithmetic",
"dependency",
]
graph = _run_grappler(config, graph_def, graph, signature)

g = GraphProto()
mod, params = g.from_tensorflow(graph, layout, shape, outputs)
Expand Down
Loading

0 comments on commit d821106

Please sign in to comment.