Skip to content

Commit

Permalink
[MetaSchedule] Schedule Rule: Auto Inline (apache#9943)
Browse files Browse the repository at this point in the history
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Xiyou Zhou <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>

Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Xiyou Zhou <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
  • Loading branch information
7 people authored and yuanfz98 committed Jan 24, 2022
1 parent 46e249b commit ac85d18
Show file tree
Hide file tree
Showing 9 changed files with 795 additions and 2 deletions.
2 changes: 0 additions & 2 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ class ScheduleRule : public runtime::ObjectRef {
* \brief Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions
* \param into_producer If allows to inline a block into its producer
* \param into_consumer If allows to inline a block into its consumer
* \param into_cache_only If it only allows to inline into a block generated by cache_read/write
* \param inline_const_tensor Always inline constant tensors
* \param disallow_if_then_else Always disallow if-then-else-like constructs
* \param require_ordered Always require the read-to-write mapping to be ordered
Expand All @@ -125,7 +124,6 @@ class ScheduleRule : public runtime::ObjectRef {
*/
TVM_DLL static ScheduleRule AutoInline(bool into_producer, //
bool into_consumer, //
bool into_cache_only, //
bool inline_const_tensor, //
bool disallow_if_then_else, //
bool require_injective, //
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
Meta Schedule schedule rules are used for modification of
blocks in a schedule. See also PostOrderApply.
"""
from .auto_inline import AutoInline
from .schedule_rule import PyScheduleRule, ScheduleRule
67 changes: 67 additions & 0 deletions python/tvm/meta_schedule/schedule_rule/auto_inline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Auto-Inline. Rule that inlines spatial blocks if it satisfies some conditions"""
from typing import List, Optional

from tvm._ffi import register_object

from .. import _ffi_api
from .schedule_rule import ScheduleRule


@register_object("meta_schedule.AutoInline")
class AutoInline(ScheduleRule):
"""Rule that inlines spatial blocks if it satisfies some conditions
Parameters
----------
into_producer : bool
If allows to inline a block into its producer
into_consumer : bool
If allows to inline a block into its consumer
inline_const_tensor : bool
Always inline constant tensors
disallow_if_then_else : bool
Always disallow if-then-else-like constructs
require_injective : bool
Always require the read-to-write mapping to be ordered
require_ordered : bool
Always require the read-to-write mapping to be injective
disallow_op : Optional[List[str]]
The operators that are disallowed in auto inline
"""

def __init__(
self,
into_producer: bool,
into_consumer: bool,
inline_const_tensor: bool,
disallow_if_then_else: bool,
require_injective: bool,
require_ordered: bool,
disallow_op: Optional[List[str]] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleAutoInline, # type: ignore # pylint: disable=no-member
into_producer,
into_consumer,
inline_const_tensor,
disallow_if_then_else,
require_injective,
require_ordered,
disallow_op,
)
47 changes: 47 additions & 0 deletions python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Default schedule rules"""
from tvm.meta_schedule.schedule_rule import (
AutoInline,
ScheduleRule,
)
from tvm.target import Target


def auto_inline(target: Target) -> ScheduleRule:
"""Default schedule rules for auto inline"""
if target.kind.name == "llvm":
return AutoInline(
into_producer=False,
into_consumer=True,
inline_const_tensor=True,
disallow_if_then_else=True,
require_injective=True,
require_ordered=True,
disallow_op=["tir.exp"],
)
if target.kind.name == "cuda":
return AutoInline(
into_producer=True,
into_consumer=True,
inline_const_tensor=True,
disallow_if_then_else=False,
require_injective=False,
require_ordered=False,
disallow_op=None,
)
raise NotImplementedError(f"{target.kind.name} is not supported")
174 changes: 174 additions & 0 deletions src/meta_schedule/schedule_rule/auto_inline.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace meta_schedule {

/*! \brief The type of inline to be performed on a specific block */
enum class InlineType : int32_t {
/*! \brief No inline opportunity */
kNoInline = 0,
/*! \brief Inline the block into its consumer */
kInlineIntoConsumer = 1,
/*! \brief Inline the block into its producer */
kInlineIntoProducer = 2,
};

/*! \brief The rule that inlines spatial blocks if it satisfies some conditions. */
class AutoInlineNode : public ScheduleRuleNode {
public:
/*! \brief Checks if the specific block should be inlined */
inline InlineType CheckInline(const tir::Schedule& sch, const tir::BlockRV& block_rv);

// Inherited from ScheduleRuleNode
void InitializeWithTuneContext(const TuneContext& context) final {}

// Inherited from ScheduleRuleNode
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
InlineType inline_type = CheckInline(sch, block_rv);
if (inline_type == InlineType::kInlineIntoConsumer) {
sch->ComputeInline(block_rv);
} else if (inline_type == InlineType::kInlineIntoProducer) {
sch->ReverseComputeInline(block_rv);
}
return {sch};
}

public:
/*! \brief If allows to inline a block into its producer */
bool into_producer;
/*! \brief If allows to inline a block into its consumer */
bool into_consumer;
/*! \brief Always inline constant tensors */
bool inline_const_tensor;
/*! \brief Always disallow if-then-else-like constructs */
bool disallow_if_then_else;
/*! \brief Always require the read-to-write mapping to be injective to do auto inline */
bool require_injective;
/*! \brief Always require the read-to-write mapping to be ordered to do auto inline */
bool require_ordered;
/*! \brief The operators that are disallowed in auto inline */
Array<Op> disallow_op;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("into_producer", &into_producer);
v->Visit("into_consumer", &into_consumer);
v->Visit("inline_const_tensor", &inline_const_tensor);
v->Visit("disallow_if_then_else", &disallow_if_then_else);
v->Visit("require_injective", &require_injective);
v->Visit("require_ordered", &require_ordered);
v->Visit("disallow_op", &disallow_op);
}

static constexpr const char* _type_key = "meta_schedule.AutoInline";
TVM_DECLARE_FINAL_OBJECT_INFO(AutoInlineNode, ScheduleRuleNode);
};

inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch,
const tir::BlockRV& block_rv) {
using namespace tvm::tir;
StmtSRef block_sref = sch->GetSRef(block_rv);
ScheduleState state = sch->state();
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
BlockRealize realize = GetBlockRealize(state, block_sref);
// Cond 1. The block has only one write buffer
if (block->writes.size() != 1) {
return InlineType::kNoInline;
}
// Cond 2. For a block that generates a constant tensor, ignore all other conditions
if (inline_const_tensor && block->reads.empty()) {
return InlineType::kInlineIntoConsumer;
}
// Cond 3. The block doesn't contain any disallowed operators
if (!disallow_op.empty() && HasOp(realize, disallow_op)) {
return InlineType::kNoInline;
}
// Cond 4. The block doesn't have any if-then-else-like constructs
if (disallow_if_then_else && HasIfThenElse(realize)) {
return InlineType::kNoInline;
}
// Cond 5. The mapping from read indices to write indices are injective and ordered
if (require_injective || require_ordered) {
const BufferRegion& write_region = block->writes[0];
for (const BufferRegion& read_region : block->reads) {
bool injective, ordered;
auto _ = std::ignore;
std::tie(/*exists=*/_, /*surjective=*/_, injective, ordered, /*no_const_read=*/_,
/*no_shift_read=*/_) = AnalyzeReadWritePattern(read_region, write_region);
if (require_injective && injective == false) {
return InlineType::kNoInline;
}
if (require_ordered && ordered == false) {
return InlineType::kNoInline;
}
}
}
// Last cond: Check inline into the consumers or the spatial producer
tir::StmtSRef scope_block = tir::GetScopeRoot(sch->state(), block_sref, //
/*require_stage_pipeline=*/false, //
/*require_subtree_compact_dataflow=*/false);
if (into_consumer) {
Array<tir::StmtSRef> consumer_srefs = GetConsumers(state, block_sref);
if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) {
return InlineType::kInlineIntoConsumer;
}
}
if (into_producer) {
Array<tir::StmtSRef> producer_srefs = GetProducers(state, block_sref);
if (producer_srefs.size() == 1 &&
tir::IsCompleteBlock(sch->state(), producer_srefs[0], scope_block) &&
CanReverseComputeInline(state, block_sref)) {
return InlineType::kInlineIntoProducer;
}
}
return InlineType::kNoInline;
}

ScheduleRule ScheduleRule::AutoInline(bool into_producer, //
bool into_consumer, //
bool inline_const_tensor, //
bool disallow_if_then_else, //
bool require_injective, //
bool require_ordered, //
Optional<Array<String>> disallow_op) {
ObjectPtr<AutoInlineNode> n = make_object<AutoInlineNode>();
n->into_producer = into_producer;
n->into_consumer = into_consumer;
n->inline_const_tensor = inline_const_tensor;
n->disallow_if_then_else = disallow_if_then_else;
n->require_injective = require_injective;
n->require_ordered = require_ordered;
n->disallow_op.clear();
if (disallow_op.defined()) {
Array<String> op_names = disallow_op.value();
n->disallow_op.reserve(op_names.size());
for (const String& op_name : op_names) {
n->disallow_op.push_back(Op::Get(op_name));
}
}
return ScheduleRule(n);
}

TVM_REGISTER_NODE_TYPE(AutoInlineNode);
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoInline")
.set_body_typed(ScheduleRule::AutoInline);

} // namespace meta_schedule
} // namespace tvm
45 changes: 45 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#define TVM_TIR_SCHEDULE_ANALYSIS_H_

#include <tvm/arith/analyzer.h>
#include <tvm/ir/op.h>
#include <tvm/tir/schedule/state.h>

#include <tuple>
Expand Down Expand Up @@ -442,6 +443,50 @@ bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const S
bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& loop_sref, bool preserve_unit_loops);

/*!
* \brief Checks if the given AST contains the specific operators
* \param stmt The AST statement to be checked
* \param ops The list of operators to be checked
* \return A boolean indicating whether the AST contains the specific operators
*/
bool HasOp(const Stmt& stmt, const Array<Op>& ops);

/*!
* \brief Checks if the given AST statement contains if-then-else, including
* 1) IfThenElse statement
* 2) Select expression
* 3) The operator `tir.if_then_else`
* 4) non-constant-true Block predicates
* \param stmt The AST statement to be checked
* \return A boolean indicating whether the statement contains the if-then-else pattern
*/
bool HasIfThenElse(const Stmt& stmt);

/*!
* \brief Given the read/write region, extract the pattern of their index correspondence
* namely, the mapping from read index to the write index.
* \param read_region The read region
* \param write_region The write region
* \return A tuple of booleans, the extracted pattern
* 0) exists: if the pattern is found
* 1) surjective: if the pattern is surjective, i.e. each write index is mapped at least once
* e.g. A[i, j] = B[i, i, j]
* 2) injective: if the pattern is injective, i.e. each write index is mapped at most once.
* e.g. A[i, j] = B[i]
* 3) ordered: if the mapping is ordered
* 4) no_const_read: if there is no constant indexing in the read indices,
* e.g. A[i, j] = B[0, i, j]
* 5) no_shift_read: if there is no constant shift in the read indices,
* e.g. A[i, j] = B[i + 1, j]
*/
std::tuple</*exists=*/bool,
/*surjective=*/bool,
/*injective=*/bool,
/*ordered=*/bool,
/*no_const_read=*/bool,
/*no_shift_read=*/bool>
AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& write_region);

} // namespace tir
} // namespace tvm

Expand Down
Loading

0 comments on commit ac85d18

Please sign in to comment.