Skip to content

Commit

Permalink
Fixed OOM error in Torch frontend test
Browse files Browse the repository at this point in the history
  • Loading branch information
Jamie Line authored and JamieLine committed Sep 27, 2022
1 parent fbe2b30 commit c913a85
Showing 1 changed file with 57 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# global
import math

import numpy as np
from hypothesis import assume, given, strategies as st

Expand All @@ -8,6 +10,51 @@
from ivy_tests.test_ivy.helpers import handle_cmd_line_args


# helpers
@st.composite
def _get_repeat_interleaves_args(
draw, *, available_dtypes, valid_axis, max_num_dims, max_dim_size
):
values_dtype, values, axis, shape = draw(
helpers.dtype_values_axis(
available_dtypes=available_dtypes,
valid_axis=valid_axis,
force_int_axis=True,
shape=draw(
helpers.get_shape(
allow_none=False,
min_num_dims=0,
max_num_dims=max_num_dims,
min_dim_size=1,
max_dim_size=max_dim_size,
)
),
ret_shape=True,
)
)

if axis is None:
generate_repeats_as_integer = draw(st.booleans())
num_repeats = 1 if generate_repeats_as_integer else math.prod(tuple(shape))
else:
num_repeats = shape[axis]

repeats_dtype, repeats = draw(
helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("integer"),
min_value=0,
max_value=10,
shape=[num_repeats],
)
)

# Output size is an optional parameter accepted by Torch for optimisation
use_output_size = draw(st.booleans())
output_size = np.sum(repeats) if use_output_size else None

return [values_dtype, repeats_dtype], values, repeats, axis, output_size


# flip
@handle_cmd_line_args
@given(
Expand Down Expand Up @@ -686,47 +733,37 @@ def test_torch_logcumsumexp(

@handle_cmd_line_args
@given(
dtype_and_input_and_dim=helpers.dtype_values_axis(
dtype_values_repeats_axis_output_size=_get_repeat_interleaves_args(
available_dtypes=helpers.get_dtypes("valid"),
valid_axis=True,
max_num_dims=4,
max_dim_size=4,
),
dtype_and_repeats=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("integer"),
# Torch requires this.
max_num_dims=1,
min_num_dims=0,
),
# Generating the output size as a strategy would be much more
# complicated than necessary.
use_output_size=st.booleans(),
num_positional_args=helpers.num_positional_args(
fn_name="ivy.functional.frontends.torch.repeat_interleave",
),
)
def test_torch_repeat_interleave(
dtype_and_input_and_dim,
dtype_and_repeats,
use_output_size,
dtype_values_repeats_axis_output_size,
as_variable,
with_out,
num_positional_args,
native_array,
fw,
):
input_dtype, input, dim = dtype_and_input_and_dim
repeat_dtype, repeats = dtype_and_repeats
output_size = np.sum(repeats) if use_output_size else None
dtype, values, repeats, axis, output_size = dtype_values_repeats_axis_output_size

helpers.test_frontend_function(
input_dtypes=input_dtype + repeat_dtype,
input_dtypes=dtype,
with_out=with_out,
num_positional_args=num_positional_args,
as_variable_flags=as_variable,
native_array_flags=native_array,
fw=fw,
frontend="torch",
fn_tree="repeat_interleave",
input=input[0],
repeats=repeats[0],
dim=dim,
input=values,
repeats=repeats,
dim=axis,
output_size=output_size,
)

0 comments on commit c913a85

Please sign in to comment.