Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] [ONNX] Support sequence_lens of GRU #13587

Merged
merged 1 commit into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 54 additions & 3 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,7 @@ def gru_cell(
n_act=_op.tanh,
backwards=False,
linear_before_reset=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add description of arguments linear_before_reset and sequence_lens

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

sequence_lens=None,
):
"""
Common implementation of GRU cell for all frontends of TVM
Expand Down Expand Up @@ -765,15 +766,53 @@ def gru_cell(
activation function for new gate. it is tanh by default
backwards : bool
Flag for reverse pass of GRU

linear_before_reset : bool
Flag for applying the linear transformation before multiplying by the output of the reset
gate.
sequence_lens : relay.op
Tensor specifying lengths of the sequences in a batch.
Shape = (batch_size)
Returns
-------
result : List[relay.Expr], relay.Expr, relay.Expr
The sequence of computed result, final hidden and cell state
"""

outputs_list = []
for x_t in input_seqs if not backwards else reversed(input_seqs):

seq_len = len(input_seqs)
input_dtype = infer_type(input_seqs[0]).checked_type.dtype

if sequence_lens is not None:
shape = infer_shape(sequence_lens)
dtype = infer_type(sequence_lens).checked_type.dtype

arange = _op.arange(_op.const(0), _op.const(seq_len), dtype=dtype)
arange = _op.expand_dims(arange, 1)
sequence_lens = _op.broadcast_to(sequence_lens, [seq_len, shape[0]])

# cast to data dtype
mask = _op.less(arange, sequence_lens)
mask = _op.cast(mask, dtype=input_dtype)
mask = _op.expand_dims(mask, 2)
mask_seqs = unbind(mask)

res_mask = _op.greater_equal(arange, sequence_lens)
res_mask = _op.cast(res_mask, dtype=input_dtype)
res_mask = _op.expand_dims(res_mask, 2)
res_mask_seqs = unbind(res_mask)

if backwards:
# need a mask to keep intial_h_B correct
initial_h = hidden_state
initial_h_mask = _op.equal(arange, sequence_lens)
initial_h_mask = _op.cast(initial_h_mask, dtype=input_dtype)
initial_h_mask = _op.expand_dims(initial_h_mask, 2)
initial_h_mask_seqs = unbind(initial_h_mask)

output = _op.zeros(infer_shape(hidden_state), input_dtype)
for i in range(seq_len) if not backwards else reversed(range(seq_len)):
x_t = input_seqs[i]
xwt = _op.nn.dense(x_t, w_inp)
if linear_before_reset:
hwt = _op.nn.dense(hidden_state, w_hid)
Expand Down Expand Up @@ -806,9 +845,21 @@ def gru_cell(

hidden_state = (hidden_state - n_gate) * z_gate + n_gate

if sequence_lens is not None:
hidden_state = hidden_state * mask_seqs[i]

outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)]

return outputs_list, hidden_state
if sequence_lens is not None:
output = output * res_mask_seqs[i] + hidden_state
else:
output = hidden_state

# make sure initial_h_B correct
if backwards and sequence_lens is not None:
hidden_state = hidden_state + initial_h * initial_h_mask_seqs[i]

return outputs_list, output


def lstm_cell(
Expand Down
18 changes: 12 additions & 6 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3126,8 +3126,7 @@ def _inputs_helper(cls, inputs, layout):
Wp = inputs[1]
Rp = inputs[2]
Bp = inputs[3]
# Sequence length currently unused as it can be inferred from shapes.
# sequence_lens = inputs['sequence_lens']
sequence_lens = inputs[4]
Hp_0 = inputs[5]

num_directions = infer_shape(Wp)[0]
Expand Down Expand Up @@ -3158,11 +3157,11 @@ def _inputs_helper(cls, inputs, layout):
Bs = None
if Bp is not None:
Bs = _op.split(Bp, num_directions)
return X_steps, H_ts, Ws, Rs, Bs, num_directions
return X_steps, H_ts, Ws, Rs, Bs, num_directions, sequence_lens

@classmethod
def _impl_common(cls, inputs, attr, layout):
X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout)
X_steps, H_ts, Ws, Rs, Bs, num_directions, _ = cls._inputs_helper(inputs, layout)
acts = cls._get_activations(attr, 1, num_directions, "RNN")

weights_dicts = []
Expand Down Expand Up @@ -3261,7 +3260,7 @@ def _default_activations(cls, num_directions):

@classmethod
def _impl_common(cls, inputs, attr, layout):
X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout)
X_steps, H_ts, Ws, Rs, Bs, num_directions, _ = cls._inputs_helper(inputs, layout)
acts = cls._get_activations(attr, 3, num_directions, "LSTM")

# cell state
Expand Down Expand Up @@ -3346,6 +3345,7 @@ def bidir_gru_cell(
input_seqs,
weight_dicts,
acts,
sequence_lens=None,
):
"""
Bidirectional GRU cell
Expand All @@ -3356,6 +3356,7 @@ def bidir_gru_cell(
**weight_dicts[0],
rz_act=acts[0],
n_act=acts[1],
sequence_lens=sequence_lens,
)

reverse_outputs, rev_H_t = gru_cell(
Expand All @@ -3364,6 +3365,7 @@ def bidir_gru_cell(
rz_act=acts[2],
n_act=acts[3],
backwards=True,
sequence_lens=sequence_lens,
)

final_outputs = []
Expand All @@ -3383,7 +3385,9 @@ def _default_activations(cls, num_directions):

@classmethod
def _impl_common(cls, inputs, attr, layout):
X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout)
X_steps, H_ts, Ws, Rs, Bs, num_directions, sequence_lens = cls._inputs_helper(
inputs, layout
)
acts = cls._get_activations(attr, 2, num_directions, "GRU")
linear_before_reset = attr.get("linear_before_reset", 0)

Expand Down Expand Up @@ -3412,6 +3416,7 @@ def _impl_common(cls, inputs, attr, layout):
input_seqs=X_steps,
weight_dicts=weights_dicts,
acts=acts,
sequence_lens=sequence_lens,
)
else:
# outputs shape = [seqs_num, (batch_size, hidden_size)]
Expand All @@ -3420,6 +3425,7 @@ def _impl_common(cls, inputs, attr, layout):
**weights_dicts[0],
rz_act=acts[0],
n_act=acts[1],
sequence_lens=sequence_lens,
)

# output shape = (seqs_num, num_directions, batch_size, hidden_size)
Expand Down
40 changes: 38 additions & 2 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3897,6 +3897,7 @@ def verify_rnn(
atol=1e-5,
target=None,
dev=None,
use_sequence_lens=False,
):
"""verify_rnn"""
if rnn_type == "RNN":
Expand Down Expand Up @@ -3954,10 +3955,16 @@ def register(np_arr, name, shape=None):
)
register(b_np, "B")

if use_sequence_lens:
sequence_np = np.random.uniform(0, seq_length, size=(batch_size)).astype("int32")
register(sequence_np, "sequence_lens")

if use_initial_state:
assert use_bias is True, "Initial states must have bias specified."
sequence_np = np.repeat(seq_length, batch_size).astype("int32")
register(sequence_np, "sequence_lens")

if not use_sequence_lens:
sequence_np = np.repeat(seq_length, batch_size).astype("int32")
register(sequence_np, "sequence_lens")

if layout == 1:
initial_h_np = np.random.uniform(size=(batch_size, directions, hidden_size)).astype(
Expand Down Expand Up @@ -4211,6 +4218,35 @@ def verify_rnn_helper(target, dev, rnn_type):
# dev=dev,
# )

# Testing with initial state
if rnn_type == "GRU":
verify_rnn(
seq_length=2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please increase seq_length (to 4-5 or more) and batch_size (to 4-8) for correct test of sequence with different lengths

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

batch_size=1,
input_size=16,
hidden_size=32,
use_bias=True,
use_initial_state=True,
rnn_type=rnn_type,
directions=directions,
target=target,
dev=dev,
use_sequence_lens=True,
)
verify_rnn(
seq_length=8,
batch_size=8,
input_size=16,
hidden_size=32,
use_bias=True,
use_initial_state=True,
rnn_type=rnn_type,
directions=directions,
target=target,
dev=dev,
use_sequence_lens=True,
)

# Testing with peepholes
if rnn_type == "LSTM":
verify_rnn(
Expand Down