Skip to content

Commit

Permalink
lintfixbot: Auto-commit fixed lint errors in codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
ivy-branch committed Feb 16, 2023
1 parent 3b6f404 commit c1669e1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
10 changes: 6 additions & 4 deletions ivy/functional/frontends/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ def _reduce_strides_dilations(dim, stride, dilations):

@to_ivy_arrays_and_back
def atrous_conv2d(value, filters, rate, padding):
return ivy.conv2d(value, filters, 1, padding, dilations=[rate]*2)
return ivy.conv2d(value, filters, 1, padding, dilations=[rate] * 2)


@to_ivy_arrays_and_back
def atrous_conv2d_transpose(value, filters, output_shape, rate, padding):
return ivy.conv2d_transpose(
value, filters, 1, padding, output_shape=output_shape, dilations=[rate]*2
value, filters, 1, padding, output_shape=output_shape, dilations=[rate] * 2
)


Expand Down Expand Up @@ -145,13 +145,15 @@ def depthwise_conv2d(
):
strides, dilations = _reduce_strides_dilations(2, strides, dilations)
fc = filter.shape[-2]
filter = filter.reshape([*filter.shape[0:2], 1, filter.shape[-2]*filter.shape[-1]])
filter = filter.reshape(
[*filter.shape[0:2], 1, filter.shape[-2] * filter.shape[-1]]
)
return ivy.conv_general_dilated(
input,
filter,
strides,
padding,
data_format='channel_last' if data_format[-1] == 'C' else 'channel_first',
data_format="channel_last" if data_format[-1] == "C" else "channel_first",
dilations=dilations,
feature_group_count=fc,
)
Expand Down
12 changes: 9 additions & 3 deletions ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def _x_and_filters(
dilations = draw(
st.one_of(
st.integers(dilation_min, dilation_max),
st.lists(st.integers(dilation_min, dilation_max), min_size=dim, max_size=dim),
st.lists(
st.integers(dilation_min, dilation_max), min_size=dim, max_size=dim
),
)
)
if atrous:
Expand All @@ -58,7 +60,9 @@ def _x_and_filters(
stride = draw(
st.one_of(
st.integers(stride_min, stride_max),
st.lists(st.integers(stride_min, stride_max), min_size=dim, max_size=dim),
st.lists(
st.integers(stride_min, stride_max), min_size=dim, max_size=dim
),
)
)
fstride = [stride] * dim if isinstance(stride, int) else stride
Expand Down Expand Up @@ -104,7 +108,9 @@ def _x_and_filters(
if transpose:
output_shape = [
x_shape[0],
_deconv_length(x_w, fstride[0], filter_shape[0], padding, fdilations[0]),
_deconv_length(
x_w, fstride[0], filter_shape[0], padding, fdilations[0]
),
d_in,
]
elif dim == 2:
Expand Down

0 comments on commit c1669e1

Please sign in to comment.