Skip to content

Commit

Permalink
Docs
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 committed Mar 18, 2022
1 parent fde75c1 commit b6453da
Showing 1 changed file with 27 additions and 16 deletions.
43 changes: 27 additions & 16 deletions src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,22 +106,22 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
f_push_block_vars(compute_op->axis);
f_push_block_vars(compute_op->reduce_axis);

// Step 2. Declare buffer and update op2buffers
// Step 2.
// - Declare buffers
// - Update `op2buffers`
// - Add the non-argument tensors to `alloc_buffer` of the root block
Array<Buffer> buffers;
for (const te::Tensor& tensor : tensors) {
Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global");
info->tensor2buffers[tensor] = buffer;
buffers.push_back(buffer);
}

// Step 3. Add Buffer to root_alloc
for (const te::Tensor& tensor : tensors) {
if (!info->IsArg(tensor)) {
info->root_alloc.push_back(info->tensor2buffers[tensor]);
}
}

// Step 4. Calculate indices for BufferStore
// Step 3. Calculate indices for BufferStore
Array<PrimExpr> indices;
indices.reserve(compute_op->axis.size());
for (const IterVar& iter_var : compute_op->axis) {
Expand All @@ -130,7 +130,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
indices.push_back(it->second);
}

// Step 5. Create block body.
// Step 4. Create block body.
String block_name{nullptr};
Optional<Stmt> init = NullOpt;
Stmt body;
Expand All @@ -144,6 +144,9 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
lhs.reserve(n_buffers);
rhs.reserve(n_buffers);

// Make the LHS operands and RHS operands:
// - A LHS operand is the buffer storing the reduction result, with corresponding indices.
// - A RHS operand is the value to be reduced.
for (int i = 0; i < n_buffers; ++i) {
const PrimExpr& left = BufferLoad(buffers[i], indices);
const PrimExpr& right =
Expand All @@ -160,6 +163,11 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
body_stmts.reserve(n_buffers);
init_stmts.reserve(n_buffers);

// - When there is only one buffer, we directly create a BufferStore which stores "combiner(lhs,
// rhs)" into the target buffer position.
// - In case there are multiple buffers, to avoid incorrect results, we create some intermediate
// variables and use LetStmts to bind the variables with "combiner(lhs, rhs)". After that, we
// then store the value of the variables into the target buffer positions.
for (int i = 0; i < n_buffers; ++i) {
const Buffer& buffer = buffers[i];
init_stmts.push_back(BufferStore(buffer, reduce->combiner->identity_element[i], indices));
Expand All @@ -176,6 +184,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
init = SeqStmt::Flatten(init_stmts);
body = SeqStmt::Flatten(body_stmts);
if (n_buffers > 1) {
// When there are multiple buffers, we wrap the body with LetStmts.
for (int i = n_buffers - 1; i >= 0; --i) {
PrimExpr value = reduce->combiner.get()->operator()(lhs, rhs)[i];
body = LetStmt(temp_vars[i], std::move(value), std::move(body));
Expand All @@ -189,7 +198,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
body = BufferStore(info->tensor2buffers[tensors[0]], analyzer->Simplify(compute_body), indices);
}

// Step 6. Add script_parsing_detect_access attr for auto complete the whole IR.
// Step 5. Add script_parsing_detect_access attr for auto complete the whole IR.
Map<String, ObjectRef> annotations;
auto mutate_attr = [&info](const ObjectRef& value) -> ObjectRef {
if (const auto* tensor_value = value.as<te::TensorNode>()) {
Expand All @@ -213,7 +222,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
// Set script_parsing_detect_access
annotations.Set(tir::attr::script_parsing_detect_access, IntImm(DataType::Int(32), 3));

// Step 7. Create Block and BlockRealize.
// Step 6. Create Block and BlockRealize.
return BlockRealize(/*iter_values=*/std::move(bindings),
/*predicate=*/Bool(true),
/*block=*/
Expand All @@ -228,12 +237,6 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
/*annotations=*/std::move(annotations)));
}

inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) {
return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) &&
(a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition)) &&
((a->init.empty() && b->init.empty()) || (a->init.same_as(b->init)));
}

Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* info,
arith::Analyzer* analyzer) {
// Step 1. Creating loop vars for block bindings.
Expand All @@ -246,15 +249,23 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in
// Step 2. Generate block bodies.
Array<Stmt> seq_stmt;
if (compute_op->body[0]->IsInstance<ReduceNode>()) {
auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) -> bool {
return a->combiner.same_as(b->combiner) && //
a->source.same_as(b->source) && //
a->axis.same_as(b->axis) && //
a->condition.same_as(b->condition) && //
((a->init.empty() && b->init.empty()) || a->init.same_as(b->init));
};

PrimExpr expr_body = compute_op->body[0];
Array<te::Tensor> tensors = {compute_op.output(0)};
const tir::ReduceNode* reduce = expr_body.as<tir::ReduceNode>();
// specially handle reduction inline for multiplre reductions.
for (size_t k = 1; k < compute_op->body.size(); ++k) {
const tir::ReduceNode* reduce_ = compute_op->body[k].as<tir::ReduceNode>();
ICHECK(reduce_);
ICHECK(ReduceEqual(reduce_, reduce)) << "The Reduce inputs of ComputeOp should "
<< "have the same attribute except value_index";
ICHECK(f_reducer_equal(reduce_, reduce))
<< "The Reduce inputs of ComputeOp should have the same attribute except value_index";
tensors.push_back(compute_op.output(k));
}

Expand Down

0 comments on commit b6453da

Please sign in to comment.