Skip to content

Commit 6065062

Browse files
authored
Add Barnes-Hut implementation of t-SNE (#101)
* Initial commit * Add generic type definition to wrapper * Add example * Improve error handling and documentation * Add example two website and run rustfmt * Fix example in documentation * Make t-SNE tests deterministic random * Address review
1 parent f597c04 commit 6065062

File tree

13 files changed

+437
-7
lines changed

13 files changed

+437
-7
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ members = [
7676
"algorithms/linfa-bayes",
7777
"algorithms/linfa-elasticnet",
7878
"algorithms/linfa-pls",
79+
"algorithms/linfa-tsne",
7980
"datasets",
8081
]
8182

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,14 @@ Where does `linfa` stand right now? [Are we learning yet?](http://www.arewelearn
3535
| [linear](algorithms/linfa-linear/) | Linear regression | Tested | Partial fit | Contains Ordinary Least Squares (OLS), Generalized Linear Models (GLM) |
3636
| [elasticnet](algorithms/linfa-elasticnet/) | Elastic Net | Tested | Supervised learning | Linear regression with elastic net constraints |
3737
| [logistic](algorithms/linfa-logistic/) | Logistic regression | Tested | Partial fit | Builds two-class logistic regression models
38-
| [reduction](algorithms/linfa-reduction/) | Dimensionality reduction | Tested | Pre-processing | Diffusion mapping and Principal Component Analysis (PCA) |
38+
| [reduction](algorithms/linfa-reduction/) | Dimensionality reduction | Tested | Pre-processing | Diffusion mapping and Principal Component Analysis (PCA) |
3939
| [trees](algorithms/linfa-trees/) | Decision trees | Experimental | Supervised learning | Linear decision trees
4040
| [svm](algorithms/linfa-svm/) | Support Vector Machines | Tested | Supervised learning | Classification or regression analysis of labeled datasets |
4141
| [hierarchical](algorithms/linfa-hierarchical/) | Agglomerative hierarchical clustering | Tested | Unsupervised learning | Cluster and build hierarchy of clusters |
4242
| [bayes](algorithms/linfa-bayes/) | Naive Bayes | Tested | Supervised learning | Contains Gaussian Naive Bayes |
4343
| [ica](algorithms/linfa-ica/) | Independent component analysis | Tested | Unsupervised learning | Contains FastICA implementation |
4444
| [pls](algorithms/linfa-pls/) | Partial Least Squares | Tested | Supervised learning | Contains PLS estimators for dimensionality reduction and regression |
45+
| [tsne](algorithms/linfa-tsne/) | Dimensionality reduction| Tested | Unsupervised learning | Contains exact solution and Barnes-Hut approximation t-SNE |
4546

4647
We believe that only a significant community effort can nurture, build, and sustain a machine learning ecosystem in Rust - there is no other way forward.
4748

algorithms/linfa-logistic/src/lib.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -384,12 +384,12 @@ fn convert_params<F: Float>(n_features: usize, w: &Array1<F>) -> (Array1<F>, F)
384384
} else if n_features + 1 == w.len() {
385385
(w.slice(s![..w.len() - 1]).to_owned(), w[w.len() - 1])
386386
} else {
387-
panic!(format!(
387+
panic!(
388388
"Unexpected length of parameter vector `w`, exected {} or {}, found {}",
389389
n_features,
390390
n_features + 1,
391391
w.len()
392-
));
392+
);
393393
}
394394
}
395395

algorithms/linfa-reduction/src/pca.rs

+17-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use serde_crate::{Deserialize, Serialize};
2828

2929
use linfa::{
3030
dataset::Records,
31-
traits::{Fit, PredictRef},
31+
traits::{Fit, PredictRef, Transformer},
3232
DatasetBase, Float,
3333
};
3434

@@ -173,6 +173,22 @@ impl<F: Float, D: Data<Elem = F>> PredictRef<ArrayBase<D, Ix2>, Array2<F>> for P
173173
}
174174
}
175175

176+
impl<F: Float, D: Data<Elem = F>, T>
177+
Transformer<DatasetBase<ArrayBase<D, Ix2>, T>, DatasetBase<Array2<F>, T>> for Pca<F>
178+
{
179+
fn transform(&self, ds: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
180+
let DatasetBase {
181+
records,
182+
targets,
183+
weights,
184+
..
185+
} = ds;
186+
187+
let new_records = self.predict_ref(&records);
188+
189+
DatasetBase::new(new_records, targets).with_weights(weights)
190+
}
191+
}
176192
#[cfg(test)]
177193
mod tests {
178194
use super::*;

algorithms/linfa-tsne/Cargo.toml

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
[package]
2+
name = "linfa-tsne"
3+
version = "0.3.1"
4+
authors = ["Lorenz Schmidt <[email protected]>"]
5+
edition = "2018"
6+
7+
description = "Barnes-Hut t-distributed stochastic neighbor embedding"
8+
license = "MIT/Apache-2.0"
9+
10+
repository = "https://github.com/rust-ml/linfa"
11+
readme = "README.md"
12+
13+
keywords = ["tsne", "data visualization", "clustering", "machine-learning", "linfa"]
14+
categories = ["algorithms", "mathematics", "science"]
15+
16+
[dependencies]
17+
thiserror = "1"
18+
ndarray = { version = "0.13", default-features = false }
19+
ndarray-rand = "0.11"
20+
bhtsne = "0.4.0"
21+
22+
linfa = { version = "0.3.1", path = "../.." }
23+
24+
[dev-dependencies]
25+
rand = "0.7"
26+
approx = "0.3"
27+
28+
linfa-datasets = { version = "0.3.1", path = "../../datasets", features = ["iris"] }
29+
linfa-reduction = { version = "0.3.1", path = "../linfa-reduction" }

algorithms/linfa-tsne/README.md

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# t-SNE
2+
3+
`linfa-tsne` provides a pure Rust implementation of exact and Barnes-Hut t-SNE.
4+
5+
## The Big Picture
6+
7+
`linfa-tsne` is a crate in the [`linfa`](https://crates.io/crates/linfa) ecosystem, an effort to create a toolkit for classical Machine Learning implemented in pure Rust, akin to Python's `scikit-learn`.
8+
9+
## Current state
10+
11+
`linfa-tsne` currently provides an implementation of the following methods:
12+
13+
- exact solution t-SNE
14+
- Barnes-Hut t-SNE
15+
16+
It wraps the [bhtsne](https://github.com/frjnn/bhtsne) crate, all kudos to them.
17+
18+
## Examples
19+
20+
There is an usage example in the `examples/` directory. The example uses a BLAS backend, to run it and use the `intel-mkl` library do:
21+
22+
```bash
23+
$ cargo run --example tsne --features linfa/intel-mkl-system
24+
```
25+
26+
You have to install the `gnuplot` library for plotting. Also take a look at the [README](https://github.com/rust-ml/linfa#blaslapack-backend) to see what BLAS/LAPACK backends are possible.
27+
28+
## License
29+
Dual-licensed to be compatible with the Rust project.
30+
31+
Licensed under the Apache License, Version 2.0 <http://www.apache.org/licenses/LICENSE-2.0> or the MIT license <http://opensource.org/licenses/MIT>, at your option. This file may not be copied, modified, or distributed except according to those terms.
32+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
set style increment user
2+
set style line 1 lc rgb 'red'
3+
set style line 2 lc rgb 'blue'
4+
set style line 3 lc rgb 'green'
5+
6+
set style data points
7+
plot 'iris.dat' using 1:2:3 linecolor variable pt 7 ps 2 t ''
+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
use linfa::traits::{Fit, Transformer};
2+
use linfa_reduction::Pca;
3+
use linfa_tsne::{Result, TSne};
4+
use std::{io::Write, process::Command};
5+
6+
fn main() -> Result<()> {
7+
let ds = linfa_datasets::iris();
8+
let ds = Pca::params(3).whiten(true).fit(&ds).transform(ds);
9+
10+
let ds = TSne::embedding_size(2)
11+
.perplexity(10.0)
12+
.approx_threshold(0.1)
13+
.transform(ds)?;
14+
15+
let mut f = std::fs::File::create("examples/iris.dat").unwrap();
16+
17+
for (x, y) in ds.sample_iter() {
18+
f.write(format!("{} {} {}\n", x[0], x[1], y[0]).as_bytes())
19+
.unwrap();
20+
}
21+
22+
Command::new("gnuplot")
23+
.arg("-p")
24+
.arg("examples/iris_plot.plt")
25+
.spawn()
26+
.expect(
27+
"Failed to launch gnuplot. Pleasure ensure that gnuplot is installed and on the $PATH.",
28+
);
29+
30+
Ok(())
31+
}

algorithms/linfa-tsne/src/error.rs

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
use thiserror::Error;
2+
3+
pub type Result<T> = std::result::Result<T, TSneError>;
4+
5+
#[derive(Error, Debug)]
6+
pub enum TSneError {
7+
#[error("negative perplexity")]
8+
NegativePerplexity,
9+
#[error("perplexity too large for number of samples")]
10+
PerplexityTooLarge,
11+
#[error("negative approximation threshold")]
12+
NegativeApproximationThreshold,
13+
#[error("embedding size larger than original dimensionality")]
14+
EmbeddingSizeTooLarge,
15+
#[error("number of preliminary iterations larger than total iterations")]
16+
PreliminaryIterationsTooLarge,
17+
#[error("invalid shaped array {0}")]
18+
InvalidShape(#[from] ndarray::ShapeError),
19+
#[error(transparent)]
20+
BaseCrate(#[from] linfa::Error),
21+
}

0 commit comments

Comments
 (0)