Skip to content

Commit

Permalink
* CI: GPU fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 committed Mar 7, 2021
1 parent bead81f commit 97ee7cb
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 20 deletions.
21 changes: 1 addition & 20 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,6 @@
__all__ = ["from_tensorflow"]


# TODO: Better to differentiate the parser ?
v2_ops = [
"Enter",
"Exit",
"Merge",
]


def from_tensorflow(tf_input, layout="NHWC", shape=None, outputs=None):
"""Load tensorflow graph which is a python tensorflow graph object into relay.
The companion parameters will be handled automatically.
Expand Down Expand Up @@ -62,17 +54,6 @@ def from_tensorflow(tf_input, layout="NHWC", shape=None, outputs=None):

parser = TFParser(tf_input, outputs)
graph = parser.parse()
is_v2 = False

for node in graph.node:
if node.op in v2_ops:
is_v2 = True
break

if not is_v2:
g = v1.GraphProto()
else:
raise ImportError("TF 2.x parser yet to be supported")

g = v1.GraphProto()
mod, params = g.from_tensorflow(graph, layout, shape, outputs)
return mod, params
4 changes: 4 additions & 0 deletions tests/python/frontend/tensorflow/test_forward_v2_x.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def _test_pooling(input_shape, **kwargs):
_test_pooling_iteration(input_shape, **kwargs)


@tvm.testing.uses_gpu
def test_forward_pooling():
""" Pooling """
# TensorFlow only supports NDHWC for max_pool3d on CPU
Expand Down Expand Up @@ -567,6 +568,7 @@ def _test_convolution(
compare_tf_with_tvm_v2([input_data], concrete_func)


@tvm.testing.uses_gpu
def test_forward_convolution():
if is_gpu_available():
_test_convolution("conv", [4, 176, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], "SAME", "NCHW")
Expand Down Expand Up @@ -705,6 +707,7 @@ def _test_convolution3d(
compare_tf_with_tvm_v2([input_data], concrete_func)


@tvm.testing.uses_gpu
def test_forward_convolution3d():
if is_gpu_available():
_test_convolution3d(
Expand Down Expand Up @@ -749,6 +752,7 @@ def _test_biasadd(tensor_in_sizes, data_format):
compare_tf_with_tvm_v2([input_data, bias_data], concrete_func)


@tvm.testing.uses_gpu
def test_forward_biasadd():
if is_gpu_available():
_test_biasadd([4, 176, 8, 8], "NCHW")
Expand Down

0 comments on commit 97ee7cb

Please sign in to comment.