Skip to content

Commit

Permalink
Docker exploration multiversion (#10595)
Browse files Browse the repository at this point in the history
* added requirement handler, for usage check comment in file

* more changes and fixes

* experiments

* multiversion, added dockerfile, added multiversion_testing.py, added conda_manipulator

* changes

* mega changes that twirl the world, haha, just changes that make multiversion testing possible

* minor changes for jax imports

* lint fixes

* changes

* lint

* lint

* lint

* lint

* multiversion

* lint errors

* Update ivy.iml

* Update ivy.iml

* Update ivy.iml

* Update misc.xml

* multiversion frontend test

* minor changes

* lint

* changes

* lint

* changes

* changes

* changes

* changes

* multiversion frontend, dtype handler added to subprocess, numpy bfloat16 dependency on tensorflow or jax eliminatedgit add -u!

* lint

* lint

* lint

* lint

* small change

* fixes for jax

---------

Co-authored-by: Rishabh Kumar <[email protected]>
  • Loading branch information
RickSanchezStoic and Rishabh Kumar authored Feb 17, 2023
1 parent b753ed1 commit 706b28b
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 18 deletions.
4 changes: 4 additions & 0 deletions ivy/functional/backends/jax/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from jaxlib.xla_extension import Buffer
from typing import Iterable, Optional, Union, Sequence, Callable
import multiprocessing as _multiprocessing
# necessary import, because stateful imports jax as soon as you import ivy, however, during multiversion
# jax is not there, and therefore a later import results in some sort of circular import, so haiku is needed
import haiku

from haiku._src.data_structures import FlatMapping

# local
Expand Down
6 changes: 5 additions & 1 deletion ivy/functional/backends/torch/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from numbers import Number
from operator import mul
from typing import Optional, Union, Sequence, Callable, List
import functorch

try:
import functorch
except ImportError:
functorch = () # for torch 1.10.1
import numpy as np
import torch

Expand Down
3 changes: 3 additions & 0 deletions ivy_tests/test_ivy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def pytest_configure(config):
frontend_strs = frontend.split(",")
for i in frontend_strs:
process = subprocess.Popen(


[
"/opt/miniconda/envs/multienv/bin/python",
"multiversion_frontend_test.py",
Expand All @@ -73,6 +75,7 @@ def pytest_configure(config):
)
mod_frontend[i.split("/")[0]] = [i, process]


# compile_graph
raw_value = config.getoption("--compile_graph")
if raw_value == "both":
Expand Down
21 changes: 14 additions & 7 deletions ivy_tests/test_ivy/helpers/function_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@


def framework_comparator(frontend):
if ivy.current_backend_str() != frontend.split("/")[0]:
return False
if frontend.split("/")[0] == "jax":
fw = frontend.split("/")[1] + frontend.split("/")[3]
backend_fw = (
Expand All @@ -28,8 +30,8 @@ def framework_comparator(frontend):
)
else:
return (
frontend.split("/")[1]
== importlib.import_module(frontend.split("/")[1]).__version__
frontend.split("/")[0]
== importlib.import_module(frontend.split("/")[0]).__version__
)


Expand Down Expand Up @@ -633,8 +635,6 @@ def test_frontend_function(
shallow=False,
)

# temporarily set frontend framework as backend
ivy.set_backend(frontend.split("/")[0])
if "/" in frontend:
# multiversion zone, changes made in non-multiversion zone should
# be applied here too
Expand Down Expand Up @@ -673,7 +673,7 @@ def test_frontend_function(
if not isinstance(frontend_ret, tuple):
frontend_ret = (frontend_ret,)
frontend_ret_idxs = ivy.nested_argwhere(
frontend_ret, lambda x: isinstance(x, np.ndarray)
frontend_ret, lambda x: isinstance(x, np.ndarray) or isinstance(x,ivy.Array)
)
frontend_ret_flat = ivy.multi_index_nest(
frontend_ret, frontend_ret_idxs
Expand All @@ -684,6 +684,9 @@ def test_frontend_function(
ivy.unset_backend()
raise e
else:

# temporarily set frontend framework as backend
ivy.set_backend(frontend.split("/")[0])
try:
# create frontend framework args
args_frontend = ivy.nested_map(
Expand Down Expand Up @@ -743,6 +746,8 @@ def test_frontend_function(
frontend_ret, frontend_ret_idxs
)
frontend_ret_np_flat = [ivy.to_numpy(x) for x in frontend_ret_flat]
# unset frontend framework from backend
ivy.unset_backend()
except Exception as e:
ivy.unset_backend()
raise e
Expand All @@ -752,6 +757,8 @@ def test_frontend_function(
# non-multiversion zone, changes made here should be
# applied to multiversion zone too

# temporarily set frontend framework as backend
ivy.set_backend(frontend.split("/")[0])
try:
# create frontend framework args
args_frontend = ivy.nested_map(
Expand Down Expand Up @@ -806,11 +813,11 @@ def test_frontend_function(
frontend_ret, frontend_ret_idxs
)
frontend_ret_np_flat = [ivy.to_numpy(x) for x in frontend_ret_flat]
# unset frontend framework from backend
ivy.unset_backend()
except Exception as e:
ivy.unset_backend()
raise e
# unset frontend framework from backend
ivy.unset_backend()

ret_np_flat = flatten_and_to_np(ret=ret)

Expand Down
14 changes: 7 additions & 7 deletions ivy_tests/test_ivy/helpers/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
testing data to be used by the test helpers to prune unsupported data.
Should not be used inside any of the test functions.
"""
import importlib
import sys
from ... import config

Expand Down Expand Up @@ -35,6 +36,7 @@
CURRENT_BACKEND: callable = _Notsetval
CURRENT_FRONTEND: callable = _Notsetval
CURRENT_RUNNING_TEST = _Notsetval
CURRENT_FRONTEND_STR = ""


@dataclass(frozen=True) # ToDo use kw_only=True when version is updated
Expand Down Expand Up @@ -83,7 +85,8 @@ def __init__(self, test_interruped):
def _get_ivy_numpy(version=None):
"""Import Numpy module from ivy"""
if version:
config.reset_sys_modules_to_base()
if version.split('/')[1]!=importlib.import_module('numpy').__version__:
config.reset_sys_modules_to_base()
config.allow_global_framework_imports(fw=[version])

try:
Expand All @@ -101,13 +104,8 @@ def _get_ivy_jax(version=None):
version.split("/")[2] + "/" + version.split("/")[3],
]
config.allow_global_framework_imports(fw=las)
try:
config.reset_sys_modules_to_base()
import ivy.functional.backends.jax
import ivy.functional.backends.jax

return ivy.functional.backends.jax
except ImportError as e:
raise e
else:
try:
import ivy.functional.backends.jax
Expand Down Expand Up @@ -174,11 +172,13 @@ def _set_test_data(test_data: TestData):

def _set_frontend(framework: str):
global CURRENT_FRONTEND
global CURRENT_FRONTEND_STR
if CURRENT_FRONTEND is not _Notsetval:
raise InterruptedTest(CURRENT_RUNNING_TEST)
if isinstance(framework, list):

CURRENT_FRONTEND = FWS_DICT[framework[0].split("/")[0]]
CURRENT_FRONTEND_STR = framework
else:
CURRENT_FRONTEND = FWS_DICT[framework]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1036,8 +1036,16 @@ def array_values(
values = [complex(*v) for v in values]
else:
values = draw(list_of_size(x=st.booleans(), size=size))
if dtype == "bfloat16":
# check bfloat16 behavior enabled or not
try:
np.dtype("bfloat16")
except Exception:
# enables bfloat16 behavior with possibly no side-effects
import paddle_bfloat # noqa

array = np.asarray(values, dtype=dtype)

if isinstance(shape, (tuple, list)):
return array.reshape(shape)
return np.asarray(array)
Expand Down
39 changes: 36 additions & 3 deletions ivy_tests/test_ivy/helpers/hypothesis_helpers/dtype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,23 @@
from hypothesis import strategies as st
from typing import Optional

try:
import jsonpickle
except ImportError:
pass
# local
import ivy
from . import number_helpers as nh
from . import array_helpers as ah
from .. import globals as test_globals


def make_json_pickable(s):
s = s.replace("builtins.bfloat16", "ivy.bfloat16")
s = s.replace("jax._src.device_array.reconstruct_device_array", "jax.numpy.array")
return s


@st.composite
def get_dtypes(
draw, kind, index=0, full=True, none=False, key=None, prune_function=True
Expand Down Expand Up @@ -67,9 +77,32 @@ def _get_type_dict(framework):

# TODO refactor this so we run the intersection in a chained clean way
backend_dtypes = _get_type_dict(ivy)[kind]
if test_globals.CURRENT_FRONTEND is not test_globals._Notsetval: # NOQA
fw_dtypes = _get_type_dict(test_globals.CURRENT_FRONTEND())[kind]
valid_dtypes = tuple(set(fw_dtypes).intersection(backend_dtypes))

if test_globals.CURRENT_FRONTEND is not test_globals._Notsetval or isinstance(
test_globals.CURRENT_FRONTEND_STR, list
): # NOQA
if isinstance(test_globals.CURRENT_FRONTEND_STR, list):
process = test_globals.CURRENT_FRONTEND_STR[1]
try:
process.stdin.write("1" + "\n")
process.stdin.flush()
except Exception as e:
print(
"Something bad happened to the subprocess, here are the logs:\n\n"
)
print(process.stdout.readlines())
raise e
frontend_ret = process.stdout.readline()
if frontend_ret:
frontend_ret = jsonpickle.loads(make_json_pickable(frontend_ret))
else:
print(process.stderr.readlines())
raise Exception
fw_dtypes = frontend_ret[kind]
valid_dtypes = tuple(set(fw_dtypes).intersection(backend_dtypes))
else:
fw_dtypes = _get_type_dict(test_globals.CURRENT_FRONTEND())[kind]
valid_dtypes = tuple(set(fw_dtypes).intersection(backend_dtypes))
else:
valid_dtypes = backend_dtypes

Expand Down
34 changes: 34 additions & 0 deletions multiversion_frontend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,36 @@ def __init__(self, native_class):
self._native_class = native_class


def _get_type_dict(framework):
return {
"valid": framework.valid_dtypes,
"numeric": framework.valid_numeric_dtypes,
"float": framework.valid_float_dtypes,
"integer": framework.valid_int_dtypes,
"unsigned": framework.valid_uint_dtypes,
"signed_integer": tuple(
set(framework.valid_int_dtypes).difference(framework.valid_uint_dtypes)
),
"complex": framework.valid_complex_dtypes,
"real_and_complex": tuple(
set(framework.valid_numeric_dtypes).union(framework.valid_complex_dtypes)
),
"float_and_complex": tuple(
set(framework.valid_float_dtypes).union(framework.valid_complex_dtypes)
),
"bool": tuple(
set(framework.valid_dtypes).difference(framework.valid_numeric_dtypes)
),
}


def dtype_handler(framework):
framework = importlib.import_module("ivy.functional.backends." + framework)
dtypes = _get_type_dict(framework)
dtypes = jsonpickle.dumps(dtypes)
print(dtypes)


if __name__ == "__main__":

arg_lis = sys.argv
Expand All @@ -80,6 +110,10 @@ def __init__(self, native_class):
while j:
try:
z = input()
if z == "1":
dtype_handler(arg_lis[2].split("/")[0])
continue

pickle_dict = jsonpickle.loads(z)
frontend_fw = input()

Expand Down

0 comments on commit 706b28b

Please sign in to comment.