Skip to content

Commit

Permalink
[Unity][Relax] Make RewriteDataflowReshape only rewrite volume-preser…
Browse files Browse the repository at this point in the history
…ving ops (apache#15112)

* [Unity] Make RewriteDataflowReshape only rewrite volume-preserving ops

The reshape operator expects that the number of elements in the source
is the same as the number of elements in the result. There are operators
that could have a reshape pattern that don't meet this requirement (e.g.
strided_slice), and they should not be converted to reshape.

* Move shape verification to IsCallingTIRReshape

* Replace ICHECK_EQ(used_arg_indices.size(), 1) with return

Since the check for has-reshape-pattern is done after this check, so
don't abort if the check fails, just return.
  • Loading branch information
Krzysztof Parzyszek authored and junrushao committed Jun 22, 2023
1 parent 7693503 commit 54a7b25
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 8 deletions.
51 changes: 45 additions & 6 deletions src/relax/transform/rewrite_dataflow_reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
* \file src/relax/transform/rewrite_dataflow_reshape.cc
* \brief Transform all reshape within dataflow block to a relax.reshape operator
*/
#include <tvm/arith/analyzer.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/function.h>

#include <vector>

#include "../op/tensor/manipulate.h"

namespace tvm {
Expand Down Expand Up @@ -69,7 +72,7 @@ class DataflowReshapeRewriter : public ExprMutator {
}

Expr VisitExpr_(const CallNode* call) final {
if (!IsCallingTIRReshape(call)) {
if (call->args.size() < 2) {
return GetRef<Call>(call);
}

Expand All @@ -85,24 +88,60 @@ class DataflowReshapeRewriter : public ExprMutator {
// can generate a fused TupleGetItem + reshape function whose input is a tuple. FuseTIR
// then flattens the tuple input so that the fused TIR reshape function ends up having
// multiple input buffers. But only one of them should be accessed and reshaped.
ICHECK_EQ(used_arg_indices.size(), 1);
if (used_arg_indices.size() != 1) {
return GetRef<Call>(call);
}

auto arg = arg_tuple[used_arg_indices[0]];

TensorStructInfo res_sinfo = Downcast<TensorStructInfo>(call->struct_info_);
ICHECK(res_sinfo->shape.defined());
if (!IsCallingTIRReshape(call, arg)) {
return GetRef<Call>(call);
}

TensorStructInfo res_sinfo = Downcast<TensorStructInfo>(call->struct_info_.value());
return reshape(arg, res_sinfo->shape.value());
}

bool IsCallingTIRReshape(const CallNode* call) {
bool IsCallingTIRReshape(const CallNode* call, Expr inp) {
static const Op& call_tir_op = Op::Get("relax.call_tir");
if (call->op != call_tir_op) {
return false;
}
const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
const auto* func = mod_->functions.Get(global_var).as<tir::PrimFuncNode>();
ICHECK_NOTNULL(func);
return HasReshapePattern(GetRef<tir::PrimFunc>(func));
if (!HasReshapePattern(GetRef<tir::PrimFunc>(func))) {
return false;
}

// The reshape operator expects that the number of elements in the source is the same
// as the number of elements in the result. There are operators that could have a reshape
// pattern that don't meet this requirement (e.g. strided_slice), and they should not be
// converted to reshape.
ICHECK(inp->struct_info_.defined() && call->struct_info_.defined());
TensorStructInfo inp_sinfo = Downcast<TensorStructInfo>(inp->struct_info_.value());
TensorStructInfo res_sinfo = Downcast<TensorStructInfo>(call->struct_info_.value());

if (inp_sinfo->IsUnknownDtype() || inp_sinfo->dtype != res_sinfo->dtype) {
return false;
}
ICHECK(inp_sinfo->shape.defined() && res_sinfo->shape.defined());
if (inp_sinfo->IsUnknownNdim() || res_sinfo->IsUnknownNdim()) {
return false;
}
auto product = [](Array<PrimExpr> args) -> PrimExpr {
ICHECK(!args.empty());
PrimExpr p = args[0];
for (int i = 1, e = args.size(); i < e; ++i) p *= args[i];
return p;
};
auto inp_count = product(inp_sinfo->GetShape().value());
auto res_count = product(res_sinfo->GetShape().value());
if (!arith::Analyzer().CanProveEqual(inp_count, res_count)) {
return false;
}

return true;
}

const IRModule& mod_;
Expand Down
50 changes: 48 additions & 2 deletions tests/python/relax/test_transform_rewrite_dataflow_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ def reshape(
def expand_dims(
rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"),
expand_dims: T.Buffer(
(T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3)),
"float32",
(T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3)), "float32"
),
):
for i0, i1, i2, i3, i4 in T.grid(
Expand Down Expand Up @@ -381,5 +380,52 @@ def main(
tvm.ir.assert_structural_equal(rewritten, Expected)


def test_invalid_reshape():
@tvm.script.ir_module
class Module:
# The strided_slice op has the reshape pattern, but it can take only a part of the input.
# It can't be replaced with the reshape op because reshape expects to preserve the "volume"
# of the input.
@T.prim_func
def strided_slice(
A: T.Buffer((T.int64(1), T.int64(1024)), "int32"),
T_strided_slice: T.Buffer((T.int64(1), T.int64(1000)), "int32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
for ax0, ax1 in T.grid(T.int64(1), T.int64(1000)):
with T.block("T_strided_slice"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1])
T.writes(T_strided_slice[v_ax0, v_ax1])
T_strided_slice[v_ax0, v_ax1] = A[v_ax0, v_ax1]

@T.prim_func
def add_one(
A: T.Buffer((T.int64(1), T.int64(1000)), "int32"),
T_add_one: T.buffer((T.int64(1), T.int64(1000)), "int32"),
):
for ax0, ax1 in T.grid(T.int64(1), T.int64(1000)):
with T.block("T_add_one"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1])
T.writes(T_add_one[v_ax0, v_ax1])
T_add_one[v_ax0, v_ax1] = A[v_ax0, v_ax1] + 1

@R.function
def main(A: R.Tensor((1, 1024), dtype="int32")) -> R.Tensor((1, 1000), dtype="int32"):
with R.dataflow():
cls = Module
S = R.call_tir(
cls.strided_slice, (A,), out_sinfo=R.Tensor((1, 1000), dtype="int32")
)
A = R.call_tir(cls.add_one, (S,), out_sinfo=R.Tensor((1, 1000), dtype="int32"))
R.output(A)
return A

assert relax.analysis.has_reshape_pattern(Module["strided_slice"])
rewritten = relax.transform.RewriteDataflowReshape()(Module)
tvm.ir.assert_structural_equal(rewritten, Module)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 54a7b25

Please sign in to comment.