Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fit trait modification and cross validation proposal #122

Merged
merged 11 commits into from
Apr 28, 2021

Conversation

Sauro98
Copy link
Member

@Sauro98 Sauro98 commented Apr 21, 2021

Changes in the Fit trait

from this:

pub trait Fit<'a, R: Records, T> {
    type Object: 'a;

    fn fit(&self, dataset: &DatasetBase<R, T>) -> Self::Object;
}

to this:

pub trait Fit<R: Records, T, E: std::error::Error + std::convert::From<linfa::Error>> {
    type Object;

    fn fit(&self, dataset: &DatasetBase<R, T>) -> Result<Self::Object, E>;
}

by:

  • removing 'a lifetime (left from previous svm implementation, not actually used by any algorithm anymore)
  • forcing every implementation to return a result with an error struct (every implementation except PCA already returned an error, some implementations returned a String as the error but the transition keeps the same error messages)

Edit 1:

Added conversion from linfa error bound on fit error type. Every sub-crate should be able to handle the errors caused by using the base crate

Cross validation POC

Cross-validation can be defined by exploiting the new Fit definition. This is what it looks like for regression:

    use linfa::prelude::*;
    use linfa_elasticnet::{ElasticNet, Result};

    // load Diabetes dataset (mutable to allow fast k-folding)
    let mut dataset = linfa_datasets::diabetes();

    // prameters to compare
    let ratios = vec![0.1, 0.2, 0.5, 0.7, 1.0];

    // create a model for each parameter
    let models = ratios
        .iter()
        .map(|ratio| ElasticNet::params().penalty(0.3).l1_ratio(*ratio))
        .collect::<Vec<_>>();

    // get the mean r2 validation score across all folds for each model
    let r2_values =
        dataset.cross_validate(5, &models, |prediction, truth| prediction.r2(&truth))?;

    for (ratio, r2) in ratios.iter().zip(r2_values.iter()) {
        println!("L1 ratio: {}, r2 score: {}", ratio, r2);
    }

And this is what it looks like for classification:

    use linfa::prelude::*;
    use linfa_logistic::error::Result;
    use linfa_logistic::LogisticRegression;

    // Load dataset. Mutability is needed for fast cross validation
    let mut dataset =
        linfa_datasets::winequality().map_targets(|x| if *x > 6 { "good" } else { "bad" });

    // define a sequence of models to compare. In this case the
    // models will differ by the amount of l2 regularization
    let alphas = vec![0.1, 1., 10.];
    let models: Vec<_> = alphas
        .iter()
        .map(|alpha| {
            LogisticRegression::default()
                .alpha(*alpha)
                .max_iterations(150)
        })
        .collect();

    // use cross validation to compute the validation accuracy of each model. The
    // accuracy of each model will be averaged across the folds, 5 in this case
    let accuracies = dataset.cross_validate(5, &models, |prediction, truth| {
        Ok(prediction.confusion_matrix(truth)?.accuracy())
    })?;

    // display the accuracy of the models along with their regularization coefficient
    for (alpha, accuracy) in alphas.iter().zip(accuracies.iter()) {
        println!("Alpha: {}, accuracy: {} ", alpha, accuracy);
    }

Possible follow up

Redefine the Transformer trait to return a Result, like the t-sne implementation does, in order to avoid panicking as much as possible. It would be better to wait for #121 in order to avoid dealing with the unwrap() calls when performing float conversions.

Notes

  • Currently there are two versions of cross-validation: one for single target datasets and one for multi-target datasets. The main reasons for this are:
    • Array1s of evaluation values can be constructed with collect, Array2 cannot and must be populated row by row (could be avoided by stacking, or maybe there is a dedicated ndarray method I am not aware of)
    • Regression metrics behave differently when they are applied Array1-Array2 or Array2-Array1, making writing evaluation closures possibly more difficult than it should.

More than likely there is a solution to both problems, but I got stuck on it for too long and so I would consider it out of scope for this PR, unless someone has an easy solution to suggest. Forcing the user to give a single return value in the multi-target case (like a mean across the targets) could make the problem easier but the evaluation for each target would be lost in the process.

  • Got this error in elasticnet just once:
     Running target/release/deps/linfa_elasticnet-d4c6cf99e56de5dd

running 11 tests
test algorithm::tests::coordinate_descent_lowers_objective ... ok
test algorithm::tests::elastic_net_2d_toy_example_works ... ok
test algorithm::tests::elastic_net_diabetes_1_works_like_sklearn ... ok
test algorithm::tests::diabetes_z_score ... ok
test algorithm::tests::elastic_net_penalty_works ... ok
test algorithm::tests::elastic_net_toy_example_works ... ok
test algorithm::tests::lasso_toy_example_works ... ok
test algorithm::tests::lasso_zero_works ... ok
error: test failed, to rerun pass '-p linfa-elasticnet --lib'

Caused by:
  process didn't exit successfully: `/home/ivano/Scrivania/github_projects/linfa/target/release/deps/linfa_elasticnet-d4c6cf99e56de5dd` (signal: 11, SIGSEGV: invalid memory reference)

@Sauro98
Copy link
Member Author

Sauro98 commented Apr 21, 2021

@YuhanLiin I got this test error on kmeans just once and I haven't been able to reproduce it:

---- k_means::init::tests::test_sample_subsequent_candidates stdout ----
thread 'k_means::init::tests::test_sample_subsequent_candidates' panicked at 'assertion failed: `(left == right)`
  left: `3`,
 right: `2`', algorithms/linfa-clustering/src/k_means/init.rs:259:9

It seems like the test takes a fixed seed in input so it should be deterministic, right? Could there be some randomness in the rng seeds for each thread?
Maybe it was just my machine acting up 😅, and also the error doesn't show up on the CI system, but I still tagged you in case you have some insight on why it might have happened

Copy link
Member

@bytesnake bytesnake left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the proposal, because

  1. it makes the associated type easier to read
  2. makes error handling more explicit
  3. allows to perform cross validation in an easier way

what do you think about renaming the functions to cross_validate_sum and cross_validate_mean?

@YuhanLiin
Copy link
Collaborator

@YuhanLiin I got this test error on kmeans just once and I haven't been able to reproduce it:

---- k_means::init::tests::test_sample_subsequent_candidates stdout ----
thread 'k_means::init::tests::test_sample_subsequent_candidates' panicked at 'assertion failed: `(left == right)`
  left: `3`,
 right: `2`', algorithms/linfa-clustering/src/k_means/init.rs:259:9

It seems like the test takes a fixed seed in input so it should be deterministic, right? Could there be some randomness in the rng seeds for each thread?
Maybe it was just my machine acting up sweat_smile, and also the error doesn't show up on the CI system, but I still tagged you in case you have some insight on why it might have happened

That RNG is seeded once per Rayon thread (not once per Rayon task), so if Rayon spawns a different number of threads on your machine then you may get different results. I'm going to look into a more deterministic way to run this test.

@Sauro98 Sauro98 marked this pull request as ready for review April 23, 2021 15:18
@Sauro98
Copy link
Member Author

Sauro98 commented Apr 23, 2021

I updated the PR description with the changes brought and some considerations on the single-target/multi-target issue

@codecov-commenter
Copy link

codecov-commenter commented Apr 25, 2021

Codecov Report

Merging #122 (2fe5df4) into master (6866450) will increase coverage by 1.91%.
The diff coverage is 67.57%.

❗ Current head 2fe5df4 differs from pull request most recent head e74a93a. Consider uploading reports for the commit e74a93a to get more accurate results
Impacted file tree graph

@@            Coverage Diff             @@
##           master     #122      +/-   ##
==========================================
+ Coverage   57.98%   59.90%   +1.91%     
==========================================
  Files          78       74       -4     
  Lines        6819     7015     +196     
==========================================
+ Hits         3954     4202     +248     
+ Misses       2865     2813      -52     
Impacted Files Coverage Δ
...infa-clustering/src/appx_dbscan/hyperparameters.rs 18.51% <0.00%> (ø)
algorithms/linfa-linear/src/glm/link.rs 57.57% <0.00%> (+1.14%) ⬆️
algorithms/linfa-pls/src/lib.rs 90.00% <ø> (ø)
...hms/linfa-preprocessing/src/count_vectorization.rs 78.09% <ø> (+0.95%) ⬆️
algorithms/linfa-trees/src/decision_trees/tikz.rs 0.00% <0.00%> (ø)
src/dataset/impl_targets.rs 18.46% <22.22%> (+11.79%) ⬆️
...rithms/linfa-trees/src/decision_trees/algorithm.rs 52.53% <25.00%> (+1.09%) ⬆️
src/dataset/impl_dataset.rs 42.94% <39.02%> (-0.52%) ⬇️
algorithms/linfa-linear/src/glm/distribution.rs 75.80% <43.75%> (+3.62%) ⬆️
algorithms/linfa-ica/src/fast_ica.rs 50.68% <52.94%> (+1.03%) ⬆️
... and 42 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 6866450...e74a93a. Read the comment docs.

@bytesnake
Copy link
Member

I will review later this day :)

cross valdation POC
@Sauro98
Copy link
Member Author

Sauro98 commented Apr 26, 2021

Ok, I re-enabled the tests that were commented out, the tests just needed a little tweak to work with the non-determinism of the hashset used to construct the confusion matrices. There is still one test that is not passing, for what I believe are good reasons. Basically, it tries to compute the confusion matrix between two sets by first calling with_labels on the ground truth set to exclude one label. The problem now is that the two sets have different lengths and so the confusion matrix comes out messy, as it can be expected. Is this a feature that would be useful, and so that I should try to fix, or is it better to just remove that single test?
The code in the test was really outdated so maybe it reflected a property of some previous implementation. The good thing is that it allowed me to catch a weird bug in the with_labels method.

I'm available for a quick chat in zulip/zoom if that's needed

@bytesnake
Copy link
Member

added a PR for this PR here Sauro98#3

  • uses WithLapack for linfa-pls (renaming of function still missing)
  • adds build script for archlinux

@Sauro98
Copy link
Member Author

Sauro98 commented Apr 27, 2021

I'm writing tests for the with_label method and I am wondering how the function should behave. Right now it works well for the single target case, but not for the multi-label case. What I mean is that:

let records = array![[0.,0.],[0.,0.]];
let targets = array![[0,7],[1,8]];
let dataset = DatasetBase::from((records, targets)).with_labels(&[&[0,1],&[8]]);

returns a dataset with both samples inside, while I would expect the first sample to be removed since its second target does not figure in the allowed ones.

Should it behave like I am expecting or should the sample be kept, since its first target is allowed? In the latter case, we can simplify the method by accepting just a slice instead of a slice of slices. Otherwise ,I'll try to modify the method to check that each target is allowed.

@bytesnake
Copy link
Member

I would expect that this is an any: if any target is contained in one of the labels, then the sample is retained. I think that with_labels only accepting a simple slice is enough for our usecase now. You can create a case where you want to retain a label in one target and not the other and then the modificated with_labels would behaviour differently. But nobody is using that right now. In the future we could add a with_labels_any and with_labels_all which accepts a double level slice.

Sauro98 and others added 4 commits April 27, 2021 21:39
* Move linfa-pls to new Lapack bound

* More cleanups

* Playing around with `cross_validation`

* Make generic over dimension

* Run rustfmt

* Add simple test for multi target cv

* Run rustfmt

* Rename cross validation multi target to `cross_validate_multi`

* Run rustfmt
@Sauro98
Copy link
Member Author

Sauro98 commented Apr 28, 2021

It keeps throwing a segmentation fault when running tests with tarpaulin on linfa-pls

@bytesnake
Copy link
Member

It keeps throwing a segmentation fault when running tests with tarpaulin on linfa-pls

mhh I will try to debug this

@bytesnake
Copy link
Member

bytesnake commented Apr 28, 2021

@bytesnake
Copy link
Member

bytesnake commented Apr 28, 2021

taking the eigendecomposition in the sign flippy function by value makes tarpaulin happy. Can you add that? Haven't opened another PR for this

--- a/algorithms/linfa-pls/src/pls_svd.rs
+++ b/algorithms/linfa-pls/src/pls_svd.rs
@@ -74,7 +74,7 @@ impl<F: Float, D: Data<Elem = F>> Fit<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>, PlsE
         // safe unwraps because both parameters are set to true in above call
         let u = u.unwrap().slice_move(s![.., ..self.n_components]);
         let vt = vt.unwrap().slice_move(s![..self.n_components, ..]);
-        let (u, vt) = utils::svd_flip(&u, &vt);
+        let (u, vt) = utils::svd_flip(u, vt);
         let v = vt.reversed_axes();
 
         let x_weights = u;
diff --git a/algorithms/linfa-pls/src/utils.rs b/algorithms/linfa-pls/src/utils.rs
index fc85677..1b780e0 100644
--- a/algorithms/linfa-pls/src/utils.rs
+++ b/algorithms/linfa-pls/src/utils.rs
@@ -89,8 +89,8 @@ pub fn svd_flip_1d<F: Float>(
 }
 
 pub fn svd_flip<F: Float>(
-    u: &ArrayBase<impl Data<Elem = F>, Ix2>,
-    v: &ArrayBase<impl Data<Elem = F>, Ix2>,
+    u: ArrayBase<impl Data<Elem = F>, Ix2>,
+    v: ArrayBase<impl Data<Elem = F>, Ix2>,
 ) -> (Array2<F>, Array2<F>) {
     // columns of u, rows of v
     let abs_u = u.mapv(|v| v.abs());
@@ -101,7 +101,7 @@ pub fn svd_flip<F: Float>(
         .and(&max_abs_val_indices)
         .and(&range)
         .apply(|s, &i, &j| *s = u[[i, j]].signum());
-    (u * &signs, v * &signs.insert_axis(Axis(1)))
+    (&u * &signs, &v * &signs.insert_axis(Axis(1)))
 }

@Sauro98
Copy link
Member Author

Sauro98 commented Apr 28, 2021

I will

@Sauro98
Copy link
Member Author

Sauro98 commented Apr 28, 2021

I have also updated the contribution guide to reflect the changes in the Fit trait, and I removed an outdated section about prediction traits, since they were discussed later in the documents in the updated way.
I did not add anything about the new Lapack traits, since I didn't know whether you wanted to do it; otherwise, I'd be happy to try and write something down 👍🏻 .

@Sauro98 Sauro98 changed the title [WIP] Fit trait modification and cross validation proposal Fit trait modification and cross validation proposal Apr 28, 2021
Copy link
Member

@bytesnake bytesnake left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you also want to add an example snippet for the website? Otherwise ready to merge 👍

@Sauro98
Copy link
Member Author

Sauro98 commented Apr 28, 2021

There's already a snippet called cross-validation, maybe I could rename that to k-folding and add another one for cross-validation? I would basically adapt the elasticnet_cv example

@bytesnake
Copy link
Member

sounds good 👍

@bytesnake bytesnake merged commit a5a479f into rust-ml:master Apr 28, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants