From fcf4c83ae8018833e208f38a778b5dcbf51bfa41 Mon Sep 17 00:00:00 2001
From: Matthew Brookhart <mbrookhart@octoml.ai>
Date: Wed, 11 Nov 2020 09:43:41 -0700
Subject: [PATCH 1/4] add a regression test for fusing dynamic take

---
 tests/python/relay/test_pass_fuse_ops.py | 27 ++++++++++++++++++++++++
 1 file changed, 27 insertions(+)

diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py
index a3146de55d5a..30ee29525daa 100644
--- a/tests/python/relay/test_pass_fuse_ops.py
+++ b/tests/python/relay/test_pass_fuse_ops.py
@@ -14,6 +14,8 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import numpy as np
+
 import tvm
 from tvm import relay
 from tvm.relay import transform
@@ -757,6 +759,31 @@ def create_diamond_func(inp):
     assert tvm.ir.structural_equal(fused, expected)
 
 
+def test_fuse_dynamic_squeeze_slice_take():
+    input_data = [
+        np.random.random([1, 2, 4]).astype("float32"),
+        np.array([0]).astype("int64"),
+    ]
+
+    x = relay.var("p0107", shape=(relay.Any(), relay.Any(), 4), dtype="float32")
+    take_val = relay.var("p166", shape=(relay.Any(),), dtype="int64")
+
+    squeeze = relay.op.squeeze(x, axis=[0])
+    strided_slice = relay.op.strided_slice(
+        squeeze, begin=[0, 0], end=[15130, 9223372036854775807], strides=[1, 1]
+    )
+    take = relay.op.take(strided_slice, take_val, axis=0)
+
+    mod = tvm.IRModule.from_expr(take)
+    ex = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(), target="llvm")
+
+    result = ex.evaluate()(*input_data)
+
+    np_result = np.squeeze(input_data[0][:, input_data[1][0], :], axis=0)
+
+    assert np.allclose(result.asnumpy(), np_result)
+
+
 if __name__ == "__main__":
     test_fuse_simple()
     test_conv2d_fuse()

From f6fc702f4dbf507ef23838f3fa6d3812a49dbeff Mon Sep 17 00:00:00 2001
From: Matthew Brookhart <mbrookhart@octoml.ai>
Date: Wed, 25 Nov 2020 15:59:59 -0700
Subject: [PATCH 2/4] add legalize for take that stops fusion on dynamic inputs

---
 python/tvm/relay/op/_transform.py | 19 +++++++++++++++++++
 python/tvm/topi/transform.py      | 22 ++++++++++++++++++++++
 2 files changed, 41 insertions(+)

diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index e42b8bbae814..7843f840824c 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -332,6 +332,25 @@ def take_shape_func(attrs, inputs, out_ndims):
     return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])]
 
 
+@_reg.register_legalize("take")
+def legalize_dyn_topk(attrs, inputs, types):
+    """Legalize take op.
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current convolution
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    types : list of types
+        List of input and output types
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
+    """
+    return topi.take_legalize(attrs, inputs, types)
+
+
 @script
 def _argwhere_shape_func_1d(condition):
     out = output_tensor((2,), "int64")
diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py
index cdf9ce5c9275..353b587da70c 100644
--- a/python/tvm/topi/transform.py
+++ b/python/tvm/topi/transform.py
@@ -426,6 +426,28 @@ def take(a, indices, axis=None, mode="clip"):
     return cpp.take(a, indices, int(axis), mode)
 
 
+@tvm.target.generic_func
+def take_legalize(attrs, inputs, types):
+    """Legalizes dyn.topk op.
+
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current convolution
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    types : list of types
+        List of input and output types
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
+    """
+    if tvm.relay.ty.is_dynamic(inputs[0].checked_type):
+        return tvm.relay.take(tvm.relay.annotation.stop_fusion(inputs[0]), inputs[1], **attrs)
+    return None
+
+
 def gather(data, axis, indices):
     """Gather values along given axis from given indices.
 

From 7a740702f1f9709e9e709970a2ce7deb979e9b45 Mon Sep 17 00:00:00 2001
From: Matthew Brookhart <mbrookhart@octoml.ai>
Date: Wed, 25 Nov 2020 16:08:17 -0700
Subject: [PATCH 3/4] fix lint

---
 python/tvm/topi/transform.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py
index 353b587da70c..9f0b4a2cba98 100644
--- a/python/tvm/topi/transform.py
+++ b/python/tvm/topi/transform.py
@@ -443,7 +443,7 @@ def take_legalize(attrs, inputs, types):
     result : tvm.relay.Expr
         The legalized expr
     """
-    if tvm.relay.ty.is_dynamic(inputs[0].checked_type):
+    if tvm.relay.ty.is_dynamic(types[0]):
         return tvm.relay.take(tvm.relay.annotation.stop_fusion(inputs[0]), inputs[1], **attrs)
     return None
 

From 45a4ccd59386936687222c19e84f474e9d57a6e7 Mon Sep 17 00:00:00 2001
From: Matthew Brookhart <mbrookhart@octoml.ai>
Date: Wed, 25 Nov 2020 17:54:48 -0700
Subject: [PATCH 4/4] fix typo

---
 python/tvm/relay/op/_transform.py | 2 +-
 python/tvm/topi/transform.py      | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index 7843f840824c..439d44b5790b 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -338,7 +338,7 @@ def legalize_dyn_topk(attrs, inputs, types):
     Parameters
     ----------
     attrs : tvm.ir.Attrs
-        Attributes of current convolution
+        Attributes of current op
     inputs : list of tvm.relay.Expr
         The args of the Relay expr to be legalized
     types : list of types
diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py
index 9f0b4a2cba98..6ddbc73e4666 100644
--- a/python/tvm/topi/transform.py
+++ b/python/tvm/topi/transform.py
@@ -433,7 +433,7 @@ def take_legalize(attrs, inputs, types):
     Parameters
     ----------
     attrs : tvm.ir.Attrs
-        Attributes of current convolution
+        Attributes of current op
     inputs : list of tvm.relay.Expr
         The args of the Relay expr to be legalized
     types : list of types