Skip to content

Commit

Permalink
[TOPI] [Relay] Sparse Conv2d Implementation for 3x3 kernels (apache#8605
Browse files Browse the repository at this point in the history
)

* [topi] add spconv2d_3x3 nhwc

* [relay] sparse_conv2d: add kernel_size attr

* [relay] add strategy for spconv2d_3x3 nhwc

* [relay] pass to convert spconv2d with const args

* [relay] convert sparse conv2d pass fixes

* use array for sparse conv2d attr

* fixup 1x1 tests; new 3x3 tests
Tantalus13A98B5F authored and Andrew Zhao Luo committed Sep 1, 2021
1 parent 3862ce6 commit 934b4e5
Showing 12 changed files with 548 additions and 50 deletions.
4 changes: 4 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
@@ -1066,12 +1066,16 @@ struct SparseTransposeAttrs : public tvm::AttrsNode<SparseTransposeAttrs> {
/*! \brief Attributes for sparse_dense operator */
struct SparseConv2DAttrs : public tvm::AttrsNode<SparseConv2DAttrs> {
std::string layout;
Array<IndexExpr> kernel_size;

TVM_DECLARE_ATTRS(SparseConv2DAttrs, "relay.attrs.SparseConv2DAttrs") {
TVM_ATTR_FIELD(layout).set_default("NHWC").describe(
"Dimension ordering of input data. Can be 'NCHW', 'NHWC'"
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(kernel_size)
.set_default(Array<IndexExpr>{1, 1})
.describe("Kernel size for SparseConv2D, 1x1 or 3x3. ");
}
};

15 changes: 8 additions & 7 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
@@ -254,13 +254,14 @@ def ref_input(self):

@ref_input.setter
def ref_input(self, val):
warnings.warn(
"You are specifying fixed input for tuning the operator. "
"Be sure your input always fits the operator. Some "
"operators may conduct layout transformation during tuning, "
"thus can lead to unexpected behaviors. ",
RuntimeWarning,
)
if val is not None:
warnings.warn(
"You are specifying fixed input for tuning the operator. "
"Be sure your input always fits the operator. Some "
"operators may conduct layout transformation during tuning, "
"thus can lead to unexpected behaviors. ",
RuntimeWarning,
)
self._ref_input = val

def set_task(self, task):
58 changes: 38 additions & 20 deletions python/tvm/relay/analysis/sparse_conv2d.py
Original file line number Diff line number Diff line change
@@ -54,7 +54,9 @@ def _search_conv2d_op_weight(expr):
return _ffi_api.search_conv2d_op_weight(expr)


def process_params(expr, params, block_size, sparsity_threshold, layout):
def process_params(
expr, params, block_size, sparsity_threshold, layout, kernel_size, reg_task_input=True
):
"""Process parameters of conv2d from dense to sparse.
Parameters
@@ -86,14 +88,18 @@ def process_params(expr, params, block_size, sparsity_threshold, layout):
for name in weight_names:
name = str(name)
w_np = params[name].numpy()
# currently only support conv2d_1*1
if not (
(w_np.shape[0] == 1 and w_np.shape[1] == 1)
or (w_np.shape[2] == 1 and w_np.shape[3] == 1)
):

if layout == "NHWC": # HWIO
weight_kernel = (w_np.shape[0], w_np.shape[1])
elif layout == "NCHW": # OIHW
weight_kernel = (w_np.shape[2], w_np.shape[3])
if weight_kernel[0] != weight_kernel[1]:
continue
sparsity = 1.0 - (np.count_nonzero(w_np) / w_np.size)
if sparsity >= sparsity_threshold:

if weight_kernel[0] == kernel_size == 1:
sparsity = 1.0 - (np.count_nonzero(w_np) / w_np.size)
if sparsity < sparsity_threshold:
continue
if layout == "NHWC":
w_np = w_np.squeeze().T
elif layout == "NCHW":
@@ -108,19 +114,31 @@ def process_params(expr, params, block_size, sparsity_threshold, layout):
)
else:
sparse_weight_data = sparse_weight.data
elif weight_kernel[0] == kernel_size == 3:
if layout == "NHWC": # HWIO
w_np = w_np.reshape((-1, w_np.shape[-1])).T
elif layout == "NCHW": # OIHW
w_np = w_np.reshape((w_np.shape[0], -1))
sparse_weight = sp.bsr_matrix(w_np, blocksize=block_size)
if 1 - (sparse_weight.nnz / w_np.size) < sparsity_threshold:
continue
sparse_weight_data = sparse_weight.data
else:
continue

# remove dense weight
del params[name]
memo.weight_name.append(name)
memo.weight_shape.append(
list(sparse_weight_data.shape)
+ list(sparse_weight.indices.shape)
+ list(sparse_weight.indptr.shape)
)
params[name + ".data"] = tvm.nd.array(sparse_weight_data)
params[name + ".indices"] = tvm.nd.array(sparse_weight.indices)
params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr)

# remove dense weight
del params[name]
memo.weight_name.append(name)
memo.weight_shape.append(
list(sparse_weight_data.shape)
+ list(sparse_weight.indices.shape)
+ list(sparse_weight.indptr.shape)
)
params[name + ".data"] = tvm.nd.array(sparse_weight_data)
params[name + ".indices"] = tvm.nd.array(sparse_weight.indices)
params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr)

if reg_task_input:
prefix = "sparse_conv2d_bsr_%d_%d_%d_%d_%d_%d_" % (
w_np.shape[0],
w_np.shape[1],
44 changes: 40 additions & 4 deletions python/tvm/relay/data_dep_optimization/bsr_conv2d.py
Original file line number Diff line number Diff line change
@@ -23,8 +23,8 @@
from .utils import _run_opt_pass


def convert(func, params, blocksize, sparsity_threshold, layout="NHWC"):
"""Convert a dense func and according parameters to block sparse
def convert(func, params, blocksize, sparsity_threshold, layout="NHWC", kernel_size=1):
"""Convert a conv2d func and according parameters to block sparse
Parameters
----------
@@ -49,10 +49,46 @@ def convert(func, params, blocksize, sparsity_threshold, layout="NHWC"):
params: Dict[Srting, tvm.nd.array]
New params with BSR matrix for mutated Expr
"""
weight_info = process_params(func, params, blocksize, sparsity_threshold, layout)
weight_info = process_params(func, params, blocksize, sparsity_threshold, layout, kernel_size)
new_func = _run_opt_pass(
func,
relay.transform.Conv2dToSparse(weight_info.weight_name, weight_info.weight_shape, layout),
relay.transform.Conv2dToSparse(
weight_info.weight_name, weight_info.weight_shape, layout, kernel_size
),
)

return new_func, params


def convert2(func, params, blocksize, sparsity_threshold, layout, kernel_size):
"""Convert a freezed conv2d func to block sparse
Parameters
----------
func : relay.Expr
Expr will be optimized to sparse operation, with params freezed
params : Dict[Srting, tvm.nd.array]
Parameters of the Expr (not used in this pass)
blocksize : Tuple(int, int)
Blocksize for BSR matrix
sparsity_threshold : float
Minimal sparsity requirement for converting.
If weight sparsity is lower than this threshold,
the dense operation will be kept.
layout : str
layout of network
kernel_size : int
kernel size of the conv2d, for filtering
Returns
-------
new_func: relay.Expr
Mutated Expr with sparse operations
params: Dict[Srting, tvm.nd.array]
New params with BSR matrix for mutated Expr (not modified)
"""
new_func = _run_opt_pass(
func, relay.transform.Conv2dToSparse2(layout, kernel_size, blocksize, sparsity_threshold)
)
return new_func, params
6 changes: 5 additions & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
@@ -198,7 +198,11 @@ def compute_sparse_transpose(attrs, inputs, out_type):
@reg.register_compute("nn.sparse_conv2d")
def compute_sparse_conv2d(attrs, inputs, out_type):
"""Compute definition of sparse_conv2d"""
return [topi.nn.sparse_conv2d(inputs[0], inputs[1], inputs[2], inputs[3], attrs["layout"])]
return [
topi.nn.sparse_conv2d(
inputs[0], inputs[1], inputs[2], inputs[3], attrs["layout"], attrs["kernel_size"]
)
]


reg.register_strategy("nn.sparse_conv2d", strategy.sparse_conv2d_strategy)
25 changes: 25 additions & 0 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
@@ -565,6 +565,31 @@ def sparse_dense_strategy_cpu(attrs, inputs, out_type, target):
return strategy


@sparse_conv2d_strategy.register("cpu")
def sparse_conv2d_strategy_cpu(attrs, inputs, out_type, target):
"""sparse conv2d x86 strategy"""
strategy = _op.OpStrategy()
if attrs["kernel_size"][0] == 1:
strategy.add_implementation(
wrap_compute_sparse_conv2d(topi.nn.sparse_conv2d),
wrap_topi_schedule(topi.generic.schedule_sparse_conv2d),
name="sparse_conv2d.generic",
)
elif attrs["kernel_size"][0] == 3:
if attrs["layout"] == "NHWC":
strategy.add_implementation(
wrap_compute_sparse_conv2d(topi.x86.spconv2d_3x3_nhwc),
wrap_topi_schedule(topi.x86.schedule_spconv2d_3x3_nhwc),
name="conv3x3_spNHWC.x86",
)
elif attrs["layout"] == "NCHW":
strategy.add_implementation(
wrap_compute_sparse_conv2d(topi.x86.spconv2d_3x3_nchw),
wrap_topi_schedule(topi.x86.schedule_spconv2d_3x3_nchw),
)
return strategy


@roi_align_strategy.register("cpu")
def roi_align_strategy_cpu(attrs, inputs, out_type, target):
"""roi_align x86 strategy"""
24 changes: 22 additions & 2 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
@@ -1093,7 +1093,7 @@ def DenseToSparse(weight_name, weight_shape):
return _ffi_api.DenseToSparse(weight_name, weight_shape)


def Conv2dToSparse(weight_name, weight_shape, layout):
def Conv2dToSparse(weight_name, weight_shape, layout, kernel_size):
"""
Rewrite qualified ```nn.conv2d operation``` to ```nn.sparse_conv2d```
@@ -1113,7 +1113,27 @@ def Conv2dToSparse(weight_name, weight_shape, layout):
ret : tvm.transform.Pass
The registered DenseToSparse pass.
"""
return _ffi_api.Conv2dToSparse(weight_name, weight_shape, layout)
return _ffi_api.Conv2dToSparse(weight_name, weight_shape, layout, kernel_size)


def Conv2dToSparse2(layout, kernel_size, blocksize, sparsity_threshold):
"""
Rewrite freezed ```nn.conv2d``` operation to ```nn.sparse_conv2d```
Parameters
----------
layout : str
layout of data
kernel_size : int
kernel size of conv2d
Returns
-------
ret : tvm.transform.Pass
The registered DenseToSparse pass.
"""
return _ffi_api.Conv2dToSparse2(layout, kernel_size, *blocksize, sparsity_threshold)


def SimplifyFCTranspose(target_weight_name):
21 changes: 12 additions & 9 deletions python/tvm/topi/nn/sparse.py
Original file line number Diff line number Diff line change
@@ -566,7 +566,9 @@ def _compute_block(i, nb_j, j, h, w): # pylint: disable=C0103
)


def sparse_conv2d(dense_data, sparse_data, sparse_indices, sparse_indptr, layout="NHWC"):
def sparse_conv2d(
dense_data, sparse_data, sparse_indices, sparse_indptr, layout="NHWC", kernel_size=1
):
"""
Computes sparse-conv2d(1*1) of ``data`` and
``(weight_data, weight_indices, weight_indptr)``
@@ -598,14 +600,15 @@ def sparse_conv2d(dense_data, sparse_data, sparse_indices, sparse_indptr, layout
4-D with shape [M, H, W, N] (layout=NHWC)
4-D with shape [M, N, H ,W] (layout=NCHW)
"""
if layout == "NHWC":
return _sparse_conv2d_bsr_compute_nhwc(
dense_data, sparse_data, sparse_indices, sparse_indptr
)
elif layout == "NCHW":
return _sparse_conv2d_bsr_compute_nchw(
dense_data, sparse_data, sparse_indices, sparse_indptr
)
if kernel_size == 1:
if layout == "NHWC":
return _sparse_conv2d_bsr_compute_nhwc(
dense_data, sparse_data, sparse_indices, sparse_indptr
)
elif layout == "NCHW":
return _sparse_conv2d_bsr_compute_nchw(
dense_data, sparse_data, sparse_indices, sparse_indptr
)
else:
raise ValueError("Unsupport Layout %s" % layout)

162 changes: 161 additions & 1 deletion python/tvm/topi/x86/sparse.py
Original file line number Diff line number Diff line change
@@ -16,8 +16,10 @@
# under the License.

"""sparse_dense schedule on x86"""
from tvm import te
from functools import partial, reduce
from tvm import te, tir, autotvm

from ..transform import reshape
from ..utils import traverse_inline, get_const_int
from .utils import get_fp32_len

@@ -60,3 +62,161 @@ def _callback(op):

traverse_inline(s, outs[0].op, _callback)
return s


@autotvm.register_topi_compute("conv3x3_spNHWC.x86")
def spconv2d_3x3_nhwc(cfg, data, wdat, wind, wptr, layout="NHWC"):
"""Sparse Conv2d 3x3 compute (NHWC)."""
assert layout == "NHWC"
nsamples, imh, imw, chanin = [i.value for i in data.shape]
nelems, bsrr, bsrc = [i.value for i in wdat.shape]
chanout = (wptr.shape[0].value - 1) * bsrr

imglen, chanlen = nsamples * imh * imw, 9 * chanin
cfg.define_split("tile_y", imglen, num_outputs=3)
cfg.define_split("tile_x", chanout // bsrr, num_outputs=2)
cfg.add_flop(imglen * (nelems * bsrc * bsrr * 2 - chanout))
if cfg.is_fallback:
cfg["tile_y"] = autotvm.task.space.SplitEntity([-1, 160, 8])
cfg["tile_x"] = autotvm.task.space.SplitEntity([-1, 4])

idxsplit = lambda x, y: reduce(lambda a, b: a[:-1] + [a[-1] % b, a[-1] // b], y, [x])

@partial(te.compute, (imglen, chanlen), name="Im2Col")
def im2col(row, col):
j_w, j_h, j_n = idxsplit(row, [imw, imh])
j_c, k_w, k_h = idxsplit(col, [chanin, 3])
i_h, i_w = j_h + k_h - 1, j_w + k_w - 1
return tir.if_then_else(
tir.all(i_h >= 0, i_h < imh, i_w >= 0, i_w < imw), data[j_n, i_h, i_w, j_c], 0
)

@partial(te.compute, (imglen, chanout // bsrr, bsrr, bsrc), name="CC")
def matmul(drow, wrow, brow, bcol):
row_start, row_end = wptr[wrow], wptr[wrow + 1]
elem_idx = te.reduce_axis((0, row_end - row_start), name="elem_idx")
elem = row_start + elem_idx
return te.sum(
im2col[drow, wind[elem] * bsrc + bcol] * wdat[elem, brow, bcol], axis=elem_idx
)

sum_bsrc = te.reduce_axis((0, bsrc), name="k")
ret = te.compute(
(imglen, chanout),
lambda y, x: te.sum(matmul[y, x // bsrr, x % bsrr, sum_bsrc], axis=sum_bsrc),
name="C",
tag="conv3x3_spNHWC",
)
return reshape(ret, (nsamples, imh, imw, chanout))


@autotvm.register_topi_schedule("conv3x3_spNHWC.x86")
def schedule_spconv2d_3x3_nhwc(cfg, outs):
"""Sparse Conv2d 3x3 schedule (NHWC)."""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if op.tag == "conv3x3_spNHWC":
(matmul,) = op.input_tensors
# wptr, wind, im2col, wdat
_, _, im2col, _ = matmul.op.input_tensors
(data,) = im2col.op.input_tensors
bsrr = matmul.shape[-2].value
chanin = data.shape[-1].value

mm_y, mm_x = s[op].op.axis
y_t, y_o, y_i = cfg["tile_y"].apply(s, op, mm_y)
x_o, x_i = s[op].split(mm_x, factor=bsrr)
x_t, x_o = cfg["tile_x"].apply(s, op, x_o)
(sum_ax,) = s[op].op.reduce_axis
s[op].reorder(y_t, x_t, y_o, x_o, y_i, x_i, sum_ax)
s[op].unroll(sum_ax)
s[op].vectorize(x_i)
s[op].unroll(y_i)

s[matmul].compute_at(s[op], x_o)
y_i, x_i, bsrr, bsrc = s[matmul].op.axis
(sum_ax,) = s[matmul].op.reduce_axis
s[matmul].reorder(x_i, sum_ax, y_i, bsrr, bsrc)
s[matmul].unroll(bsrc)
s[matmul].vectorize(bsrr)
s[matmul].unroll(y_i)

s[im2col].compute_at(s[op], y_o)
y_i, sum_ax = s[im2col].op.axis
_, k_i = s[im2col].split(sum_ax, factor=chanin)
s[im2col].vectorize(k_i)

traverse_inline(s, outs[0].op, _callback)
return s


@autotvm.register_topi_compute("conv3x3_spNCHW.x86")
def spconv2d_3x3_nchw(cfg, data, wdat, wind, wptr, layout="NCHW"):
"""Sparse Conv2d 3x3 compute (NCHW)."""
nsamples, chanin, imgh, imgw = [i.value for i in data.shape]
nelems, veclen, bsrc = [i.value for i in wdat.shape]
chanout = (wptr.shape[0].value - 1) * veclen
assert bsrc == 1 and layout == "NCHW"

cfg.add_flop(nsamples * imgh * imgw * (nelems * veclen * bsrc * 2 - chanout))
cfg.define_split("tile_hw", imgh * imgw, num_outputs=3)
cfg.define_split("tile_ckk", chanin * 9, num_outputs=3)

@partial(te.compute, (nsamples, chanin * 3 * 3, imgh * imgw), name="im2col")
def im2col(nsamples, ckk, imglen):
j_h, j_w = imglen // imgw, imglen % imgw
i_c, k_h, k_w = ckk // 9, ckk // 3 % 3, ckk % 3
i_h, i_w = j_h + k_h - 1, j_w + k_w - 1
return tir.if_then_else(
tir.all(i_h >= 0, i_h < imgh, i_w >= 0, i_w < imgw), data[nsamples, i_c, i_h, i_w], 0
)

@partial(
te.compute,
(nsamples, chanout // veclen, veclen, bsrc, imgh * imgw),
name="CC",
tag="conv3x3_spNCHW",
)
def matmul(nsamples, f_o, f_i, bsrk, imglen):
row_start, row_end = wptr[f_o], wptr[f_o + 1]
elem_idx = te.reduce_axis((0, row_end - row_start), name="elem_idx")
elem = row_start + elem_idx
return te.sum(
im2col[nsamples, wind[elem] * bsrc + bsrk, imglen] * wdat[elem, f_i, bsrk],
axis=elem_idx,
)

return reshape(matmul, [nsamples, chanout, imgh, imgw])


@autotvm.register_topi_schedule("conv3x3_spNCHW.x86")
def schedule_spconv2d_3x3_nchw(cfg, outs):
"""Sparse Conv2d 3x3 schedule (NCHW)."""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if op.tag == "conv3x3_spNCHW":
# wptr, wind, im2col, wdat
_, _, im2col, _ = op.input_tensors

n_samples, f_o, f_i, b_c, imglen = s[op].op.axis
(sum_ax,) = s[op].op.reduce_axis
hw1, hw2, hw3 = cfg["tile_hw"].apply(s, op, imglen)
s[op].reorder(n_samples, hw1, f_o, hw2, sum_ax, f_i, b_c, hw3)
s[op].unroll(f_i)
s[op].unroll(b_c)
s[op].vectorize(hw3)

s[im2col].compute_at(s[op], hw1)
n_samples, ckk, imglen = s[im2col].op.axis
ckk1, ckk2, ckk3 = cfg["tile_ckk"].apply(s, im2col, ckk)
hw2, hw3 = s[im2col].split(imglen, factor=cfg["tile_hw"].size[-1])
s[im2col].reorder(n_samples, ckk1, ckk2, hw2, ckk3, hw3)
s[im2col].unroll(ckk3)
s[im2col].vectorize(hw3)

traverse_inline(s, outs[0].op, _callback)
return s
3 changes: 2 additions & 1 deletion src/relay/op/nn/sparse.cc
Original file line number Diff line number Diff line change
@@ -274,10 +274,11 @@ bool SparseConv2dRel(const Array<Type>& types, int num_inputs, const Attrs& attr
}

Expr MakeSparseConv2d(Expr data, Expr weight_data, Expr weight_indices, Expr weight_indptr,
std::string layout) {
std::string layout, Array<IndexExpr> kernel_size) {
static const Op& op = Op::Get("nn.sparse_conv2d");
auto attrs = make_object<SparseConv2DAttrs>();
attrs->layout = std::move(layout);
attrs->kernel_size = std::move(kernel_size);
return Call(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {});
}

173 changes: 168 additions & 5 deletions src/relay/transforms/convert_sparse_conv2d.cc
Original file line number Diff line number Diff line change
@@ -73,10 +73,12 @@ TVM_REGISTER_GLOBAL("relay.analysis.search_conv2d_op_weight").set_body_typed(Sea
class Conv2dToSparseConv2dMutator : public ExprRewriter {
public:
Conv2dToSparseConv2dMutator(const Array<ObjectRef>& weight_name,
const Array<Array<PrimExpr>>& weight_shape, const String& layout)
const Array<Array<PrimExpr>>& weight_shape, const String& layout,
int kernel_size)
: conv2d_op_(Op::Get("nn.conv2d")), sparse_conv2d_op_(Op::Get("nn.sparse_conv2d")) {
ICHECK_EQ(weight_name.size(), weight_shape.size());
layout_ = layout;
kernel_size_ = kernel_size;
for (size_t i = 0; i < weight_name.size(); ++i) {
ICHECK(weight_name[i]->IsInstance<runtime::StringObj>());
std::string k = weight_name[i].as<runtime::StringObj>()->data;
@@ -112,6 +114,7 @@ class Conv2dToSparseConv2dMutator : public ExprRewriter {
Var weight_indptr(prefix + ".indptr", ws_indptr_type);
auto attrs = make_object<SparseConv2DAttrs>();
attrs->layout = std::move(layout_);
attrs->kernel_size = Array<IndexExpr>{kernel_size_, kernel_size_};
return Call(sparse_conv2d_op_, {data, weight_data, weight_indices, weight_indptr},
Attrs(attrs));
}
@@ -126,22 +129,168 @@ class Conv2dToSparseConv2dMutator : public ExprRewriter {
const Op& sparse_conv2d_op_;
std::unordered_map<std::string, std::vector<int>> target_weights_;
String layout_;
int kernel_size_;
}; // class Conv2dToSparseConv2dAlter

Expr Conv2dToSparse(const Expr& e, const Array<ObjectRef>& weight_name,
const Array<Array<PrimExpr>>& weight_shape, const String& layout) {
auto rewriter = Conv2dToSparseConv2dMutator(weight_name, weight_shape, layout);
const Array<Array<PrimExpr>>& weight_shape, const String& layout,
int kernel_size) {
auto rewriter = Conv2dToSparseConv2dMutator(weight_name, weight_shape, layout, kernel_size);
return PostOrderRewrite(e, &rewriter);
}

template <typename elemTy, size_t... Is>
auto unpack_to_tuple_internal(elemTy* arr, std::index_sequence<Is...>) {
return std::make_tuple(arr[Is]...);
}

template <int N, typename elemTy>
auto unpack_to_tuple(elemTy* arr) {
return unpack_to_tuple_internal(arr, std::make_index_sequence<N>{});
}

struct Range {
size_t dim;
explicit Range(size_t d) : dim(d) {}

struct iterpoint {
size_t val, lim;
iterpoint(size_t v1, size_t v2) : val(v1), lim(v2) {}

size_t operator*() const { return val; }

iterpoint operator/(const iterpoint& rhs) const {
return iterpoint(val * rhs.lim + rhs.val, lim * rhs.lim);
}
};

struct iterator {
size_t val, lim;
iterator(size_t v1, size_t v2) : val(v1), lim(v2) {}

bool operator!=(const iterator& rhs) const { return val != rhs.val; }

void operator++() { ++val; }

iterpoint operator*() const { return iterpoint(val, lim); }
};

iterator begin() { return iterator(0, dim); }

iterator end() { return iterator(dim, dim); }
};

// Mutate ```nn.conv2d``` to ```nn.sparse_conv2d```
class Conv2dToSparseConv2dMutator2 : public ExprRewriter {
public:
Conv2dToSparseConv2dMutator2(const String& layout, int kernel_size, int blockH, int blockW,
double sparse_thresh)
: sparse_conv2d_op_(Op::Get("nn.sparse_conv2d")),
dev_cpu0_{DLDeviceType::kDLCPU, 0},
layout_(layout),
kernel_size_(kernel_size),
blockH_(blockH),
blockW_(blockW),
sparse_thresh_(sparse_thresh) {}

Expr Rewrite_(const CallNode* pre, const Expr& post) override {
// check op type & attrs
const auto pre_attrs = pre->attrs.as<Conv2DAttrs>();
if (!pre_attrs || pre_attrs->data_layout != layout_ ||
pre_attrs->strides[0].as<IntImmNode>()->value != 1 ||
pre_attrs->kernel_size[0].as<IntImmNode>()->value != kernel_size_)
return post;
// check constant weight
const auto pre_weight_node = pre->args[1].as<ConstantNode>();
if (!pre_weight_node) return post;

// check weight dtype & shape
auto&& pre_weight = pre_weight_node->data;
auto dtype = pre_weight.DataType(), itype = runtime::DataType::Int(32);
ICHECK(dtype.code() == DataType::kFloat && dtype.bits() == 32); // float32 only
auto pre_weight_shape = unpack_to_tuple<4>(pre_weight.Shape().data());
int O, I, H, W;
if (layout_ == "NCHW") {
std::tie(O, I, H, W) = pre_weight_shape;
} else { // NHWC
std::tie(H, W, I, O) = pre_weight_shape;
}
int CO = O, CI = H * W * I;

// copy to vector
std::vector<float> pre_weight_data(CO * CI);
pre_weight.CopyToBytes(pre_weight_data.data(), pre_weight_data.size() * sizeof(float));
if (layout_ == "NHWC") {
std::vector<float> tmp(pre_weight_data.size());
for (auto i : Range(CO))
for (auto j : Range(CI)) tmp[*(i / j)] = pre_weight_data[*(j / i)];
std::swap(tmp, pre_weight_data);
}
// convert to BSR
std::vector<float> wdata, block(blockH_ * blockW_);
std::vector<int32_t> windices, windptr;
for (auto bh : Range(CO / blockH_)) {
windptr.push_back(windices.size());
for (auto bw : Range(CI / blockW_)) {
int cntnnz = 0;
for (auto i : Range(blockH_))
for (auto j : Range(blockW_)) {
auto tmp = pre_weight_data[*(bh / i / bw / j)];
if (tmp) cntnnz++;
block[*(i / j)] = tmp;
}
if (cntnnz) {
wdata.insert(wdata.end(), block.begin(), block.end());
windices.push_back(*bw);
}
}
}
windptr.push_back(windices.size());
double sprate = 1 - 1.0 * wdata.size() / pre_weight_data.size();
if (sprate < sparse_thresh_) return post;

// constrct return data
int nnz = windices.size();
auto weight_data = runtime::NDArray::Empty({nnz, blockH_, blockW_}, dtype, dev_cpu0_);
auto weight_indices = runtime::NDArray::Empty({nnz}, itype, dev_cpu0_);
auto weight_indptr = runtime::NDArray::Empty({CO / blockH_ + 1}, itype, dev_cpu0_);
weight_data.CopyFromBytes(wdata.data(), wdata.size() * sizeof(float));
weight_indices.CopyFromBytes(windices.data(), windices.size() * sizeof(int32_t));
weight_indptr.CopyFromBytes(windptr.data(), windptr.size() * sizeof(int32_t));

// construct return call
auto args = runtime::Array<relay::Expr>{post.as<CallNode>()->args[0], Constant(weight_data),
Constant(weight_indices), Constant(weight_indptr)};
auto attrs = make_object<SparseConv2DAttrs>();
attrs->layout = layout_;
attrs->kernel_size = Array<IndexExpr>{kernel_size_, kernel_size_};
return Call(sparse_conv2d_op_, args, Attrs(attrs));
}

private:
const Op& sparse_conv2d_op_;
DLDevice dev_cpu0_;
String layout_;
int kernel_size_, blockH_, blockW_;
double sparse_thresh_;
}; // class Conv2dToSparseConv2dMutator2

Expr Conv2dToSparse2(const Expr& e, const String& layout, int kernel_size, int blockH, int blockW,
double sparse_thresh) {
auto rewriter = Conv2dToSparseConv2dMutator2(layout, kernel_size, blockH, blockW, sparse_thresh);
return PostOrderRewrite(e, &rewriter);
}

namespace transform {

// Convert a model with seperate weight info (already sparsified).
Pass Conv2dToSparse(const Array<ObjectRef>& weight_name, const Array<Array<PrimExpr>>& weight_shape,
const String& layout) {
const String& layout, int kernel_size) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
// Remove FreeVar warnings
auto f0 = Downcast<Function>(Conv2dToSparse(f, weight_name, weight_shape, layout));
auto f0 =
Downcast<Function>(Conv2dToSparse(f, weight_name, weight_shape, layout, kernel_size));
Array<Var> sparse_params = FreeVars(f0);
auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs);
Array<Var> params = FreeVars(f1);
@@ -155,6 +304,20 @@ Pass Conv2dToSparse(const Array<ObjectRef>& weight_name, const Array<Array<PrimE

TVM_REGISTER_GLOBAL("relay._transform.Conv2dToSparse").set_body_typed(Conv2dToSparse);

// Convert a model with freezed params (sparsified in the pass).
Pass Conv2dToSparse2(const String& layout, int kernel_size, int blockH, int blockW,
double sparse_thresh) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
auto f0 = Downcast<Function>(
Conv2dToSparse2(f, layout, kernel_size, blockH, blockW, sparse_thresh));
return f0;
};
return CreateFunctionPass(pass_func, 5, "Conv2dToSparse2", {"DeadCodeElimination"});
}

TVM_REGISTER_GLOBAL("relay._transform.Conv2dToSparse2").set_body_typed(Conv2dToSparse2);

} // namespace transform

} // namespace relay
63 changes: 63 additions & 0 deletions tests/python/relay/test_sparse_conv2d_convert.py
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@
from tvm.ir import IRModule
from tvm import relay
from tvm.topi.sparse.utils import random_bsr_matrix
from tvm.relay.build_module import bind_params_by_name


def run_func(func, params, x):
@@ -100,6 +101,68 @@ def test_bsr_sparse_conv2d_nhwc():
np.testing.assert_allclose(sparse_output, dense_output, atol=1e-5, rtol=1e-5)


def test_bsr_sparse_conv2d_3x3_nchw():
data = relay.var("data", shape=(1, 64, 32, 32), dtype="float32")
x = relay.nn.relu(data)
w = relay.var("weight", shape=(128, 64, 3, 3), dtype="float32")
y = relay.nn.conv2d(
x, w, channels=128, kernel_size=3, padding=1, data_layout="NCHW", kernel_layout="OIHW"
)
z = relay.nn.relu(y)
func = relay.Function(relay.analysis.free_vars(z), z)

params = {
"weight": tvm.nd.array(
np.array(random_bsr_matrix(128, 64 * 9, 16, 1, 0.1, "float32").todense()).reshape(
128, 64, 3, 3
)
)
}

x_np = np.random.randn(1, 64, 32, 32).astype("float32")
# dense output
dense_output = run_func(func, params, x_np)
# sparse
func = bind_params_by_name(func, params)
sparse_func, params = relay.data_dep_optimization.bsr_conv2d.convert2(
func, {}, (16, 1), 0.2, "NCHW", 3
)
sparse_output = run_func(sparse_func, params, x_np)
np.testing.assert_allclose(sparse_output, dense_output, atol=1e-5, rtol=1e-5)


def test_bsr_sparse_conv2d_3x3_nhwc():
data = relay.var("data", shape=(1, 32, 32, 64), dtype="float32")
x = relay.nn.relu(data)
w = relay.var("weight", shape=(3, 3, 64, 128), dtype="float32")
y = relay.nn.conv2d(
x, w, channels=128, kernel_size=3, padding=1, data_layout="NHWC", kernel_layout="HWIO"
)
z = relay.nn.relu(y)
func = relay.Function(relay.analysis.free_vars(z), z)

params = {
"weight": tvm.nd.array(
np.array(random_bsr_matrix(128, 64 * 9, 16, 1, 0.1, "float32").todense()).T.reshape(
3, 3, 64, 128
)
)
}

x_np = np.random.randn(1, 32, 32, 64).astype("float32")
# dense output
dense_output = run_func(func, params, x_np)
# sparse
func = bind_params_by_name(func, params)
sparse_func, params = relay.data_dep_optimization.bsr_conv2d.convert2(
func, {}, (16, 1), 0.2, "NHWC", 3
)
sparse_output = run_func(sparse_func, params, x_np)
np.testing.assert_allclose(sparse_output, dense_output, atol=1e-5, rtol=1e-5)


if __name__ == "__main__":
test_bsr_sparse_conv2d_nhwc()
test_bsr_sparse_conv2d_nchw()
test_bsr_sparse_conv2d_3x3_nhwc()
test_bsr_sparse_conv2d_3x3_nchw()

0 comments on commit 934b4e5

Please sign in to comment.