Skip to content

Commit

Permalink
[TIR] Tuple Reduction Support in CreatePrimFunc (#10671)
Browse files Browse the repository at this point in the history
* [CreatePrimFunc] Support multi-source ReduceNode (#64)

* initial

* assert structural equal test

* Enhancement and tests

* Fix dtype

* Docs

Co-authored-by: Andrew Liu <[email protected]>
  • Loading branch information
MasterJH5574 and hypercubestart authored Mar 22, 2022
1 parent 5b5bf75 commit ae61603
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 29 deletions.
140 changes: 111 additions & 29 deletions src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@ struct CreateFuncInfo {
}
};

BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::Tensor& tensor,
Array<PrimExpr> bindings, PrimExpr expr_body,
CreateFuncInfo* info, arith::Analyzer* analyzer) {
BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
const Array<te::Tensor>& tensors, Array<PrimExpr> bindings,
PrimExpr expr_body, CreateFuncInfo* info,
arith::Analyzer* analyzer) {
// Step 1. Push_back data_par axis and reduce_axis into block_vars.
Array<IterVar> iter_vars;
std::unordered_map<const VarNode*, PrimExpr> var_map;
Expand All @@ -105,16 +106,22 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::
f_push_block_vars(compute_op->axis);
f_push_block_vars(compute_op->reduce_axis);

// Step 2. Declare buffer and update op2buffers
Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global");
info->tensor2buffers[tensor] = buffer;

// Step 3. Add Buffer to root_alloc
if (!info->IsArg(tensor)) {
info->root_alloc.push_back(buffer);
// Step 2.
// - Declare buffers
// - Update `op2buffers`
// - Add the non-argument tensors to `alloc_buffer` of the root block
Array<Buffer> buffers;
for (const te::Tensor& tensor : tensors) {
Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global");
info->tensor2buffers[tensor] = buffer;
buffers.push_back(buffer);

if (!info->IsArg(tensor)) {
info->root_alloc.push_back(info->tensor2buffers[tensor]);
}
}

// Step 4. Calculate indices for BufferStore
// Step 3. Calculate indices for BufferStore
Array<PrimExpr> indices;
indices.reserve(compute_op->axis.size());
for (const IterVar& iter_var : compute_op->axis) {
Expand All @@ -123,26 +130,75 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::
indices.push_back(it->second);
}

// Step 5. Create block body.
// Step 4. Create block body.
String block_name{nullptr};
Optional<Stmt> init = NullOpt;
Stmt body;
if (const auto* reduce = expr_body.as<ReduceNode>()) {
// Case 1. Reduce compute
ICHECK_EQ(reduce->source.size(), 1);
const PrimExpr& lhs = BufferLoad(buffer, indices);
const PrimExpr& rhs = Substitute(info->transformer(reduce->source[0]), var_map);
ICHECK(lhs->dtype == rhs->dtype);
const PrimExpr& reduce_body = reduce->combiner.get()->operator()({lhs}, {rhs})[0];
const PrimExpr& init_body = reduce->combiner->identity_element[0];
body = BufferStore(buffer, analyzer->Simplify(reduce_body), indices);
init = BufferStore(buffer, analyzer->Simplify(init_body), indices);
block_name = compute_op->name;
int n_buffers = buffers.size();

Array<PrimExpr> lhs;
Array<PrimExpr> rhs;
lhs.reserve(n_buffers);
rhs.reserve(n_buffers);

// Make the LHS operands and RHS operands:
// - A LHS operand is the buffer storing the reduction result, with corresponding indices.
// - A RHS operand is the value to be reduced.
for (int i = 0; i < n_buffers; ++i) {
const PrimExpr& left = BufferLoad(buffers[i], indices);
const PrimExpr& right =
analyzer->Simplify(Substitute(info->transformer(reduce->source[i]), var_map));
lhs.push_back(left);
rhs.push_back(right);
ICHECK_EQ(left->dtype, right->dtype);
}

Array<Var> temp_vars;
Array<Stmt> body_stmts;
Array<Stmt> init_stmts;
temp_vars.reserve(n_buffers);
body_stmts.reserve(n_buffers);
init_stmts.reserve(n_buffers);

// - When there is only one buffer, we directly create a BufferStore which stores "combiner(lhs,
// rhs)" into the target buffer position.
// - In case there are multiple buffers, to avoid incorrect results, we create some intermediate
// variables and use LetStmts to bind the variables with "combiner(lhs, rhs)". After that, we
// then store the value of the variables into the target buffer positions.
for (int i = 0; i < n_buffers; ++i) {
const Buffer& buffer = buffers[i];
init_stmts.push_back(BufferStore(buffer, reduce->combiner->identity_element[i], indices));
PrimExpr value{nullptr};
if (n_buffers > 1) {
temp_vars.push_back(Var("v_" + buffer->name, PrimType(lhs[i].dtype())));
value = temp_vars.back();
} else {
value = reduce->combiner.get()->operator()(lhs, rhs)[i];
}
body_stmts.push_back(BufferStore(buffer, value, indices));
}

init = SeqStmt::Flatten(init_stmts);
body = SeqStmt::Flatten(body_stmts);
if (n_buffers > 1) {
// When there are multiple buffers, we wrap the body with LetStmts.
for (int i = n_buffers - 1; i >= 0; --i) {
PrimExpr value = reduce->combiner.get()->operator()(lhs, rhs)[i];
body = LetStmt(temp_vars[i], std::move(value), std::move(body));
}
}
} else {
// Case 2. Data parallel compute
ICHECK_EQ(tensors.size(), 1);
block_name = info->GetUniqueName(tensors[0]->GetNameHint());
const PrimExpr& compute_body = Substitute(info->transformer(expr_body), var_map);
body = BufferStore(buffer, analyzer->Simplify(compute_body), indices);
body = BufferStore(info->tensor2buffers[tensors[0]], analyzer->Simplify(compute_body), indices);
}

// Step 6. Add script_parsing_detect_access attr for auto complete the whole IR.
// Step 5. Add script_parsing_detect_access attr for auto complete the whole IR.
Map<String, ObjectRef> annotations;
auto mutate_attr = [&info](const ObjectRef& value) -> ObjectRef {
if (const auto* tensor_value = value.as<te::TensorNode>()) {
Expand All @@ -166,14 +222,14 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::
// Set script_parsing_detect_access
annotations.Set(tir::attr::script_parsing_detect_access, IntImm(DataType::Int(32), 3));

// Step 7. Create Block and BlockRealize.
// Step 6. Create Block and BlockRealize.
return BlockRealize(/*iter_values=*/std::move(bindings),
/*predicate=*/Bool(true),
/*block=*/
Block(/*iter_vars=*/std::move(iter_vars),
/*reads=*/{},
/*writes=*/{},
/*name_hint=*/info->GetUniqueName(tensor->GetNameHint()),
/*name_hint=*/block_name,
/*body=*/std::move(body),
/*init=*/std::move(init),
/*alloc_buffers=*/{},
Expand All @@ -192,12 +248,38 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in
}
// Step 2. Generate block bodies.
Array<Stmt> seq_stmt;
for (int i = 0; i < compute_op->num_outputs(); ++i) {
const te::Tensor& tensor = compute_op.output(i);
PrimExpr expr_body = compute_op->body[i];
seq_stmt.push_back(GenerateBlockFromTensor(compute_op, tensor, bindings, std::move(expr_body),
info, analyzer));
if (compute_op->body[0]->IsInstance<ReduceNode>()) {
auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) -> bool {
return a->combiner.same_as(b->combiner) && //
a->source.same_as(b->source) && //
a->axis.same_as(b->axis) && //
a->condition.same_as(b->condition) && //
((a->init.empty() && b->init.empty()) || a->init.same_as(b->init));
};

PrimExpr expr_body = compute_op->body[0];
Array<te::Tensor> tensors = {compute_op.output(0)};
const tir::ReduceNode* reduce = expr_body.as<tir::ReduceNode>();
// specially handle reduction inline for multiplre reductions.
for (size_t k = 1; k < compute_op->body.size(); ++k) {
const tir::ReduceNode* reduce_ = compute_op->body[k].as<tir::ReduceNode>();
ICHECK(reduce_);
ICHECK(f_reducer_equal(reduce_, reduce))
<< "The Reduce inputs of ComputeOp should have the same attribute except value_index";
tensors.push_back(compute_op.output(k));
}

seq_stmt.push_back(GenerateBlockFromTensors(compute_op, tensors, bindings, std::move(expr_body),
info, analyzer));
} else {
for (int i = 0; i < compute_op->num_outputs(); ++i) {
const te::Tensor& tensor = compute_op.output(i);
PrimExpr expr_body = compute_op->body[i];
seq_stmt.push_back(GenerateBlockFromTensors(compute_op, {tensor}, bindings,
std::move(expr_body), info, analyzer));
}
}

Stmt body = SeqStmt::Flatten(seq_stmt);

// Step 3. Generate loop nesting.
Expand Down
104 changes: 104 additions & 0 deletions tests/python/unittest/test_te_create_primfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,108 @@ def test_tensor_attr():
tvm.ir.assert_structural_equal(func, rt_func)


def te_argmax_idx_val():
def f_combine(x, y):
lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs

def f_identity(dtype0: tvm.DataType, dtype1: tvm.DataType):
return tvm.tir.const(-1, dtype0), tvm.te.min_value(dtype1)

argmax = te.comm_reducer(f_combine, f_identity, name="argmax")

m = te.var("m")
n = te.var("n")
idx = te.placeholder((m, n), name="idx", dtype="int32")
val = te.placeholder((m, n), name="val", dtype="float32")
k = te.reduce_axis((0, n), "k")
max_idx, max_val = te.compute(
(m,), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name="argmax"
)
return [idx, val, max_idx, max_val]


@T.prim_func
def tir_argmax_idx_val(
var_idx: T.handle, var_val: T.handle, var_argmax_v0: T.handle, var_argmax_v1: T.handle
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
m = T.var("int32")
n = T.var("int32")
idx = T.match_buffer(var_idx, [m, n], dtype="int32")
val = T.match_buffer(var_val, [m, n], dtype="float32")
argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="int32")
argmax_v1 = T.match_buffer(var_argmax_v1, [m], dtype="float32")
for i0, i1 in T.grid(m, n):
with T.block("argmax"):
i, k = T.axis.remap("SR", [i0, i1])
T.reads(argmax_v1[i], val[i, k], argmax_v0[i], idx[i, k])
T.writes(argmax_v0[i], argmax_v1[i])
with T.init():
argmax_v0[i] = T.int32(-1)
argmax_v1[i] = T.min_value("float32")
v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k])
v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k])
argmax_v0[i] = v_argmax_v0
argmax_v1[i] = v_argmax_v1


def te_argmax_val_idx():
def f_combine(x, y):
lhs = tvm.tir.Select((x[0] >= y[0]), x[0], y[0])
rhs = tvm.tir.Select((x[0] >= y[0]), x[1], y[1])
return lhs, rhs

def f_identity(dtype0: tvm.DataType, dtype1: tvm.DataType):
return tvm.te.min_value(dtype0), tvm.tir.const(-1, dtype1)

argmax = te.comm_reducer(f_combine, f_identity, name="argmax")

m = te.var("m")
n = te.var("n")
val = te.placeholder((m, n), name="val", dtype="float32")
idx = te.placeholder((m, n), name="idx", dtype="int32")
k = te.reduce_axis((0, n), "k")
max_val, max_idx = te.compute(
(m,), lambda i: argmax((val[i, k], idx[i, k]), axis=k), name="argmax"
)
return [val, idx, max_val, max_idx]


@T.prim_func
def tir_argmax_val_idx(
var_val: T.handle, var_idx: T.handle, var_argmax_v0: T.handle, var_argmax_v1: T.handle
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
m = T.var("int32")
n = T.var("int32")
val = T.match_buffer(var_val, [m, n], dtype="float32")
idx = T.match_buffer(var_idx, [m, n], dtype="int32")
argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="float32")
argmax_v1 = T.match_buffer(var_argmax_v1, [m], dtype="int32")
for i0, i1 in T.grid(m, n):
with T.block("argmax"):
i, k = T.axis.remap("SR", [i0, i1])
T.reads(argmax_v0[i], val[i, k], argmax_v1[i], idx[i, k])
T.writes(argmax_v0[i], argmax_v1[i])
with T.init():
argmax_v0[i] = T.min_value("float32")
argmax_v1[i] = T.int32(-1)
v_argmax_v0: T.float32 = T.Select(argmax_v0[i] >= val[i, k], argmax_v0[i], val[i, k])
v_argmax_v1: T.int32 = T.Select(argmax_v0[i] >= val[i, k], argmax_v1[i], idx[i, k])
argmax_v0[i] = v_argmax_v0
argmax_v1[i] = v_argmax_v1


def test_argmax_idx_val():
_check_workload(te_argmax_idx_val, tir_argmax_idx_val)


def test_argmax_val_idx():
_check_workload(te_argmax_val_idx, tir_argmax_val_idx)


if __name__ == "__main__":
test_unique_name()
test_matmul()
Expand All @@ -371,3 +473,5 @@ def test_tensor_attr():
test_constant()
test_select_simplify()
test_tensor_attr()
test_argmax_idx_val()
test_argmax_val_idx()

0 comments on commit ae61603

Please sign in to comment.