-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOPI] add dilation operators #316
Conversation
thanks! can we add a unittest case to verify this? |
To be consistent with existing interface, (e.g. mxnet), we should take in a Array of Expr(can be constant) as argument so we can pass tuple of int directly as dilate into it. |
It is also helpful to add a padding stage, as per #294 |
To verify logical correctness of the compute declaration, we don't need cuda schedule, just a llvm cpu schedule of small workload will suffice |
Actually, I think it is helpful to directly pass in a ndim Array of Expr as dilate argument, to specify dilate on each dimension. This will avoid the need of specifying layout |
topi/python/topi/nn/dilate.py
Outdated
Output = tvm.compute( | ||
(N, (H-1)*stride_h+1, (W-1)*stride_w+1, C), | ||
lambda n, h, w, c: tvm.select( | ||
tvm.all(tvm.make.EQ(h%stride_h, 0), tvm.make.EQ(w%stride_w, 0)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can use (h%stride_h).equal(0)
topi/python/topi/nn/dilate.py
Outdated
(N, (H-1)*stride_h+1, (W-1)*stride_w+1, C), | ||
lambda n, h, w, c: tvm.select( | ||
tvm.all(tvm.make.EQ(h%stride_h, 0), tvm.make.EQ(w%stride_w, 0)), | ||
Input(n, h/stride_h, w/stride_w, c), tvm.const(0.0)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tvm.const(0.0, Input.dtype)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This allows the expression to work for other data types
topi/python/topi/nn/dilate.py
Outdated
Output : tvm.Tensor | ||
Output tensor, layout is NCHW. | ||
""" | ||
N, C, H, W = get_const_tuple(Input.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this expression to work, we might not need const input shapes. The new dimension calculation can be done symbolically, then follows a tvm.ir_pass.Simplfy.
topi/python/topi/nn/dilate.py
Outdated
Parameters | ||
---------- | ||
Input : tvm.Tensor | ||
4-D, can be any layout. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can directly ask user to pass in nd tensor
topi/python/topi/nn/dilate.py
Outdated
4-D, the same layout as Input. | ||
""" | ||
A, B, C, D = Input.shape | ||
sa, sb, sc, sd = strides |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use python's list unpack tricks to support nd tensor
function(*list)
I convert caffe model to mxnet model, I use mxnet as frontend, but I get a error: FILE nnvm/python/nnvm/top/nn.py, in compute_conv2d. Can I get some help ? |
@sgxu You can compose dilated conv using topi.nn.dilate and topi.nn.conv2d It won't be difficult, just not implemented now. |
…pache#316) This PR removes the `global_symbol` linkage added by Relay Translator. It also fixes unaddressed comments of apache#262. All tests can pass locally and I believe it is safe to merge this PR directly.
…pache#316) This PR removes the `global_symbol` linkage added by Relay Translator. It also fixes unaddressed comments of apache#262. All tests can pass locally and I believe it is safe to merge this PR directly.
…pache#316) This PR removes the `global_symbol` linkage added by Relay Translator. It also fixes unaddressed comments of apache#262. All tests can pass locally and I believe it is safe to merge this PR directly.
can be (has been proven) useful in implementing backward propagation of depthwise convolution #294