Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Transform] Modify FuseTIR pass to propagate buffer attributes #17075

Merged
merged 1 commit into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 120 additions & 20 deletions src/relax/transform/fuse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,114 @@ class BlockNameDeduplicator : public tir::StmtMutator {

namespace relax {

static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& inplace_indices,
int num_inputs) {
Array<Integer> ret;
int last_idx = num_inputs;
for (auto idx : inplace_indices) {
int i = idx.IntValue();
if (i >= 0) {
ret.push_back(Integer(i));
} else {
CHECK_EQ(i, -1) << "The only negative index expected in inplace_indices is -1, but got " << i;
ret.push_back(Integer(last_idx));
quic-sanirudh marked this conversation as resolved.
Show resolved Hide resolved
last_idx++;
}
}

return ret;
}

class RelaxToTIRVarMapCollector : public ExprVisitor {
public:
explicit RelaxToTIRVarMapCollector(const IRModule& mod) : mod_(mod) {}
static Map<Expr, tir::Buffer> Collect(const IRModule& mod, const Function& func) {
RelaxToTIRVarMapCollector visitor(mod);
visitor(func->body);
return visitor.relax_to_tir_var_map_;
}

private:
void VisitBinding_(const VarBindingNode* binding) final {
current_var_ = binding->var;
ExprVisitor::VisitBinding_(binding);
}

void VisitExpr_(const CallNode* call) {
static const Op& call_tir_op_ = Op::Get("relax.call_tir");
static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace");

ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_)
<< "Only call_tir and call_tir_inplace are supported in primitive function, but got: "
<< GetRef<Expr>(call);
CollectVarMapping(call, current_var_, call->op == call_tir_inplace_op_);
}

void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool in_place) {
GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
const auto& buffer_map = prim_func_->buffer_map;
const auto& tir_args = prim_func_->params;

const auto& relax_args = Downcast<Tuple>(call->args[1])->fields;

Array<Expr> relax_results;
if (lhs_var->IsInstance<TupleNode>()) {
relax_results = Downcast<Tuple>(lhs_var)->fields;
} else {
CHECK(lhs_var->IsInstance<VarNode>()) << "The lhs_var is expected to be either tuple or var";
relax_results = {Downcast<Var>(lhs_var)};
}

size_t num_inputs = relax_args.size();
size_t num_outputs = relax_results.size();

Array<Integer> output_idxs;
if (in_place) {
const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call";
output_idxs = GetInplaceOutputIndices(attrs->inplace_indices, num_inputs);
} else {
for (size_t i = num_inputs; i < num_inputs + num_outputs; i++) {
output_idxs.push_back(i);
}
}

// If the `expr` is already seen (present in the map), validate whether the mapped buffer is
// structurally equal to the `new_buf` passed
auto ValidateBufferCompatibility = [this](tir::Buffer new_buf, Expr expr) {
if (auto it = relax_to_tir_var_map_.find(expr); it != relax_to_tir_var_map_.end()) {
ICHECK(StructuralEqual()((*it).second, new_buf))
<< "Inconsistent buffers " << (*it).second << " and " << new_buf
<< " mapped to the same relax var: " << expr;
}
};
for (size_t i = 0; i < tir_args.size(); ++i) {
const auto& tir_var = tir_args[i];
if (auto tir_buffer = buffer_map.Get(tir_var)) {
if (i < num_inputs) {
const auto& relax_var = relax_args[i];
ValidateBufferCompatibility(tir_buffer.value(), relax_var);
relax_to_tir_var_map_.Set(relax_var, tir_buffer.value());
}
if (auto it = std::find(output_idxs.begin(), output_idxs.end(), i);
it != output_idxs.end()) {
int result_idx = it - output_idxs.begin();
const auto& relax_var = relax_results[result_idx];
ValidateBufferCompatibility(tir_buffer.value(), relax_var);
relax_to_tir_var_map_.Set(relax_var, tir_buffer.value());
}
}
}
}

private:
/*! \brief The IRModule */
const IRModule& mod_;
Map<Expr, tir::Buffer> relax_to_tir_var_map_;
Var current_var_;
};

class FusedTIRConstructor : public ExprVisitor {
public:
/*!
Expand Down Expand Up @@ -391,10 +499,11 @@ class FusedTIRConstructor : public ExprVisitor {
: mod_(mod), func_name_(func_name) {}

void VisitExpr_(const FunctionNode* func) final {
auto relax_to_tir_var_map = RelaxToTIRVarMapCollector::Collect(mod_, GetRef<Function>(func));
std::vector<Variant<tir::Var, tir::Buffer>> prim_func_params;
for (const Var& relax_param : func->params) {
size_t size_before = prim_func_params.size();
CollectPrimFuncParams(relax_param, &prim_func_params);
CollectPrimFuncParams(relax_param, &prim_func_params, relax_to_tir_var_map.Get(relax_param));

auto param_buffers = [&]() -> Array<tir::Buffer> {
Array<tir::Buffer> out;
Expand Down Expand Up @@ -676,23 +785,6 @@ class FusedTIRConstructor : public ExprVisitor {
MapArgsToBuffer(arg_list, buffer_list);
}

static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& inplace_indices,
int num_inputs) {
Array<Integer> ret;
int last_idx = num_inputs;
for (auto idx : inplace_indices) {
int i = idx.IntValue();
if (i >= 0) {
ret.push_back(Integer(i));
} else {
ret.push_back(Integer(last_idx));
last_idx++;
}
}

return ret;
}

static Array<tir::Var> GetPrimFuncOutputParams(const tir::PrimFunc& func,
const Array<Integer>& output_indices) {
size_t n = func->params.size();
Expand Down Expand Up @@ -799,7 +891,8 @@ class FusedTIRConstructor : public ExprVisitor {
* \param out The vector into which to collect the params/buffers
*/
static void CollectPrimFuncParams(const Var& relax_param,
std::vector<Variant<tir::Var, tir::Buffer>>* out) {
std::vector<Variant<tir::Var, tir::Buffer>>* out,
const tvm::runtime::Optional<tir::Buffer>& tir_buffer_param) {
auto struct_info = GetStructInfo(relax_param);

CHECK(!struct_info.as<TupleStructInfoNode>())
Expand All @@ -814,7 +907,14 @@ class FusedTIRConstructor : public ExprVisitor {
const auto* shape_expr = tensor->shape.as<ShapeExprNode>();
ICHECK(shape_expr) << "FuseTIR expects all Tensor parameters have a known shape.";
DataType dtype = tensor->dtype;
tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype, name_hint);
tir::Buffer buffer;
if (tir_buffer_param.defined()) {
buffer =
tir::decl_buffer(shape_expr->values, dtype, name_hint, tir_buffer_param.value().scope(),
tir_buffer_param.value()->axis_separators);
} else {
buffer = tir::decl_buffer(shape_expr->values, dtype, name_hint);
}
out->push_back(std::move(buffer));

} else if (const auto* prim_value = struct_info.as<PrimStructInfoNode>()) {
Expand Down
128 changes: 128 additions & 0 deletions tests/python/relax/test_transform_fuse_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.

import pytest

import tvm
import tvm.testing
from tvm import relax, topi
Expand Down Expand Up @@ -2314,5 +2316,131 @@ def take(
_check(Before, Before)


def test_fuse_with_axis_separators():
@I.ir_module
class Before:
@T.prim_func(private=True)
def add(a: T.handle, b: T.handle, c: T.handle):
A = T.match_buffer(a, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
B = T.match_buffer(b, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])

for iters in T.grid(T.int64(16), T.int64(32)):
with T.block("compute"):
i, j = T.axis.remap("SS", iters)
C[i, j] = A[i, j] + B[i, j]

@R.function(private=True)
def fused_function(
x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
y: R.Tensor([T.int64(16), T.int64(32)], "float32"),
z: R.Tensor([T.int64(16), T.int64(32)], "float32"),
) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
R.func_attr({"Primitive": 1})
cls = Before
with R.dataflow():
w = R.call_tir(
cls.add, [x, y], out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test case for incompatible usage of a single Relax var? As currently written, we could have a single Relax variable that is used in two separate R.call_tir statements, where the function being called imposes different restrictions on it. For example, if x were used in cls.add1, which requires axis_separators=[1], and cls.add2, which requires axis_separators=[]. We should be able to identify this case and raise an error when it occurs.

(Ideally, that should never happen, but this would be the last point at which we'd have enough information to catch this failure mode at compile-time.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added the test case to check possible inconsistencies.

)
out = R.call_tir(
cls.add, [w, z], out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32")
)
R.output(out)
return out

@R.function
def main(
x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
y: R.Tensor([T.int64(16), T.int64(32)], "float32"),
z: R.Tensor([T.int64(16), T.int64(32)], "float32"),
) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
cls = Before
with R.dataflow():
gv = cls.fused_function(x, y, z)
R.output(gv)
return gv

@I.ir_module
class Expected:
@T.prim_func(private=True)
def fused_function(x: T.handle, y: T.handle, z: T.handle, c: T.handle):
T.func_attr({"tir.noalias": True})
X = T.match_buffer(x, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
Y = T.match_buffer(y, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
Z = T.match_buffer(z, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
Temp = T.alloc_buffer(X.shape, "float32", axis_separators=[1])
for iters in T.grid(*X.shape):
with T.block("compute_Y"):
i, j = T.axis.remap("SS", iters)
Temp[i, j] = X[i, j] + Y[i, j]

for iters in T.grid(*X.shape):
with T.block("compute_Z"):
i, j = T.axis.remap("SS", iters)
C[i, j] = Temp[i, j] + Z[i, j]

@R.function
def main(
x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
y: R.Tensor([T.int64(16), T.int64(32)], "float32"),
z: R.Tensor([T.int64(16), T.int64(32)], "float32"),
) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
cls = Expected
with R.dataflow():
gv = R.call_tir(
cls.fused_function,
[x, y, z],
out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32"),
)
R.output(gv)
return gv

_check(Before, Expected)


def test_fuse_with_axis_separators_inconsistent_buffer_mapping():
@I.ir_module
class Before:
@T.prim_func(private=True)
def mul(a: T.handle, b: T.handle, c: T.handle):
A = T.match_buffer(a, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])
B = T.match_buffer(b, [T.int64(16), T.int64(32)], "float32", axis_separators=[])
C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", axis_separators=[1])

for iters in T.grid(T.int64(16), T.int64(32)):
with T.block("compute"):
i, j = T.axis.remap("SS", iters)
C[i, j] = A[i, j] * B[i, j]

@R.function(private=True)
def fused_function(
x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
R.func_attr({"Primitive": 1})
cls = Before
with R.dataflow():
out = R.call_tir(
cls.mul, [x, x], out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32")
)
R.output(out)
return out

@R.function
def main(
x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
cls = Before
with R.dataflow():
gv = cls.fused_function(x)
R.output(gv)
return gv

with pytest.raises(
tvm.TVMError, match=r"Inconsistent buffers.*and.*mapped to the same relax var:.*"
):
relax.transform.FuseTIR()(Before)


if __name__ == "__main__":
tvm.testing.main()
Loading