Skip to content

Commit

Permalink
TOPI attrs
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Mar 3, 2022
1 parent c9dc5c6 commit 11cb179
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 28 deletions.
27 changes: 19 additions & 8 deletions python/tvm/topi/cuda/conv2d_nhwc_winograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@
"""Winograd template for cuda backend"""

import tvm
from tvm import te
from tvm import autotvm
from tvm import autotvm, te

from .. import nn
from ..utils import get_const_int, get_const_tuple, traverse_inline
from ..nn.winograd_util import winograd_transform_matrices
from .tensor_intrin import intrin_wmma_load_matrix_A
from .tensor_intrin import intrin_wmma_load_matrix_W
from .tensor_intrin import intrin_wmma_store_matrix
from .tensor_intrin import intrin_wmma_gemm
from ..utils import get_const_int, get_const_tuple, traverse_inline
from .tensor_intrin import (
intrin_wmma_gemm,
intrin_wmma_load_matrix_A,
intrin_wmma_load_matrix_W,
intrin_wmma_store_matrix,
)


def _infer_tile_size(data, kernel):
Expand Down Expand Up @@ -332,7 +334,13 @@ def nhwc_winograd_cuda(
assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1

pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW))
data_pad = nn.pad(data, (0, pt, pl, 0), (0, pb, pr, 0), name="data_pad")
data_pad = nn.pad(
data,
(0, pt, pl, 0),
(0, pb, pr, 0),
name="data_pad",
attrs={"schedule_rule": "None"},
)

r = KW
m = tile_size
Expand Down Expand Up @@ -388,6 +396,7 @@ def nhwc_winograd_cuda(
idxdiv(p, (nH * nW)), idxmod(idxdiv(p, nW), nH) * m + eps, idxmod(p, nW) * m + nu, c
],
name="d",
attrs={"schedule_rule": "None"},
)

# Transform data
Expand All @@ -399,6 +408,7 @@ def nhwc_winograd_cuda(
input_tile[p][ci][r_a][r_b] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b]
),
name="data_pack",
attrs={"schedule_rule": "meta_schedule.winograd_data_pack.cuda"},
)

# Convert data type of input feature maps and weights for tensorcore
Expand Down Expand Up @@ -430,6 +440,7 @@ def nhwc_winograd_cuda(
bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b]
),
name="inverse",
attrs={"schedule_rule": "meta_schedule.winograd_inverse"},
)

# Output
Expand Down
20 changes: 14 additions & 6 deletions python/tvm/topi/cuda/conv2d_winograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@
"""Winograd template for cuda backend"""

import logging

import tvm
from tvm import te
from tvm import autotvm
from tvm import autotvm, te

from .. import nn
from ..utils import get_const_int, get_const_tuple, traverse_inline
from ..nn.conv2d import _conv2d_winograd_nhwc_impl, conv2d_winograd_nhwc
from ..nn.winograd_util import winograd_transform_matrices
from ..nn.conv2d import conv2d_winograd_nhwc, _conv2d_winograd_nhwc_impl

from ..utils import get_const_int, get_const_tuple, traverse_inline

logger = logging.getLogger("conv2d_winograd")

Expand Down Expand Up @@ -78,7 +77,13 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_
assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1

pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW))
data_pad = nn.pad(data, (0, 0, pt, pl), (0, 0, pb, pr), name="data_pad")
data_pad = nn.pad(
data,
(0, 0, pt, pl),
(0, 0, pb, pr),
name="data_pad",
attrs={"schedule_rule": "None"},
)

r = KW
m = tile_size
Expand Down Expand Up @@ -113,6 +118,7 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_
idxmod(idxdiv(p, nW), nH) * m + eps
][idxmod(p, nW) * m + nu],
name="d",
attrs={"schedule_rule": "None"},
)

# transform data
Expand All @@ -124,6 +130,7 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_
input_tile[ci][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b]
),
name="data_pack",
attrs={"schedule_rule": "meta_schedule.winograd_data_pack.cuda"},
)

# do batch gemm
Expand All @@ -145,6 +152,7 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_
bgemm[r_a][r_b][co][p] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b]
),
name="inverse",
attrs={"schedule_rule": "meta_schedule.winograd_inverse"},
)

# output
Expand Down
23 changes: 17 additions & 6 deletions python/tvm/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
"""Conv2D operators"""
from __future__ import absolute_import as _abs

from collections import namedtuple
import re
from typing import Union, Sequence, Optional
import numpy as np
from collections import namedtuple
from typing import Optional, Sequence, Union

import numpy as np
import tvm
from tvm import auto_scheduler, te

Expand Down Expand Up @@ -1019,7 +1019,11 @@ def _conv2d_winograd_nhwc_impl(

pad_extra = (nW - 1) * m + alpha - (H + pad_t + pad_b)
data_pad = pad(
data, (0, pad_t, pad_l, 0), (0, pad_b + pad_extra, pad_r + pad_extra, 0), name="data_pad"
data,
(0, pad_t, pad_l, 0),
(0, pad_b + pad_extra, pad_r + pad_extra, 0),
name="data_pad",
attrs={"schedule_rule": "None"},
)

if not pre_computed:
Expand All @@ -1044,6 +1048,7 @@ def _conv2d_winograd_nhwc_impl(
(p % nW) * m + nu
][ci],
name="input_tile",
attrs={"schedule_rule": "None"},
)

# transform data
Expand All @@ -1055,7 +1060,10 @@ def _conv2d_winograd_nhwc_impl(
input_tile[r_a][r_b][p][ci] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b]
),
name="data_pack",
attrs={"auto_scheduler_simplify_const_tensor_indices": ["eps", "nu", "r_a", "r_b"]},
attrs={
"auto_scheduler_simplify_const_tensor_indices": ["eps", "nu", "r_a", "r_b"],
"schedule_rule": "meta_schedule.winograd_data_pack.cpu",
},
# the attrs are necessary hints for the auto-scheduler
)

Expand All @@ -1082,7 +1090,10 @@ def _conv2d_winograd_nhwc_impl(
bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b]
),
name="inverse",
attrs={"auto_scheduler_simplify_const_tensor_indices": ["vh", "vw", "r_a", "r_b"]},
attrs={
"auto_scheduler_simplify_const_tensor_indices": ["vh", "vw", "r_a", "r_b"],
"schedule_rule": "meta_schedule.winograd_inverse",
},
# the attrs are necessary hints for the auto-scheduler
)

Expand Down
8 changes: 5 additions & 3 deletions python/tvm/topi/nn/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
# under the License.
"""Pad the data by constant value """
from __future__ import absolute_import as _abs

import tvm
from tvm import te
from ..utils import equal_const_int

from .. import tag
from ..utils import equal_const_int


@tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad")
def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"):
def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput", attrs=None):
"""Pad Input with zeros.
Parameters
Expand Down Expand Up @@ -85,7 +87,7 @@ def _pad(*indices):
return tvm.tir.if_then_else(not_zero, data(*index_tuple), pad_value)
return data(*index_tuple)

return te.compute(out_shape, _pad, name=name)
return te.compute(out_shape, _pad, name=name, attrs=attrs)


@tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad")
Expand Down
19 changes: 14 additions & 5 deletions python/tvm/topi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
# pylint: disable=invalid-name
"""Common topi utilities"""
from __future__ import absolute_import as _abs
from numbers import Integral
import numpy as np

from numbers import Integral

import numpy as np
import tvm
from tvm import te
from tvm.tir import layout, bijective_layout
from . import tag, cpp
from tvm.tir import bijective_layout, layout

from . import cpp, tag


class InvalidShapeError(ValueError):
Expand Down Expand Up @@ -347,7 +348,15 @@ def select_array(i, j):
)
return now

return te.compute(matrix.shape, select_array, name=name, attrs={"const_matrix": True})
return te.compute(
matrix.shape,
select_array,
name=name,
attrs={
"const_matrix": True,
"schedule_rule": "meta_schedule.compute_inline",
},
)


def get_max_power2_factor(n, max_value=None):
Expand Down

0 comments on commit 11cb179

Please sign in to comment.