Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Tuple Reduction Support in CreatePrimFunc #10671

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 111 additions & 29 deletions src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@ struct CreateFuncInfo {
}
};

BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::Tensor& tensor,
Array<PrimExpr> bindings, PrimExpr expr_body,
CreateFuncInfo* info, arith::Analyzer* analyzer) {
BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
const Array<te::Tensor>& tensors, Array<PrimExpr> bindings,
PrimExpr expr_body, CreateFuncInfo* info,
arith::Analyzer* analyzer) {
// Step 1. Push_back data_par axis and reduce_axis into block_vars.
Array<IterVar> iter_vars;
std::unordered_map<const VarNode*, PrimExpr> var_map;
Expand All @@ -105,16 +106,22 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::
f_push_block_vars(compute_op->axis);
f_push_block_vars(compute_op->reduce_axis);

// Step 2. Declare buffer and update op2buffers
Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global");
info->tensor2buffers[tensor] = buffer;

// Step 3. Add Buffer to root_alloc
if (!info->IsArg(tensor)) {
info->root_alloc.push_back(buffer);
// Step 2.
// - Declare buffers
// - Update `op2buffers`
// - Add the non-argument tensors to `alloc_buffer` of the root block
Array<Buffer> buffers;
for (const te::Tensor& tensor : tensors) {
Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global");
info->tensor2buffers[tensor] = buffer;
buffers.push_back(buffer);

if (!info->IsArg(tensor)) {
info->root_alloc.push_back(info->tensor2buffers[tensor]);
}
}

// Step 4. Calculate indices for BufferStore
// Step 3. Calculate indices for BufferStore
Array<PrimExpr> indices;
indices.reserve(compute_op->axis.size());
for (const IterVar& iter_var : compute_op->axis) {
Expand All @@ -123,26 +130,75 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::
indices.push_back(it->second);
}

// Step 5. Create block body.
// Step 4. Create block body.
String block_name{nullptr};
Optional<Stmt> init = NullOpt;
Stmt body;
if (const auto* reduce = expr_body.as<ReduceNode>()) {
// Case 1. Reduce compute
ICHECK_EQ(reduce->source.size(), 1);
const PrimExpr& lhs = BufferLoad(buffer, indices);
const PrimExpr& rhs = Substitute(info->transformer(reduce->source[0]), var_map);
ICHECK(lhs->dtype == rhs->dtype);
const PrimExpr& reduce_body = reduce->combiner.get()->operator()({lhs}, {rhs})[0];
const PrimExpr& init_body = reduce->combiner->identity_element[0];
body = BufferStore(buffer, analyzer->Simplify(reduce_body), indices);
init = BufferStore(buffer, analyzer->Simplify(init_body), indices);
block_name = compute_op->name;
int n_buffers = buffers.size();

Array<PrimExpr> lhs;
Array<PrimExpr> rhs;
lhs.reserve(n_buffers);
rhs.reserve(n_buffers);

// Make the LHS operands and RHS operands:
// - A LHS operand is the buffer storing the reduction result, with corresponding indices.
// - A RHS operand is the value to be reduced.
for (int i = 0; i < n_buffers; ++i) {
const PrimExpr& left = BufferLoad(buffers[i], indices);
const PrimExpr& right =
analyzer->Simplify(Substitute(info->transformer(reduce->source[i]), var_map));
lhs.push_back(left);
rhs.push_back(right);
ICHECK_EQ(left->dtype, right->dtype);
}

Array<Var> temp_vars;
Array<Stmt> body_stmts;
Array<Stmt> init_stmts;
temp_vars.reserve(n_buffers);
body_stmts.reserve(n_buffers);
init_stmts.reserve(n_buffers);

// - When there is only one buffer, we directly create a BufferStore which stores "combiner(lhs,
// rhs)" into the target buffer position.
// - In case there are multiple buffers, to avoid incorrect results, we create some intermediate
// variables and use LetStmts to bind the variables with "combiner(lhs, rhs)". After that, we
// then store the value of the variables into the target buffer positions.
for (int i = 0; i < n_buffers; ++i) {
const Buffer& buffer = buffers[i];
init_stmts.push_back(BufferStore(buffer, reduce->combiner->identity_element[i], indices));
PrimExpr value{nullptr};
if (n_buffers > 1) {
temp_vars.push_back(Var("v_" + buffer->name, PrimType(lhs[i].dtype())));
value = temp_vars.back();
} else {
value = reduce->combiner.get()->operator()(lhs, rhs)[i];
}
body_stmts.push_back(BufferStore(buffer, value, indices));
}

init = SeqStmt::Flatten(init_stmts);
body = SeqStmt::Flatten(body_stmts);
if (n_buffers > 1) {
// When there are multiple buffers, we wrap the body with LetStmts.
for (int i = n_buffers - 1; i >= 0; --i) {
PrimExpr value = reduce->combiner.get()->operator()(lhs, rhs)[i];
body = LetStmt(temp_vars[i], std::move(value), std::move(body));
}
}
} else {
// Case 2. Data parallel compute
ICHECK_EQ(tensors.size(), 1);
block_name = info->GetUniqueName(tensors[0]->GetNameHint());
const PrimExpr& compute_body = Substitute(info->transformer(expr_body), var_map);
body = BufferStore(buffer, analyzer->Simplify(compute_body), indices);
body = BufferStore(info->tensor2buffers[tensors[0]], analyzer->Simplify(compute_body), indices);
}

// Step 6. Add script_parsing_detect_access attr for auto complete the whole IR.
// Step 5. Add script_parsing_detect_access attr for auto complete the whole IR.
Map<String, ObjectRef> annotations;
auto mutate_attr = [&info](const ObjectRef& value) -> ObjectRef {
if (const auto* tensor_value = value.as<te::TensorNode>()) {
Expand All @@ -166,14 +222,14 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::
// Set script_parsing_detect_access
annotations.Set(tir::attr::script_parsing_detect_access, IntImm(DataType::Int(32), 3));

// Step 7. Create Block and BlockRealize.
// Step 6. Create Block and BlockRealize.
return BlockRealize(/*iter_values=*/std::move(bindings),
/*predicate=*/Bool(true),
/*block=*/
Block(/*iter_vars=*/std::move(iter_vars),
/*reads=*/{},
/*writes=*/{},
/*name_hint=*/info->GetUniqueName(tensor->GetNameHint()),
/*name_hint=*/block_name,
/*body=*/std::move(body),
/*init=*/std::move(init),
/*alloc_buffers=*/{},
Expand All @@ -192,12 +248,38 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in
}
// Step 2. Generate block bodies.
Array<Stmt> seq_stmt;
for (int i = 0; i < compute_op->num_outputs(); ++i) {
const te::Tensor& tensor = compute_op.output(i);
PrimExpr expr_body = compute_op->body[i];
seq_stmt.push_back(GenerateBlockFromTensor(compute_op, tensor, bindings, std::move(expr_body),
info, analyzer));
if (compute_op->body[0]->IsInstance<ReduceNode>()) {
auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) -> bool {
return a->combiner.same_as(b->combiner) && //
a->source.same_as(b->source) && //
a->axis.same_as(b->axis) && //
a->condition.same_as(b->condition) && //
((a->init.empty() && b->init.empty()) || a->init.same_as(b->init));
};

PrimExpr expr_body = compute_op->body[0];
Array<te::Tensor> tensors = {compute_op.output(0)};
const tir::ReduceNode* reduce = expr_body.as<tir::ReduceNode>();
// specially handle reduction inline for multiplre reductions.
for (size_t k = 1; k < compute_op->body.size(); ++k) {
const tir::ReduceNode* reduce_ = compute_op->body[k].as<tir::ReduceNode>();
ICHECK(reduce_);
ICHECK(f_reducer_equal(reduce_, reduce))
<< "The Reduce inputs of ComputeOp should have the same attribute except value_index";
tensors.push_back(compute_op.output(k));
}

seq_stmt.push_back(GenerateBlockFromTensors(compute_op, tensors, bindings, std::move(expr_body),
info, analyzer));
} else {
for (int i = 0; i < compute_op->num_outputs(); ++i) {
const te::Tensor& tensor = compute_op.output(i);
PrimExpr expr_body = compute_op->body[i];
seq_stmt.push_back(GenerateBlockFromTensors(compute_op, {tensor}, bindings,
std::move(expr_body), info, analyzer));
}
}

Stmt body = SeqStmt::Flatten(seq_stmt);

// Step 3. Generate loop nesting.
Expand Down
104 changes: 104 additions & 0 deletions tests/python/unittest/test_te_create_primfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,108 @@ def test_tensor_attr():
tvm.ir.assert_structural_equal(func, rt_func)


def te_argmax_idx_val():
def f_combine(x, y):
lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs

def f_identity(dtype0: tvm.DataType, dtype1: tvm.DataType):
return tvm.tir.const(-1, dtype0), tvm.te.min_value(dtype1)

argmax = te.comm_reducer(f_combine, f_identity, name="argmax")

m = te.var("m")
n = te.var("n")
idx = te.placeholder((m, n), name="idx", dtype="int32")
val = te.placeholder((m, n), name="val", dtype="float32")
k = te.reduce_axis((0, n), "k")
max_idx, max_val = te.compute(
(m,), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name="argmax"
)
return [idx, val, max_idx, max_val]


@T.prim_func
def tir_argmax_idx_val(
var_idx: T.handle, var_val: T.handle, var_argmax_v0: T.handle, var_argmax_v1: T.handle
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
m = T.var("int32")
n = T.var("int32")
idx = T.match_buffer(var_idx, [m, n], dtype="int32")
val = T.match_buffer(var_val, [m, n], dtype="float32")
argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="int32")
argmax_v1 = T.match_buffer(var_argmax_v1, [m], dtype="float32")
for i0, i1 in T.grid(m, n):
with T.block("argmax"):
i, k = T.axis.remap("SR", [i0, i1])
T.reads(argmax_v1[i], val[i, k], argmax_v0[i], idx[i, k])
T.writes(argmax_v0[i], argmax_v1[i])
with T.init():
argmax_v0[i] = T.int32(-1)
argmax_v1[i] = T.min_value("float32")
v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k])
v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k])
argmax_v0[i] = v_argmax_v0
argmax_v1[i] = v_argmax_v1


def te_argmax_val_idx():
def f_combine(x, y):
lhs = tvm.tir.Select((x[0] >= y[0]), x[0], y[0])
rhs = tvm.tir.Select((x[0] >= y[0]), x[1], y[1])
return lhs, rhs

def f_identity(dtype0: tvm.DataType, dtype1: tvm.DataType):
return tvm.te.min_value(dtype0), tvm.tir.const(-1, dtype1)

argmax = te.comm_reducer(f_combine, f_identity, name="argmax")

m = te.var("m")
n = te.var("n")
val = te.placeholder((m, n), name="val", dtype="float32")
idx = te.placeholder((m, n), name="idx", dtype="int32")
k = te.reduce_axis((0, n), "k")
max_val, max_idx = te.compute(
(m,), lambda i: argmax((val[i, k], idx[i, k]), axis=k), name="argmax"
)
return [val, idx, max_val, max_idx]


@T.prim_func
def tir_argmax_val_idx(
var_val: T.handle, var_idx: T.handle, var_argmax_v0: T.handle, var_argmax_v1: T.handle
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
m = T.var("int32")
n = T.var("int32")
val = T.match_buffer(var_val, [m, n], dtype="float32")
idx = T.match_buffer(var_idx, [m, n], dtype="int32")
argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="float32")
argmax_v1 = T.match_buffer(var_argmax_v1, [m], dtype="int32")
for i0, i1 in T.grid(m, n):
with T.block("argmax"):
i, k = T.axis.remap("SR", [i0, i1])
T.reads(argmax_v0[i], val[i, k], argmax_v1[i], idx[i, k])
T.writes(argmax_v0[i], argmax_v1[i])
with T.init():
argmax_v0[i] = T.min_value("float32")
argmax_v1[i] = T.int32(-1)
v_argmax_v0: T.float32 = T.Select(argmax_v0[i] >= val[i, k], argmax_v0[i], val[i, k])
v_argmax_v1: T.int32 = T.Select(argmax_v0[i] >= val[i, k], argmax_v1[i], idx[i, k])
argmax_v0[i] = v_argmax_v0
argmax_v1[i] = v_argmax_v1


def test_argmax_idx_val():
_check_workload(te_argmax_idx_val, tir_argmax_idx_val)


def test_argmax_val_idx():
_check_workload(te_argmax_val_idx, tir_argmax_val_idx)


if __name__ == "__main__":
test_unique_name()
test_matmul()
Expand All @@ -371,3 +473,5 @@ def test_tensor_attr():
test_constant()
test_select_simplify()
test_tensor_attr()
test_argmax_idx_val()
test_argmax_val_idx()