Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update get_valid strategy #10684

Merged
merged 11 commits into from
Feb 21, 2023
1 change: 1 addition & 0 deletions ivy_tests/test_ivy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def run_around_tests(request, on_device, backend_fw, compile_graph, implicit):
request.function.test_data,
backend_fw.backend,
request.function.ground_truth_backend,
on_device,
)
except Exception as e:
test_globals.teardown_api_test()
Expand Down
13 changes: 0 additions & 13 deletions ivy_tests/test_ivy/helpers/function_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,6 @@ def empty_func(*args, **kwargs):
is_torch_native_array = empty_func


# ToDo, this is temporary until unsupported_dtype is embedded
# into helpers.get_dtypes
def _assert_dtypes_are_valid(input_dtypes: Union[List[ivy.Dtype], List[str]]):
for dtype in input_dtypes:
if dtype not in ivy.valid_dtypes + ivy.valid_complex_dtypes:
raise Exception(f"{dtype} is not a valid data type.")


# Function testing


Expand Down Expand Up @@ -213,7 +205,6 @@ def test_function(
>>> x2 = np.array([-3, 15, 24])
>>> test_function(input_dtypes, test_flags, fw, fn_name, x1=x1, x2=x2)
"""
_assert_dtypes_are_valid(input_dtypes)
# split the arguments into their positional and keyword components
args_np, kwargs_np = kwargs_to_args_n_kwargs(
num_positional_args=test_flags.num_positional_args, kwargs=all_as_kwargs_np
Expand Down Expand Up @@ -1037,8 +1028,6 @@ def test_method(
ret_gt
optional, return value from the Ground Truth function
"""
_assert_dtypes_are_valid(method_input_dtypes)

init_input_dtypes = ivy.default(init_input_dtypes, [])

# Constructor arguments #
Expand Down Expand Up @@ -1344,8 +1333,6 @@ def test_frontend_method(
"""
if isinstance(frontend, list):
frontend, frontend_proc = frontend
_assert_dtypes_are_valid(init_input_dtypes)
_assert_dtypes_are_valid(method_input_dtypes)

# split the arguments into their positional and keyword components

Expand Down
23 changes: 21 additions & 2 deletions ivy_tests/test_ivy/helpers/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
CURRENT_BACKEND: callable = _Notsetval
CURRENT_FRONTEND: callable = _Notsetval
CURRENT_RUNNING_TEST = _Notsetval
CURRENT_DEVICE = _Notsetval
CURRENT_FRONTEND_STR = ""


Expand Down Expand Up @@ -139,28 +140,34 @@ def _get_ivy_torch(version=None):
# Setup


def setup_api_test(test_data: TestData, backend: str, ground_truth_backend: str):
def setup_api_test(
test_data: TestData, backend: str, ground_truth_backend: str, device: str
):
_set_test_data(test_data)
_set_backend(backend)
_set_device(device)
_set_ground_truth_backend(ground_truth_backend)


def teardown_api_test():
_unset_test_data()
_unset_backend()
_unset_device()
_unset_ground_truth_backend()


def setup_frontend_test(test_data: TestData, frontend: str, backend: str):
def setup_frontend_test(test_data: TestData, frontend: str, backend: str, device: str):
_set_test_data(test_data)
_set_frontend(frontend)
_set_backend(backend)
_set_device(device)


def teardown_frontend_test():
_unset_test_data()
_unset_frontend()
_unset_backend()
_unset_device()


def _set_test_data(test_data: TestData):
Expand Down Expand Up @@ -199,6 +206,13 @@ def _set_ground_truth_backend(framework: str):
CURRENT_GROUND_TRUTH_BACKEND = FWS_DICT[framework]


def _set_device(device: str):
global CURRENT_DEVICE
if CURRENT_DEVICE is not _Notsetval:
raise InterruptedTest(CURRENT_RUNNING_TEST)
CURRENT_DEVICE = device


# Teardown


Expand All @@ -220,3 +234,8 @@ def _unset_backend():
def _unset_ground_truth_backend():
global CURRENT_GROUND_TRUTH_BACKEND
CURRENT_GROUND_TRUTH_BACKEND = _Notsetval


def _unset_device():
global CURRENT_DEVICE
CURRENT_DEVICE = _Notsetval
159 changes: 83 additions & 76 deletions ivy_tests/test_ivy/helpers/hypothesis_helpers/dtype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,59 @@
from .. import globals as test_globals


_dtype_kind_keys = {
"valid",
"numeric",
"float",
"unsigned",
"integer",
"signed_integer",
"complex",
"real_and_complex",
"float_and_complex",
"bool",
}


def _get_fn_dtypes(framework, kind="valid"):
return test_globals.CURRENT_RUNNING_TEST.supported_device_dtypes[framework.backend][
test_globals.CURRENT_DEVICE
][kind]


def _get_type_dict(framework, kind):
if kind == "valid":
return framework.valid_dtypes
elif kind == "numeric":
return framework.valid_numeric_dtypes
elif kind == "integer":
return framework.valid_int_dtypes
elif kind == "float":
return framework.valid_float_dtypes
elif kind == "unsigned":
return framework.valid_int_dtypes
elif kind == "signed_integer":
return tuple(
set(framework.valid_int_dtypes).difference(framework.valid_uint_dtypes)
)
elif kind == "complex":
return framework.valid_complex_dtypes
elif kind == "real_and_complex":
return tuple(
set(framework.valid_numeric_dtypes).union(framework.valid_complex_dtypes)
)
elif kind == "float_and_complex":
return tuple(
set(framework.valid_float_dtypes).union(framework.valid_complex_dtypes)
)
elif kind == "bool":
return tuple(
set(framework.valid_dtypes).difference(framework.valid_numeric_dtypes)
)
else:
raise RuntimeError("{} is an unknown kind!".format(kind))


def make_json_pickable(s):
s = s.replace("builtins.bfloat16", "ivy.bfloat16")
s = s.replace("jax._src.device_array.reconstruct_device_array", "jax.numpy.array")
Expand All @@ -22,7 +75,7 @@ def make_json_pickable(s):

@st.composite
def get_dtypes(
draw, kind, index=0, full=True, none=False, key=None, prune_function=True
draw, kind="valid", index=0, full=True, none=False, key=None, prune_function=True
):
"""
Draws a valid dtypes for the test function. For frontend tests,
Expand Down Expand Up @@ -51,91 +104,45 @@ def get_dtypes(
dtype string
"""

def _get_type_dict(framework):
return {
"valid": framework.valid_dtypes,
"numeric": framework.valid_numeric_dtypes,
"float": framework.valid_float_dtypes,
"integer": framework.valid_int_dtypes,
"unsigned": framework.valid_uint_dtypes,
"signed_integer": tuple(
set(framework.valid_int_dtypes).difference(framework.valid_uint_dtypes)
),
"complex": framework.valid_complex_dtypes,
"real_and_complex": tuple(
set(framework.valid_numeric_dtypes).union(
framework.valid_complex_dtypes
)
),
"float_and_complex": tuple(
set(framework.valid_float_dtypes).union(framework.valid_complex_dtypes)
),
"bool": tuple(
set(framework.valid_dtypes).difference(framework.valid_numeric_dtypes)
),
}

# TODO refactor this so we run the intersection in a chained clean way
backend_dtypes = _get_type_dict(ivy)[kind]

if test_globals.CURRENT_FRONTEND is not test_globals._Notsetval or isinstance(
test_globals.CURRENT_FRONTEND_STR, list
): # NOQA
if isinstance(test_globals.CURRENT_FRONTEND_STR, list):
process = test_globals.CURRENT_FRONTEND_STR[1]
try:
process.stdin.write("1" + "\n")
process.stdin.flush()
except Exception as e:
print(
"Something bad happened to the subprocess, here are the logs:\n\n"
)
print(process.stdout.readlines())
raise e
frontend_ret = process.stdout.readline()
if frontend_ret:
frontend_ret = jsonpickle.loads(make_json_pickable(frontend_ret))
else:
print(process.stderr.readlines())
raise Exception
fw_dtypes = frontend_ret[kind]
valid_dtypes = tuple(set(fw_dtypes).intersection(backend_dtypes))
if prune_function:
retrieval_fn = _get_fn_dtypes
if test_globals.CURRENT_RUNNING_TEST is not test_globals._Notsetval:
valid_dtypes = set(retrieval_fn(test_globals.CURRENT_BACKEND()))
else:
fw_dtypes = _get_type_dict(test_globals.CURRENT_FRONTEND())[kind]
valid_dtypes = tuple(set(fw_dtypes).intersection(backend_dtypes))
raise RuntimeError(
"No function is set to prune, calling "
"prune_function=True without a function is redundant."
)
else:
valid_dtypes = backend_dtypes
retrieval_fn = _get_type_dict
valid_dtypes = set(retrieval_fn(ivy, kind))

# The function may be called from a frontend test or an IVY api test
# In the case of a IVY api test, the function should make sure it returns a valid
# dtypes for the backend and also for the ground truth backend, if it is called from
# a frontend test, we should also count for the frontend support data types
# In conclusion, the following operations will get the intersection of
# FN_DTYPES & BACKEND_DTYPES & FRONTEND_DTYPES & GROUND_TRUTH_DTYPES

# If being called from a frontend test
if test_globals.CURRENT_FRONTEND is not test_globals._Notsetval: # NOQA
frontend_dtypes = retrieval_fn(test_globals.CURRENT_FRONTEND(), kind)
valid_dtypes = valid_dtypes.intersection(frontend_dtypes)

# Make sure we return dtypes that are compatiable with ground truth backend
ground_truth_is_set = (
test_globals.CURRENT_GROUND_TRUTH_BACKEND is not test_globals._Notsetval # NOQA
)
if ground_truth_is_set:
gtb_dtypes = _get_type_dict(test_globals.CURRENT_GROUND_TRUTH_BACKEND())[kind]
valid_dtypes = tuple(set(gtb_dtypes).intersection(valid_dtypes))

# TODO, do this in a better way...
if (
prune_function
and test_globals.CURRENT_RUNNING_TEST is not test_globals._Notsetval
): # NOQA
fn_dtypes = test_globals.CURRENT_RUNNING_TEST.supported_device_dtypes
valid_dtypes = set(valid_dtypes).intersection(
fn_dtypes[test_globals.CURRENT_BACKEND().backend]["cpu"]
valid_dtypes = valid_dtypes.intersection(
retrieval_fn(test_globals.CURRENT_GROUND_TRUTH_BACKEND(), kind)
)
if ground_truth_is_set:
valid_dtypes = tuple(
valid_dtypes.intersection(
fn_dtypes[test_globals.CURRENT_GROUND_TRUTH_BACKEND().backend][
"cpu"
]
)
)
else:
valid_dtypes = tuple(valid_dtypes)

valid_dtypes = list(valid_dtypes)
if none:
valid_dtypes += (None,)
valid_dtypes.append(None)
if full:
return list(valid_dtypes[index:])
return valid_dtypes[index:]
if key is None:
return [draw(st.sampled_from(valid_dtypes[index:]))]
return [draw(st.shared(st.sampled_from(valid_dtypes[index:]), key=key))]
Expand Down
30 changes: 27 additions & 3 deletions ivy_tests/test_ivy/helpers/testing_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
available_frameworks,
ground_truth,
)
from ivy_tests.test_ivy.helpers.hypothesis_helpers.dtype_helpers import (
_dtype_kind_keys,
_get_type_dict,
)

ground_truth = ground_truth()

Expand Down Expand Up @@ -180,7 +184,13 @@ def _get_method_supported_devices_dtypes(
for b in backends: # ToDo can optimize this ?
ivy.set_backend(b)
_fn = getattr(class_module.__dict__[class_name], method_name)
supported_device_dtypes[b] = ivy.function_supported_devices_and_dtypes(_fn)
devices_and_dtypes = ivy.function_supported_devices_and_dtypes(_fn)
organized_dtypes = {}
for device in devices_and_dtypes.keys():
organized_dtypes[device] = _partition_dtypes_into_kinds(
ivy, devices_and_dtypes[device]
)
supported_device_dtypes[b] = organized_dtypes
ivy.unset_backend()
return supported_device_dtypes

Expand Down Expand Up @@ -212,15 +222,29 @@ def _get_supported_devices_dtypes(fn_name: str, fn_module: str):

backends = available_frameworks()
for b in backends: # ToDo can optimize this ?

ivy.set_backend(b)
_tmp_mod = importlib.import_module(fn_module)
_fn = _tmp_mod.__dict__[fn_name]
supported_device_dtypes[b] = ivy.function_supported_devices_and_dtypes(_fn)
devices_and_dtypes = ivy.function_supported_devices_and_dtypes(_fn)
organized_dtypes = {}
for device in devices_and_dtypes.keys():
organized_dtypes[device] = _partition_dtypes_into_kinds(
ivy, devices_and_dtypes[device]
)
supported_device_dtypes[b] = organized_dtypes
ivy.unset_backend()
return supported_device_dtypes


def _partition_dtypes_into_kinds(framework, dtypes):
partitioned_dtypes = {}
for kind in _dtype_kind_keys:
partitioned_dtypes[kind] = set(_get_type_dict(framework, kind)).intersection(
dtypes
)
return partitioned_dtypes


# Decorators


Expand Down
2 changes: 1 addition & 1 deletion ivy_tests/test_ivy/test_frontends/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def run_around_tests(request, on_device, backend_fw, frontend, compile_graph, im
if hasattr(request.function, "test_data"):
try:
test_globals.setup_frontend_test(
request.function.test_data, frontend, backend_fw.backend
request.function.test_data, frontend, backend_fw.backend, on_device
)
except Exception as e:
test_globals.teardown_frontend_test()
Expand Down