Skip to content

Commit

Permalink
[TFLite][Frontend] Generate name when tensor name is missing (#14819)
Browse files Browse the repository at this point in the history
After upgrade to TFLite 2.6, some networks have missing tensor names.
This commit generates names with prefix tvmgen_ from TFLite frontend.
  • Loading branch information
ashutosh-arm authored May 15, 2023
1 parent 265c098 commit 58434d1
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 2 deletions.
9 changes: 7 additions & 2 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -4091,7 +4091,12 @@ def get_tensor_name(subgraph, tensor_idx):
-------
tensor name in UTF-8 encoding
"""
return subgraph.Tensors(tensor_idx).Name().decode("utf-8")
tensor_name = subgraph.Tensors(tensor_idx).Name()
if tensor_name is not None:
tensor_name = tensor_name.decode("utf-8")
else:
tensor_name = "tvmgen_tensor_" + str(tensor_idx)
return tensor_name


def _decode_type(n):
Expand Down Expand Up @@ -4125,7 +4130,7 @@ def _input_type(model):
tensor = subgraph.Tensors(input_)
input_shape = tuple(tensor.ShapeAsNumpy())
tensor_type = tensor.Type()
input_name = tensor.Name().decode("utf8")
input_name = get_tensor_name(subgraph, input_)
shape_dict[input_name] = input_shape
dtype_dict[input_name] = _decode_type(tensor_type)

Expand Down
27 changes: 27 additions & 0 deletions tests/python/contrib/test_cmsisnn/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,5 +120,32 @@ def test_cnn_small(test_runner):
)


@tvm.testing.requires_package("tflite")
def test_keyword_scramble():
"""Download keyword_scrambled and test for Relay conversion.
In future, this test can be extended for CMSIS-NN"""
# download the model
base_url = (
"https://github.com/tensorflow/tflite-micro/raw/"
"de8f61a074460e1fa5227d875c95aa303be01240/"
"tensorflow/lite/micro/models"
)
file_to_download = "keyword_scrambled.tflite"
file_saved = "keyword_scrambled.tflite"
model_file = download_testdata("{}/{}".format(base_url, file_to_download), file_saved)

with open(model_file, "rb") as f:
tflite_model_buf = f.read()

input_shape = (1, 96)
dtype = "int8"
in_min, in_max = get_dtype_range(dtype)
rng = np.random.default_rng(12345)
input_data = rng.integers(in_min, high=in_max, size=input_shape, dtype=dtype)

with pytest.raises(tvm.error.OpNotImplemented):
_, _ = _convert_to_relay(tflite_model_buf, input_data, "input")


if __name__ == "__main__":
tvm.testing.main()
1 change: 1 addition & 0 deletions tests/scripts/request_hook/request_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@
"https://github.com/tlc-pack/web-data/raw/main/testdata/microTVM/data/image_classification_int8_0.npy": f"{BASE}/tlc-pack/web-data/raw/main/testdata/microTVM/data/image_classification_int8_0.npy",
"https://github.com/tlc-pack/web-data/raw/main/testdata/microTVM/data/vww_sample_person.jpg": f"{BASE}/tlc-pack/web-data/testdata/microTVM/data/vww_sample_person.jpg",
"https://github.com/tlc-pack/web-data/raw/main/testdata/microTVM/data/vww_sample_not_person.jpg": f"{BASE}/tlc-pack/web-data/testdata/microTVM/data/vww_sample_not_person.jpg",
"https://github.com/tensorflow/tflite-micro/raw/de8f61a074460e1fa5227d875c95aa303be01240/tensorflow/lite/micro/models/keyword_scrambled.tflite": f"{BASE}/models/tflite/keyword_scrambled_8bit.tflite",
}


Expand Down

0 comments on commit 58434d1

Please sign in to comment.