Skip to content

Commit

Permalink
Remove unsupported_dtypes (ivy-llc#10745)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dsantra92 authored Feb 22, 2023
1 parent c5172e0 commit a5ed816
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 44 deletions.
14 changes: 0 additions & 14 deletions ivy/functional/backends/tensorflow/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,17 +796,3 @@ def vector_to_skew_symmetric_matrix(
# BS x 3 x 3
ret = tf.concat((row1, row2, row3), -2)
return ret


vector_to_skew_symmetric_matrix.unsupported_dtypes = (
"int8",
"int16",
"int32",
"int64",
"uint8",
"uint16",
"uint32",
"uint64",
"float16",
"float64",
)
5 changes: 4 additions & 1 deletion ivy/functional/backends/torch/experimental/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def lcm(
lcm.support_native_out = True


@with_unsupported_dtypes(
{"2.9.1 and below": ("bfloat16",)},
backend_version,
)
def fmod(
x1: torch.Tensor,
x2: torch.Tensor,
Expand All @@ -39,7 +43,6 @@ def fmod(


fmod.support_native_out = True
fmod.unsupported_dtypes = ("bfloat16",)


def fmax(
Expand Down
65 changes: 36 additions & 29 deletions ivy/functional/frontends/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def conv3d_transpose(
)


@with_unsupported_dtypes({"2.9.1 and below": ("bfloat16",)}, "tensorflow")
@to_ivy_arrays_and_back
def depthwise_conv2d(
input,
Expand All @@ -164,9 +165,6 @@ def depthwise_conv2d(
)


depthwise_conv2d.unsupported_dtypes = ("bfloat16",)


@to_ivy_arrays_and_back
def batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name=None):
inv = 1.0 / ivy.sqrt(variance + variance_epsilon)
Expand All @@ -183,22 +181,37 @@ def dropout(x, rate, noise_shape=None, seed=None, name=None):
return ivy.dropout(x, rate, noise_shape=noise_shape, seed=seed)


@with_unsupported_dtypes(
{
"2.9.1": (
"int8",
"int16",
"int32",
"int64",
"bool",
"bfloat16",
)
},
"tensorflow",
)
@to_ivy_arrays_and_back
def silu(features, beta: float = 1.0):
beta = ivy.astype(ivy.array(beta), ivy.dtype(features))
return ivy.multiply(features, ivy.sigmoid(ivy.multiply(beta, features)))


silu.unsupported_dtypes = (
"int8",
"int16",
"int32",
"int64",
"bool",
"bfloat16",
@with_unsupported_dtypes(
{
"2.9.1": (
"int8",
"int16",
"int32",
"int64",
"bool",
)
},
"tensorflow",
)


@to_ivy_arrays_and_back
def sigmoid_cross_entropy_with_logits(labels=None, logits=None, name=None):
ivy.utils.assertions.check_shape(labels, logits)
Expand All @@ -210,15 +223,18 @@ def sigmoid_cross_entropy_with_logits(labels=None, logits=None, name=None):
return ivy.add(ret_val, ivy.log1p(ivy.exp(neg_abs_logits)))


sigmoid_cross_entropy_with_logits.unsupported_dtypes = (
"int8",
"int16",
"int32",
"int64",
"bool",
@with_unsupported_dtypes(
{
"2.9.1": (
"int8",
"int16",
"int32",
"int64",
"bool",
)
},
"tensorflow",
)


@to_ivy_arrays_and_back
def weighted_cross_entropy_with_logits(
labels=None, logits=None, pos_weight=1.0, name=None
Expand All @@ -239,15 +255,6 @@ def weighted_cross_entropy_with_logits(
return ivy.add(first_term, second_term)


weighted_cross_entropy_with_logits.unsupported_dtypes = (
"int8",
"int16",
"int32",
"int64",
"bool",
)


@with_supported_dtypes(
{"2.9.0 and below": ("float32", "float16", "bfloat16")}, "tensorflow"
)
Expand Down

0 comments on commit a5ed816

Please sign in to comment.