Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR][M2a] Compute-Inline,Reverse-Compute-Inline #8170

Merged
merged 1 commit into from
Jun 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,35 @@ class ScheduleNode : public runtime::Object {
* \return A list of loops above the given block in its scope, from outer to inner
*/
virtual Array<LoopRV> GetLoops(const BlockRV& block_rv) = 0;
/******** Schedule: loops manipulation ********/
/******** Schedule: compute location ********/
/*!
* \brief Inline a block into its consumer(s). It requires:
* 1) The block is a complete non-root block, which only produces one buffer
* 2) The block must not be the only leaf in the scope.
* 3) The body of the block must be a BufferStore statement in the form of,
* A[i, j, k, ...] = ...
* where the indices of the LHS are all distinct atomic variables,
* and no variables other than those indexing variables are allowed in the statement.
* \param block The block to be inlined to its consumer(s)
*/
virtual void ComputeInline(const BlockRV& block) = 0;
/*!
* \brief Inline a block into its only producer. It requires:
* 1) The block is a complete non-root block, which only produces and consumers one buffer
* 2) The block must not be the only leaf in the scope.
* 3) The only producer of the block is a read-after-write producer and a complete non-root block
* 4) The body of the block must be a BufferStore statement in the form of,
* B[f(i, j, k, ...)] = g(i, j, k, A[i, j, k, ...] ...)
* where the indices of each `BufferLoad` on the RHS are all distinct atomic variables,
* and no variables other than those indexing variables are allowed in the statement.
* \param block The block to be inlined to its producer
*/
virtual void ReverseComputeInline(const BlockRV& block) = 0;
/******** Schedule: loop binding/annotation ********/
/******** Schedule: cache read/write ********/
/******** Schedule: reduction ********/
/******** Schedule: blockize & tensorize ********/
};

/*!
Expand Down
115 changes: 115 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,121 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]:
"""
return _ffi_api_schedule.ScheduleGetLoops(self, block) # pylint: disable=no-member

########## Schedule: loops manipulation ##########
########## Schedule: compute location ##########
def compute_inline(self, block: BlockRV) -> None:
"""Inline a block into its consumer(s). It requires:
1) The block is a complete non-root block, which only produces one buffer
2) The block must not be the only leaf in the scope.
3) The body of the block must be a BufferStore statement in the form of,
A[i, j, k, ...] = ...
where the indices of the LHS are all distinct atomic variables,
and no variables other than those indexing variables are allowed in the statement.

Parameters
----------
block : BlockRV
The block to be inlined to its consumer(s)

Examples
--------

Before compute-inline, in TensorIR, the IR is:

.. code-block:: python

@tvm.script.tir
def before_inline(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.alloc_buffer((128, 128))
C = tir.match_buffer(c, (128, 128))
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
with tir.block([128, 128], "C") as [vi, vj]:
C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do compute-inline:

.. code-block:: python

sch = tir.Schedule(before_inline, debug_mode=True)
sch.compute_inline(sch.get_block("B"))
print(tvm.script.asscript(sch.mod["main"]))

After applying compute-inline, the IR becomes:

.. code-block:: python

@tvm.script.tir
def after_inline(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
C = tir.match_buffer(c, (128, 128))
with tir.block([128, 128], "C") as [vi, vj]:
C[vi, vj] = A[vi, vj] * 2.0 + 1.0

"""
_ffi_api_schedule.ScheduleComputeInline(self, block) # pylint: disable=no-member

def reverse_compute_inline(self, block: BlockRV) -> None:
"""Inline a block into its only producer. It requires:
1) The block is a complete non-root block, which only produces and consumes one buffer
2) The block must not be the only leaf in the scope.
3) The only producer of the block is a read-after-write producer
and a complete non-root block
4) The body of the block must be a BufferStore statement in the form of,
B[f(i, j, k, ...)] = g(i, j, k, A[i, j, k, ...] ...)
where the indices of each `BufferLoad` on the RHS are all distinct atomic variables,
and no variables other than those indexing variables are allowed in the statement.

Parameters
----------
block : BlockRV
The block to be inlined to its producer

Examples
--------

Before reverse-compute-inline, in TensorIR, the IR is:

.. code-block:: python

@tvm.script.tir
def before_inline(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.alloc_buffer((128, 128))
C = tir.match_buffer(c, (128, 128))
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
with tir.block([128, 128], "C") as [vi, vj]:
C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do reverse-compute-inline:

.. code-block:: python

sch = tir.Schedule(before_inline, debug_mode=True)
sch.reverse_compute_inline(sch.get_block("C"))
print(tvm.script.asscript(sch.mod["main"]))

After applying reverse-compute-inline, the IR becomes:

.. code-block:: python

@tvm.script.tir
def after_inline(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
C = tir.match_buffer(c, (128, 128))
with tir.block([128, 128], "C") as [vi, vj]:
C[vi, vj] = A[vi, vj] * 2.0 + 1.0

"""
_ffi_api_schedule.ScheduleReverseComputeInline(self, block) # pylint: disable=no-member

########## Schedule: loop binding/annotation ##########
########## Schedule: cache read/write ##########
########## Schedule: reduction ##########
########## Schedule: blockize & tensorize ##########


@_register_object("tir.ConcreteSchedule")
class ConcreteSchedule(Schedule):
Expand Down
72 changes: 72 additions & 0 deletions src/support/array.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_SUPPORT_ARRAY_H_
#define TVM_SUPPORT_ARRAY_H_
#include <tvm/runtime/container.h>

#include <vector>

namespace tvm {
namespace support {

/*!
* \brief Checks if two arrays contain the same objects
* \tparam T The type of objects in the array
* \param a The first array
* \param b The second array
* \return A boolean indicating if they are the same
*/
template <class T>
inline bool ArrayWithSameContent(const Array<T>& a, const Array<T>& b) {
if (a.size() != b.size()) {
return false;
}
int n = a.size();
for (int i = 0; i < n; ++i) {
if (!a[i].same_as(b[i])) {
return false;
}
}
return true;
}

/*!
* \brief Checks if two arrays contain the same objects
* \tparam T The type of objects in the array
* \param a The first array
* \param b The second array
* \return A boolean indicating if they are the same
*/
template <class T>
inline bool ArrayWithSameContent(const std::vector<T*>& a, const std::vector<T*>& b) {
if (a.size() != b.size()) {
return false;
}
int n = a.size();
for (int i = 0; i < n; ++i) {
if (a[i] != b[i]) {
return false;
}
}
return true;
}

} // namespace support
} // namespace tvm
#endif // TVM_SUPPORT_ARRAY_H_
59 changes: 51 additions & 8 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ namespace tir {

/******** Verification ********/
/*!
* \brief Verify the sref tree state is consistent with the IR
* \brief Verifies the sref tree state is consistent with the IR
* \param self The schedule state containing the sref to be verified
* \throw An exception will be thrown if the sref tree is not valid
*/
void VerifySRefTree(const ScheduleState& self);
/*!
* \brief Verify the cached flags in the schedule state, including:
* \brief Verifies the cached flags in the schedule state, including:
* - affine_binding
* - region_cover
* - stage_pipeline
Expand All @@ -41,10 +41,53 @@ void VerifySRefTree(const ScheduleState& self);
*/
void VerifyCachedFlags(const ScheduleState& self);

/******** Binding ********/
/******** Scope ********/
/*!
* \brief Gets the sref to the scope root block, exclusive
* \param sref The block or loop sref to be retrieved
* \return The sref to the scope root block. NullOpt if `sref` is the root block of the IR
*/
Optional<StmtSRef> GetScopeRoot(const StmtSRef& sref);

/*!
* \brief Checks if scope the specified sref is in is a stage-pipeline and return it
junrushao marked this conversation as resolved.
Show resolved Hide resolved
* \param prim The name of the schedule primitive
* \param self The schedule state
* \param sref The sref whose scope is to be checked
* \throw ScheduleError if the sref has been the root of the AST (so it has no scope root), or its
* scope root is not a stage pipeline
* \return The block sref to the scope root
*/
StmtSRef GetScopeRootAndCheckStagePipeline(const ScheduleState& self, const StmtSRef& sref);

/*!
* \brief Checks whether the block is a complete block under the scope
* \param self The schedule state
* \param block_sref The block to be checked
* \param scope_root The sref to the root block of the scope that `block_sref` is in
* \return A boolean indicating if the block is a complete block
* \note Definition of a complete block:
* 1) All block vars are data parallel
* 2) Dominant: the block is the only writer of its output,
* dominating the reader of its output buffers
* 3) No overlap between the buffers the block reads and writes
*/
bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_root);

/*!
* \brief Checks if the block is a complete block
* \param self The schedule state
* \param block_sref The sref to the block whose completeness is to be checked
* \param scope_root_sref The scope root of the block
* \throw ScheduleError If the block is not a complete block
*/
void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_root_sref);

/******** Binding ********/
/*!
* \brief Verify if the block binding in a specific BlockRealize is an affine binding.
* \brief Verifies if the block binding in a specific BlockRealize is an affine binding.
* The binding can be represented as an injective affine map from the loop iterators.
* \param realize The BlockRealize to be analyzed
* \param loop_var_ranges The ranges of the loop variables
Expand All @@ -55,7 +98,7 @@ bool IsAffineBinding(const BlockRealize& realize, const Map<Var, Range>& loop_va
arith::Analyzer* analyzer);

/*!
* \brief Extract the ranges of loop variables in a path of the sref tree
* \brief Extracts the ranges of loop variables in a path of the sref tree
* \param low_inclusive The lowest node in the path
* \param high_exclusive The highest node in the path, defaults to the scope root if not specified
* \param extra_relax_scope If the scope is not global, the method will look beyond the limit and
Expand All @@ -78,22 +121,22 @@ Map<Var, PrimExpr> GetBindings(const BlockRealize& realize);

/******** Block-loop relation ********/
/*!
* \brief Retrieve blocks in a specific function with its name
* \brief Retrieves blocks in a specific function with its name
* \param self The schedule state
* \param name The name of the blocks to be retrieved
* \param func_name The name of the function
* \return A list of blocks with the specific name
*/
Array<StmtSRef> GetBlocks(const ScheduleState& self, const String& name, const String& func_name);
/*!
* \brief Get the parent loops of the block in its scope, from outer to inner
* \brief Gets the parent loops of the block in its scope, from outer to inner
* \param self The schedule state
* \param block_sref The query block
* \return A list of loops above the given block in its scope, from outer to inner
*/
Array<StmtSRef> GetLoops(const StmtSRef& block_sref);
/*!
* \brief Get the leaf blocks of a scope where a specific block/loop is in
* \brief Gets the leaf blocks of a scope where a specific block/loop is in
* \param self The schedule state
* \param parent_sref The StmtSRef that points to the parent block/loop
* \return A list of leaf blocks
Expand Down
Loading