Skip to content

Commit

Permalink
Extend BLAS interface (#3173)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkachuma authored and Algiane committed Jul 30, 2024
1 parent 917740c commit 3464c15
Show file tree
Hide file tree
Showing 4 changed files with 717 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,151 @@ void matrixLeastSquaresSolutionSolve( arraySlice2d< real64, USD > const & A,
GEOS_ERROR_IF( INFO != 0, "The algorithm computing matrix linear system failed to converge." );
}

template< int USD1, int USD2 >
GEOS_FORCE_INLINE
void matrixCopy( int const N,
int const M,
arraySlice2d< real64, USD1 > const & A,
arraySlice2d< real64, USD2 > const & B )
{
for( int i = 0; i < N; i++ )
{
for( int j = 0; j < M; j++ )
{
B( i, j ) = A( i, j );
}
}
}

template< int USD >
GEOS_FORCE_INLINE
void matrixTranspose( int const N,
arraySlice2d< real64, USD > const & A )
{
for( int i = 0; i < N; i++ )
{
for( int j = i+1; j < N; j++ )
{
std::swap( A( i, j ), A( j, i ) );
}
}
}

template< typename T, int USD >
void solveLinearSystem( arraySlice2d< T, USD > const & A,
arraySlice2d< real64 const, USD > const & B,
arraySlice2d< real64, USD > const & X )
{
// --- Check that source matrix is square
int const N = LvArray::integerConversion< int >( A.size( 0 ) );
GEOS_ASSERT_MSG( N > 0 &&
N == A.size( 1 ),
"Matrix must be square" );

// --- Check that rhs B has appropriate dimensions
GEOS_ASSERT_MSG( B.size( 0 ) == N,
"right-hand-side matrix has wrong dimensions" );
int const M = LvArray::integerConversion< int >( B.size( 1 ) );

// --- Check that solution X has appropriate dimensions
GEOS_ASSERT_MSG( X.size( 0 ) == N &&
X.size( 1 ) == M,
"solution matrix has wrong dimensions" );

// --- Check that everything is contiguous
GEOS_ASSERT_MSG( A.isContiguous(), "Matrix is not contiguous" );
GEOS_ASSERT_MSG( B.isContiguous(), "right-hand-side matrix is not contiguous" );
GEOS_ASSERT_MSG( X.isContiguous(), "solution matrix is not contiguous" );

real64 * matrixData = nullptr;
array2d< real64 > LU; // Space for LU-factors
if constexpr ( !std::is_const< T >::value )
{
matrixData = A.dataIfContiguous();
}
else
{
LU.resize( N, N );
matrixData = LU.data();
// Direct copy here ignoring permutation
int const INCX = 1;
int const INCY = 1;
int const K = LvArray::integerConversion< int >( A.size( ) );
GEOS_dcopy( &K, A.dataIfContiguous(), &INCX, matrixData, &INCY );
}

array1d< int > IPIV( N );
int INFO;
char const TRANS = (USD == MatrixLayout::ROW_MAJOR) ? 'T' : 'N';

GEOS_dgetrf( &N, &N, matrixData, &N, IPIV.data(), &INFO );

GEOS_ASSERT_MSG( INFO == 0, "LAPACK dgetrf error code: " << INFO );

if constexpr ( std::is_const< T >::value )
{
int const INCX = 1;
int const INCY = 1;
int const K = LvArray::integerConversion< int >( B.size( ) );
GEOS_dcopy( &K, B.dataIfContiguous(), &INCX, X.dataIfContiguous(), &INCY );
}

// For row-major form, we need to reorder into col-major form
// This might require an extra allocation
real64 * solutionData = X.dataIfContiguous();
array2d< real64, MatrixLayout::COL_MAJOR_PERM > X0;
if constexpr ( USD == MatrixLayout::ROW_MAJOR )
{
if( 1 < M && M == N )
{
// Square case: swap in place
matrixTranspose( N, X );
}
else if( 1 < M )
{
X0.resize( N, M );
matrixCopy( N, M, X, X0.toSlice() );
solutionData = X0.data();
}
}

GEOS_dgetrs( &TRANS, &N, &M, matrixData, &N, IPIV.data(), solutionData, &N, &INFO );

GEOS_ASSERT_MSG( INFO == 0, "LAPACK dgetrs error code: " << INFO );

if constexpr ( USD == MatrixLayout::ROW_MAJOR )
{
if( 1 < M && M == N )
{
// Square case: swap in place
matrixTranspose( N, X );
}
else if( 1 < M )
{
matrixCopy( N, M, X0.toSlice(), X );
}
}
}

template< typename T, int USD >
void solveLinearSystem( arraySlice2d< T, USD > const & A,
arraySlice1d< real64 const > const & b,
arraySlice1d< real64 > const & x )
{
// --- Check that b and x have the same size
int const N = LvArray::integerConversion< int >( b.size( 0 ) );
GEOS_ASSERT_MSG( 0 < N && x.size() == N,
"right-hand-side and/or solution has wrong dimensions" );

// Create 2d slices
int const dims[2] = {N, 1};
int const strides[2] = {1, 1};
arraySlice2d< real64 const, USD > B( b.dataIfContiguous(), dims, strides );
arraySlice2d< real64, USD > X( x.dataIfContiguous(), dims, strides );

solveLinearSystem( A, B, X );
}

} // namespace detail

real64 BlasLapackLA::determinant( arraySlice2d< real64 const, MatrixLayout::ROW_MAJOR > const & A )
Expand Down Expand Up @@ -885,60 +1030,56 @@ void BlasLapackLA::matrixEigenvalues( MatRowMajor< real64 const > const & A,
matrixEigenvalues( AT.toSliceConst(), lambda );
}

void BlasLapackLA::solveLinearSystem( MatColMajor< real64 const > const & A,
arraySlice1d< real64 const > const & rhs,
arraySlice1d< real64 > const & solution )
void BlasLapackLA::solveLinearSystem( MatRowMajor< real64 const > const & A,
Vec< real64 const > const & rhs,
Vec< real64 > const & solution )
{
// --- Check that source matrix is square
int const NN = LvArray::integerConversion< int >( A.size( 0 ));
GEOS_ASSERT_MSG( NN > 0 &&
NN == A.size( 1 ),
"Matrix must be square" );

// --- Check that rhs and solution have appropriate dimension
GEOS_ASSERT_MSG( rhs.size( 0 ) == NN,
"right-hand-side vector has wrong dimensions" );

GEOS_ASSERT_MSG( solution.size( 0 ) == NN,
"solution vector has wrong dimensions" );

array1d< int > IPIV;
IPIV.resize( NN );
int const NRHS = 1; // we only allow for 1 rhs vector.
int INFO;

// make a copy of A, since dgeev destroys contents
array2d< real64, MatrixLayout::COL_MAJOR_PERM > ACOPY( A.size( 0 ), A.size( 1 ) );
BlasLapackLA::matrixCopy( A, ACOPY );

// copy the rhs in the solution vector
BlasLapackLA::vectorCopy( rhs, solution );

GEOS_dgetrf( &NN, &NN, ACOPY.data(), &NN, IPIV.data(), &INFO );
detail::solveLinearSystem( A, rhs, solution );
}

GEOS_ASSERT_MSG( INFO == 0, "LAPACK dgetrf error code: " << INFO );
void BlasLapackLA::solveLinearSystem( MatColMajor< real64 const > const & A,
Vec< real64 const > const & rhs,
Vec< real64 > const & solution )
{
detail::solveLinearSystem( A, rhs, solution );
}

GEOS_dgetrs( "N", &NN, &NRHS, ACOPY.data(), &NN, IPIV.data(), solution.dataIfContiguous(), &NN, &INFO );
void BlasLapackLA::solveLinearSystem( MatRowMajor< real64 > const & A,
Vec< real64 > const & rhs )
{
detail::solveLinearSystem( A, rhs.toSliceConst(), rhs );
}

GEOS_ASSERT_MSG( INFO == 0, "LAPACK dgetrs error code: " << INFO );
void BlasLapackLA::solveLinearSystem( MatColMajor< real64 > const & A,
Vec< real64 > const & rhs )
{
detail::solveLinearSystem( A, rhs.toSliceConst(), rhs );
}

void BlasLapackLA::solveLinearSystem( MatRowMajor< real64 const > const & A,
arraySlice1d< real64 const > const & rhs,
arraySlice1d< real64 > const & solution )
MatRowMajor< real64 const > const & rhs,
MatRowMajor< real64 > const & solution )
{
array2d< real64, MatrixLayout::COL_MAJOR_PERM > AT( A.size( 0 ), A.size( 1 ) );
detail::solveLinearSystem( A, rhs, solution );
}

// convert A to a column major format
for( int i = 0; i < A.size( 0 ); ++i )
{
for( int j = 0; j < A.size( 1 ); ++j )
{
AT( i, j ) = A( i, j );
}
}
void BlasLapackLA::solveLinearSystem( MatColMajor< real64 const > const & A,
MatColMajor< real64 const > const & rhs,
MatColMajor< real64 > const & solution )
{
detail::solveLinearSystem( A, rhs, solution );
}

void BlasLapackLA::solveLinearSystem( MatRowMajor< real64 > const & A,
MatRowMajor< real64 > const & rhs )
{
detail::solveLinearSystem( A, rhs.toSliceConst(), rhs );
}

solveLinearSystem( AT.toSliceConst(), rhs, solution );
void BlasLapackLA::solveLinearSystem( MatColMajor< real64 > const & A,
MatColMajor< real64 > const & rhs )
{
detail::solveLinearSystem( A, rhs.toSliceConst(), rhs );
}

void BlasLapackLA::matrixLeastSquaresSolutionSolve( arraySlice2d< real64 const, MatrixLayout::ROW_MAJOR > const & A,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace geos
* \class BlasLapackLA
* \brief This class contains a collection of BLAS and LAPACK linear
* algebra operations for GEOSX array1d and array2d
* \warning These methods are currently not supported on GPUs
*/
struct BlasLapackLA
{
Expand Down Expand Up @@ -454,18 +455,88 @@ struct BlasLapackLA
* @param [in] rhs GEOSX array1d.
* @param [out] solution GEOSX array1d.
*/
static void solveLinearSystem( MatColMajor< real64 const > const & A,
Vec< real64 const > const & rhs,
Vec< real64 > const & solution );

/**
* @copydoc solveLinearSystem( MatColMajor<real64 const> const &, Vec< real64 const > const &, Vec< real64 const > const & )
*/
static void solveLinearSystem( MatRowMajor< real64 const > const & A,
Vec< real64 const > const & rhs,
Vec< real64 > const & solution );

/**
* @copydoc solveLinearSystem( MatRowMajor<real64 const> const &, Vec< real64 const > const &, Vec< real64 const > const & )
* @brief Solves the linear system ;
* \p A \p solution = \p rhs.
*
* @details The method is intended for the solution of a small dense linear system.
* This solves the system in-place without allocating extra memory for the matrix or the solution. This means
* that at on exit the matrix is modified replaced by the LU factors and the right hand side vector is
* replaced by the solution.
* It employs lapack method dgetr.
*
* @param [in/out] A GEOSX array2d. The matrix. On exit this will be replaced by the factorisation of A
* @param [in/out] rhs GEOSX array1d. The right hand side. On exit this will be the solution
*/
static void solveLinearSystem( MatColMajor< real64 > const & A,
Vec< real64 > const & rhs );

/**
* @copydoc solveLinearSystem( MatColMajor< real64 > const &, Vec< real64 > const & )
*/
static void solveLinearSystem( MatRowMajor< real64 > const & A,
Vec< real64 > const & rhs );

/**
* @brief Solves the linear system ;
* \p A \p solution = \p rhs.
*
* @details The method is intended for the solution of a small dense linear system in which A is an NxN matrix, the
* right-hand-side and the solution are matrices of size NxM.
* It employs lapack method dgetr.
*
* @note this function first applies a matrix permutation and then calls the row major version of the function.
* @param [in] A GEOSX array2d.
* @param [in] rhs GEOSX array2d.
* @param [out] solution GEOSX array2d.
*/
static void solveLinearSystem( MatColMajor< real64 const > const & A,
Vec< real64 const > const & rhs,
Vec< real64 > const & solution );
MatColMajor< real64 const > const & rhs,
MatColMajor< real64 > const & solution );

/**
* @copydoc solveLinearSystem( MatColMajor< real64 const > const &, MatColMajor< real64 const > const &, MatColMajor< const > const & )
*
* @note this function will allocate space to reorder the solution into column major form.
*/
static void solveLinearSystem( MatRowMajor< real64 const > const & A,
MatRowMajor< real64 const > const & rhs,
MatRowMajor< real64 > const & solution );

/**
* @brief Solves the linear system ;
* \p A \p solution = \p rhs.
*
* @details The method is intended for the solution of a small dense linear system in which A is an NxN matrix, the
* right-hand-side and the solution are matrices of size NxM.
* This solves the system in-place without allocating extra memory for the matrix or the solution. This means
* that at on exit the matrix is modified replaced by the LU factors and the right hand side vector is
* replaced by the solution.
* It employs lapack method dgetr.
*
* @param [in/out] A GEOSX array2d. The matrix. On exit this will be replaced by the factorisation of A
* @param [in/out] rhs GEOSX array1d. The right hand side. On exit this will be the solution
*/
static void solveLinearSystem( MatColMajor< real64 > const & A,
MatColMajor< real64 > const & rhs );

/**
* @copydoc solveLinearSystem( MatColMajor< real64 > const &, MatRowMajor< real64 > const & )
*
* @note this function will allocate space to reorder the solution into column major form.
*/
static void solveLinearSystem( MatRowMajor< real64 > const & A,
MatRowMajor< real64 > const & rhs );

/**
* @brief Vector copy;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
set( serial_tests
testBlasLapack.cpp )
testBlasLapack.cpp
testSolveLinearSystem.cpp )

set( dependencyList gtest denseLinearAlgebra )

Expand Down
Loading

0 comments on commit 3464c15

Please sign in to comment.