Skip to content

Commit d6a7a7e

Browse files
authored
Merge pull request #25 from maciejkula/parallel_predict
Parallel model fitting and prediction
2 parents 130d00a + 5c36a54 commit d6a7a7e

File tree

15 files changed

+623
-149
lines changed

15 files changed

+623
-149
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ build = "build.rs"
2020
[dependencies]
2121
rand = "0.3"
2222
rustc-serialize = "0.3"
23+
crossbeam = "0.2.9"
2324

2425
[build-dependencies]
2526
gcc = "0.3"

changelog.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
# Changelog
2+
3+
## [unreleased][unreleased]
4+
### Added
5+
- factorization machines
6+
- parallel fitting and prediction for one-vs-rest models
7+
28
## [0.3.1][2016-03-01]
39
### Changed
410
- NonzeroIterable now takes &self

readme.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ For full usage details, see the [API documentation](https://maciejkula.github.io
99

1010
## Introduction
1111

12-
This crate is mostly an excuse for me to learn Rust. Nevertheless, it contains reasonably effective
12+
This crate contains reasonably effective
1313
implementations of a number of common machine learning algorithms.
1414

1515
At the moment, `rustlearn` uses its own basic dense and sparse array types, but I will be happy
@@ -43,6 +43,10 @@ should be roughly competitive with Python `sklearn` implementations, both in acc
4343
- [accuracy](https://maciejkula.github.io/rustlearn/doc/rustlearn/metrics/fn.accuracy_score.html)
4444
- [ROC AUC score](https://maciejkula.github.io/rustlearn/doc/rustlearn/metrics/ranking/fn.roc_auc_score.html)
4545

46+
## Parallelization
47+
48+
A number of models support both parallel model fitting and prediction.
49+
4650
### Model serialization
4751

4852
Model serialization is supported via `rustc_serialize`. This will probably change to `serde` once compiler plugins land in stable.

src/array/dense.rs

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@
111111
112112

113113
use std::iter::Iterator;
114+
use std::ops::Range;
114115

115116
use array::traits::*;
116117

@@ -144,6 +145,7 @@ pub struct ArrayView<'a> {
144145

145146
/// Iterator over row or column views of a dense matrix.
146147
pub struct ArrayIterator<'a> {
148+
stop: usize,
147149
idx: usize,
148150
axis: ArrayIteratorAxis,
149151
array: &'a Array,
@@ -155,12 +157,7 @@ impl<'a> Iterator for ArrayIterator<'a> {
155157

156158
fn next(&mut self) -> Option<ArrayView<'a>> {
157159

158-
let bound = match self.axis {
159-
ArrayIteratorAxis::Row => self.array.rows,
160-
ArrayIteratorAxis::Column => self.array.cols,
161-
};
162-
163-
let result = if self.idx < bound {
160+
let result = if self.idx < self.stop {
164161
Some(ArrayView {
165162
idx: self.idx,
166163
axis: self.axis,
@@ -182,6 +179,7 @@ impl<'a> RowIterable for &'a Array {
182179
type Output = ArrayIterator<'a>;
183180
fn iter_rows(self) -> ArrayIterator<'a> {
184181
ArrayIterator {
182+
stop: self.rows(),
185183
idx: 0,
186184
axis: ArrayIteratorAxis::Row,
187185
array: self,
@@ -196,6 +194,21 @@ impl<'a> RowIterable for &'a Array {
196194
array: self,
197195
}
198196
}
197+
198+
fn iter_rows_range(self, range: Range<usize>) -> ArrayIterator<'a> {
199+
let stop = if range.end > self.rows {
200+
self.rows
201+
} else {
202+
range.end
203+
};
204+
205+
ArrayIterator {
206+
stop: stop,
207+
idx: range.start,
208+
axis: ArrayIteratorAxis::Row,
209+
array: self,
210+
}
211+
}
199212
}
200213

201214

@@ -204,6 +217,7 @@ impl<'a> ColumnIterable for &'a Array {
204217
type Output = ArrayIterator<'a>;
205218
fn iter_columns(self) -> ArrayIterator<'a> {
206219
ArrayIterator {
220+
stop: self.cols(),
207221
idx: 0,
208222
axis: ArrayIteratorAxis::Column,
209223
array: self,
@@ -218,6 +232,21 @@ impl<'a> ColumnIterable for &'a Array {
218232
array: self,
219233
}
220234
}
235+
236+
fn iter_columns_range(self, range: Range<usize>) -> ArrayIterator<'a> {
237+
let stop = if range.end > self.cols {
238+
self.cols
239+
} else {
240+
range.end
241+
};
242+
243+
ArrayIterator {
244+
stop: stop,
245+
idx: range.start,
246+
axis: ArrayIteratorAxis::Column,
247+
array: self,
248+
}
249+
}
221250
}
222251

223252

@@ -1018,4 +1047,21 @@ mod tests {
10181047
}
10191048
}
10201049
}
1050+
1051+
use datasets::iris;
1052+
1053+
#[test]
1054+
fn range_iteration() {
1055+
let (data, _) = iris::load_data();
1056+
1057+
let (start, stop) = (5, 10);
1058+
1059+
for (row_num, row) in data.iter_rows_range(start..stop).enumerate() {
1060+
for (col_idx, value) in row.iter_nonzero() {
1061+
assert!(value == data.get(start + row_num, col_idx));
1062+
}
1063+
1064+
assert!(row_num < (stop - start));
1065+
}
1066+
}
10211067
}

src/array/sparse.rs

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
//!
3838
//! ```
3939
use std::iter::Iterator;
40+
use std::ops::Range;
4041

4142
use array::dense::*;
4243
use array::traits::*;
@@ -79,8 +80,8 @@ pub struct SparseArrayViewIterator<'a> {
7980

8081
/// Iterator over row or column views of a sparse matrix.
8182
pub struct SparseArrayIterator<'a> {
83+
stop: usize,
8284
idx: usize,
83-
dim: usize,
8485
indices: &'a Vec<Vec<usize>>,
8586
data: &'a Vec<Vec<f32>>,
8687
}
@@ -290,7 +291,22 @@ impl<'a> RowIterable for &'a SparseRowArray {
290291
fn iter_rows(self) -> SparseArrayIterator<'a> {
291292
SparseArrayIterator {
292293
idx: 0,
293-
dim: self.rows,
294+
stop: self.rows,
295+
indices: &self.indices,
296+
data: &self.data,
297+
}
298+
}
299+
300+
fn iter_rows_range(self, range: Range<usize>) -> SparseArrayIterator<'a> {
301+
let stop = if range.end > self.rows {
302+
self.rows
303+
} else {
304+
range.end
305+
};
306+
307+
SparseArrayIterator {
308+
stop: stop,
309+
idx: range.start,
294310
indices: &self.indices,
295311
data: &self.data,
296312
}
@@ -388,11 +404,27 @@ impl<'a> ColumnIterable for &'a SparseColumnArray {
388404
fn iter_columns(self) -> SparseArrayIterator<'a> {
389405
SparseArrayIterator {
390406
idx: 0,
391-
dim: self.cols,
407+
stop: self.cols,
408+
indices: &self.indices,
409+
data: &self.data,
410+
}
411+
}
412+
413+
fn iter_columns_range(self, range: Range<usize>) -> SparseArrayIterator<'a> {
414+
let stop = if range.end > self.cols {
415+
self.cols
416+
} else {
417+
range.end
418+
};
419+
420+
SparseArrayIterator {
421+
stop: stop,
422+
idx: range.start,
392423
indices: &self.indices,
393424
data: &self.data,
394425
}
395426
}
427+
396428
fn view_column(self, idx: usize) -> SparseArrayView<'a> {
397429
SparseArrayView {
398430
indices: &self.indices[idx],
@@ -457,7 +489,7 @@ impl<'a> Iterator for SparseArrayIterator<'a> {
457489

458490
fn next(&mut self) -> Option<SparseArrayView<'a>> {
459491

460-
let result = if self.idx < self.dim {
492+
let result = if self.idx < self.stop {
461493
Some(SparseArrayView {
462494
indices: &self.indices[self.idx][..],
463495
data: &self.data[self.idx][..],
@@ -595,4 +627,33 @@ mod tests {
595627
&dense_arr.get_rows(&vec![1, 0])));
596628
assert!(allclose(&arr.get_rows(&(..)).todense(), &dense_arr.get_rows(&(..))));
597629
}
630+
631+
use datasets::iris;
632+
633+
#[test]
634+
fn range_iteration() {
635+
let (data, _) = iris::load_data();
636+
637+
let (start, stop) = (5, 10);
638+
639+
let data = SparseRowArray::from(&data);
640+
641+
for (row_num, row) in data.iter_rows_range(start..stop).enumerate() {
642+
for (col_idx, value) in row.iter_nonzero() {
643+
assert!(value == data.get(start + row_num, col_idx));
644+
}
645+
646+
assert!(row_num < (stop - start));
647+
}
648+
649+
let (start, stop) = (1, 3);
650+
651+
let data = SparseColumnArray::from(&data);
652+
653+
for (col_num, col) in data.iter_columns_range(start..stop).enumerate() {
654+
for (row_idx, value) in col.iter_nonzero() {
655+
assert!(value == data.get(row_idx, start + col_num));
656+
}
657+
}
658+
}
598659
}

src/array/traits.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ pub trait RowIterable {
7676
type Output: Iterator<Item = Self::Item>;
7777
/// Iterate over rows of the matrix.
7878
fn iter_rows(self) -> Self::Output;
79+
/// Iterate over a subset of rows of the matrix.
80+
fn iter_rows_range(self, range: Range<usize>) -> Self::Output;
7981
/// View a row of the matrix.
8082
fn view_row(self, idx: usize) -> Self::Item;
8183
}
@@ -88,6 +90,8 @@ pub trait ColumnIterable {
8890
type Output: Iterator<Item = Self::Item>;
8991
/// Iterate over columns of a the matrix.
9092
fn iter_columns(self) -> Self::Output;
93+
/// Iterate over a subset of columns of the matrix.
94+
fn iter_columns_range(self, range: Range<usize>) -> Self::Output;
9195
/// View a column of the matrix.
9296
fn view_column(self, idx: usize) -> Self::Item;
9397
}

src/datasets/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
pub mod iris;
44

55
#[cfg(test)]
6-
#[cfg(feature = "all_tests")]
6+
#[cfg(any(feature = "all_tests", feature = "bench"))]
77
pub mod newsgroups;

src/ensemble/random_forest.rs

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,14 @@ impl Hyperparameters {
107107
}
108108

109109

110-
#[derive(RustcEncodable, RustcDecodable)]
111-
#[derive(Clone)]
110+
#[derive(RustcEncodable, RustcDecodable, Clone)]
112111
pub struct RandomForest {
113112
trees: Vec<decision_tree::DecisionTree>,
114113
rng: EncodableRng,
115114
}
116115

117116

118-
impl SupervisedModel<Array> for RandomForest {
117+
impl<'a> SupervisedModel<&'a Array> for RandomForest {
119118
fn fit(&mut self, X: &Array, y: &Array) -> Result<(), &'static str> {
120119

121120
let mut rng = self.rng.clone();
@@ -145,7 +144,7 @@ impl SupervisedModel<Array> for RandomForest {
145144
}
146145

147146

148-
impl SupervisedModel<SparseRowArray> for RandomForest {
147+
impl<'a> SupervisedModel<&'a SparseRowArray> for RandomForest {
149148
fn fit(&mut self, X: &SparseRowArray, y: &Array) -> Result<(), &'static str> {
150149

151150
let mut rng = self.rng.clone();
@@ -253,6 +252,47 @@ mod tests {
253252
assert!(test_accuracy > 0.96);
254253
}
255254

255+
#[test]
256+
fn test_random_forest_iris_parallel() {
257+
let (data, target) = load_data();
258+
259+
let mut test_accuracy = 0.0;
260+
261+
let no_splits = 10;
262+
263+
let mut cv = CrossValidation::new(data.rows(), no_splits);
264+
cv.set_rng(StdRng::from_seed(&[100]));
265+
266+
for (train_idx, test_idx) in cv {
267+
268+
let x_train = data.get_rows(&train_idx);
269+
let x_test = data.get_rows(&test_idx);
270+
271+
let y_train = target.get_rows(&train_idx);
272+
273+
let mut tree_params = decision_tree::Hyperparameters::new(data.cols());
274+
tree_params.min_samples_split(10)
275+
.max_features(4)
276+
.rng(StdRng::from_seed(&[100]));
277+
278+
let mut model = Hyperparameters::new(tree_params, 10)
279+
.rng(StdRng::from_seed(&[100]))
280+
.one_vs_rest();
281+
282+
model.fit_parallel(&x_train, &y_train, 2).unwrap();
283+
284+
let test_prediction = model.predict_parallel(&x_test, 2).unwrap();
285+
286+
test_accuracy += accuracy_score(&target.get_rows(&test_idx), &test_prediction);
287+
}
288+
289+
test_accuracy /= no_splits as f32;
290+
291+
println!("Accuracy {}", test_accuracy);
292+
293+
assert!(test_accuracy > 0.96);
294+
}
295+
256296
#[test]
257297
fn test_random_forest_iris_sparse() {
258298
let (data, target) = load_data();

0 commit comments

Comments
 (0)