Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/dmlc/tvm
Browse files Browse the repository at this point in the history
  • Loading branch information
honghua.cao committed Sep 15, 2020
2 parents 1dac9dc + 7b744b3 commit c1c25eb
Show file tree
Hide file tree
Showing 12 changed files with 282 additions and 40 deletions.
2 changes: 1 addition & 1 deletion python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _encode(x):
return ("TENSOR", get_const_tuple(x.shape), x.dtype)
if isinstance(x, (tuple, list, container.Array)):
return tuple([_encode(a) for a in x])
if isinstance(x, (str, int, float, np.int, np.float, expr.Var)):
if isinstance(x, (str, int, float, np.int, np.float, expr.Var, expr.Any)):
return x
if isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
return x.value
Expand Down
21 changes: 18 additions & 3 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from tvm.te.hybrid import script
from tvm import topi
from tvm.runtime import convert

from .op import register_compute, register_shape_func
from .op import register_broadcast_schedule, register_injective_schedule
Expand Down Expand Up @@ -156,11 +157,22 @@ def _full_shape_func(shape):
return out


@script
def _convert_shape(shape):
out = output_tensor((len(shape),), "int64")
for i in const_range(len(shape)):
out[i] = int64(shape[i])
return out


def full_shape_func(attrs, inputs, out_ndims):
"""
Shape func for full.
"""
return [_full_shape_func(inputs[1])]
if len(inputs) > 1:
return [_full_shape_func(inputs[1])]

return [_convert_shape(convert(attrs.shape))]


def no_data_full_shape_func(attrs, inputs, out_ndims):
Expand Down Expand Up @@ -216,9 +228,9 @@ def elemwise_shape_func(attrs, inputs, _):


register_shape_func("cast", False, elemwise_shape_func)
register_shape_func("zeros", False, full_shape_func)
register_shape_func("zeros", False, no_data_full_shape_func)
register_shape_func("zeros_like", False, elemwise_shape_func)
register_shape_func("ones", False, full_shape_func)
register_shape_func("ones", False, no_data_full_shape_func)
register_shape_func("ones_like", False, elemwise_shape_func)
register_shape_func("full", False, full_shape_func)
register_shape_func("full_like", False, elemwise_shape_func)
Expand Down Expand Up @@ -257,3 +269,6 @@ def elemwise_shape_func(attrs, inputs, _):
register_shape_func("floor", False, elemwise_shape_func)
register_shape_func("log", False, elemwise_shape_func)
register_shape_func("device_copy", False, elemwise_shape_func)
register_shape_func("clip", False, elemwise_shape_func)
register_shape_func("log2", False, elemwise_shape_func)
register_shape_func("sigmoid", False, elemwise_shape_func)
50 changes: 50 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,3 +745,53 @@ def adv_index_shape_func(attrs, inputs, _):
Only allow single index tensor.
"""
return [_adv_index_shape_func(inputs)]


@script
def _repeat_shape_func(data_shape, repeats, axis):
out = output_tensor((data_shape.shape[0],), "int64")

for i in const_range(data_shape.shape[0]):
if i == axis:
out[i] = int64(data_shape[i] * repeats)
else:
out[i] = data_shape[i]

return out

@_reg.register_shape_func("repeat", False)
def repeat_shape_func(attrs, inputs, _):
"""
Shape func for repeat.
"""
axis = get_const_int(attrs.axis)
if axis < 0:
axis = inputs[0].shape[0] + axis
return [_repeat_shape_func(inputs[0], attrs.repeats, convert(axis))]


@_reg.register_shape_func("broadcast_to_like", False)
def broadcast_to_like_shape_func(attrs, inputs, _):
return [topi.math.identity(inputs[1])]


@script
def _stack_shape_func(data_shape, axis, num_inputs):
out = output_tensor((data_shape.shape[0] + 1,), "int64")

for i in const_range(data_shape.shape[0] + 1):
if i == axis:
out[i] = int64(num_inputs)
elif i < axis:
out[i] = data_shape[i]
else:
out[i] = data_shape[i - 1]

return out

@_reg.register_shape_func("stack", False)
def stack_shape_func(attrs, inputs, _):
axis = get_const_int(attrs.axis)
if axis < 0:
axis += inputs[0].shape[0] + 1
return [_stack_shape_func(inputs[0], convert(axis), convert(len(inputs)))]
43 changes: 43 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,49 @@ def conv2d_NCHWc_shape_func(attrs, inputs, _):
]


@script
def _conv2d_transpose_nchw_shape_func(dshape, kshape, strides,
padding, dilation, output_padding):
out = output_tensor((dshape.shape[0],), "int64")
kheight = kshape[2]
kwidth = kshape[3]
dilated_kh = (kheight - 1) * dilation[0] + 1
dilated_kw = (kwidth - 1) * dilation[1] + 1

out_height = strides[0] * (dshape[2] - 1) + dilated_kh - \
2 * padding[0] + output_padding[0]
out_width = strides[1] * (dshape[3] - 1) + dilated_kw - \
2 * padding[1] + output_padding[1]

out[0] = dshape[0]
out[1] = kshape[1]
out[2] = out_height
out[3] = out_width
return out


@reg.register_shape_func("nn.conv2d_transpose", False)
def conv2d_transpose_nchw_shape_func(attrs, inputs, _):
"""
Shape function for conv2d_transpose op.
"""
strides = get_const_tuple(attrs.strides)
padding = get_const_tuple(attrs.padding)
dilation = get_const_tuple(attrs.dilation)
output_padding = get_const_tuple(attrs.output_padding)

return [
_conv2d_transpose_nchw_shape_func(
inputs[0],
inputs[1],
convert(strides),
convert(padding),
convert(dilation),
convert(output_padding)
)
]


@script
def _pool2d_shape_func(data_shape, pool_size, strides, padding, height_axis, width_axis):
out = output_tensor((data_shape.shape[0],), "int64")
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/relay/op/vision/_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

from tvm import topi
from tvm.te.hybrid import script
from tvm.runtime import convert

from .. import op as reg
from .. import strategy
from ..op import OpPattern
Expand Down Expand Up @@ -81,3 +83,18 @@ def nms_shape_func(attrs, inputs, _):
if attrs.return_indices:
return _nms_shape_func(inputs[0])
return [topi.math.identity(inputs[0])]


@script
def _roi_align_shape_func(data_shape, rois_shape, pooled_size):
out = output_tensor((4,), "int64")
out[0] = rois_shape[0]
out[1] = data_shape[1]
out[2] = int64(pooled_size[0])
out[3] = int64(pooled_size[1])
return out

@reg.register_shape_func("vision.roi_align", False)
def roi_align_shape_func(attrs, inputs, _):
return [_roi_align_shape_func(inputs[0], inputs[1],
convert(attrs.pooled_size))]
40 changes: 20 additions & 20 deletions python/tvm/topi/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def _scatter_1d(data, indices, updates):
@hybrid.script
def _scatter_2d(data, indices, updates, axis):
out = output_tensor(data.shape, data.dtype)
for i in const_range(data.shape[0]):
for j in const_range(data.shape[1]):
for i in range(data.shape[0]):
for j in range(data.shape[1]):
out[i, j] = data[i, j]
if axis == 0:
for i in range(indices.shape[0]):
Expand All @@ -54,14 +54,14 @@ def _scatter_2d(data, indices, updates, axis):
@hybrid.script
def _scatter_3d(data, indices, updates, axis):
out = output_tensor(data.shape, data.dtype)
for i in const_range(data.shape[0]):
for j in const_range(data.shape[1]):
for k in const_range(data.shape[2]):
for i in range(data.shape[0]):
for j in range(data.shape[1]):
for k in range(data.shape[2]):
out[i, j, k] = data[i, j, k]
if axis == 0:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for k in range(indices.shape[2]):
out[
indices[i, j, k]
if indices[i, j, k] >= 0
Expand All @@ -72,7 +72,7 @@ def _scatter_3d(data, indices, updates, axis):
elif axis == 1:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for k in range(indices.shape[2]):
out[
i,
indices[i, j, k]
Expand All @@ -83,7 +83,7 @@ def _scatter_3d(data, indices, updates, axis):
else:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for k in range(indices.shape[2]):
out[
i,
j,
Expand All @@ -98,17 +98,17 @@ def _scatter_3d(data, indices, updates, axis):
@hybrid.script
def _scatter_4d(data, indices, updates, axis):
out = output_tensor(data.shape, data.dtype)
for i in const_range(data.shape[0]):
for j in const_range(data.shape[1]):
for k in const_range(data.shape[2]):
for l in const_range(data.shape[3]):
for i in range(data.shape[0]):
for j in range(data.shape[1]):
for k in range(data.shape[2]):
for l in range(data.shape[3]):
out[i, j, k, l] = data[i, j, k, l]

if axis == 0:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for l in const_range(indices.shape[3]):
for k in range(indices.shape[2]):
for l in range(indices.shape[3]):
out[
indices[i, j, k, l]
if indices[i, j, k, l] >= 0
Expand All @@ -120,8 +120,8 @@ def _scatter_4d(data, indices, updates, axis):
elif axis == 1:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for l in const_range(indices.shape[3]):
for k in range(indices.shape[2]):
for l in range(indices.shape[3]):
out[
i,
indices[i, j, k, l]
Expand All @@ -133,8 +133,8 @@ def _scatter_4d(data, indices, updates, axis):
elif axis == 2:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for l in const_range(indices.shape[3]):
for k in range(indices.shape[2]):
for l in range(indices.shape[3]):
out[
i,
j,
Expand All @@ -146,8 +146,8 @@ def _scatter_4d(data, indices, updates, axis):
else:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for l in const_range(indices.shape[3]):
for k in range(indices.shape[2]):
for l in range(indices.shape[3]):
out[
i,
j,
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/topi/x86/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,17 @@ def _pack_data(cfg, data, kernel):
ic_chunk = ic // ic_bn
oc_chunk = oc // oc_bn

# Handle dynamic shape to pass tuning dispatch.
if isinstance(n, tvm.tir.Any):
n = tvm.te.size_var("n")
if isinstance(ih, tvm.tir.Any):
ih = tvm.te.size_var("ih")
if isinstance(iw, tvm.tir.Any):
iw = tvm.te.size_var("iw")
if isinstance(ic, tvm.tir.Any):
raise RuntimeError("Dynamic input channel is not supported for conv2d.")


data = te.compute(
(n, ic_chunk, ih, iw, ic_bn),
lambda bs, c, h, w, vc: data[bs, c * ic_bn + vc, h, w],
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,8 @@ def dense_blas_common(cfg, data, weight, bias, out_dtype, lib):
"""Compute dense using a BLAS library"""
M, K = get_const_tuple(data.shape)
N, _ = get_const_tuple(weight.shape)
cfg.add_flop(M * K * N * 2)
if isinstance(M, int) and isinstance(K, int) and isinstance(N, int):
cfg.add_flop(M * K * N * 2)
if data.dtype == "uint8" and weight.dtype == "int8" and out_dtype == "int32":
if not hasattr(lib, "matmul_u8s8s32"):
raise NotImplementedError(
Expand Down
10 changes: 7 additions & 3 deletions python/tvm/topi/x86/roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


@hybrid.script
def roi_align_nchw_ir(data, rois, w_pc, pos_pc, pooled_size, spatial_scale, sample_ratio):
def roi_align_nchw_ir(data, rois, num_rois, w_pc, pos_pc, pooled_size, spatial_scale, sample_ratio):
"""Hybrid routing fo ROI align operator in NCHW layout.
Parameters
Expand All @@ -37,6 +37,10 @@ def roi_align_nchw_ir(data, rois, w_pc, pos_pc, pooled_size, spatial_scale, samp
2-D with shape [num_roi, 5]. The last dimension should be in format of
[batch_index, w_start, h_start, w_end, h_end]
num_rois : tvm.tir.IntImm or tvm.tir.Var
Number of roi. We need to pass it in since hybrid script doesn't support
binding variable to symbolic dim.
w_pc : tvm.te.Tensor or numpy NDArray
3-D weight pre-calculation buffer
Expand All @@ -61,7 +65,6 @@ def roi_align_nchw_ir(data, rois, w_pc, pos_pc, pooled_size, spatial_scale, samp
channels = data.shape[1]
height = data.shape[2]
width = data.shape[3]
num_rois = rois.shape[0]
pooled_size_h = pooled_size[0]
pooled_size_w = pooled_size[1]
output = output_tensor((num_rois, channels, pooled_size_h, pooled_size_w), data.dtype)
Expand Down Expand Up @@ -235,6 +238,7 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
_, _, height, width = get_const_tuple(data.shape)
max_roi_bin_grid_h = math.ceil(height / pooled_size[0])
max_roi_bin_grid_w = math.ceil(width / pooled_size[1])
num_rois = rois.shape[0]
max_pc_shape = (
rois.shape[0],
max_roi_bin_grid_h * max_roi_bin_grid_w * pooled_size[0] * pooled_size[1],
Expand All @@ -247,7 +251,7 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
spatial_scale = tvm.tir.const(spatial_scale, "float32")
sample_ratio = tvm.tir.const(sample_ratio, "int32")
return roi_align_nchw_ir(
data, rois, w_pc_buffer, pos_pc_buffer, pooled_size, spatial_scale, sample_ratio
data, rois, num_rois, w_pc_buffer, pos_pc_buffer, pooled_size, spatial_scale, sample_ratio
)


Expand Down
Loading

0 comments on commit c1c25eb

Please sign in to comment.