Skip to content

Commit e289247

Browse files
committed
Merge branch 'fix-eig' #298
2 parents 685fc3c + 419ce95 commit e289247

File tree

2 files changed

+78
-38
lines changed

2 files changed

+78
-38
lines changed

lax/src/eig.rs

Lines changed: 47 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,16 @@ macro_rules! impl_eig_complex {
2323
mut a: &mut [Self],
2424
) -> Result<(Vec<Self::Complex>, Vec<Self::Complex>)> {
2525
let (n, _) = l.size();
26-
// Because LAPACK assumes F-continious array, C-continious array should be taken Hermitian conjugate.
27-
// However, we utilize a fact that left eigenvector of A^H corresponds to the right eigenvector of A
26+
// LAPACK assumes a column-major input. A row-major input can
27+
// be interpreted as the transpose of a column-major input. So,
28+
// for row-major inputs, we we want to solve the following,
29+
// given the column-major input `A`:
30+
//
31+
// A^T V = V Λ ⟺ V^T A = Λ V^T ⟺ conj(V)^H A = Λ conj(V)^H
32+
//
33+
// So, in this case, the right eigenvectors are the conjugates
34+
// of the left eigenvectors computed with `A`, and the
35+
// eigenvalues are the eigenvalues computed with `A`.
2836
let (jobvl, jobvr) = if calc_v {
2937
match l {
3038
MatrixLayout::C { .. } => (b'V', b'N'),
@@ -118,8 +126,22 @@ macro_rules! impl_eig_real {
118126
mut a: &mut [Self],
119127
) -> Result<(Vec<Self::Complex>, Vec<Self::Complex>)> {
120128
let (n, _) = l.size();
121-
// Because LAPACK assumes F-continious array, C-continious array should be taken Hermitian conjugate.
122-
// However, we utilize a fact that left eigenvector of A^H corresponds to the right eigenvector of A
129+
// LAPACK assumes a column-major input. A row-major input can
130+
// be interpreted as the transpose of a column-major input. So,
131+
// for row-major inputs, we we want to solve the following,
132+
// given the column-major input `A`:
133+
//
134+
// A^T V = V Λ ⟺ V^T A = Λ V^T ⟺ conj(V)^H A = Λ conj(V)^H
135+
//
136+
// So, in this case, the right eigenvectors are the conjugates
137+
// of the left eigenvectors computed with `A`, and the
138+
// eigenvalues are the eigenvalues computed with `A`.
139+
//
140+
// We could conjugate the eigenvalues instead of the
141+
// eigenvectors, but we have to reconstruct the eigenvectors
142+
// into new matrices anyway, and by not modifying the
143+
// eigenvalues, we preserve the nice ordering specified by
144+
// `sgeev`/`dgeev`.
123145
let (jobvl, jobvr) = if calc_v {
124146
match l {
125147
MatrixLayout::C { .. } => (b'V', b'N'),
@@ -211,40 +233,34 @@ macro_rules! impl_eig_real {
211233
// - v(j) = VR(:,j) + i*VR(:,j+1)
212234
// - v(j+1) = VR(:,j) - i*VR(:,j+1).
213235
//
214-
// ```
215-
// j -> <----pair----> <----pair---->
216-
// [ ... (real), (imag), (imag), (imag), (imag), ... ] : eigs
217-
// ^ ^ ^ ^ ^
218-
// false false true false true : is_conjugate_pair
219-
// ```
236+
// In the C-layout case, we need the conjugates of the left
237+
// eigenvectors, so the signs should be reversed.
238+
220239
let n = n as usize;
221240
let v = vr.or(vl).unwrap();
222241
let mut eigvecs = unsafe { vec_uninit(n * n) };
223-
let mut is_conjugate_pair = false; // flag for check `j` is complex conjugate
224-
for j in 0..n {
225-
if eig_im[j] == 0.0 {
226-
// j-th eigenvalue is real
227-
for i in 0..n {
228-
eigvecs[i + j * n] = Self::complex(v[i + j * n], 0.0);
242+
let mut col = 0;
243+
while col < n {
244+
if eig_im[col] == 0. {
245+
// The corresponding eigenvalue is real.
246+
for row in 0..n {
247+
let re = v[row + col * n];
248+
eigvecs[row + col * n] = Self::complex(re, 0.);
229249
}
250+
col += 1;
230251
} else {
231-
// j-th eigenvalue is complex
232-
// complex conjugated pair can be `j-1` or `j+1`
233-
if is_conjugate_pair {
234-
let j_pair = j - 1;
235-
assert!(j_pair < n);
236-
for i in 0..n {
237-
eigvecs[i + j * n] = Self::complex(v[i + j_pair * n], v[i + j * n]);
238-
}
239-
} else {
240-
let j_pair = j + 1;
241-
assert!(j_pair < n);
242-
for i in 0..n {
243-
eigvecs[i + j * n] =
244-
Self::complex(v[i + j * n], -v[i + j_pair * n]);
252+
// This is a complex conjugate pair.
253+
assert!(col + 1 < n);
254+
for row in 0..n {
255+
let re = v[row + col * n];
256+
let mut im = v[row + (col + 1) * n];
257+
if jobvl == b'V' {
258+
im = -im;
245259
}
260+
eigvecs[row + col * n] = Self::complex(re, im);
261+
eigvecs[row + (col + 1) * n] = Self::complex(re, -im);
246262
}
247-
is_conjugate_pair = !is_conjugate_pair;
263+
col += 2;
248264
}
249265
}
250266

ndarray-linalg/tests/eig.rs

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,19 @@
11
use ndarray::*;
22
use ndarray_linalg::*;
33

4+
fn sorted_eigvals<T: Scalar>(eigvals: ArrayView1<'_, T>) -> Array1<T> {
5+
let mut indices: Vec<usize> = (0..eigvals.len()).collect();
6+
indices.sort_by(|&ind1, &ind2| {
7+
let e1 = eigvals[ind1];
8+
let e2 = eigvals[ind2];
9+
e1.re()
10+
.partial_cmp(&e2.re())
11+
.unwrap()
12+
.then(e1.im().partial_cmp(&e2.im()).unwrap())
13+
});
14+
indices.iter().map(|&ind| eigvals[ind]).collect()
15+
}
16+
417
// Test Av_i = e_i v_i for i = 0..n
518
fn test_eig<T: Scalar>(
619
a: ArrayView2<'_, T>,
@@ -90,7 +103,10 @@ fn test_matrix_real<T: Scalar>() -> Array2<T::Real> {
90103
}
91104

92105
fn test_matrix_real_t<T: Scalar>() -> Array2<T::Real> {
93-
test_matrix_real::<T>().t().permuted_axes([1, 0]).to_owned()
106+
let orig = test_matrix_real::<T>();
107+
let mut out = Array2::zeros(orig.raw_dim().f());
108+
out.assign(&orig);
109+
out
94110
}
95111

96112
fn answer_eig_real<T: Scalar>() -> Array1<T::Complex> {
@@ -157,10 +173,10 @@ fn test_matrix_complex<T: Scalar>() -> Array2<T::Complex> {
157173
}
158174

159175
fn test_matrix_complex_t<T: Scalar>() -> Array2<T::Complex> {
160-
test_matrix_complex::<T>()
161-
.t()
162-
.permuted_axes([1, 0])
163-
.to_owned()
176+
let orig = test_matrix_complex::<T>();
177+
let mut out = Array2::zeros(orig.raw_dim().f());
178+
out.assign(&orig);
179+
out
164180
}
165181

166182
fn answer_eig_complex<T: Scalar>() -> Array1<T::Complex> {
@@ -218,9 +234,17 @@ macro_rules! impl_test_real {
218234
fn [<$real _eigvals_t>]() {
219235
let a = test_matrix_real_t::<$real>();
220236
let (e1, _vecs) = a.eig().unwrap();
237+
assert_close_l2!(
238+
&sorted_eigvals(e1.view()),
239+
&sorted_eigvals(answer_eig_real::<$real>().view()),
240+
1.0e-3
241+
);
221242
let e2 = a.eigvals().unwrap();
222-
assert_close_l2!(&e1, &answer_eig_real::<$real>(), 1.0e-3);
223-
assert_close_l2!(&e2, &answer_eig_real::<$real>(), 1.0e-3);
243+
assert_close_l2!(
244+
&sorted_eigvals(e2.view()),
245+
&sorted_eigvals(answer_eig_real::<$real>().view()),
246+
1.0e-3
247+
);
224248
}
225249

226250
#[test]

0 commit comments

Comments
 (0)