diff --git a/lax/src/layout.rs b/lax/src/layout.rs index 9dad70e6..43b7ee87 100644 --- a/lax/src/layout.rs +++ b/lax/src/layout.rs @@ -97,6 +97,37 @@ impl MatrixLayout { MatrixLayout::F { col, lda } => MatrixLayout::C { row: lda, lda: col }, } } + + /// Transpose without changing memory representation + /// + /// C-contigious row=2, lda=3 + /// + /// ```text + /// [[1, 2, 3] + /// [4, 5, 6]] + /// ``` + /// + /// and F-contigious col=2, lda=3 + /// + /// ```text + /// [[1, 4] + /// [2, 5] + /// [3, 6]] + /// ``` + /// + /// have same memory representation `[1, 2, 3, 4, 5, 6]`, and this toggles them. + /// + /// ``` + /// # use lax::layout::*; + /// let layout = MatrixLayout::C { row: 2, lda: 3 }; + /// assert_eq!(layout.t(), MatrixLayout::F { col: 2, lda: 3 }); + /// ``` + pub fn t(&self) -> Self { + match *self { + MatrixLayout::C { row, lda } => MatrixLayout::F { col: row, lda }, + MatrixLayout::F { col, lda } => MatrixLayout::C { row: col, lda }, + } + } } /// In-place transpose of a square matrix by keeping F/C layout @@ -139,3 +170,59 @@ pub fn square_transpose(layout: MatrixLayout, a: &mut [T]) { } } } + +/// Out-place transpose for general matrix +/// +/// Inplace transpose of non-square matrices is hard. +/// See also: https://en.wikipedia.org/wiki/In-place_matrix_transposition +/// +/// ```rust +/// # use lax::layout::*; +/// let layout = MatrixLayout::C { row: 2, lda: 3 }; +/// let a = vec![1., 2., 3., 4., 5., 6.]; +/// let mut b = vec![0.0; a.len()]; +/// let l = transpose(layout, &a, &mut b); +/// assert_eq!(l, MatrixLayout::F { col: 3, lda: 2 }); +/// assert_eq!(b, &[1., 4., 2., 5., 3., 6.]); +/// ``` +/// +/// ```rust +/// # use lax::layout::*; +/// let layout = MatrixLayout::F { col: 2, lda: 3 }; +/// let a = vec![1., 2., 3., 4., 5., 6.]; +/// let mut b = vec![0.0; a.len()]; +/// let l = transpose(layout, &a, &mut b); +/// assert_eq!(l, MatrixLayout::C { row: 3, lda: 2 }); +/// assert_eq!(b, &[1., 4., 2., 5., 3., 6.]); +/// ``` +/// +/// Panics +/// ------ +/// - If size of `a` and `layout` size mismatch +/// +pub fn transpose(layout: MatrixLayout, from: &[T], to: &mut [T]) -> MatrixLayout { + let (m, n) = layout.size(); + let transposed = layout.resized(n, m).t(); + let m = m as usize; + let n = n as usize; + assert_eq!(from.len(), m * n); + assert_eq!(to.len(), m * n); + + match layout { + MatrixLayout::C { .. } => { + for i in 0..m { + for j in 0..n { + to[j * m + i] = from[i * n + j]; + } + } + } + MatrixLayout::F { .. } => { + for i in 0..m { + for j in 0..n { + to[i * n + j] = from[j * m + i]; + } + } + } + } + transposed +} diff --git a/lax/src/least_squares.rs b/lax/src/least_squares.rs index 69553a44..d684c9b8 100644 --- a/lax/src/least_squares.rs +++ b/lax/src/least_squares.rs @@ -1,8 +1,8 @@ //! Least squares -use crate::{error::*, layout::MatrixLayout}; +use crate::{error::*, layout::*}; use cauchy::*; -use num_traits::Zero; +use num_traits::{ToPrimitive, Zero}; /// Result of LeastSquares pub struct LeastSquaresOutput { @@ -14,13 +14,13 @@ pub struct LeastSquaresOutput { /// Wraps `*gelsd` pub trait LeastSquaresSvdDivideConquer_: Scalar { - unsafe fn least_squares( + fn least_squares( a_layout: MatrixLayout, a: &mut [Self], b: &mut [Self], ) -> Result>; - unsafe fn least_squares_nrhs( + fn least_squares_nrhs( a_layout: MatrixLayout, a: &mut [Self], b_layout: MatrixLayout, @@ -29,81 +29,129 @@ pub trait LeastSquaresSvdDivideConquer_: Scalar { } macro_rules! impl_least_squares { - ($scalar:ty, $gelsd:path) => { + (@real, $scalar:ty, $gelsd:path) => { + impl_least_squares!(@body, $scalar, $gelsd, ); + }; + (@complex, $scalar:ty, $gelsd:path) => { + impl_least_squares!(@body, $scalar, $gelsd, rwork); + }; + + (@body, $scalar:ty, $gelsd:path, $($rwork:ident),*) => { impl LeastSquaresSvdDivideConquer_ for $scalar { - unsafe fn least_squares( - a_layout: MatrixLayout, + fn least_squares( + l: MatrixLayout, a: &mut [Self], b: &mut [Self], ) -> Result> { - let (m, n) = a_layout.size(); - if (m as usize) > b.len() || (n as usize) > b.len() { - return Err(Error::InvalidShape); - } - let k = ::std::cmp::min(m, n); - let nrhs = 1; - let ldb = match a_layout { - MatrixLayout::F { .. } => m.max(n), - MatrixLayout::C { .. } => 1, - }; - let rcond: Self::Real = -1.; - let mut singular_values: Vec = vec![Self::Real::zero(); k as usize]; - let mut rank: i32 = 0; - - $gelsd( - a_layout.lapacke_layout(), - m, - n, - nrhs, - a, - a_layout.lda(), - b, - ldb, - &mut singular_values, - rcond, - &mut rank, - ) - .as_lapack_result()?; - - Ok(LeastSquaresOutput { - singular_values, - rank, - }) + let b_layout = l.resized(b.len() as i32, 1); + Self::least_squares_nrhs(l, a, b_layout, b) } - unsafe fn least_squares_nrhs( + fn least_squares_nrhs( a_layout: MatrixLayout, a: &mut [Self], b_layout: MatrixLayout, b: &mut [Self], ) -> Result> { + // Minimize |b - Ax|_2 + // + // where + // A : (m, n) + // b : (max(m, n), nrhs) // `b` has to store `x` on exit + // x : (n, nrhs) let (m, n) = a_layout.size(); - if (m as usize) > b.len() - || (n as usize) > b.len() - || a_layout.lapacke_layout() != b_layout.lapacke_layout() - { - return Err(Error::InvalidShape); - } - let k = ::std::cmp::min(m, n); - let nrhs = b_layout.size().1; + let (m_, nrhs) = b_layout.size(); + let k = m.min(n); + assert!(m_ >= m); + + // Transpose if a is C-continuous + let mut a_t = None; + let a_layout = match a_layout { + MatrixLayout::C { .. } => { + a_t = Some(vec![Self::zero(); a.len()]); + transpose(a_layout, a, a_t.as_mut().unwrap()) + } + MatrixLayout::F { .. } => a_layout, + }; + + // Transpose if b is C-continuous + let mut b_t = None; + let b_layout = match b_layout { + MatrixLayout::C { .. } => { + b_t = Some(vec![Self::zero(); b.len()]); + transpose(b_layout, b, b_t.as_mut().unwrap()) + } + MatrixLayout::F { .. } => b_layout, + }; + let rcond: Self::Real = -1.; let mut singular_values: Vec = vec![Self::Real::zero(); k as usize]; let mut rank: i32 = 0; - $gelsd( - a_layout.lapacke_layout(), - m, - n, - nrhs, - a, - a_layout.lda(), - b, - b_layout.lda(), - &mut singular_values, - rcond, - &mut rank, - ) - .as_lapack_result()?; + // eval work size + let mut info = 0; + let mut work_size = [Self::zero()]; + let mut iwork_size = [0]; + $( + let mut $rwork = [Self::Real::zero()]; + )* + unsafe { + $gelsd( + m, + n, + nrhs, + a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a), + a_layout.lda(), + b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), + b_layout.lda(), + &mut singular_values, + rcond, + &mut rank, + &mut work_size, + -1, + $(&mut $rwork,)* + &mut iwork_size, + &mut info, + ) + }; + info.as_lapack_result()?; + + // calc + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + let liwork = iwork_size[0].to_usize().unwrap(); + let mut iwork = vec![0; liwork]; + $( + let lrwork = $rwork[0].to_usize().unwrap(); + let mut $rwork = vec![Self::Real::zero(); lrwork]; + )* + unsafe { + $gelsd( + m, + n, + nrhs, + a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a), + a_layout.lda(), + b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b), + b_layout.lda(), + &mut singular_values, + rcond, + &mut rank, + &mut work, + lwork as i32, + $(&mut $rwork,)* + &mut iwork, + &mut info, + ); + } + info.as_lapack_result()?; + + // Skip a_t -> a transpose because A has been destroyed + // Re-transpose b + if let Some(b_t) = b_t { + transpose(b_layout, &b_t, b); + } + Ok(LeastSquaresOutput { singular_values, rank, @@ -113,7 +161,7 @@ macro_rules! impl_least_squares { }; } -impl_least_squares!(f64, lapacke::dgelsd); -impl_least_squares!(f32, lapacke::sgelsd); -impl_least_squares!(c64, lapacke::zgelsd); -impl_least_squares!(c32, lapacke::cgelsd); +impl_least_squares!(@real, f64, lapack::dgelsd); +impl_least_squares!(@real, f32, lapack::sgelsd); +impl_least_squares!(@complex, c64, lapack::zgelsd); +impl_least_squares!(@complex, c32, lapack::cgelsd); diff --git a/ndarray-linalg/src/least_squares.rs b/ndarray-linalg/src/least_squares.rs index 18d2033f..0ff518ad 100644 --- a/ndarray-linalg/src/least_squares.rs +++ b/ndarray-linalg/src/least_squares.rs @@ -76,6 +76,7 @@ use crate::types::*; /// is a `m x 1` column vector. If `I` is `Ix2`, the RHS is a `n x k` matrix /// (which can be seen as solving `Ax = b` k times for different b) and /// the solution is a `m x k` matrix. +#[derive(Debug, Clone)] pub struct LeastSquaresResult { /// The singular values of the matrix A in `Ax = b` pub singular_values: Array1, @@ -266,6 +267,9 @@ where &mut self, rhs: &mut ArrayBase, ) -> Result> { + if self.shape()[0] != rhs.shape()[0] { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into()); + } let (m, n) = (self.shape()[0], self.shape()[1]); if n > m { // we need a new rhs b/c it will be overwritten with the solution @@ -284,21 +288,19 @@ fn compute_least_squares_srhs( rhs: &mut ArrayBase, ) -> Result> where - E: Scalar + Lapack + LeastSquaresSvdDivideConquer_, + E: Scalar + Lapack, D1: DataMut, D2: DataMut, { let LeastSquaresOutput:: { singular_values, rank, - } = unsafe { - ::least_squares( - a.layout()?, - a.as_allocated_mut()?, - rhs.as_slice_memory_order_mut() - .ok_or_else(|| LinalgError::MemoryNotCont)?, - )? - }; + } = E::least_squares( + a.layout()?, + a.as_allocated_mut()?, + rhs.as_slice_memory_order_mut() + .ok_or_else(|| LinalgError::MemoryNotCont)?, + )?; let (m, n) = (a.shape()[0], a.shape()[1]); let solution = rhs.slice(s![0..n]).to_owned(); @@ -347,6 +349,9 @@ where &mut self, rhs: &mut ArrayBase, ) -> Result> { + if self.shape()[0] != rhs.shape()[0] { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into()); + } let (m, n) = (self.shape()[0], self.shape()[1]); if n > m { // we need a new rhs b/c it will be overwritten with the solution @@ -378,14 +383,12 @@ where let LeastSquaresOutput:: { singular_values, rank, - } = unsafe { - E::least_squares_nrhs( - a_layout, - a.as_allocated_mut()?, - rhs_layout, - rhs.as_allocated_mut()?, - )? - }; + } = E::least_squares_nrhs( + a_layout, + a.as_allocated_mut()?, + rhs_layout, + rhs.as_allocated_mut()?, + )?; let solution: Array2 = rhs.slice(s![..a.shape()[1], ..]).to_owned(); let singular_values = Array::from_shape_vec((singular_values.len(),), singular_values)?; @@ -549,28 +552,13 @@ mod tests { // // Testing error cases // - #[test] fn incompatible_shape_error_on_mismatching_num_rows() { let a: Array2 = array![[1., 2.], [4., 5.], [3., 4.]]; let b: Array1 = array![1., 2.]; - let res = a.least_squares(&b); - match res { - Err(LinalgError::Lapack(err)) if matches!(err, lax::error::Error::InvalidShape) => {} - _ => panic!("Expected Err()"), - } - } - - #[test] - fn incompatible_shape_error_on_mismatching_layout() { - let a: Array2 = array![[1., 2.], [4., 5.], [3., 4.]]; - let b = array![[1.], [2.]].t().to_owned(); - assert_eq!(b.layout().unwrap(), MatrixLayout::F { col: 2, lda: 1 }); - - let res = a.least_squares(&b); - match res { - Err(LinalgError::Lapack(err)) if matches!(err, lax::error::Error::InvalidShape) => {} - _ => panic!("Expected Err()"), + match a.least_squares(&b) { + Err(LinalgError::Shape(e)) if e.kind() == ErrorKind::IncompatibleShape => {} + _ => panic!("Should be raise IncompatibleShape"), } } } diff --git a/ndarray-linalg/tests/least_squares.rs b/ndarray-linalg/tests/least_squares.rs index c388c9d7..e2df3370 100644 --- a/ndarray-linalg/tests/least_squares.rs +++ b/ndarray-linalg/tests/least_squares.rs @@ -27,13 +27,13 @@ macro_rules! impl_exact { paste::item! { #[test] fn []() { - let a: Array2 = random((3, 3)); + let a: Array2<$scalar> = random((3, 3)); test_exact(a) } #[test] fn []() { - let a: Array2 = random((3, 3).f()); + let a: Array2<$scalar> = random((3, 3).f()); test_exact(a) } } @@ -73,13 +73,13 @@ macro_rules! impl_overdetermined { paste::item! { #[test] fn []() { - let a: Array2 = random((4, 3)); + let a: Array2<$scalar> = random((4, 3)); test_overdetermined(a) } #[test] fn []() { - let a: Array2 = random((4, 3).f()); + let a: Array2<$scalar> = random((4, 3).f()); test_overdetermined(a) } } @@ -110,13 +110,13 @@ macro_rules! impl_underdetermined { paste::item! { #[test] fn []() { - let a: Array2 = random((3, 4)); + let a: Array2<$scalar> = random((3, 4)); test_underdetermined(a) } #[test] fn []() { - let a: Array2 = random((3, 4).f()); + let a: Array2<$scalar> = random((3, 4).f()); test_underdetermined(a) } } diff --git a/ndarray-linalg/tests/least_squares_nrhs.rs b/ndarray-linalg/tests/least_squares_nrhs.rs index 4c964697..dd7d283c 100644 --- a/ndarray-linalg/tests/least_squares_nrhs.rs +++ b/ndarray-linalg/tests/least_squares_nrhs.rs @@ -9,6 +9,7 @@ fn test_exact(a: Array2, b: Array2) { assert_eq!(b.layout().unwrap().size(), (3, 2)); let result = a.least_squares(&b).unwrap(); + dbg!(&result); // unpack result let x: Array2 = result.solution; let residual_l2_square: Array1 = result.residual_sum_of_squares.unwrap(); @@ -31,33 +32,29 @@ macro_rules! impl_exact { paste::item! { #[test] fn []() { - let a: Array2 = random((3, 3)); - let b: Array2 = random((3, 2)); + let a: Array2<$scalar> = random((3, 3)); + let b: Array2<$scalar> = random((3, 2)); test_exact(a, b) } - /* Unsupported currently. See https://github.com/rust-ndarray/ndarray-linalg/issues/234 - #[test] fn []() { - let a: Array2 = random((3, 3)); - let b: Array2 = random((3, 2).f()); + let a: Array2<$scalar> = random((3, 3)); + let b: Array2<$scalar> = random((3, 2).f()); test_exact(a, b) } #[test] fn []() { - let a: Array2 = random((3, 3).f()); - let b: Array2 = random((3, 2)); + let a: Array2<$scalar> = random((3, 3).f()); + let b: Array2<$scalar> = random((3, 2)); test_exact(a, b) } - */ - #[test] fn []() { - let a: Array2 = random((3, 3).f()); - let b: Array2 = random((3, 2).f()); + let a: Array2<$scalar> = random((3, 3).f()); + let b: Array2<$scalar> = random((3, 2).f()); test_exact(a, b) } } @@ -103,33 +100,29 @@ macro_rules! impl_overdetermined { paste::item! { #[test] fn []() { - let a: Array2 = random((4, 3)); - let b: Array2 = random((4, 2)); + let a: Array2<$scalar> = random((4, 3)); + let b: Array2<$scalar> = random((4, 2)); test_overdetermined(a, b) } - /* Unsupported currently. See https://github.com/rust-ndarray/ndarray-linalg/issues/234 - #[test] fn []() { - let a: Array2 = random((4, 3).f()); - let b: Array2 = random((4, 2)); + let a: Array2<$scalar> = random((4, 3).f()); + let b: Array2<$scalar> = random((4, 2)); test_overdetermined(a, b) } #[test] fn []() { - let a: Array2 = random((4, 3)); - let b: Array2 = random((4, 2).f()); + let a: Array2<$scalar> = random((4, 3)); + let b: Array2<$scalar> = random((4, 2).f()); test_overdetermined(a, b) } - */ - #[test] fn []() { - let a: Array2 = random((4, 3).f()); - let b: Array2 = random((4, 2).f()); + let a: Array2<$scalar> = random((4, 3).f()); + let b: Array2<$scalar> = random((4, 2).f()); test_overdetermined(a, b) } } @@ -162,33 +155,29 @@ macro_rules! impl_underdetermined { paste::item! { #[test] fn []() { - let a: Array2 = random((3, 4)); - let b: Array2 = random((3, 2)); + let a: Array2<$scalar> = random((3, 4)); + let b: Array2<$scalar> = random((3, 2)); test_underdetermined(a, b) } - /* Unsupported currently. See https://github.com/rust-ndarray/ndarray-linalg/issues/234 - #[test] fn []() { - let a: Array2 = random((3, 4).f()); - let b: Array2 = random((3, 2)); + let a: Array2<$scalar> = random((3, 4).f()); + let b: Array2<$scalar> = random((3, 2)); test_underdetermined(a, b) } #[test] fn []() { - let a: Array2 = random((3, 4)); - let b: Array2 = random((3, 2).f()); + let a: Array2<$scalar> = random((3, 4)); + let b: Array2<$scalar> = random((3, 2).f()); test_underdetermined(a, b) } - */ - #[test] fn []() { - let a: Array2 = random((3, 4).f()); - let b: Array2 = random((3, 2).f()); + let a: Array2<$scalar> = random((3, 4).f()); + let b: Array2<$scalar> = random((3, 2).f()); test_underdetermined(a, b) } }