Skip to content

Commit

Permalink
[Relax] Handle dynamic arguments in legalization of nn.attention (#16592
Browse files Browse the repository at this point in the history
)

Prior to this commit, when using causal_mask="BottomRight" in `R.nn.attention`,
the legalization would assume that the query and key/value sequence lengths
were static integers. This commit updates the legalization to allow dynamic shapes.
  • Loading branch information
Lunderberg authored Feb 22, 2024
1 parent 8f42597 commit 4b7d78d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def _te_attention(
if causal_mask == "TopLeft":
offset = tir.IntImm("int32", 0)
elif causal_mask == "BottomRight":
offset = tir.IntImm("int32", abs(seq_len - seq_len_kv))
offset = tir.abs(seq_len - seq_len_kv).astype("int32")
else:
raise NotImplementedError()
p_masked = topi.trilu(p, k=offset, upper=False)
Expand Down
24 changes: 24 additions & 0 deletions tests/python/relax/test_transform_legalize_ops_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3270,6 +3270,30 @@ def main(q: R.Tensor((4, 16, 32, 8), dtype="float32"), k: R.Tensor((4, 8, 32, 8)
tvm.ir.assert_structural_equal(mod, Expected)


def test_dynamic_attention():
"""The sequence lengths may be dynamic
In previous implementations, the `seq_len` and `seq_len_kv` were
assumed to be static integers, and produced an exception during
legalization.
"""

@tvm.script.ir_module
class Attention:
@R.function
def main(
q: R.Tensor((4, "seq_len", 32, 8), "float32"),
k: R.Tensor((4, "seq_len_kv", 32, 8), "float32"),
v: R.Tensor((4, "seq_len_kv", 32, 16), "float32"),
bias: R.Tensor((4, 32, "seq_len", "seq_len_kv"), "float32"),
):
scale = T.FloatImm("float32", 0.1)
gv = R.nn.attention(q, k, v, bias, scale=scale, causal_mask="BottomRight")
return gv

LegalizeOps()(Attention)


def test_nll_loss():
# fmt: off
@tvm.script.ir_module
Expand Down

0 comments on commit 4b7d78d

Please sign in to comment.