Skip to content

Commit

Permalink
automatically set pytorch requires
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Jun 20, 2024
1 parent 1164aab commit ebb1a87
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 9 deletions.
15 changes: 13 additions & 2 deletions backend/dp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

from scikit_build_core import build as _orig

from .find_pytorch import (
find_pytorch,
)
from .find_tensorflow import (
find_tensorflow,
)
Expand Down Expand Up @@ -40,10 +43,18 @@ def __dir__() -> List[str]:
def get_requires_for_build_wheel(
config_settings: dict,
) -> List[str]:
return _orig.get_requires_for_build_wheel(config_settings) + find_tensorflow()[1]
return (
_orig.get_requires_for_build_wheel(config_settings)
+ find_tensorflow()[1]
+ find_pytorch()[1]
)


def get_requires_for_build_editable(
config_settings: dict,
) -> List[str]:
return _orig.get_requires_for_build_editable(config_settings) + find_tensorflow()[1]
return (
_orig.get_requires_for_build_editable(config_settings)
+ find_tensorflow()[1]
+ find_pytorch()[1]
)
40 changes: 37 additions & 3 deletions backend/find_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
get_path,
)
from typing import (
List,
Optional,
Tuple,
)


@lru_cache
def find_pytorch() -> Optional[str]:
def find_pytorch() -> Tuple[Optional[str], List[str]]:
"""Find PyTorch library.
Tries to find PyTorch in the order of:
Expand All @@ -39,9 +41,12 @@ def find_pytorch() -> Optional[str]:
-------
str, optional
PyTorch library path if found.
list of str
TensorFlow requirement if not found. Empty if found.
"""
if os.environ.get("DP_ENABLE_PYTORCH", "1") == "0":
return None
return None, []
requires = []
pt_spec = None

if (pt_spec is None or not pt_spec) and os.environ.get("PYTORCH_ROOT") is not None:
Expand Down Expand Up @@ -73,4 +78,33 @@ def find_pytorch() -> Optional[str]:
# IndexError if submodule_search_locations is an empty list
except (AttributeError, TypeError, IndexError):
pt_install_dir = None
return pt_install_dir
requires.extend(get_pt_requirement()["torch"])
return pt_install_dir, requires


@lru_cache
def get_pt_requirement(pt_version: str = "") -> dict:
"""Get PyTorch requirement when PT is not installed.
If pt_version is not given and the environment variable `PYTORCH_VERSION` is set, use it as the requirement.
Parameters
----------
pt_version : str, optional
PT version
Returns
-------
dict
PyTorch requirement.
"""
if pt_version is None:
return {"torch": []}
if pt_version == "":
pt_version = os.environ.get("PYTORCH_VERSION", "")

return {
"torch": [
f"torch=={pt_version}" if pt_version != "" else "torch>=2a",
],
}
2 changes: 1 addition & 1 deletion backend/read_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def get_argument_from_env() -> Tuple[str, list, list, dict, str]:
tf_version = None

if os.environ.get("DP_ENABLE_PYTORCH", "1") == "1":
pt_install_dir = find_pytorch()
pt_install_dir, _ = find_pytorch()
if pt_install_dir is None:
raise RuntimeError("Cannot find installed PyTorch.")
cmake_args.extend(
Expand Down
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,6 @@ cu12 = [
"nvidia-cudnn-cu12<9",
"nvidia-cuda-nvcc-cu12",
]
torch = [
"torch>=2a",
]

[tool.deepmd_build_backend.scripts]
dp = "deepmd.main:main"
Expand Down

0 comments on commit ebb1a87

Please sign in to comment.