Skip to content

Commit

Permalink
[TIR, analysis] Add GetAutoTensorizeMappingInfo to generate transform…
Browse files Browse the repository at this point in the history
…s for auto tensorization (#11740)

This PR added a utility function `GetAutoTensorizeMappingInfo` to propose mapping from workload block iters to the iters in the tensor intrin. An example usage is conv2d, where the computation block has more iters than the matmul tensor intrin.
  • Loading branch information
vinx13 authored Jun 19, 2022
1 parent 77756ea commit 9bba758
Show file tree
Hide file tree
Showing 7 changed files with 630 additions and 31 deletions.
68 changes: 68 additions & 0 deletions python/tvm/meta_schedule/testing/te_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,74 @@ def softmax_mn(m, n) -> Tuple[te.Tensor, te.Tensor]: # pylint: disable=invalid-
return (a, b)


def conv2d_nhwc_f16( # pylint: disable=invalid-name,missing-docstring
N: int,
H: int,
W: int,
CI: int,
CO: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
):
inputs = te.placeholder((N, H, W, CI), name="inputs", dtype="float16")
weight = te.placeholder(
(kernel_size, kernel_size, CI // groups, CO), name="weight", dtype="float16"
)
batch_size, in_h, in_w, _ = inputs.shape
k_h, k_w, channel_per_group, out_channel = weight.shape
out_channel_per_group = out_channel // groups

out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1
out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1
rh = te.reduce_axis((0, k_h), name="rh")
rw = te.reduce_axis((0, k_w), name="rw")
rc = te.reduce_axis((0, channel_per_group), name="rc")

padded = topi.nn.pad(inputs, [0, padding, padding, 0])
output = te.compute(
(batch_size, out_h, out_w, out_channel),
lambda n, h, w, co: te.sum(
(
tir.Cast(
value=padded[
n,
h * stride + rh * dilation,
w * stride + rw * dilation,
co // out_channel_per_group * channel_per_group + rc,
],
dtype="float32",
)
* tir.Cast(value=weight[rh, rw, rc, co], dtype="float32")
),
axis=[rh, rw, rc],
),
name="conv2d_nhwc",
)
return (inputs, weight, output)


def batch_matmul_nkkm_f16( # pylint: disable=invalid-name,missing-docstring
B: int,
N: int,
M: int,
K: int,
) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
x = te.placeholder((B, N, K), name="X", dtype="float16")
y = te.placeholder((B, K, M), name="Y", dtype="float16")
k = te.reduce_axis((0, K), name="k")
z = te.compute( # pylint: disable=invalid-name
(B, N, M),
lambda b, i, j: te.sum(
tir.Cast("float32", x[b][i][k]) * tir.Cast("float32", y[b][k][j]), axis=[k]
),
name="Z",
)
return (x, y, z)


def create_te_workload(name: str, idx: int) -> tir.PrimFunc:
workload_func, params = CONFIGS[name]
return te.create_prim_func(workload_func(*params[idx])) # type: ignore
Expand Down
34 changes: 34 additions & 0 deletions python/tvm/tir/schedule/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,37 @@ def get_tensorize_loop_mapping(
TensorizeInfo structure if a valid mapping is found, None otherwise
"""
return _ffi_api.GetTensorizeLoopMapping(sch, block, desc_func) # type: ignore


@tvm._ffi.register_object("tir.schedule.AutoTensorizeMappingInfo")
class AutoTensorizeMappingInfo(Object):
"""Necessary information used to perform transformations for tensorization."""


def get_auto_tensorize_mapping_info(
sch: Schedule, block: BlockRV, desc_func: PrimFunc
) -> Optional[AutoTensorizeMappingInfo]:
"""Get mapping info between a target block and an intrinsic description including layout
transformations to apply.
Parameters
----------
sch : Schedule
The schedule to be tensorized
block : BlockRV
The compute block for auto tensorization
desc_func : PrimFunc
The prim func describing the computation to be tensorized
Returns
-------
auto_tensorize_mapping_info : Optional[AutoTensorizeMappingInfo]
AutoTensorizeMappingInfo structure if potential mappings found, None otherwise.
Note
----
Returning a valid AutoTensorizeMappingInfo doesn't guarantee the block can be tensorized.
We will need to apply the suggested layout transformations and then match against the tensor
intrinsics.
"""
return _ffi_api.GetAutoTensorizeMappingInfo(sch, block, desc_func) # type: ignore
50 changes: 50 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,56 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
const tir::StmtSRef& block_sref,
const tir::PrimFunc& desc_func);

/*!\brief Necessary information used to perform transformations for tensorization */
class AutoTensorizeMappingInfoNode : public Object {
public:
/*! \brief Possible mappings to apply to block iters */
Array<IndexMap> mappings;

/* Additional information from AutoTensorizeComparator */

/*! \brief Mapping from LHS buffer to RHS buffer */
Map<Buffer, Buffer> lhs_buffer_map;
/*! \brief Buffer indices on RHS */
Map<Buffer, Array<PrimExpr>> rhs_buffer_indices;
/*! \brief Block iters on LHS */
Array<IterVar> lhs_iters;
/*! \brief Block iters on RHS */
Array<IterVar> rhs_iters;

void VisitAttrs(AttrVisitor* v) {
v->Visit("mappings", &mappings);
v->Visit("lhs_buffer_map", &lhs_buffer_map);
v->Visit("rhs_buffer_indices", &rhs_buffer_indices);
v->Visit("lhs_iters", &lhs_iters);
v->Visit("rhs_iters", &rhs_iters);
}

static constexpr const char* _type_key = "tir.schedule.AutoTensorizeMappingInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(AutoTensorizeMappingInfoNode, Object);
};

class AutoTensorizeMappingInfo : public ObjectRef {
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AutoTensorizeMappingInfo, ObjectRef,
AutoTensorizeMappingInfoNode);
};

/*!
* \brief Get mapping info between a target block and an intrinsic description including layout
* transformations to apply.
* \param self The schedule state
* \param block_sref The compute block for auto tensorization
* \param desc_func The prim func describing the computation to be tensorized
* \return AutoTensorizeMappingInfo structure if a potential mapping is found, NullOpt otherwise.
* \note Returning a valid AutoTensorizeMappingInfo doesn't guarantee the block can be tensorized.
* We will need to apply the suggested layout transformations and then match against the tensor
* intrinsics.
*/
Optional<AutoTensorizeMappingInfo> GetAutoTensorizeMappingInfo(const ScheduleState& self,
const StmtSRef& block_sref,
const PrimFunc& desc_func);

} // namespace tir
} // namespace tvm

Expand Down
Loading

0 comments on commit 9bba758

Please sign in to comment.