Skip to content

Commit

Permalink
[TIR, Schedule] Add schedule primitive PadEinsum
Browse files Browse the repository at this point in the history
Co-authored-by: Bohan Hou <[email protected]>
  • Loading branch information
vinx13 and spectrometerHBH committed Sep 9, 2022
1 parent da48e13 commit af3c595
Show file tree
Hide file tree
Showing 15 changed files with 847 additions and 40 deletions.
20 changes: 20 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,7 @@ class ScheduleNode : public runtime::Object {
BufferIndexType buffer_index_type,
const Array<IntImm>& axis_separators) = 0;

/******** Schedule: Padding ********/
/*!
* \brief Decompose a padding block into a block filling const pad values and a block
* writing in-bound values.
Expand All @@ -636,6 +637,25 @@ class ScheduleNode : public runtime::Object {
*/
virtual BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) = 0;

/*!
* \brief Pad the computation of Einsum.
* \param block_rv The block that matches the Einsum pattern.
* \param padding The padding for each block iter.
* \details This schedule primitives identifies the Einsum pattern in the block body, and find its
* producer blocks. It then pads the computation of the Einsum pattern and its producer blocks.
* The output buffer and the producer buffer is resized according to the padding size. It requires
* the output buffer and the producer buffer to be allocated inside the PrimFunc.
*
* The padding is a list of non-negative integers, each element corresponds to the padding for
* each block iter in the order of block iters. The block and it's producer blocks should have
* trivial bindings, i.e. each block iter is bound to a single loop variable. After padding, the
* block iter extent and the corresponding outer loop is extended by the padding size.
*
* The size of the producer buffers are infered from the padding size of the Einsum computation.
* The producer buffers are padded by the initial value of the corresponding reduction.
*/
virtual void PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) = 0;

/******** Schedule: Misc ********/
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
virtual void EnterPostproc() = 0;
Expand Down
122 changes: 122 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2783,6 +2783,128 @@ def can_decompose_padding(self, block: Union[BlockRV, str], loop: LoopRV) -> boo
"""Check whether the block match padding pattern and can be decomposed."""
return _ffi_api.CanDecomposePadding(self, block, loop) # type: ignore # pylint: disable=no-member

@type_checked
def pad_einsum(self, block: Union[BlockRV, str], padding: List[int]) -> None:
"""Pad the computation of Einsum.
This schedule primitives identifies the Einsum pattern in the block body, and find its
producer blocks. It then pads the computation of the Einsum pattern and its producer blocks.
The output buffer and the producer buffer is resized according to the padding size. It
requires the output buffer and the producer buffer to be allocated inside the PrimFunc.
The padding is a list of non-negative integers, each element corresponds to the padding for
each block iter in the order of block iters. The block and it's producer blocks should have
trivial bindings, i.e. each block iter is bound to a single loop variable. After padding,
thblock iter extent and the corresponding outer loop is extended by the padding size.
The size of the producer buffers are infered from the padding size of the Einsum
computation. The producer buffers are padded by the initial value of the corresponding
reduction.
Parameters
----------
block : Union[BlockRV, str]
The block that matches the Einsum pattern.
padding : List[int]
The padding for each block iter.
Examples
--------
Before applying pad-einsum, in TensorIR, the IR is:
.. code-block:: python
@T.prim_func
def before_pad_einsum(
A: T.Buffer[(128, 127), "float32"],
B: T.Buffer[(127, 127), "float32"],
C: T.Buffer[(128, 127), "float32"],
) -> None:
A_shared = T.alloc_buffer((128, 127), "float32", scope="shared")
B_shared = T.alloc_buffer((127, 127), "float32", scope="shared")
C_shared = T.alloc_buffer((128, 127), "float32", scope="shared")
for i0, i1 in T.grid(128, 127):
with T.block("A"):
i, j = T.axis.remap("SS", [i0, i1])
A_shared[i, j] = A[i, j]
for i0, i1 in T.grid(127, 127):
with T.block("B"):
i, j = T.axis.remap("SS", [i0, i1])
B_shared[i, j] = B[i, j]
for i0, i1, i2 in T.grid(128, 127, 127):
with T.block("C_shared"):
i, j, k = T.axis.remap("SSR", [i0, i1, i2])
with T.init():
C_shared[i, j] = T.float32(0)
C_shared[i, j] = C_shared[i, j] + A_shared[i, k] * B_shared[k, j]
for i0, i1 in T.grid(128, 127):
with T.block("C"):
i, j = T.axis.remap("SS", [i0, i1])
C[i, j] = C_shared[i, j]
Create the schedule and do pad-einsum with specified block:
.. code-block:: python
sch = tir.Schedule(before_pad_einsum, debug_mask="all")
block = sch.get_block("C_shared")
sch.pad_einsum(block, [0, 1, 1])
print(sch.mod["main"].script())
After applying decompose-padding, the IR becomes:
.. code-block:: python
@T.prim_func
def after_pad_einsum(
A: T.Buffer[(128, 127), "float32"],
B: T.Buffer[(127, 127), "float32"],
C: T.Buffer[(128, 127), "float32"],
) -> None:
A_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
B_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
C_shared_padded = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
for i0, i1 in T.grid(128, 128):
with T.block("A"):
i, j = T.axis.remap("SS", [i0, i1])
T.reads(A[i, j])
T.writes(A_shared_padded[i, j])
A_shared_padded[i, j] = T.if_then_else(
j < 127, A[i, j], T.float32(0), dtype="float32"
)
for i0, i1 in T.grid(128, 128):
with T.block("B"):
i, j = T.axis.remap("SS", [i0, i1])
T.reads(B[i, j])
T.writes(B_shared_padded[i, j])
B_shared_padded[i, j] = T.if_then_else(
i < 127 and j < 127, B[i, j], T.float32(0), dtype="float32"
)
for i0, i1, i2 in T.grid(128, 128, 128):
with T.block("C_shared"):
i, j, k = T.axis.remap("SSR", [i0, i1, i2])
T.reads(A_shared_padded[i, k], B_shared_padded[k, j])
T.writes(C_shared_padded[i, j])
with T.init():
C_shared_padded[i, j] = T.float32(0)
C_shared_padded[i, j] = (
C_shared_padded[i, j] + A_shared_padded[i, k] * B_shared_padded[k, j]
)
for i0, i1 in T.grid(128, 127):
with T.block("C"):
i, j = T.axis.remap("SS", [i0, i1])
T.reads(C_shared_padded[i, j])
T.writes(C[i, j])
C[i, j] = C_shared_padded[i, j]
"""
block = self._normalize_block_arg(block)
return _ffi_api.SchedulePadEinsum( # type: ignore # pylint: disable=no-member
self, block, padding
)

########## Schedule: Misc ##########

@type_checked
Expand Down
9 changes: 9 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,15 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize,
void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref,
arith::Analyzer* analyzer);

/*!
* \brief Check whether a block has a trivial binding, i.e. each block var is bound to a outer loop,
* from outer to inner.
* \param self The schedule state
* \param block_sref The block to be checked
* \throw ScheduleError If the block does not have trivial bindings
*/
void CheckBlockHasTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref);

/******** Block-loop relation ********/

/*!
Expand Down
29 changes: 29 additions & 0 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,35 @@ void CheckAffineBinding(const ScheduleState& self, Block block) {
CheckPartialAffineBinding(self, std::move(block), NullOpt);
}

void CheckBlockHasTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) {
class NotTrivialBindingError : public ScheduleError {
public:
explicit NotTrivialBindingError(IRModule mod, Block block)
: mod_(std::move(mod)), block_(std::move(block)) {}

String FastErrorString() const final {
return "ScheduleError: The binding values of the block are not variables of outer loops.";
}

String DetailRenderTemplate() const final {
std::ostringstream os;
os << "The binding values of the {0} are not variables of outer loops.";
return os.str();
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }

private:
IRModule mod_;
Block block_;
};

if (!IsTrivialBinding(self, block_sref)) {
throw NotTrivialBindingError(self->mod, GetRef<Block>(block_sref->StmtAs<BlockNode>()));
}
}

Map<Var, Range> LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive,
const Optional<StmtSRef>& high_exclusive,
const runtime::StorageScope& extra_relax_scope) {
Expand Down
6 changes: 6 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,12 @@ BlockRV ConcreteScheduleNode::DecomposePadding(const BlockRV& block_rv, const Lo
return CreateRV<BlockRV>(result);
}

void ConcreteScheduleNode::PadEinsum(const BlockRV& block_rv, const Array<Integer>& padded_shape) {
TVM_TIR_SCHEDULE_BEGIN();
tir::PadEinsum(state_, this->GetSRef(block_rv), padded_shape);
TVM_TIR_SCHEDULE_END("pad-einsum", this->error_render_level_);
this->state_->DebugVerify();
}
/******** Schedule: Misc ********/

} // namespace tir
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class ConcreteScheduleNode : public ScheduleNode {
/******** Schedule: Reduction ********/
BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override;
BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) override;
void PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) override;
/******** Schedule: Block annotation ********/
void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
int offset) override;
Expand Down
11 changes: 10 additions & 1 deletion src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int
TVM_DLL void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
const IndexMap& index_map);

/******** Schedule: Padding decomposition ********/
/******** Schedule: Padding ********/
/*!
* \brief Decompose a padding block into a block filling const pad values and a block
* writing in-bound values.
Expand All @@ -501,6 +501,15 @@ TVM_DLL void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref
TVM_DLL StmtSRef DecomposePadding(ScheduleState self, const StmtSRef& block_sref,
const StmtSRef& loop_sref);

/*!
* \brief Pad the computation of Einsum.
* \param self The state of the schedule
* \param block_sref The block sref that matches the Einsum pattern.
* \param padding The padding for each block iter.
*/
TVM_DLL void PadEinsum(ScheduleState self, const StmtSRef& block_sref,
const Array<Integer>& padding);

/******** Schedule: Misc ********/

} // namespace tir
Expand Down
36 changes: 1 addition & 35 deletions src/tir/schedule/primitive/layout_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,40 +278,6 @@ class IndexMapNotApplicableToBlockIterError : public ScheduleError {
IndexMap index_map_;
};

class NotTrivialBindingError : public ScheduleError {
public:
explicit NotTrivialBindingError(IRModule mod, Block block)
: mod_(std::move(mod)), block_(std::move(block)) {}

static void CheckBlockHasTrivialBinding(const IRModule& mod, const BlockRealize& block_realize,
std::unordered_set<const VarNode*> outer_loop_vars) {
// Step 2: Check all the binding values are loops vars
for (const PrimExpr& iter_value : block_realize->iter_values) {
const VarNode* loop_var = iter_value.as<VarNode>();
if (!loop_var || !outer_loop_vars.count(loop_var)) {
throw NotTrivialBindingError(mod, block_realize->block);
}
}
}

String FastErrorString() const final {
return "ScheduleError: The binding values of the block are not variables of outer loops.";
}

String DetailRenderTemplate() const final {
std::ostringstream os;
os << "The binding values of the {0} are not variables of outer loops.";
return os.str();
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }

private:
IRModule mod_;
Block block_;
};

class OpaqueNewIterTypeError : public ScheduleError {
public:
explicit OpaqueNewIterTypeError(IRModule mod, Block block, PrimExpr iter_value)
Expand Down Expand Up @@ -363,7 +329,7 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
}

BlockRealize block_realize = GetBlockRealize(self, block_sref);
NotTrivialBindingError::CheckBlockHasTrivialBinding(self->mod, block_realize, loop_vars);
CheckBlockHasTrivialBinding(self, block_sref);

// Step 3: Collect information of block iter vars
Array<PrimExpr> block_vars; // iter_var->var of each block iter
Expand Down
Loading

0 comments on commit af3c595

Please sign in to comment.