Skip to content

Commit

Permalink
[TOPI] add dilation operators (#316)
Browse files Browse the repository at this point in the history
* add dilation operators

* fix pylint

* dilate testcases success

* n-D tensor dilation

* support arbitrary dimension
  • Loading branch information
Huyuwei authored and tqchen committed Aug 14, 2017
1 parent ba6664a commit b0c42f3
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 0 deletions.
1 change: 1 addition & 0 deletions topi/python/topi/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .mapping import *
from .ewise import *
from .conv import *
from .dilate import *
44 changes: 44 additions & 0 deletions topi/python/topi/nn/dilate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# pylint: disable=invalid-name
"""Dilation operators"""
from __future__ import absolute_import as _abs
import tvm


@tvm.tag_scope(tag="dilation")
def dilate(Input, strides):
"""Dilate Input with zeros.
Parameters
----------
Input : tvm.Tensor
n-D, can be any layout.
strides : list / tuple of n ints
Dilation stride on each dimension, 1 means no dilation.
Returns
-------
Output : tvm.Tensor
n-D, the same layout as Input.
"""
n = len(Input.shape)
assert len(strides) == n, \
"Input dimension and strides size dismatch : %d vs %d" %(n, len(strides))
output_size = ()
for i in range(n):
output_size += (tvm.ir_pass.Simplify((Input.shape[i]-1)*strides[i]+1),)

def _dilate(data, *indices):
not_zero = (indices[0]%strides[0]).equal(0)
index_tuple = ()
for i in range(n):
index_tuple += (indices[i]/strides[i],)
not_zero = tvm.all(not_zero, (indices[i]%strides[i]).equal(0))
return tvm.select(not_zero, data[index_tuple], tvm.const(0.0, data.dtype))

Output = tvm.compute(
(output_size),
lambda *indices: _dilate(Input, *indices),
name='DilatedInput')

return Output
1 change: 1 addition & 0 deletions topi/python/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@

from .conv2d_hwcn_python import conv2d_hwcn_python
from .conv2d_nchw_python import conv2d_nchw_python
from .dilate_python import dilate_python
33 changes: 33 additions & 0 deletions topi/python/topi/testing/dilate_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# pylint: disable=invalid-name
"""Dilate operation in python"""
import numpy as np


def dilate_python(input_np, strides):
"""Dilate operation.
Parameters
----------
input_np : numpy.ndarray
n-D, can be any layout.
strides : list / tuple of n ints
Dilation stride on each dimension, 1 means no dilation.
Returns
-------
output_np : numpy.ndarray
n-D, the same layout as Input.
"""
n = len(input_np.shape)
assert len(strides) == n, \
"Input dimension and strides size dismatch : %d vs %d" %(n, len(strides))
output_size = ()
no_zero = ()
for i in range(n):
output_size += ((input_np.shape[i]-1)*strides[i]+1,)
no_zero += ((range(0, output_size[i], strides[i])),)
output_np = np.zeros(shape=output_size)
output_np[np.ix_(*no_zero)] = input_np

return output_np
36 changes: 36 additions & 0 deletions topi/tests/python/test_topi_dilate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import tvm
import topi
import numpy as np


def test_dilate():
target = 'llvm'
ctx = tvm.cpu(0)

def _test_dilate(input_size, strides):
Input = tvm.placeholder((input_size))
Output = topi.nn.dilate(Input, strides)
schedule = tvm.create_schedule(Output.op)
input_np = np.random.uniform(size=input_size).astype(Input.dtype)
output_np = topi.testing.dilate_python(input_np, strides)
input_tvm = tvm.nd.array(input_np, ctx=ctx)
output_size = ()
for i in range(len(input_size)):
output_size += (tvm.ir_pass.Simplify(Output.shape[i]).value,)
output_tvm = tvm.nd.array(np.zeros(shape=output_size).astype(Output.dtype), ctx=ctx)
f = tvm.build(schedule, [Input, Output], target)
f(input_tvm, output_tvm)
np.testing.assert_allclose(output_tvm.asnumpy(), output_np, rtol=1e-5)

_test_dilate((32,), (2,))
_test_dilate((32,32), (2,2))
_test_dilate((1,3,32,32), (1,1,1,1))
_test_dilate((1,3,32,32), (2,2,2,2))
_test_dilate((1,32,32,3,3), (1,1,1,1,1))
_test_dilate((1,32,32,3,3), (2,2,2,2,2))
_test_dilate((1,32,32,32,3,3), (1,1,1,2,2,2))
_test_dilate((1,32,32,32,3,3), (2,2,2,1,1,1))


if __name__ == "__main__":
test_dilate()

0 comments on commit b0c42f3

Please sign in to comment.