Skip to content

Commit

Permalink
Move schedule to topi/gpu dir
Browse files Browse the repository at this point in the history
  • Loading branch information
echuraev committed Aug 6, 2021
1 parent 479aac5 commit a44c81e
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 28 deletions.
6 changes: 3 additions & 3 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
elif layout == "NHWC":
assert kernel_layout == "HWIO"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nhwc),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc),
name="conv2d_nhwc.cuda",
wrap_compute_conv2d(topi.gpu.conv2d_nhwc),
wrap_topi_schedule(topi.gpu.schedule_conv2d_nhwc),
name="conv2d_nhwc.gpu",
)

N, H, W, _ = get_const_tuple(data.shape)
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/op/strategy/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target):
elif layout == "NHWC":
assert kernel_layout == "HWIO"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nhwc),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc),
name="conv2d_nhwc.cuda",
wrap_compute_conv2d(topi.gpu.conv2d_nhwc),
wrap_topi_schedule(topi.gpu.schedule_conv2d_nhwc),
name="conv2d_nhwc.gpu",
)
N, H, W, _ = get_const_tuple(data.shape)
KH, KW, CI, CO = get_const_tuple(kernel.shape)
Expand Down
21 changes: 0 additions & 21 deletions python/tvm/topi/cuda/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from ..nn.utils import get_pad_tuple
from ..utils import get_const_tuple, traverse_inline
from .conv2d_direct import schedule_direct_cuda
from .conv2d_nhwc import schedule_conv2d_nhwc_direct


@autotvm.register_topi_compute("conv2d_nchw.cuda")
Expand All @@ -48,26 +47,6 @@ def _callback(op):
return s


@autotvm.register_topi_compute("conv2d_nhwc.cuda")
def conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float32"):
"""Compute conv2d with NHWC layout"""
return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype)


@autotvm.register_topi_schedule("conv2d_nhwc.cuda")
def schedule_conv2d_nhwc(cfg, outs):
"""Create the schedule for conv2d_nhwc"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if op.tag == "conv2d_nhwc":
schedule_conv2d_nhwc_direct(cfg, s, op.output(0))

traverse_inline(s, outs[0].op, _callback)
return s


@autotvm.register_topi_compute("conv2d_cudnn.cuda")
def conv2d_cudnn(
cfg, data, kernel, strides, padding, dilation, groups=1, layout="NCHW", out_dtype="float32"
Expand Down
43 changes: 43 additions & 0 deletions python/tvm/topi/gpu/conv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Schedule for conv2d operator"""
from tvm import te, autotvm

from .. import nn
from ..utils import traverse_inline
from .conv2d_nhwc import schedule_conv2d_nhwc_direct


@autotvm.register_topi_compute("conv2d_nhwc.gpu")
def conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float32"):
"""Compute conv2d with NHWC layout"""
return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype)


@autotvm.register_topi_schedule("conv2d_nhwc.gpu")
def schedule_conv2d_nhwc(cfg, outs):
"""Create the schedule for conv2d_nhwc"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if op.tag == "conv2d_nhwc":
schedule_conv2d_nhwc_direct(cfg, s, op.output(0))

traverse_inline(s, outs[0].op, _callback)
return s
File renamed without changes.
2 changes: 1 addition & 1 deletion tests/python/topi/python/test_topi_conv2d_nhwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

_conv2d_nhwc_implement = {
"llvm": (topi.nn.conv2d_nhwc, topi.generic.schedule_conv2d_nhwc),
"cuda": (topi.cuda.conv2d_nhwc, topi.cuda.schedule_conv2d_nhwc),
"cuda": (topi.gpu.conv2d_nhwc, topi.gpu.schedule_conv2d_nhwc),
"cpu": (topi.nn.conv2d_nhwc, topi.x86.schedule_conv2d_nhwc),
"arm_cpu": (
topi.arm_cpu.conv2d_nhwc_spatial_pack,
Expand Down

0 comments on commit a44c81e

Please sign in to comment.