Skip to content

Commit bfecc99

Browse files
authored
Merge pull request #154 from rust-ndarray/householder
Householder reflection
2 parents 021fcee + 840bf26 commit bfecc99

File tree

4 files changed

+355
-38
lines changed

4 files changed

+355
-38
lines changed

src/krylov/householder.rs

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
//! Householder reflection
2+
//!
3+
//! - [Householder transformation - Wikipedia](https://en.wikipedia.org/wiki/Householder_transformation)
4+
//!
5+
6+
use super::*;
7+
use crate::{inner::*, norm::*};
8+
use num_traits::One;
9+
10+
/// Calc a reflactor `w` from a vector `x`
11+
pub fn calc_reflector<A, S>(x: &mut ArrayBase<S, Ix1>) -> A
12+
where
13+
A: Scalar + Lapack,
14+
S: DataMut<Elem = A>,
15+
{
16+
let norm = x.norm_l2();
17+
let alpha = -x[0].mul_real(norm / x[0].abs());
18+
x[0] -= alpha;
19+
let inv_rev_norm = A::Real::one() / x.norm_l2();
20+
azip!(mut a(x) in { *a = a.mul_real(inv_rev_norm)});
21+
alpha
22+
}
23+
24+
/// Take a reflection `P = I - 2ww^T`
25+
///
26+
/// Panic
27+
/// ------
28+
/// - if the size of `w` and `a` mismaches
29+
pub fn reflect<A, S1, S2>(w: &ArrayBase<S1, Ix1>, a: &mut ArrayBase<S2, Ix1>)
30+
where
31+
A: Scalar + Lapack,
32+
S1: Data<Elem = A>,
33+
S2: DataMut<Elem = A>,
34+
{
35+
assert_eq!(w.len(), a.len());
36+
let n = a.len();
37+
let c = A::from(2.0).unwrap() * w.inner(&a);
38+
for l in 0..n {
39+
a[l] -= c * w[l];
40+
}
41+
}
42+
43+
/// Iterative orthogonalizer using Householder reflection
44+
#[derive(Debug, Clone)]
45+
pub struct Householder<A: Scalar> {
46+
/// Dimension of orthogonalizer
47+
dim: usize,
48+
49+
/// Store Householder reflector.
50+
///
51+
/// The coefficient is copied into another array, and this does not contain
52+
v: Vec<Array1<A>>,
53+
}
54+
55+
impl<A: Scalar + Lapack> Householder<A> {
56+
/// Create a new orthogonalizer
57+
pub fn new(dim: usize) -> Self {
58+
Householder { dim, v: Vec::new() }
59+
}
60+
61+
/// Take a Reflection `P = I - 2ww^T`
62+
fn fundamental_reflection<S>(&self, k: usize, a: &mut ArrayBase<S, Ix1>)
63+
where
64+
S: DataMut<Elem = A>,
65+
{
66+
assert!(k < self.v.len());
67+
assert_eq!(a.len(), self.dim, "Input array size mismaches to the dimension");
68+
reflect(&self.v[k].slice(s![k..]), &mut a.slice_mut(s![k..]));
69+
}
70+
71+
/// Take forward reflection `P = P_l ... P_1`
72+
pub fn forward_reflection<S>(&self, a: &mut ArrayBase<S, Ix1>)
73+
where
74+
S: DataMut<Elem = A>,
75+
{
76+
assert!(a.len() == self.dim);
77+
let l = self.v.len();
78+
for k in 0..l {
79+
self.fundamental_reflection(k, a);
80+
}
81+
}
82+
83+
/// Take backward reflection `P = P_1 ... P_l`
84+
pub fn backward_reflection<S>(&self, a: &mut ArrayBase<S, Ix1>)
85+
where
86+
S: DataMut<Elem = A>,
87+
{
88+
assert!(a.len() == self.dim);
89+
let l = self.v.len();
90+
for k in (0..l).rev() {
91+
self.fundamental_reflection(k, a);
92+
}
93+
}
94+
95+
fn eval_residual<S>(&self, a: &ArrayBase<S, Ix1>) -> A::Real
96+
where
97+
S: Data<Elem = A>,
98+
{
99+
let l = self.v.len();
100+
a.slice(s![l..]).norm_l2()
101+
}
102+
}
103+
104+
impl<A: Scalar + Lapack> Orthogonalizer for Householder<A> {
105+
type Elem = A;
106+
107+
fn dim(&self) -> usize {
108+
self.dim
109+
}
110+
111+
fn len(&self) -> usize {
112+
self.v.len()
113+
}
114+
115+
fn coeff<S>(&self, a: ArrayBase<S, Ix1>) -> Array1<A>
116+
where
117+
S: Data<Elem = A>,
118+
{
119+
let mut a = a.into_owned();
120+
self.forward_reflection(&mut a);
121+
let res = self.eval_residual(&a);
122+
let k = self.len();
123+
let mut c = Array1::zeros(k + 1);
124+
azip!(mut c(c.slice_mut(s![..k])), a(a.slice(s![..k])) in { *c = a });
125+
c[k] = A::from_real(res);
126+
c
127+
}
128+
129+
fn append<S>(&mut self, mut a: ArrayBase<S, Ix1>, rtol: A::Real) -> Result<Array1<A>, Array1<A>>
130+
where
131+
S: DataMut<Elem = A>,
132+
{
133+
assert_eq!(a.len(), self.dim);
134+
let k = self.len();
135+
136+
self.forward_reflection(&mut a);
137+
let mut coef = Array::zeros(k + 1);
138+
for i in 0..k {
139+
coef[i] = a[i];
140+
}
141+
if self.is_full() {
142+
return Err(coef); // coef[k] must be zero in this case
143+
}
144+
145+
let alpha = calc_reflector(&mut a.slice_mut(s![k..]));
146+
coef[k] = alpha;
147+
148+
if alpha.abs() < rtol {
149+
// linearly dependent
150+
return Err(coef);
151+
}
152+
self.v.push(a.into_owned());
153+
Ok(coef)
154+
}
155+
156+
fn get_q(&self) -> Q<A> {
157+
assert!(self.len() > 0);
158+
let mut a = Array::zeros((self.dim(), self.len()));
159+
for (i, mut col) in a.axis_iter_mut(Axis(1)).enumerate() {
160+
col[i] = A::one();
161+
self.backward_reflection(&mut col);
162+
}
163+
a
164+
}
165+
}
166+
167+
/// Online QR decomposition using Householder reflection
168+
pub fn householder<A, S>(
169+
iter: impl Iterator<Item = ArrayBase<S, Ix1>>,
170+
dim: usize,
171+
rtol: A::Real,
172+
strategy: Strategy,
173+
) -> (Q<A>, R<A>)
174+
where
175+
A: Scalar + Lapack,
176+
S: Data<Elem = A>,
177+
{
178+
let h = Householder::new(dim);
179+
qr(iter, h, rtol, strategy)
180+
}
181+
182+
#[cfg(test)]
183+
mod tests {
184+
use super::*;
185+
use crate::assert::*;
186+
use num_traits::Zero;
187+
188+
#[test]
189+
fn check_reflector() {
190+
let mut a = array![c64::new(1.0, 1.0), c64::new(1.0, 0.0), c64::new(0.0, 1.0)];
191+
let mut w = a.clone();
192+
calc_reflector(&mut w);
193+
reflect(&w, &mut a);
194+
close_l2(
195+
&a,
196+
&array![-c64::new(2.0.sqrt(), 2.0.sqrt()), c64::zero(), c64::zero()],
197+
1e-9,
198+
);
199+
}
200+
}

src/krylov/mgs.rs

Lines changed: 28 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,6 @@ use super::*;
44
use crate::{generate::*, inner::*, norm::Norm};
55

66
/// Iterative orthogonalizer using modified Gram-Schmit procedure
7-
///
8-
/// ```rust
9-
/// # use ndarray::*;
10-
/// # use ndarray_linalg::{krylov::*, *};
11-
/// let mut mgs = MGS::new(3);
12-
/// let coef = mgs.append(array![0.0, 1.0, 0.0], 1e-9).unwrap();
13-
/// close_l2(&coef, &array![1.0], 1e-9);
14-
///
15-
/// let coef = mgs.append(array![1.0, 1.0, 0.0], 1e-9).unwrap();
16-
/// close_l2(&coef, &array![1.0, 1.0], 1e-9);
17-
///
18-
/// // Fail if the vector is linearly dependent
19-
/// assert!(mgs.append(array![1.0, 2.0, 0.0], 1e-9).is_err());
20-
///
21-
/// // You can get coefficients of dependent vector
22-
/// if let Err(coef) = mgs.append(array![1.0, 2.0, 0.0], 1e-9) {
23-
/// close_l2(&coef, &array![2.0, 1.0, 0.0], 1e-9);
24-
/// }
25-
/// ```
267
#[derive(Debug, Clone)]
278
pub struct MGS<A> {
289
/// Dimension of base space
@@ -31,14 +12,36 @@ pub struct MGS<A> {
3112
q: Vec<Array1<A>>,
3213
}
3314

34-
impl<A: Scalar> MGS<A> {
15+
impl<A: Scalar + Lapack> MGS<A> {
3516
/// Create an empty orthogonalizer
3617
pub fn new(dimension: usize) -> Self {
3718
Self {
3819
dimension,
3920
q: Vec::new(),
4021
}
4122
}
23+
24+
/// Orthogonalize given vector against to the current basis
25+
///
26+
/// - Returned array is coefficients and residual norm
27+
/// - `a` will contain the residual vector
28+
///
29+
pub fn orthogonalize<S>(&self, a: &mut ArrayBase<S, Ix1>) -> Array1<A>
30+
where
31+
S: DataMut<Elem = A>,
32+
{
33+
assert_eq!(a.len(), self.dim());
34+
let mut coef = Array1::zeros(self.len() + 1);
35+
for i in 0..self.len() {
36+
let q = &self.q[i];
37+
let c = q.inner(&a);
38+
azip!(mut a (&mut *a), q (q) in { *a = *a - c * q } );
39+
coef[i] = c;
40+
}
41+
let nrm = a.norm_l2();
42+
coef[self.len()] = A::from_real(nrm);
43+
coef
44+
}
4245
}
4346

4447
impl<A: Scalar + Lapack> Orthogonalizer for MGS<A> {
@@ -52,22 +55,13 @@ impl<A: Scalar + Lapack> Orthogonalizer for MGS<A> {
5255
self.q.len()
5356
}
5457

55-
fn orthogonalize<S>(&self, a: &mut ArrayBase<S, Ix1>) -> Array1<A>
58+
fn coeff<S>(&self, a: ArrayBase<S, Ix1>) -> Array1<A>
5659
where
5760
A: Lapack,
58-
S: DataMut<Elem = A>,
61+
S: Data<Elem = A>,
5962
{
60-
assert_eq!(a.len(), self.dim());
61-
let mut coef = Array1::zeros(self.len() + 1);
62-
for i in 0..self.len() {
63-
let q = &self.q[i];
64-
let c = q.inner(&a);
65-
azip!(mut a (&mut *a), q (q) in { *a = *a - c * q } );
66-
coef[i] = c;
67-
}
68-
let nrm = a.norm_l2();
69-
coef[self.len()] = A::from_real(nrm);
70-
coef
63+
let mut a = a.into_owned();
64+
self.orthogonalize(&mut a)
7165
}
7266

7367
fn append<S>(&mut self, a: ArrayBase<S, Ix1>, rtol: A::Real) -> Result<Array1<A>, Array1<A>>
@@ -92,7 +86,7 @@ impl<A: Scalar + Lapack> Orthogonalizer for MGS<A> {
9286
}
9387
}
9488

95-
/// Online QR decomposition of vectors using modified Gram-Schmit algorithm
89+
/// Online QR decomposition using modified Gram-Schmit algorithm
9690
pub fn mgs<A, S>(
9791
iter: impl Iterator<Item = ArrayBase<S, Ix1>>,
9892
dim: usize,

src/krylov/mod.rs

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
use crate::types::*;
44
use ndarray::*;
55

6-
mod mgs;
6+
pub mod householder;
7+
pub mod mgs;
78

9+
pub use householder::{householder, Householder};
810
pub use mgs::{mgs, MGS};
911

1012
/// Q-matrix
@@ -22,6 +24,28 @@ pub type Q<A> = Array2<A>;
2224
pub type R<A> = Array2<A>;
2325

2426
/// Trait for creating orthogonal basis from iterator of arrays
27+
///
28+
/// Example
29+
/// -------
30+
///
31+
/// ```rust
32+
/// # use ndarray::*;
33+
/// # use ndarray_linalg::{krylov::*, *};
34+
/// let mut mgs = MGS::new(3);
35+
/// let coef = mgs.append(array![0.0, 1.0, 0.0], 1e-9).unwrap();
36+
/// close_l2(&coef, &array![1.0], 1e-9);
37+
///
38+
/// let coef = mgs.append(array![1.0, 1.0, 0.0], 1e-9).unwrap();
39+
/// close_l2(&coef, &array![1.0, 1.0], 1e-9);
40+
///
41+
/// // Fail if the vector is linearly dependent
42+
/// assert!(mgs.append(array![1.0, 2.0, 0.0], 1e-9).is_err());
43+
///
44+
/// // You can get coefficients of dependent vector
45+
/// if let Err(coef) = mgs.append(array![1.0, 2.0, 0.0], 1e-9) {
46+
/// close_l2(&coef, &array![2.0, 1.0, 0.0], 1e-9);
47+
/// }
48+
/// ```
2549
pub trait Orthogonalizer {
2650
type Elem: Scalar;
2751

@@ -40,15 +64,18 @@ pub trait Orthogonalizer {
4064
self.len() == 0
4165
}
4266

43-
/// Orthogonalize given vector using current basis
67+
/// Calculate the coefficient to the given basis and residual norm
68+
///
69+
/// - The length of the returned array must be `self.len() + 1`
70+
/// - Last component is the residual norm
4471
///
4572
/// Panic
4673
/// -------
4774
/// - if the size of the input array mismatches to the dimension
4875
///
49-
fn orthogonalize<S>(&self, a: &mut ArrayBase<S, Ix1>) -> Array1<Self::Elem>
76+
fn coeff<S>(&self, a: ArrayBase<S, Ix1>) -> Array1<Self::Elem>
5077
where
51-
S: DataMut<Elem = Self::Elem>;
78+
S: Data<Elem = Self::Elem>;
5279

5380
/// Add new vector if the residual is larger than relative tolerance
5481
///

0 commit comments

Comments
 (0)