Skip to content

Commit a21f738

Browse files
authored
Merge pull request #221 from rust-ndarray/lapack-merge-macros
Cleanup macro definitions
2 parents 8c1b069 + e3a7767 commit a21f738

File tree

2 files changed

+42
-180
lines changed

2 files changed

+42
-180
lines changed

lax/src/eigh.rs

Lines changed: 25 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,13 @@ pub trait Eigh_: Scalar {
2525
}
2626

2727
macro_rules! impl_eigh {
28-
($scalar:ty, $ev:path, $evg:path) => {
28+
(@real, $scalar:ty, $ev:path, $evg:path) => {
29+
impl_eigh!(@body, $scalar, $ev, $evg, );
30+
};
31+
(@complex, $scalar:ty, $ev:path, $evg:path) => {
32+
impl_eigh!(@body, $scalar, $ev, $evg, rwork);
33+
};
34+
(@body, $scalar:ty, $ev:path, $evg:path, $($rwork_ident:ident),*) => {
2935
impl Eigh_ for $scalar {
3036
fn eigh(
3137
calc_v: bool,
@@ -37,11 +43,14 @@ macro_rules! impl_eigh {
3743
let n = layout.len();
3844
let jobz = if calc_v { b'V' } else { b'N' };
3945
let mut eigs = vec![Self::Real::zero(); n as usize];
40-
let n = n as i32;
46+
47+
$(
48+
let mut $rwork_ident = vec![Self::Real::zero(); 3 * n as usize - 2];
49+
)*
4150

4251
// calc work size
4352
let mut info = 0;
44-
let mut work_size = [0.0];
53+
let mut work_size = [Self::zero()];
4554
unsafe {
4655
$ev(
4756
jobz,
@@ -52,6 +61,7 @@ macro_rules! impl_eigh {
5261
&mut eigs,
5362
&mut work_size,
5463
-1,
64+
$(&mut $rwork_ident,)*
5565
&mut info,
5666
);
5767
}
@@ -70,6 +80,7 @@ macro_rules! impl_eigh {
7080
&mut eigs,
7181
&mut work,
7282
lwork as i32,
83+
$(&mut $rwork_ident,)*
7384
&mut info,
7485
);
7586
}
@@ -88,11 +99,14 @@ macro_rules! impl_eigh {
8899
let n = layout.len();
89100
let jobz = if calc_v { b'V' } else { b'N' };
90101
let mut eigs = vec![Self::Real::zero(); n as usize];
91-
let n = n as i32;
102+
103+
$(
104+
let mut $rwork_ident = vec![Self::Real::zero(); 3 * n as usize - 2];
105+
)*
92106

93107
// calc work size
94108
let mut info = 0;
95-
let mut work_size = [0.0];
109+
let mut work_size = [Self::zero()];
96110
unsafe {
97111
$evg(
98112
&[1],
@@ -106,6 +120,7 @@ macro_rules! impl_eigh {
106120
&mut eigs,
107121
&mut work_size,
108122
-1,
123+
$(&mut $rwork_ident,)*
109124
&mut info,
110125
);
111126
}
@@ -127,6 +142,7 @@ macro_rules! impl_eigh {
127142
&mut eigs,
128143
&mut work,
129144
lwork as i32,
145+
$(&mut $rwork_ident,)*
130146
&mut info,
131147
);
132148
}
@@ -137,85 +153,7 @@ macro_rules! impl_eigh {
137153
};
138154
} // impl_eigh!
139155

140-
impl_eigh!(f64, lapack::dsyev, lapack::dsygv);
141-
impl_eigh!(f32, lapack::ssyev, lapack::ssygv);
142-
143-
// splitted for RWORK
144-
macro_rules! impl_eighc {
145-
($scalar:ty, $ev:path, $evg:path) => {
146-
impl Eigh_ for $scalar {
147-
fn eigh(
148-
calc_v: bool,
149-
layout: MatrixLayout,
150-
uplo: UPLO,
151-
mut a: &mut [Self],
152-
) -> Result<Vec<Self::Real>> {
153-
assert_eq!(layout.len(), layout.lda());
154-
let n = layout.len();
155-
let jobz = if calc_v { b'V' } else { b'N' };
156-
let mut eigs = vec![Self::Real::zero(); n as usize];
157-
let mut work = vec![Self::zero(); 2 * n as usize - 1];
158-
let mut rwork = vec![Self::Real::zero(); 3 * n as usize - 2];
159-
let mut info = 0;
160-
let n = n as i32;
161-
162-
unsafe {
163-
$ev(
164-
jobz,
165-
uplo as u8,
166-
n,
167-
&mut a,
168-
n,
169-
&mut eigs,
170-
&mut work,
171-
2 * n - 1,
172-
&mut rwork,
173-
&mut info,
174-
)
175-
};
176-
info.as_lapack_result()?;
177-
Ok(eigs)
178-
}
179-
180-
fn eigh_generalized(
181-
calc_v: bool,
182-
layout: MatrixLayout,
183-
uplo: UPLO,
184-
mut a: &mut [Self],
185-
mut b: &mut [Self],
186-
) -> Result<Vec<Self::Real>> {
187-
assert_eq!(layout.len(), layout.lda());
188-
let n = layout.len();
189-
let jobz = if calc_v { b'V' } else { b'N' };
190-
let mut eigs = vec![Self::Real::zero(); n as usize];
191-
let mut work = vec![Self::zero(); 2 * n as usize - 1];
192-
let mut rwork = vec![Self::Real::zero(); 3 * n as usize - 2];
193-
let n = n as i32;
194-
let mut info = 0;
195-
196-
unsafe {
197-
$evg(
198-
&[1],
199-
jobz,
200-
uplo as u8,
201-
n,
202-
&mut a,
203-
n,
204-
&mut b,
205-
n,
206-
&mut eigs,
207-
&mut work,
208-
2 * n - 1,
209-
&mut rwork,
210-
&mut info,
211-
)
212-
};
213-
info.as_lapack_result()?;
214-
Ok(eigs)
215-
}
216-
}
217-
};
218-
} // impl_eigh!
219-
220-
impl_eighc!(c64, lapack::zheev, lapack::zhegv);
221-
impl_eighc!(c32, lapack::cheev, lapack::chegv);
156+
impl_eigh!(@real, f64, lapack::dsyev, lapack::dsygv);
157+
impl_eigh!(@real, f32, lapack::ssyev, lapack::ssygv);
158+
impl_eigh!(@complex, c64, lapack::zheev, lapack::zhegv);
159+
impl_eigh!(@complex, c32, lapack::cheev, lapack::chegv);

lax/src/svd.rs

Lines changed: 17 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -44,94 +44,14 @@ pub trait SVD_: Scalar {
4444
) -> Result<SVDOutput<Self>>;
4545
}
4646

47-
macro_rules! impl_svd_real {
48-
($scalar:ty, $gesvd:path) => {
49-
impl SVD_ for $scalar {
50-
unsafe fn svd(
51-
l: MatrixLayout,
52-
calc_u: bool,
53-
calc_vt: bool,
54-
mut a: &mut [Self],
55-
) -> Result<SVDOutput<Self>> {
56-
let ju = match l {
57-
MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u),
58-
MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt),
59-
};
60-
let jvt = match l {
61-
MatrixLayout::F { .. } => FlagSVD::from_bool(calc_vt),
62-
MatrixLayout::C { .. } => FlagSVD::from_bool(calc_u),
63-
};
64-
65-
let m = l.lda();
66-
let mut u = match ju {
67-
FlagSVD::All => Some(vec![Self::zero(); (m * m) as usize]),
68-
FlagSVD::No => None,
69-
};
70-
71-
let n = l.len();
72-
let mut vt = match jvt {
73-
FlagSVD::All => Some(vec![Self::zero(); (n * n) as usize]),
74-
FlagSVD::No => None,
75-
};
76-
77-
let k = std::cmp::min(m, n);
78-
let mut s = vec![Self::Real::zero(); k as usize];
79-
80-
// eval work size
81-
let mut info = 0;
82-
let mut work_size = [Self::zero()];
83-
$gesvd(
84-
ju as u8,
85-
jvt as u8,
86-
m,
87-
n,
88-
&mut a,
89-
m,
90-
&mut s,
91-
u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
92-
m,
93-
vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
94-
n,
95-
&mut work_size,
96-
-1,
97-
&mut info,
98-
);
99-
info.as_lapack_result()?;
100-
101-
// calc
102-
let lwork = work_size[0].to_usize().unwrap();
103-
let mut work = vec![Self::zero(); lwork];
104-
$gesvd(
105-
ju as u8,
106-
jvt as u8,
107-
m,
108-
n,
109-
&mut a,
110-
m,
111-
&mut s,
112-
u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
113-
m,
114-
vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
115-
n,
116-
&mut work,
117-
lwork as i32,
118-
&mut info,
119-
);
120-
info.as_lapack_result()?;
121-
match l {
122-
MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }),
123-
MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }),
124-
}
125-
}
126-
}
47+
macro_rules! impl_svd {
48+
(@real, $scalar:ty, $gesvd:path) => {
49+
impl_svd!(@body, $scalar, $gesvd, );
12750
};
128-
} // impl_svd_real!
129-
130-
impl_svd_real!(f64, lapack::dgesvd);
131-
impl_svd_real!(f32, lapack::sgesvd);
132-
133-
macro_rules! impl_svd_complex {
134-
($scalar:ty, $gesvd:path) => {
51+
(@complex, $scalar:ty, $gesvd:path) => {
52+
impl_svd!(@body, $scalar, $gesvd, rwork);
53+
};
54+
(@body, $scalar:ty, $gesvd:path, $($rwork_ident:ident),*) => {
13555
impl SVD_ for $scalar {
13656
unsafe fn svd(
13757
l: MatrixLayout,
@@ -163,7 +83,9 @@ macro_rules! impl_svd_complex {
16383
let k = std::cmp::min(m, n);
16484
let mut s = vec![Self::Real::zero(); k as usize];
16585

166-
let mut rwork = vec![Self::Real::zero(); 5 * k as usize];
86+
$(
87+
let mut $rwork_ident = vec![Self::Real::zero(); 5 * k as usize];
88+
)*
16789

16890
// eval work size
16991
let mut info = 0;
@@ -182,7 +104,7 @@ macro_rules! impl_svd_complex {
182104
n,
183105
&mut work_size,
184106
-1,
185-
&mut rwork,
107+
$(&mut $rwork_ident,)*
186108
&mut info,
187109
);
188110
info.as_lapack_result()?;
@@ -204,7 +126,7 @@ macro_rules! impl_svd_complex {
204126
n,
205127
&mut work,
206128
lwork as i32,
207-
&mut rwork,
129+
$(&mut $rwork_ident,)*
208130
&mut info,
209131
);
210132
info.as_lapack_result()?;
@@ -215,7 +137,9 @@ macro_rules! impl_svd_complex {
215137
}
216138
}
217139
};
218-
} // impl_svd_real!
140+
} // impl_svd!
219141

220-
impl_svd_complex!(c64, lapack::zgesvd);
221-
impl_svd_complex!(c32, lapack::cgesvd);
142+
impl_svd!(@real, f64, lapack::dgesvd);
143+
impl_svd!(@real, f32, lapack::sgesvd);
144+
impl_svd!(@complex, c64, lapack::zgesvd);
145+
impl_svd!(@complex, c32, lapack::cgesvd);

0 commit comments

Comments
 (0)