Skip to content

Commit

Permalink
Update interface for PaStiX 6.4.0 (#2)
Browse files Browse the repository at this point in the history
* Always set runtime path for Pastix

* Use spmScalMat and spmScal

* Pass pointer instead of object

* Use pastix_task_solve_and_refine

* pass spm.nexp to spmScalMat

* Pass lda = b.rows()

* Add new function pastix_save

* Maximum two arguments allowed

* Print the matrix

* Use CSR format if the lower triangular part is provided

* Use CSR format for lower triangular matrix

* Switch rows and columns

* Force upper triangular matrix

* Do not transpose the matrix

* Default to MAT_SYM_UPPER instead of MAT_SYM_LOWER

* Call pastix_task_refine for each right hand side.
  • Loading branch information
octave-user authored Aug 24, 2024
1 parent 00f321b commit 73d462d
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 16 deletions.
8 changes: 4 additions & 4 deletions src/config.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@
/* Define to 1 if you have the `spmNorm' function. */
#undef HAVE_SPMNORM

/* Define to 1 if you have the `spmScalMatrix' function. */
#undef HAVE_SPMSCALMATRIX
/* Define to 1 if you have the `spmScal' function. */
#undef HAVE_SPMSCAL

/* Define to 1 if you have the `spmScalVector' function. */
#undef HAVE_SPMSCALVECTOR
/* Define to 1 if you have the `spmScalMat' function. */
#undef HAVE_SPMSCALMAT

/* Define to 1 if you have the <stdexcept> header file. */
#undef HAVE_STDEXCEPT
Expand Down
4 changes: 2 additions & 2 deletions src/configure.ac
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ AC_CHECK_FUNCS([pastixInit \
pastix_task_refine \
pastixFinalize \
spmNorm \
spmScalMatrix \
spmScalVector \
spmScalMat \
spmScal \
spmExit \
spmCheckAxb],[have_pastix=yes],[have_pastix=no])

Expand Down
79 changes: 69 additions & 10 deletions src/pastix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,14 @@ class PastixObject : public octave_base_value {
virtual size_t byte_size() const;
virtual dim_vector dims() const;
bool solve(DenseMatrixType& b, DenseMatrixType& x, pastix_trans_t sys) const;
int save(const std::string& filename) const;
static bool get_options(const octave_value& ovOptions, PastixObject::Options& options);
virtual bool is_constant(void) const{ return true; }
virtual bool is_defined(void) const{ return true; }
virtual bool isreal() const { return PastixTraits<T>::isreal; }
virtual bool iscomplex() const { return PastixTraits<T>::iscomplex; }
static octave_value_list eval(const octave_value_list& args, int nargout);

static octave_value_list save(const octave_value_list& args, int nargout);
private:
void cleanup();

Expand Down Expand Up @@ -168,9 +169,9 @@ PastixObject<T>::PastixObject(const SparseMatrixType& A, const Options& options)
{
std::memset(&spm, 0, sizeof(spm));

const octave_idx_type* const cidx = A.cidx();
const octave_idx_type* const ridx = A.ridx();
const T* const data = A.data();
const octave_idx_type* cidx = A.cidx();
const octave_idx_type* ridx = A.ridx();
const T* data = A.data();

enum MatrixPattern { MAT_SYM_UPPER,
MAT_SYM_LOWER,
Expand All @@ -183,10 +184,10 @@ PastixObject<T>::PastixObject(const SparseMatrixType& A, const Options& options)
for (octave_idx_type i = cidx[j]; i < cidx[j + 1]; ++i) {
switch (eMatPattern) {
case MAT_DIAG:
if (ridx[i] > j) {
eMatPattern = MAT_SYM_LOWER;
} else if (ridx[i] < j) {
if (ridx[i] < j) {
eMatPattern = MAT_SYM_UPPER;
} else if (ridx[i] > j) {
eMatPattern = MAT_SYM_LOWER;
}
break;
case MAT_SYM_UPPER:
Expand Down Expand Up @@ -298,7 +299,7 @@ PastixObject<T>::PastixObject(const SparseMatrixType& A, const Options& options)
colptr[ncols] = idx;
} break;
default:
// Copy the full matrix because it has been declared as unsymmetrical
// Copy the full matrix because it has been declared as unsymmetric
for (octave_idx_type i = 0; i < nnz; ++i) {
rows[i] = ridx[i];
}
Expand Down Expand Up @@ -326,13 +327,14 @@ PastixObject<T>::PastixObject(const SparseMatrixType& A, const Options& options)
}

spm.flttype = PastixTraits<T>::flttype;

spm.fmttype = SpmCSC;
spm.rowptr = rows;
spm.colptr = colptr;
spm.nnz = nnz;
spm.n = ncols;
spm.dof = 1;
spm.values = avals;
spm.rowptr = rows;
spm.colptr = colptr;

spmUpdateComputedFields(&spm);

Expand All @@ -345,6 +347,10 @@ PastixObject<T>::PastixObject(const SparseMatrixType& A, const Options& options)
spm = spm2;
}

if (options.verbose >= PastixVerboseYes) {
spmPrintInfo(&spm, stdout);
}

pastixInitParam(iparm, dparm);

iparm[IPARM_VERBOSE] = options.verbose;
Expand Down Expand Up @@ -517,6 +523,11 @@ bool PastixObject<T>::solve(DenseMatrixType& b, DenseMatrixType& x, pastix_trans
return true;
}

template <typename T>
int PastixObject<T>::save(const std::string& filename) const {
return spmSave(&spm, filename.c_str());
}

template <typename T>
bool PastixObject<T>::get_options(const octave_value& ovOptions, PastixObject::Options& options)
{
Expand Down Expand Up @@ -834,8 +845,33 @@ octave_value_list PastixObject<T>::eval(const octave_value_list& args, int nargo
return retval;
}

template<typename T>
octave_value_list PastixObject<T>::save(const octave_value_list& args, int nargout)
{
octave_value_list retval;
octave_idx_type iarg = 0;
PastixObject<T>* pPastix = nullptr;

octave_base_value& oOctaveObj = const_cast<octave_base_value&>(args(iarg++).get_rep());

pPastix = dynamic_cast<PastixObject<T>*>(&oOctaveObj);

if (!pPastix) {
error_with_id("pastix:input", "pastix: class(pastix_obj) must be equal to \"pastix\"");
return retval;
}

const std::string filename = args(iarg++).string_value();

retval.append(pPastix->save(filename));

return retval;
}

// PKG_ADD: autoload ("pastix", "__mboct_numerical__.oct");
// PKG_DEL: autoload ("pastix", "__mboct_numerical__.oct", "remove");
// PKG_ADD: autoload ("pastix_save", "__mboct_numerical__.oct");
// PKG_DEL: autoload ("pastix_save", "__mboct_numerical__.oct", "remove");

// PKG_ADD: autoload ("PASTIX_API_SYM_YES", "__mboct_numerical__.oct");
// PKG_ADD: autoload ("PASTIX_API_SYM_NO", "__mboct_numerical__.oct");
Expand Down Expand Up @@ -903,6 +939,29 @@ DEFUN_DLD (pastix, args, nargout,
return retval;
}

DEFUN_DLD (pastix_save, args, nargout,
"-*- texinfo -*-\n"
"@deftypefn {} @var{status} = pastix_save(@var{A}, @var{filename})\n\n"
"@end deftypefn\n")
{
octave_value_list retval;

if (args.length() != 2 || nargout > 1) {
print_usage();
return retval;
}

bool bcomplex = args(0).iscomplex();

if (bcomplex) {
retval = PastixObject<std::complex<double> >::save(args, nargout);
} else {
retval = PastixObject<double>::save(args, nargout);
}

return retval;
}

#define DEFINE_GLOBAL_CONSTANT(CONST,VALUE) \
DEFUN_DLD(PASTIX_##CONST, args, nargout, "id = PASTIX_" #CONST "()\n") \
{ \
Expand Down

0 comments on commit 73d462d

Please sign in to comment.