Skip to content

Commit

Permalink
[TOPI] Support grouped conv1d (#9832)
Browse files Browse the repository at this point in the history
* [TOPI] Support grouped conv1d

Generalize the conv2d compute statement to a generic convNd that
supports any layout and groups. Replace some existing conv2d and conv1d
compute statements with this generic compute. Also add a topi
group_conv1d compute that uses the generic convNd compute. Existing
schedules for conv1d work with group_conv1d, so they are reused.

* permute reduction axis order

* formatting
  • Loading branch information
Tristan Konolige authored Jan 7, 2022
1 parent afc29e6 commit f6f252f
Show file tree
Hide file tree
Showing 15 changed files with 484 additions and 327 deletions.
7 changes: 5 additions & 2 deletions python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,10 @@ def wrapper(outs, *args, **kwargs):
"""wrapper function for topi schedule"""
workload = get_workload(outs, task_name)
if workload is None:
raise RuntimeError("Cannot find workload in attribute of this schedule")
raise RuntimeError(
f"Cannot find TOPI workload {task_name}. "
"Is it registered with `register_topi_compute`?"
)
tgt = Target.current()
cfg = DispatchContext.current.query(tgt, workload)
return topi_schedule(cfg, outs, *args, **kwargs)
Expand All @@ -253,7 +256,7 @@ def traverse(tensors):
for t in tensors:
op = t.op
wkl = traverse(op.input_tensors)
if wkl:
if wkl is not None:
return wkl

if "workload" in op.attrs:
Expand Down
21 changes: 0 additions & 21 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,23 +526,6 @@ def _impl_v1(cls, inputs, attr, params):
raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"]))
attr.pop("auto_pad")

# Check if the requested convolution is a group conv1d, if so convert it to conv2d.
# TODO(jwfromm) Remove once proper group_conv1d is supported.
group_conv1d = False
if dimension_picker("conv")(attr) == "conv1d" and attr.get("group") != 1:
group_conv1d = True
# Expand input from NCW to NCHW
data = _op.expand_dims(data, axis=2)
# Expand kernel from OIW to OIHW
kernel = _op.expand_dims(kernel, axis=2)
# Add new value to kernel_shape, strices, dilation, pads, if needed
attr["kernel_shape"] = [1] + list(attr["kernel_shape"])
if "strides" in attr:
attr["strides"] = [1] + list(attr["strides"])
if "dilations" in attr:
attr["dilations"] = [1] + list(attr["dilations"])
if "pads" in attr:
attr["pads"] = [0, attr["pads"][0], 0, attr["pads"][1]]
attr["channels"] = kernel_shapes[0][0]
out = AttrCvt(
op_name=dimension_picker("conv"),
Expand All @@ -555,10 +538,6 @@ def _impl_v1(cls, inputs, attr, params):
custom_check=dimension_constraint(),
)([data, kernel], attr, params)

# If this was a group_conv1d, squish output back to NCW.
if group_conv1d:
out = _op.squeeze(out, axis=[2])

use_bias = len(inputs) == 3
if use_bias:
out = _op.nn.bias_add(out, inputs[2])
Expand Down
42 changes: 29 additions & 13 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,20 +689,36 @@ def conv1d_strategy_cuda(attrs, inputs, out_type, target):
if dilation[0] < 1:
raise ValueError("dilation should be a positive value")
strategy = _op.OpStrategy()
if layout == "NCW":
strategy.add_implementation(
wrap_compute_conv1d(topi.cuda.conv1d_ncw),
wrap_topi_schedule(topi.cuda.schedule_conv1d_ncw),
name="conv1d_ncw.cuda",
)
elif layout == "NWC":
strategy.add_implementation(
wrap_compute_conv1d(topi.cuda.conv1d_nwc),
wrap_topi_schedule(topi.cuda.schedule_conv1d_nwc),
name="conv1d_nwc.cuda",
)
if attrs.groups == 1:
if layout == "NCW":
strategy.add_implementation(
wrap_compute_conv1d(topi.cuda.conv1d_ncw),
wrap_topi_schedule(topi.cuda.schedule_conv1d_ncw),
name="conv1d_ncw.cuda",
)
elif layout == "NWC":
strategy.add_implementation(
wrap_compute_conv1d(topi.cuda.conv1d_nwc),
wrap_topi_schedule(topi.cuda.schedule_conv1d_nwc),
name="conv1d_nwc.cuda",
)
else:
raise ValueError("Unsupported conv1d layout {}".format(layout))
else:
raise ValueError("Unsupported conv1d layout {}".format(layout))
if layout == "NCW":
strategy.add_implementation(
wrap_compute_group_conv1d(topi.cuda.group_conv1d_ncw),
wrap_topi_schedule(topi.cuda.schedule_group_conv1d_ncw),
name="group_conv1d_ncw.cuda",
)
elif layout == "NWC":
strategy.add_implementation(
wrap_compute_group_conv1d(topi.cuda.group_conv1d_nwc),
wrap_topi_schedule(topi.cuda.schedule_group_conv1d_nwc),
name="group_conv1d_nwc.cuda",
)
else:
raise ValueError("Unsupported conv1d layout {}".format(layout))
return strategy


Expand Down
43 changes: 43 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,49 @@ def conv1d_strategy(attrs, inputs, out_type, target):
return strategy


def wrap_compute_group_conv1d(topi_compute):
"""wrap conv1d topi compute"""

def _compute_group_conv1d(attrs, inputs, out_type):
"""Compute definition of conv1d"""
strides = get_const_tuple(attrs.strides)
padding = get_const_tuple(attrs.padding)
dilation = get_const_tuple(attrs.dilation)
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
return [
topi_compute(inputs[0], inputs[1], strides, padding, dilation, attrs.groups, out_dtype)
]

return _compute_group_conv1d


@override_native_generic_func("group_conv1d_strategy")
def group_conv1d_strategy(attrs, inputs, out_type, target):
"""group_conv1d generic strategy"""
logger.warning("group_conv1d is not optimized for this platform.")
layout = attrs.data_layout
dilation = get_const_tuple(attrs.dilation)
if dilation[0] < 1:
raise ValueError("dilation should be a positive value")
strategy = _op.OpStrategy()
if layout == "NCW":
strategy.add_implementation(
wrap_compute_conv1d(topi.nn.group_conv1d_ncw),
wrap_topi_schedule(topi.generic.schedule_group_conv1d_ncw),
name="group_conv1d_ncw.generic",
)
elif layout == "NWC":
strategy.add_implementation(
wrap_compute_conv1d(topi.nn.group_conv1d_nwc),
wrap_topi_schedule(topi.generic.schedule_group_conv1d_nwc),
name="group_conv1d_nwc.generic",
)
else:
raise ValueError("Unsupported conv1d layout {}".format(layout))
return strategy


# conv1d_transpose
def wrap_compute_conv1d_transpose(topi_compute):
"""wrap conv1d_transpose topi compute"""
Expand Down
43 changes: 30 additions & 13 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,24 +360,41 @@ def conv3d_strategy_cpu(attrs, inputs, out_type, target):
def conv1d_strategy_cpu(attrs, inputs, out_type, target):
"""conv1d x86 strategy"""
layout = attrs.data_layout
groups = attrs.groups
dilation = get_const_tuple(attrs.dilation)
if dilation[0] < 1:
raise ValueError("dilation should be a positive value")
strategy = _op.OpStrategy()
if layout == "NCW":
strategy.add_implementation(
wrap_compute_conv1d(topi.nn.conv1d_ncw),
wrap_topi_schedule(topi.x86.schedule_conv1d_ncw),
name="conv1d_ncw.x86",
)
elif layout == "NWC":
strategy.add_implementation(
wrap_compute_conv1d(topi.nn.conv1d_nwc),
wrap_topi_schedule(topi.x86.schedule_conv1d_nwc),
name="conv1d_nwc.x86",
)
if groups == 1:
if layout == "NCW":
strategy.add_implementation(
wrap_compute_conv1d(topi.nn.conv1d_ncw),
wrap_topi_schedule(topi.x86.schedule_conv1d_ncw),
name="conv1d_ncw.x86",
)
elif layout == "NWC":
strategy.add_implementation(
wrap_compute_conv1d(topi.nn.conv1d_nwc),
wrap_topi_schedule(topi.x86.schedule_conv1d_nwc),
name="conv1d_nwc.x86",
)
else:
raise ValueError("Unsupported conv1d layout {}".format(layout))
else:
raise ValueError("Unsupported conv1d layout {}".format(layout))
if layout == "NCW":
strategy.add_implementation(
wrap_compute_group_conv1d(topi.nn.group_conv1d_ncw),
wrap_topi_schedule(topi.x86.schedule_group_conv1d_ncw),
name="group_conv1d_ncw.x86",
)
elif layout == "NWC":
strategy.add_implementation(
wrap_compute_group_conv1d(topi.nn.group_conv1d_nwc),
wrap_topi_schedule(topi.x86.schedule_group_conv1d_nwc),
name="group_conv1d_nwc.x86",
)
else:
raise ValueError("Unsupported conv1d layout {}".format(layout))
return strategy


Expand Down
40 changes: 34 additions & 6 deletions python/tvm/topi/cuda/conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ def conv1d_ncw(cfg, data, kernel, strides, padding, dilation, out_dtype="float32
return nn.conv1d_ncw(data, kernel, strides, padding, dilation, out_dtype)


@autotvm.register_topi_schedule("conv1d_ncw.cuda")
def schedule_conv1d_ncw(cfg, outs):
def _schedule_conv1d_ncw(cfg, outs):
"""TOPI schedule callback of conv1d ncw for cuda gpu
Parameters
Expand All @@ -51,7 +50,7 @@ def schedule_conv1d_ncw(cfg, outs):
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if op.tag == "conv1d_ncw":
if op.tag == "conv1d_ncw" or op.tag == "group_conv1d_ncw":
pad_data = op.input_tensors[0]
kernel = op.input_tensors[1]
conv = op.output(0)
Expand Down Expand Up @@ -140,13 +139,27 @@ def _callback(op):
return s


@autotvm.register_topi_schedule("conv1d_ncw.cuda")
def schedule_conv1d_ncw(cfg, outs):
return _schedule_conv1d_ncw(cfg, outs)


@autotvm.register_topi_compute("group_conv1d_ncw.cuda")
def group_conv1d_ncw(cfg, data, kernel, strides, padding, dilation, groups, out_dtype="float32"):
return nn.group_conv1d_ncw(data, kernel, strides, padding, dilation, groups, out_dtype)


@autotvm.register_topi_schedule("group_conv1d_ncw.cuda")
def schedule_group_conv1d_ncw(cfg, outs):
return _schedule_conv1d_ncw(cfg, outs)


@autotvm.register_topi_compute("conv1d_nwc.cuda")
def conv1d_nwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float32"):
return nn.conv1d_nwc(data, kernel, strides, padding, dilation, out_dtype)


@autotvm.register_topi_schedule("conv1d_nwc.cuda")
def schedule_conv1d_nwc(cfg, outs):
def _schedule_conv1d_nwc(cfg, outs):
"""TOPI schedule callback of conv1d nwc for cuda gpu
Parameters
Expand All @@ -167,7 +180,7 @@ def schedule_conv1d_nwc(cfg, outs):
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if op.tag == "conv1d_nwc":
if op.tag == "conv1d_nwc" or op.tag == "group_conv1d_nwc":
pad_data = op.input_tensors[0]
kernel = op.input_tensors[1]
conv = op.output(0)
Expand Down Expand Up @@ -254,3 +267,18 @@ def _callback(op):
traverse_inline(s, outs[0].op, _callback)

return s


@autotvm.register_topi_schedule("conv1d_nwc.cuda")
def schedule_conv1d_nwc(cfg, outs):
return _schedule_conv1d_nwc(cfg, outs)


@autotvm.register_topi_compute("group_conv1d_nwc.cuda")
def group_conv1d_nwc(cfg, data, kernel, strides, padding, dilation, groups, out_dtype="float32"):
return nn.group_conv1d_nwc(data, kernel, strides, padding, dilation, groups, out_dtype)


@autotvm.register_topi_schedule("group_conv1d_nwc.cuda")
def schedule_group_conv1d_nwc(cfg, outs):
return _schedule_conv1d_nwc(cfg, outs)
34 changes: 34 additions & 0 deletions python/tvm/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,40 @@ def schedule_conv1d_nwc(outs):
return _default_schedule(outs, False)


def schedule_group_conv1d_ncw(outs):
"""Schedule for group_conv1d_ncw
Parameters
----------
outs: Array of Tensor
The computation graph description of group_conv1d_ncw
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)


def schedule_group_conv1d_nwc(outs):
"""Schedule for group_conv1d_nwc
Parameters
----------
outs: Array of Tensor
The computation graph description of group_conv1d_nwc
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)


def schedule_conv2d_hwcn(outs):
"""Schedule for conv2d_hwcn
Expand Down
Loading

0 comments on commit f6f252f

Please sign in to comment.