diff --git a/ivy_tests/test_ivy/conftest.py b/ivy_tests/test_ivy/conftest.py index 9bc479895dc31..c352003e359db 100644 --- a/ivy_tests/test_ivy/conftest.py +++ b/ivy_tests/test_ivy/conftest.py @@ -127,6 +127,7 @@ def run_around_tests(request, on_device, backend_fw, compile_graph, implicit): request.function.test_data, backend_fw.backend, request.function.ground_truth_backend, + on_device, ) except Exception as e: test_globals.teardown_api_test() diff --git a/ivy_tests/test_ivy/helpers/function_testing.py b/ivy_tests/test_ivy/helpers/function_testing.py index 4d917d3d63b30..7150dc2c9a957 100644 --- a/ivy_tests/test_ivy/helpers/function_testing.py +++ b/ivy_tests/test_ivy/helpers/function_testing.py @@ -103,14 +103,6 @@ def empty_func(*args, **kwargs): is_torch_native_array = empty_func -# ToDo, this is temporary until unsupported_dtype is embedded -# into helpers.get_dtypes -def _assert_dtypes_are_valid(input_dtypes: Union[List[ivy.Dtype], List[str]]): - for dtype in input_dtypes: - if dtype not in ivy.valid_dtypes + ivy.valid_complex_dtypes: - raise Exception(f"{dtype} is not a valid data type.") - - # Function testing @@ -213,7 +205,6 @@ def test_function( >>> x2 = np.array([-3, 15, 24]) >>> test_function(input_dtypes, test_flags, fw, fn_name, x1=x1, x2=x2) """ - _assert_dtypes_are_valid(input_dtypes) # split the arguments into their positional and keyword components args_np, kwargs_np = kwargs_to_args_n_kwargs( num_positional_args=test_flags.num_positional_args, kwargs=all_as_kwargs_np @@ -1037,8 +1028,6 @@ def test_method( ret_gt optional, return value from the Ground Truth function """ - _assert_dtypes_are_valid(method_input_dtypes) - init_input_dtypes = ivy.default(init_input_dtypes, []) # Constructor arguments # @@ -1344,8 +1333,6 @@ def test_frontend_method( """ if isinstance(frontend, list): frontend, frontend_proc = frontend - _assert_dtypes_are_valid(init_input_dtypes) - _assert_dtypes_are_valid(method_input_dtypes) # split the arguments into their positional and keyword components diff --git a/ivy_tests/test_ivy/helpers/globals.py b/ivy_tests/test_ivy/helpers/globals.py index 1a763b7a61e49..91ac91469d40f 100644 --- a/ivy_tests/test_ivy/helpers/globals.py +++ b/ivy_tests/test_ivy/helpers/globals.py @@ -36,6 +36,7 @@ CURRENT_BACKEND: callable = _Notsetval CURRENT_FRONTEND: callable = _Notsetval CURRENT_RUNNING_TEST = _Notsetval +CURRENT_DEVICE = _Notsetval CURRENT_FRONTEND_STR = "" @@ -139,28 +140,34 @@ def _get_ivy_torch(version=None): # Setup -def setup_api_test(test_data: TestData, backend: str, ground_truth_backend: str): +def setup_api_test( + test_data: TestData, backend: str, ground_truth_backend: str, device: str +): _set_test_data(test_data) _set_backend(backend) + _set_device(device) _set_ground_truth_backend(ground_truth_backend) def teardown_api_test(): _unset_test_data() _unset_backend() + _unset_device() _unset_ground_truth_backend() -def setup_frontend_test(test_data: TestData, frontend: str, backend: str): +def setup_frontend_test(test_data: TestData, frontend: str, backend: str, device: str): _set_test_data(test_data) _set_frontend(frontend) _set_backend(backend) + _set_device(device) def teardown_frontend_test(): _unset_test_data() _unset_frontend() _unset_backend() + _unset_device() def _set_test_data(test_data: TestData): @@ -199,6 +206,13 @@ def _set_ground_truth_backend(framework: str): CURRENT_GROUND_TRUTH_BACKEND = FWS_DICT[framework] +def _set_device(device: str): + global CURRENT_DEVICE + if CURRENT_DEVICE is not _Notsetval: + raise InterruptedTest(CURRENT_RUNNING_TEST) + CURRENT_DEVICE = device + + # Teardown @@ -220,3 +234,8 @@ def _unset_backend(): def _unset_ground_truth_backend(): global CURRENT_GROUND_TRUTH_BACKEND CURRENT_GROUND_TRUTH_BACKEND = _Notsetval + + +def _unset_device(): + global CURRENT_DEVICE + CURRENT_DEVICE = _Notsetval diff --git a/ivy_tests/test_ivy/helpers/hypothesis_helpers/dtype_helpers.py b/ivy_tests/test_ivy/helpers/hypothesis_helpers/dtype_helpers.py index 0a6ccfbe6c441..9ba9c34869105 100644 --- a/ivy_tests/test_ivy/helpers/hypothesis_helpers/dtype_helpers.py +++ b/ivy_tests/test_ivy/helpers/hypothesis_helpers/dtype_helpers.py @@ -14,6 +14,59 @@ from .. import globals as test_globals +_dtype_kind_keys = { + "valid", + "numeric", + "float", + "unsigned", + "integer", + "signed_integer", + "complex", + "real_and_complex", + "float_and_complex", + "bool", +} + + +def _get_fn_dtypes(framework, kind="valid"): + return test_globals.CURRENT_RUNNING_TEST.supported_device_dtypes[framework.backend][ + test_globals.CURRENT_DEVICE + ][kind] + + +def _get_type_dict(framework, kind): + if kind == "valid": + return framework.valid_dtypes + elif kind == "numeric": + return framework.valid_numeric_dtypes + elif kind == "integer": + return framework.valid_int_dtypes + elif kind == "float": + return framework.valid_float_dtypes + elif kind == "unsigned": + return framework.valid_int_dtypes + elif kind == "signed_integer": + return tuple( + set(framework.valid_int_dtypes).difference(framework.valid_uint_dtypes) + ) + elif kind == "complex": + return framework.valid_complex_dtypes + elif kind == "real_and_complex": + return tuple( + set(framework.valid_numeric_dtypes).union(framework.valid_complex_dtypes) + ) + elif kind == "float_and_complex": + return tuple( + set(framework.valid_float_dtypes).union(framework.valid_complex_dtypes) + ) + elif kind == "bool": + return tuple( + set(framework.valid_dtypes).difference(framework.valid_numeric_dtypes) + ) + else: + raise RuntimeError("{} is an unknown kind!".format(kind)) + + def make_json_pickable(s): s = s.replace("builtins.bfloat16", "ivy.bfloat16") s = s.replace("jax._src.device_array.reconstruct_device_array", "jax.numpy.array") @@ -22,7 +75,7 @@ def make_json_pickable(s): @st.composite def get_dtypes( - draw, kind, index=0, full=True, none=False, key=None, prune_function=True + draw, kind="valid", index=0, full=True, none=False, key=None, prune_function=True ): """ Draws a valid dtypes for the test function. For frontend tests, @@ -51,91 +104,45 @@ def get_dtypes( dtype string """ - def _get_type_dict(framework): - return { - "valid": framework.valid_dtypes, - "numeric": framework.valid_numeric_dtypes, - "float": framework.valid_float_dtypes, - "integer": framework.valid_int_dtypes, - "unsigned": framework.valid_uint_dtypes, - "signed_integer": tuple( - set(framework.valid_int_dtypes).difference(framework.valid_uint_dtypes) - ), - "complex": framework.valid_complex_dtypes, - "real_and_complex": tuple( - set(framework.valid_numeric_dtypes).union( - framework.valid_complex_dtypes - ) - ), - "float_and_complex": tuple( - set(framework.valid_float_dtypes).union(framework.valid_complex_dtypes) - ), - "bool": tuple( - set(framework.valid_dtypes).difference(framework.valid_numeric_dtypes) - ), - } - - # TODO refactor this so we run the intersection in a chained clean way - backend_dtypes = _get_type_dict(ivy)[kind] - - if test_globals.CURRENT_FRONTEND is not test_globals._Notsetval or isinstance( - test_globals.CURRENT_FRONTEND_STR, list - ): # NOQA - if isinstance(test_globals.CURRENT_FRONTEND_STR, list): - process = test_globals.CURRENT_FRONTEND_STR[1] - try: - process.stdin.write("1" + "\n") - process.stdin.flush() - except Exception as e: - print( - "Something bad happened to the subprocess, here are the logs:\n\n" - ) - print(process.stdout.readlines()) - raise e - frontend_ret = process.stdout.readline() - if frontend_ret: - frontend_ret = jsonpickle.loads(make_json_pickable(frontend_ret)) - else: - print(process.stderr.readlines()) - raise Exception - fw_dtypes = frontend_ret[kind] - valid_dtypes = tuple(set(fw_dtypes).intersection(backend_dtypes)) + if prune_function: + retrieval_fn = _get_fn_dtypes + if test_globals.CURRENT_RUNNING_TEST is not test_globals._Notsetval: + valid_dtypes = set(retrieval_fn(test_globals.CURRENT_BACKEND())) else: - fw_dtypes = _get_type_dict(test_globals.CURRENT_FRONTEND())[kind] - valid_dtypes = tuple(set(fw_dtypes).intersection(backend_dtypes)) + raise RuntimeError( + "No function is set to prune, calling " + "prune_function=True without a function is redundant." + ) else: - valid_dtypes = backend_dtypes + retrieval_fn = _get_type_dict + valid_dtypes = set(retrieval_fn(ivy, kind)) + + # The function may be called from a frontend test or an IVY api test + # In the case of a IVY api test, the function should make sure it returns a valid + # dtypes for the backend and also for the ground truth backend, if it is called from + # a frontend test, we should also count for the frontend support data types + # In conclusion, the following operations will get the intersection of + # FN_DTYPES & BACKEND_DTYPES & FRONTEND_DTYPES & GROUND_TRUTH_DTYPES + # If being called from a frontend test + if test_globals.CURRENT_FRONTEND is not test_globals._Notsetval: # NOQA + frontend_dtypes = retrieval_fn(test_globals.CURRENT_FRONTEND(), kind) + valid_dtypes = valid_dtypes.intersection(frontend_dtypes) + + # Make sure we return dtypes that are compatiable with ground truth backend ground_truth_is_set = ( test_globals.CURRENT_GROUND_TRUTH_BACKEND is not test_globals._Notsetval # NOQA ) if ground_truth_is_set: - gtb_dtypes = _get_type_dict(test_globals.CURRENT_GROUND_TRUTH_BACKEND())[kind] - valid_dtypes = tuple(set(gtb_dtypes).intersection(valid_dtypes)) - - # TODO, do this in a better way... - if ( - prune_function - and test_globals.CURRENT_RUNNING_TEST is not test_globals._Notsetval - ): # NOQA - fn_dtypes = test_globals.CURRENT_RUNNING_TEST.supported_device_dtypes - valid_dtypes = set(valid_dtypes).intersection( - fn_dtypes[test_globals.CURRENT_BACKEND().backend]["cpu"] + valid_dtypes = valid_dtypes.intersection( + retrieval_fn(test_globals.CURRENT_GROUND_TRUTH_BACKEND(), kind) ) - if ground_truth_is_set: - valid_dtypes = tuple( - valid_dtypes.intersection( - fn_dtypes[test_globals.CURRENT_GROUND_TRUTH_BACKEND().backend][ - "cpu" - ] - ) - ) - else: - valid_dtypes = tuple(valid_dtypes) + + valid_dtypes = list(valid_dtypes) if none: - valid_dtypes += (None,) + valid_dtypes.append(None) if full: - return list(valid_dtypes[index:]) + return valid_dtypes[index:] if key is None: return [draw(st.sampled_from(valid_dtypes[index:]))] return [draw(st.shared(st.sampled_from(valid_dtypes[index:]), key=key))] diff --git a/ivy_tests/test_ivy/helpers/testing_helpers.py b/ivy_tests/test_ivy/helpers/testing_helpers.py index 7717f43e4cdf2..fd109e7a1cb02 100644 --- a/ivy_tests/test_ivy/helpers/testing_helpers.py +++ b/ivy_tests/test_ivy/helpers/testing_helpers.py @@ -25,6 +25,10 @@ available_frameworks, ground_truth, ) +from ivy_tests.test_ivy.helpers.hypothesis_helpers.dtype_helpers import ( + _dtype_kind_keys, + _get_type_dict, +) ground_truth = ground_truth() @@ -180,7 +184,13 @@ def _get_method_supported_devices_dtypes( for b in backends: # ToDo can optimize this ? ivy.set_backend(b) _fn = getattr(class_module.__dict__[class_name], method_name) - supported_device_dtypes[b] = ivy.function_supported_devices_and_dtypes(_fn) + devices_and_dtypes = ivy.function_supported_devices_and_dtypes(_fn) + organized_dtypes = {} + for device in devices_and_dtypes.keys(): + organized_dtypes[device] = _partition_dtypes_into_kinds( + ivy, devices_and_dtypes[device] + ) + supported_device_dtypes[b] = organized_dtypes ivy.unset_backend() return supported_device_dtypes @@ -212,15 +222,29 @@ def _get_supported_devices_dtypes(fn_name: str, fn_module: str): backends = available_frameworks() for b in backends: # ToDo can optimize this ? - ivy.set_backend(b) _tmp_mod = importlib.import_module(fn_module) _fn = _tmp_mod.__dict__[fn_name] - supported_device_dtypes[b] = ivy.function_supported_devices_and_dtypes(_fn) + devices_and_dtypes = ivy.function_supported_devices_and_dtypes(_fn) + organized_dtypes = {} + for device in devices_and_dtypes.keys(): + organized_dtypes[device] = _partition_dtypes_into_kinds( + ivy, devices_and_dtypes[device] + ) + supported_device_dtypes[b] = organized_dtypes ivy.unset_backend() return supported_device_dtypes +def _partition_dtypes_into_kinds(framework, dtypes): + partitioned_dtypes = {} + for kind in _dtype_kind_keys: + partitioned_dtypes[kind] = set(_get_type_dict(framework, kind)).intersection( + dtypes + ) + return partitioned_dtypes + + # Decorators diff --git a/ivy_tests/test_ivy/test_frontends/conftest.py b/ivy_tests/test_ivy/test_frontends/conftest.py index b39bc14a61532..30afe96e2c345 100644 --- a/ivy_tests/test_ivy/test_frontends/conftest.py +++ b/ivy_tests/test_ivy/test_frontends/conftest.py @@ -9,7 +9,7 @@ def run_around_tests(request, on_device, backend_fw, frontend, compile_graph, im if hasattr(request.function, "test_data"): try: test_globals.setup_frontend_test( - request.function.test_data, frontend, backend_fw.backend + request.function.test_data, frontend, backend_fw.backend, on_device ) except Exception as e: test_globals.teardown_frontend_test()