Skip to content

Commit

Permalink
lintfixbot: Auto-commit fixed lint errors in codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
ivy-branch committed Feb 17, 2023
1 parent 9d1e37d commit e44891b
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 11 deletions.
5 changes: 4 additions & 1 deletion ivy/array/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,10 @@ def __ivy_array_function__(self, func, types, args, kwargs):
and (t.__ivy_array_function__ is not ivy.Array.__ivy_array_function__)
or (
hasattr(ivy.NativeArray, "__ivy_array_function__")
and (t.__ivy_array_function__ is not ivy.NativeArray.__ivy_array_function__)
and (
t.__ivy_array_function__
is not ivy.NativeArray.__ivy_array_function__
)
)
):
return NotImplemented
Expand Down
4 changes: 3 additions & 1 deletion ivy/functional/backends/torch/experimental/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def quantile(

temp = a.reshape((-1,) + tuple(desired_shape))

return torch.quantile(temp, q, dim=0, keepdim=keepdims, interpolation=interpolation)
return torch.quantile(
temp, q, dim=0, keepdim=keepdims, interpolation=interpolation
)

return torch.quantile(a, q, dim=axis, keepdim=keepdims, interpolation=interpolation)

Expand Down
10 changes: 7 additions & 3 deletions ivy/functional/frontends/torch/reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,17 @@ def aminmax(input, *, dim=None, keepdim=False, out=None):
"tensorflow": ("float16", "bfloat16"),
}


@to_ivy_arrays_and_back
def quantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None):
return ivy.quantile(input, q, axis=dim, keepdims=keepdim, interpolation=interpolation, out=out)
def quantile(input, q, dim=None, keepdim=False, *, interpolation="linear", out=None):
return ivy.quantile(
input, q, axis=dim, keepdims=keepdim, interpolation=interpolation, out=out
)


quantile.unsupported_dtypes = {
"torch": ("float16", "bfloat16"),
"numpy": ("float16", "bfloat16"),
"jax": ("float16", "bfloat16"),
"tensorflow": ("float16", "bfloat16"),
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from ivy_tests.test_ivy.test_functional.test_core.test_statistical import (
statistical_dtype_values,
)
from ivy_tests.test_ivy.test_functional.test_experimental.test_core.test_statistical import statistical_dtype_values as statistical_dtype_values_experimental
from ivy_tests.test_ivy.test_functional.test_experimental.test_core.test_statistical import (
statistical_dtype_values as statistical_dtype_values_experimental,
)


@handle_frontend_test(
fn_tree="torch.dist",
Expand Down Expand Up @@ -667,12 +670,12 @@ def test_torch_aminmax(
keepdim=keepdims,
)


@handle_frontend_test(
fn_tree="torch.quantile",
dtype_and_x=statistical_dtype_values_experimental(function="quantile"),
keepdims=st.booleans()
keepdims=st.booleans(),
)

def test_torch_quantile(
*,
dtype_and_x,
Expand All @@ -684,7 +687,7 @@ def test_torch_quantile(
):
input_dtype, x, axis, interpolation, q = dtype_and_x
if type(axis) is tuple:
axis=axis[0]
axis = axis[0]
helpers.test_frontend_function(
input_dtypes=input_dtype,
frontend=frontend,
Expand All @@ -695,5 +698,5 @@ def test_torch_quantile(
q=q,
dim=axis,
keepdim=keepdims,
interpolation=interpolation[0]
)
interpolation=interpolation[0],
)

0 comments on commit e44891b

Please sign in to comment.