Skip to content

Commit

Permalink
Change import behavior (#132)
Browse files Browse the repository at this point in the history
* fix numpy import

* optimize circle ci performance

Co-authored-by: /bin/eash <[email protected]>
  • Loading branch information
liwt31 and /bin/eash authored Sep 19, 2022
1 parent c91ebb0 commit 254b981
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 81 deletions.
4 changes: 3 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ jobs:
name: Run tests
# This assumes pytest is installed via the install-package step above
command: |
pytest --durations=0
pip install pytest-xdist
export RENO_NUM_THREADS=1
pytest -n 4 --durations=0
pip install primme==3.2.*
pytest --durations=0 renormalizer/mps/tests/test_gs.py::test_multistate
- run:
Expand Down
22 changes: 12 additions & 10 deletions renormalizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,22 @@
# Author: Jiajun Ren <[email protected]>
import os
import sys
import warnings

if "numpy" in sys.modules:
raise ImportError("renormalizer should be imported before numpy is imported")

# set environment variables to limit NumPy cpu usage
# Note that this should be done before NumPy is imported
for env in ["MKL_NUM_THREADS", "NUMEXPR_NUM_THREADS", "OMP_NUM_THREADS"]:
os.environ[env] = "1"
del env
reno_num_threads = os.environ.get("RENO_NUM_THREADS")
if reno_num_threads is not None:
# set environment variables to limit NumPy cpu usage
# Note that this should be done before NumPy is imported

# NEP-18 not working. For compatibility of newer NumPy version. See gh-cupy/cupy#2130
os.environ["NUMPY_EXPERIMENTAL_ARRAY_FUNCTION"] = "0"
if "numpy" in sys.modules:
warnings.warn("renormalizer should be imported before numpy for `RENO_NUM_THREADS` to take effect")
else:
for env in ["MKL_NUM_THREADS", "NUMEXPR_NUM_THREADS", "OMP_NUM_THREADS"]:
os.environ[env] = reno_num_threads
del env

del os, sys
del os, sys, warnings

from renormalizer.utils.log import init_log

Expand Down
13 changes: 11 additions & 2 deletions renormalizer/mps/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ def __init__(self):
self.first_mp = False
self._real_dtype = None
self._complex_dtype = None
#self.use_32bits()
self.use_64bits()
if os.environ.get("RENO_FP32") is None:
self.use_64bits()
else:
self.use_32bits()

def free_all_blocks(self):
if not USE_GPU:
Expand Down Expand Up @@ -130,5 +132,12 @@ def dtypes(self):
def dtypes(self, target):
self.real_dtype, self.complex_dtype = target

@property
def canonical_atol(self):
if self.is_32bits:
return 1e-4
else:
return 1e-5


backend = Backend()
12 changes: 8 additions & 4 deletions renormalizer/mps/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,21 +90,25 @@ def r_combine(self):
def l_combine(self):
return self.reshape(self.l_combine_shape)

def check_lortho(self, rtol=1e-5, atol=1e-8):
def check_lortho(self, atol=None):
"""
check L-orthogonal
"""
if atol is None:
atol = backend.canonical_atol
tensm = asxp(self.array.reshape([np.prod(self.shape[:-1]), self.shape[-1]]))
s = tensm.T.conj() @ tensm
return xp.allclose(s, xp.eye(s.shape[0]), rtol=rtol, atol=atol)
return xp.allclose(s, xp.eye(s.shape[0]), atol=atol)

def check_rortho(self, rtol=1e-5, atol=1e-8):
def check_rortho(self, atol=None):
"""
check R-orthogonal
"""
if atol is None:
atol = backend.canonical_atol
tensm = asxp(self.array.reshape([self.shape[0], np.prod(self.shape[1:])]))
s = tensm @ tensm.T.conj()
return xp.allclose(s, xp.eye(s.shape[0]), rtol=rtol, atol=atol)
return xp.allclose(s, xp.eye(s.shape[0]), atol=atol)

def to_complex(self):
# `xp.array` always creates new array, so to_complex means copy, which is
Expand Down
Loading

0 comments on commit 254b981

Please sign in to comment.