From 706b28bfdb327afd754ddb426fa2d221918bc0fb Mon Sep 17 00:00:00 2001 From: RickSanchezStoic <57310695+RickSanchezStoic@users.noreply.github.com> Date: Fri, 17 Feb 2023 21:06:49 +0530 Subject: [PATCH] Docker exploration multiversion (#10595) * added requirement handler, for usage check comment in file * more changes and fixes * experiments * multiversion, added dockerfile, added multiversion_testing.py, added conda_manipulator * changes * mega changes that twirl the world, haha, just changes that make multiversion testing possible * minor changes for jax imports * lint fixes * changes * lint * lint * lint * lint * multiversion * lint errors * Update ivy.iml * Update ivy.iml * Update ivy.iml * Update misc.xml * multiversion frontend test * minor changes * lint * changes * lint * changes * changes * changes * changes * multiversion frontend, dtype handler added to subprocess, numpy bfloat16 dependency on tensorflow or jax eliminatedgit add -u! * lint * lint * lint * lint * small change * fixes for jax --------- Co-authored-by: Rishabh Kumar --- ivy/functional/backends/jax/general.py | 4 ++ ivy/functional/backends/torch/general.py | 6 ++- ivy_tests/test_ivy/conftest.py | 3 ++ .../test_ivy/helpers/function_testing.py | 21 ++++++---- ivy_tests/test_ivy/helpers/globals.py | 14 +++---- .../hypothesis_helpers/array_helpers.py | 8 ++++ .../hypothesis_helpers/dtype_helpers.py | 39 +++++++++++++++++-- multiversion_frontend_test.py | 34 ++++++++++++++++ 8 files changed, 111 insertions(+), 18 deletions(-) diff --git a/ivy/functional/backends/jax/general.py b/ivy/functional/backends/jax/general.py index e4547c14bc3aa..fe391589c7a81 100644 --- a/ivy/functional/backends/jax/general.py +++ b/ivy/functional/backends/jax/general.py @@ -11,6 +11,10 @@ from jaxlib.xla_extension import Buffer from typing import Iterable, Optional, Union, Sequence, Callable import multiprocessing as _multiprocessing +# necessary import, because stateful imports jax as soon as you import ivy, however, during multiversion +# jax is not there, and therefore a later import results in some sort of circular import, so haiku is needed +import haiku + from haiku._src.data_structures import FlatMapping # local diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 31c3361bc0bad..42e426bbcb624 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -4,7 +4,11 @@ from numbers import Number from operator import mul from typing import Optional, Union, Sequence, Callable, List -import functorch + +try: + import functorch +except ImportError: + functorch = () # for torch 1.10.1 import numpy as np import torch diff --git a/ivy_tests/test_ivy/conftest.py b/ivy_tests/test_ivy/conftest.py index 9bc479895dc31..93b90e87c8f38 100644 --- a/ivy_tests/test_ivy/conftest.py +++ b/ivy_tests/test_ivy/conftest.py @@ -60,6 +60,8 @@ def pytest_configure(config): frontend_strs = frontend.split(",") for i in frontend_strs: process = subprocess.Popen( + + [ "/opt/miniconda/envs/multienv/bin/python", "multiversion_frontend_test.py", @@ -73,6 +75,7 @@ def pytest_configure(config): ) mod_frontend[i.split("/")[0]] = [i, process] + # compile_graph raw_value = config.getoption("--compile_graph") if raw_value == "both": diff --git a/ivy_tests/test_ivy/helpers/function_testing.py b/ivy_tests/test_ivy/helpers/function_testing.py index ae320fd1dd4e9..f2a6c47e61bb4 100644 --- a/ivy_tests/test_ivy/helpers/function_testing.py +++ b/ivy_tests/test_ivy/helpers/function_testing.py @@ -14,6 +14,8 @@ def framework_comparator(frontend): + if ivy.current_backend_str() != frontend.split("/")[0]: + return False if frontend.split("/")[0] == "jax": fw = frontend.split("/")[1] + frontend.split("/")[3] backend_fw = ( @@ -28,8 +30,8 @@ def framework_comparator(frontend): ) else: return ( - frontend.split("/")[1] - == importlib.import_module(frontend.split("/")[1]).__version__ + frontend.split("/")[0] + == importlib.import_module(frontend.split("/")[0]).__version__ ) @@ -633,8 +635,6 @@ def test_frontend_function( shallow=False, ) - # temporarily set frontend framework as backend - ivy.set_backend(frontend.split("/")[0]) if "/" in frontend: # multiversion zone, changes made in non-multiversion zone should # be applied here too @@ -673,7 +673,7 @@ def test_frontend_function( if not isinstance(frontend_ret, tuple): frontend_ret = (frontend_ret,) frontend_ret_idxs = ivy.nested_argwhere( - frontend_ret, lambda x: isinstance(x, np.ndarray) + frontend_ret, lambda x: isinstance(x, np.ndarray) or isinstance(x,ivy.Array) ) frontend_ret_flat = ivy.multi_index_nest( frontend_ret, frontend_ret_idxs @@ -684,6 +684,9 @@ def test_frontend_function( ivy.unset_backend() raise e else: + + # temporarily set frontend framework as backend + ivy.set_backend(frontend.split("/")[0]) try: # create frontend framework args args_frontend = ivy.nested_map( @@ -743,6 +746,8 @@ def test_frontend_function( frontend_ret, frontend_ret_idxs ) frontend_ret_np_flat = [ivy.to_numpy(x) for x in frontend_ret_flat] + # unset frontend framework from backend + ivy.unset_backend() except Exception as e: ivy.unset_backend() raise e @@ -752,6 +757,8 @@ def test_frontend_function( # non-multiversion zone, changes made here should be # applied to multiversion zone too + # temporarily set frontend framework as backend + ivy.set_backend(frontend.split("/")[0]) try: # create frontend framework args args_frontend = ivy.nested_map( @@ -806,11 +813,11 @@ def test_frontend_function( frontend_ret, frontend_ret_idxs ) frontend_ret_np_flat = [ivy.to_numpy(x) for x in frontend_ret_flat] + # unset frontend framework from backend + ivy.unset_backend() except Exception as e: ivy.unset_backend() raise e - # unset frontend framework from backend - ivy.unset_backend() ret_np_flat = flatten_and_to_np(ret=ret) diff --git a/ivy_tests/test_ivy/helpers/globals.py b/ivy_tests/test_ivy/helpers/globals.py index 41707303d3946..0e8eef65245b8 100644 --- a/ivy_tests/test_ivy/helpers/globals.py +++ b/ivy_tests/test_ivy/helpers/globals.py @@ -3,6 +3,7 @@ testing data to be used by the test helpers to prune unsupported data. Should not be used inside any of the test functions. """ +import importlib import sys from ... import config @@ -35,6 +36,7 @@ CURRENT_BACKEND: callable = _Notsetval CURRENT_FRONTEND: callable = _Notsetval CURRENT_RUNNING_TEST = _Notsetval +CURRENT_FRONTEND_STR = "" @dataclass(frozen=True) # ToDo use kw_only=True when version is updated @@ -83,7 +85,8 @@ def __init__(self, test_interruped): def _get_ivy_numpy(version=None): """Import Numpy module from ivy""" if version: - config.reset_sys_modules_to_base() + if version.split('/')[1]!=importlib.import_module('numpy').__version__: + config.reset_sys_modules_to_base() config.allow_global_framework_imports(fw=[version]) try: @@ -101,13 +104,8 @@ def _get_ivy_jax(version=None): version.split("/")[2] + "/" + version.split("/")[3], ] config.allow_global_framework_imports(fw=las) - try: - config.reset_sys_modules_to_base() - import ivy.functional.backends.jax + import ivy.functional.backends.jax - return ivy.functional.backends.jax - except ImportError as e: - raise e else: try: import ivy.functional.backends.jax @@ -174,11 +172,13 @@ def _set_test_data(test_data: TestData): def _set_frontend(framework: str): global CURRENT_FRONTEND + global CURRENT_FRONTEND_STR if CURRENT_FRONTEND is not _Notsetval: raise InterruptedTest(CURRENT_RUNNING_TEST) if isinstance(framework, list): CURRENT_FRONTEND = FWS_DICT[framework[0].split("/")[0]] + CURRENT_FRONTEND_STR = framework else: CURRENT_FRONTEND = FWS_DICT[framework] diff --git a/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py b/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py index 0a43163ac3092..0228a18852f02 100644 --- a/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py +++ b/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py @@ -1036,8 +1036,16 @@ def array_values( values = [complex(*v) for v in values] else: values = draw(list_of_size(x=st.booleans(), size=size)) + if dtype == "bfloat16": + # check bfloat16 behavior enabled or not + try: + np.dtype("bfloat16") + except Exception: + # enables bfloat16 behavior with possibly no side-effects + import paddle_bfloat # noqa array = np.asarray(values, dtype=dtype) + if isinstance(shape, (tuple, list)): return array.reshape(shape) return np.asarray(array) diff --git a/ivy_tests/test_ivy/helpers/hypothesis_helpers/dtype_helpers.py b/ivy_tests/test_ivy/helpers/hypothesis_helpers/dtype_helpers.py index c6762f1ce0aeb..0a6ccfbe6c441 100644 --- a/ivy_tests/test_ivy/helpers/hypothesis_helpers/dtype_helpers.py +++ b/ivy_tests/test_ivy/helpers/hypothesis_helpers/dtype_helpers.py @@ -3,6 +3,10 @@ from hypothesis import strategies as st from typing import Optional +try: + import jsonpickle +except ImportError: + pass # local import ivy from . import number_helpers as nh @@ -10,6 +14,12 @@ from .. import globals as test_globals +def make_json_pickable(s): + s = s.replace("builtins.bfloat16", "ivy.bfloat16") + s = s.replace("jax._src.device_array.reconstruct_device_array", "jax.numpy.array") + return s + + @st.composite def get_dtypes( draw, kind, index=0, full=True, none=False, key=None, prune_function=True @@ -67,9 +77,32 @@ def _get_type_dict(framework): # TODO refactor this so we run the intersection in a chained clean way backend_dtypes = _get_type_dict(ivy)[kind] - if test_globals.CURRENT_FRONTEND is not test_globals._Notsetval: # NOQA - fw_dtypes = _get_type_dict(test_globals.CURRENT_FRONTEND())[kind] - valid_dtypes = tuple(set(fw_dtypes).intersection(backend_dtypes)) + + if test_globals.CURRENT_FRONTEND is not test_globals._Notsetval or isinstance( + test_globals.CURRENT_FRONTEND_STR, list + ): # NOQA + if isinstance(test_globals.CURRENT_FRONTEND_STR, list): + process = test_globals.CURRENT_FRONTEND_STR[1] + try: + process.stdin.write("1" + "\n") + process.stdin.flush() + except Exception as e: + print( + "Something bad happened to the subprocess, here are the logs:\n\n" + ) + print(process.stdout.readlines()) + raise e + frontend_ret = process.stdout.readline() + if frontend_ret: + frontend_ret = jsonpickle.loads(make_json_pickable(frontend_ret)) + else: + print(process.stderr.readlines()) + raise Exception + fw_dtypes = frontend_ret[kind] + valid_dtypes = tuple(set(fw_dtypes).intersection(backend_dtypes)) + else: + fw_dtypes = _get_type_dict(test_globals.CURRENT_FRONTEND())[kind] + valid_dtypes = tuple(set(fw_dtypes).intersection(backend_dtypes)) else: valid_dtypes = backend_dtypes diff --git a/multiversion_frontend_test.py b/multiversion_frontend_test.py index 1ae7a2eb56a0f..0208558e2c97e 100644 --- a/multiversion_frontend_test.py +++ b/multiversion_frontend_test.py @@ -58,6 +58,36 @@ def __init__(self, native_class): self._native_class = native_class +def _get_type_dict(framework): + return { + "valid": framework.valid_dtypes, + "numeric": framework.valid_numeric_dtypes, + "float": framework.valid_float_dtypes, + "integer": framework.valid_int_dtypes, + "unsigned": framework.valid_uint_dtypes, + "signed_integer": tuple( + set(framework.valid_int_dtypes).difference(framework.valid_uint_dtypes) + ), + "complex": framework.valid_complex_dtypes, + "real_and_complex": tuple( + set(framework.valid_numeric_dtypes).union(framework.valid_complex_dtypes) + ), + "float_and_complex": tuple( + set(framework.valid_float_dtypes).union(framework.valid_complex_dtypes) + ), + "bool": tuple( + set(framework.valid_dtypes).difference(framework.valid_numeric_dtypes) + ), + } + + +def dtype_handler(framework): + framework = importlib.import_module("ivy.functional.backends." + framework) + dtypes = _get_type_dict(framework) + dtypes = jsonpickle.dumps(dtypes) + print(dtypes) + + if __name__ == "__main__": arg_lis = sys.argv @@ -80,6 +110,10 @@ def __init__(self, native_class): while j: try: z = input() + if z == "1": + dtype_handler(arg_lis[2].split("/")[0]) + continue + pickle_dict = jsonpickle.loads(z) frontend_fw = input()