Skip to content

Commit

Permalink
Tensorize mapping proposer
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Jun 15, 2022
1 parent 28e778c commit dc57daf
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 129 deletions.
68 changes: 68 additions & 0 deletions python/tvm/meta_schedule/testing/te_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,74 @@ def softmax_mn(m, n) -> Tuple[te.Tensor, te.Tensor]: # pylint: disable=invalid-
return (a, b)


def conv2d_nhwc_f16( # pylint: disable=invalid-name,missing-docstring
N: int,
H: int,
W: int,
CI: int,
CO: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
):
inputs = te.placeholder((N, H, W, CI), name="inputs", dtype="float16")
weight = te.placeholder(
(kernel_size, kernel_size, CI // groups, CO), name="weight", dtype="float16"
)
batch_size, in_h, in_w, _ = inputs.shape
k_h, k_w, channel_per_group, out_channel = weight.shape
out_channel_per_group = out_channel // groups

out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1
out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1
rh = te.reduce_axis((0, k_h), name="rh")
rw = te.reduce_axis((0, k_w), name="rw")
rc = te.reduce_axis((0, channel_per_group), name="rc")

padded = topi.nn.pad(inputs, [0, padding, padding, 0])
output = te.compute(
(batch_size, out_h, out_w, out_channel),
lambda n, h, w, co: te.sum(
(
tir.Cast(
value=padded[
n,
h * stride + rh * dilation,
w * stride + rw * dilation,
co // out_channel_per_group * channel_per_group + rc,
],
dtype="float32",
)
* tir.Cast(value=weight[rh, rw, rc, co], dtype="float32")
),
axis=[rh, rw, rc],
),
name="conv2d_nhwc",
)
return (inputs, weight, output)


def batch_matmul_nkkm_f16( # pylint: disable=invalid-name,missing-docstring
B: int,
N: int,
M: int,
K: int,
) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
x = te.placeholder((B, N, K), name="X", dtype="float16")
y = te.placeholder((B, K, M), name="Y", dtype="float16")
k = te.reduce_axis((0, K), name="k")
z = te.compute( # pylint: disable=invalid-name
(B, N, M),
lambda b, i, j: te.sum(
tir.Cast("float32", x[b][i][k]) * tir.Cast("float32", y[b][k][j]), axis=[k]
),
name="Z",
)
return (x, y, z)


def create_te_workload(name: str, idx: int) -> tir.PrimFunc:
workload_func, params = CONFIGS[name]
return te.create_prim_func(workload_func(*params[idx])) # type: ignore
Expand Down
27 changes: 25 additions & 2 deletions python/tvm/tir/schedule/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,33 @@ def get_tensorize_loop_mapping(

@tvm._ffi.register_object("tir.schedule.AutoTensorizeMappingInfo")
class AutoTensorizeMappingInfo(Object):
"""TODO"""
"""Necessary information used to perform transformations for tensorization."""


def get_tensorize_layout_info(
def get_auto_tensorize_mapping_info(
sch: Schedule, block: BlockRV, desc_func: PrimFunc
) -> Optional[AutoTensorizeMappingInfo]:
"""Get mapping info between a target block and an intrinsic description including layout
transformations to apply.
Parameters
----------
sch : Schedule
The schedule to be tensorized
block : BlockRV
The compute block for auto tensorization
desc_func : PrimFunc
The prim func describing the computation to be tensorized
Returns
-------
auto_tensorize_mapping_info : Optional[AutoTensorizeMappingInfo]
AutoTensorizeMappingInfo structure if potential mappings found, None otherwise.
Note
----
Returning a valid AutoTensorizeMappingInfo doesn't guarantee the block can be tensorized.
We will need to apply the suggested layout transformations and then match against the tensor
intrinsics.
"""
return _ffi_api.GetAutoTensorizeMappingInfo(sch, block, desc_func) # type: ignore
18 changes: 12 additions & 6 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -711,16 +711,22 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
class AutoTensorizeMappingInfoNode : public Object {
public:
/*! \brief Possible mappings to apply to block iters */
Array<IndexMap> mapping;
Array<IndexMap> mappings;

/* Additional information from AutoTensorizeExtractor */

/*! \brief Mapping from LHS buffer to RHS buffer */
Map<Buffer, Buffer> lhs_buffer_map;

Map<Buffer, Array<PrimExpr>> rhs_indices_map;
Array<IterVar> lhs_iters, rhs_iters;
/*! \brief Buffer indices on RHS */
Map<Buffer, Array<PrimExpr>> rhs_buffer_indices;
/*! \brief Block iters on LHS */
Array<IterVar> lhs_iters;
/*! \brief Block iters on RHS */
Array<IterVar> rhs_iters;

void VisitAttrs(AttrVisitor* v) {
v->Visit("mapping", &mapping);
v->Visit("rhs_indices_map", &rhs_indices_map);
v->Visit("mappings", &mappings);
v->Visit("rhs_buffer_indices", &rhs_buffer_indices);
v->Visit("lhs_iters", &lhs_iters);
v->Visit("rhs_iters", &rhs_iters);
}
Expand Down
64 changes: 27 additions & 37 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2265,33 +2265,25 @@ TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping")
return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func);
});

<<<<<<< HEAD
=======
/******** Auto Tensorization ********/

/*! \brief IndexMap proposer for layout transformation in auto tensorization. */
class MappingProposer {
class AutoTensorizeMappingProposer {
public:
static Array<IndexMap> ProposeMappings(const AutoTensorizeExtractor* extractor) {
MappingProposer proposer(extractor);
static Array<IndexMap> ProposeMappings(const AutoTensorizeExtractor* extractor,
arith::Analyzer* analyzer) {
AutoTensorizeMappingProposer proposer(extractor, analyzer);
proposer.CollectFeasibleSet();
proposer.ProposeAllFuseMapping();
return proposer.mappings_;
return proposer.ProposeAllFuseMapping();
}

private:
explicit MappingProposer(const AutoTensorizeExtractor* extractor) : extractor_(extractor) {}
explicit AutoTensorizeMappingProposer(const AutoTensorizeExtractor* extractor,
arith::Analyzer* analyzer)
: extractor_(extractor), analyzer_(analyzer) {}

using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>;

std::string to_string(const VarSet& vs) {
std::ostringstream os;
for (const auto& v : vs) {
os << v << ", ";
}
return os.str();
};

void CollectFeasibleSet() {
// Collect the set of potential iter var mapping between the workload and the tensor intrin.
// We analyze the appearance of each variable in the buffer indices of each buffer on LHS and
Expand Down Expand Up @@ -2375,12 +2367,11 @@ class MappingProposer {
}

for (const auto& iter : extractor_->lhs_iters_) {
// lhs_representers.push_back(iter->var.copy_with_suffix("_l"));
lhs_feasible_vars_[iter->var] = mask_to_rhs_vars[lhs_buffer_masks[iter->var.get()]];
}
}

void ProposeAllFuseMapping() {
Array<IndexMap> ProposeAllFuseMapping() {
// Now we have calcuated potential mapping for each iter var on LHS. For iters on LHS mapped to
// the same iter on RHS, they will be fused in the original order in LHS block iters. We will
// generate IndexMap to represent such fusion on LHS. For example, if n, h, w on LHS are mapped
Expand Down Expand Up @@ -2416,33 +2407,34 @@ class MappingProposer {
} else if (rhs_candidates.size() == 1) {
Var rhs_var = *rhs_candidates.begin();
PrimExpr fused_lhs = fused_lhs_iters.at(rhs_var);
PrimExpr updated_fused_lhs = fused_lhs * lhs_iter_extents.at(lhs_iter_var) + index_map_src[i];
PrimExpr updated_fused_lhs =
fused_lhs * lhs_iter_extents.at(lhs_iter_var) + index_map_src[i];
fused_lhs_iters.Set(rhs_var, updated_fused_lhs);
} else {
// non-unique mapping is not supported
return;
return {};
}
}
arith::Analyzer analyzer;
for (const auto& iter : extractor_->rhs_iters_) {
index_map_tgt.push_back(analyzer.Simplify(fused_lhs_iters[iter->var]));
index_map_tgt.push_back(analyzer_->Simplify(fused_lhs_iters[iter->var]));
}
mappings_.push_back(IndexMap(index_map_src, index_map_tgt));
LOG(INFO) << mappings_[0];
// At most one mapping is supported.
return {IndexMap(index_map_src, index_map_tgt)};
}

public:
Array<Var> lhs_representers;
std::unordered_map<Buffer, Buffer, ObjectHash, ObjectEqual> lhs_buffer_map_;
// std::unordered_map<Var, VarSet, ObjectPtrHash, ObjectPtrEqual> rhs_feasible_vars_;
std::unordered_map<Var, VarSet, ObjectPtrHash, ObjectPtrEqual> lhs_feasible_vars_;
Array<IndexMap> mappings_;
private:
// The extractor that has extracted information for auto tensorization from the workload and the
// tensor intrin.
const AutoTensorizeExtractor* extractor_;
// The arithmetic analyzer.
arith::Analyzer* analyzer_;
/*! \brief Potential mappings on RHS for each variable on LHS */
std::unordered_map<Var, VarSet, ObjectPtrHash, ObjectPtrEqual> lhs_feasible_vars_;
};

Optional<AutoTensorizeMappingInfo> GetAutoTensorizeMappingInfo(const tir::ScheduleState& self,
const tir::StmtSRef& block_sref,
const tir::PrimFunc& desc_func) {
const tir::StmtSRef& block_sref,
const tir::PrimFunc& desc_func) {
arith::Analyzer analyzer;
const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
// Step 1. Analyze desc_func, extract its block, loops and loop vars
Expand All @@ -2455,15 +2447,14 @@ Optional<AutoTensorizeMappingInfo> GetAutoTensorizeMappingInfo(const tir::Schedu
if (!extractor.VisitStmt(block->block, desc_info.desc_block->block)) {
return NullOpt;
}
Array<IndexMap> mappings = MappingProposer::ProposeMappings(&extractor);
Array<IndexMap> mappings = AutoTensorizeMappingProposer::ProposeMappings(&extractor, &analyzer);
if (mappings.empty()) {
return NullOpt;
}
ObjectPtr<AutoTensorizeMappingInfoNode> ret = make_object<AutoTensorizeMappingInfoNode>();
// Only using 1 layout now
ret->mapping = std::move(mappings);
ret->mappings = std::move(mappings);
ret->lhs_buffer_map = std::move(extractor.lhs_buffer_map_);
ret->rhs_indices_map = std::move(extractor.rhs_buffer_indices_map_);
ret->rhs_buffer_indices = std::move(extractor.rhs_buffer_indices_map_);
ret->lhs_iters = std::move(extractor.lhs_iters_);
ret->rhs_iters = std::move(extractor.rhs_iters_);
return AutoTensorizeMappingInfo(ret);
Expand All @@ -2476,6 +2467,5 @@ TVM_REGISTER_GLOBAL("tir.schedule.GetAutoTensorizeMappingInfo")
return GetAutoTensorizeMappingInfo(sch->state(), sch->GetSRef(block), desc_func);
});

>>>>>>> 19a13545e ([WIP] Tensorize Mapping proposer)
} // namespace tir
} // namespace tvm
Loading

0 comments on commit dc57daf

Please sign in to comment.