Skip to content

Commit

Permalink
frontend tensorflow.raw_ops Update (#4860)
Browse files Browse the repository at this point in the history
  • Loading branch information
JerryGCDing authored Sep 24, 2022
1 parent fad0c76 commit ef1d4a7
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 2 deletions.
33 changes: 31 additions & 2 deletions ivy/functional/frontends/tensorflow/raw_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
import ivy


def AddN(*, inputs, name="AddN"):
inputs = ivy.array(inputs)
return ivy.sum(inputs, axis=0, dtype=inputs.dtype)


def Acos(*, x, name="Acos"):
return ivy.acos(x)

Expand Down Expand Up @@ -63,6 +68,16 @@ def Cosh(*, x, name="cosh"):
return ivy.cosh(x)


def Equal(*, x, y, incompatible_shape_error=True, name="Equal"):
if incompatible_shape_error:
return ivy.equal(x, y)

try:
ivy.equal(x, y)
except (ivy.exceptions.IvyError, ivy.exceptions.IvyBackendException):
return ivy.array(False)


def Exp(*, x, name="Exp"):
return ivy.exp(x)

Expand Down Expand Up @@ -95,11 +110,11 @@ def Log(*, x, name="Log"):
return ivy.log(x)


def LogicalOr(*, x, y, name=None):
def LogicalOr(*, x, y, name="LogicalOr"):
return ivy.logical_or(x, y)


def LogicalNot(*, x, name=None):
def LogicalNot(*, x, name="LogicalNot"):
return ivy.logical_not(x)


Expand All @@ -111,6 +126,20 @@ def Minimum(*, x, y, name="Minimum"):
return ivy.minimum(x, y)


def Neg(*, x, name="Neg"):
return ivy.negative(x)


def NotEqual(*, x, y, incompatible_shape_error=True, name="NotEqual"):
if incompatible_shape_error:
return ivy.not_equal(x, y)

try:
ivy.not_equal(x, y)
except (ivy.exceptions.IvyError, ivy.exceptions.IvyBackendException):
return ivy.array(False)


def Reshape(*, tensor, shape, name="Reshape"):
return ivy.reshape(tensor, shape)

Expand Down
123 changes: 123 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,3 +1147,126 @@ def test_tensorflow_Shape(
fn_tree="raw_ops.Shape",
input=np.asarray(x, dtype=input_dtype),
)


# AddN
@handle_cmd_line_args
@given(
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("numeric"),
min_num_dims=1,
),
num_positional_args=helpers.num_positional_args(
fn_name="ivy.functional.frontends.tensorflow.AddN"
),
)
def test_tensorflow_AddN(
dtype_and_x, as_variable, num_positional_args, native_array, fw
):
input_dtype, x = dtype_and_x
helpers.test_frontend_function(
input_dtypes=input_dtype,
as_variable_flags=as_variable,
with_out=False,
num_positional_args=num_positional_args,
native_array_flags=native_array,
fw=fw,
frontend="tensorflow",
fn_tree="raw_ops.AddN",
inputs=x,
)


# Neg
@handle_cmd_line_args
@given(
dtype_and_x=helpers.dtype_and_values(
available_dtypes=[
"bfloat16",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
],
),
num_positional_args=helpers.num_positional_args(
fn_name="ivy.functional.frontends.tensorflow.Neg"
),
)
def test_tensorflow_Neg(
dtype_and_x, as_variable, num_positional_args, native_array, fw
):
input_dtype, x = dtype_and_x
helpers.test_frontend_function(
input_dtypes=input_dtype,
as_variable_flags=as_variable,
with_out=False,
num_positional_args=num_positional_args,
native_array_flags=native_array,
fw=fw,
frontend="tensorflow",
fn_tree="raw_ops.Neg",
x=x,
)


# Equal
@handle_cmd_line_args
@given(
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("numeric"),
num_arrays=2,
shared_dtype=True,
),
num_positional_args=helpers.num_positional_args(
fn_name="ivy.functional.frontends.tensorflow.Equal"
),
)
def test_tensorflow_Equal(
dtype_and_x, as_variable, num_positional_args, native_array, fw
):
input_dtype, x = dtype_and_x
helpers.test_frontend_function(
input_dtypes=input_dtype,
as_variable_flags=as_variable,
with_out=False,
num_positional_args=num_positional_args,
native_array_flags=native_array,
fw=fw,
frontend="tensorflow",
fn_tree="raw_ops.Equal",
x=np.array(x[0], dtype=input_dtype[0]),
y=np.array(x[1], dtype=input_dtype[1]),
)


# NotEqual
@handle_cmd_line_args
@given(
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("numeric"),
num_arrays=2,
shared_dtype=True,
),
num_positional_args=helpers.num_positional_args(
fn_name="ivy.functional.frontends.tensorflow.NotEqual"
),
)
def test_tensorflow_NotEqual(
dtype_and_x, as_variable, num_positional_args, native_array, fw
):
input_dtype, x = dtype_and_x
helpers.test_frontend_function(
input_dtypes=input_dtype,
as_variable_flags=as_variable,
with_out=False,
num_positional_args=num_positional_args,
native_array_flags=native_array,
fw=fw,
frontend="tensorflow",
fn_tree="raw_ops.NotEqual",
x=np.array(x[0], dtype=input_dtype[0]),
y=np.array(x[1], dtype=input_dtype[1]),
)

0 comments on commit ef1d4a7

Please sign in to comment.