From bcf48eda0a3844e52c5f06fb9056eb3a1adde64f Mon Sep 17 00:00:00 2001 From: Reinhard Date: Sat, 27 Apr 2024 17:08:17 +0200 Subject: [PATCH 01/10] Add option for transposed solve --- src/umfpack.cc | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/umfpack.cc b/src/umfpack.cc index ec5ff78..680c638 100644 --- a/src/umfpack.cc +++ b/src/umfpack.cc @@ -1,4 +1,4 @@ -// Copyright (C) 2019(-2023) Reinhard +// Copyright (C) 2019(-2024) Reinhard // This program is free software; you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by @@ -150,7 +150,7 @@ class UmfpackObject : public octave_base_value { UmfpackObject(); explicit UmfpackObject(const SparseMatrixType& A, const Options& options); virtual ~UmfpackObject(void); - DenseMatrixType solve(const DenseMatrixType& b); + DenseMatrixType solve(const DenseMatrixType& b, SuiteSparse_long sys); virtual bool is_constant(void) const{ return true; } virtual bool is_defined(void) const{ return true; } virtual dim_vector dims (void) const { return oMat.dims(); } @@ -245,7 +245,7 @@ UmfpackObject::~UmfpackObject() } template -typename UmfpackObject::DenseMatrixType UmfpackObject::solve(const DenseMatrixType& b) +typename UmfpackObject::DenseMatrixType UmfpackObject::solve(const DenseMatrixType& b, SuiteSparse_long sys) { DenseMatrixType x(b.rows(), b.columns()); @@ -255,7 +255,7 @@ typename UmfpackObject::DenseMatrixType UmfpackObject::solve(const DenseMa const T* const bp = b.data(); for (octave_idx_type j = 0; j < b.columns(); ++j) { - auto status = oMat.umfpack_solve(UMFPACK_A, + auto status = oMat.umfpack_solve(sys, xp + j * n, bp + j * n, Numeric, @@ -411,6 +411,12 @@ octave_value_list UmfpackObject::eval(const octave_value_list& args, int narg } } + SuiteSparse_long sys = UMFPACK_A; + + if (args.length() > iarg && args(++iarg).bool_value()) { + sys = UMFPACK_At; + } + try { if (bHaveMatrix) { pUmfpack = new UmfpackObjectType{A, options}; @@ -419,7 +425,7 @@ octave_value_list UmfpackObject::eval(const octave_value_list& args, int narg } if (bHaveRightHandSide) { - retval.append(pUmfpack->solve(b)); + retval.append(pUmfpack->solve(b, sys)); if (bOwnUmfpack) { delete pUmfpack; From fcc0304da754d0d2fe347b4aff85f94c57424e18 Mon Sep 17 00:00:00 2001 From: Reinhard Date: Sat, 27 Apr 2024 17:17:27 +0200 Subject: [PATCH 02/10] Fix increment operator --- src/umfpack.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/umfpack.cc b/src/umfpack.cc index 680c638..d183e55 100644 --- a/src/umfpack.cc +++ b/src/umfpack.cc @@ -413,7 +413,7 @@ octave_value_list UmfpackObject::eval(const octave_value_list& args, int narg SuiteSparse_long sys = UMFPACK_A; - if (args.length() > iarg && args(++iarg).bool_value()) { + if (args.length() > iarg && args(iarg++).bool_value()) { sys = UMFPACK_At; } From 4ca4b7469df0be6b476a873f600cf3eacf9a4f7a Mon Sep 17 00:00:00 2001 From: Reinhard Date: Sat, 27 Apr 2024 17:35:06 +0200 Subject: [PATCH 03/10] Add transposed solve to pardiso --- src/pardiso.cc | 18 +++++++++++++++--- src/umfpack.cc | 4 ++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/pardiso.cc b/src/pardiso.cc index c7af113..4425b06 100644 --- a/src/pardiso.cc +++ b/src/pardiso.cc @@ -113,7 +113,7 @@ class PardisoObject : public octave_base_value { virtual ~PardisoObject(void); virtual size_t byte_size() const; virtual dim_vector dims() const; - bool solve(DenseMatrixType& b, DenseMatrixType& x) const; + bool solve(DenseMatrixType& b, DenseMatrixType& x, long long sys) const; static bool get_options(const octave_value& ovOptions, PardisoObject::Options& options); virtual bool is_constant(void) const{ return true; } virtual bool is_defined(void) const{ return true; } @@ -447,7 +447,7 @@ PardisoObject::~PardisoObject() } template -bool PardisoObject::solve(DenseMatrixType& b, DenseMatrixType& x) const { +bool PardisoObject::solve(DenseMatrixType& b, DenseMatrixType& x, long long sys) const { if (b.rows() != n) { error_with_id("pardiso:solve", "pardiso: rows(b)=%Ld must be equal to rows(A)=%Ld", static_cast(b.rows()), n); return false; @@ -456,8 +456,14 @@ bool PardisoObject::solve(DenseMatrixType& b, DenseMatrixType& x) const { assert(b.rows() == x.rows()); assert(b.columns() == x.columns()); + const auto save_sys = iparm[11]; + + iparm[11] = sys; + long long ierror = pardiso(b.fortran_vec(), x.fortran_vec(), b.columns()); + iparm[11] = save_sys; + if (ierror != 0LL) { error_with_id("pardiso:solve", "pardiso solve failed with status %Ld", ierror); return false; @@ -623,8 +629,14 @@ octave_value_list PardisoObject::eval(const octave_value_list& args, int narg bOwnPardiso = true; } + long long sys = 0LL; + + if (args.length() > iarg) { + sys = args(iarg++).long_value(); + } + if (bHaveRightHandSide) { - if (pPardiso->solve(b, x)) { + if (pPardiso->solve(b, x, sys)) { retval.append(x); } diff --git a/src/umfpack.cc b/src/umfpack.cc index d183e55..38efcd0 100644 --- a/src/umfpack.cc +++ b/src/umfpack.cc @@ -413,8 +413,8 @@ octave_value_list UmfpackObject::eval(const octave_value_list& args, int narg SuiteSparse_long sys = UMFPACK_A; - if (args.length() > iarg && args(iarg++).bool_value()) { - sys = UMFPACK_At; + if (args.length() > iarg) { + sys = args(iarg++).long_value(); } try { From d5c65f308bbdaf74b90ef6788242309338ef89b4 Mon Sep 17 00:00:00 2001 From: Reinhard Date: Sat, 27 Apr 2024 17:52:50 +0200 Subject: [PATCH 04/10] Add transposed solve to pastix --- src/pastix.cc | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/src/pastix.cc b/src/pastix.cc index 5dd7dc8..e3371a4 100644 --- a/src/pastix.cc +++ b/src/pastix.cc @@ -97,7 +97,7 @@ class PastixObject : public octave_base_value { virtual ~PastixObject(void); virtual size_t byte_size() const; virtual dim_vector dims() const; - bool solve(DenseMatrixType& b, DenseMatrixType& x) const; + bool solve(DenseMatrixType& b, DenseMatrixType& x, pastix_trans_t sys) const; static bool get_options(const octave_value& ovOptions, PastixObject::Options& options); virtual bool is_constant(void) const{ return true; } virtual bool is_defined(void) const{ return true; } @@ -423,7 +423,7 @@ void PastixObject::cleanup() } template -bool PastixObject::solve(DenseMatrixType& b, DenseMatrixType& x) const { +bool PastixObject::solve(DenseMatrixType& b, DenseMatrixType& x, pastix_trans_t sys) const { if (b.rows() != ncols) { error_with_id("pastix:solve", "pastix: rows(b)=%ld must be equal to rows(A)=%ld", long(b.rows()), long(ncols)); return false; @@ -433,12 +433,18 @@ bool PastixObject::solve(DenseMatrixType& b, DenseMatrixType& x) const { x = b; + const auto save_sys = iparm[IPARM_TRANSPOSE_SOLVE]; + + iparm[IPARM_TRANSPOSE_SOLVE] = sys; + int rc = pastix_task_solve(pastix_data, x.rows(), x.columns(), x.fortran_vec(), x.rows()); + iparm[IPARM_TRANSPOSE_SOLVE] = save_sys; + if (PASTIX_SUCCESS != rc) { error_with_id("pastix:solve", "pastix_task_solve failed with status %d", rc); return false; @@ -792,8 +798,26 @@ octave_value_list PastixObject::eval(const octave_value_list& args, int nargo #endif } + pastix_trans_t sys = PastixNoTrans; + + if (args.length() > iarg) { + switch (args(iarg++).long_value()) { + case 0: + sys = PastixNoTrans; + break; + case 1: + sys = PastixConjTrans; + break; + case 2: + sys = PastixTrans; + break; + default: + sys = static_cast(-1); + } + } + if (bHaveRightHandSide) { - if (pPastix->solve(b, x)) { + if (pPastix->solve(b, x, sys)) { retval.append(x); } From 74c91cb3ad08ab61dbd239bfe0ed0cfa0a4b0d87 Mon Sep 17 00:00:00 2001 From: Reinhard Date: Sat, 27 Apr 2024 18:12:32 +0200 Subject: [PATCH 05/10] Use transposed version for refinement --- src/pastix.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/pastix.cc b/src/pastix.cc index e3371a4..7872b8d 100644 --- a/src/pastix.cc +++ b/src/pastix.cc @@ -445,7 +445,7 @@ bool PastixObject::solve(DenseMatrixType& b, DenseMatrixType& x, pastix_trans iparm[IPARM_TRANSPOSE_SOLVE] = save_sys; - if (PASTIX_SUCCESS != rc) { + if (PASTIX_SUCCESS != rc) { error_with_id("pastix:solve", "pastix_task_solve failed with status %d", rc); return false; } @@ -464,6 +464,8 @@ bool PastixObject::solve(DenseMatrixType& b, DenseMatrixType& x, pastix_trans } if (!bZeroVec) { + iparm[IPARM_TRANSPOSE_SOLVE] = sys; + // Avoid division zero by zero in PaStiX rc = pastix_task_refine(pastix_data, spm.n, @@ -472,7 +474,9 @@ bool PastixObject::solve(DenseMatrixType& b, DenseMatrixType& x, pastix_trans b.rows(), x.fortran_vec() + j * x.rows(), x.rows()); - + + iparm[IPARM_TRANSPOSE_SOLVE] = save_sys; + if (PASTIX_SUCCESS != rc) { error_with_id("pastix:solve", "pastix_task_refine failed with status %d", rc); return false; From 83094073d02b12b1a2bcb4c2eb3fe15f89102e31 Mon Sep 17 00:00:00 2001 From: Reinhard Date: Sat, 27 Apr 2024 18:18:47 +0200 Subject: [PATCH 06/10] Do not use PastixConjTrans for real matrices --- src/pastix.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/pastix.cc b/src/pastix.cc index 7872b8d..48c806b 100644 --- a/src/pastix.cc +++ b/src/pastix.cc @@ -810,7 +810,11 @@ octave_value_list PastixObject::eval(const octave_value_list& args, int nargo sys = PastixNoTrans; break; case 1: - sys = PastixConjTrans; + if constexpr(std::is_same::value) { + sys = PastixTrans; + } else { + sys = PastixConjTrans; + } break; case 2: sys = PastixTrans; From 02e8087949e0dcc8d81d7ac683ca47384ba57160 Mon Sep 17 00:00:00 2001 From: Reinhard Date: Sat, 27 Apr 2024 18:24:32 +0200 Subject: [PATCH 07/10] Disable complex conjugate solve --- src/pastix.cc | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/pastix.cc b/src/pastix.cc index 48c806b..889947f 100644 --- a/src/pastix.cc +++ b/src/pastix.cc @@ -809,13 +809,6 @@ octave_value_list PastixObject::eval(const octave_value_list& args, int nargo case 0: sys = PastixNoTrans; break; - case 1: - if constexpr(std::is_same::value) { - sys = PastixTrans; - } else { - sys = PastixConjTrans; - } - break; case 2: sys = PastixTrans; break; From 887b4e4094456517739beb08b8eb276f18152ecd Mon Sep 17 00:00:00 2001 From: Reinhard Date: Sat, 27 Apr 2024 18:29:03 +0200 Subject: [PATCH 08/10] Cleanup whitespace --- src/pardiso.cc | 4 ++-- src/pastix.cc | 16 ++++++++-------- src/umfpack.cc | 4 ++-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/pardiso.cc b/src/pardiso.cc index 4425b06..a242ac0 100644 --- a/src/pardiso.cc +++ b/src/pardiso.cc @@ -457,7 +457,7 @@ bool PardisoObject::solve(DenseMatrixType& b, DenseMatrixType& x, long long s assert(b.columns() == x.columns()); const auto save_sys = iparm[11]; - + iparm[11] = sys; long long ierror = pardiso(b.fortran_vec(), x.fortran_vec(), b.columns()); @@ -634,7 +634,7 @@ octave_value_list PardisoObject::eval(const octave_value_list& args, int narg if (args.length() > iarg) { sys = args(iarg++).long_value(); } - + if (bHaveRightHandSide) { if (pPardiso->solve(b, x, sys)) { retval.append(x); diff --git a/src/pastix.cc b/src/pastix.cc index 889947f..31cbd2e 100644 --- a/src/pastix.cc +++ b/src/pastix.cc @@ -434,9 +434,9 @@ bool PastixObject::solve(DenseMatrixType& b, DenseMatrixType& x, pastix_trans x = b; const auto save_sys = iparm[IPARM_TRANSPOSE_SOLVE]; - + iparm[IPARM_TRANSPOSE_SOLVE] = sys; - + int rc = pastix_task_solve(pastix_data, x.rows(), x.columns(), @@ -444,8 +444,8 @@ bool PastixObject::solve(DenseMatrixType& b, DenseMatrixType& x, pastix_trans x.rows()); iparm[IPARM_TRANSPOSE_SOLVE] = save_sys; - - if (PASTIX_SUCCESS != rc) { + + if (PASTIX_SUCCESS != rc) { error_with_id("pastix:solve", "pastix_task_solve failed with status %d", rc); return false; } @@ -465,7 +465,7 @@ bool PastixObject::solve(DenseMatrixType& b, DenseMatrixType& x, pastix_trans if (!bZeroVec) { iparm[IPARM_TRANSPOSE_SOLVE] = sys; - + // Avoid division zero by zero in PaStiX rc = pastix_task_refine(pastix_data, spm.n, @@ -474,9 +474,9 @@ bool PastixObject::solve(DenseMatrixType& b, DenseMatrixType& x, pastix_trans b.rows(), x.fortran_vec() + j * x.rows(), x.rows()); - + iparm[IPARM_TRANSPOSE_SOLVE] = save_sys; - + if (PASTIX_SUCCESS != rc) { error_with_id("pastix:solve", "pastix_task_refine failed with status %d", rc); return false; @@ -816,7 +816,7 @@ octave_value_list PastixObject::eval(const octave_value_list& args, int nargo sys = static_cast(-1); } } - + if (bHaveRightHandSide) { if (pPastix->solve(b, x, sys)) { retval.append(x); diff --git a/src/umfpack.cc b/src/umfpack.cc index 38efcd0..9e35c4f 100644 --- a/src/umfpack.cc +++ b/src/umfpack.cc @@ -412,11 +412,11 @@ octave_value_list UmfpackObject::eval(const octave_value_list& args, int narg } SuiteSparse_long sys = UMFPACK_A; - + if (args.length() > iarg) { sys = args(iarg++).long_value(); } - + try { if (bHaveMatrix) { pUmfpack = new UmfpackObjectType{A, options}; From 4479bbc5b6f579f3d42d97ded1819da31757beff Mon Sep 17 00:00:00 2001 From: Reinhard Date: Sat, 27 Apr 2024 18:39:59 +0200 Subject: [PATCH 09/10] Disable spmCheckAxb for transposed solve --- src/pastix.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pastix.cc b/src/pastix.cc index 31cbd2e..1b17384 100644 --- a/src/pastix.cc +++ b/src/pastix.cc @@ -486,7 +486,8 @@ bool PastixObject::solve(DenseMatrixType& b, DenseMatrixType& x, pastix_trans OCTAVE_QUIT; } - if (options.check_solution) { + if (options.check_solution && sys == PastixNoTrans) { + // FIXME: Is there any transposed version of spmCheckAxb? rc = spmCheckAxb(dparm[DPARM_EPSILON_REFINEMENT], b.columns(), &spm, From 39a0623396c23550ae5cbbbd75c1c13bca880d3b Mon Sep 17 00:00:00 2001 From: Reinhard Date: Sat, 27 Apr 2024 18:44:36 +0200 Subject: [PATCH 10/10] Add test for transposed solve --- inst/numerical_tests_01.tst | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/inst/numerical_tests_01.tst b/inst/numerical_tests_01.tst index 246f00c..8f07010 100644 --- a/inst/numerical_tests_01.tst +++ b/inst/numerical_tests_01.tst @@ -1,7 +1,7 @@ ## numerical_tests.tst:01 %!test %! if (~isempty(which("pastix"))) -%! for i=1:2 +%! for i=1:4 %! for j=1:2 %! A = [1 0 0 0 0 %! 0 3 0 0 0 @@ -26,12 +26,23 @@ %! opts.number_of_threads = int32(4); %! opts.check_solution = true; %! switch i -%! case 1 -%! x = pastix(A, b, opts); -%! case 2 -%! x = pastix(pastix(A, opts), b); +%! case {1, 2} +%! trans = 0; +%! case {3, 4} +%! trans = 2; +%! endswitch +%! switch i +%! case {1, 3} +%! x = pastix(A, b, opts, trans); +%! case {2, 4} +%! x = pastix(pastix(A, opts), b, trans); +%! endswitch +%! switch (trans) +%! case 0 +%! f = max(norm(A * x - b, "cols") ./ norm(A * x + b, "cols")); +%! case 2 +%! f = max(norm(A.' * x - b, "cols") ./ norm(A.' * x + b, "cols")); %! endswitch -%! f = max(norm(A * x - b, "cols") ./ norm(A * x + b, "cols")); %! assert(f <= eps^0.8); %! endfor %! endfor