-
-
Notifications
You must be signed in to change notification settings - Fork 264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fit trait modification and cross validation proposal #122
Conversation
@YuhanLiin I got this test error on kmeans just once and I haven't been able to reproduce it: ---- k_means::init::tests::test_sample_subsequent_candidates stdout ----
thread 'k_means::init::tests::test_sample_subsequent_candidates' panicked at 'assertion failed: `(left == right)`
left: `3`,
right: `2`', algorithms/linfa-clustering/src/k_means/init.rs:259:9 It seems like the test takes a fixed seed in input so it should be deterministic, right? Could there be some randomness in the rng seeds for each thread? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really like the proposal, because
- it makes the associated type easier to read
- makes error handling more explicit
- allows to perform cross validation in an easier way
what do you think about renaming the functions to cross_validate_sum
and cross_validate_mean
?
That RNG is seeded once per Rayon thread (not once per Rayon task), so if Rayon spawns a different number of threads on your machine then you may get different results. I'm going to look into a more deterministic way to run this test. |
I updated the PR description with the changes brought and some considerations on the single-target/multi-target issue |
Codecov Report
@@ Coverage Diff @@
## master #122 +/- ##
==========================================
+ Coverage 57.98% 59.90% +1.91%
==========================================
Files 78 74 -4
Lines 6819 7015 +196
==========================================
+ Hits 3954 4202 +248
+ Misses 2865 2813 -52
Continue to review full report at Codecov.
|
I will review later this day :) |
cross valdation POC
Ok, I re-enabled the tests that were commented out, the tests just needed a little tweak to work with the non-determinism of the hashset used to construct the confusion matrices. There is still one test that is not passing, for what I believe are good reasons. Basically, it tries to compute the confusion matrix between two sets by first calling I'm available for a quick chat in zulip/zoom if that's needed |
added a PR for this PR here Sauro98#3
|
I'm writing tests for the let records = array![[0.,0.],[0.,0.]];
let targets = array![[0,7],[1,8]];
let dataset = DatasetBase::from((records, targets)).with_labels(&[&[0,1],&[8]]); returns a dataset with both samples inside, while I would expect the first sample to be removed since its second target does not figure in the allowed ones. Should it behave like I am expecting or should the sample be kept, since its first target is allowed? In the latter case, we can simplify the method by accepting just a slice instead of a slice of slices. Otherwise ,I'll try to modify the method to check that each target is allowed. |
I would expect that this is an |
* Move linfa-pls to new Lapack bound * More cleanups * Playing around with `cross_validation` * Make generic over dimension * Run rustfmt * Add simple test for multi target cv * Run rustfmt * Rename cross validation multi target to `cross_validate_multi` * Run rustfmt
It keeps throwing a segmentation fault when running tests with tarpaulin on linfa-pls |
mhh I will try to debug this |
this is caused in line https://github.com/rust-ml/linfa/pull/122/files#diff-d93fb3eb4fda95138ffbeb69520f77ce15fc2b6698b5db7eadd8b3c6aa4231d5R97 (linfa-pls/src/utils.rs line 97) |
taking the eigendecomposition in the sign flippy function by value makes tarpaulin happy. Can you add that? Haven't opened another PR for this
|
I will |
I have also updated the contribution guide to reflect the changes in the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you also want to add an example snippet for the website? Otherwise ready to merge 👍
There's already a snippet called cross-validation, maybe I could rename that to k-folding and add another one for cross-validation? I would basically adapt the elasticnet_cv example |
sounds good 👍 |
Changes in the
Fit
traitfrom this:
to this:
by:
'a
lifetime (left from previous svm implementation, not actually used by any algorithm anymore)Edit 1:
Added conversion from linfa error bound on fit error type. Every sub-crate should be able to handle the errors caused by using the base crate
Cross validation POC
Cross-validation can be defined by exploiting the new
Fit
definition. This is what it looks like for regression:And this is what it looks like for classification:
Possible follow up
Redefine the
Transformer
trait to return aResult
, like the t-sne implementation does, in order to avoid panicking as much as possible. It would be better to wait for #121 in order to avoid dealing with theunwrap()
calls when performing float conversions.Notes
Currently there are two versions of cross-validation: one for single target datasets and one for multi-target datasets. The main reasons for this are:Array1
s of evaluation values can be constructed withcollect
,Array2
cannot and must be populated row by row (could be avoided by stacking, or maybe there is a dedicated ndarray method I am not aware of)Regression metrics behave differently when they are appliedArray1-Array2
orArray2-Array1
, making writing evaluation closures possibly more difficult than it should.More than likely there is a solution to both problems, but I got stuck on it for too long and so I would consider it out of scope for this PR, unless someone has an easy solution to suggest. Forcing the user to give a single return value in the multi-target case (like a mean across the targets) could make the problem easier but the evaluation for each target would be lost in the process.