Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] [Relay] Sparse Conv2d Implementation for 3x3 kernels #8605

Merged
merged 22 commits into from
Aug 27, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -1047,12 +1047,15 @@ struct SparseTransposeAttrs : public tvm::AttrsNode<SparseTransposeAttrs> {
/*! \brief Attributes for sparse_dense operator */
struct SparseConv2DAttrs : public tvm::AttrsNode<SparseConv2DAttrs> {
std::string layout;
int kernel_size;
jcf94 marked this conversation as resolved.
Show resolved Hide resolved

TVM_DECLARE_ATTRS(SparseConv2DAttrs, "relay.attrs.SparseConv2DAttrs") {
TVM_ATTR_FIELD(layout).set_default("NHWC").describe(
"Dimension ordering of input data. Can be 'NCHW', 'NHWC'"
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(kernel_size).set_default(1).describe(
"Kernel size for SparseConv2D, 1x1 or 3x3. ");
}
};

Expand Down
15 changes: 8 additions & 7 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,14 @@ def ref_input(self):

@ref_input.setter
def ref_input(self, val):
warnings.warn(
"You are specifying fixed input for tuning the operator. "
"Be sure your input always fits the operator. Some "
"operators may conduct layout transformation during tuning, "
"thus can lead to unexpected behaviors. ",
RuntimeWarning,
)
if val is not None:
warnings.warn(
"You are specifying fixed input for tuning the operator. "
"Be sure your input always fits the operator. Some "
"operators may conduct layout transformation during tuning, "
"thus can lead to unexpected behaviors. ",
RuntimeWarning,
)
self._ref_input = val

def set_task(self, task):
Expand Down
53 changes: 33 additions & 20 deletions python/tvm/relay/analysis/sparse_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _search_conv2d_op_weight(expr):
return _ffi_api.search_conv2d_op_weight(expr)


def process_params(expr, params, block_size, sparsity_threshold, layout):
def process_params(expr, params, block_size, sparsity_threshold, layout, kernel_size, reg_task_input=True):
"""Process parameters of conv2d from dense to sparse.

Parameters
Expand Down Expand Up @@ -86,14 +86,18 @@ def process_params(expr, params, block_size, sparsity_threshold, layout):
for name in weight_names:
name = str(name)
w_np = params[name].numpy()
# currently only support conv2d_1*1
if not (
(w_np.shape[0] == 1 and w_np.shape[1] == 1)
or (w_np.shape[2] == 1 and w_np.shape[3] == 1)
):

if layout == "NHWC": # HWIO
weight_kernel = (w_np.shape[0], w_np.shape[1])
elif layout == "NCHW": # OIHW
weight_kernel = (w_np.shape[2], w_np.shape[3])
if weight_kernel[0] != weight_kernel[1]:
continue
sparsity = 1.0 - (np.count_nonzero(w_np) / w_np.size)
if sparsity >= sparsity_threshold:

if weight_kernel[0] == kernel_size == 1:
sparsity = 1.0 - (np.count_nonzero(w_np) / w_np.size)
if sparsity < sparsity_threshold:
continue
if layout == "NHWC":
w_np = w_np.squeeze().T
elif layout == "NCHW":
Expand All @@ -108,19 +112,28 @@ def process_params(expr, params, block_size, sparsity_threshold, layout):
)
else:
sparse_weight_data = sparse_weight.data
elif weight_kernel[0] == kernel_size == 3 and layout == "NHWC":
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
w_np = w_np.reshape((-1, w_np.shape[-1])).T
sparse_weight = sp.bsr_matrix(w_np, blocksize=block_size)
if 1 - (sparse_weight.nnz / w_np.size) < sparsity_threshold:
continue
sparse_weight_data = sparse_weight.data
else:
continue

# remove dense weight
del params[name]
memo.weight_name.append(name)
memo.weight_shape.append(
list(sparse_weight_data.shape)
+ list(sparse_weight.indices.shape)
+ list(sparse_weight.indptr.shape)
)
params[name + ".data"] = tvm.nd.array(sparse_weight_data)
params[name + ".indices"] = tvm.nd.array(sparse_weight.indices)
params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr)

# remove dense weight
del params[name]
memo.weight_name.append(name)
memo.weight_shape.append(
list(sparse_weight_data.shape)
+ list(sparse_weight.indices.shape)
+ list(sparse_weight.indptr.shape)
)
params[name + ".data"] = tvm.nd.array(sparse_weight_data)
params[name + ".indices"] = tvm.nd.array(sparse_weight.indices)
params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr)

if reg_task_input:
prefix = "sparse_conv2d_bsr_%d_%d_%d_%d_%d_%d_" % (
w_np.shape[0],
w_np.shape[1],
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/data_dep_optimization/bsr_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .utils import _run_opt_pass


def convert(func, params, blocksize, sparsity_threshold, layout="NHWC"):
def convert(func, params, blocksize, sparsity_threshold, layout="NHWC", kernel_size=1):
"""Convert a dense func and according parameters to block sparse
jcf94 marked this conversation as resolved.
Show resolved Hide resolved

Parameters
Expand All @@ -49,10 +49,10 @@ def convert(func, params, blocksize, sparsity_threshold, layout="NHWC"):
params: Dict[Srting, tvm.nd.array]
New params with BSR matrix for mutated Expr
"""
weight_info = process_params(func, params, blocksize, sparsity_threshold, layout)
weight_info = process_params(func, params, blocksize, sparsity_threshold, layout, kernel_size)
new_func = _run_opt_pass(
func,
relay.transform.Conv2dToSparse(weight_info.weight_name, weight_info.weight_shape, layout),
relay.transform.Conv2dToSparse(weight_info.weight_name, weight_info.weight_shape, layout, kernel_size),
)

return new_func, params
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def compute_sparse_transpose(attrs, inputs, out_type):
@reg.register_compute("nn.sparse_conv2d")
def compute_sparse_conv2d(attrs, inputs, out_type):
"""Compute definition of sparse_conv2d"""
return [topi.nn.sparse_conv2d(inputs[0], inputs[1], inputs[2], inputs[3], attrs["layout"])]
return [topi.nn.sparse_conv2d(inputs[0], inputs[1], inputs[2], inputs[3], attrs["layout"], attrs["kernel_size"])]


reg.register_strategy("nn.sparse_conv2d", strategy.sparse_conv2d_strategy)
Expand Down
25 changes: 25 additions & 0 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,31 @@ def sparse_dense_strategy_cpu(attrs, inputs, out_type, target):
return strategy


@sparse_conv2d_strategy.register("cpu")
def sparse_conv2d_strategy_cpu(attrs, inputs, out_type, target):
"""sparse conv2d x86 strategy"""
strategy = _op.OpStrategy()
if attrs["kernel_size"] == 1:
strategy.add_implementation(
wrap_compute_sparse_conv2d(topi.nn.sparse_conv2d),
wrap_topi_schedule(topi.generic.schedule_sparse_conv2d),
name="sparse_conv2d.generic",
)
elif attrs["kernel_size"] == 3:
if attrs["layout"] == "NHWC":
strategy.add_implementation(
wrap_compute_sparse_conv2d(topi.x86.spconv2d_3x3_nhwc),
wrap_topi_schedule(topi.x86.schedule_spconv2d_3x3_nhwc),
name="conv3x3_spNHWC.x86",
)
elif attrs["layout"] == "NCHW":
strategy.add_implementation(
wrap_compute_sparse_conv2d(topi.x86.spconv2d_3x3_nchw),
wrap_topi_schedule(topi.x86.schedule_spconv2d_3x3_nchw),
)
return strategy


@roi_align_strategy.register("cpu")
def roi_align_strategy_cpu(attrs, inputs, out_type, target):
"""roi_align x86 strategy"""
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,7 @@ def DenseToSparse(weight_name, weight_shape):
return _ffi_api.DenseToSparse(weight_name, weight_shape)


def Conv2dToSparse(weight_name, weight_shape, layout):
def Conv2dToSparse(weight_name, weight_shape, layout, kernel_size):
"""
Rewrite qualified ```nn.conv2d operation``` to ```nn.sparse_conv2d```

Expand All @@ -1113,7 +1113,7 @@ def Conv2dToSparse(weight_name, weight_shape, layout):
ret : tvm.transform.Pass
The registered DenseToSparse pass.
"""
return _ffi_api.Conv2dToSparse(weight_name, weight_shape, layout)
return _ffi_api.Conv2dToSparse(weight_name, weight_shape, layout, kernel_size)


def SimplifyFCTranspose(target_weight_name):
Expand Down
19 changes: 10 additions & 9 deletions python/tvm/topi/nn/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def _compute_block(i, nb_j, j, h, w): # pylint: disable=C0103
)


def sparse_conv2d(dense_data, sparse_data, sparse_indices, sparse_indptr, layout="NHWC"):
def sparse_conv2d(dense_data, sparse_data, sparse_indices, sparse_indptr, layout="NHWC", kernel_size=1):
"""
Computes sparse-conv2d(1*1) of ``data`` and
``(weight_data, weight_indices, weight_indptr)``
Expand Down Expand Up @@ -598,14 +598,15 @@ def sparse_conv2d(dense_data, sparse_data, sparse_indices, sparse_indptr, layout
4-D with shape [M, H, W, N] (layout=NHWC)
4-D with shape [M, N, H ,W] (layout=NCHW)
"""
if layout == "NHWC":
return _sparse_conv2d_bsr_compute_nhwc(
dense_data, sparse_data, sparse_indices, sparse_indptr
)
elif layout == "NCHW":
return _sparse_conv2d_bsr_compute_nchw(
dense_data, sparse_data, sparse_indices, sparse_indptr
)
if kernel_size == 1:
if layout == "NHWC":
return _sparse_conv2d_bsr_compute_nhwc(
dense_data, sparse_data, sparse_indices, sparse_indptr
)
elif layout == "NCHW":
return _sparse_conv2d_bsr_compute_nchw(
dense_data, sparse_data, sparse_indices, sparse_indptr
)
else:
raise ValueError("Unsupport Layout %s" % layout)

Expand Down
147 changes: 146 additions & 1 deletion python/tvm/topi/x86/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
# under the License.

"""sparse_dense schedule on x86"""
from tvm import te
from tvm import te, tir, autotvm
from functools import partial, reduce

from ..transform import reshape
from ..utils import traverse_inline, get_const_int
from .utils import get_fp32_len

Expand Down Expand Up @@ -60,3 +62,146 @@ def _callback(op):

traverse_inline(s, outs[0].op, _callback)
return s


@autotvm.register_topi_compute("conv3x3_spNHWC.x86")
def spconv2d_3x3_nhwc(cfg, Data, Wdat, Wind, Wptr, layout="NHWC"):
N, H, W, CI = [i.value for i in Data.shape]
nElems, bsrR, bsrC = [i.value for i in Wdat.shape]
CO = (Wptr.shape[0].value - 1) * bsrR

Y, X, K = N*H*W, CO, 9*CI
cfg.define_split("tile_y", Y, num_outputs=3)
cfg.define_split("tile_x", X // bsrR, num_outputs=2)
cfg.add_flop(Y * (nElems * bsrC * bsrR * 2 - X))
if cfg.is_fallback:
cfg["tile_y"] = autotvm.task.space.SplitEntity([-1, 160, 8])
cfg["tile_x"] = autotvm.task.space.SplitEntity([-1, 4])

idxsplit = lambda x,y: reduce(lambda a,b: a[:-1]+[a[-1]%b,a[-1]//b], y, [x])

@partial(te.compute, (Y, K), name="Im2Col")
def Im2Col(row, col):
jw, jh, jn = idxsplit(row, [W, H])
jc, kw, kh = idxsplit(col, [CI, 3])
ih, iw = jh + kh - 1, jw + kw - 1
return tir.if_then_else(
tir.all(0 <= ih, ih < H, 0 <= iw, iw < W),
Data[jn, ih, iw, jc], 0)

@partial(te.compute, (Y, X // bsrR, bsrR, bsrC), name="CC")
def CC(drow, wrow, brow, bcol):
row_start, row_end = Wptr[wrow], Wptr[wrow+1]
elem_idx = te.reduce_axis((0, row_end - row_start), name="elem_idx")
elem = row_start + elem_idx
return te.sum(Im2Col[drow, Wind[elem]*bsrC + bcol] * Wdat[elem, brow, bcol], axis=elem_idx)

k = te.reduce_axis((0, bsrC), name="k")
C = te.compute((Y, X),
lambda y, x: te.sum(CC[y, x // bsrR, x % bsrR, k], axis=k),
name="C", tag="conv3x3_spNHWC")
return reshape(C, (N, H, W, CO))


@autotvm.register_topi_schedule("conv3x3_spNHWC.x86")
def schedule_spconv2d_3x3_nhwc(cfg, outs):
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if op.tag == "conv3x3_spNHWC":
C = op
CC, = op.input_tensors
Wptr, Wind, Im2Col, Wdat = CC.op.input_tensors
Data, = Im2Col.op.input_tensors
bsrR = CC.shape[-2].value
CI = Data.shape[-1].value

y, x = s[C].op.axis
yt, yo, yi = cfg["tile_y"].apply(s, C, y)
xo, xi = s[C].split(x, factor=bsrR)
xt, xo = cfg["tile_x"].apply(s, C, xo)
(k,) = s[C].op.reduce_axis
s[C].reorder(yt, xt, yo, xo, yi, xi, k)
s[C].unroll(k)
s[C].vectorize(xi)
s[C].unroll(yi)

s[CC].compute_at(s[C], xo)
yi, xi, r, c = s[CC].op.axis
(k,) = s[CC].op.reduce_axis
s[CC].reorder(xi, k, yi, r, c)
s[CC].unroll(c)
s[CC].vectorize(r)
s[CC].unroll(yi)

s[Im2Col].compute_at(s[C], yo)
yi, k = s[Im2Col].op.axis
ko, ki = s[Im2Col].split(k, factor=CI)
s[Im2Col].vectorize(ki)
#s[Im2Col].unroll(yi)

traverse_inline(s, outs[0].op, _callback)
return s


@autotvm.register_topi_compute("conv3x3_spNCHW.x86")
def spconv2d_3x3_nchw(cfg, Data, Wdat, Wind, Wptr, layout="NCHW"):
N, CI, H, W = [i.value for i in Data.shape]
NNZ, VL, bsrC = [i.value for i in Wdat.shape]
CO = (Wptr.shape[0].value - 1) * VL
assert bsrC == 1

cfg.add_flop(N*H*W * (NNZ * VL * bsrC * 2 - CO))
cfg.define_split("tile_hw", H*W, num_outputs=3)
cfg.define_split("tile_ckk", CI*9, num_outputs=3)

@partial(te.compute, (N, CI*3*3, H*W), name="im2col")
def Im2Col(n, ckk, hw):
jh, jw = hw // W, hw % W
ic, kh, kw = ckk // 9, ckk // 3 % 3, ckk % 3
ih, iw = jh + kh - 1, jw + kw - 1
return tir.if_then_else(
tir.all(0 <= ih, ih < H, 0 <= iw, iw < W),
Data[n, ic, ih, iw], 0)

@partial(te.compute, (N, CO // VL, VL, bsrC, H*W), name="CC", tag="conv3x3_spNCHW")
def CC(n, fo, fi, k, hw):
row_start, row_end = Wptr[fo], Wptr[fo+1]
elem_idx = te.reduce_axis((0, row_end - row_start), name="elem_idx")
elem = row_start + elem_idx
return te.sum(Im2Col[n, Wind[elem] * bsrC + k, hw] * Wdat[elem, fi, k],
axis=elem_idx)

return reshape(CC, [N, CO, H, W])


@autotvm.register_topi_schedule("conv3x3_spNCHW.x86")
def schedule_spconv2d_3x3_nchw(cfg, outs):
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if op.tag == "conv3x3_spNCHW":
CC = op
Wptr, Wind, im2col, Wdat = op.input_tensors
Data, = im2col.op.input_tensors

n, fo, fi, bc, hw = s[CC].op.axis
kk, = s[CC].op.reduce_axis
hw1, hw2, hw3 = cfg["tile_hw"].apply(s, CC, hw)
s[CC].reorder(n, hw1, fo, hw2, kk, fi, bc, hw3)
s[CC].unroll(fi)
s[CC].unroll(bc)
s[CC].vectorize(hw3)

s[im2col].compute_at(s[CC], hw1)
n, ckk, hw = s[im2col].op.axis
ckk1, ckk2, ckk3 = cfg["tile_ckk"].apply(s, im2col, ckk)
hw2, hw3 = s[im2col].split(hw, factor=cfg["tile_hw"].size[-1])
s[im2col].reorder(n, ckk1, ckk2, hw2, ckk3, hw3)
s[im2col].unroll(ckk3)
s[im2col].vectorize(hw3)

traverse_inline(s, outs[0].op, _callback)
return s
Loading