Skip to content

Commit cd92891

Browse files
authored
Merge pull request #297 from jturner314/fix-factorize_into_inv
Fix Inverse for LUFactorized
2 parents 261fdd6 + 6766c31 commit cd92891

File tree

3 files changed

+107
-13
lines changed

3 files changed

+107
-13
lines changed

lax/src/solve.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ macro_rules! impl_solve {
4242

4343
fn inv(l: MatrixLayout, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
4444
let (n, _) = l.size();
45+
if n == 0 {
46+
// Do nothing for empty matrices.
47+
return Ok(());
48+
}
4549

4650
// calc work size
4751
let mut info = 0;

ndarray-linalg/src/solve.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,9 @@ pub trait Solve<A: Scalar> {
199199
pub struct LUFactorized<S: Data + RawDataClone> {
200200
/// The factors `L` and `U`; the unit diagonal elements of `L` are not
201201
/// stored.
202-
pub a: ArrayBase<S, Ix2>,
202+
a: ArrayBase<S, Ix2>,
203203
/// The pivot indices that define the permutation matrix `P`.
204-
pub ipiv: Pivot,
204+
ipiv: Pivot,
205205
}
206206

207207
impl<A, S> Solve<A> for LUFactorized<S>
@@ -387,8 +387,15 @@ where
387387
type Output = Array2<A>;
388388

389389
fn inv(&self) -> Result<Array2<A>> {
390+
// Preserve the existing layout. This is required to obtain the correct
391+
// result, because the result of `A::inv` is layout-dependent.
392+
let a = if self.a.is_standard_layout() {
393+
replicate(&self.a)
394+
} else {
395+
replicate(&self.a.t()).reversed_axes()
396+
};
390397
let f = LUFactorized {
391-
a: replicate(&self.a),
398+
a,
392399
ipiv: self.ipiv.clone(),
393400
};
394401
f.inv_into()

ndarray-linalg/tests/inv.rs

Lines changed: 93 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,103 @@
11
use ndarray::*;
22
use ndarray_linalg::*;
33

4+
fn test_inv_random<A>(n: usize, set_f: bool, rtol: A::Real)
5+
where
6+
A: Scalar + Lapack,
7+
{
8+
let a: Array2<A> = random([n; 2].set_f(set_f));
9+
let identity = Array2::eye(n);
10+
assert_close_l2!(&a.inv().unwrap().dot(&a), &identity, rtol);
11+
assert_close_l2!(
12+
&a.factorize().unwrap().inv().unwrap().dot(&a),
13+
&identity,
14+
rtol
15+
);
16+
assert_close_l2!(
17+
&a.clone().factorize_into().unwrap().inv().unwrap().dot(&a),
18+
&identity,
19+
rtol
20+
);
21+
}
22+
23+
fn test_inv_into_random<A>(n: usize, set_f: bool, rtol: A::Real)
24+
where
25+
A: Scalar + Lapack,
26+
{
27+
let a: Array2<A> = random([n; 2].set_f(set_f));
28+
let identity = Array2::eye(n);
29+
assert_close_l2!(&a.clone().inv_into().unwrap().dot(&a), &identity, rtol);
30+
assert_close_l2!(
31+
&a.factorize().unwrap().inv_into().unwrap().dot(&a),
32+
&identity,
33+
rtol
34+
);
35+
assert_close_l2!(
36+
&a.clone()
37+
.factorize_into()
38+
.unwrap()
39+
.inv_into()
40+
.unwrap()
41+
.dot(&a),
42+
&identity,
43+
rtol
44+
);
45+
}
46+
47+
#[test]
48+
fn inv_empty() {
49+
test_inv_random::<f32>(0, false, 0.);
50+
test_inv_random::<f64>(0, false, 0.);
51+
test_inv_random::<c32>(0, false, 0.);
52+
test_inv_random::<c64>(0, false, 0.);
53+
}
54+
55+
#[test]
56+
fn inv_random_float() {
57+
for n in 1..=8 {
58+
for &set_f in &[false, true] {
59+
test_inv_random::<f32>(n, set_f, 1e-3);
60+
test_inv_random::<f64>(n, set_f, 1e-9);
61+
}
62+
}
63+
}
64+
65+
#[test]
66+
fn inv_random_complex() {
67+
for n in 1..=8 {
68+
for &set_f in &[false, true] {
69+
test_inv_random::<c32>(n, set_f, 1e-3);
70+
test_inv_random::<c64>(n, set_f, 1e-9);
71+
}
72+
}
73+
}
74+
75+
#[test]
76+
fn inv_into_empty() {
77+
test_inv_into_random::<f32>(0, false, 0.);
78+
test_inv_into_random::<f64>(0, false, 0.);
79+
test_inv_into_random::<c32>(0, false, 0.);
80+
test_inv_into_random::<c64>(0, false, 0.);
81+
}
82+
483
#[test]
5-
fn inv_random() {
6-
let a: Array2<f64> = random((3, 3));
7-
let ai: Array2<_> = (&a).inv().unwrap();
8-
let id = Array::eye(3);
9-
assert_close_l2!(&ai.dot(&a), &id, 1e-7);
84+
fn inv_into_random_float() {
85+
for n in 1..=8 {
86+
for &set_f in &[false, true] {
87+
test_inv_into_random::<f32>(n, set_f, 1e-3);
88+
test_inv_into_random::<f64>(n, set_f, 1e-9);
89+
}
90+
}
1091
}
1192

1293
#[test]
13-
fn inv_random_t() {
14-
let a: Array2<f64> = random((3, 3).f());
15-
let ai: Array2<_> = (&a).inv().unwrap();
16-
let id = Array::eye(3);
17-
assert_close_l2!(&ai.dot(&a), &id, 1e-7);
94+
fn inv_into_random_complex() {
95+
for n in 1..=8 {
96+
for &set_f in &[false, true] {
97+
test_inv_into_random::<c32>(n, set_f, 1e-3);
98+
test_inv_into_random::<c64>(n, set_f, 1e-9);
99+
}
100+
}
18101
}
19102

20103
#[test]

0 commit comments

Comments
 (0)