-
Notifications
You must be signed in to change notification settings - Fork 23.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.linalg.eigh
fails on GPU
#94772
Comments
Just want to emphasize that this is a major issue for experimenting with full-matrix Adagrad and Shampoo. Any suggestions on how to proceed would be much appreciated! cc: @shintaro-iwasaki @0x10cxR1 @kaushik88 @mhfb22 @dmudigere |
also cc: @xwang233 |
The matrix is all zeros and unfortunately cuSOLVER backend doesn't handle this edge case without an error. Could you try initializing |
If you are referring to For a typical run I observe:
|
Thanks for looking into it @IvanYashchuk! FYI, @mikerabbat has observed failure cases of |
It'd be good to have the exact matrix for which this one fails. Could you guys provide that example? My money is on the fact that the input matrix has repeated eigenvalues. Eigenvalue solvers struggle a bit with repeated eigenvalues... |
Thanks @lezcano -- I have saved the numpy array for to the You can restore the matrix
|
Right. An outer product gives a matrix of rank one. As such, in this case, you have a matrix of shape In other words, your solution to the problem is a bit numerically unstable, similar to when you try to solve a linear system with a matrix that's almost singular. What I would suggest is that, if you know that your matrix has the rank-1 structure, or low-rank structure in general, you use other methods. For a rank-1 symmetric matrix of the form |
@lezcano, thanks for your response. Just as additional context: We will try to dig up some additional failure cases where the matrix is not just low-rank and post our findings as we also investigate further on our side. Thanks! |
Note that the difference of behaviour between CPU and CUDA and this exact point on matrices with repeated eigenvalues is documented here: https://pytorch.org/docs/master/notes/numerical_accuracy.html#extremal-values-in-linalg |
Hi @lezcano, just want to clarify that on CPU, the algorithm at least succeeds (perhaps with numerical error), while on GPU, the algorithm entirely fails. We are willing to tolerate numerical error in the eigenvectors if the matrix is low-rank, as it will still produce an orthonormal basis for the zero eigenvalue subspace. I am referring here to the error bounds described here: https://netlib.org/lapack/lug/node90.html. Based on my understanding of your argument, even if we regularize by some constant here by adding epsilon * I, wouldn't the algorithm still fail? Would it even fail on a diagonal matrix c * I? If you recommend using a different algorithm for the low rank case, which algorithms would you suggest? Thanks again. |
The error bounds in the link you provide say what I mentioned above. You see that Theoretically, these algorithms may even fail for diagonal matrices, but in this case, since the matrix is exactly low-rank on floating point precision, the algorithm often succeeds. Note that As to which algorithm to use... well, that very much depends on your exact problem. I'd recommend you discuss it with your local linear algebra expert. For large and very sparse matrices, you have the classic SVD approximation described in https://arxiv.org/abs/0909.4061, that I believe would work well for your case, as you can give a bound on the rank of your matrix. You even have a post from meta describing what's what, and there is a fast implementation in cusolver of it. Again, different algorithms may be more amenable to different problems. |
Thanks @lezcano for the quick response on #105359! I understand that failure is expected on nearly singular matrices, but the matrix becoming unrecoverable afterwards may be an issue for our use case. Is there anything that can be done besides implementing checks on our side to convert to float64 before calling Additionally, I was able to reproduce the error yesterday on CUDA 12.1 as well. It seems to only occur on matrices larger than 512 x 512 - is the issue likely to be in the syevd solver? |
The issue is indeed in the cusolver side, so there's very little we can do on our end really. I don't think that there's any other solution for this ATM. |
cc: @dmudigere @xwang233 |
Thanks for the reminder. We'll take a look |
What is it that is left in an unrecoverable state? The tensor? The GPU? |
It seems that nothing on the GPU can be accessed anymore without a CUDA illegal memory access error.
It turns out that I may have been incorrect about the error occurring in CUDA 12.1 - after changing some environment variables, it seems to no longer fail? I'll investigate this some more and keep everyone updated. |
Seems like I was running on PyTorch nightly (2.1.0.dev20230717+cu121) by accident. Here's what I seem to have observed so far (please let me know if you can reproduce this!):
If another matrix were to fail in the last setup, I'm unsure whether it would cause the unrecoverable state or not. |
I believe this issue has been fixed starting from cuSOLVER 11.4.5, which shipped with CUDA 12.1 Update 1. Can you please confirm if this matches your observations? |
As before, while everything seems fine with PyTorch built from source + CUDA 12.1, if I use PyTorch nightly + CUDA 12.1 instead the error still occurs:
|
Is there a way for me to check the cusolver version? My CUDA runtime version seems to be 12.1.105. |
You can check the version stated in the filename of the library, for example: |
I see, |
I was able to reproduce the issue using the script from the first post, but that error was definitely resolved in CUDA 12.1 Update 1. Can you share your file, Also, please note that cuSOLVER had this issue in CUDA 12.1, and it was fixed in the update. You can check the version with |
The file
|
From discussion in triage review:
|
We'll add a test case for those inputs. For "possibly add support using MAGMA", users may refer to this function to prefer MAGMA as the linear algebra backend library. This is a global runtime switch. https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library torch.backends.cuda.preferred_linalg_library('magma') There might be performance hits when switching from default (cusolver) to MAGMA in |
…conditioned, in some cusolver version (#107082) Related: #94772, #105359 I can locally reproduce this crash with pytorch 2.0.1 stable pip binary. The test already passes with the latest cuda 12.2 release. Re: #94772 (comment) > From discussion in triage review: - [x] we should add a test to prevent regressions - [x] properly document support wrt different CUDA versions - [x] possibly add support using MAGMA Pull Request resolved: #107082 Approved by: https://github.com/lezcano
…conditioned, in some cusolver version (pytorch#107082) Related: pytorch#94772, pytorch#105359 I can locally reproduce this crash with pytorch 2.0.1 stable pip binary. The test already passes with the latest cuda 12.2 release. Re: pytorch#94772 (comment) > From discussion in triage review: - [x] we should add a test to prevent regressions - [x] properly document support wrt different CUDA versions - [x] possibly add support using MAGMA Pull Request resolved: pytorch#107082 Approved by: https://github.com/lezcano
Reproduced the issue using |
* Set FORCE_RPATH for ROCm (pytorch#1468) * Decouple aarch64 ci setup and build (pytorch#1470) * Run git update-index --chmod=+x aarch64_ci_setup.sh (pytorch#1471) * [aarch64][CICD]Add aarch64 docker image build. (pytorch#1472) * Add aarch64 docker image build * removing ulimit for PT workflow * set aarch64 worker for docker build * Fix `install_conda.sh` By pinning conda version to 23.5.2 as latest(23.7.2 at this time) does not have a compatible version of `git` packages Fixes pytorch#1473 * Remove explicit `conda install cmake` As it's already done as part of `common/install_conda.sh` script * update to CUDA 12.1U1 (pytorch#1476) Should fix pytorch/pytorch#94772 in wheel builds * Use conda version 23.5.2 for conda pytorch build (pytorch#1477) * Use py311 miniconda install (pytorch#1479) * Windows conda build fix (pytorch#1480) * Revert "Use py311 miniconda install (pytorch#1479)" (pytorch#1481) This reverts commit 5585c05. * Remove c/cb folder on windows (pytorch#1482) * Add numpy install - fix windows smoke tests (pytorch#1483) * Add numpy install * Add numpy install * Add hostedtoolcache purge step (pytorch#1484) * Add hostedtoolcache purge step * Change step name * Update CUDA_UPGRADE_GUIDE.MD * update CUDA to 12.1U1 for Windows (pytorch#1485) * Small improvements in build pytorch script (pytorch#1486) * Undo using conda activate (pytorch#1487) * Update meta.yaml (pytorch#1389) * Add pytorch-triton-rocm as an install dependency for ROCm (pytorch#1463) * Add pytorch-triton-rocm as an install dependency for ROCm * Update build_rocm.sh * Add aarch64 to validation framework (pytorch#1474) * Add aarch64 to validation framework (pytorch#1489) * Add aarch64 to validation framework (pytorch#1490) * Add aarch64 to validation framework * Add aarch64 to validation framework * Add aarch64 to validation framework (pytorch#1491) * Add aarch64 to validation framework * Add aarch64 to validation framework * Add aarch64 to validation framework * Temporary disable poetry test (pytorch#1492) * Add torchonly option to validation workflows (pytorch#1494) * Add torchonly option to validation workflows * fix typo * Remove pipy validation temporarily (pytorch#1495) * Remove pipy validation temporarily (pytorch#1496) * Add no-sudo to linux-aarch64 tests (pytorch#1499) * Pass container image to aarch64 test jobs (pytorch#1500) * Add setup aarch64 builds for aarch64 testing (pytorch#1501) * Fix DESIRED_PYTHON setting for aarch64 validations (pytorch#1502) * Use extra-index-url for aarch64 builds (pytorch#1503) * Pypi validation enable (pytorch#1504) * Validation pypi torchonly (pytorch#1505) * Pipy validation workflow (pytorch#1506) * Pipy validation workflow (pytorch#1507) * Pipy validation workflow (pytorch#1508) * Pipy validation workflow (pytorch#1509) * Validate poetry workflow (pytorch#1511) * Validate poetry workflow (pytorch#1512) * Remove linux-aarch64 installation workaround (pytorch#1513) * Temporary change test aarch64 builds (pytorch#1514) * Remove torchonly restictions from aarch64 builds (pytorch#1517) * Fix aarch64 nightly/release version override (pytorch#1518) * Aarch64 fix overrdie passing from CI to build * Aarch64 fix overrdie passing from CI to build * Aarch64 fix overrdie passing from CI to build * Revert "Temporary change test aarch64 builds (pytorch#1514)" (pytorch#1521) This reverts commit 1e281be. * Changes related to OVERRIDE_PACKAGE_VERSION in aarch64 builds (pytorch#1520) (pytorch#1523) * Torchmetrics in S3 Index (pytorch#1522) We will need the stable torchmetrics wheel in the S3 index, since torchrec depends on it. This is similar to how pytorch depends on numpy, etc. and these binaries need to be hosted in our index when uses try to pip install from download.pytorch.org. * [aarch64] update ACL version to v23.05.1 and OpenBLAS to v0.3.20 (pytorch#1488) * Changed runner for linux arm64 (pytorch#1525) * Add torch-tensorrt to S3 PyPI Index (pytorch#1529) As pytorch/tensorrt moves off of CCI onto Nova, we must to host their nightlies on our S3 index. This change allows the indexing to occur correctly for this package. * Enable torch compile for python 3.11 smoke tests (pytorch#1534) * Enable torch compile for python 3.11 smoke tests * Make sure release is covered * Fix typo * add jinja2 (pytorch#1536) * Remove restriction on 3.11 (pytorch#1537) * Revert "add jinja2 (pytorch#1536)" (pytorch#1538) This reverts commit 224a4c5. * S3 Management Job Outside Docker (pytorch#1531) * S3 Management Job Outside Docker * job name * remove failfast * no matrix * inherit secrets * spacing? * random nits * add back secrets * add back matrix * export env vars correctlty * Update update-s3-html.yml * Add fbgemm-gpu to S3 Index (pytorch#1539) * Update builder images to ROCm5.7 (pytorch#1541) * Update docker build images for rocm5.7 * Fix erroneous logic that was skipping msccl files even for ROCm5.6; update msccl path for ROCm5.7 (cherry picked from commit 36c10cc) * missing bzip2 package install for miopen * Revert "missing bzip2 package install for miopen" This reverts commit 8ef5fc9. * ROCm 5.7 MIOpen does not need any patches, do not build from source --------- Co-authored-by: Jeff Daily <[email protected]> * Update docker build convenience scripts to ROCm5.7 (pytorch#1543) * Do not uninstall MIOpen if skipping build-from-source (pytorch#1544) * Install nvtx3 on Windows (pytorch#1547) * Provide file hashes in the URLs to avoid unnecessary file downloads (bandwidth saver) (pytorch#1433) Supply sha256 query parameters using boto3 to avoid hundreds of extra Gigabytes of downloads each day during pipenv and poetry resolution lock cycles. Fixes point 1 in pytorch/pytorch#76557 Fixes pytorch#1347 * Workaround for older files * Bugfixes introduced by pytorch#1433 Replace `obj` with `obj.key` in few places Dismantle pyramid of doom while iterating over objects Test plan: Run `python manage.py whl/test --generate-pep503` * [S3_management] Update boto3 to 1.28.53 * [manage_s3] Download objects metadata concurrently Using `concurrent.futures.ThreadPoolExecutor` This speeds up rebuilding `whl/test` index from 300 sec to 90 sec on my laptop * Make smoke-test runnable without envvars * [aarch64] set acl_build_flags arch=armv8a, remove editing build flags (pytorch#1550) Looking at this PR: pytorch#1370 this line: https://github.com/pytorch/builder/pull/1370/files#diff-54480d0a69ca27f54fb0736a9762caa8b03bd4736dcd77190d99ec3033c9bd2fR229 That fixed the issue: pytorch/pytorch#97226 One of the changes is to set ``` arch=armv8a ``` We are experiencing the same issue now: pytorch/pytorch#109312 Hence this fix. * [BE] Fix all flake8 violations in `smoke_test.py` (pytorch#1553) Namely: - `if(x):` -> `if x:` - `"dev\d+"` -> `"dev\\d+"` - Keep 2 newlines between functions - Add `assert foo is not None` to suppress "variable assigned but not used" warning * [aarch64] patch mkl-dnn to use 'march=armv8-a' as the default build (pytorch#1554) * [aarch64] patch pytorch 2.1 for mkl-dnn fix (pytorch#1555) * patch ci script with mkldnn fix (pytorch#1556) * [BE] Add lint workflow (pytorch#1557) And format `smoke_test.py` with `ruff` Invoke/confgure `ruff` using `lintrunner` Copy lint runner adapters from https://github.com/pytorch/pytorch/tree/main/tools/linter/adapters * [BE] Add `s3_management` to the linted folders (pytorch#1558) Add `PERF401` to list of ignored suggestions, fix the rest. * Fix path issue when building aarch64 wheels (pytorch#1560) * Fix linalg smoke tests (pytorch#1563) * Towards enabling M1 wheel builds Do not try to install MKL on Apple Silicon * And only install llvm-9 on x86 systems * Do not build tests when building natively on M1 * And fix Python-3.8 native compilation on M1 There are no numpy=3.17 for M1 * Release 2.1 update promotion scripts (pytorch#1564) * [BE] Small code cleanup Fold multiple inidices and single index generation into one loop As loop body is the same anyway... * S3_management: Add option to compute sha256 That will be used later to generate sha256 indexes in PEP503 * Remove debug print * [S3_management] Minor improvements - Refactor `fetch_obj_names` into class method - Make sure that object remains public when ACL is computed - Add `has_public_read` and `grant_public_read` class methods * s3_management: compute checksum in cloud I.e. file never gets downloaded on the client, which is a nice thing * [S3Management] Add `undelete_prefix` method That can be used to recover object in a versioned bucket * Validate poetry for release (pytorch#1567) * Validate poetry for release * test * test * fixtypo * Use released version of 3.12 (pytorch#1568) As it was released on Oct 6 2023: https://www.python.org/downloads/release/python-3120/ * Move manywheel builds to `linux.12xlarge.ephemeral` (pytorch#1569) Should be faster(<20 min vs 40+ min) and as secure as using GH ones * Add cuSparseLt-0.5.0 to manywheel images * Use `linux.12xlarge.ephemeral` for conda docker builds (pytorch#1570) As `ubuntu.20.04` often OOM/failed to fetch data from RHEL repo * Revert "Add cuSparseLt-0.5.0 to manywheel images" This reverts commit 00841b6 as cuSparseLT is not compatible with CentOS 7 * Move libtorch docker builder to `linux.12xlarge.ephemeral` (pytorch#1571) As running it on `ubutu22.04` often results in flay infra failures/running out of disk space, for example, from https://github.com/pytorch/builder/actions/runs/6484948230/job/17609933012 ``` cat: write error: No space left on device ``` * Add cuSparseLt-0.4.0 to manywheel images But set USE_CUSPARSELT to 0 by default * Add xformers to the list of indexable packages * Build wheels with cuSparseLt Build libtorch without cuSparseLt so far Factor out `DEPS_LIST` to top level and add cuSparseLt of `USE_CUSPARSELT` is set to 1 Tested in pytorch/pytorch#111245 * Do not build conda with CuSparseLT * Add ROCM_PATH env var to Dockerfile for ROCm5.7 issue with finding HIP (pytorch#1572) * [aarch64_wheel] Minor typing improvements * [aarch64_wheel] Flake8 fix * [aarch64_wheel] Cosmetic changes * [aarch64_wheel] Fix readdir crash Probably fixes pytorch/pytorch#111695 * [S3_management] generate libtorch index.html * [CI] Update ruff to 0.1.1 To keep it in sync with pytorch * Get rid of http://repo.okay.com.mx (pytorch#1575) * [S3_management] Print time it takes to fetch index * [S3_manage] Handle invalid versions * [S3_management] Fix Version on error And fix flake8 lint violation * [S3_Management] Refactor `from_S3` Move `fetch_metadata` into its own method, which could be called later on Make S3Object non-frozen and introduce implicit __hash__ method * [S3_Management] Filter nighly before `fetch_metadata` This reduces time to call `from_S3Index` from 600 to 80 sec * Add option to build -arm64- libtorch binaries * [Docker] Remove trailing whitespace And cause docker rebuild, to overwrite docker build from release/2.1 branch artifacts * [MacOS] Small changes to libtorch naming Intel x86 libtorch builds will have `x86_64` suffix and Apple Silicon ones will have `arm64` ones, but latest will point to Intel ones for now. * Update libtorch/Dockerfile to use Ubuntu-20.04 (pytorch#1578) As 18.04 EOLed * Conda builds should respect `MAX_JOBS` May be this help with OOMs * [S3_management] Fix subpackage urls Make them `lower()` * Advance versions for release 2.1.1 (pytorch#1583) * [aarch64] Release pypi prep script change for aarch64 builds (pytorch#1585) * Changes needed for core enablement of 3.12 binary wheels (pytorch#1586) * Fix aarch64 build on 3.8 (pytorch#1593) * Add some more validation checks for torch.linalg.eigh and torch.compile (pytorch#1580) * Add some more validation checks for torch.linalg.eigh and torch.compile * Update test * Also update smoke_test.py * Fix lint * Revert "Add some more validation checks for torch.linalg.eigh and torch.compile (pytorch#1580)" (pytorch#1594) This reverts commit 4c7fa06. * Release validations using release version matrix (pytorch#1611) * Release pypi prep change (pytorch#1587) * [aarch64] Release pypi prep script change for aarch64 builds * Release versions for testing Testing calling version (pytorch#1588) Upstream/release validations (pytorch#1589) * Testing calling version * add release matrix Upstream/release validations (pytorch#1590) * Testing calling version * add release matrix * test test (pytorch#1591) test (pytorch#1592) Release v1 (pytorch#1595) * test * test Release v1 (pytorch#1596) * test * test * test test (pytorch#1597) Test versions validations (pytorch#1598) * test * basedir Test versions validations (pytorch#1599) * test * basedir * test test (pytorch#1600) * test * test Add release versions everywhere (pytorch#1601) * test * test * test * test test (pytorch#1602) Test version validations (pytorch#1603) * test * test Test version validations (pytorch#1604) * test * test * test tests (pytorch#1605) More tests nov16 (pytorch#1606) * tests * test More tests nov16 (pytorch#1607) * tests * test * test More tests nov16 (pytorch#1608) * tests * test * test * test More tests nov16 (pytorch#1609) * tests * test * test * test * test * fix_lint * fix: typo (pytorch#1581) * desired_cuda -> DESIRED_CUDA (pytorch#1612) * desired_cuda -> DESIRED_CUDA Found with shellcheck * Update manywheel/build_cuda.sh Co-authored-by: Nikita Shulga <[email protected]> --------- Co-authored-by: Nikita Shulga <[email protected]> * [BE] Cleanup build unused code (pytorch#1613) 1. Upload Scripts are not used anymore. We use Github Action upload workflows 2. M1 Builds are now automated 3. build_all.bat run git grep in pytorch and builder - No result * Changes to pypi release promotion scripts introduced for 2.1.0 and 2.1.1 (pytorch#1614) * Changes topypi release promotion scripts introduced during 2.1.1 * typo * Pin miniconda version for Windows To Miniconda3-py311_23.9.0-0-Windows-x86_64.exe * Fix poetry and pypi validations when version is specified (pytorch#1622) * test (pytorch#1617) Fix validations (pytorch#1618) * test * poetry_fix * test Fix validations (pytorch#1619) * test * poetry_fix * test * test * restrict * Validate pypi build only for release (pytorch#1623) * Validate pypi build only for release (pytorch#1624) * [Manywheel] Do not hardcode triton version * [Manywheel][BE] Dedup Triton requirement spec * [Manywheel] Restrict `pytorch-triton` to x86-64 Linux Partially addresses pytorch/pytorch#114042 * Tweak py312 conda requirements * Build PyTorch without TLS for 3.12 Because GLOO still expect OpenSSL-1, but 3.12 is build with OpenSSL-3 * [conda] Skip sympy for 3.12 As at the moment it is only available for Windows %) * [conda] Do not depend on triton for 3.12 yet * Tweak mkl requirements for win+py312 * Add aarch64 conda env lib to LD_LIBRARY_PATH (pytorch#1628) After the change on pytorch#1586, nightly aarch64 wheel fails to find `libopenblas.so` which is now installed under `/opt/conda/envs/aarch64_env/lib/` instead of the base conda `/opt/conda/lib`. Using CPU nightly wheels on aarch64 from Nov 16 then ends up with the error as described in pytorch/pytorch#114862: `Calling torch.geqrf on a CPU tensor requires compiling PyTorch with LAPACK. Please use PyTorch built with LAPACK support`. The error can be found on night build log https://github.com/pytorch/pytorch/actions/runs/6887666324/job/18735230109#step:15:4933 Fixes pytorch/pytorch#114862 I double check `2.1.[0-1]` and the current RC for 2.1.2, the issue is not there because pytorch#1586 only change builder main, thus impacting nightly. ### Testing Build nightly wheel manually on aarch64 runner and confirm that openblas is detected correctly: ``` -- Found a library with BLAS API (open). Full path: (/opt/conda/envs/aarch64_env/lib/libopenblas.so) ... -- USE_BLAS : 1 -- BLAS : open -- BLAS_HAS_SBGEMM : -- USE_LAPACK : 1 -- LAPACK : open ... ``` * Revert "[conda] Skip sympy for 3.12" This reverts commit 88457a1. As sympy has been updated to 1.12 and it now supports Python-3.12 * [aarch64] ACL, OpenBLAS and mkldnn updates for PyTorch 2.2 (pytorch#1627) Note# ~~This PR has a dependency on updating the oneDNN version to v3.3 (via ideep submodule to v3.3)~~ ideep submodule update is done, so, this PR can be merged anytime now. This PR is for: ACL - build with fixed format kernels OpenBLAS - upgrade the version to 0.3.25 numpy - upgrade version to 1.26.2 and mkldnn - cleanup the patches that are already upstreamed. * Validation scripts, install using version (pytorch#1633) * Test Windows static lib (pytorch#1465) Add support for testing Windows Cuda static lib * Pin windows intel-openmp to 2023.2.0 (pytorch#1635) (pytorch#1636) * Torch compile test for python 3.8-3.11 linux only (pytorch#1629) This should fix failure on with Python 3.12 validations: https://github.com/pytorch/builder/actions/runs/7064433251/job/19232483984#step:11:4859 * [aarch64] cleanup mkldnn patching (pytorch#1630) pytorch is moved to oneDNN v3.3.2 and some of the old patches are not applicable any more. * Add `aarch64_linux` to the list of linted files * Actually fix lint this type * Extend test_linalg from smoke_test.py To take device as an argument and run tests on both cpu and cuda * Run smoke_test_linalg during check_binary This is a regression test for pytorch/pytorch#114862 * Fix linalg testing * [BE] Add CI for check_binary.sh changes (pytorch#1637) Make sure latest nightly passes the testing for: - Linux Wheel CPU - Linux Wheel CUDA Tweak script a bit to work correctly with relative path to executable * Keep nightly 20231010 for ExecuTorch alpha 0.1 for now (pytorch#1642) * [Validations] do conda update before starting validations (pytorch#1643) * [Validations] Validate aarch64 if all is slected (pytorch#1644) * Fix validation workflow on aarch64 with conda 23.11.0 and GLIBC_2.25 (pytorch#1645) * Debug aarch64 clone * Debug * Fix validation workflow with conda 23.11.0 and GLIBC_2.25 * Gate the change on linux-aarch64 and keep the old LD_LIBRARY_PATH * Try to unset LD_LIBRARY_PATH in the workflow instead * Fix copy/paste typo * Do not hardcode triton version in builder code (pytorch#1646) * Do not hardcode triton version in builder code * Minor tweak to use pytorch_rootdir * [Lint] Prohibit tabs in shell scripts Fix current violations * Link conda packages with cusparselt Fixes pytorch/pytorch#115085 * aarch64: patch mkl-dnn for xbyak crashes due to /sys not accessible (pytorch#1648) There are platforms with /sys not mounted. skip handling HW caps for such platforms. cherry-pick of: oneapi-src/oneDNN#1773 This fixes the issue# pytorch/pytorch#115482 * Update builder images to ROCm6.0 (pytorch#1647) * Update ROCm versions for docker images * Don't build MIOpen from source for ROCm6.0 * Temporarily use magma fork with ROCm6.0 patch * Update ROCm versions for docker images * Add gfx942 * Update MIOpen repo * Magma PR 42 is merged, so use upstream repo master branch now * gfx942 target only fully supported for ROCm6.0 and above * Avoid finding out std::basic_string_view (pytorch#1528) As pytorch moving to C++17, the binary can contain both "std::basic_string_view" and "std::__cxx11::basic_string<", change the pattern to avoid finding out std::basic_string_view, causing false positives. * Add test ops validation for validation workflows (pytorch#1650) * Add test ops validation * include workflows * Add test ops validation for validation workflows (pytorch#1651) * Add test ops validation for validation workflows (pytorch#1652) * Add test ops validation for validation workflows (pytorch#1653) * Add test ops validation for validation workflows (pytorch#1654) * Add test ops validation for validation workflows (pytorch#1655) * [validations] Add missing required packages (pytorch#1656) * [validations] Perform test_ops only on CUDA binaries (pytorch#1657) * [validations] Adjust timeout for linux jobs (pytorch#1658) * [validations] Restrict testing for python 3.8-3.11 (pytorch#1659) * [validations] Fix use case if INCLUDE_TEST_OPS is not set (pytorch#1660) * Add unit tests and one line reproducers to detect bad pytorch cuda wheels (pytorch#1663) * Add one line reproducers and unit tests that would fail when bad wheels were generated by the compiler(s). nextafter reproducer thanks to @malfet! * cosmetic fixes * fix comments * Fix quotation issues when migrating from python file to one line format (pytorch#1664) Sorry, looks like the last line had an issue while porting it from multi-line python file to one-line. Side question: when does this file get used? Is it only used during release binary generation/testing? * Add nccl version print for cuda related smoke test (pytorch#1667) * Apply nccl test to linux only (pytorch#1669) * Build nccl after installing cuda (pytorch#1670) Fix: pytorch/pytorch#116977 Nccl 2.19.3 don't exist for cuda 11.8 and cuda 12.1. Refer to https://docs.nvidia.com/deeplearning/nccl/release-notes/rel_2-19-3.html#rel_2-19-3 CUDA 12.0, 12.2, 12.3 are supported. Hence we do manual build. Follow this build process: https://github.com/NVIDIA/nccl/tree/v2.19.3-1?tab=readme-ov-file#build We want nccl version be exactly the same as installed here: https://github.com/pytorch/pytorch/blob/main/.github/scripts/generate_binary_build_matrix.py#L45 * Update cusparselt to v0.5.2 (pytorch#1672) This PR adds in support for cuSPARSELt v0.5.2 and updates the cuda 12.1 build step to use it instead of 0.4.0 Also fixes a typo when deleting the cusparselt folder after installing. * Run test ops tests from outside of pytorch root folder (pytorch#1676) * Remove s3 update html job and scripts (pytorch#1677) * [BE] Remove unused nightly_defaults.bat (pytorch#1678) * [Conda] Mark `blas * mkl` as x86 only dependency * [Conda] Download arch appropriate Miniconda By using `$(uname -m)` as suffix, which is arm64 on Apple Silicon and x86 on Intel Macs * [Conda] Do not depend on llvmdev-9 on ARM As earliest available for the platform is llvmdev-11 * [Conda] Set correct developer dir for MacOS runners * [Conda] Add llvm-openmp dependency for ARM64 PyTorch for M1 is finally built with OpenMP, so it needs to depend on it * Use dynamic MKL on Windows (pytorch#1467) Use dynamic MKL on Windows and updated MKL to 2021.4.0 On conda python 3.12 use mkl 2023.1 * Add torchrec to promote s3 script (pytorch#1680) * Add torchrec to promote s3 script * Add torchrec version to release_version.sh * Revert "Dynamic MKL windows" (pytorch#1682) * Revert "Revert "Dynamic MKL windows"" (pytorch#1683) * Add numpy install to windows conda tests (pytorch#1684) * Windows conda test. Install numpy in conda testenv (pytorch#1685) * Add fbgemm to promote s3 script (pytorch#1681) * Release 2.2.0 pypi prep script modifications (pytorch#1686) * [Analytics] add pypi staging validations, remove circleci script (pytorch#1688) * [Analytics] Pypi validations. Add call to check-wheel-contents (pytorch#1689) * Modify Validate Nightly PyPI Wheel Binary Size to pick correct binary (pytorch#1690) * Fix test_ops scripts on release validation testing (pytorch#1691) * Add option to validate only from download.pytorch.org (pytorch#1692) * Exclude pipy and poetry tests when USE_ONLY_DL_PYTORCH_ORG is set (pytorch#1693) * [ROCm] add hipblaslt library files (pytorch#1695) With pytorch/pytorch#114329 merged, we need to include hipblaslt library files within the ROCm wheel. * Minor tweak to fbgemmgpu version to ignore RC suffix (pytorch#1694) * Remove custom PyTorch build dependency logic on 3.11 (pytorch#1697) * Remove custom PyTorch build dependency logic on 3.11 * Add a smoke test for openmp * Pin conda-build to 3.28.4 (pytorch#1698) * ci: aarch64 linux: fix torch performance issue with conda openblas package (pytorch#1696) changing the conda openblas package from pthread version to openmp version to match torch openmp runtime. The pthread version was conflicting with the openmp runtime and causing thread over-subscription and performance degradation. * Add triton version for nightly and release (pytorch#1703) * Bundle PTXAS into 11.8 wheel * Add tensorrt promo script, bump release version for 2.2.1 (pytorch#1706) * Pin Conda to 23.11.0 --------- Co-authored-by: Andrey Talman <[email protected]> Co-authored-by: Mike Schneider <[email protected]> Co-authored-by: Nikita Shulga <[email protected]> Co-authored-by: ptrblck <[email protected]> Co-authored-by: JYX <[email protected]> Co-authored-by: Omkar Salpekar <[email protected]> Co-authored-by: snadampal <[email protected]> Co-authored-by: Danylo Baibak <[email protected]> Co-authored-by: Supadchaya <[email protected]> Co-authored-by: Jeff Daily <[email protected]> Co-authored-by: cyy <[email protected]> Co-authored-by: Matt Davis <[email protected]> Co-authored-by: Nikita Shulga <[email protected]> Co-authored-by: Huy Do <[email protected]> Co-authored-by: albanD <[email protected]> Co-authored-by: Luo Bo <[email protected]> Co-authored-by: Sergii Dymchenko <[email protected]> Co-authored-by: Ionuț Manța <[email protected]> Co-authored-by: Wei Wang <[email protected]> Co-authored-by: Jesse Cai <[email protected]> Co-authored-by: henrylhtsang <[email protected]>
🐛 Describe the bug
Calling
torch.linalg.eigh
on a CUDA tensor fails, but the computation succeeds when the tensor is on the CPU.I have experienced this issue on CUDA 11.6, 11.7 and 11.8.
This is a blocker towards executing the Shampoo optimizer.
Minimal replication script
Error trace
cc @ezyang @gchanan @zou3519 @ptrblck @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano @ngimel @hjmshi @mikerabbat @tsunghsienlee @dmudiger
Versions
The text was updated successfully, but these errors were encountered: