From 4dab53fd438ca2116e8dc1984bf84b24155ee247 Mon Sep 17 00:00:00 2001 From: "Boblest Sebastian (ETAS-DEV/XPC-Fe1)" Date: Fri, 29 Apr 2022 15:35:40 +0200 Subject: [PATCH 01/16] UnidirectionalLSTM added --- python/tvm/relay/frontend/tflite.py | 169 +++++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 30 ++++ 2 files changed, 199 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index b696bd6d056b..7cf9d4fe1df1 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -173,6 +173,7 @@ def __init__(self, model, subgraph, exp_tab): "TRANSPOSE_CONV": self.convert_transpose_conv, "TRANSPOSE": self.convert_transpose, "UNPACK": self.convert_unpack, + "UNIDIRECTIONAL_SEQUENCE_LSTM": self.convert_unidirectional_sequence_lstm, "WHERE": self.convert_select, "ZEROS_LIKE": self.convert_zeros_like, } @@ -220,6 +221,38 @@ def check_unsupported_ops(self): if len(raise_msg) > 0: raise tvm.error.OpNotImplemented(raise_msg) + def unbind(self, data, axis=1): + """ + This is a slightly modified version compared to the one in common.py + + Parameters + ---------- + data : relay.Expr + Input tensor + axis : int + Axis along which tensor is split. + Returns + ------- + result : List[relay.Expr] + The sequence of computed tensors + """ + shape = to_int_list(self.get_tensor_shape(data)) + if axis >= len(shape): + msg = "Please check input dim, it shouldn't be greater than or equal to rank." + raise AttributeError(msg) + + selections = shape[axis] + shape.pop(axis) + timestep = 0 # Reshape to make time step as the first dim + shape.insert(timestep, selections) + res_split = _op.split( + _op.reshape(self.get_expr(data.tensor_idx), tuple(shape)), selections, timestep + ) + ret = [] + for i in range(selections): + ret.append(_op.squeeze(res_split[i], axis=[timestep])) + return _expr.TupleWrapper(_expr.Tuple(ret), selections) + def convert_op_to_relay(self): """Convert TFLite ops to relay ops""" for op_idx in range(self.subgraph.OperatorsLength()): @@ -2715,6 +2748,142 @@ def convert_unpack(self, op): return squeezed + def convert_unidirectional_sequence_lstm(self, op): + """Long Short Term Memory for TFLite implementation.""" + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + "TFlite quantized UNIDIRECTIONALSEQUENCELSTM operator is not supported yet." + ) + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) >= 2, "input tensors length should be >= 2" + + # Extract input tensor from saved model + input_tensor = input_tensors[0] + + # Extract tensors from input tensors from saved model + # Input weights + input_input_weights = input_tensors[1] + input_forget_weights = input_tensors[2] + input_cell_weights = input_tensors[3] + input_output_weights = input_tensors[4] + # Recurrent weights + recurrent_input_weights = input_tensors[5] + recurrent_forget_weights = input_tensors[6] + recurrent_cell_weights = input_tensors[7] + recurrent_output_weights = input_tensors[8] + # Bias weights + input_gate_bias = input_tensors[12] + forget_gate_bias = input_tensors[13] + cell_gate_bias = input_tensors[14] + output_gate_bias = input_tensors[15] + # State input + output_state_in = input_tensors[18] + cell_state_in = input_tensors[19] + + # Extract output tensor from saved model + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "output tensors length should be 1" + X_steps = self.unbind(input_tensor, axis=1) + weights_dict = {} + + # hidden_state_weights is equivalent to output_state_in in tflite model + out_state_in_shape = tuple(self.get_tensor_shape(output_state_in)) + out_state_in_dtype = self.get_tensor_type_str(output_state_in.tensor.Type()) + out_state_in_expr = _op.zeros(out_state_in_shape, dtype=out_state_in_dtype) + weights_dict["hidden_state"] = _op.split(out_state_in_expr, 1)[0] + + # cell_state_weights is equivalent to output_state_in tflite model + cell_state_in_shape = tuple(self.get_tensor_shape(cell_state_in)) + cell_state_in_dtype = self.get_tensor_type_str(cell_state_in.tensor.Type()) + cell_state_in_expr = _op.zeros(cell_state_in_shape, dtype=cell_state_in_dtype) + weights_dict["cell_state"] = _op.split(cell_state_in_expr, 1)[0] + + # Process weight matrix of input: w_inp + # Concatenate of [input_input_weight, input_forget_weights, input_cell_weights, input_output_weights] + input_input_weights_default_values = self.get_tensor_value(input_input_weights) + input_input_weights_op = _op.split( + _op.const(input_input_weights_default_values.tolist()), 1 + ) + input_output_weights_default_values = self.get_tensor_value(input_output_weights) + input_output_weights_op = _op.split( + _op.const(input_output_weights_default_values.tolist()), 1 + ) + input_forget_weights_default_values = self.get_tensor_value(input_forget_weights) + input_forget_weights_op = _op.split( + _op.const(input_forget_weights_default_values.tolist()), 1 + ) + input_cell_weights_default_values = self.get_tensor_value(input_cell_weights) + input_cell_weights_op = _op.split(_op.const(input_cell_weights_default_values.tolist()), 1) + weights_dict["w_inp"] = _op.concatenate( + [ + _op.squeeze(input_input_weights_op[0]), + _op.squeeze(input_forget_weights_op[0]), + _op.squeeze(input_cell_weights_op[0]), + _op.squeeze(input_output_weights_op[0]), + ], + axis=0, + ) + + # Process weight matrix of hidden state: w_hid to support lstm_cell function. Not used in tflite + recurrent_input_weights_values = self.get_tensor_value(recurrent_input_weights) + recurrent_input_weights_op = _op.split( + _op.const(recurrent_input_weights_values.tolist()), 1 + ) + recurrent_output_weights_values = self.get_tensor_value(recurrent_output_weights) + recurrent_output_weights_op = _op.split( + _op.const(recurrent_output_weights_values.tolist()), 1 + ) + recurrent_forget_weights_values = self.get_tensor_value(recurrent_forget_weights) + recurrent_forget_weights_op = _op.split( + _op.const(recurrent_forget_weights_values.tolist()), 1 + ) + recurrent_cell_weights_values = self.get_tensor_value(recurrent_cell_weights) + recurrent_cell_weights_op = _op.split(_op.const(recurrent_cell_weights_values.tolist()), 1) + weights_dict["w_hid"] = _op.concatenate( + [ + recurrent_input_weights_op[0], + recurrent_forget_weights_op[0], + recurrent_cell_weights_op[0], + recurrent_output_weights_op[0], + ], + axis=0, + ) + + # Process weight matrix of bias: b_inp + input_gate_bias_values = self.get_tensor_value(input_gate_bias) + input_gate_bias_op = _op.split(_op.const(input_gate_bias_values.tolist()), 1) + output_gate_bias_values = self.get_tensor_value(output_gate_bias) + output_gate_bias_op = _op.split(_op.const(output_gate_bias_values.tolist()), 1) + forget_gate_bias_values = self.get_tensor_value(forget_gate_bias) + forget_gate_bias_op = _op.split(_op.const(forget_gate_bias_values.tolist()), 1) + cell_gate_bias_values = self.get_tensor_value(cell_gate_bias) + cell_gate_bias_op = _op.split(_op.const(cell_gate_bias_values.tolist()), 1) + weights_dict["b_inp"] = _op.concatenate( + [ + input_gate_bias_op[0], + forget_gate_bias_op[0], + cell_gate_bias_op[0], + output_gate_bias_op[0], + ], + axis=0, + ) + + # Process wieght matrix of hidden bias: b_hid (with the same shape as b_inp) + gate_bias_dtype = self.get_tensor_type_str(input_gate_bias.tensor.Type()) + weights_dict["b_hid"] = _op.split( + _op.const( + np.zeros(_infer_shape(weights_dict["b_inp"]), dtype=gate_bias_dtype), + dtype=gate_bias_dtype, + ), + 1, + )[0] + + outputs, H, C = lstm_cell(input_seqs=X_steps, **weights_dict) + + output = _op.stack(outputs, axis=1) + return output + def convert_batch_to_space_nd(self, op): """batch_to_space_nd implementation.""" diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 8c8ca0eab2ff..ad4efaa174a1 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -4693,6 +4693,36 @@ def representative_dataset(): tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) +####################################################################### +# Unidirectional Sequence LSTM +# --------------------- +def test_unidirectional_sequence_lstm(): + """Test the UnidirectionalSequenceLSTM TFLite""" + if package_version.parse(tf.VERSION) >= package_version.parse("2.1.0"): + tflite_model_file = download_testdata( + "https://github.com/SebastianBoblestETAS/nn_models/blob/ce49c5de64889493161ca4194a20e0fd5eb707e6/lstm_1_in_3_out_2_ts_4.tflite?raw=true", + "lstm_1_in_3_out_2_ts_4.tflite", + ) + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + + data = np.array( + [ + [ + [0.5488135, 0.71518934, 0.60276335], + [0.5448832, 0.4236548, 0.6458941], + [0.4375872, 0.891773, 0.96366274], + [0.3834415, 0.79172504, 0.5288949], + ] + ], + dtype="float32", + ) + + tflite_output = run_tflite_graph(tflite_model_buf, data) + tvm_output = run_tvm_graph(tflite_model_buf, data, "serving_default_input_1:0") + tvm.testing.assert_allclose(tflite_output, tvm_output) + + ####################################################################### # Quantized SSD Mobilenet # ----------------------- From a063a909992a2c5131d7506cde9827b7e9fc8ec4 Mon Sep 17 00:00:00 2001 From: "Boblest Sebastian (ETAS-DEV/XPC-Fe1)" Date: Fri, 29 Apr 2022 15:47:39 +0200 Subject: [PATCH 02/16] fixed missing import --- python/tvm/relay/frontend/tflite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 7cf9d4fe1df1..a5805343d52b 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -33,7 +33,7 @@ from ..backend.name_transforms import sanitize_name from .common import ExprTable from .common import infer_shape as _infer_shape -from .common import to_int_list, shape_of +from .common import lstm_cell, to_int_list, shape_of from .tflite_flexbuffer import FlexBufferDecoder __all__ = ["from_tflite"] From 61e390a20465ab9b55562f9e0178a4532b15689f Mon Sep 17 00:00:00 2001 From: "Boblest Sebastian (ETAS-DEV/XPC-Fe1)" Date: Tue, 3 May 2022 06:47:00 +0200 Subject: [PATCH 03/16] fixed pylint warnings --- python/tvm/relay/frontend/tflite.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index a5805343d52b..7477730bb610 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2800,7 +2800,8 @@ def convert_unidirectional_sequence_lstm(self, op): weights_dict["cell_state"] = _op.split(cell_state_in_expr, 1)[0] # Process weight matrix of input: w_inp - # Concatenate of [input_input_weight, input_forget_weights, input_cell_weights, input_output_weights] + # Concatenate of [input_input_weight, input_forget_weights, + # input_cell_weights, input_output_weights] input_input_weights_default_values = self.get_tensor_value(input_input_weights) input_input_weights_op = _op.split( _op.const(input_input_weights_default_values.tolist()), 1 @@ -2825,7 +2826,8 @@ def convert_unidirectional_sequence_lstm(self, op): axis=0, ) - # Process weight matrix of hidden state: w_hid to support lstm_cell function. Not used in tflite + # Process weight matrix of hidden state: + # w_hid to support lstm_cell function. Not used in tflite recurrent_input_weights_values = self.get_tensor_value(recurrent_input_weights) recurrent_input_weights_op = _op.split( _op.const(recurrent_input_weights_values.tolist()), 1 @@ -2869,7 +2871,8 @@ def convert_unidirectional_sequence_lstm(self, op): axis=0, ) - # Process wieght matrix of hidden bias: b_hid (with the same shape as b_inp) + # Process weight matrix of hidden bias: + # b_hid (with the same shape as b_inp) gate_bias_dtype = self.get_tensor_type_str(input_gate_bias.tensor.Type()) weights_dict["b_hid"] = _op.split( _op.const( @@ -2879,7 +2882,7 @@ def convert_unidirectional_sequence_lstm(self, op): 1, )[0] - outputs, H, C = lstm_cell(input_seqs=X_steps, **weights_dict) + outputs, _, _ = lstm_cell(input_seqs=X_steps, **weights_dict) output = _op.stack(outputs, axis=1) return output From 393823e189a759ccb583b76a367857ab7379309c Mon Sep 17 00:00:00 2001 From: "Boblest Sebastian (ETAS-DEV/XPC-Fe1)" Date: Tue, 3 May 2022 09:45:56 +0200 Subject: [PATCH 04/16] black formatted tflite.py --- python/tvm/relay/frontend/tflite.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 7477730bb610..70548a200da3 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2800,7 +2800,7 @@ def convert_unidirectional_sequence_lstm(self, op): weights_dict["cell_state"] = _op.split(cell_state_in_expr, 1)[0] # Process weight matrix of input: w_inp - # Concatenate of [input_input_weight, input_forget_weights, + # Concatenate of [input_input_weight, input_forget_weights, # input_cell_weights, input_output_weights] input_input_weights_default_values = self.get_tensor_value(input_input_weights) input_input_weights_op = _op.split( @@ -2826,7 +2826,7 @@ def convert_unidirectional_sequence_lstm(self, op): axis=0, ) - # Process weight matrix of hidden state: + # Process weight matrix of hidden state: # w_hid to support lstm_cell function. Not used in tflite recurrent_input_weights_values = self.get_tensor_value(recurrent_input_weights) recurrent_input_weights_op = _op.split( @@ -2871,7 +2871,7 @@ def convert_unidirectional_sequence_lstm(self, op): axis=0, ) - # Process weight matrix of hidden bias: + # Process weight matrix of hidden bias: # b_hid (with the same shape as b_inp) gate_bias_dtype = self.get_tensor_type_str(input_gate_bias.tensor.Type()) weights_dict["b_hid"] = _op.split( From 3a6ec7c4af4201f40f29b93282743a524621ecdf Mon Sep 17 00:00:00 2001 From: "Boblest Sebastian (ETAS-DEV/XPC-Fe1)" Date: Thu, 5 May 2022 09:19:15 +0200 Subject: [PATCH 05/16] corrections according to reviewer comments --- python/tvm/relay/frontend/tflite.py | 2 +- tests/python/frontend/tflite/test_forward.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 70548a200da3..1a02c4a24c3e 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2756,7 +2756,7 @@ def convert_unidirectional_sequence_lstm(self, op): ) input_tensors = self.get_input_tensors(op) - assert len(input_tensors) >= 2, "input tensors length should be >= 2" + assert len(input_tensors) >= 20, "input tensors length should be >= 20" # Extract input tensor from saved model input_tensor = input_tensors[0] diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index ad4efaa174a1..6a839ffcbc1e 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -4696,7 +4696,7 @@ def representative_dataset(): ####################################################################### # Unidirectional Sequence LSTM # --------------------- -def test_unidirectional_sequence_lstm(): +def test_forward_unidirectional_sequence_lstm(): """Test the UnidirectionalSequenceLSTM TFLite""" if package_version.parse(tf.VERSION) >= package_version.parse("2.1.0"): tflite_model_file = download_testdata( @@ -4964,6 +4964,8 @@ def test_prevent_tensorflow_dynamic_range(): test_forward_fully_connected() test_forward_l2_normalization() test_forward_local_response_normalization() + test_forward_unidirectional_sequence_lstm() + # Elemwise test_all_elemwise() From 1117175af8445327f7abed1396cbebcdd2f82e49 Mon Sep 17 00:00:00 2001 From: "Boblest Sebastian (ETAS-DEV/XPC-Fe1)" Date: Thu, 5 May 2022 11:45:10 +0200 Subject: [PATCH 06/16] fixed black formatting --- tests/python/frontend/tflite/test_forward.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 6a839ffcbc1e..9558497e569c 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1867,7 +1867,7 @@ def tf_function(self, x): model, export_dir, signatures=model.tf_function.get_concrete_function( - tf.TensorSpec(data.shape, tf.float32, name="input"), + tf.TensorSpec(data.shape, tf.float32, name="input") ), ) @@ -3759,8 +3759,7 @@ def test_forward_prelu(): np.full((32, 3), 0.2, dtype="float32"), ) _test_prelu( - np.random.uniform(-5, 5, size=(32, 3)).astype("float32"), - np.full((3), 0.2, dtype="float32"), + np.random.uniform(-5, 5, size=(32, 3)).astype("float32"), np.full((3), 0.2, dtype="float32") ) @@ -4966,7 +4965,6 @@ def test_prevent_tensorflow_dynamic_range(): test_forward_local_response_normalization() test_forward_unidirectional_sequence_lstm() - # Elemwise test_all_elemwise() test_forward_add_n() From 0ae9067bf6430bcbc5ac2f301163aa78f50c75d2 Mon Sep 17 00:00:00 2001 From: "Boblest Sebastian (ETAS-DEV/XPC-Fe1)" Date: Thu, 5 May 2022 16:19:13 +0200 Subject: [PATCH 07/16] just to trigger the CI again --- tests/python/frontend/tflite/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 9558497e569c..8b0244d75eda 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -4959,10 +4959,10 @@ def test_prevent_tensorflow_dynamic_range(): test_forward_leaky_relu() test_forward_relu_n1_to_1() test_forward_log_softmax() - test_forward_prelu() test_forward_fully_connected() test_forward_l2_normalization() test_forward_local_response_normalization() + test_forward_prelu() test_forward_unidirectional_sequence_lstm() # Elemwise From a5779c8d0e89b82a8ae53f755203caebd0f7927c Mon Sep 17 00:00:00 2001 From: "Boblest Sebastian (ETAS-DEV/XPC-Fe1)" Date: Mon, 9 May 2022 10:37:28 +0200 Subject: [PATCH 08/16] assertion now tests that there are exactly 24 input tensors. --- python/tvm/relay/frontend/tflite.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 1a02c4a24c3e..4ca535957213 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -33,7 +33,7 @@ from ..backend.name_transforms import sanitize_name from .common import ExprTable from .common import infer_shape as _infer_shape -from .common import lstm_cell, to_int_list, shape_of +from .common import lstm_cell, to_int_list, shape_of, infer_value from .tflite_flexbuffer import FlexBufferDecoder __all__ = ["from_tflite"] @@ -2756,7 +2756,7 @@ def convert_unidirectional_sequence_lstm(self, op): ) input_tensors = self.get_input_tensors(op) - assert len(input_tensors) >= 20, "input tensors length should be >= 20" + assert len(input_tensors) == 24, "input tensors length should be == 24" # Extract input tensor from saved model input_tensor = input_tensors[0] @@ -2777,6 +2777,10 @@ def convert_unidirectional_sequence_lstm(self, op): forget_gate_bias = input_tensors[13] cell_gate_bias = input_tensors[14] output_gate_bias = input_tensors[15] + + # unused + + # State input output_state_in = input_tensors[18] cell_state_in = input_tensors[19] From cd91cf18a9a005688e69b193a052ab3c7e5074a3 Mon Sep 17 00:00:00 2001 From: "Boblest Sebastian (ETAS-DEV/XPC-Fe1)" Date: Mon, 9 May 2022 10:45:55 +0200 Subject: [PATCH 09/16] black formatted tflite.py --- python/tvm/relay/frontend/tflite.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 4ca535957213..306ee3ddd8cd 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2780,7 +2780,6 @@ def convert_unidirectional_sequence_lstm(self, op): # unused - # State input output_state_in = input_tensors[18] cell_state_in = input_tensors[19] From 24dc203011c533c2d8f4a97a115150f0a1553b0a Mon Sep 17 00:00:00 2001 From: "Boblest Sebastian (ETAS-DEV/XPC-Fe1)" Date: Mon, 9 May 2022 10:52:43 +0200 Subject: [PATCH 10/16] added explanatory comment regarding unused imports --- python/tvm/relay/frontend/tflite.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 306ee3ddd8cd..1ab61fe7ab0a 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2772,14 +2772,14 @@ def convert_unidirectional_sequence_lstm(self, op): recurrent_forget_weights = input_tensors[6] recurrent_cell_weights = input_tensors[7] recurrent_output_weights = input_tensors[8] + # inputs 9, 10, 11, 16, 17, 20, 21, 22, 23 are not occupied + # there locations are -1 in the flatbuffer # Bias weights input_gate_bias = input_tensors[12] forget_gate_bias = input_tensors[13] cell_gate_bias = input_tensors[14] output_gate_bias = input_tensors[15] - # unused - # State input output_state_in = input_tensors[18] cell_state_in = input_tensors[19] From 6de5e4b705698c2aef47befd56fe80c054ab156c Mon Sep 17 00:00:00 2001 From: "Boblest Sebastian (ETAS-DEV/XPC-Fe1)" Date: Mon, 9 May 2022 12:41:32 +0200 Subject: [PATCH 11/16] removed unused import --- python/tvm/relay/frontend/tflite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 1ab61fe7ab0a..3cb8ed706160 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -33,7 +33,7 @@ from ..backend.name_transforms import sanitize_name from .common import ExprTable from .common import infer_shape as _infer_shape -from .common import lstm_cell, to_int_list, shape_of, infer_value +from .common import lstm_cell, to_int_list, shape_of from .tflite_flexbuffer import FlexBufferDecoder __all__ = ["from_tflite"] From 20d7548a67611fc3c920d80b88dd9015c6a95b0c Mon Sep 17 00:00:00 2001 From: "Boblest Sebastian (ETAS-DEV/XPC-Fe1)" Date: Mon, 9 May 2022 17:35:26 +0200 Subject: [PATCH 12/16] nothing --- python/tvm/relay/frontend/tflite.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 3cb8ed706160..de1321e0806a 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -199,7 +199,6 @@ def check_unsupported_ops(self): qnn_out_cnt = len( [_.qnn_params for _ in self.get_output_tensors(op) if _.qnn_params is not None] ) - if qnn_in_cnt == 0 and qnn_out_cnt == 0 and qnn_weight_cnt > 0: dynamic_range_ops_set.add(op_code_str) From e989c402ae2c40b1660d8dbba48071ab875b6d31 Mon Sep 17 00:00:00 2001 From: "Boblest Sebastian (ETAS-DEV/XPC-Fe1)" Date: Mon, 9 May 2022 17:35:36 +0200 Subject: [PATCH 13/16] nothing --- python/tvm/relay/frontend/tflite.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index de1321e0806a..3cb8ed706160 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -199,6 +199,7 @@ def check_unsupported_ops(self): qnn_out_cnt = len( [_.qnn_params for _ in self.get_output_tensors(op) if _.qnn_params is not None] ) + if qnn_in_cnt == 0 and qnn_out_cnt == 0 and qnn_weight_cnt > 0: dynamic_range_ops_set.add(op_code_str) From 8e17b20052c8d1fd9463964a3984b18d9cb71a7a Mon Sep 17 00:00:00 2001 From: "Boblest Sebastian (ETAS-DEV/XPC-Fe1)" Date: Tue, 24 May 2022 16:05:48 +0200 Subject: [PATCH 14/16] added some details in a comment about the differences in unbind regarding to the version in common.py --- python/tvm/relay/frontend/tflite.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 3cb8ed706160..f9cfb7bebb43 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -223,7 +223,8 @@ def check_unsupported_ops(self): def unbind(self, data, axis=1): """ - This is a slightly modified version compared to the one in common.py + This is a modified version compared to the one in common.py. + In onnx files the timestep index is shape[0], in tflite it is shape[1]. Parameters ---------- From efecbcd472912c29ec697cf1be6824d11da44c65 Mon Sep 17 00:00:00 2001 From: "Boblest Sebastian (ETAS-DEV/XPC-Fe1)" Date: Wed, 25 May 2022 16:20:56 +0200 Subject: [PATCH 15/16] improved comment on unbind --- python/tvm/relay/frontend/tflite.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index f9cfb7bebb43..22f36de62cd7 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -224,11 +224,13 @@ def check_unsupported_ops(self): def unbind(self, data, axis=1): """ This is a modified version compared to the one in common.py. - In onnx files the timestep index is shape[0], in tflite it is shape[1]. + The onnx version takes a relay.Expr.Call, the tflite + version a TensorWrapper. Also this version by default splits + along axis 1 and not axis 0 as the onnx version. Parameters ---------- - data : relay.Expr + data : tvm.relay.frontend.tflite.TensorWrapper Input tensor axis : int Axis along which tensor is split. From 089324ab9d033cd1b221dde4ce068db960d7fd50 Mon Sep 17 00:00:00 2001 From: "Boblest Sebastian (ETAS-DEV/XPC-Fe1)" Date: Fri, 27 May 2022 14:15:49 +0200 Subject: [PATCH 16/16] fix of black issue --- python/tvm/relay/frontend/tflite.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 22f36de62cd7..d1d764619434 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -224,8 +224,8 @@ def check_unsupported_ops(self): def unbind(self, data, axis=1): """ This is a modified version compared to the one in common.py. - The onnx version takes a relay.Expr.Call, the tflite - version a TensorWrapper. Also this version by default splits + The onnx version takes a relay.Expr.Call, the tflite + version a TensorWrapper. Also this version by default splits along axis 1 and not axis 0 as the onnx version. Parameters