From a2c9d58cc2ca05d4aec32c66da6621ce38b2acdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patrick=20Gel=C3=9F?= <38036185+PGelss@users.noreply.github.com> Date: Wed, 17 Jul 2024 14:46:59 +0200 Subject: [PATCH] fix in ortho routines --- scikit_tt/tensor_train.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/scikit_tt/tensor_train.py b/scikit_tt/tensor_train.py index 7c72c2f..c9a6783 100644 --- a/scikit_tt/tensor_train.py +++ b/scikit_tt/tensor_train.py @@ -1141,15 +1141,16 @@ def ortho_left(self, start_index: int=0, # check for correct max_rank argument and set max_ranks max_rank_tf = True - if not isinstance(max_rank, list) and ((isinstance(max_rank, (int, np.int32, np.int64)) and max_rank > 0) or max_rank == np.infty): + if (isinstance(max_rank, (int, np.int32, np.int64)) and max_rank > 0) or max_rank == np.infty: max_ranks = [1] + [max_rank for _ in range(self.order-1)] + [1] else: - if len(max_rank) == self.order+1: - for i in range(self.order+1): - if not ((isinstance(max_rank[i], (int, np.int32, np.int64)) and max_rank[i] > 0) or max_rank[i] == np.infty): - max_rank_tf = False - if max_rank_tf: - max_ranks = max_rank + if isinstance(max_rank, list): + if len(max_rank) == self.order+1: + for i in range(self.order+1): + if not ((isinstance(max_rank[i], (int, np.int32, np.int64)) and max_rank[i] > 0) or max_rank[i] == np.infty): + max_rank_tf = False + if max_rank_tf: + max_ranks = max_rank if max_rank_tf: @@ -1193,7 +1194,7 @@ def ortho_left(self, start_index: int=0, return self else: - raise ValueError('Maximum rank must be a positive integer.') + raise ValueError('Maximum rank(s) must be positive integers.') else: raise ValueError('Threshold must be greater or equal 0.') @@ -1244,15 +1245,16 @@ def ortho_right(self, start_index: Optional[int]=None, # check for correct max_rank argument and set max_ranks max_rank_tf = True - if not isinstance(max_rank, list) and ((isinstance(max_rank, (int, np.int32, np.int64)) and max_rank > 0) or max_rank == np.infty): + if (isinstance(max_rank, (int, np.int32, np.int64)) and max_rank > 0) or max_rank == np.infty: max_ranks = [1] + [max_rank for _ in range(self.order-1)] + [1] else: - if len(max_rank) == self.order+1: - for i in range(self.order+1): - if not ((isinstance(max_rank[i], (int, np.int32, np.int64)) and max_rank[i] > 0) or max_rank[i] == np.infty): - max_rank_tf = False - if max_rank_tf: - max_ranks = max_rank + if isinstance(max_rank, list): + if len(max_rank) == self.order+1: + for i in range(self.order+1): + if not ((isinstance(max_rank[i], (int, np.int32, np.int64)) and max_rank[i] > 0) or max_rank[i] == np.infty): + max_rank_tf = False + if max_rank_tf: + max_ranks = max_rank if max_rank_tf: @@ -1298,7 +1300,7 @@ def ortho_right(self, start_index: Optional[int]=None, return self else: - raise ValueError('Maximum rank must be a positive integer.') + raise ValueError('Maximum rank(s) must be positive integers.') else: raise ValueError('Threshold must be greater or equal 0.')