Skip to content

Commit

Permalink
ivy-llc#10505 : Conv2D Operations for TensorFlow frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
AryanSharma9917 committed Mar 7, 2023
1 parent 88d7d00 commit 6bf56d4
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 0 deletions.
29 changes: 29 additions & 0 deletions ivy/functional/frontends/tensorflow/raw_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,35 @@ def EuclideanNorm(*, input, axis, keep_dims=False, name="EuclideanNorm"):
ConcatV2 = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.concat))


@to_ivy_arrays_and_back
def Conv2D(
*,
input,
filter,
strides,
padding,
data_format="NHWC",
dilations=[1, 1, 1, 1],
name="Conv2D",
):
if data_format == "NDHWC":
strides = [1] + strides[1:-1] + [1]
dilations = [1] + dilations[1:-1] + [1]
elif data_format == "NCDHW":
strides = [1, 1] + strides[2:] + [1]
dilations = [1, 1] + dilations[2:] + [1]

return tf_frontend.nn.conv2d(
input,
filter[:, :, None, :, :],
strides,
padding,
data_format=data_format,
dilations=dilations,
name=name,
)


@to_ivy_arrays_and_back
def Conv3D(
*,
Expand Down
43 changes: 43 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2920,6 +2920,49 @@ def test_tensorflow_ConcatV2(
)


# Conv2D
@handle_frontend_test(
fn_tree="tensorflow.nn.conv2d",
x_f_d_df=_x_and_filters(
dtypes=helpers.get_dtypes("float", full=False),
data_format=st.sampled_from(["NHWC"]),
padding=st.sampled_from(["SAME", "VALID"]),
type="2d",
dilation_min=1,
dilation_max=1,
),
test_with_out=st.just(False),
number_positional_args=st.just(0),
)
def test_tensorflow_Conv2D(
*,
x_f_d_df,
test_flags,
frontend,
fn_tree,
on_device,
):
input_dtype, x, filters, dilation, data_format, stride, padding = x_f_d_df

stride = _convolution_broadcast_helper(
stride, num_spatial_dims=2, channel_index=3, name="strides"
)

helpers.test_frontend_function(
input_dtypes=input_dtype,
test_flags=test_flags,
frontend=frontend,
fn_tree=fn_tree,
on_device=on_device,
input=x,
filter=filters[:, :, None, :],
strides=stride,
padding=padding,
data_format=data_format,
dilations=[1, 1, 1, 1],
)


# Conv3D
@handle_frontend_test(
fn_tree="tensorflow.raw_ops.Conv3D",
Expand Down

0 comments on commit 6bf56d4

Please sign in to comment.