Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docker exploration multiversion #10595

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
1990cfd
added requirement handler, for usage check comment in file
Dec 19, 2022
dc81002
more changes and fixes
Dec 20, 2022
5c48a00
Merge remote-tracking branch 'upstream/master' into docker_exploratio…
Dec 20, 2022
c6428bf
experiments
Dec 21, 2022
40d71e4
Merge branch 'unifyai:master' into docker_exploration_multiversion
RickSanchezStoic Dec 21, 2022
3292ead
Merge branch 'unifyai:master' into docker_exploration_multiversion
RickSanchezStoic Jan 4, 2023
98448b6
Merge branch 'unifyai:master' into docker_exploration_multiversion
RickSanchezStoic Jan 5, 2023
5898cd6
multiversion, added dockerfile, added multiversion_testing.py, added …
Jan 6, 2023
ec05b96
Merge branch 'docker_exploration_multiversion' of https://github.com/…
Jan 6, 2023
12210eb
Merge branch 'unifyai:master' into docker_exploration_multiversion
RickSanchezStoic Jan 6, 2023
66c0477
changes
Jan 10, 2023
adbccce
Merge branch 'docker_exploration_multiversion' of https://github.com/…
Jan 10, 2023
336df14
Merge branch 'unifyai:master' into docker_exploration_multiversion
RickSanchezStoic Jan 11, 2023
9ffadd7
mega changes that twirl the world, haha, just changes that make multi…
Jan 20, 2023
80171e9
Merge branch 'docker_exploration_multiversion' of https://github.com/…
Jan 20, 2023
3c69d87
merge_conflicts
Jan 20, 2023
28cc836
minor changes for jax imports
Jan 20, 2023
958b0bf
Merge branch 'unifyai:master' into docker_exploration_multiversion
RickSanchezStoic Jan 20, 2023
60b0973
lint fixes
Jan 23, 2023
c2271f4
Merge branch 'master' into docker_exploration_multiversion
RickSanchezStoic Jan 23, 2023
97faee3
changes
Jan 23, 2023
2d73769
Merge branch 'docker_exploration_multiversion' of https://github.com/…
Jan 23, 2023
1ddab74
lint
Jan 24, 2023
f22605f
lint
Jan 24, 2023
0c24710
lint
Jan 24, 2023
9d7ba12
lint
Jan 24, 2023
c2dfe71
multiversion
Feb 1, 2023
dabaa97
merge_conflicts
Feb 1, 2023
17c95d7
lint errors
Feb 1, 2023
6c3ec9d
Merge branch 'unifyai:master' into docker_exploration_multiversion
RickSanchezStoic Feb 1, 2023
cdb2369
Update ivy.iml
RickSanchezStoic Feb 1, 2023
45fdd63
Update ivy.iml
RickSanchezStoic Feb 1, 2023
1f19e88
Update ivy.iml
RickSanchezStoic Feb 1, 2023
01ada6f
Update misc.xml
RickSanchezStoic Feb 1, 2023
ace42c9
multiversion frontend test
Feb 6, 2023
96abbe0
merge conflicts
Feb 6, 2023
6a5cf99
minor changes
Feb 6, 2023
d3f8989
lint
Feb 6, 2023
68ca06b
changes
Feb 7, 2023
51f62b9
lint
Feb 7, 2023
96c753b
changes
Feb 8, 2023
fc72096
changes
Feb 8, 2023
155bd46
Merge branch 'master' into docker_exploration_multiversion
RickSanchezStoic Feb 8, 2023
89584b2
changes
Feb 8, 2023
5ad9c4f
changes
Feb 8, 2023
4c94a41
multiversion frontend, dtype handler added to subprocess, numpy bfloa…
Feb 16, 2023
b1fca81
lint
Feb 16, 2023
0d410d7
Merge branch 'master' into docker_exploration_multiversion
RickSanchezStoic Feb 16, 2023
ef2a890
lint
Feb 16, 2023
6f6908e
lint
Feb 16, 2023
205032d
lint
Feb 16, 2023
6c5b115
small change
Feb 17, 2023
7ae1454
fixes for jax
Feb 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ivy/functional/backends/jax/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from jaxlib.xla_extension import Buffer
from typing import Iterable, Optional, Union, Sequence, Callable
import multiprocessing as _multiprocessing
# necessary import, because stateful imports jax as soon as you import ivy, however, during multiversion
# jax is not there, and therefore a later import results in some sort of circular import, so haiku is needed
import haiku

from haiku._src.data_structures import FlatMapping

# local
Expand Down
6 changes: 5 additions & 1 deletion ivy/functional/backends/torch/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from numbers import Number
from operator import mul
from typing import Optional, Union, Sequence, Callable, List
import functorch

try:
import functorch
except ImportError:
functorch = () # for torch 1.10.1
import numpy as np
import torch

Expand Down
3 changes: 3 additions & 0 deletions ivy_tests/test_ivy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def pytest_configure(config):
frontend_strs = frontend.split(",")
for i in frontend_strs:
process = subprocess.Popen(


[
"/opt/miniconda/envs/multienv/bin/python",
"multiversion_frontend_test.py",
Expand All @@ -73,6 +75,7 @@ def pytest_configure(config):
)
mod_frontend[i.split("/")[0]] = [i, process]


# compile_graph
raw_value = config.getoption("--compile_graph")
if raw_value == "both":
Expand Down
21 changes: 14 additions & 7 deletions ivy_tests/test_ivy/helpers/function_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@


def framework_comparator(frontend):
if ivy.current_backend_str() != frontend.split("/")[0]:
return False
if frontend.split("/")[0] == "jax":
fw = frontend.split("/")[1] + frontend.split("/")[3]
backend_fw = (
Expand All @@ -28,8 +30,8 @@ def framework_comparator(frontend):
)
else:
return (
frontend.split("/")[1]
== importlib.import_module(frontend.split("/")[1]).__version__
frontend.split("/")[0]
== importlib.import_module(frontend.split("/")[0]).__version__
)


Expand Down Expand Up @@ -633,8 +635,6 @@ def test_frontend_function(
shallow=False,
)

# temporarily set frontend framework as backend
ivy.set_backend(frontend.split("/")[0])
if "/" in frontend:
# multiversion zone, changes made in non-multiversion zone should
# be applied here too
Expand Down Expand Up @@ -673,7 +673,7 @@ def test_frontend_function(
if not isinstance(frontend_ret, tuple):
frontend_ret = (frontend_ret,)
frontend_ret_idxs = ivy.nested_argwhere(
frontend_ret, lambda x: isinstance(x, np.ndarray)
frontend_ret, lambda x: isinstance(x, np.ndarray) or isinstance(x,ivy.Array)
)
frontend_ret_flat = ivy.multi_index_nest(
frontend_ret, frontend_ret_idxs
Expand All @@ -684,6 +684,9 @@ def test_frontend_function(
ivy.unset_backend()
raise e
else:

# temporarily set frontend framework as backend
ivy.set_backend(frontend.split("/")[0])
try:
# create frontend framework args
args_frontend = ivy.nested_map(
Expand Down Expand Up @@ -743,6 +746,8 @@ def test_frontend_function(
frontend_ret, frontend_ret_idxs
)
frontend_ret_np_flat = [ivy.to_numpy(x) for x in frontend_ret_flat]
# unset frontend framework from backend
ivy.unset_backend()
except Exception as e:
ivy.unset_backend()
raise e
Expand All @@ -752,6 +757,8 @@ def test_frontend_function(
# non-multiversion zone, changes made here should be
# applied to multiversion zone too

# temporarily set frontend framework as backend
ivy.set_backend(frontend.split("/")[0])
try:
# create frontend framework args
args_frontend = ivy.nested_map(
Expand Down Expand Up @@ -806,11 +813,11 @@ def test_frontend_function(
frontend_ret, frontend_ret_idxs
)
frontend_ret_np_flat = [ivy.to_numpy(x) for x in frontend_ret_flat]
# unset frontend framework from backend
ivy.unset_backend()
except Exception as e:
ivy.unset_backend()
raise e
# unset frontend framework from backend
ivy.unset_backend()

ret_np_flat = flatten_and_to_np(ret=ret)

Expand Down
14 changes: 7 additions & 7 deletions ivy_tests/test_ivy/helpers/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
testing data to be used by the test helpers to prune unsupported data.
Should not be used inside any of the test functions.
"""
import importlib
import sys
from ... import config

Expand Down Expand Up @@ -35,6 +36,7 @@
CURRENT_BACKEND: callable = _Notsetval
CURRENT_FRONTEND: callable = _Notsetval
CURRENT_RUNNING_TEST = _Notsetval
CURRENT_FRONTEND_STR = ""


@dataclass(frozen=True) # ToDo use kw_only=True when version is updated
Expand Down Expand Up @@ -83,7 +85,8 @@ def __init__(self, test_interruped):
def _get_ivy_numpy(version=None):
"""Import Numpy module from ivy"""
if version:
config.reset_sys_modules_to_base()
if version.split('/')[1]!=importlib.import_module('numpy').__version__:
config.reset_sys_modules_to_base()
config.allow_global_framework_imports(fw=[version])

try:
Expand All @@ -101,13 +104,8 @@ def _get_ivy_jax(version=None):
version.split("/")[2] + "/" + version.split("/")[3],
]
config.allow_global_framework_imports(fw=las)
try:
config.reset_sys_modules_to_base()
import ivy.functional.backends.jax
import ivy.functional.backends.jax

return ivy.functional.backends.jax
except ImportError as e:
raise e
else:
try:
import ivy.functional.backends.jax
Expand Down Expand Up @@ -174,11 +172,13 @@ def _set_test_data(test_data: TestData):

def _set_frontend(framework: str):
global CURRENT_FRONTEND
global CURRENT_FRONTEND_STR
if CURRENT_FRONTEND is not _Notsetval:
raise InterruptedTest(CURRENT_RUNNING_TEST)
if isinstance(framework, list):

CURRENT_FRONTEND = FWS_DICT[framework[0].split("/")[0]]
CURRENT_FRONTEND_STR = framework
else:
CURRENT_FRONTEND = FWS_DICT[framework]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1036,8 +1036,16 @@ def array_values(
values = [complex(*v) for v in values]
else:
values = draw(list_of_size(x=st.booleans(), size=size))
if dtype == "bfloat16":
# check bfloat16 behavior enabled or not
try:
np.dtype("bfloat16")
except Exception:
# enables bfloat16 behavior with possibly no side-effects
import paddle_bfloat # noqa

array = np.asarray(values, dtype=dtype)

if isinstance(shape, (tuple, list)):
return array.reshape(shape)
return np.asarray(array)
Expand Down
39 changes: 36 additions & 3 deletions ivy_tests/test_ivy/helpers/hypothesis_helpers/dtype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,23 @@
from hypothesis import strategies as st
from typing import Optional

try:
import jsonpickle
except ImportError:
pass
# local
import ivy
from . import number_helpers as nh
from . import array_helpers as ah
from .. import globals as test_globals


def make_json_pickable(s):
s = s.replace("builtins.bfloat16", "ivy.bfloat16")
s = s.replace("jax._src.device_array.reconstruct_device_array", "jax.numpy.array")
return s


@st.composite
def get_dtypes(
draw, kind, index=0, full=True, none=False, key=None, prune_function=True
Expand Down Expand Up @@ -67,9 +77,32 @@ def _get_type_dict(framework):

# TODO refactor this so we run the intersection in a chained clean way
backend_dtypes = _get_type_dict(ivy)[kind]
if test_globals.CURRENT_FRONTEND is not test_globals._Notsetval: # NOQA
fw_dtypes = _get_type_dict(test_globals.CURRENT_FRONTEND())[kind]
valid_dtypes = tuple(set(fw_dtypes).intersection(backend_dtypes))

if test_globals.CURRENT_FRONTEND is not test_globals._Notsetval or isinstance(
test_globals.CURRENT_FRONTEND_STR, list
): # NOQA
if isinstance(test_globals.CURRENT_FRONTEND_STR, list):
process = test_globals.CURRENT_FRONTEND_STR[1]
try:
process.stdin.write("1" + "\n")
process.stdin.flush()
except Exception as e:
print(
"Something bad happened to the subprocess, here are the logs:\n\n"
)
print(process.stdout.readlines())
raise e
frontend_ret = process.stdout.readline()
if frontend_ret:
frontend_ret = jsonpickle.loads(make_json_pickable(frontend_ret))
else:
print(process.stderr.readlines())
raise Exception
fw_dtypes = frontend_ret[kind]
valid_dtypes = tuple(set(fw_dtypes).intersection(backend_dtypes))
else:
fw_dtypes = _get_type_dict(test_globals.CURRENT_FRONTEND())[kind]
valid_dtypes = tuple(set(fw_dtypes).intersection(backend_dtypes))
else:
valid_dtypes = backend_dtypes

Expand Down
34 changes: 34 additions & 0 deletions multiversion_frontend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,36 @@ def __init__(self, native_class):
self._native_class = native_class


def _get_type_dict(framework):
return {
"valid": framework.valid_dtypes,
"numeric": framework.valid_numeric_dtypes,
"float": framework.valid_float_dtypes,
"integer": framework.valid_int_dtypes,
"unsigned": framework.valid_uint_dtypes,
"signed_integer": tuple(
set(framework.valid_int_dtypes).difference(framework.valid_uint_dtypes)
),
"complex": framework.valid_complex_dtypes,
"real_and_complex": tuple(
set(framework.valid_numeric_dtypes).union(framework.valid_complex_dtypes)
),
"float_and_complex": tuple(
set(framework.valid_float_dtypes).union(framework.valid_complex_dtypes)
),
"bool": tuple(
set(framework.valid_dtypes).difference(framework.valid_numeric_dtypes)
),
}


def dtype_handler(framework):
framework = importlib.import_module("ivy.functional.backends." + framework)
dtypes = _get_type_dict(framework)
dtypes = jsonpickle.dumps(dtypes)
print(dtypes)


if __name__ == "__main__":

arg_lis = sys.argv
Expand All @@ -80,6 +110,10 @@ def __init__(self, native_class):
while j:
try:
z = input()
if z == "1":
dtype_handler(arg_lis[2].split("/")[0])
continue

pickle_dict = jsonpickle.loads(z)
frontend_fw = input()

Expand Down