Skip to content

Commit

Permalink
Remove all dependencies except torch (pytorch#369)
Browse files Browse the repository at this point in the history
* clean up req.txt

* yolo

* yolo

* yolo

* yolo

* yolo

* yolo

* Update dev-requirements.txt

* Delete requirements.txt

* Update utils.py

* Update README.md

* Trigger CI

* remove req from ci

* Update doc_build.yml

* remove numpy as a dependency

* yolo

* Update dev-requirements.txt

* Update setup.py
  • Loading branch information
msaroufim authored Jun 17, 2024
1 parent eb1511e commit 1b78dda
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 35 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/doc_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ jobs:
python -m pip install torch
python -m pip install -e .
pip install -r dev-requirements.txt
cd docs
python -m pip install -r requirements.txt
python -m pip install -r docs/requirements.txt
- name: Build docs
env:
TORCHAO_VERSION_DOCS: ${{ github.ref }}
Expand Down
3 changes: 1 addition & 2 deletions .github/workflows/nightly_smoke_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: PyTorch CUDA Nightly Smoke Test

on:
schedule:
# 6 am PST every day
# 6 am PST every day
- cron: "0 14 * * *"
workflow_dispatch:

Expand Down Expand Up @@ -34,7 +34,6 @@ jobs:
script: |
python -m pip install --upgrade pip
pip install ${{ matrix.torch-spec }}
pip install -r requirements.txt
pip install -r dev-requirements.txt
python setup.py install
pytest test --verbose -s
3 changes: 1 addition & 2 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
torch-spec: 'torch==2.3.0'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"
- name: CUDA Nightly
- name: CUDA Nightly
runs-on: linux.g5.12xlarge.nvidia.gpu
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
gpu-arch-type: "cuda"
Expand Down Expand Up @@ -65,7 +65,6 @@ jobs:
export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH
python -m pip install --upgrade pip
pip install ${{ matrix.torch-spec }}
pip install -r requirements.txt
pip install -r dev-requirements.txt
pip install .
pytest test --verbose -s
12 changes: 3 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,15 @@ From source
```Shell
git clone https://github.com/pytorch/ao
cd ao
pip install -r requirements.txt
pip install -r dev-requirements.txt
python setup.py install
```

There are two options;
-If you plan to be developing the library run:
If you plan to be developing the library run:
```Shell
pip install -r dev-requirements.txt
python setup.py develop
```

If you want to install from source run
```Shell
python setup.py install
```

** Note:
If you are running into any issues while building `ao` cpp extensions you can instead build using

Expand Down
3 changes: 3 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ unittest-xml-reporting
parameterized
packaging
transformers
hypothesis # Avoid test derandomization warning
sentencepiece # for gpt-fast tokenizer
expecttest

# For prototype features and benchmarks
bitsandbytes #needed for testing triton quant / dequant ops for 8-bit optimizers
Expand Down
4 changes: 0 additions & 4 deletions requirements.txt

This file was deleted.

1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def get_extensions():
pytorch_dep += "==" + os.getenv("PYTORCH_VERSION")

requirements = [
"numpy",
pytorch_dep,
]

Expand Down
5 changes: 5 additions & 0 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import torch
import logging

# torch/nested/_internal/nested_tensor.py:417: UserWarning: Failed to initialize NumPy: No module named 'numpy'
import warnings
warnings.filterwarnings("ignore", message="Failed to initialize NumPy: No module named 'numpy'")


# We use this "hack" to set torchao.__version__ correctly
# the version of ao is dependent on environment variables for multiple architectures
# For local development this will default to whatever is version.txt
Expand Down
25 changes: 10 additions & 15 deletions torchao/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import torch
import torch.utils.benchmark as benchmark
from typing import Tuple
from functools import reduce
from importlib.metadata import version
from math import gcd
from packaging import version
import torch.nn.utils.parametrize as parametrize
import itertools

Expand All @@ -19,6 +18,7 @@
"TORCH_VERSION_AFTER_2_2",
"TORCH_VERSION_AFTER_2_3",
"TORCH_VERSION_AFTER_2_4",
"TORCH_VERSION_AFTER_2_5",
]


Expand Down Expand Up @@ -64,8 +64,9 @@ def wrapper(*args, **kwargs):


def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
import torch.utils.benchmark as benchmark # this avoids importing numpy when torchao module is loaded

# Manual warmup

f(*args, **kwargs)
f(*args, **kwargs)

Expand Down Expand Up @@ -163,20 +164,14 @@ def unwrap_tensor_subclass(model, filter_fn=None):
unwrap_tensor_subclass(child)
return model

if version.parse(torch.__version__) >= version.parse("2.4.0.dev"):
TORCH_VERSION_AFTER_2_4 = True
else:
TORCH_VERSION_AFTER_2_4 = False

if version.parse(torch.__version__) >= version.parse("2.3.0.dev"):
TORCH_VERSION_AFTER_2_3 = True
else:
TORCH_VERSION_AFTER_2_3 = False
def torch_version_at_least(min_version):
return version("torch") >= min_version

if version.parse(torch.__version__) >= version.parse("2.2.0.dev"):
TORCH_VERSION_AFTER_2_2 = True
else:
TORCH_VERSION_AFTER_2_2 = False
TORCH_VERSION_AFTER_2_5 = torch_version_at_least("2.5.0.dev")
TORCH_VERSION_AFTER_2_4 = torch_version_at_least("2.4.0.dev")
TORCH_VERSION_AFTER_2_3 = torch_version_at_least("2.3.0.dev")
TORCH_VERSION_AFTER_2_2 = torch_version_at_least("2.2.0.dev")

def is_fbcode():
return not hasattr(torch.version, "git_version")

0 comments on commit 1b78dda

Please sign in to comment.