Skip to content

Commit

Permalink
Multi cuda gpu (#13189)
Browse files Browse the repository at this point in the history
* multicuda changes as well as multiversion changes

* multiversion framework directory changes

* check

* changes

* run tests

* changes

* Update run_tests.py

---------

Co-authored-by: Rishabh Kumar <[email protected]>
  • Loading branch information
RickSanchezStoic and Rishabh Kumar authored Mar 26, 2023
1 parent 86f2483 commit 2ed827a
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
13 changes: 12 additions & 1 deletion ivy_tests/test_ivy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Dict
import subprocess
import importlib

from .. import config as env_config
mod_frontend = {
"tensorflow": None,
"numpy": None,
Expand Down Expand Up @@ -69,6 +69,16 @@ def pytest_configure(config):
else:
backend_strs = raw_value.split(",")

# env specification for multiversion backend
env_val=config.getoption("--env")
if env_val:
# check if multiversion format in backend argument
if [True if '/' in x else False for x in backend_strs][0]:
raise Exception("--env and '/' naming in backend can't be used together")
else:
env_val=env_val.split(',')
env_config.allow_global_framework_imports(fw=env_val)

# frontend
frontend = config.getoption("--frontend")
if frontend:
Expand Down Expand Up @@ -258,6 +268,7 @@ def pytest_addoption(parser):
parser.addoption("--compile_graph", action="store_true")
parser.addoption("--with_implicit", action="store_true")
parser.addoption("--frontend", action="store", default=None)
parser.addoption("--env", action="store", default=None)
parser.addoption("--ground_truth", action="store", default=None)
parser.addoption("--skip-variable-testing", action="store_true")
parser.addoption("--skip-native-array-testing", action="store_true")
Expand Down
2 changes: 1 addition & 1 deletion ivy_tests/test_ivy/helpers/function_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def framework_comparator(frontend):
elif frontend.split("/")[0] == "torch":
return (
frontend.split("/")[1]
== importlib.import_module(frontend.split("/")[1]).__version__.split("+")[0]
== importlib.import_module(frontend.split("/")[0]).__version__.split("+")[0]
)
else:
return (
Expand Down
2 changes: 1 addition & 1 deletion multiversion_frontend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def test_frontend_method():
try:
ivy.set_backend(arg_lis[2].split("/")[0])
except: # noqa: E722
raise Exception(f"lalalalal {fw_lis}")
raise
import numpy

try:
Expand Down

0 comments on commit 2ed827a

Please sign in to comment.