Skip to content

Commit

Permalink
Added torch frontend function quantile() and its test 2.0 (#10573)
Browse files Browse the repository at this point in the history
Co-authored-by: @AnnaTz
  • Loading branch information
ErukhimovaN authored Feb 16, 2023
1 parent bd3d07e commit 07df9fb
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
11 changes: 11 additions & 0 deletions ivy/functional/frontends/torch/reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,14 @@ def aminmax(input, *, dim=None, keepdim=False, out=None):
"jax": ("float16", "bfloat16"),
"tensorflow": ("float16", "bfloat16"),
}

@to_ivy_arrays_and_back
def quantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None):
return ivy.quantile(input, q, axis=dim, keepdims=keepdim, interpolation=interpolation, out=out)

quantile.unsupported_dtypes = {
"torch": ("float16", "bfloat16"),
"numpy": ("float16", "bfloat16"),
"jax": ("float16", "bfloat16"),
"tensorflow": ("float16", "bfloat16"),
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ivy_tests.test_ivy.test_functional.test_core.test_statistical import (
statistical_dtype_values,
)

from ivy_tests.test_ivy.test_functional.test_experimental.test_core.test_statistical import statistical_dtype_values as statistical_dtype_values_experimental

@handle_frontend_test(
fn_tree="torch.dist",
Expand Down Expand Up @@ -666,3 +666,34 @@ def test_torch_aminmax(
dim=axis,
keepdim=keepdims,
)

@handle_frontend_test(
fn_tree="torch.quantile",
dtype_and_x=statistical_dtype_values_experimental(function="quantile"),
keepdims=st.booleans()
)

def test_torch_quantile(
*,
dtype_and_x,
keepdims,
on_device,
fn_tree,
frontend,
test_flags,
):
input_dtype, x, axis, interpolation, q = dtype_and_x
if type(axis) is tuple:
axis=axis[0]
helpers.test_frontend_function(
input_dtypes=input_dtype,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
input=x[0],
q=q,
dim=axis,
keepdim=keepdims,
interpolation=interpolation[0]
)

0 comments on commit 07df9fb

Please sign in to comment.