Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Tunable Template for Conv2D HWCN on CUDA #4168

Merged
merged 3 commits into from
Oct 24, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def args_to_workload(x, topi_compute_func=None):
elif x is None:
workload = 0
else:
raise RuntimeError('Do not support type "%s" in argument. Consider to use'
raise RuntimeError('Do not support type "%s" in argument. Consider to use '
'primitive types only' % type(x))
return (get_func_name(topi_compute_func), ) + workload if topi_compute_func else workload

Expand Down
7 changes: 5 additions & 2 deletions python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,12 @@ def _topi_nn_conv2d(*args, **kwargs):
args = deserialize_args(args)
A, W = args[:2]
layout = args[-2]
assert layout == 'NCHW', "only support NCHW currently"
assert layout == 'NCHW' or layout == 'HWCN', "only support NCHW/HWCN currently"
C = topi.nn.conv2d(*args, **kwargs)
s = topi.generic.schedule_conv2d_nchw([C])
if layout == 'NCHW':
s = topi.generic.schedule_conv2d_nchw([C])
else:
s = topi.generic.schedule_conv2d_hwcn([C])
return s, [A, W, C]

@register("topi_nn_depthwise_conv2d_nchw")
Expand Down
12 changes: 7 additions & 5 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,14 @@ def compute_conv2d(attrs, inputs, out_type, target):
out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
else out_dtype)

assert layout in ["NCHW", "NHWC", "NCHW4c"]
assert layout in ["NCHW", "NHWC", "NCHW4c", "HWCN"]
(dilation_h, dilation_w) = dilation
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")

def _get_out_depth():
weight_shape = get_const_tuple(inputs[1].shape)
if kernel_layout == "HWOI":
if kernel_layout.startswith("HW"):
return weight_shape[2] * weight_shape[3]
return weight_shape[0] * weight_shape[1]

Expand Down Expand Up @@ -192,11 +192,13 @@ def schedule_conv2d(attrs, outs, target):
with target:
if groups == 1 and layout == "NCHW":
return topi.generic.schedule_conv2d_nchw(outs)
if groups == 1 and layout == "NCHW4c":
elif groups == 1 and layout == "NCHW4c":
return topi.generic.schedule_conv2d_nchw(outs)
if groups == 1 and layout == "NHWC":
elif groups == 1 and layout == "NHWC":
return topi.generic.schedule_conv2d_nhwc(outs)
if groups != 1:
elif groups == 1 and layout == "HWCN":
return topi.generic.schedule_conv2d_hwcn(outs)
elif groups != 1:
# collect in_channels to distinguish depthwise and group conv2d
op = _find_conv2d_op(outs[0].op)
assert op is not None
Expand Down
4 changes: 2 additions & 2 deletions src/pass/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ class Vectorizer : public IRMutator {
CHECK(!op->extent.type().is_vector());
Expr extent = Mutate(op->extent);
if (extent.type().is_vector()) {
LOG(WARNING) << "Detect vectorized extent type, scalarizing...";
// LOG(WARNING) << "Detect vectorized extent type, scalarizing...";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove comment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@merrymercy Do we need this warning?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason we want to remove this warning is that AutoTVM may trigger it with some candidate configs in the tuning process and it looks annoying. If we really need to keep this warning, one alternative solution is trying to hide all warnings in AutoTVM, although I am not sure if it is doable since they are all managed by the same logging system.

return Scalarize(s);
}
Stmt body = Mutate(op->body);
Expand All @@ -386,7 +386,7 @@ class Vectorizer : public IRMutator {
CHECK(!op->condition.type().is_vector());
Expr condition = this->Mutate(op->condition);
if (condition.type().is_vector()) {
LOG(WARNING) << "Detect vector condition in Vectorized Loop, scalarizing...";
// LOG(WARNING) << "Detect vector condition in Vectorized Loop, scalarizing...";
return Scalarize(s);
}
Stmt then_case = this->Mutate(op->then_case);
Expand Down
87 changes: 50 additions & 37 deletions topi/python/topi/cuda/conv2d_hwcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@
# pylint: disable=invalid-name, too-many-locals, too-many-statements
"""Schedule for conv2d_hwcn with auto fusion"""
import tvm
from .. import tag
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity

def schedule_conv2d_hwcn(outs):
from .. import generic, tag


@autotvm.register_topi_schedule(generic.schedule_conv2d_hwcn, ["cuda", "gpu"], ["direct"])
def schedule_conv2d_hwcn(cfg, outs):
"""Schedule for conv2d_hwcn and any element-wise operations.

Parameters
Expand Down Expand Up @@ -51,36 +56,44 @@ def schedule(Apad, W, B):
sch[B].set_scope("local")
BL = B

tile = 8
num_thread = 8
block_factor = tile * num_thread
step = 8
vthread = 2
hi, wi, fi, ni = sch[Out].op.axis

block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
block_z = tvm.thread_axis("blockIdx.z")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")
# Create tuning space
n_thread_cand = [1, 2, 4, 8, 16, 32]
vthread_cand = [1, 2, 4, 8]

hi, wi, fi, ni = sch[Out].op.axis
bz = sch[Out].fuse(hi, wi)
by, fi = sch[Out].split(fi, factor=block_factor)
bx, ni = sch[Out].split(ni, factor=block_factor)
tyz, fi = sch[Out].split(fi, nparts=vthread)
txz, ni = sch[Out].split(ni, nparts=vthread)
ty, fi = sch[Out].split(fi, nparts=num_thread)
tx, ni = sch[Out].split(ni, nparts=num_thread)
cfg.define_split(
'tile_fi',
fi,
num_outputs=4,
filter=lambda x:
(x.size[1] in vthread_cand and x.size[2] in n_thread_cand))
cfg.define_split(
'tile_ni',
ni,
num_outputs=4,
filter=lambda x:
(x.size[1] in vthread_cand and x.size[2] in n_thread_cand))

if cfg.is_fallback:
cfg['tile_fi'] = SplitEntity([-1, 2, 8, 4])
cfg['tile_ni'] = SplitEntity([-1, 2, 8, 4])

# Scheduling
step = 8

bz = sch[Out].fuse(hi, wi) # FIXME: Does it assume square images?
comaniac marked this conversation as resolved.
Show resolved Hide resolved
by, tyz, ty, fi = cfg['tile_fi'].apply(sch, Out, fi)
bx, txz, tx, ni = cfg['tile_ni'].apply(sch, Out, ni)
sch[Out].reorder(bz, by, bx, tyz, txz, ty, tx, fi, ni)
sch[Out].bind(bz, block_z)
sch[Out].bind(by, block_y)
sch[Out].bind(bx, block_x)
sch[Out].bind(tyz, thread_yz)
sch[Out].bind(txz, thread_xz)
sch[Out].bind(ty, thread_y)
sch[Out].bind(tx, thread_x)

sch[Out].bind(bz, tvm.thread_axis('blockIdx.z'))
sch[Out].bind(by, tvm.thread_axis('blockIdx.y'))
sch[Out].bind(bx, tvm.thread_axis('blockIdx.x'))
sch[Out].bind(tyz, tvm.thread_axis('vthread'))
sch[Out].bind(txz, tvm.thread_axis('vthread'))
sch[Out].bind(ty, tvm.thread_axis('threadIdx.y'))
sch[Out].bind(tx, tvm.thread_axis('threadIdx.x'))

# Schedule BL local write
sch[BL].compute_at(sch[Out], tx)
Expand All @@ -98,21 +111,21 @@ def schedule(Apad, W, B):
sch[WL].compute_at(sch[BL], rci)
# Schedule for A's shared memory load
yi, xi, ci, ni = sch[AA].op.axis
ty, ci = sch[AA].split(ci, nparts=num_thread)
tx, ni = sch[AA].split(ni, nparts=num_thread)
ty, ci = sch[AA].split(ci, nparts=cfg['tile_fi'].size[2])
tx, ni = sch[AA].split(ni, nparts=cfg['tile_ni'].size[2])
_, ni = sch[AA].split(ni, factor=4)
sch[AA].reorder(ty, tx, yi, xi, ci, ni)
sch[AA].bind(ty, thread_y)
sch[AA].bind(tx, thread_x)
sch[AA].bind(ty, tvm.thread_axis('threadIdx.y'))
sch[AA].bind(tx, tvm.thread_axis('threadIdx.x'))
sch[AA].vectorize(ni)
# Schedule for W's shared memory load
yi, xi, ci, fi = sch[WW].op.axis
ty, ci = sch[WW].split(ci, nparts=num_thread)
tx, fi = sch[WW].split(fi, nparts=num_thread)
ty, ci = sch[WW].split(ci, nparts=cfg['tile_fi'].size[2])
tx, fi = sch[WW].split(fi, nparts=cfg['tile_ni'].size[2])
_, fi = sch[WW].split(fi, factor=4)
sch[WW].reorder(ty, tx, yi, xi, ci, fi)
sch[WW].bind(ty, thread_y)
sch[WW].bind(tx, thread_x)
sch[WW].bind(ty, tvm.thread_axis('threadIdx.y'))
sch[WW].bind(tx, tvm.thread_axis('threadIdx.x'))
sch[WW].vectorize(fi)

scheduled_ops = []
Expand Down
18 changes: 18 additions & 0 deletions topi/python/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,24 @@ def _default_schedule(outs, auto_inline):
return s


@tvm.target.generic_func
def schedule_conv2d_hwcn(outs):
"""Schedule for conv2d_hwcn

Parameters
----------
outs: Array of Tensor
The computation graph description of conv2d_hwcn
in the format of an array of tensors.

Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)


@tvm.target.generic_func
def schedule_conv2d_nchw(outs):
"""Schedule for conv2d_nchw
Expand Down
4 changes: 2 additions & 2 deletions topi/python/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N
# default declaration
if layout == 'NCHW':
return conv2d_nchw(input, filter, strides, padding, dilation, out_dtype)
if layout == 'HWCN':
elif layout == 'HWCN':
return conv2d_hwcn(input, filter, strides, padding, dilation, out_dtype)
if layout == 'NHWC':
elif layout == 'NHWC':
return conv2d_nhwc(input, filter, strides, padding, dilation, out_dtype)
raise ValueError("not support this layout {} yet".format(layout))

Expand Down