diff --git a/lax/src/svddc.rs b/lax/src/svddc.rs index 84f8394b..3e50d7bb 100644 --- a/lax/src/svddc.rs +++ b/lax/src/svddc.rs @@ -1,7 +1,7 @@ use super::*; use crate::{error::*, layout::MatrixLayout}; use cauchy::*; -use num_traits::Zero; +use num_traits::{ToPrimitive, Zero}; /// Specifies how many of the columns of *U* and rows of *V*ᵀ are computed and returned. /// @@ -21,56 +21,106 @@ pub trait SVDDC_: Scalar { unsafe fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result>; } -macro_rules! impl_svdd { - ($scalar:ty, $gesdd:path) => { +macro_rules! impl_svddc { + (@real, $scalar:ty, $gesdd:path) => { + impl_svddc!(@body, $scalar, $gesdd, ); + }; + (@complex, $scalar:ty, $gesdd:path) => { + impl_svddc!(@body, $scalar, $gesdd, rwork); + }; + (@body, $scalar:ty, $gesdd:path, $($rwork_ident:ident),*) => { impl SVDDC_ for $scalar { unsafe fn svddc( l: MatrixLayout, jobz: UVTFlag, mut a: &mut [Self], ) -> Result> { - let (m, n) = l.size(); + let m = l.lda(); + let n = l.len(); let k = m.min(n); - let lda = l.lda(); - let (ucol, vtrow) = match jobz { - UVTFlag::Full => (m, n), + let mut s = vec![Self::Real::zero(); k as usize]; + + let (u_col, vt_row) = match jobz { + UVTFlag::Full | UVTFlag::None => (m, n), UVTFlag::Some => (k, k), - UVTFlag::None => (1, 1), }; - let mut s = vec![Self::Real::zero(); k.max(1) as usize]; - let mut u = vec![Self::zero(); (m * ucol).max(1) as usize]; - let ldu = l.resized(m, ucol).lda(); - let mut vt = vec![Self::zero(); (vtrow * n).max(1) as usize]; - let ldvt = l.resized(vtrow, n).lda(); + let (mut u, mut vt) = match jobz { + UVTFlag::Full => ( + Some(vec![Self::zero(); (m * m) as usize]), + Some(vec![Self::zero(); (n * n) as usize]), + ), + UVTFlag::Some => ( + Some(vec![Self::zero(); (m * u_col) as usize]), + Some(vec![Self::zero(); (n * vt_row) as usize]), + ), + UVTFlag::None => (None, None), + }; + + $( // for complex only + let mx = n.max(m) as usize; + let mn = n.min(m) as usize; + let lrwork = match jobz { + UVTFlag::None => 7 * mn, + _ => std::cmp::max(5*mn*mn + 5*mn, 2*mx*mn + 2*mn*mn + mn), + }; + let mut $rwork_ident = vec![Self::Real::zero(); lrwork]; + )* + + // eval work size + let mut info = 0; + let mut iwork = vec![0; 8 * k as usize]; + let mut work_size = [Self::zero()]; $gesdd( - l.lapacke_layout(), jobz as u8, m, n, &mut a, - lda, + m, &mut s, - &mut u, - ldu, - &mut vt, - ldvt, - ) - .as_lapack_result()?; - Ok(SVDOutput { - s, - u: if jobz == UVTFlag::None { None } else { Some(u) }, - vt: if jobz == UVTFlag::None { - None - } else { - Some(vt) - }, - }) + u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + m, + vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + vt_row, + &mut work_size, + -1, + $(&mut $rwork_ident,)* + &mut iwork, + &mut info, + ); + info.as_lapack_result()?; + + // do svd + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + $gesdd( + jobz as u8, + m, + n, + &mut a, + m, + &mut s, + u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + m, + vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []), + vt_row, + &mut work, + lwork as i32, + $(&mut $rwork_ident,)* + &mut iwork, + &mut info, + ); + info.as_lapack_result()?; + + match l { + MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }), + MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }), + } } } }; } -impl_svdd!(f32, lapacke::sgesdd); -impl_svdd!(f64, lapacke::dgesdd); -impl_svdd!(c32, lapacke::cgesdd); -impl_svdd!(c64, lapacke::zgesdd); +impl_svddc!(@real, f32, lapack::sgesdd); +impl_svddc!(@real, f64, lapack::dgesdd); +impl_svddc!(@complex, c32, lapack::cgesdd); +impl_svddc!(@complex, c64, lapack::zgesdd); diff --git a/ndarray-linalg/src/svddc.rs b/ndarray-linalg/src/svddc.rs index 22f3ae0c..b27212ec 100644 --- a/ndarray-linalg/src/svddc.rs +++ b/ndarray-linalg/src/svddc.rs @@ -1,12 +1,8 @@ //! Singular-value decomposition (SVD) by divide-and-conquer (?gesdd) +use super::{convert::*, error::*, layout::*, types::*}; use ndarray::*; -use super::convert::*; -use super::error::*; -use super::layout::*; -use super::types::*; - pub use lapack::svddc::UVTFlag; /// Singular-value decomposition of matrix (copying) by divide-and-conquer @@ -87,17 +83,21 @@ where let svd_res = unsafe { A::svddc(l, uvt_flag, self.as_allocated_mut()?)? }; let (m, n) = l.size(); let k = m.min(n); - let (ldu, tdu, ldvt, tdvt) = match uvt_flag { - UVTFlag::Full => (m, m, n, n), - UVTFlag::Some => (m, k, k, n), - UVTFlag::None => (1, 1, 1, 1), + + let (u_col, vt_row) = match uvt_flag { + UVTFlag::Full => (m, n), + UVTFlag::Some => (k, k), + UVTFlag::None => (0, 0), }; + let u = svd_res .u - .map(|u| into_matrix(l.resized(ldu, tdu), u).expect("Size of U mismatches")); + .map(|u| into_matrix(l.resized(m, u_col), u).unwrap()); + let vt = svd_res .vt - .map(|vt| into_matrix(l.resized(ldvt, tdvt), vt).expect("Size of VT mismatches")); + .map(|vt| into_matrix(l.resized(vt_row, n), vt).unwrap()); + let s = ArrayBase::from(svd_res.s); Ok((u, s, vt)) } diff --git a/ndarray-linalg/tests/svddc.rs b/ndarray-linalg/tests/svddc.rs index 2c9204c8..fb26c8d5 100644 --- a/ndarray-linalg/tests/svddc.rs +++ b/ndarray-linalg/tests/svddc.rs @@ -1,13 +1,13 @@ use ndarray::*; use ndarray_linalg::*; -fn test(a: &Array2, flag: UVTFlag) { +fn test(a: &Array2, flag: UVTFlag) { let (n, m) = a.dim(); let k = n.min(m); let answer = a.clone(); println!("a = \n{:?}", a); let (u, s, vt): (_, Array1<_>, _) = a.svddc(flag).unwrap(); - let mut sm = match flag { + let mut sm: Array2 = match flag { UVTFlag::Full => Array::zeros((n, m)), UVTFlag::Some => Array::zeros((k, k)), UVTFlag::None => { @@ -22,53 +22,56 @@ fn test(a: &Array2, flag: UVTFlag) { println!("s = \n{:?}", &s); println!("v = \n{:?}", &vt); for i in 0..k { - sm[(i, i)] = s[i]; + sm[(i, i)] = T::from_real(s[i]); } - assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7); + assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, T::real(1e-7)); } macro_rules! test_svd_impl { - ($n:expr, $m:expr) => { + ($scalar:ty, $n:expr, $m:expr) => { paste::item! { #[test] - fn []() { + fn []() { let a = random(($n, $m)); - test(&a, UVTFlag::Full); + test::<$scalar>(&a, UVTFlag::Full); } #[test] - fn []() { + fn []() { let a = random(($n, $m)); - test(&a, UVTFlag::Some); + test::<$scalar>(&a, UVTFlag::Some); } #[test] - fn []() { + fn []() { let a = random(($n, $m)); - test(&a, UVTFlag::None); + test::<$scalar>(&a, UVTFlag::None); } #[test] - fn []() { + fn []() { let a = random(($n, $m).f()); - test(&a, UVTFlag::Full); + test::<$scalar>(&a, UVTFlag::Full); } #[test] - fn []() { + fn []() { let a = random(($n, $m).f()); - test(&a, UVTFlag::Some); + test::<$scalar>(&a, UVTFlag::Some); } #[test] - fn []() { + fn []() { let a = random(($n, $m).f()); - test(&a, UVTFlag::None); + test::<$scalar>(&a, UVTFlag::None); } } }; } -test_svd_impl!(3, 3); -test_svd_impl!(4, 3); -test_svd_impl!(3, 4); +test_svd_impl!(f64, 3, 3); +test_svd_impl!(f64, 4, 3); +test_svd_impl!(f64, 3, 4); +test_svd_impl!(c64, 3, 3); +test_svd_impl!(c64, 4, 3); +test_svd_impl!(c64, 3, 4);