diff --git a/lax/src/lib.rs b/lax/src/lib.rs index 769910db..be88410e 100644 --- a/lax/src/lib.rs +++ b/lax/src/lib.rs @@ -70,6 +70,7 @@ pub mod layout; pub mod least_squares; pub mod opnorm; pub mod qr; +pub mod rcond; pub mod solve; pub mod solveh; pub mod svd; @@ -83,6 +84,7 @@ pub use self::eigh::*; pub use self::least_squares::*; pub use self::opnorm::*; pub use self::qr::*; +pub use self::rcond::*; pub use self::solve::*; pub use self::solveh::*; pub use self::svd::*; @@ -107,6 +109,7 @@ pub trait Lapack: + Eigh_ + Triangular_ + Tridiagonal_ + + Rcond_ { } diff --git a/lax/src/rcond.rs b/lax/src/rcond.rs new file mode 100644 index 00000000..135c4a12 --- /dev/null +++ b/lax/src/rcond.rs @@ -0,0 +1,86 @@ +use super::*; +use crate::{error::*, layout::MatrixLayout}; +use cauchy::*; +use num_traits::Zero; + +pub trait Rcond_: Scalar + Sized { + /// Estimates the the reciprocal of the condition number of the matrix in 1-norm. + /// + /// `anorm` should be the 1-norm of the matrix `a`. + fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result; +} + +macro_rules! impl_rcond_real { + ($scalar:ty, $gecon:path) => { + impl Rcond_ for $scalar { + fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result { + let (n, _) = l.size(); + let mut rcond = Self::Real::zero(); + let mut info = 0; + + let mut work = vec![Self::zero(); 4 * n as usize]; + let mut iwork = vec![0; n as usize]; + let norm_type = match l { + MatrixLayout::C { .. } => NormType::Infinity, + MatrixLayout::F { .. } => NormType::One, + } as u8; + unsafe { + $gecon( + norm_type, + n, + a, + l.lda(), + anorm, + &mut rcond, + &mut work, + &mut iwork, + &mut info, + ) + }; + info.as_lapack_result()?; + + Ok(rcond) + } + } + }; +} + +impl_rcond_real!(f32, lapack::sgecon); +impl_rcond_real!(f64, lapack::dgecon); + +macro_rules! impl_rcond_complex { + ($scalar:ty, $gecon:path) => { + impl Rcond_ for $scalar { + fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result { + let (n, _) = l.size(); + let mut rcond = Self::Real::zero(); + let mut info = 0; + let mut work = vec![Self::zero(); 2 * n as usize]; + let mut rwork = vec![Self::Real::zero(); 2 * n as usize]; + let norm_type = match l { + MatrixLayout::C { .. } => NormType::Infinity, + MatrixLayout::F { .. } => NormType::One, + } as u8; + unsafe { + $gecon( + norm_type, + n, + a, + l.lda(), + anorm, + &mut rcond, + &mut work, + &mut rwork, + &mut info, + ) + }; + info.as_lapack_result()?; + + Ok(rcond) + } + } + }; +} + +impl_rcond_complex!(c32, lapack::cgecon); +impl_rcond_complex!(c64, lapack::zgecon); diff --git a/lax/src/solve.rs b/lax/src/solve.rs index 67af6409..93aa4722 100644 --- a/lax/src/solve.rs +++ b/lax/src/solve.rs @@ -3,119 +3,99 @@ use super::*; use crate::{error::*, layout::MatrixLayout}; use cauchy::*; -use num_traits::Zero; +use num_traits::{ToPrimitive, Zero}; -/// Wraps `*getrf`, `*getri`, and `*getrs` pub trait Solve_: Scalar + Sized { /// Computes the LU factorization of a general `m x n` matrix `a` using /// partial pivoting with row interchanges. /// - /// If the result matches `Err(LinalgError::Lapack(LapackError { - /// return_code )) if return_code > 0`, then `U[(return_code-1, - /// return_code-1)]` is exactly zero. The factorization has been completed, - /// but the factor `U` is exactly singular, and division by zero will occur - /// if it is used to solve a system of equations. - unsafe fn lu(l: MatrixLayout, a: &mut [Self]) -> Result; - unsafe fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>; - /// Estimates the the reciprocal of the condition number of the matrix in 1-norm. + /// $ PA = LU $ /// - /// `anorm` should be the 1-norm of the matrix `a`. - unsafe fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result; - unsafe fn solve( - l: MatrixLayout, - t: Transpose, - a: &[Self], - p: &Pivot, - b: &mut [Self], - ) -> Result<()>; + /// Error + /// ------ + /// - `LapackComputationalFailure { return_code }` when the matrix is singular + /// - Division by zero will occur if it is used to solve a system of equations + /// because `U[(return_code-1, return_code-1)]` is exactly zero. + fn lu(l: MatrixLayout, a: &mut [Self]) -> Result; + + fn inv(l: MatrixLayout, a: &mut [Self], p: &Pivot) -> Result<()>; + + fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>; } macro_rules! impl_solve { - ($scalar:ty, $getrf:path, $getri:path, $gecon:path, $getrs:path) => { + ($scalar:ty, $getrf:path, $getri:path, $getrs:path) => { impl Solve_ for $scalar { - unsafe fn lu(l: MatrixLayout, a: &mut [Self]) -> Result { + fn lu(l: MatrixLayout, a: &mut [Self]) -> Result { let (row, col) = l.size(); + assert_eq!(a.len() as i32, row * col); + if row == 0 || col == 0 { + // Do nothing for empty matrix + return Ok(Vec::new()); + } let k = ::std::cmp::min(row, col); let mut ipiv = vec![0; k as usize]; - $getrf(l.lapacke_layout(), row, col, a, l.lda(), &mut ipiv).as_lapack_result()?; + let mut info = 0; + unsafe { $getrf(l.lda(), l.len(), a, l.lda(), &mut ipiv, &mut info) }; + info.as_lapack_result()?; Ok(ipiv) } - unsafe fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> { + fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> { let (n, _) = l.size(); - $getri(l.lapacke_layout(), n, a, l.lda(), ipiv).as_lapack_result()?; - Ok(()) - } - unsafe fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result { - let (n, _) = l.size(); - let mut rcond = Self::Real::zero(); - $gecon( - l.lapacke_layout(), - NormType::One as u8, - n, - a, - l.lda(), - anorm, - &mut rcond, - ) - .as_lapack_result()?; - Ok(rcond) + // calc work size + let mut info = 0; + let mut work_size = [Self::zero()]; + unsafe { $getri(n, a, l.lda(), ipiv, &mut work_size, -1, &mut info) }; + info.as_lapack_result()?; + + // actual + let lwork = work_size[0].to_usize().unwrap(); + let mut work = vec![Self::zero(); lwork]; + unsafe { + $getri( + l.len(), + a, + l.lda(), + ipiv, + &mut work, + lwork as i32, + &mut info, + ) + }; + info.as_lapack_result()?; + + Ok(()) } - unsafe fn solve( + fn solve( l: MatrixLayout, t: Transpose, a: &[Self], ipiv: &Pivot, b: &mut [Self], ) -> Result<()> { + let t = match l { + MatrixLayout::C { .. } => match t { + Transpose::No => Transpose::Transpose, + Transpose::Transpose | Transpose::Hermite => Transpose::No, + }, + _ => t, + }; let (n, _) = l.size(); let nrhs = 1; - let ldb = 1; - $getrs( - l.lapacke_layout(), - t as u8, - n, - nrhs, - a, - l.lda(), - ipiv, - b, - ldb, - ) - .as_lapack_result()?; + let ldb = l.lda(); + let mut info = 0; + unsafe { $getrs(t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb, &mut info) }; + info.as_lapack_result()?; Ok(()) } } }; } // impl_solve! -impl_solve!( - f64, - lapacke::dgetrf, - lapacke::dgetri, - lapacke::dgecon, - lapacke::dgetrs -); -impl_solve!( - f32, - lapacke::sgetrf, - lapacke::sgetri, - lapacke::sgecon, - lapacke::sgetrs -); -impl_solve!( - c64, - lapacke::zgetrf, - lapacke::zgetri, - lapacke::zgecon, - lapacke::zgetrs -); -impl_solve!( - c32, - lapacke::cgetrf, - lapacke::cgetri, - lapacke::cgecon, - lapacke::cgetrs -); +impl_solve!(f64, lapack::dgetrf, lapack::dgetri, lapack::dgetrs); +impl_solve!(f32, lapack::sgetrf, lapack::sgetri, lapack::sgetrs); +impl_solve!(c64, lapack::zgetrf, lapack::zgetri, lapack::zgetrs); +impl_solve!(c32, lapack::cgetrf, lapack::cgetri, lapack::cgetrs); diff --git a/ndarray-linalg/src/solve.rs b/ndarray-linalg/src/solve.rs index 566511f3..fd4b3017 100644 --- a/ndarray-linalg/src/solve.rs +++ b/ndarray-linalg/src/solve.rs @@ -167,15 +167,13 @@ where where Sb: DataMut, { - unsafe { - A::solve( - self.a.square_layout()?, - Transpose::No, - self.a.as_allocated()?, - &self.ipiv, - rhs.as_slice_mut().unwrap(), - )? - }; + A::solve( + self.a.square_layout()?, + Transpose::No, + self.a.as_allocated()?, + &self.ipiv, + rhs.as_slice_mut().unwrap(), + )?; Ok(rhs) } fn solve_t_inplace<'a, Sb>( @@ -185,15 +183,13 @@ where where Sb: DataMut, { - unsafe { - A::solve( - self.a.square_layout()?, - Transpose::Transpose, - self.a.as_allocated()?, - &self.ipiv, - rhs.as_slice_mut().unwrap(), - )? - }; + A::solve( + self.a.square_layout()?, + Transpose::Transpose, + self.a.as_allocated()?, + &self.ipiv, + rhs.as_slice_mut().unwrap(), + )?; Ok(rhs) } fn solve_h_inplace<'a, Sb>( @@ -203,15 +199,13 @@ where where Sb: DataMut, { - unsafe { - A::solve( - self.a.square_layout()?, - Transpose::Hermite, - self.a.as_allocated()?, - &self.ipiv, - rhs.as_slice_mut().unwrap(), - )? - }; + A::solve( + self.a.square_layout()?, + Transpose::Hermite, + self.a.as_allocated()?, + &self.ipiv, + rhs.as_slice_mut().unwrap(), + )?; Ok(rhs) } } @@ -273,7 +267,7 @@ where S: DataMut + RawDataClone, { fn factorize_into(mut self) -> Result> { - let ipiv = unsafe { A::lu(self.layout()?, self.as_allocated_mut()?)? }; + let ipiv = A::lu(self.layout()?, self.as_allocated_mut()?)?; Ok(LUFactorized { a: self, ipiv }) } } @@ -285,7 +279,7 @@ where { fn factorize(&self) -> Result>> { let mut a: Array2 = replicate(self); - let ipiv = unsafe { A::lu(a.layout()?, a.as_allocated_mut()?)? }; + let ipiv = A::lu(a.layout()?, a.as_allocated_mut()?)?; Ok(LUFactorized { a, ipiv }) } } @@ -312,13 +306,11 @@ where type Output = ArrayBase; fn inv_into(mut self) -> Result> { - unsafe { - A::inv( - self.a.square_layout()?, - self.a.as_allocated_mut()?, - &self.ipiv, - )? - }; + A::inv( + self.a.square_layout()?, + self.a.as_allocated_mut()?, + &self.ipiv, + )?; Ok(self.a) } } @@ -539,13 +531,11 @@ where S: Data + RawDataClone, { fn rcond(&self) -> Result { - unsafe { - Ok(A::rcond( - self.a.layout()?, - self.a.as_allocated()?, - self.a.opnorm_one()?, - )?) - } + Ok(A::rcond( + self.a.layout()?, + self.a.as_allocated()?, + self.a.opnorm_one()?, + )?) } }