Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement SVD by divide-and-conquer #164

Merged
merged 1 commit into from
Jun 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/lapack/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub mod qr;
pub mod solve;
pub mod solveh;
pub mod svd;
pub mod svddc;
pub mod triangular;

pub use self::cholesky::*;
Expand All @@ -16,6 +17,7 @@ pub use self::qr::*;
pub use self::solve::*;
pub use self::solveh::*;
pub use self::svd::*;
pub use self::svddc::*;
pub use self::triangular::*;

use super::error::*;
Expand All @@ -24,7 +26,7 @@ use super::types::*;
pub type Pivot = Vec<i32>;

/// Trait for primitive types which implements LAPACK subroutines
pub trait Lapack: OperatorNorm_ + QR_ + SVD_ + Solve_ + Solveh_ + Cholesky_ + Eigh_ + Triangular_ {}
pub trait Lapack: OperatorNorm_ + QR_ + SVD_ + SVDDC_ + Solve_ + Solveh_ + Cholesky_ + Eigh_ + Triangular_ {}

impl Lapack for f32 {}
impl Lapack for f64 {}
Expand Down
69 changes: 69 additions & 0 deletions src/lapack/svddc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use lapacke;
use num_traits::Zero;

use crate::error::*;
use crate::layout::MatrixLayout;
use crate::types::*;
use crate::svddc::UVTFlag;

use super::{SVDOutput, into_result};

pub trait SVDDC_: Scalar {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could for instance merge this with the existing SVD_ trait, and change the trait to provide another function (svddc, or sdd). This may be preferred.

Or alternatively a more complicated way to invoke the existing svd that switches things around depending.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm planning to refactoring these basis of Lapack traits (SVD_ and others) into single trait, but it should be another PR. Your implementation is enough match to current impls, and I will accept it.

unsafe fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result<SVDOutput<Self>>;
}

macro_rules! impl_svdd {
($scalar:ty, $gesdd:path) => {
impl SVDDC_ for $scalar {
unsafe fn svddc(
l: MatrixLayout,
jobz: UVTFlag,
mut a: &mut [Self],
) -> Result<SVDOutput<Self>> {
let (m, n) = l.size();
let k = m.min(n);
let lda = l.lda();
let (ucol, vtrow) = match jobz {
UVTFlag::Full => (m, n),
UVTFlag::Some => (k, k),
UVTFlag::None => (1, 1),
};
let mut s = vec![Self::Real::zero(); k.max(1) as usize];
let mut u = vec![Self::zero(); (m * ucol).max(1) as usize];
let ldu = l.resized(m, ucol).lda();
let mut vt = vec![Self::zero(); (vtrow * n).max(1) as usize];
let ldvt = l.resized(vtrow, n).lda();
let info = $gesdd(
l.lapacke_layout(),
jobz as u8,
m,
n,
&mut a,
lda,
&mut s,
&mut u,
ldu,
&mut vt,
ldvt,
);
into_result(
info,
SVDOutput {
s: s,
u: if jobz == UVTFlag::None { None } else { Some(u) },
vt: if jobz == UVTFlag::None {
None
} else {
Some(vt)
},
},
)
}
}
};
}

impl_svdd!(f32, lapacke::sgesdd);
impl_svdd!(f64, lapacke::dgesdd);
impl_svdd!(c32, lapacke::cgesdd);
impl_svdd!(c64, lapacke::zgesdd);
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pub mod qr;
pub mod solve;
pub mod solveh;
pub mod svd;
pub mod svddc;
pub mod trace;
pub mod triangular;
pub mod types;
Expand All @@ -76,6 +77,7 @@ pub use qr::*;
pub use solve::*;
pub use solveh::*;
pub use svd::*;
pub use svddc::*;
pub use trace::*;
pub use triangular::*;
pub use types::*;
110 changes: 110 additions & 0 deletions src/svddc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
//! Singular-value decomposition (SVD) by divide-and-conquer (?gesdd)

use ndarray::*;

use super::convert::*;
use super::error::*;
use super::layout::*;
use super::types::*;

#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum UVTFlag {
Full = b'A',
Some = b'S',
None = b'N',
}

/// Singular-value decomposition of matrix (copying) by divide-and-conquer
pub trait SVDDC {
type U;
type VT;
type Sigma;
fn svddc(&self, uvt_flag: UVTFlag) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::VT>)>;
}

/// Singular-value decomposition of matrix by divide-and-conquer
pub trait SVDDCInto {
type U;
type VT;
type Sigma;
fn svddc_into(
self,
uvt_flag: UVTFlag,
) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::VT>)>;
}

/// Singular-value decomposition of matrix reference by divide-and-conquer
pub trait SVDDCInplace {
type U;
type VT;
type Sigma;
fn svddc_inplace(
&mut self,
uvt_flag: UVTFlag,
) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::VT>)>;
}

impl<A, S> SVDDC for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: DataMut<Elem = A>,
{
type U = Array2<A>;
type VT = Array2<A>;
type Sigma = Array1<A::Real>;

fn svddc(&self, uvt_flag: UVTFlag) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::VT>)> {
self.to_owned().svddc_into(uvt_flag)
}
}

impl<A, S> SVDDCInto for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: DataMut<Elem = A>,
{
type U = Array2<A>;
type VT = Array2<A>;
type Sigma = Array1<A::Real>;

fn svddc_into(
mut self,
uvt_flag: UVTFlag,
) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::VT>)> {
self.svddc_inplace(uvt_flag)
}
}

impl<A, S> SVDDCInplace for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: DataMut<Elem = A>,
{
type U = Array2<A>;
type VT = Array2<A>;
type Sigma = Array1<A::Real>;

fn svddc_inplace(
&mut self,
uvt_flag: UVTFlag,
) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::VT>)> {
let l = self.layout()?;
let svd_res = unsafe { A::svddc(l, uvt_flag, self.as_allocated_mut()?)? };
let (m, n) = l.size();
let k = m.min(n);
let (ldu, tdu, ldvt, tdvt) = match uvt_flag {
UVTFlag::Full => (m, m, n, n),
UVTFlag::Some => (m, k, k, n),
UVTFlag::None => (1, 1, 1, 1),
};
let u = svd_res
.u
.map(|u| into_matrix(l.resized(ldu, tdu), u).expect("Size of U mismatches"));
let vt = svd_res
.vt
.map(|vt| into_matrix(l.resized(ldvt, tdvt), vt).expect("Size of VT mismatches"));
let s = ArrayBase::from_vec(svd_res.s);
Ok((u, s, vt))
}
}
74 changes: 74 additions & 0 deletions tests/svddc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
use ndarray::*;
use ndarray_linalg::*;

fn test(a: &Array2<f64>, flag: UVTFlag) {
let (n, m) = a.dim();
let k = n.min(m);
let answer = a.clone();
println!("a = \n{:?}", a);
let (u, s, vt): (_, Array1<_>, _) = a.svddc(flag).unwrap();
let mut sm = match flag {
UVTFlag::Full => Array::zeros((n, m)),
UVTFlag::Some => Array::zeros((k, k)),
UVTFlag::None => {
assert!(u.is_none());
assert!(vt.is_none());
return;
},
};
let u: Array2<_> = u.unwrap();
let vt: Array2<_> = vt.unwrap();
println!("u = \n{:?}", &u);
println!("s = \n{:?}", &s);
println!("v = \n{:?}", &vt);
for i in 0..k {
sm[(i, i)] = s[i];
}
assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7);
}

macro_rules! test_svd_impl {
($n:expr, $m:expr) => {
paste::item! {
#[test]
fn [<svddc_full_ $n x $m>]() {
let a = random(($n, $m));
test(&a, UVTFlag::Full);
}

#[test]
fn [<svddc_some_ $n x $m>]() {
let a = random(($n, $m));
test(&a, UVTFlag::Some);
}

#[test]
fn [<svddc_none_ $n x $m>]() {
let a = random(($n, $m));
test(&a, UVTFlag::None);
}

#[test]
fn [<svddc_full_ $n x $m _t>]() {
let a = random(($n, $m).f());
test(&a, UVTFlag::Full);
}

#[test]
fn [<svddc_some_ $n x $m _t>]() {
let a = random(($n, $m).f());
test(&a, UVTFlag::Some);
}

#[test]
fn [<svddc_none_ $n x $m _t>]() {
let a = random(($n, $m).f());
test(&a, UVTFlag::None);
}
}
};
}

test_svd_impl!(3, 3);
test_svd_impl!(4, 3);
test_svd_impl!(3, 4);