Skip to content

Commit

Permalink
address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
jinhongyii committed Jul 19, 2021
1 parent 87c98e2 commit ac7afc7
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 54 deletions.
2 changes: 1 addition & 1 deletion python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def after_split(a: ty.handle, b: ty.handle) -> None:
"""
# it will be checked later in C++ implementation
# sthat there is at most one None or -1 in `factors`
# that there is at most one None or -1 in `factors`
return _ffi_api_schedule.ScheduleSplit(self, loop, factors) # type: ignore # pylint: disable=no-member

########## Schedule: compute location ##########
Expand Down
6 changes: 5 additions & 1 deletion src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,11 @@ Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
Array<PrimExpr> factors;
factors.reserve(factor_rvs.size());
for (const Optional<ExprRV>& factor_rv : factor_rvs) {
factors.push_back(this->Get(factor_rv.value_or(Integer(-1))));
if (factor_rv.defined()) {
factors.push_back(Integer(-1));
} else {
factors.push_back(this->Get(factor_rv.value()));
}
}
Array<StmtSRef> results;
TVM_TIR_SCHEDULE_BEGIN();
Expand Down
8 changes: 1 addition & 7 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,9 @@ inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const {
}
const ObjectRef& obj = (*it).second;
const auto* int_imm = TVM_TYPE_AS(int_imm, obj, IntImmNode);
if (int_imm == nullptr) {
LOG(FATAL) << "ValueError: ExprRV's corresponding type is invalid: "
<< (obj.defined() ? obj->GetTypeKey() : "None");
}
return Integer(int_imm->value);
});
PrimExpr simplified = this->analyzer_->Simplify(transformed);
CHECK(is_const_int(transformed)) << "ValueError: The ExprRV does not have a specific value";
return simplified;
return this->analyzer_->Simplify(transformed);
}

inline StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const {
Expand Down
88 changes: 43 additions & 45 deletions src/tir/schedule/primitive/loop_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,11 @@ class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator {
}

Stmt VisitStmt_(const BlockRealizeNode* op) final {
Stmt res = StmtMutator::VisitStmt_(op);
if (op->block->iter_vars.empty()) {
const BlockRealizeNode* realize = TVM_TYPE_AS(realize, res, BlockRealizeNode);
BlockRealize realize = Downcast<BlockRealize>(StmtMutator::VisitStmt_(op));
if (realize->block->iter_vars.empty()) {
opaque_blocks_->Set(op->block, realize->block);
}
return res;
return std::move(realize);
}

/*! \brief The substitute function */
Expand All @@ -87,18 +86,19 @@ class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator {
/*! \brief Simplify the binding of block realize and update the opaque block reuse mapping */
class IterMapSimplifyBlockBinding : public StmtExprMutator {
public:
explicit IterMapSimplifyBlockBinding(const Map<Var, Range>& loop_map,
Map<Block, Block>* opaque_blocks)
: opaque_blocks_(opaque_blocks), loop_var2extent_(std::move(loop_map)) {}
explicit IterMapSimplifyBlockBinding(MapNode* opaque_blocks,
Map<Var, Range> loop_var2extent)
: opaque_blocks_(opaque_blocks), loop_var2extent_(loop_var2extent) {}

static For SimplifyBindings(const Stmt& stmt, const Array<StmtSRef>& loop_srefs,
static For SimplifyBindings(Stmt stmt, const Array<StmtSRef>& loop_srefs,
Map<Block, Block>* opaque_blocks) {
Map<Var, Range> loop_var2extent;
for (const StmtSRef& sref : loop_srefs) {
const ForNode* loop = TVM_SREF_TO_FOR(loop, sref);
loop_var2extent.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
}
return Downcast<For>(IterMapSimplifyBlockBinding(loop_var2extent, opaque_blocks)(stmt));
return Downcast<For>(IterMapSimplifyBlockBinding(opaque_blocks->CopyOnWrite(),
std::move(loop_var2extent))(std::move(stmt)));
}

private:
Expand All @@ -112,16 +112,15 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator {
Stmt VisitStmt_(const BlockRealizeNode* op) final {
// skip opaque block and update mapping
if (op->iter_values.empty()) {
Stmt res = StmtMutator::VisitStmt_(op);
const BlockRealizeNode* realize = res.as<BlockRealizeNode>();
MapNode* mutable_map = opaque_blocks_->CopyOnWrite();
for (const std::pair<Block, Block>& entry : *opaque_blocks_) {
Block block = op->block;
BlockRealize realize = Downcast<BlockRealize>(StmtMutator::VisitStmt_(op));
for (const std::pair<ObjectRef, ObjectRef>& entry : *opaque_blocks_) {
if (entry.second.same_as(op->block)) {
mutable_map->at(entry.first) = realize->block;
opaque_blocks_->at(entry.first) = realize->block;
break;
}
}
return res;
return std::move(realize);
}
Array<PrimExpr> v = arith::IterMapSimplify(/*indices=*/op->iter_values,
/*input_iters=*/loop_var2extent_,
Expand All @@ -137,7 +136,7 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator {
}

/*! \brief The reuse mapping */
Map<Block, Block>* opaque_blocks_;
MapNode* opaque_blocks_;
/*! \brief The range of loops */
Map<Var, Range> loop_var2extent_;
};
Expand Down Expand Up @@ -272,7 +271,6 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
// order with before.
// Step 1. Check correctness
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
ICHECK(loop) << "the input sref does not point to a loop";
if (!loop->annotations.empty() || loop->thread_binding.defined()) {
throw HasAnnotationOrThreadBindingError(self->mod, GetRef<For>(loop));
}
Expand All @@ -283,8 +281,8 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
}
PrimExpr tot_length = 1;
int infer_index = -1;
size_t n = factors.size();
for (size_t i = 0; i < n; i++) {
int n = factors.size();
for (int i = 0; i < n; i++) {
if (!analyzer.CanProve(factors[i] == -1)) {
tot_length *= factors[i];
} else if (infer_index != -1) {
Expand All @@ -302,43 +300,42 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
throw WrongFactorProductError(self->mod, GetRef<For>(loop));
}
// Step 3. Replace all occurrences of the original loop var with new variables
PrimExpr substitute_value = 0;
std::vector<Var> new_loop_vars;
new_loop_vars.reserve(n);
for (size_t i = 0; i < n; i++) {
new_loop_vars.push_back(loop->loop_var.copy_with_suffix("_" + std::to_string(i)));
}
PrimExpr substitute_value = 0;
for (size_t i = 0; i < n; i++) {
substitute_value *= inferred_factors[i];
substitute_value += new_loop_vars[i];
const PrimExpr& factor = inferred_factors[i];
Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i));
substitute_value = substitute_value * factor + var;
analyzer.Bind(var, Range::FromMinExtent(0, factor));
new_loop_vars.emplace_back(std::move(var));
}
Map<Block, Block> opaque_block_reuse;
auto f_substitute = [&](const Var& v) -> Optional<PrimExpr> {
if (v.same_as(loop->loop_var)) {
return substitute_value;
} else {
return NullOpt;
}
};
Stmt new_stmt =
SubstituteVarAndCollectOpaqueBlock(f_substitute, &opaque_block_reuse)(std::move(loop->body));
for (size_t i = 0; i < n; i++) {
analyzer.Bind(new_loop_vars[i], Range::FromMinExtent(0, inferred_factors[i]));
}
Stmt new_stmt = loop->body;
new_stmt = SubstituteVarAndCollectOpaqueBlock(
[&](const Var& v) -> Optional<PrimExpr> {
if (v.same_as(loop->loop_var)) {
return substitute_value;
} else {
return NullOpt;
}
},
&opaque_block_reuse
)(std::move(new_stmt));
// Step 4. Update predicate to guard the loop
new_stmt =
BlockPredicateAppender(/*predicate=*/substitute_value < loop->extent, &analyzer)(new_stmt);
// Step 5. Generate nested loops to replace the original loop and simplify the binding
for (int i = n - 1; i >= 0; i--) {
new_stmt = For(new_loop_vars[i], 0, inferred_factors[i], loop->kind, new_stmt);
new_stmt = For(new_loop_vars[i], 0, inferred_factors[i], ForKind::kSerial, new_stmt);
}

new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(new_stmt, GetLoops(loop_sref),
new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(std::move(new_stmt), GetLoops(loop_sref),
&opaque_block_reuse);
self->Replace(loop_sref, new_stmt, opaque_block_reuse);
Array<StmtSRef> result_srefs;
result_srefs.reserve(n);
for (size_t i = 0; i < n; i++) {
for (int i = 0; i < n; i++) {
result_srefs.push_back(self->stmt2ref.at(new_stmt.get()));
const ForNode* outer_loop = TVM_TYPE_AS(outer_loop, new_stmt, ForNode);
new_stmt = outer_loop->body;
Expand Down Expand Up @@ -387,11 +384,11 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
Array<PrimExpr> substitute_value;
substitute_value.resize(loops.size());
PrimExpr tot = fused_var;
for (int i = loops.size() - 1; i >= 0; i--) {
for (int i = static_cast<int>(loops.size()) - 1; i >= 0; i--) {
substitute_value.Set(i, floormod(tot, loops[i]->extent));
tot = floordiv(tot, loops[i]->extent);
}
Stmt loop_body = loops.back()->body;
Stmt new_stmt = loops.back()->body;
Map<Block, Block> opaque_block_reuse;
auto f_substitute = [&](const Var& v) -> Optional<PrimExpr> {
for (size_t i = 0; i < loops.size(); i++) {
Expand All @@ -401,16 +398,17 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
}
return NullOpt;
};
Stmt new_stmt =
SubstituteVarAndCollectOpaqueBlock(f_substitute, &opaque_block_reuse)(std::move(loop_body));
new_stmt =
SubstituteVarAndCollectOpaqueBlock(f_substitute, &opaque_block_reuse)(std::move(new_stmt));
// Step 3. Generate a loop to replace the original loops
PrimExpr fused_extent = 1;
for (size_t i = 0; i < loops.size(); i++) {
fused_extent *= loops[i]->extent;
}
fused_extent = analyzer.Simplify(fused_extent);
new_stmt = For(fused_var, 0, fused_extent, ForKind::kSerial, new_stmt);
new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(new_stmt, GetLoops(loop_srefs[0]),
new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(std::move(new_stmt), GetLoops
(loop_srefs[0]),
&opaque_block_reuse);
self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse);
return self->stmt2ref.at(new_stmt.get());
Expand Down

0 comments on commit ac7afc7

Please sign in to comment.