Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fit trait modification and cross validation proposal #122

Merged
merged 11 commits into from
Apr 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 5 additions & 24 deletions CONTRIBUTE.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,19 @@ This document should be used as a reference when contributing to Linfa. It descr

An important part of the Linfa ecosystem is how to organize data for the training and estimation process. A [Dataset](src/dataset/mod.rs) serves this purpose. It is a small wrapper of data and targets types and should be used as argument for the [Fit](src/traits.rs) trait. Its parametrization is generic, with [Records](src/dataset/mod.rs) representing input data (atm only implemented for `ndarray::ArrayBase`) and [Targets](src/dataset/mod.rs) for targets.

You can find traits for different classes of algorithms [here](src/traits.rs). For example, to implement a fittable algorithm, which takes an `Array2` as input data and boolean array as targets:
You can find traits for different classes of algorithms [here](src/traits.rs). For example, to implement a fittable algorithm, which takes an `Array2` as input data and boolean array as targets and could fail with an `Error` struct:
```rust
impl<'a, F: Float> Fit<'a, Array2<F>, Array1<bool>> for SvmParams<F, Pr> {
impl<F: Float> Fit<Array2<F>, Array1<bool>, Error> for SvmParams<F, Pr> {
type Object = Svm<F, Pr>;

fn fit(&self, dataset: &Dataset<Array2<F>, Array1<bool>>) -> Self::Object {
fn fit(&self, dataset: &Dataset<Array2<F>, Array1<bool>>) -> Result<Self::Object, Error> {
...
}
}
```
the type of the dataset is `&Dataset<Kernel<F>, Array1<bool>>`, and lifetime `'a` is the required lifetime for the fitted state. It produces a fitted state, called `Svm<F, Pr>` with probability type `Pr`.
where the type of the input dataset is `&Dataset<Kernel<F>, Array1<bool>>`. It produces a result with a fitted state, called `Svm<F, Pr>` with probability type `Pr`, or an error of type `Error` in case of failure.

The [Predict](src/traits.rs) should be implemented with dataset arguments, as well as arrays. If a dataset is provided, then predict takes its ownership and returns a new dataset with predicted targets. For an array, predict takes a reference and returns predicted targets. In the same context, SVM implemented predict like this:
```rust
impl<F: Float, T: Targets> Predict<Dataset<Array2<F>, T>, Dataset<Array2<F>, Vec<Pr>>>
for Svm<F, Pr>
{
fn predict(&self, data: Dataset<Array2<F>, T>) -> Dataset<Array2<F>, Vec<Pr>> {
...
}
}
```
and
```rust
impl<F: Float, D: Data<Elem = F>> Predict<ArrayBase<D, Ix2>, Vec<Pr>> for Svm<F, Pr> {
fn predict(&self, data: ArrayBase<D, Ix2>) -> Vec<Pr> {
...
}
}
```

For an example of a `Transformer` please look into the [linfa-kernel](linfa-kernel/src/lib.rs) implementation.
The [Predict](src/traits.rs) trait has its own section later in this document, while for an example of a `Transformer` please look into the [linfa-kernel](linfa-kernel/src/lib.rs) implementation.

## Parameters and builder

Expand Down
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,13 @@ features = ["cblas"]
default-features = false

[dependencies.openblas-src]
version = "0.9.0"
version = "0.10.4"
optional = true
default-features = false
features = ["cblas"]

[dev-dependencies]
ndarray-rand = "0.13"

linfa-datasets = { path = "datasets", features = ["winequality", "iris", "diabetes"] }

[workspace]
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Where does `linfa` stand right now? [Are we learning yet?](http://www.arewelearn
| [ica](algorithms/linfa-ica/) | Independent component analysis | Tested | Unsupervised learning | Contains FastICA implementation |
| [pls](algorithms/linfa-pls/) | Partial Least Squares | Tested | Supervised learning | Contains PLS estimators for dimensionality reduction and regression |
| [tsne](algorithms/linfa-tsne/) | Dimensionality reduction| Tested | Unsupervised learning | Contains exact solution and Barnes-Hut approximation t-SNE |
| [preprocessing](algorithms/linfa-preprocessing/) |Normalization & Vectorization| Tested | Pre-processing | Contains data normalization/whitening and count vectorization/tf-idf |

We believe that only a significant community effort can nurture, build, and sustain a machine learning ecosystem in Rust - there is no other way forward.

Expand Down
10 changes: 5 additions & 5 deletions algorithms/linfa-bayes/src/gaussian_nb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use ndarray::{s, Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Axis, Data,
use ndarray_stats::QuantileExt;
use std::collections::HashMap;

use crate::error::Result;
use crate::error::{BayesError, Result};
use linfa::dataset::{AsTargets, DatasetBase, Labels};
use linfa::traits::{Fit, IncrementalFit, PredictRef};
use linfa::Float;
Expand Down Expand Up @@ -40,13 +40,13 @@ impl GaussianNbParams {
}
}

impl<F, D, L> Fit<'_, ArrayBase<D, Ix2>, L> for GaussianNbParams
impl<F, D, L> Fit<ArrayBase<D, Ix2>, L, BayesError> for GaussianNbParams
where
F: Float,
D: Data<Elem = F>,
L: AsTargets<Elem = usize> + Labels<Elem = usize>,
{
type Object = Result<GaussianNb<F>>;
type Object = GaussianNb<F>;

/// Fit the model
///
Expand Down Expand Up @@ -77,7 +77,7 @@ where
/// # Ok(())
/// # }
/// ```
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, L>) -> Self::Object {
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, L>) -> Result<Self::Object> {
// We extract the unique classes in sorted order
let mut unique_classes = dataset.targets.labels();
unique_classes.sort_unstable();
Expand Down Expand Up @@ -303,7 +303,7 @@ where
///
/// __Panics__ if the input is empty or if pairwise orderings are undefined
/// (this occurs in presence of NaN values)
fn predict_ref<'a>(&'a self, x: &ArrayBase<D, Ix2>) -> Array1<usize> {
fn predict_ref(&self, x: &ArrayBase<D, Ix2>) -> Array1<usize> {
let joint_log_likelihood = self.joint_log_likelihood(x.view());

// We store the classes and likelihood info in an vec and matrix
Expand Down
2 changes: 1 addition & 1 deletion algorithms/linfa-clustering/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ ndarray-rand = "0.13"
ndarray-stats = "0.4"
num-traits = "0.2"
rand_isaac = "0.3"
thiserror = "1"
partitions = "0.2.4"

linfa = { version = "0.3.1", path = "../..", features = ["ndarray-linalg"] }

[dev-dependencies]
Expand Down
10 changes: 5 additions & 5 deletions algorithms/linfa-clustering/src/appx_dbscan/hyperparameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,21 @@ impl<F: Float> AppxDbscanHyperParams<F> {
}

fn build(tolerance: F, min_points: usize, slack: F) -> Self {
if tolerance <= F::cast(0.) {
if tolerance <= F::zero() {
panic!("`tolerance` must be greater than 0!");
}
// There is always at least one neighbor to a point (itself)
if min_points <= 1 {
panic!("`min_points` must be greater than 1!");
}

if slack <= F::cast(0.) {
if slack <= F::zero() {
panic!("`slack` must be greater than 0!");
}
Self {
tolerance: tolerance,
min_points: min_points,
slack: slack,
tolerance,
min_points,
slack,
appx_tolerance: tolerance * (F::one() + slack),
}
}
Expand Down
14 changes: 7 additions & 7 deletions algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,10 @@ impl<F: Float> GaussianMixtureModel<F> {
reg_covar: F,
) -> Result<(Array1<F>, Array2<F>, Array3<F>)> {
let nk = resp.sum_axis(Axis(0));
if nk.min().unwrap() < &(F::cast(10.) * F::epsilon()) {
if nk.min()? < &(F::cast(10.) * F::epsilon()) {
return Err(GmmError::EmptyCluster(format!(
"Cluster #{} has no more point. Consider decreasing number of clusters or change initialization.",
nk.argmin().unwrap() + 1
nk.argmin()? + 1
)));
}

Expand Down Expand Up @@ -400,12 +400,12 @@ impl<F: Float> GaussianMixtureModel<F> {
}
}

impl<'a, F: Float, R: Rng + SeedableRng + Clone, D: Data<Elem = F>, T> Fit<'a, ArrayBase<D, Ix2>, T>
for GmmHyperParams<F, R>
impl<F: Float, R: Rng + SeedableRng + Clone, D: Data<Elem = F>, T>
Fit<ArrayBase<D, Ix2>, T, GmmError> for GmmHyperParams<F, R>
{
type Object = Result<GaussianMixtureModel<F>>;
type Object = GaussianMixtureModel<F>;

fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Self::Object {
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
self.validate()?;
let observations = dataset.records().view();
let mut gmm = GaussianMixtureModel::<F>::new(self, dataset, self.rng())?;
Expand Down Expand Up @@ -488,7 +488,7 @@ mod tests {
}
impl MultivariateNormal {
pub fn new(mean: &ArrayView1<f64>, covariance: &ArrayView2<f64>) -> LAResult<Self> {
let lower = covariance.cholesky(UPLO::Lower).unwrap();
let lower = covariance.cholesky(UPLO::Lower)?;
Ok(MultivariateNormal {
mean: mean.to_owned(),
covariance: covariance.to_owned(),
Expand Down
61 changes: 20 additions & 41 deletions algorithms/linfa-clustering/src/gaussian_mixture/errors.rs
Original file line number Diff line number Diff line change
@@ -1,58 +1,37 @@
use crate::k_means::KMeansError;
use ndarray_linalg::error::LinalgError;
use std::error::Error;
use std::fmt::{self, Display};

use thiserror::Error;
pub type Result<T> = std::result::Result<T, GmmError>;

/// An error when modeling a GMM algorithm
#[derive(Debug)]
#[derive(Error, Debug)]
pub enum GmmError {
/// When any of the hyperparameters are set the wrong value
#[error("Invalid value encountered: {0}")]
InvalidValue(String),
/// Errors encountered during linear algebra operations
LinalgError(LinalgError),
#[error(
"Linalg Error: \
Fitting the mixture model failed because some components have \
ill-defined empirical covariance (for instance caused by singleton \
or collapsed samples). Try to decrease the number of components, \
or increase reg_covar. Error: {0}"
)]
LinalgError(#[from] LinalgError),
/// When a cluster has no more data point while fitting GMM
#[error("Fitting failed: {0}")]
EmptyCluster(String),
/// When lower bound computation fails
#[error("Fitting failed: {0}")]
LowerBoundError(String),
/// When fitting EM algorithm does not converge
#[error("Fitting failed: {0}")]
NotConverged(String),
/// When initial KMeans fails
KMeansError(String),
}

impl Display for GmmError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::InvalidValue(message) => write!(f, "Invalid value encountered: {}", message),
Self::LinalgError(error) => write!(
f,
"Linalg Error: \
Fitting the mixture model failed because some components have \
ill-defined empirical covariance (for instance caused by singleton \
or collapsed samples). Try to decrease the number of components, \
or increase reg_covar. Error: {}",
error
),
Self::EmptyCluster(message) => write!(f, "Fitting failed: {}", message),
Self::LowerBoundError(message) => write!(f, "Fitting failed: {}", message),
Self::NotConverged(message) => write!(f, "Fitting failed: {}", message),
Self::KMeansError(message) => write!(f, "Initial KMeans failed: {}", message),
}
}
}

impl Error for GmmError {}

impl From<LinalgError> for GmmError {
fn from(error: LinalgError) -> GmmError {
GmmError::LinalgError(error)
}
}

impl From<KMeansError> for GmmError {
fn from(error: KMeansError) -> GmmError {
GmmError::KMeansError(error.to_string())
}
#[error("Initial KMeans failed: {0}")]
KMeansError(#[from] KMeansError),
#[error(transparent)]
LinfaError(#[from] linfa::error::Error),
#[error(transparent)]
MinMaxError(#[from] ndarray_stats::errors::MinMaxError),
}
8 changes: 4 additions & 4 deletions algorithms/linfa-clustering/src/k_means/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,17 +215,17 @@ impl<F: Float> KMeans<F> {
}
}

impl<'a, F: Float, R: Rng + Clone + SeedableRng, D: Data<Elem = F>, T> Fit<'a, ArrayBase<D, Ix2>, T>
for KMeansHyperParams<F, R>
impl<F: Float, R: Rng + Clone + SeedableRng, D: Data<Elem = F>, T>
Fit<ArrayBase<D, Ix2>, T, KMeansError> for KMeansHyperParams<F, R>
{
type Object = Result<KMeans<F>>;
type Object = KMeans<F>;

/// Given an input matrix `observations`, with shape `(n_observations, n_features)`,
/// `fit` identifies `n_clusters` centroids based on the training data distribution.
///
/// An instance of `KMeans` is returned.
///
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Self::Object {
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
let mut rng = self.rng();
let observations = dataset.records().view();
let n_samples = dataset.nsamples();
Expand Down
22 changes: 7 additions & 15 deletions algorithms/linfa-clustering/src/k_means/errors.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,19 @@
use std::error::Error;
use std::fmt::{self, Display};
use thiserror::Error;

pub type Result<T> = std::result::Result<T, KMeansError>;

/// An error when modeling a KMeans algorithm
#[derive(Debug)]
#[derive(Error, Debug)]
pub enum KMeansError {
/// When any of the hyperparameters are set the wrong value
#[error("Invalid value encountered: {0}")]
InvalidValue(String),
/// When inertia computation fails
#[error("Fitting failed: {0}")]
InertiaError(String),
/// When fitting algorithm does not converge
#[error("Fitting failed: {0}")]
NotConverged(String),
#[error(transparent)]
LinfaError(#[from] linfa::error::Error),
}

impl Display for KMeansError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::InvalidValue(message) => write!(f, "Invalid value encountered: {}", message),
Self::InertiaError(message) => write!(f, "Fitting failed: {}", message),
Self::NotConverged(message) => write!(f, "Fitting failed: {}", message),
}
}
}

impl Error for KMeansError {}
26 changes: 26 additions & 0 deletions algorithms/linfa-elasticnet/examples/elasticnet_cv.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use linfa::prelude::*;
use linfa_elasticnet::{ElasticNet, Result};

fn main() -> Result<()> {
// load Diabetes dataset (mutable to allow fast k-folding)
let mut dataset = linfa_datasets::diabetes();

// parameters to compare
let ratios = vec![0.1, 0.2, 0.5, 0.7, 1.0];

// create a model for each parameter
let models = ratios
.iter()
.map(|ratio| ElasticNet::params().penalty(0.3).l1_ratio(*ratio))
.collect::<Vec<_>>();

// get the mean r2 validation score across all folds for each model
let r2_values =
dataset.cross_validate(5, &models, |prediction, truth| prediction.r2(&truth))?;

for (ratio, r2) in ratios.iter().zip(r2_values.iter()) {
println!("L1 ratio: {}, r2 score: {}", ratio, r2);
}

Ok(())
}
6 changes: 3 additions & 3 deletions algorithms/linfa-elasticnet/src/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ use linfa::{

use super::{ElasticNet, ElasticNetParams, Error, Result};

impl<'a, F, D, T> Fit<'a, ArrayBase<D, Ix2>, T> for ElasticNetParams<F>
impl<F, D, T> Fit<ArrayBase<D, Ix2>, T, crate::error::Error> for ElasticNetParams<F>
where
F: Float + Lapack,
D: Data<Elem = F>,
T: AsTargets<Elem = F>,
{
type Object = Result<ElasticNet<F>>;
type Object = ElasticNet<F>;

/// Fit an elastic net model given a feature matrix `x` and a target
/// variable `y`.
Expand All @@ -28,7 +28,7 @@ where
/// Returns a `FittedElasticNet` object which contains the fitted
/// parameters and can be used to `predict` values of the target variable
/// for new feature values.
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<ElasticNet<F>> {
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
self.validate_params()?;
let target = dataset.try_single_target()?;

Expand Down
3 changes: 2 additions & 1 deletion algorithms/linfa-ica/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ ndarray-rand = "0.13"
ndarray-stats = "0.4"
num-traits = "0.2"
rand_isaac = "0.3"
thiserror = "1"

linfa = { version = "0.3.1", path = "../.." }
linfa = { version = "0.3.1", path = "../..", features = ["ndarray-linalg"] }

[dev-dependencies]
ndarray-npy = { version = "0.7", default-features = false }
Expand Down
Loading