Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom SortRule #147

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/Spectra/DavidsonSymEigsSolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ class DavidsonSymEigsSolver : public JDSymEigsBase<DavidsonSymEigsSolver<OpType>
///
/// \param selection Spectrum section to target (e.g. lowest, etc.)
/// \return Matrix with the initial orthonormal basis
Matrix setup_initial_search_space(SortRule selection) const
Matrix setup_initial_search_space(const EigenvalueSorter<Scalar> &selection) const
{
std::vector<Eigen::Index> indices_sorted = argsort(selection, m_diagonal);
std::vector<Eigen::Index> indices_sorted = selection.argsort(m_diagonal);

Matrix initial_basis = Matrix::Zero(this->m_matrix_operator.rows(), this->m_initial_search_space_size);

Expand Down
99 changes: 8 additions & 91 deletions include/Spectra/GenEigsBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class GenEigsBase
static bool is_conj(const Complex& v1, const Complex& v2) { return v1 == Eigen::numext::conj(v2); }

// Implicitly restarted Arnoldi factorization
void restart(Index k, SortRule selection)
void restart(Index k, const EigenvalueSorter<Complex>& selection)
{
using std::norm;

Expand All @@ -96,7 +96,7 @@ class GenEigsBase

for (Index i = k; i < m_ncv; i++)
{
if (is_complex(m_ritz_val[i]) && is_conj(m_ritz_val[i], m_ritz_val[i + 1]))
if (i + 1 < m_ncv && is_complex(m_ritz_val[i]) && is_conj(m_ritz_val[i], m_ritz_val[i + 1]))
{
// H - mu * I = Q1 * R1
// H <- R1 * Q1 + mu * I = Q1' * H * Q1
Expand Down Expand Up @@ -193,55 +193,14 @@ class GenEigsBase
}

// Retrieves and sorts Ritz values and Ritz vectors
void retrieve_ritzpair(SortRule selection)
void retrieve_ritzpair(const EigenvalueSorter<Complex> & selection)
{
UpperHessenbergEigen<Scalar> decomp(m_fac.matrix_H());
const ComplexVector& evals = decomp.eigenvalues();
ComplexMatrix evecs = decomp.eigenvectors();

// Sort Ritz values and put the wanted ones at the beginning
std::vector<Index> ind;
switch (selection)
{
case SortRule::LargestMagn:
{
SortEigenvalue<Complex, SortRule::LargestMagn> sorting(evals.data(), m_ncv);
sorting.swap(ind);
break;
}
case SortRule::LargestReal:
{
SortEigenvalue<Complex, SortRule::LargestReal> sorting(evals.data(), m_ncv);
sorting.swap(ind);
break;
}
case SortRule::LargestImag:
{
SortEigenvalue<Complex, SortRule::LargestImag> sorting(evals.data(), m_ncv);
sorting.swap(ind);
break;
}
case SortRule::SmallestMagn:
{
SortEigenvalue<Complex, SortRule::SmallestMagn> sorting(evals.data(), m_ncv);
sorting.swap(ind);
break;
}
case SortRule::SmallestReal:
{
SortEigenvalue<Complex, SortRule::SmallestReal> sorting(evals.data(), m_ncv);
sorting.swap(ind);
break;
}
case SortRule::SmallestImag:
{
SortEigenvalue<Complex, SortRule::SmallestImag> sorting(evals.data(), m_ncv);
sorting.swap(ind);
break;
}
default:
throw std::invalid_argument("unsupported selection rule");
}
std::vector<Index> ind = selection.argsort(evals.data(), m_ncv);

// Copy the Ritz values and vectors to m_ritz_val and m_ritz_vec, respectively
for (Index i = 0; i < m_ncv; i++)
Expand All @@ -258,51 +217,9 @@ class GenEigsBase
protected:
// Sorts the first nev Ritz pairs in the specified order
// This is used to return the final results
virtual void sort_ritzpair(SortRule sort_rule)
virtual void sort_ritzpair(const EigenvalueSorter<Complex> & sort_rule)
{
std::vector<Index> ind;
switch (sort_rule)
{
case SortRule::LargestMagn:
{
SortEigenvalue<Complex, SortRule::LargestMagn> sorting(m_ritz_val.data(), m_nev);
sorting.swap(ind);
break;
}
case SortRule::LargestReal:
{
SortEigenvalue<Complex, SortRule::LargestReal> sorting(m_ritz_val.data(), m_nev);
sorting.swap(ind);
break;
}
case SortRule::LargestImag:
{
SortEigenvalue<Complex, SortRule::LargestImag> sorting(m_ritz_val.data(), m_nev);
sorting.swap(ind);
break;
}
case SortRule::SmallestMagn:
{
SortEigenvalue<Complex, SortRule::SmallestMagn> sorting(m_ritz_val.data(), m_nev);
sorting.swap(ind);
break;
}
case SortRule::SmallestReal:
{
SortEigenvalue<Complex, SortRule::SmallestReal> sorting(m_ritz_val.data(), m_nev);
sorting.swap(ind);
break;
}
case SortRule::SmallestImag:
{
SortEigenvalue<Complex, SortRule::SmallestImag> sorting(m_ritz_val.data(), m_nev);
sorting.swap(ind);
break;
}
default:
throw std::invalid_argument("unsupported sorting rule");
}

std::vector<Index> ind = sort_rule.argsort(m_ritz_val.data(), m_nev);
ComplexVector new_ritz_val(m_ncv);
ComplexMatrix new_ritz_vec(m_ncv, m_nev);
BoolArray new_ritz_conv(m_nev);
Expand Down Expand Up @@ -414,8 +331,8 @@ class GenEigsBase
///
/// \return Number of converged eigenvalues.
///
Index compute(SortRule selection = SortRule::LargestMagn, Index maxit = 1000,
Scalar tol = 1e-10, SortRule sorting = SortRule::LargestMagn)
Index compute(const EigenvalueSorter<Complex>& selection = SortRule::LargestMagn, Index maxit = 1000,
Scalar tol = 1e-10, const EigenvalueSorter<Complex>& sorting = SortRule::LargestMagn)
{
// The m-step Arnoldi factorization
m_fac.factorize_from(1, m_ncv, m_nmatop);
Expand Down
2 changes: 1 addition & 1 deletion include/Spectra/GenEigsComplexShiftSolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class GenEigsComplexShiftSolver : public GenEigsBase<OpType, IdentityBOp>
const Scalar m_sigmai;

// First transform back the Ritz values, and then sort
void sort_ritzpair(SortRule sort_rule) override
void sort_ritzpair(const EigenvalueSorter<Complex> &sort_rule) override
{
using std::abs;
using std::sqrt;
Expand Down
2 changes: 1 addition & 1 deletion include/Spectra/GenEigsRealShiftSolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class GenEigsRealShiftSolver : public GenEigsBase<OpType, IdentityBOp>
const Scalar m_sigma;

// First transform back the Ritz values, and then sort
void sort_ritzpair(SortRule sort_rule) override
void sort_ritzpair(const EigenvalueSorter<Complex> &sort_rule) override
{
// The eigenvalues we get from the iteration is nu = 1 / (lambda - sigma)
// So the eigenvalues of the original problem is lambda = 1 / nu + sigma
Expand Down
4 changes: 2 additions & 2 deletions include/Spectra/LinAlg/RitzPairs.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ class RitzPairs
/// Sort the eigen pairs according to the selection rule
///
/// \param selection Sorting rule
void sort(SortRule selection)
void sort(const EigenvalueSorter<Scalar> &selection)
{
std::vector<Index> ind = argsort(selection, m_values);
std::vector<Index> ind = selection.argsort(m_values);
RitzPairs<Scalar> temp = *this;
for (Index i = 0; i < size(); i++)
{
Expand Down
12 changes: 4 additions & 8 deletions include/Spectra/SymEigsBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,14 @@ class SymEigsBase
}

// Retrieves and sorts Ritz values and Ritz vectors
void retrieve_ritzpair(SortRule selection)
void retrieve_ritzpair(const EigenvalueSorter<Scalar> &selection)
{
TridiagEigen<Scalar> decomp(m_fac.matrix_H());
const Vector& evals = decomp.eigenvalues();
const Matrix& evecs = decomp.eigenvectors();

// Sort Ritz values and put the wanted ones at the beginning
std::vector<Index> ind = argsort(selection, evals, m_ncv);
std::vector<Index> ind = selection.argsort(evals.data(), m_ncv);

// Copy the Ritz values and vectors to m_ritz_val and m_ritz_vec, respectively
for (Index i = 0; i < m_ncv; i++)
Expand All @@ -198,13 +198,9 @@ class SymEigsBase
protected:
// Sorts the first nev Ritz pairs in the specified order
// This is used to return the final results
virtual void sort_ritzpair(SortRule sort_rule)
virtual void sort_ritzpair(const EigenvalueSorter<Scalar> &sort_rule)
{
if ((sort_rule != SortRule::LargestAlge) && (sort_rule != SortRule::LargestMagn) &&
(sort_rule != SortRule::SmallestAlge) && (sort_rule != SortRule::SmallestMagn))
throw std::invalid_argument("unsupported sorting rule");

std::vector<Index> ind = argsort(sort_rule, m_ritz_val, m_nev);
std::vector<Index> ind = sort_rule.argsort(m_ritz_val.data(), m_nev);

Vector new_ritz_val(m_ncv);
Matrix new_ritz_vec(m_ncv, m_nev);
Expand Down
2 changes: 1 addition & 1 deletion include/Spectra/SymEigsShiftSolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class SymEigsShiftSolver : public SymEigsBase<OpType, IdentityBOp>
const Scalar m_sigma;

// First transform back the Ritz values, and then sort
void sort_ritzpair(SortRule sort_rule) override
void sort_ritzpair(const EigenvalueSorter<Scalar> &sort_rule) override
{
// The eigenvalues we get from the iteration is nu = 1 / (lambda - sigma)
// So the eigenvalues of the original problem is lambda = 1 / nu + sigma
Expand Down
6 changes: 3 additions & 3 deletions include/Spectra/SymGEigsShiftSolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class SymGEigsShiftSolver<OpType, BOpType, GEigsMode::ShiftInvert> :
}

// First transform back the Ritz values, and then sort
void sort_ritzpair(SortRule sort_rule) override
void sort_ritzpair(const EigenvalueSorter<Scalar> &sort_rule) override
{
// The eigenvalues we get from the iteration is nu = 1 / (lambda - sigma)
// So the eigenvalues of the original problem is lambda = 1 / nu + sigma
Expand Down Expand Up @@ -329,7 +329,7 @@ class SymGEigsShiftSolver<OpType, BOpType, GEigsMode::Buckling> :
}

// First transform back the Ritz values, and then sort
void sort_ritzpair(SortRule sort_rule) override
void sort_ritzpair(const EigenvalueSorter<Scalar> &sort_rule) override
{
// The eigenvalues we get from the iteration is nu = lambda / (lambda - sigma)
// So the eigenvalues of the original problem is lambda = sigma * nu / (nu - 1)
Expand Down Expand Up @@ -421,7 +421,7 @@ class SymGEigsShiftSolver<OpType, BOpType, GEigsMode::Cayley> :
}

// First transform back the Ritz values, and then sort
void sort_ritzpair(SortRule sort_rule) override
void sort_ritzpair(const EigenvalueSorter<Scalar> &sort_rule) override
{
// The eigenvalues we get from the iteration is nu = (lambda + sigma) / (lambda - sigma)
// So the eigenvalues of the original problem is lambda = sigma * (nu + 1) / (nu - 1)
Expand Down
Loading