Skip to content

least square by LAPACK #220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions lax/src/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,37 @@ impl MatrixLayout {
MatrixLayout::F { col, lda } => MatrixLayout::C { row: lda, lda: col },
}
}

/// Transpose without changing memory representation
///
/// C-contigious row=2, lda=3
///
/// ```text
/// [[1, 2, 3]
/// [4, 5, 6]]
/// ```
///
/// and F-contigious col=2, lda=3
///
/// ```text
/// [[1, 4]
/// [2, 5]
/// [3, 6]]
/// ```
///
/// have same memory representation `[1, 2, 3, 4, 5, 6]`, and this toggles them.
///
/// ```
/// # use lax::layout::*;
/// let layout = MatrixLayout::C { row: 2, lda: 3 };
/// assert_eq!(layout.t(), MatrixLayout::F { col: 2, lda: 3 });
/// ```
pub fn t(&self) -> Self {
match *self {
MatrixLayout::C { row, lda } => MatrixLayout::F { col: row, lda },
MatrixLayout::F { col, lda } => MatrixLayout::C { row: col, lda },
}
}
}

/// In-place transpose of a square matrix by keeping F/C layout
Expand Down Expand Up @@ -139,3 +170,59 @@ pub fn square_transpose<T: Scalar>(layout: MatrixLayout, a: &mut [T]) {
}
}
}

/// Out-place transpose for general matrix
///
/// Inplace transpose of non-square matrices is hard.
/// See also: https://en.wikipedia.org/wiki/In-place_matrix_transposition
///
/// ```rust
/// # use lax::layout::*;
/// let layout = MatrixLayout::C { row: 2, lda: 3 };
/// let a = vec![1., 2., 3., 4., 5., 6.];
/// let mut b = vec![0.0; a.len()];
/// let l = transpose(layout, &a, &mut b);
/// assert_eq!(l, MatrixLayout::F { col: 3, lda: 2 });
/// assert_eq!(b, &[1., 4., 2., 5., 3., 6.]);
/// ```
///
/// ```rust
/// # use lax::layout::*;
/// let layout = MatrixLayout::F { col: 2, lda: 3 };
/// let a = vec![1., 2., 3., 4., 5., 6.];
/// let mut b = vec![0.0; a.len()];
/// let l = transpose(layout, &a, &mut b);
/// assert_eq!(l, MatrixLayout::C { row: 3, lda: 2 });
/// assert_eq!(b, &[1., 4., 2., 5., 3., 6.]);
/// ```
///
/// Panics
/// ------
/// - If size of `a` and `layout` size mismatch
///
pub fn transpose<T: Scalar>(layout: MatrixLayout, from: &[T], to: &mut [T]) -> MatrixLayout {
let (m, n) = layout.size();
let transposed = layout.resized(n, m).t();
let m = m as usize;
let n = n as usize;
assert_eq!(from.len(), m * n);
assert_eq!(to.len(), m * n);

match layout {
MatrixLayout::C { .. } => {
for i in 0..m {
for j in 0..n {
to[j * m + i] = from[i * n + j];
}
}
}
MatrixLayout::F { .. } => {
for i in 0..m {
for j in 0..n {
to[i * n + j] = from[j * m + i];
}
}
}
}
transposed
}
182 changes: 115 additions & 67 deletions lax/src/least_squares.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
//! Least squares

use crate::{error::*, layout::MatrixLayout};
use crate::{error::*, layout::*};
use cauchy::*;
use num_traits::Zero;
use num_traits::{ToPrimitive, Zero};

/// Result of LeastSquares
pub struct LeastSquaresOutput<A: Scalar> {
Expand All @@ -14,13 +14,13 @@ pub struct LeastSquaresOutput<A: Scalar> {

/// Wraps `*gelsd`
pub trait LeastSquaresSvdDivideConquer_: Scalar {
unsafe fn least_squares(
fn least_squares(
a_layout: MatrixLayout,
a: &mut [Self],
b: &mut [Self],
) -> Result<LeastSquaresOutput<Self>>;

unsafe fn least_squares_nrhs(
fn least_squares_nrhs(
a_layout: MatrixLayout,
a: &mut [Self],
b_layout: MatrixLayout,
Expand All @@ -29,81 +29,129 @@ pub trait LeastSquaresSvdDivideConquer_: Scalar {
}

macro_rules! impl_least_squares {
($scalar:ty, $gelsd:path) => {
(@real, $scalar:ty, $gelsd:path) => {
impl_least_squares!(@body, $scalar, $gelsd, );
};
(@complex, $scalar:ty, $gelsd:path) => {
impl_least_squares!(@body, $scalar, $gelsd, rwork);
};

(@body, $scalar:ty, $gelsd:path, $($rwork:ident),*) => {
impl LeastSquaresSvdDivideConquer_ for $scalar {
unsafe fn least_squares(
a_layout: MatrixLayout,
fn least_squares(
l: MatrixLayout,
a: &mut [Self],
b: &mut [Self],
) -> Result<LeastSquaresOutput<Self>> {
let (m, n) = a_layout.size();
if (m as usize) > b.len() || (n as usize) > b.len() {
return Err(Error::InvalidShape);
}
let k = ::std::cmp::min(m, n);
let nrhs = 1;
let ldb = match a_layout {
MatrixLayout::F { .. } => m.max(n),
MatrixLayout::C { .. } => 1,
};
let rcond: Self::Real = -1.;
let mut singular_values: Vec<Self::Real> = vec![Self::Real::zero(); k as usize];
let mut rank: i32 = 0;

$gelsd(
a_layout.lapacke_layout(),
m,
n,
nrhs,
a,
a_layout.lda(),
b,
ldb,
&mut singular_values,
rcond,
&mut rank,
)
.as_lapack_result()?;

Ok(LeastSquaresOutput {
singular_values,
rank,
})
let b_layout = l.resized(b.len() as i32, 1);
Self::least_squares_nrhs(l, a, b_layout, b)
}

unsafe fn least_squares_nrhs(
fn least_squares_nrhs(
a_layout: MatrixLayout,
a: &mut [Self],
b_layout: MatrixLayout,
b: &mut [Self],
) -> Result<LeastSquaresOutput<Self>> {
// Minimize |b - Ax|_2
//
// where
// A : (m, n)
// b : (max(m, n), nrhs) // `b` has to store `x` on exit
// x : (n, nrhs)
let (m, n) = a_layout.size();
if (m as usize) > b.len()
|| (n as usize) > b.len()
|| a_layout.lapacke_layout() != b_layout.lapacke_layout()
{
return Err(Error::InvalidShape);
}
let k = ::std::cmp::min(m, n);
let nrhs = b_layout.size().1;
let (m_, nrhs) = b_layout.size();
let k = m.min(n);
assert!(m_ >= m);

// Transpose if a is C-continuous
let mut a_t = None;
let a_layout = match a_layout {
MatrixLayout::C { .. } => {
a_t = Some(vec![Self::zero(); a.len()]);
transpose(a_layout, a, a_t.as_mut().unwrap())
}
MatrixLayout::F { .. } => a_layout,
};

// Transpose if b is C-continuous
let mut b_t = None;
let b_layout = match b_layout {
MatrixLayout::C { .. } => {
b_t = Some(vec![Self::zero(); b.len()]);
transpose(b_layout, b, b_t.as_mut().unwrap())
}
MatrixLayout::F { .. } => b_layout,
};

let rcond: Self::Real = -1.;
let mut singular_values: Vec<Self::Real> = vec![Self::Real::zero(); k as usize];
let mut rank: i32 = 0;

$gelsd(
a_layout.lapacke_layout(),
m,
n,
nrhs,
a,
a_layout.lda(),
b,
b_layout.lda(),
&mut singular_values,
rcond,
&mut rank,
)
.as_lapack_result()?;
// eval work size
let mut info = 0;
let mut work_size = [Self::zero()];
let mut iwork_size = [0];
$(
let mut $rwork = [Self::Real::zero()];
)*
unsafe {
$gelsd(
m,
n,
nrhs,
a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a),
a_layout.lda(),
b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b),
b_layout.lda(),
&mut singular_values,
rcond,
&mut rank,
&mut work_size,
-1,
$(&mut $rwork,)*
&mut iwork_size,
&mut info,
)
};
info.as_lapack_result()?;

// calc
let lwork = work_size[0].to_usize().unwrap();
let mut work = vec![Self::zero(); lwork];
let liwork = iwork_size[0].to_usize().unwrap();
let mut iwork = vec![0; liwork];
$(
let lrwork = $rwork[0].to_usize().unwrap();
let mut $rwork = vec![Self::Real::zero(); lrwork];
)*
unsafe {
$gelsd(
m,
n,
nrhs,
a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a),
a_layout.lda(),
b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b),
b_layout.lda(),
&mut singular_values,
rcond,
&mut rank,
&mut work,
lwork as i32,
$(&mut $rwork,)*
&mut iwork,
&mut info,
);
}
info.as_lapack_result()?;

// Skip a_t -> a transpose because A has been destroyed
// Re-transpose b
if let Some(b_t) = b_t {
transpose(b_layout, &b_t, b);
}

Ok(LeastSquaresOutput {
singular_values,
rank,
Expand All @@ -113,7 +161,7 @@ macro_rules! impl_least_squares {
};
}

impl_least_squares!(f64, lapacke::dgelsd);
impl_least_squares!(f32, lapacke::sgelsd);
impl_least_squares!(c64, lapacke::zgelsd);
impl_least_squares!(c32, lapacke::cgelsd);
impl_least_squares!(@real, f64, lapack::dgelsd);
impl_least_squares!(@real, f32, lapack::sgelsd);
impl_least_squares!(@complex, c64, lapack::zgelsd);
impl_least_squares!(@complex, c32, lapack::cgelsd);
Loading