Skip to content

Commit 4f0ef55

Browse files
authored
Merge pull request #155 from rust-ndarray/arnoldi2
Arnoldi iterator
2 parents 7bb4dd4 + 7f7b806 commit 4f0ef55

File tree

5 files changed

+400
-98
lines changed

5 files changed

+400
-98
lines changed

src/krylov/arnoldi.rs

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
//! Arnoldi iteration
2+
3+
use super::*;
4+
use crate::norm::Norm;
5+
use num_traits::One;
6+
use std::iter::*;
7+
8+
/// Execute Arnoldi iteration as Rust iterator
9+
///
10+
/// - [Arnoldi iteration - Wikipedia](https://en.wikipedia.org/wiki/Arnoldi_iteration)
11+
///
12+
pub struct Arnoldi<A, S, F, Ortho>
13+
where
14+
A: Scalar,
15+
S: DataMut<Elem = A>,
16+
F: Fn(&mut ArrayBase<S, Ix1>),
17+
Ortho: Orthogonalizer<Elem = A>,
18+
{
19+
a: F,
20+
/// Next vector (normalized `|v|=1`)
21+
v: ArrayBase<S, Ix1>,
22+
/// Orthogonalizer
23+
ortho: Ortho,
24+
/// Coefficients to be composed into H-matrix
25+
h: Vec<Array1<A>>,
26+
}
27+
28+
impl<A, S, F, Ortho> Arnoldi<A, S, F, Ortho>
29+
where
30+
A: Scalar + Lapack,
31+
S: DataMut<Elem = A>,
32+
F: Fn(&mut ArrayBase<S, Ix1>),
33+
Ortho: Orthogonalizer<Elem = A>,
34+
{
35+
/// Create an Arnoldi iterator from any linear operator `a`
36+
pub fn new(a: F, mut v: ArrayBase<S, Ix1>, mut ortho: Ortho) -> Self {
37+
assert_eq!(ortho.len(), 0);
38+
assert!(ortho.tolerance() < One::one());
39+
// normalize before append because |v| may be smaller than ortho.tolerance()
40+
let norm = v.norm_l2();
41+
azip!(mut v(&mut v) in { *v = v.div_real(norm) });
42+
ortho.append(v.view());
43+
Arnoldi {
44+
a,
45+
v,
46+
ortho,
47+
h: Vec::new(),
48+
}
49+
}
50+
51+
/// Dimension of Krylov subspace
52+
pub fn dim(&self) -> usize {
53+
self.ortho.len()
54+
}
55+
56+
/// Iterate until convergent
57+
pub fn complete(mut self) -> (Q<A>, H<A>) {
58+
for _ in &mut self {} // execute iteration until convergent
59+
let q = self.ortho.get_q();
60+
let n = self.h.len();
61+
let mut h = Array2::zeros((n, n).f());
62+
for (i, hc) in self.h.iter().enumerate() {
63+
let m = std::cmp::min(n, i + 2);
64+
for j in 0..m {
65+
h[(j, i)] = hc[j];
66+
}
67+
}
68+
(q, h)
69+
}
70+
}
71+
72+
impl<A, S, F, Ortho> Iterator for Arnoldi<A, S, F, Ortho>
73+
where
74+
A: Scalar + Lapack,
75+
S: DataMut<Elem = A>,
76+
F: Fn(&mut ArrayBase<S, Ix1>),
77+
Ortho: Orthogonalizer<Elem = A>,
78+
{
79+
type Item = Array1<A>;
80+
81+
fn next(&mut self) -> Option<Self::Item> {
82+
(self.a)(&mut self.v);
83+
let result = self.ortho.div_append(&mut self.v);
84+
let norm = self.v.norm_l2();
85+
azip!(mut v(&mut self.v) in { *v = v.div_real(norm) });
86+
match result {
87+
AppendResult::Added(coef) => {
88+
self.h.push(coef.clone());
89+
Some(coef)
90+
}
91+
AppendResult::Dependent(coef) => {
92+
self.h.push(coef);
93+
None
94+
}
95+
}
96+
}
97+
}
98+
99+
/// Interpret a matrix as a linear operator
100+
pub fn mul_mat<A, S1, S2>(a: ArrayBase<S1, Ix2>) -> impl Fn(&mut ArrayBase<S2, Ix1>)
101+
where
102+
A: Scalar,
103+
S1: Data<Elem = A>,
104+
S2: DataMut<Elem = A>,
105+
{
106+
let (n, m) = a.dim();
107+
assert_eq!(n, m, "Input matrix must be square");
108+
move |x| {
109+
assert_eq!(m, x.len(), "Input matrix and vector sizes mismatch");
110+
let ax = a.dot(x);
111+
azip!(mut x(x), ax in { *x = ax });
112+
}
113+
}
114+
115+
/// Utility to execute Arnoldi iteration with Householder reflection
116+
pub fn arnoldi_householder<A, S1, S2>(a: ArrayBase<S1, Ix2>, v: ArrayBase<S2, Ix1>, tol: A::Real) -> (Q<A>, H<A>)
117+
where
118+
A: Scalar + Lapack,
119+
S1: Data<Elem = A>,
120+
S2: DataMut<Elem = A>,
121+
{
122+
let householder = Householder::new(v.len(), tol);
123+
Arnoldi::new(mul_mat(a), v, householder).complete()
124+
}
125+
126+
/// Utility to execute Arnoldi iteration with modified Gram-Schmit orthogonalizer
127+
pub fn arnoldi_mgs<A, S1, S2>(a: ArrayBase<S1, Ix2>, v: ArrayBase<S2, Ix1>, tol: A::Real) -> (Q<A>, H<A>)
128+
where
129+
A: Scalar + Lapack,
130+
S1: Data<Elem = A>,
131+
S2: DataMut<Elem = A>,
132+
{
133+
let mgs = MGS::new(v.len(), tol);
134+
Arnoldi::new(mul_mat(a), v, mgs).complete()
135+
}

src/krylov/householder.rs

Lines changed: 75 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@ use crate::{inner::*, norm::*};
88
use num_traits::One;
99

1010
/// Calc a reflactor `w` from a vector `x`
11-
pub fn calc_reflector<A, S>(x: &mut ArrayBase<S, Ix1>) -> A
11+
pub fn calc_reflector<A, S>(x: &mut ArrayBase<S, Ix1>)
1212
where
1313
A: Scalar + Lapack,
1414
S: DataMut<Elem = A>,
1515
{
16+
assert!(x.len() > 0);
1617
let norm = x.norm_l2();
1718
let alpha = -x[0].mul_real(norm / x[0].abs());
1819
x[0] -= alpha;
1920
let inv_rev_norm = A::Real::one() / x.norm_l2();
2021
azip!(mut a(x) in { *a = a.mul_real(inv_rev_norm)});
21-
alpha
2222
}
2323

2424
/// Take a reflection `P = I - 2ww^T`
@@ -50,12 +50,19 @@ pub struct Householder<A: Scalar> {
5050
///
5151
/// The coefficient is copied into another array, and this does not contain
5252
v: Vec<Array1<A>>,
53+
54+
/// Tolerance
55+
tol: A::Real,
5356
}
5457

5558
impl<A: Scalar + Lapack> Householder<A> {
5659
/// Create a new orthogonalizer
57-
pub fn new(dim: usize) -> Self {
58-
Householder { dim, v: Vec::new() }
60+
pub fn new(dim: usize, tol: A::Real) -> Self {
61+
Householder {
62+
dim,
63+
v: Vec::new(),
64+
tol,
65+
}
5966
}
6067

6168
/// Take a Reflection `P = I - 2ww^T`
@@ -92,12 +99,32 @@ impl<A: Scalar + Lapack> Householder<A> {
9299
}
93100
}
94101

95-
fn eval_residual<S>(&self, a: &ArrayBase<S, Ix1>) -> A::Real
102+
/// Compose coefficients array using reflected vector
103+
fn compose_coefficients<S>(&self, a: &ArrayBase<S, Ix1>) -> Coefficients<A>
96104
where
97105
S: Data<Elem = A>,
98106
{
99-
let l = self.v.len();
100-
a.slice(s![l..]).norm_l2()
107+
let k = self.len();
108+
let res = a.slice(s![k..]).norm_l2();
109+
let mut c = Array1::zeros(k + 1);
110+
azip!(mut c(c.slice_mut(s![..k])), a(a.slice(s![..k])) in { *c = a });
111+
if k < a.len() {
112+
let ak = a[k];
113+
c[k] = -ak.mul_real(res / ak.abs());
114+
} else {
115+
c[k] = A::from_real(res);
116+
}
117+
c
118+
}
119+
120+
/// Construct the residual vector from reflected vector
121+
fn construct_residual<S>(&self, a: &mut ArrayBase<S, Ix1>)
122+
where
123+
S: DataMut<Elem = A>,
124+
{
125+
let k = self.len();
126+
azip!(mut a( a.slice_mut(s![..k])) in { *a = A::zero() });
127+
self.backward_reflection(a);
101128
}
102129
}
103130

@@ -112,45 +139,61 @@ impl<A: Scalar + Lapack> Orthogonalizer for Householder<A> {
112139
self.v.len()
113140
}
114141

142+
fn tolerance(&self) -> A::Real {
143+
self.tol
144+
}
145+
146+
fn decompose<S>(&self, a: &mut ArrayBase<S, Ix1>) -> Array1<A>
147+
where
148+
S: DataMut<Elem = A>,
149+
{
150+
self.forward_reflection(a);
151+
let coef = self.compose_coefficients(a);
152+
self.construct_residual(a);
153+
coef
154+
}
155+
115156
fn coeff<S>(&self, a: ArrayBase<S, Ix1>) -> Array1<A>
116157
where
117158
S: Data<Elem = A>,
118159
{
119160
let mut a = a.into_owned();
120161
self.forward_reflection(&mut a);
121-
let res = self.eval_residual(&a);
122-
let k = self.len();
123-
let mut c = Array1::zeros(k + 1);
124-
azip!(mut c(c.slice_mut(s![..k])), a(a.slice(s![..k])) in { *c = a });
125-
c[k] = A::from_real(res);
126-
c
162+
self.compose_coefficients(&a)
127163
}
128164

129-
fn append<S>(&mut self, mut a: ArrayBase<S, Ix1>, rtol: A::Real) -> Result<Array1<A>, Array1<A>>
165+
fn div_append<S>(&mut self, a: &mut ArrayBase<S, Ix1>) -> AppendResult<A>
130166
where
131167
S: DataMut<Elem = A>,
132168
{
133169
assert_eq!(a.len(), self.dim);
134170
let k = self.len();
135-
136-
self.forward_reflection(&mut a);
137-
let mut coef = Array::zeros(k + 1);
138-
for i in 0..k {
139-
coef[i] = a[i];
140-
}
141-
if self.is_full() {
142-
return Err(coef); // coef[k] must be zero in this case
171+
self.forward_reflection(a);
172+
let coef = self.compose_coefficients(a);
173+
if coef[k].abs() < self.tol {
174+
return AppendResult::Dependent(coef);
143175
}
176+
calc_reflector(&mut a.slice_mut(s![k..]));
177+
self.v.push(a.to_owned());
178+
self.construct_residual(a);
179+
AppendResult::Added(coef)
180+
}
144181

145-
let alpha = calc_reflector(&mut a.slice_mut(s![k..]));
146-
coef[k] = alpha;
147-
148-
if alpha.abs() < rtol {
149-
// linearly dependent
150-
return Err(coef);
182+
fn append<S>(&mut self, a: ArrayBase<S, Ix1>) -> AppendResult<A>
183+
where
184+
S: Data<Elem = A>,
185+
{
186+
assert_eq!(a.len(), self.dim);
187+
let mut a = a.into_owned();
188+
let k = self.len();
189+
self.forward_reflection(&mut a);
190+
let coef = self.compose_coefficients(&a);
191+
if coef[k].abs() < self.tol {
192+
return AppendResult::Dependent(coef);
151193
}
152-
self.v.push(a.into_owned());
153-
Ok(coef)
194+
calc_reflector(&mut a.slice_mut(s![k..]));
195+
self.v.push(a.to_owned());
196+
AppendResult::Added(coef)
154197
}
155198

156199
fn get_q(&self) -> Q<A> {
@@ -175,8 +218,8 @@ where
175218
A: Scalar + Lapack,
176219
S: Data<Elem = A>,
177220
{
178-
let h = Householder::new(dim);
179-
qr(iter, h, rtol, strategy)
221+
let h = Householder::new(dim, rtol);
222+
qr(iter, h, strategy)
180223
}
181224

182225
#[cfg(test)]

0 commit comments

Comments
 (0)