From ea9123153b712da58a75c064eb75aa5f3806cd90 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 10 Jul 2020 22:54:44 +0900 Subject: [PATCH 1/2] Impl bk, solveh, and invh using LAPACK --- lax/src/solveh.rs | 72 +++++++++++++++++++++++++++++------------------ 1 file changed, 44 insertions(+), 28 deletions(-) diff --git a/lax/src/solveh.rs b/lax/src/solveh.rs index 01e90f13..508b2ec6 100644 --- a/lax/src/solveh.rs +++ b/lax/src/solveh.rs @@ -5,6 +5,7 @@ use super::*; use crate::{error::*, layout::MatrixLayout}; use cauchy::*; +use num_traits::{ToPrimitive, Zero}; pub trait Solveh_: Sized { /// Bunch-Kaufman: wrapper of `*sytrf` and `*hetrf` @@ -28,13 +29,39 @@ macro_rules! impl_solveh { let (n, _) = l.size(); let mut ipiv = vec![0; n as usize]; if n == 0 { - // Work around bug in LAPACKE functions. - Ok(ipiv) - } else { - $trf(l.lapacke_layout(), uplo as u8, n, a, l.lda(), &mut ipiv) - .as_lapack_result()?; - Ok(ipiv) + return Ok(Vec::new()); } + + // calc work size + let mut info = 0; + let mut work_size = [Self::zero()]; + $trf( + uplo as u8, + n, + a, + l.lda(), + &mut ipiv, + &mut work_size, + -1, + &mut info, + ); + info.as_lapack_result()?; + + // actual + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + $trf( + uplo as u8, + n, + a, + l.lda(), + &mut ipiv, + &mut work, + lwork as i32, + &mut info, + ); + info.as_lapack_result()?; + Ok(ipiv) } unsafe fn invh( @@ -44,7 +71,10 @@ macro_rules! impl_solveh { ipiv: &Pivot, ) -> Result<()> { let (n, _) = l.size(); - $tri(l.lapacke_layout(), uplo as u8, n, a, l.lda(), ipiv).as_lapack_result()?; + let mut info = 0; + let mut work = vec![Self::zero(); n as usize]; + $tri(uplo as u8, n, a, l.lda(), ipiv, &mut work, &mut info); + info.as_lapack_result()?; Ok(()) } @@ -56,30 +86,16 @@ macro_rules! impl_solveh { b: &mut [Self], ) -> Result<()> { let (n, _) = l.size(); - let nrhs = 1; - let ldb = match l { - MatrixLayout::C { .. } => 1, - MatrixLayout::F { .. } => n, - }; - $trs( - l.lapacke_layout(), - uplo as u8, - n, - nrhs, - a, - l.lda(), - ipiv, - b, - ldb, - ) - .as_lapack_result()?; + let mut info = 0; + $trs(uplo as u8, n, 1, a, l.lda(), ipiv, b, n, &mut info); + info.as_lapack_result()?; Ok(()) } } }; } // impl_solveh! -impl_solveh!(f64, lapacke::dsytrf, lapacke::dsytri, lapacke::dsytrs); -impl_solveh!(f32, lapacke::ssytrf, lapacke::ssytri, lapacke::ssytrs); -impl_solveh!(c64, lapacke::zhetrf, lapacke::zhetri, lapacke::zhetrs); -impl_solveh!(c32, lapacke::chetrf, lapacke::chetri, lapacke::chetrs); +impl_solveh!(f64, lapack::dsytrf, lapack::dsytri, lapack::dsytrs); +impl_solveh!(f32, lapack::ssytrf, lapack::ssytri, lapack::ssytrs); +impl_solveh!(c64, lapack::zhetrf, lapack::zhetri, lapack::zhetrs); +impl_solveh!(c32, lapack::chetrf, lapack::chetri, lapack::chetrs); From 9ccb6af5487f3eae51592a7d49036e2def232fcd Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 10 Jul 2020 22:55:44 +0900 Subject: [PATCH 2/2] Revise deth --- ndarray-linalg/src/solveh.rs | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/ndarray-linalg/src/solveh.rs b/ndarray-linalg/src/solveh.rs index c05e2bde..da748f77 100644 --- a/ndarray-linalg/src/solveh.rs +++ b/ndarray-linalg/src/solveh.rs @@ -314,6 +314,7 @@ where S: Data, A: Scalar + Lapack, { + let layout = a.layout().unwrap(); let mut sign = A::Real::one(); let mut ln_det = A::Real::zero(); let mut ipiv_enum = ipiv_iter.enumerate(); @@ -337,9 +338,15 @@ where debug_assert_eq!(lower_diag.im(), Zero::zero()); // Off-diagonal elements, can be complex. - let off_diag = match uplo { - UPLO::Upper => unsafe { a.uget((k, k + 1)) }, - UPLO::Lower => unsafe { a.uget((k + 1, k)) }, + let off_diag = match layout { + MatrixLayout::C { .. } => match uplo { + UPLO::Upper => unsafe { a.uget((k + 1, k)) }, + UPLO::Lower => unsafe { a.uget((k, k + 1)) }, + }, + MatrixLayout::F { .. } => match uplo { + UPLO::Upper => unsafe { a.uget((k, k + 1)) }, + UPLO::Lower => unsafe { a.uget((k + 1, k)) }, + }, }; // Determinant of 2x2 block.