Skip to content

Commit

Permalink
Add generic layout conv2d strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed May 17, 2022
1 parent bebaf12 commit 3205474
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
8 changes: 7 additions & 1 deletion python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,15 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.cuda",
)
elif is_auto_scheduler_enabled():
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv, need_data_layout=True, need_kernel_layout=True, has_groups=True),
naive_schedule,
name="conv2d_generic_layout",
)
elif target.kind.name == "cuda" and "cudnn" not in target.libs:
# No TVM native kernel applicable
raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
raise RuntimeError("Unsupported conv2d layout {} {} for CUDA".format(layout, kernel_layout))

if (
target.kind.name == "cuda"
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def schedule_bitpack(attrs, outs, target):
def wrap_compute_conv2d(
topi_compute,
need_data_layout=False,
need_kernel_layout=False,
need_out_layout=False,
has_groups=False,
need_auto_scheduler_layout=False,
Expand All @@ -227,6 +228,7 @@ def _compute_conv2d(attrs, inputs, out_type):
strides = get_const_tuple(attrs.strides)
dilation = get_const_tuple(attrs.dilation)
data_layout = attrs.get_str("data_layout")
kernel_layout = attrs.get_str("kernel_layout")
out_layout = attrs.get_str("out_layout")
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
Expand All @@ -235,6 +237,8 @@ def _compute_conv2d(attrs, inputs, out_type):
args.append(attrs.groups)
if need_data_layout:
args.append(data_layout)
if need_kernel_layout:
args.append(kernel_layout)
if need_out_layout:
args.append(out_layout)
args.append(out_dtype)
Expand Down

0 comments on commit 3205474

Please sign in to comment.