-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added requirement handler, for usage check comment in file (#8887)
* added requirement handler, for usage check comment in file * more changes and fixes * experiments * multiversion, added dockerfile, added multiversion_testing.py, added conda_manipulator * changes * mega changes that twirl the world, haha, just changes that make multiversion testing possible * minor changes for jax imports * lint fixes * changes * lint * lint * lint * lint Co-authored-by: Rishabh Kumar <[email protected]>
- Loading branch information
1 parent
9ec2b69
commit ecabace
Showing
22 changed files
with
888 additions
and
115 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
FROM ubuntu:20.04 | ||
WORKDIR /ivy | ||
|
||
COPY ../docker/multicondaenv.yml . | ||
COPY ../docker/multiversion_framework_directory.py . | ||
COPY ../docker/multiversion_testing_requirements.txt . | ||
COPY ../docker/run_multiversion_framework_directory.sh . | ||
|
||
|
||
|
||
|
||
# Install miniconda | ||
ENV CONDA_DIR /opt/miniconda | ||
|
||
RUN apt clean && \ | ||
rm -rf /var/lib/apt/lists/* && \ | ||
apt-get update && \ | ||
apt-get install -y wget | ||
|
||
RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-py38_22.11.1-1-Linux-x86_64.sh -O ~/miniconda.sh && \ | ||
/bin/bash ~/miniconda.sh -b -p /opt/miniconda | ||
|
||
|
||
ENV PATH=$CONDA_DIR/bin:$PATH | ||
RUN conda env create -f multicondaenv.yml | ||
|
||
RUN pip3 install --no-cache-dir --no-deps -r multiversion_testing_requirements.txt --target=/opt/miniconda/envs/multienv/lib/python3.8/site-packages | ||
RUN ./run_multiversion_framework_directory.sh | ||
# Make RUN commands use the new environment: | ||
SHELL ["conda", "run", "-n", "multienv", "/bin/bash", "-c"] | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#!/bin/bash | ||
|
||
docker build --progress=plain -t experiment_conda -f MultiversionDockerFile .. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
name: multienv | ||
channels: | ||
- conda-forge | ||
dependencies: | ||
- python=3.8.0 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# flake8: noqa | ||
import os | ||
import subprocess | ||
|
||
|
||
def directory_generator(req, base="/opt/miniconda/fw/"): | ||
for versions in req: | ||
pkg, ver = versions.split("/") | ||
path = base + pkg + "/" + ver | ||
if not os.path.exists(path): | ||
install_pkg(path, pkg + "==" + ver) | ||
|
||
|
||
def install_pkg(path, pkg, base="fw/"): | ||
if pkg.split("==")[0] == "torch": | ||
subprocess.run( | ||
f"pip3 install {pkg} --default-timeout=100 -f https://download.pytorch.org/whl/torch_stable.html --target={path}", | ||
shell=True, | ||
) | ||
elif pkg.split("==")[0] == "jaxlib": | ||
subprocess.run( | ||
f"pip3 install {pkg} --default-timeout=100 -f https://storage.googleapis.com/jax-releases/jax_releases.html --target={path}", | ||
shell=True, | ||
) | ||
else: | ||
subprocess.run( | ||
f"pip3 install {pkg} --default-timeout=100 --target={path}", shell=True | ||
) | ||
|
||
|
||
torch_req = ["torch/1.4.0", "torch/1.5.0", "torch/1.10.1"] | ||
tensorflow_req = [ | ||
"tensorflow/2.2.0", | ||
"tensorflow/2.2.1", | ||
"tensorflow/2.2.2", | ||
"tensorflow/2.4.4", | ||
"tensorflow/2.9.0", | ||
"tensorflow/2.9.1", | ||
] | ||
jax_req = ["jax/0.1.60", "jax/0.1.61"] | ||
jaxlib_req = ["jaxlib/0.1.50", "jaxlib/0.1.60", "jaxlib/0.1.61"] | ||
numpy_req = ["numpy/1.17.3", "numpy/1.17.4", "numpy/1.23.1", "numpy/1.24.0"] | ||
|
||
directory_generator(torch_req) | ||
directory_generator(tensorflow_req) | ||
directory_generator(jax_req) | ||
directory_generator(numpy_req) | ||
directory_generator(jaxlib_req) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
h5py==3.7.0 | ||
pytest==7.1.2 | ||
networkx==2.8.4 | ||
hypothesis==6.55.0 | ||
pymongo==4.3.3 | ||
redis==4.3.4 | ||
matplotlib==3.5.2 | ||
opencv-python==4.6.0.66 # mod_name=cv2 | ||
tensorflow-addons==0.17.1 # mod_name=tensorflow_addons | ||
tensorflow-probability==0.17.0 # mod_name=tensorflow_probability | ||
functorch==0.1.1 | ||
scipy==1.8.1 | ||
dm-haiku==0.0.6 # mod_name=haiku | ||
pydriller | ||
tqdm | ||
coverage | ||
einops | ||
psutil | ||
termcolor | ||
colorama | ||
packaging | ||
nvidia-ml-py<=11.495.46 # mod_name=pynvml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#!/bin/sh | ||
cd ivy | ||
python multiversion_framework_drectory.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# flake8: noqa | ||
import os | ||
import sys | ||
import importlib | ||
|
||
global_temp_sys_module = {} | ||
|
||
|
||
def allow_global_framework_imports(fw=["numpy/1.23.1"]): | ||
# since no framework installed right now we quickly store a copy of the sys.modules | ||
global global_temp_sys_module | ||
if not global_temp_sys_module: | ||
global_temp_sys_module = sys.modules.copy() | ||
for framework in fw: | ||
sys.path.insert(1, os.path.abspath("/opt/miniconda/fw/" + framework)) | ||
|
||
|
||
def try_except(): | ||
try: | ||
import numpy | ||
except ImportError: | ||
allow_global_framework_imports() | ||
|
||
|
||
def return_global_temp_sys_module(): | ||
return global_temp_sys_module | ||
|
||
|
||
def reset_sys_modules_to_base(): | ||
if global_temp_sys_module != sys.modules: | ||
sys.modules.clear() | ||
sys.modules.update(global_temp_sys_module) | ||
|
||
|
||
# to import a specific pkg along with version name, to be used by the test functions | ||
def custom_import( | ||
pkg, base="/opt/miniconda/fw/", globally_done=None | ||
): # format is pkg_name/version , globally_done means | ||
# if we have imported any framework before globally | ||
if globally_done: # i.e import numpy etc | ||
if pkg == globally_done: | ||
ret = importlib.import_module(pkg.split("/")[0]) | ||
return ret | ||
sys.path.remove(os.path.abspath(base + globally_done)) | ||
temp = sys.modules.copy() | ||
sys.modules.clear() | ||
sys.modules.update(global_temp_sys_module) | ||
sys.path.insert(1, os.path.abspath(base + pkg)) | ||
ret = importlib.import_module(pkg.split("/")[0]) | ||
sys.path.remove(os.path.abspath(base + pkg)) | ||
sys.path.insert(1, os.path.abspath(base + globally_done)) | ||
sys.modules.clear() | ||
sys.modules.update(temp) | ||
return ret | ||
|
||
temp = sys.modules.copy() | ||
sys.path.insert(1, os.path.abspath(base + pkg)) | ||
ret = importlib.import_module(pkg.split("/")[0]) | ||
sys.path.remove(os.path.abspath(base + pkg)) | ||
sys.modules.clear() | ||
sys.modules.update(temp) | ||
|
||
return ret |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
from .. import config | ||
|
||
if hasattr(config, "try_except"): | ||
config.try_except() | ||
from . import helpers | ||
|
||
test_shapes = ((), (1,), (2, 1), (1, 2, 3)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,40 @@ | ||
# A list of available backends that can be used for testing. | ||
|
||
available_frameworks = ["numpy", "jax", "tensorflow", "torch"] | ||
|
||
try: | ||
import jax | ||
|
||
assert jax, "jax is imported to see if the user has it installed" | ||
except ImportError: | ||
available_frameworks.remove("jax") | ||
|
||
try: | ||
import tensorflow as tf | ||
|
||
assert tf, "tensorflow is imported to see if the user has it installed" | ||
except ImportError: | ||
available_frameworks.remove("tensorflow") | ||
|
||
try: | ||
import torch | ||
|
||
assert torch, "torch is imported to see if the user has it installed" | ||
except ImportError: | ||
available_frameworks.remove("torch") | ||
|
||
if "tensorflow" in available_frameworks: | ||
ground_truth = "tensorflow" | ||
elif "torch" in available_frameworks: | ||
ground_truth = "torch" | ||
elif "jax" in available_frameworks: | ||
ground_truth = "jax" | ||
else: | ||
ground_truth = "numpy" | ||
def available_frameworks(): | ||
available_frameworks_lis = ["numpy", "jax", "tensorflow", "torch"] | ||
try: | ||
import jax | ||
|
||
assert jax, "jax is imported to see if the user has it installed" | ||
except ImportError: | ||
available_frameworks_lis.remove("jax") | ||
|
||
try: | ||
import tensorflow as tf | ||
|
||
assert tf, "tensorflow is imported to see if the user has it installed" | ||
except ImportError: | ||
available_frameworks_lis.remove("tensorflow") | ||
|
||
try: | ||
import torch | ||
|
||
assert torch, "torch is imported to see if the user has it installed" | ||
except ImportError: | ||
available_frameworks_lis.remove("torch") | ||
return available_frameworks_lis | ||
|
||
|
||
def ground_truth(): | ||
available_framework_lis = available_frameworks() | ||
g_truth = "" | ||
if "tensorflow" in available_framework_lis: | ||
g_truth = "tensorflow" | ||
elif "torch" in available_framework_lis: | ||
g_truth = "torch" | ||
elif "jax" in available_framework_lis: | ||
g_truth = "jax" | ||
else: | ||
g_truth = "numpy" | ||
return g_truth |
Oops, something went wrong.