From 8110914cd189841ff7d6260d369635ed513efd8b Mon Sep 17 00:00:00 2001 From: Toon Baeyens Date: Tue, 6 Dec 2022 15:28:09 +0100 Subject: [PATCH 1/5] start custom sort rule --- include/Spectra/Util/SelectionRule.h | 170 +++++++-------------------- 1 file changed, 43 insertions(+), 127 deletions(-) diff --git a/include/Spectra/Util/SelectionRule.h b/include/Spectra/Util/SelectionRule.h index 9d12a34..390b830 100644 --- a/include/Spectra/Util/SelectionRule.h +++ b/include/Spectra/Util/SelectionRule.h @@ -30,7 +30,8 @@ namespace Spectra { /// /// The enumeration of selection rules of desired eigenvalues. /// -enum class SortRule + +enum class SortRuleType { LargestMagn, ///< Select eigenvalues with largest magnitude. Magnitude ///< means the absolute value for real numbers and norm for @@ -55,140 +56,55 @@ enum class SortRule BothEnds ///< Select eigenvalues half from each end of the spectrum. When ///< `nev` is odd, compute more from the high end. Only for symmetric eigen solvers. + }; /// \cond - // When comparing eigenvalues, we first calculate the "target" to sort. // For example, if we want to choose the eigenvalues with // largest magnitude, the target will be -abs(x). // The minus sign is due to the fact that std::sort() sorts in ascending order. -// Default target: throw an exception -template -class SortingTarget -{ -public: - static ElemType get(const Scalar& val) - { - using std::abs; - throw std::invalid_argument("incompatible selection rule"); - return -abs(val); - } -}; - -// Specialization for SortRule::LargestMagn -// This covers [float, double, complex] x [SortRule::LargestMagn] -template -class SortingTarget -{ -public: - static ElemType get(const Scalar& val) - { - using std::abs; - return -abs(val); - } -}; - -// Specialization for SortRule::LargestReal -// This covers [complex] x [SortRule::LargestReal] -template -class SortingTarget, SortRule::LargestReal> -{ -public: - static RealType get(const std::complex& val) - { - return -val.real(); - } -}; - -// Specialization for SortRule::LargestImag -// This covers [complex] x [SortRule::LargestImag] -template -class SortingTarget, SortRule::LargestImag> -{ -public: - static RealType get(const std::complex& val) - { - using std::abs; - return -abs(val.imag()); - } -}; - -// Specialization for SortRule::LargestAlge -// This covers [float, double] x [SortRule::LargestAlge] -template -class SortingTarget -{ -public: - static Scalar get(const Scalar& val) - { - return -val; - } -}; - -// Here SortRule::BothEnds is the same as SortRule::LargestAlge, but -// we need some additional steps, which are done in -// SymEigsSolver.h => retrieve_ritzpair(). -// There we move the smallest values to the proper locations. -template -class SortingTarget -{ -public: - static Scalar get(const Scalar& val) - { - return -val; - } -}; - -// Specialization for SortRule::SmallestMagn -// This covers [float, double, complex] x [SortRule::SmallestMagn] -template -class SortingTarget -{ -public: - static ElemType get(const Scalar& val) - { - using std::abs; - return abs(val); - } -}; - -// Specialization for SortRule::SmallestReal -// This covers [complex] x [SortRule::SmallestReal] -template -class SortingTarget, SortRule::SmallestReal> -{ -public: - static RealType get(const std::complex& val) - { - return val.real(); - } -}; - -// Specialization for SortRule::SmallestImag -// This covers [complex] x [SortRule::SmallestImag] -template -class SortingTarget, SortRule::SmallestImag> -{ -public: - static RealType get(const std::complex& val) - { - using std::abs; - return abs(val.imag()); - } -}; - -// Specialization for SortRule::SmallestAlge -// This covers [float, double] x [SortRule::SmallestAlge] template -class SortingTarget +struct SortRule { -public: - static Scalar get(const Scalar& val) - { - return val; - } + std::function(Scalar)> get; + + inline static SortRule LargestMagn{[](Scalar x) { using std::abs; return -std::abs(x); }}; + inline static SortRule LargestReal{[](Scalar x) { + if constexpr (Eigen::NumTraits::IsComplex) + return -x.real(); + else + return -x; + }}; + inline static SortRule LargestImag{[](Scalar x) { + static_assert::IsComplex, "LargestImag is only for complex numbers.">; + return -x.imag(); + }}; + inline static SortRule LargestAlge{[](Scalar x) { + static_assert::IsComplex, "LargestAlge is only for real numbers.">; + return -x; + }}; + inline static SortRule BothEnds{[](Scalar x) { + static_assert::IsComplex, "LargestAlge is only for real numbers.">; + return -x; + }}; + + inline static SortRule SmallestMagn{[](Scalar x) { using std::abs; return -std::abs(x); }}; + inline static SortRule SmallestReal{[](Scalar x) { + if constexpr (Eigen::NumTraits::IsComplex) + return x.real(); + else + return x; + }}; + inline static SortRule SmallestImag{[](Scalar x) { + static_assert::IsComplex, "SmallestImag is only for complex numbers.">; + return x.imag(); + }}; + inline static SortRule SmallestAlge{[](Scalar x) { + static_assert::IsComplex, "SmallestAlge is only for real numbers.">; + return x; + }}; }; // Sort eigenvalues @@ -225,7 +141,7 @@ class SortEigenvalue // Sort values[:len] according to the selection rule, and return the indices template -std::vector argsort(SortRule selection, const Eigen::Matrix& values, Eigen::Index len) +std::vector argsort(SortRule selection, const Eigen::Matrix& values, Eigen::Index len) { using Index = Eigen::Index; @@ -269,7 +185,7 @@ std::vector argsort(SortRule selection, const Eigen::Matrix::BothEnds) { std::vector ind_copy(ind); for (Index i = 0; i < len; i++) From fb0657e4060623a6331ad51b9488bf9aea8de197 Mon Sep 17 00:00:00 2001 From: Toon Baeyens Date: Wed, 7 Dec 2022 12:03:35 +0100 Subject: [PATCH 2/5] Add support for custom sort rule --- include/Spectra/DavidsonSymEigsSolver.h | 4 +- include/Spectra/GenEigsBase.h | 99 +--------- include/Spectra/GenEigsComplexShiftSolver.h | 2 +- include/Spectra/GenEigsRealShiftSolver.h | 2 +- include/Spectra/LinAlg/RitzPairs.h | 4 +- include/Spectra/SymEigsBase.h | 12 +- include/Spectra/SymEigsShiftSolver.h | 2 +- include/Spectra/SymGEigsShiftSolver.h | 6 +- include/Spectra/Util/SelectionRule.h | 206 ++++++++------------ 9 files changed, 104 insertions(+), 233 deletions(-) diff --git a/include/Spectra/DavidsonSymEigsSolver.h b/include/Spectra/DavidsonSymEigsSolver.h index fe9f56c..79c9df3 100644 --- a/include/Spectra/DavidsonSymEigsSolver.h +++ b/include/Spectra/DavidsonSymEigsSolver.h @@ -57,9 +57,9 @@ class DavidsonSymEigsSolver : public JDSymEigsBase /// /// \param selection Spectrum section to target (e.g. lowest, etc.) /// \return Matrix with the initial orthonormal basis - Matrix setup_initial_search_space(SortRule selection) const + Matrix setup_initial_search_space(const EigenvalueSorter &selection) const { - std::vector indices_sorted = argsort(selection, m_diagonal); + std::vector indices_sorted = selection.argsort(m_diagonal); Matrix initial_basis = Matrix::Zero(this->m_matrix_operator.rows(), this->m_initial_search_space_size); diff --git a/include/Spectra/GenEigsBase.h b/include/Spectra/GenEigsBase.h index 9f71b7c..009a82f 100644 --- a/include/Spectra/GenEigsBase.h +++ b/include/Spectra/GenEigsBase.h @@ -83,7 +83,7 @@ class GenEigsBase static bool is_conj(const Complex& v1, const Complex& v2) { return v1 == Eigen::numext::conj(v2); } // Implicitly restarted Arnoldi factorization - void restart(Index k, SortRule selection) + void restart(Index k, const EigenvalueSorter& selection) { using std::norm; @@ -96,7 +96,7 @@ class GenEigsBase for (Index i = k; i < m_ncv; i++) { - if (is_complex(m_ritz_val[i]) && is_conj(m_ritz_val[i], m_ritz_val[i + 1])) + if (i + 1 < m_ncv && is_complex(m_ritz_val[i]) && is_conj(m_ritz_val[i], m_ritz_val[i + 1])) { // H - mu * I = Q1 * R1 // H <- R1 * Q1 + mu * I = Q1' * H * Q1 @@ -193,55 +193,14 @@ class GenEigsBase } // Retrieves and sorts Ritz values and Ritz vectors - void retrieve_ritzpair(SortRule selection) + void retrieve_ritzpair(const EigenvalueSorter & selection) { UpperHessenbergEigen decomp(m_fac.matrix_H()); const ComplexVector& evals = decomp.eigenvalues(); ComplexMatrix evecs = decomp.eigenvectors(); // Sort Ritz values and put the wanted ones at the beginning - std::vector ind; - switch (selection) - { - case SortRule::LargestMagn: - { - SortEigenvalue sorting(evals.data(), m_ncv); - sorting.swap(ind); - break; - } - case SortRule::LargestReal: - { - SortEigenvalue sorting(evals.data(), m_ncv); - sorting.swap(ind); - break; - } - case SortRule::LargestImag: - { - SortEigenvalue sorting(evals.data(), m_ncv); - sorting.swap(ind); - break; - } - case SortRule::SmallestMagn: - { - SortEigenvalue sorting(evals.data(), m_ncv); - sorting.swap(ind); - break; - } - case SortRule::SmallestReal: - { - SortEigenvalue sorting(evals.data(), m_ncv); - sorting.swap(ind); - break; - } - case SortRule::SmallestImag: - { - SortEigenvalue sorting(evals.data(), m_ncv); - sorting.swap(ind); - break; - } - default: - throw std::invalid_argument("unsupported selection rule"); - } + std::vector ind = selection.argsort(evals.data(), m_ncv); // Copy the Ritz values and vectors to m_ritz_val and m_ritz_vec, respectively for (Index i = 0; i < m_ncv; i++) @@ -258,51 +217,9 @@ class GenEigsBase protected: // Sorts the first nev Ritz pairs in the specified order // This is used to return the final results - virtual void sort_ritzpair(SortRule sort_rule) + virtual void sort_ritzpair(const EigenvalueSorter & sort_rule) { - std::vector ind; - switch (sort_rule) - { - case SortRule::LargestMagn: - { - SortEigenvalue sorting(m_ritz_val.data(), m_nev); - sorting.swap(ind); - break; - } - case SortRule::LargestReal: - { - SortEigenvalue sorting(m_ritz_val.data(), m_nev); - sorting.swap(ind); - break; - } - case SortRule::LargestImag: - { - SortEigenvalue sorting(m_ritz_val.data(), m_nev); - sorting.swap(ind); - break; - } - case SortRule::SmallestMagn: - { - SortEigenvalue sorting(m_ritz_val.data(), m_nev); - sorting.swap(ind); - break; - } - case SortRule::SmallestReal: - { - SortEigenvalue sorting(m_ritz_val.data(), m_nev); - sorting.swap(ind); - break; - } - case SortRule::SmallestImag: - { - SortEigenvalue sorting(m_ritz_val.data(), m_nev); - sorting.swap(ind); - break; - } - default: - throw std::invalid_argument("unsupported sorting rule"); - } - + std::vector ind = sort_rule.argsort(m_ritz_val.data(), m_nev); ComplexVector new_ritz_val(m_ncv); ComplexMatrix new_ritz_vec(m_ncv, m_nev); BoolArray new_ritz_conv(m_nev); @@ -414,8 +331,8 @@ class GenEigsBase /// /// \return Number of converged eigenvalues. /// - Index compute(SortRule selection = SortRule::LargestMagn, Index maxit = 1000, - Scalar tol = 1e-10, SortRule sorting = SortRule::LargestMagn) + Index compute(const EigenvalueSorter& selection = SortRule::LargestMagn, Index maxit = 1000, + Scalar tol = 1e-10, const EigenvalueSorter& sorting = SortRule::LargestMagn) { // The m-step Arnoldi factorization m_fac.factorize_from(1, m_ncv, m_nmatop); diff --git a/include/Spectra/GenEigsComplexShiftSolver.h b/include/Spectra/GenEigsComplexShiftSolver.h index 90040a8..e6f47f3 100644 --- a/include/Spectra/GenEigsComplexShiftSolver.h +++ b/include/Spectra/GenEigsComplexShiftSolver.h @@ -51,7 +51,7 @@ class GenEigsComplexShiftSolver : public GenEigsBase const Scalar m_sigmai; // First transform back the Ritz values, and then sort - void sort_ritzpair(SortRule sort_rule) override + void sort_ritzpair(const EigenvalueSorter &sort_rule) override { using std::abs; using std::sqrt; diff --git a/include/Spectra/GenEigsRealShiftSolver.h b/include/Spectra/GenEigsRealShiftSolver.h index a1d08b8..781c3dd 100644 --- a/include/Spectra/GenEigsRealShiftSolver.h +++ b/include/Spectra/GenEigsRealShiftSolver.h @@ -45,7 +45,7 @@ class GenEigsRealShiftSolver : public GenEigsBase const Scalar m_sigma; // First transform back the Ritz values, and then sort - void sort_ritzpair(SortRule sort_rule) override + void sort_ritzpair(const EigenvalueSorter &sort_rule) override { // The eigenvalues we get from the iteration is nu = 1 / (lambda - sigma) // So the eigenvalues of the original problem is lambda = 1 / nu + sigma diff --git a/include/Spectra/LinAlg/RitzPairs.h b/include/Spectra/LinAlg/RitzPairs.h index 09bda61..7841662 100644 --- a/include/Spectra/LinAlg/RitzPairs.h +++ b/include/Spectra/LinAlg/RitzPairs.h @@ -52,9 +52,9 @@ class RitzPairs /// Sort the eigen pairs according to the selection rule /// /// \param selection Sorting rule - void sort(SortRule selection) + void sort(const EigenvalueSorter &selection) { - std::vector ind = argsort(selection, m_values); + std::vector ind = selection.argsort(m_values); RitzPairs temp = *this; for (Index i = 0; i < size(); i++) { diff --git a/include/Spectra/SymEigsBase.h b/include/Spectra/SymEigsBase.h index 6357248..a9d2899 100644 --- a/include/Spectra/SymEigsBase.h +++ b/include/Spectra/SymEigsBase.h @@ -174,14 +174,14 @@ class SymEigsBase } // Retrieves and sorts Ritz values and Ritz vectors - void retrieve_ritzpair(SortRule selection) + void retrieve_ritzpair(const EigenvalueSorter &selection) { TridiagEigen decomp(m_fac.matrix_H()); const Vector& evals = decomp.eigenvalues(); const Matrix& evecs = decomp.eigenvectors(); // Sort Ritz values and put the wanted ones at the beginning - std::vector ind = argsort(selection, evals, m_ncv); + std::vector ind = selection.argsort(evals.data(), m_ncv); // Copy the Ritz values and vectors to m_ritz_val and m_ritz_vec, respectively for (Index i = 0; i < m_ncv; i++) @@ -198,13 +198,9 @@ class SymEigsBase protected: // Sorts the first nev Ritz pairs in the specified order // This is used to return the final results - virtual void sort_ritzpair(SortRule sort_rule) + virtual void sort_ritzpair(const EigenvalueSorter &sort_rule) { - if ((sort_rule != SortRule::LargestAlge) && (sort_rule != SortRule::LargestMagn) && - (sort_rule != SortRule::SmallestAlge) && (sort_rule != SortRule::SmallestMagn)) - throw std::invalid_argument("unsupported sorting rule"); - - std::vector ind = argsort(sort_rule, m_ritz_val, m_nev); + std::vector ind = sort_rule.argsort(m_ritz_val.data(), m_nev); Vector new_ritz_val(m_ncv); Matrix new_ritz_vec(m_ncv, m_nev); diff --git a/include/Spectra/SymEigsShiftSolver.h b/include/Spectra/SymEigsShiftSolver.h index a2bfea9..0a20080 100644 --- a/include/Spectra/SymEigsShiftSolver.h +++ b/include/Spectra/SymEigsShiftSolver.h @@ -160,7 +160,7 @@ class SymEigsShiftSolver : public SymEigsBase const Scalar m_sigma; // First transform back the Ritz values, and then sort - void sort_ritzpair(SortRule sort_rule) override + void sort_ritzpair(const EigenvalueSorter &sort_rule) override { // The eigenvalues we get from the iteration is nu = 1 / (lambda - sigma) // So the eigenvalues of the original problem is lambda = 1 / nu + sigma diff --git a/include/Spectra/SymGEigsShiftSolver.h b/include/Spectra/SymGEigsShiftSolver.h index c7dc50f..82a9d4c 100644 --- a/include/Spectra/SymGEigsShiftSolver.h +++ b/include/Spectra/SymGEigsShiftSolver.h @@ -167,7 +167,7 @@ class SymGEigsShiftSolver : } // First transform back the Ritz values, and then sort - void sort_ritzpair(SortRule sort_rule) override + void sort_ritzpair(const EigenvalueSorter &sort_rule) override { // The eigenvalues we get from the iteration is nu = 1 / (lambda - sigma) // So the eigenvalues of the original problem is lambda = 1 / nu + sigma @@ -329,7 +329,7 @@ class SymGEigsShiftSolver : } // First transform back the Ritz values, and then sort - void sort_ritzpair(SortRule sort_rule) override + void sort_ritzpair(const EigenvalueSorter &sort_rule) override { // The eigenvalues we get from the iteration is nu = lambda / (lambda - sigma) // So the eigenvalues of the original problem is lambda = sigma * nu / (nu - 1) @@ -421,7 +421,7 @@ class SymGEigsShiftSolver : } // First transform back the Ritz values, and then sort - void sort_ritzpair(SortRule sort_rule) override + void sort_ritzpair(const EigenvalueSorter &sort_rule) override { // The eigenvalues we get from the iteration is nu = (lambda + sigma) / (lambda - sigma) // So the eigenvalues of the original problem is lambda = sigma * (nu + 1) / (nu - 1) diff --git a/include/Spectra/Util/SelectionRule.h b/include/Spectra/Util/SelectionRule.h index 390b830..bb67408 100644 --- a/include/Spectra/Util/SelectionRule.h +++ b/include/Spectra/Util/SelectionRule.h @@ -31,7 +31,7 @@ namespace Spectra { /// The enumeration of selection rules of desired eigenvalues. /// -enum class SortRuleType +enum class SortRule { LargestMagn, ///< Select eigenvalues with largest magnitude. Magnitude ///< means the absolute value for real numbers and norm for @@ -66,148 +66,106 @@ enum class SortRuleType // The minus sign is due to the fact that std::sort() sorts in ascending order. template -struct SortRule +struct EigenvalueSorter { + bool both_ends; std::function(Scalar)> get; - inline static SortRule LargestMagn{[](Scalar x) { using std::abs; return -std::abs(x); }}; - inline static SortRule LargestReal{[](Scalar x) { - if constexpr (Eigen::NumTraits::IsComplex) - return -x.real(); - else - return -x; - }}; - inline static SortRule LargestImag{[](Scalar x) { - static_assert::IsComplex, "LargestImag is only for complex numbers.">; - return -x.imag(); - }}; - inline static SortRule LargestAlge{[](Scalar x) { - static_assert::IsComplex, "LargestAlge is only for real numbers.">; - return -x; - }}; - inline static SortRule BothEnds{[](Scalar x) { - static_assert::IsComplex, "LargestAlge is only for real numbers.">; - return -x; - }}; - - inline static SortRule SmallestMagn{[](Scalar x) { using std::abs; return -std::abs(x); }}; - inline static SortRule SmallestReal{[](Scalar x) { - if constexpr (Eigen::NumTraits::IsComplex) - return x.real(); - else - return x; - }}; - inline static SortRule SmallestImag{[](Scalar x) { - static_assert::IsComplex, "SmallestImag is only for complex numbers.">; - return x.imag(); - }}; - inline static SortRule SmallestAlge{[](Scalar x) { - static_assert::IsComplex, "SmallestAlge is only for real numbers.">; - return x; - }}; -}; - -// Sort eigenvalues -template -class SortEigenvalue -{ -private: using Index = Eigen::Index; using IndexArray = std::vector; - const T* m_evals; - IndexArray m_index; - -public: - // Sort indices according to the eigenvalues they point to - inline bool operator()(Index i, Index j) + template + EigenvalueSorter(SortRule rule, typename std::enable_if::IsComplex>::type* = nullptr) { - return SortingTarget::get(m_evals[i]) < SortingTarget::get(m_evals[j]); + both_ends = false; + if (rule == SortRule::LargestMagn) + get = [](Scalar x) { using std::abs; return -std::abs(x); }; + else if (rule == SortRule::LargestReal) + get = [](Scalar x) { + return -x.real(); + }; + else if (rule == SortRule::LargestImag) + get = [](Scalar x) { + return -x.imag(); + }; + else if (rule == SortRule::SmallestMagn) + get = [](Scalar x) { using std::abs; return -std::abs(x); }; + else if (rule == SortRule::SmallestReal) + get = [](Scalar x) { + return x.real(); + }; + else if (rule == SortRule::SmallestImag) + get = [](Scalar x) { + return x.imag(); + }; + else + throw std::invalid_argument("unsupported selection rule for complex types"); } - SortEigenvalue(const T* start, Index size) : - m_evals(start), m_index(size) + template + EigenvalueSorter(SortRule rule, typename std::enable_if::IsComplex>::type* = nullptr) { - for (Index i = 0; i < size; i++) - { - m_index[i] = i; - } - std::sort(m_index.begin(), m_index.end(), *this); + both_ends = rule == SortRule::BothEnds; + if (rule == SortRule::LargestMagn) + get = [](Scalar x) { using std::abs; return -std::abs(x); }; + else if (rule == SortRule::LargestReal) + get = [](Scalar x) { + return -x; + }; + else if (rule == SortRule::LargestAlge || rule == SortRule::BothEnds) + get = [](Scalar x) { + return -x; + }; + else if (rule == SortRule::SmallestMagn) + get = [](Scalar x) { using std::abs; return -std::abs(x); }; + else if (rule == SortRule::SmallestReal) + get = [](Scalar x) { + return x; + }; + else if (rule == SortRule::SmallestAlge) + get = [](Scalar x) { + return x; + }; + else + throw std::invalid_argument("unsupported selection rule for real types"); } - inline IndexArray index() const { return m_index; } - inline void swap(IndexArray& other) { m_index.swap(other); } -}; - -// Sort values[:len] according to the selection rule, and return the indices -template -std::vector argsort(SortRule selection, const Eigen::Matrix& values, Eigen::Index len) -{ - using Index = Eigen::Index; - - // Sort Ritz values and put the wanted ones at the beginning - std::vector ind; - switch (selection) + IndexArray argsort(const Scalar* data, Index size) const { - case SortRule::LargestMagn: - { - SortEigenvalue sorting(values.data(), len); - sorting.swap(ind); - break; - } - case SortRule::BothEnds: - case SortRule::LargestAlge: - { - SortEigenvalue sorting(values.data(), len); - sorting.swap(ind); - break; - } - case SortRule::SmallestMagn: - { - SortEigenvalue sorting(values.data(), len); - sorting.swap(ind); - break; - } - case SortRule::SmallestAlge: + IndexArray index; + index.resize(size); + for (Index i = 0; i < size; i++) + index[i] = i; + std::sort(index.begin(), index.end(), [&](Index i, Index j) { return get(data[i]) < get(data[j]); }); + + // For SortRule::BothEnds, the eigenvalues are sorted according to the + // SortRule::LargestAlge rule, so we need to move those smallest values to the left + // The order would be + // Largest => Smallest => 2nd largest => 2nd smallest => ... + // We keep this order since the first k values will always be + // the wanted collection, no matter k is nev_updated (used in SymEigsBase::restart()) + // or is nev (used in SymEigsBase::sort_ritzpair()) + if (both_ends) { - SortEigenvalue sorting(values.data(), len); - sorting.swap(ind); - break; + IndexArray index_copy(index); + for (Index i = 0; i < size; i++) + { + // If i is even, pick values from the left (large values) + // If i is odd, pick values from the right (small values) + if (i % 2 == 0) + index[i] = index_copy[i / 2]; + else + index[i] = index_copy[size - 1 - i / 2]; + } } - default: - throw std::invalid_argument("unsupported selection rule"); + return index; } - // For SortRule::BothEnds, the eigenvalues are sorted according to the - // SortRule::LargestAlge rule, so we need to move those smallest values to the left - // The order would be - // Largest => Smallest => 2nd largest => 2nd smallest => ... - // We keep this order since the first k values will always be - // the wanted collection, no matter k is nev_updated (used in SymEigsBase::restart()) - // or is nev (used in SymEigsBase::sort_ritzpair()) - if (&selection == &SortRule::BothEnds) + IndexArray argsort(const Eigen::Matrix& values) const { - std::vector ind_copy(ind); - for (Index i = 0; i < len; i++) - { - // If i is even, pick values from the left (large values) - // If i is odd, pick values from the right (small values) - if (i % 2 == 0) - ind[i] = ind_copy[i / 2]; - else - ind[i] = ind_copy[len - 1 - i / 2]; - } + return argsort(values.data(), values.size()); } - - return ind; -} - -// Default vector length -template -std::vector argsort(SortRule selection, const Eigen::Matrix& values) -{ - return argsort(selection, values, values.size()); -} +}; /// \endcond From 72767462f03a0741953c85852434b1b4753cfe92 Mon Sep 17 00:00:00 2001 From: Toon Baeyens Date: Wed, 7 Dec 2022 12:07:06 +0100 Subject: [PATCH 3/5] run workflow on all branches --- .github/workflows/Basic.yml | 6 +----- .github/workflows/checkformat.yml | 7 ++----- .github/workflows/codecov.yml | 7 ++----- 3 files changed, 5 insertions(+), 15 deletions(-) diff --git a/.github/workflows/Basic.yml b/.github/workflows/Basic.yml index 47355e8..8f6b343 100644 --- a/.github/workflows/Basic.yml +++ b/.github/workflows/Basic.yml @@ -1,10 +1,6 @@ name: Basic CI -on: - push: - branches: [ master ] - pull_request: - branches: [ master ] +on: [push, pull_request] jobs: build: diff --git a/.github/workflows/checkformat.yml b/.github/workflows/checkformat.yml index 0d1c105..cd226be 100644 --- a/.github/workflows/checkformat.yml +++ b/.github/workflows/checkformat.yml @@ -1,9 +1,6 @@ name: check format -on: - push: - branches: [ master ] - pull_request: - branches: [ master ] + +on: [push, pull_request] jobs: build: diff --git a/.github/workflows/codecov.yml b/.github/workflows/codecov.yml index 459eaac..674435c 100644 --- a/.github/workflows/codecov.yml +++ b/.github/workflows/codecov.yml @@ -1,9 +1,6 @@ name: Codecov -on: - push: - branches: [ master ] - pull_request: - branches: [ master ] + +on: [push, pull_request] jobs: build: From d00196e8d942cb72074b4c1f1e65afb9ddc488fc Mon Sep 17 00:00:00 2001 From: Toon Baeyens Date: Wed, 7 Dec 2022 14:22:08 +0100 Subject: [PATCH 4/5] Add some documentation to EigenvalueSorter class --- .github/workflows/Basic.yml | 6 +- .github/workflows/checkformat.yml | 7 ++- .github/workflows/codecov.yml | 7 ++- include/Spectra/Util/SelectionRule.h | 91 +++++++++++++++++++--------- test/GenEigs.cpp | 6 +- 5 files changed, 84 insertions(+), 33 deletions(-) diff --git a/.github/workflows/Basic.yml b/.github/workflows/Basic.yml index 8f6b343..47355e8 100644 --- a/.github/workflows/Basic.yml +++ b/.github/workflows/Basic.yml @@ -1,6 +1,10 @@ name: Basic CI -on: [push, pull_request] +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] jobs: build: diff --git a/.github/workflows/checkformat.yml b/.github/workflows/checkformat.yml index cd226be..0d1c105 100644 --- a/.github/workflows/checkformat.yml +++ b/.github/workflows/checkformat.yml @@ -1,6 +1,9 @@ name: check format - -on: [push, pull_request] +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] jobs: build: diff --git a/.github/workflows/codecov.yml b/.github/workflows/codecov.yml index 674435c..459eaac 100644 --- a/.github/workflows/codecov.yml +++ b/.github/workflows/codecov.yml @@ -1,6 +1,9 @@ name: Codecov - -on: [push, pull_request] +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] jobs: build: diff --git a/include/Spectra/Util/SelectionRule.h b/include/Spectra/Util/SelectionRule.h index bb67408..604ad7d 100644 --- a/include/Spectra/Util/SelectionRule.h +++ b/include/Spectra/Util/SelectionRule.h @@ -59,71 +59,110 @@ enum class SortRule }; -/// \cond -// When comparing eigenvalues, we first calculate the "target" to sort. -// For example, if we want to choose the eigenvalues with -// largest magnitude, the target will be -abs(x). -// The minus sign is due to the fact that std::sort() sorts in ascending order. +/// +/// Eigenvalue solvers use this class to determine which eigenvalues to select. +/// Values of the type SortRule are automatically cast to an appropriate EigenvalueSorter. +/// +/// You can also provide your own implementation to determine which eigenvalues to select. +/// For symmetric solvers, this should have a real scalar type. +/// \code C++ +/// SymEigsSolver<...> eigsSolver(...); +/// eigsSolver.init(); +/// EigenvalueSorter closestToTarget{[](double eigenvalue) { return std::abs(eigenvalue - 10.); }}; +/// eigsSolver.compute(closestToTarget); +/// \endcode +/// For general solvers, the scalar type should be complex. +/// \code C++ +/// GenEigsSolver<...> eigsSolver(...); +/// eigsSolver.init(); +/// EigenvalueSorter> closestToTarget{[](std::complex eigenvalue) { return std::abs(eigenvalue - 10.); }}; +/// eigsSolver.compute(closestToTarget); +/// \endcode template -struct EigenvalueSorter +class EigenvalueSorter { - bool both_ends; - std::function(Scalar)> get; +private: + bool m_both_ends; + std::function(Scalar)> m_target; +public: using Index = Eigen::Index; using IndexArray = std::vector; + /// + /// \param target Eigenvalues will be sorted according to this target. Only the values with the lowest target will be computed. + /// \param both_ends If both_ends is true, half of the eigenvalues with the lowest target and half of the eigenvalues with the highest target will be computed. + EigenvalueSorter(std::function(Scalar)> target, bool both_ends) : + m_both_ends(both_ends), m_target(target) + { + } + + /// + /// \param target Eigenvalues will be sorted according to this target. Only the values with the lowest target will be computed. + explicit EigenvalueSorter(std::function(Scalar)> target) : + m_both_ends(false), m_target(target) + { + } + + /// This constructor casts a SortRule to an appropriate EigenvalueSorter, for complex scalar types template EigenvalueSorter(SortRule rule, typename std::enable_if::IsComplex>::type* = nullptr) { - both_ends = false; + // The scalar-type is complex + + m_both_ends = false; if (rule == SortRule::LargestMagn) - get = [](Scalar x) { using std::abs; return -std::abs(x); }; + m_target = [](Scalar x) { using std::abs; return -std::abs(x); }; else if (rule == SortRule::LargestReal) - get = [](Scalar x) { + m_target = [](Scalar x) { return -x.real(); }; else if (rule == SortRule::LargestImag) - get = [](Scalar x) { - return -x.imag(); + m_target = [](Scalar x) { + using std::abs; + return -abs(x.imag()); }; else if (rule == SortRule::SmallestMagn) - get = [](Scalar x) { using std::abs; return -std::abs(x); }; + m_target = [](Scalar x) { using std::abs; return -std::abs(x); }; else if (rule == SortRule::SmallestReal) - get = [](Scalar x) { + m_target = [](Scalar x) { return x.real(); }; else if (rule == SortRule::SmallestImag) - get = [](Scalar x) { - return x.imag(); + m_target = [](Scalar x) { + using std::abs; + return abs(x.imag()); }; else throw std::invalid_argument("unsupported selection rule for complex types"); } + /// This constructor casts a SortRule to an appropriate EigenvalueSorter, for real scalar types template EigenvalueSorter(SortRule rule, typename std::enable_if::IsComplex>::type* = nullptr) { - both_ends = rule == SortRule::BothEnds; + // The scalar-type is real + + m_both_ends = rule == SortRule::BothEnds; if (rule == SortRule::LargestMagn) - get = [](Scalar x) { using std::abs; return -std::abs(x); }; + m_target = [](Scalar x) { using std::abs; return -std::abs(x); }; else if (rule == SortRule::LargestReal) - get = [](Scalar x) { + m_target = [](Scalar x) { return -x; }; else if (rule == SortRule::LargestAlge || rule == SortRule::BothEnds) - get = [](Scalar x) { + m_target = [](Scalar x) { return -x; }; else if (rule == SortRule::SmallestMagn) - get = [](Scalar x) { using std::abs; return -std::abs(x); }; + m_target = [](Scalar x) { using std::abs; return -std::abs(x); }; else if (rule == SortRule::SmallestReal) - get = [](Scalar x) { + m_target = [](Scalar x) { return x; }; else if (rule == SortRule::SmallestAlge) - get = [](Scalar x) { + m_target = [](Scalar x) { return x; }; else @@ -145,7 +184,7 @@ struct EigenvalueSorter // We keep this order since the first k values will always be // the wanted collection, no matter k is nev_updated (used in SymEigsBase::restart()) // or is nev (used in SymEigsBase::sort_ritzpair()) - if (both_ends) + if (m_both_ends) { IndexArray index_copy(index); for (Index i = 0; i < size; i++) @@ -167,8 +206,6 @@ struct EigenvalueSorter } }; -/// \endcond - } // namespace Spectra #endif // SPECTRA_SELECTION_RULE_H diff --git a/test/GenEigs.cpp b/test/GenEigs.cpp index 4f932a5..71eddf1 100644 --- a/test/GenEigs.cpp +++ b/test/GenEigs.cpp @@ -37,7 +37,7 @@ SpMatrix gen_sparse_data(int n, double prob = 0.5) } template -void run_test(const MatType& mat, Solver& eigs, SortRule selection, bool allow_fail = false) +void run_test(const MatType& mat, Solver& eigs, const EigenvalueSorter>>& selection, bool allow_fail = false) { eigs.init(); // maxit = 300 to reduce running time for failed cases @@ -106,6 +106,10 @@ void run_test_sets(const MatType& A, int k, int m) { run_test(A, eigs, SortRule::SmallestImag, true); } + SECTION("Custom SortRule: Closest to target") + { + run_test(A, eigs, EigenvalueSorter>{[](std::complex d) { return std::abs(d - 10.); }}); + } } TEST_CASE("Eigensolver of general real matrix [10x10]", "[eigs_gen]") From 13d5773e579910d3ba824d3c642dfee02bf7cdf9 Mon Sep 17 00:00:00 2001 From: Toon Baeyens Date: Wed, 7 Dec 2022 15:05:07 +0100 Subject: [PATCH 5/5] fix rename to m_target issue --- include/Spectra/Util/SelectionRule.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/Spectra/Util/SelectionRule.h b/include/Spectra/Util/SelectionRule.h index 604ad7d..8913fce 100644 --- a/include/Spectra/Util/SelectionRule.h +++ b/include/Spectra/Util/SelectionRule.h @@ -175,7 +175,7 @@ class EigenvalueSorter index.resize(size); for (Index i = 0; i < size; i++) index[i] = i; - std::sort(index.begin(), index.end(), [&](Index i, Index j) { return get(data[i]) < get(data[j]); }); + std::sort(index.begin(), index.end(), [&](Index i, Index j) { return m_target(data[i]) < m_target(data[j]); }); // For SortRule::BothEnds, the eigenvalues are sorted according to the // SortRule::LargestAlge rule, so we need to move those smallest values to the left