diff --git a/ivy/functional/frontends/tensorflow/nn.py b/ivy/functional/frontends/tensorflow/nn.py index 2f96c984fc881..3c24579a442ca 100644 --- a/ivy/functional/frontends/tensorflow/nn.py +++ b/ivy/functional/frontends/tensorflow/nn.py @@ -19,13 +19,13 @@ def _reduce_strides_dilations(dim, stride, dilations): @to_ivy_arrays_and_back def atrous_conv2d(value, filters, rate, padding): - return ivy.conv2d(value, filters, 1, padding, dilations=[rate]*2) + return ivy.conv2d(value, filters, 1, padding, dilations=[rate] * 2) @to_ivy_arrays_and_back def atrous_conv2d_transpose(value, filters, output_shape, rate, padding): return ivy.conv2d_transpose( - value, filters, 1, padding, output_shape=output_shape, dilations=[rate]*2 + value, filters, 1, padding, output_shape=output_shape, dilations=[rate] * 2 ) @@ -145,13 +145,15 @@ def depthwise_conv2d( ): strides, dilations = _reduce_strides_dilations(2, strides, dilations) fc = filter.shape[-2] - filter = filter.reshape([*filter.shape[0:2], 1, filter.shape[-2]*filter.shape[-1]]) + filter = filter.reshape( + [*filter.shape[0:2], 1, filter.shape[-2] * filter.shape[-1]] + ) return ivy.conv_general_dilated( input, filter, strides, padding, - data_format='channel_last' if data_format[-1] == 'C' else 'channel_first', + data_format="channel_last" if data_format[-1] == "C" else "channel_first", dilations=dilations, feature_group_count=fc, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py index f09829901ff1e..23e3f75b9fe75 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py @@ -49,7 +49,9 @@ def _x_and_filters( dilations = draw( st.one_of( st.integers(dilation_min, dilation_max), - st.lists(st.integers(dilation_min, dilation_max), min_size=dim, max_size=dim), + st.lists( + st.integers(dilation_min, dilation_max), min_size=dim, max_size=dim + ), ) ) if atrous: @@ -58,7 +60,9 @@ def _x_and_filters( stride = draw( st.one_of( st.integers(stride_min, stride_max), - st.lists(st.integers(stride_min, stride_max), min_size=dim, max_size=dim), + st.lists( + st.integers(stride_min, stride_max), min_size=dim, max_size=dim + ), ) ) fstride = [stride] * dim if isinstance(stride, int) else stride @@ -104,7 +108,9 @@ def _x_and_filters( if transpose: output_shape = [ x_shape[0], - _deconv_length(x_w, fstride[0], filter_shape[0], padding, fdilations[0]), + _deconv_length( + x_w, fstride[0], filter_shape[0], padding, fdilations[0] + ), d_in, ] elif dim == 2: