From 58434d16f0173ec1ab98edabe7a270e8d6856fac Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi <86472128+ashutosh-arm@users.noreply.github.com> Date: Mon, 15 May 2023 11:44:25 +0100 Subject: [PATCH] [TFLite][Frontend] Generate name when tensor name is missing (#14819) After upgrade to TFLite 2.6, some networks have missing tensor names. This commit generates names with prefix tvmgen_ from TFLite frontend. --- python/tvm/relay/frontend/tflite.py | 9 +++++-- .../contrib/test_cmsisnn/test_networks.py | 27 +++++++++++++++++++ tests/scripts/request_hook/request_hook.py | 1 + 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index f5d9b5bbf29a..9e2e244cb146 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -4091,7 +4091,12 @@ def get_tensor_name(subgraph, tensor_idx): ------- tensor name in UTF-8 encoding """ - return subgraph.Tensors(tensor_idx).Name().decode("utf-8") + tensor_name = subgraph.Tensors(tensor_idx).Name() + if tensor_name is not None: + tensor_name = tensor_name.decode("utf-8") + else: + tensor_name = "tvmgen_tensor_" + str(tensor_idx) + return tensor_name def _decode_type(n): @@ -4125,7 +4130,7 @@ def _input_type(model): tensor = subgraph.Tensors(input_) input_shape = tuple(tensor.ShapeAsNumpy()) tensor_type = tensor.Type() - input_name = tensor.Name().decode("utf8") + input_name = get_tensor_name(subgraph, input_) shape_dict[input_name] = input_shape dtype_dict[input_name] = _decode_type(tensor_type) diff --git a/tests/python/contrib/test_cmsisnn/test_networks.py b/tests/python/contrib/test_cmsisnn/test_networks.py index 9f64be246182..16afffdccefb 100644 --- a/tests/python/contrib/test_cmsisnn/test_networks.py +++ b/tests/python/contrib/test_cmsisnn/test_networks.py @@ -120,5 +120,32 @@ def test_cnn_small(test_runner): ) +@tvm.testing.requires_package("tflite") +def test_keyword_scramble(): + """Download keyword_scrambled and test for Relay conversion. + In future, this test can be extended for CMSIS-NN""" + # download the model + base_url = ( + "https://github.com/tensorflow/tflite-micro/raw/" + "de8f61a074460e1fa5227d875c95aa303be01240/" + "tensorflow/lite/micro/models" + ) + file_to_download = "keyword_scrambled.tflite" + file_saved = "keyword_scrambled.tflite" + model_file = download_testdata("{}/{}".format(base_url, file_to_download), file_saved) + + with open(model_file, "rb") as f: + tflite_model_buf = f.read() + + input_shape = (1, 96) + dtype = "int8" + in_min, in_max = get_dtype_range(dtype) + rng = np.random.default_rng(12345) + input_data = rng.integers(in_min, high=in_max, size=input_shape, dtype=dtype) + + with pytest.raises(tvm.error.OpNotImplemented): + _, _ = _convert_to_relay(tflite_model_buf, input_data, "input") + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/scripts/request_hook/request_hook.py b/tests/scripts/request_hook/request_hook.py index 3c193d84ae4c..dd92a92bc5a4 100644 --- a/tests/scripts/request_hook/request_hook.py +++ b/tests/scripts/request_hook/request_hook.py @@ -212,6 +212,7 @@ "https://github.com/tlc-pack/web-data/raw/main/testdata/microTVM/data/image_classification_int8_0.npy": f"{BASE}/tlc-pack/web-data/raw/main/testdata/microTVM/data/image_classification_int8_0.npy", "https://github.com/tlc-pack/web-data/raw/main/testdata/microTVM/data/vww_sample_person.jpg": f"{BASE}/tlc-pack/web-data/testdata/microTVM/data/vww_sample_person.jpg", "https://github.com/tlc-pack/web-data/raw/main/testdata/microTVM/data/vww_sample_not_person.jpg": f"{BASE}/tlc-pack/web-data/testdata/microTVM/data/vww_sample_not_person.jpg", + "https://github.com/tensorflow/tflite-micro/raw/de8f61a074460e1fa5227d875c95aa303be01240/tensorflow/lite/micro/models/keyword_scrambled.tflite": f"{BASE}/models/tflite/keyword_scrambled_8bit.tflite", }