Skip to content

Commit

Permalink
frontend tensorflow.raw_ops Update (#5011)
Browse files Browse the repository at this point in the history
  • Loading branch information
JerryGCDing authored Sep 27, 2022
1 parent c913a85 commit b47cc53
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 5 deletions.
8 changes: 8 additions & 0 deletions ivy/functional/frontends/tensorflow/raw_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ def FloorDiv(*, x, y, name="FloorDiv"):
return ivy.floor_divide(x, y)


def Greater(*, x, y, name="Greater"):
return ivy.greater(x, y)


def GreaterEqual(*, x, y, name="GreaterEqual"):
return ivy.greater_equal(x, y)


def Less(*, x, y, name="Less"):
return ivy.less(x, y)

Expand Down
56 changes: 51 additions & 5 deletions ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,16 +1357,12 @@ def test_tensorflow_Relu(dtype_and_x, as_variable, fw, native_array):
),
transpose_a=st.booleans(),
transpose_b=st.booleans(),
num_positional_args=helpers.num_positional_args(
fn_name="ivy.functional.frontends.tensorflow.MatMul"
),
)
def test_tensroflow_MatMul(
dtype_and_x,
transpose_a,
transpose_b,
as_variable,
num_positional_args,
native_array,
fw,
):
Expand All @@ -1376,7 +1372,7 @@ def test_tensroflow_MatMul(
input_dtypes=input_dtype,
as_variable_flags=as_variable,
with_out=False,
num_positional_args=num_positional_args,
num_positional_args=0,
native_array_flags=native_array,
fw=fw,
frontend="tensorflow",
Expand All @@ -1386,3 +1382,53 @@ def test_tensroflow_MatMul(
transpose_a=transpose_a,
transpose_b=transpose_b,
)


# Greater
@handle_cmd_line_args
@given(
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("numeric"),
num_arrays=2,
shared_dtype=True,
),
)
def test_tensorflow_Greater(dtype_and_x, as_variable, fw, native_array):
input_dtype, x = dtype_and_x
helpers.test_frontend_function(
input_dtypes=input_dtype,
as_variable_flags=as_variable,
with_out=False,
num_positional_args=0,
native_array_flags=native_array,
fw=fw,
frontend="tensorflow",
fn_tree="raw_ops.Greater",
x=x[0],
y=x[1],
)


# GreaterEqual
@handle_cmd_line_args
@given(
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("numeric"),
num_arrays=2,
shared_dtype=True,
),
)
def test_tensorflow_GreaterEqual(dtype_and_x, as_variable, fw, native_array):
input_dtype, x = dtype_and_x
helpers.test_frontend_function(
input_dtypes=input_dtype,
as_variable_flags=as_variable,
with_out=False,
num_positional_args=0,
native_array_flags=native_array,
fw=fw,
frontend="tensorflow",
fn_tree="raw_ops.GreaterEqual",
x=x[0],
y=x[1],
)

0 comments on commit b47cc53

Please sign in to comment.