From 4b7d78d157330e455e8b6c34973ab8608a011e90 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 22 Feb 2024 11:22:37 -0600 Subject: [PATCH] [Relax] Handle dynamic arguments in legalization of nn.attention (#16592) Prior to this commit, when using causal_mask="BottomRight" in `R.nn.attention`, the legalization would assume that the query and key/value sequence lengths were static integers. This commit updates the legalization to allow dynamic shapes. --- python/tvm/relax/transform/legalize_ops/nn.py | 2 +- .../relax/test_transform_legalize_ops_nn.py | 24 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 87eea97a8b04..f80d28099c82 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -486,7 +486,7 @@ def _te_attention( if causal_mask == "TopLeft": offset = tir.IntImm("int32", 0) elif causal_mask == "BottomRight": - offset = tir.IntImm("int32", abs(seq_len - seq_len_kv)) + offset = tir.abs(seq_len - seq_len_kv).astype("int32") else: raise NotImplementedError() p_masked = topi.trilu(p, k=offset, upper=False) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 45e6bd878a95..29171daaae3a 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -3270,6 +3270,30 @@ def main(q: R.Tensor((4, 16, 32, 8), dtype="float32"), k: R.Tensor((4, 8, 32, 8) tvm.ir.assert_structural_equal(mod, Expected) +def test_dynamic_attention(): + """The sequence lengths may be dynamic + + In previous implementations, the `seq_len` and `seq_len_kv` were + assumed to be static integers, and produced an exception during + legalization. + """ + + @tvm.script.ir_module + class Attention: + @R.function + def main( + q: R.Tensor((4, "seq_len", 32, 8), "float32"), + k: R.Tensor((4, "seq_len_kv", 32, 8), "float32"), + v: R.Tensor((4, "seq_len_kv", 32, 16), "float32"), + bias: R.Tensor((4, 32, "seq_len", "seq_len_kv"), "float32"), + ): + scale = T.FloatImm("float32", 0.1) + gv = R.nn.attention(q, k, v, bias, scale=scale, causal_mask="BottomRight") + return gv + + LegalizeOps()(Attention) + + def test_nll_loss(): # fmt: off @tvm.script.ir_module