Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unidirectional sequence lstm #11183

Merged
180 changes: 179 additions & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from ..backend.name_transforms import sanitize_name
from .common import ExprTable
from .common import infer_shape as _infer_shape
from .common import to_int_list, shape_of
from .common import lstm_cell, to_int_list, shape_of
from .tflite_flexbuffer import FlexBufferDecoder

__all__ = ["from_tflite"]
Expand Down Expand Up @@ -173,6 +173,7 @@ def __init__(self, model, subgraph, exp_tab):
"TRANSPOSE_CONV": self.convert_transpose_conv,
"TRANSPOSE": self.convert_transpose,
"UNPACK": self.convert_unpack,
"UNIDIRECTIONAL_SEQUENCE_LSTM": self.convert_unidirectional_sequence_lstm,
"WHERE": self.convert_select,
"ZEROS_LIKE": self.convert_zeros_like,
}
Expand Down Expand Up @@ -220,6 +221,41 @@ def check_unsupported_ops(self):
if len(raise_msg) > 0:
raise tvm.error.OpNotImplemented(raise_msg)

def unbind(self, data, axis=1):
"""
This is a modified version compared to the one in common.py.
The onnx version takes a relay.Expr.Call, the tflite
version a TensorWrapper. Also this version by default splits
along axis 1 and not axis 0 as the onnx version.

Parameters
----------
data : tvm.relay.frontend.tflite.TensorWrapper
Input tensor
axis : int
Axis along which tensor is split.
Returns
-------
result : List[relay.Expr]
The sequence of computed tensors
"""
shape = to_int_list(self.get_tensor_shape(data))
if axis >= len(shape):
msg = "Please check input dim, it shouldn't be greater than or equal to rank."
raise AttributeError(msg)

selections = shape[axis]
shape.pop(axis)
timestep = 0 # Reshape to make time step as the first dim
shape.insert(timestep, selections)
res_split = _op.split(
_op.reshape(self.get_expr(data.tensor_idx), tuple(shape)), selections, timestep
)
ret = []
for i in range(selections):
ret.append(_op.squeeze(res_split[i], axis=[timestep]))
return _expr.TupleWrapper(_expr.Tuple(ret), selections)

def convert_op_to_relay(self):
"""Convert TFLite ops to relay ops"""
for op_idx in range(self.subgraph.OperatorsLength()):
Expand Down Expand Up @@ -2715,6 +2751,148 @@ def convert_unpack(self, op):

return squeezed

def convert_unidirectional_sequence_lstm(self, op):
"""Long Short Term Memory for TFLite implementation."""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
"TFlite quantized UNIDIRECTIONALSEQUENCELSTM operator is not supported yet."
)

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 24, "input tensors length should be == 24"

# Extract input tensor from saved model
input_tensor = input_tensors[0]

# Extract tensors from input tensors from saved model
# Input weights
input_input_weights = input_tensors[1]
input_forget_weights = input_tensors[2]
input_cell_weights = input_tensors[3]
input_output_weights = input_tensors[4]
# Recurrent weights
recurrent_input_weights = input_tensors[5]
recurrent_forget_weights = input_tensors[6]
recurrent_cell_weights = input_tensors[7]
recurrent_output_weights = input_tensors[8]
# inputs 9, 10, 11, 16, 17, 20, 21, 22, 23 are not occupied
# there locations are -1 in the flatbuffer
# Bias weights
input_gate_bias = input_tensors[12]
forget_gate_bias = input_tensors[13]
cell_gate_bias = input_tensors[14]
output_gate_bias = input_tensors[15]

# State input
AndrewZhaoLuo marked this conversation as resolved.
Show resolved Hide resolved
output_state_in = input_tensors[18]
cell_state_in = input_tensors[19]

# Extract output tensor from saved model
output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
X_steps = self.unbind(input_tensor, axis=1)
weights_dict = {}

# hidden_state_weights is equivalent to output_state_in in tflite model
out_state_in_shape = tuple(self.get_tensor_shape(output_state_in))
out_state_in_dtype = self.get_tensor_type_str(output_state_in.tensor.Type())
out_state_in_expr = _op.zeros(out_state_in_shape, dtype=out_state_in_dtype)
weights_dict["hidden_state"] = _op.split(out_state_in_expr, 1)[0]
AndrewZhaoLuo marked this conversation as resolved.
Show resolved Hide resolved

# cell_state_weights is equivalent to output_state_in tflite model
cell_state_in_shape = tuple(self.get_tensor_shape(cell_state_in))
cell_state_in_dtype = self.get_tensor_type_str(cell_state_in.tensor.Type())
cell_state_in_expr = _op.zeros(cell_state_in_shape, dtype=cell_state_in_dtype)
weights_dict["cell_state"] = _op.split(cell_state_in_expr, 1)[0]

# Process weight matrix of input: w_inp
# Concatenate of [input_input_weight, input_forget_weights,
# input_cell_weights, input_output_weights]
input_input_weights_default_values = self.get_tensor_value(input_input_weights)
input_input_weights_op = _op.split(
_op.const(input_input_weights_default_values.tolist()), 1
)
input_output_weights_default_values = self.get_tensor_value(input_output_weights)
input_output_weights_op = _op.split(
_op.const(input_output_weights_default_values.tolist()), 1
)
input_forget_weights_default_values = self.get_tensor_value(input_forget_weights)
input_forget_weights_op = _op.split(
_op.const(input_forget_weights_default_values.tolist()), 1
)
input_cell_weights_default_values = self.get_tensor_value(input_cell_weights)
input_cell_weights_op = _op.split(_op.const(input_cell_weights_default_values.tolist()), 1)
weights_dict["w_inp"] = _op.concatenate(
[
_op.squeeze(input_input_weights_op[0]),
_op.squeeze(input_forget_weights_op[0]),
_op.squeeze(input_cell_weights_op[0]),
_op.squeeze(input_output_weights_op[0]),
],
axis=0,
)

# Process weight matrix of hidden state:
# w_hid to support lstm_cell function. Not used in tflite
recurrent_input_weights_values = self.get_tensor_value(recurrent_input_weights)
recurrent_input_weights_op = _op.split(
_op.const(recurrent_input_weights_values.tolist()), 1
)
recurrent_output_weights_values = self.get_tensor_value(recurrent_output_weights)
recurrent_output_weights_op = _op.split(
_op.const(recurrent_output_weights_values.tolist()), 1
)
recurrent_forget_weights_values = self.get_tensor_value(recurrent_forget_weights)
recurrent_forget_weights_op = _op.split(
_op.const(recurrent_forget_weights_values.tolist()), 1
)
recurrent_cell_weights_values = self.get_tensor_value(recurrent_cell_weights)
recurrent_cell_weights_op = _op.split(_op.const(recurrent_cell_weights_values.tolist()), 1)
weights_dict["w_hid"] = _op.concatenate(
[
recurrent_input_weights_op[0],
recurrent_forget_weights_op[0],
recurrent_cell_weights_op[0],
recurrent_output_weights_op[0],
],
axis=0,
)

# Process weight matrix of bias: b_inp
input_gate_bias_values = self.get_tensor_value(input_gate_bias)
input_gate_bias_op = _op.split(_op.const(input_gate_bias_values.tolist()), 1)
output_gate_bias_values = self.get_tensor_value(output_gate_bias)
output_gate_bias_op = _op.split(_op.const(output_gate_bias_values.tolist()), 1)
forget_gate_bias_values = self.get_tensor_value(forget_gate_bias)
forget_gate_bias_op = _op.split(_op.const(forget_gate_bias_values.tolist()), 1)
cell_gate_bias_values = self.get_tensor_value(cell_gate_bias)
cell_gate_bias_op = _op.split(_op.const(cell_gate_bias_values.tolist()), 1)
weights_dict["b_inp"] = _op.concatenate(
[
input_gate_bias_op[0],
forget_gate_bias_op[0],
cell_gate_bias_op[0],
output_gate_bias_op[0],
],
axis=0,
)

# Process weight matrix of hidden bias:
# b_hid (with the same shape as b_inp)
gate_bias_dtype = self.get_tensor_type_str(input_gate_bias.tensor.Type())
weights_dict["b_hid"] = _op.split(
_op.const(
np.zeros(_infer_shape(weights_dict["b_inp"]), dtype=gate_bias_dtype),
dtype=gate_bias_dtype,
),
1,
)[0]

outputs, _, _ = lstm_cell(input_seqs=X_steps, **weights_dict)

output = _op.stack(outputs, axis=1)
return output

def convert_batch_to_space_nd(self, op):
"""batch_to_space_nd implementation."""

Expand Down
38 changes: 34 additions & 4 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1867,7 +1867,7 @@ def tf_function(self, x):
model,
export_dir,
signatures=model.tf_function.get_concrete_function(
tf.TensorSpec(data.shape, tf.float32, name="input"),
tf.TensorSpec(data.shape, tf.float32, name="input")
),
)

Expand Down Expand Up @@ -3759,8 +3759,7 @@ def test_forward_prelu():
np.full((32, 3), 0.2, dtype="float32"),
)
_test_prelu(
np.random.uniform(-5, 5, size=(32, 3)).astype("float32"),
np.full((3), 0.2, dtype="float32"),
np.random.uniform(-5, 5, size=(32, 3)).astype("float32"), np.full((3), 0.2, dtype="float32")
)


Expand Down Expand Up @@ -4693,6 +4692,36 @@ def representative_dataset():
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)


#######################################################################
# Unidirectional Sequence LSTM
# ---------------------
def test_forward_unidirectional_sequence_lstm():
"""Test the UnidirectionalSequenceLSTM TFLite"""
if package_version.parse(tf.VERSION) >= package_version.parse("2.1.0"):
tflite_model_file = download_testdata(
"https://github.com/SebastianBoblestETAS/nn_models/blob/ce49c5de64889493161ca4194a20e0fd5eb707e6/lstm_1_in_3_out_2_ts_4.tflite?raw=true",
"lstm_1_in_3_out_2_ts_4.tflite",
)
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()

data = np.array(
[
[
[0.5488135, 0.71518934, 0.60276335],
[0.5448832, 0.4236548, 0.6458941],
[0.4375872, 0.891773, 0.96366274],
[0.3834415, 0.79172504, 0.5288949],
]
],
dtype="float32",
)

tflite_output = run_tflite_graph(tflite_model_buf, data)
tvm_output = run_tvm_graph(tflite_model_buf, data, "serving_default_input_1:0")
tvm.testing.assert_allclose(tflite_output, tvm_output)


#######################################################################
# Quantized SSD Mobilenet
# -----------------------
Expand Down Expand Up @@ -4930,10 +4959,11 @@ def test_prevent_tensorflow_dynamic_range():
test_forward_leaky_relu()
test_forward_relu_n1_to_1()
test_forward_log_softmax()
test_forward_prelu()
test_forward_fully_connected()
test_forward_l2_normalization()
test_forward_local_response_normalization()
test_forward_prelu()
test_forward_unidirectional_sequence_lstm()

# Elemwise
test_all_elemwise()
Expand Down