Skip to content

Commit

Permalink
[skip ci][CI][Fix] Fixing lint (#10445)
Browse files Browse the repository at this point in the history
A linting issue was introduced in #10423, fixing this up.

Change-Id: I06c518194e30dcaa755005f06b8b7280c237d386
  • Loading branch information
lhutton1 authored Mar 2, 2022
1 parent 8f6fa8f commit a772de8
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions python/tvm/topi/cuda/conv2d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype, output_
stride_height, stride_width = stride
outpad_height, outpad_width = output_padding
assert outpad_height < stride_height and outpad_width < stride_width
assert inp_channels % groups == 0, f"input channels {inp_channels} must divide group size {groups}"
assert (
inp_channels % groups == 0
), f"input channels {inp_channels} must divide group size {groups}"
cfg.stride = stride
pad_top, pad_left, pad_bottom, pad_right = nn.get_pad_tuple(
padding, (kernel_height, kernel_width)
Expand Down Expand Up @@ -112,14 +114,14 @@ def conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype, output_
data_out = te.compute(
(batch, out_channels * groups, out_height, out_width),
lambda b, c, h, w: te.sum(
data[
b, c // out_channels * (inp_channels // groups) + dc, h + dh, w + dw
].astype(out_dtype)
data[b, c // out_channels * (inp_channels // groups) + dc, h + dh, w + dw].astype(
out_dtype
)
* kernel[
c // out_channels * (inp_channels // groups) + dc,
c % out_channels,
kernel_height - 1 - dh,
kernel_width - 1 - dw
kernel_width - 1 - dw,
].astype(out_dtype),
axis=[dc, dh, dw],
),
Expand Down

0 comments on commit a772de8

Please sign in to comment.