diff --git a/DESCRIPTION b/DESCRIPTION index a856f43..9bf7b31 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,7 +1,7 @@ Package: structToolbox Type: Package Title: Data processing & analysis tools for Metabolomics and other omics -Version: 1.12.2 +Version: 1.12.3 Authors@R: c( person( c("Gavin","Rhys"), diff --git a/NEWS b/NEWS index 4404949..4ba434c 100644 --- a/NEWS +++ b/NEWS @@ -1,3 +1,8 @@ +Changes 1.12.3 ++ fix PLSDA predicted groups ++ add option to use probability or yhat for PLSDA predictions ++ update PLSDA tests + Changes 1.12.2 + dratio equations changed to match Broadhurst et al (2018) + added more tests for dratio_filter diff --git a/R/PLSDA_class.R b/R/PLSDA_class.R index cdebd6d..64b170e 100644 --- a/R/PLSDA_class.R +++ b/R/PLSDA_class.R @@ -3,11 +3,12 @@ #' @include PLSR_class.R #' @examples #' M = PLSDA('number_components'=2,factor_name='Species') -PLSDA = function(number_components=2,factor_name,...) { +PLSDA = function(number_components=2,factor_name,pred_method='max_prob',...) { out=struct::new_struct('PLSDA', - number_components=number_components, - factor_name=factor_name, - ...) + number_components=number_components, + factor_name=factor_name, + pred_method=pred_method, + ...) return(out) } @@ -29,19 +30,24 @@ PLSDA = function(number_components=2,factor_name,...) { pred='data.frame', threshold='numeric', sr = 'entity', - sr_pvalue='entity' + sr_pvalue='entity', + pred_method='entity' ), - prototype = list(name='Partial least squares discriminant analysis', + prototype = list( + name='Partial least squares discriminant analysis', type="classification", predicted='pred', libraries='pls', - description=paste0('PLS is a multivariate regression technique that ', + description=paste0( + 'PLS is a multivariate regression technique that ', 'extracts latent variables maximising covariance between the input ', 'data and the response. The Discriminant Analysis variant uses group ', - 'labels in the response variable and applies a threshold to the ', - 'predicted values in order to predict group membership for new samples.'), - .params=c('number_components','factor_name'), + 'labels in the response variable. For >2 groups a 1-vs-all ', + 'approach is used. Group membership can be predicted for test ', + 'samples based on a probability estimate of group membership, ', + 'or the estimated y-value.'), + .params=c('number_components','factor_name','pred_method'), .outputs=c( 'scores', 'loadings', @@ -57,12 +63,28 @@ PLSDA = function(number_components=2,factor_name,...) { 'sr', 'sr_pvalue'), - number_components=entity(value = 2, + number_components=entity( + value = 2, name = 'Number of components', description = 'The number of PLS components', type = c('numeric','integer') ), factor_name=ents$factor_name, + pred_method=enum( + name='Prediction method', + description=c( + 'max_yhat'= + paste0('The predicted group is selected based on the ', + 'largest value of y_hat.'), + 'max_prob'= + paste0('The predicted group is selected based on the ', + 'largest probability of group membership.') + ), + value='max_prob', + allowed=c('max_yhat','max_prob'), + type='character', + max_length=1 + ), sr = entity( name = 'Selectivity ratio', description = paste0( @@ -92,8 +114,8 @@ PLSDA = function(number_components=2,factor_name,...) { pages = '122-128', author = as.person("Nestor F. Perez and Joan Ferre and Ricard Boque"), title = paste0('Calculation of the reliability of ', - 'classification in discriminant partial least-squares ', - 'binary classification'), + 'classification in discriminant partial least-squares ', + 'binary classification'), journal = "Chemometrics and Intelligent Laboratory Systems" ), bibentry( @@ -113,80 +135,83 @@ PLSDA = function(number_components=2,factor_name,...) { #' @export #' @template model_train setMethod(f="model_train", - signature=c("PLSDA",'DatasetExperiment'), - definition=function(M,D) - { - SM=D$sample_meta - y=SM[[M$factor_name]] - # convert the factor to a design matrix - z=model.matrix(~y+0) - z[z==0]=-1 # +/-1 for PLS - - X=as.matrix(D$data) # convert X to matrix - - Z=as.data.frame(z) - colnames(Z)=as.character(interaction('PLSDA',1:ncol(Z),sep='_')) - - D$sample_meta=cbind(D$sample_meta,Z) - - # PLSR model - N = PLSR(number_components=M$number_components,factor_name=colnames(Z)) - N = model_apply(N,D) - - # copy outputs across - output_list(M) = output_list(N) - - # some specific outputs for PLSDA - output_value(M,'design_matrix')=Z - output_value(M,'y')=D$sample_meta[,M$factor_name,drop=FALSE] - - # for PLSDA compute probabilities - probs=prob(as.matrix(M$yhat),as.matrix(M$yhat),D$sample_meta[[M$factor_name]]) - output_value(M,'probability')=as.data.frame(probs$ingroup) - output_value(M,'threshold')=probs$threshold - - # update column names for outputs - colnames(M$reg_coeff)=levels(y) - colnames(M$sr)=levels(y) - colnames(M$vip)=levels(y) - colnames(M$yhat)=levels(y) - colnames(M$design_matrix)=levels(y) - colnames(M$probability)=levels(y) - names(M$threshold)=levels(y) - colnames(M$sr_pvalue)=levels(y) - - return(M) - } + signature=c("PLSDA",'DatasetExperiment'), + definition=function(M,D) + { + SM=D$sample_meta + y=SM[[M$factor_name]] + # convert the factor to a design matrix + z=model.matrix(~y+0) + z[z==0]=-1 # +/-1 for PLS + + X=as.matrix(D$data) # convert X to matrix + + Z=as.data.frame(z) + colnames(Z)=as.character(interaction('PLSDA',1:ncol(Z),sep='_')) + + D$sample_meta=cbind(D$sample_meta,Z) + + # PLSR model + N = PLSR(number_components=M$number_components,factor_name=colnames(Z)) + N = model_apply(N,D) + + # copy outputs across + output_list(M) = output_list(N) + + # some specific outputs for PLSDA + output_value(M,'design_matrix')=Z + output_value(M,'y')=D$sample_meta[,M$factor_name,drop=FALSE] + + # for PLSDA compute probabilities + probs=prob(as.matrix(M$yhat),as.matrix(M$yhat),D$sample_meta[[M$factor_name]]) + output_value(M,'probability')=as.data.frame(probs$ingroup) + output_value(M,'threshold')=probs$threshold + + # update column names for outputs + colnames(M$reg_coeff)=levels(y) + colnames(M$sr)=levels(y) + colnames(M$vip)=levels(y) + colnames(M$yhat)=levels(y) + colnames(M$design_matrix)=levels(y) + colnames(M$probability)=levels(y) + names(M$threshold)=levels(y) + colnames(M$sr_pvalue)=levels(y) + + return(M) + } ) #' @export #' @template model_predict setMethod(f="model_predict", - signature=c("PLSDA",'DatasetExperiment'), - definition=function(M,D) - { - # call PLSR predict - N=callNextMethod(M,D) - SM=N$y - - ## probability estimate - # http://www.eigenvector.com/faq/index.php?id=38%7C - p=as.matrix(N$pred) - d=prob(x=p,yhat=as.matrix(N$yhat),ytrue=SM[[M$factor_name]]) - pred=(p>d$threshold)*1 - pred=apply(pred,MARGIN=1,FUN=which.max) - hi=apply(d$ingroup,MARGIN=1,FUN=which.max) # max probability - if (sum(is.na(pred)>0)) { - pred[is.na(pred)]=hi[is.na(pred)] # if none above threshold, use group with highest probability - } - pred=factor(pred,levels=1:nlevels(SM[[M$factor_name]]),labels=levels(SM[[M$factor_name]])) # make sure pred has all the levels of y - q=data.frame("pred"=pred) - output_value(M,'pred')=q - return(M) - } + signature=c("PLSDA",'DatasetExperiment'), + definition=function(M,D) + { + # call PLSR predict + N=callNextMethod(M,D) + SM=N$y + + ## probability estimate + # http://www.eigenvector.com/faq/index.php?id=38%7C + p=as.matrix(N$pred) + d=prob(x=p,yhat=as.matrix(N$yhat),ytrue=M$y[[M$factor_name]]) + + # predictions + if (M$pred_method=='max_yhat') { + pred=apply(p,MARGIN=1,FUN=which.max) + } else if (M$pred_method=='max_prob') { + pred=apply(d$ingroup,MARGIN=1,FUN=which.max) + } + pred=factor(pred,levels=1:nlevels(SM[[M$factor_name]]),labels=levels(SM[[M$factor_name]])) # make sure pred has all the levels of y + q=data.frame("pred"=pred) + output_value(M,'pred')=q + return(M) + } ) + + prob=function(x,yhat,ytrue) { # x is predicted values @@ -250,8 +275,7 @@ prob=function(x,yhat,ytrue) } -gauss_intersect=function(m1,m2,s1,s2) -{ +gauss_intersect=function(m1,m2,s1,s2) { #https://stackoverflow.com/questions/22579434/python-finding-the-intersection-point-of-two-gaussian-curves a=(1/(2*s1*s1))-(1/(2*s2*s2)) b=(m2/(s2*s2)) - (m1/(s1*s1)) diff --git a/man/PLSDA.Rd b/man/PLSDA.Rd index 6f1ccf3..4c65386 100644 --- a/man/PLSDA.Rd +++ b/man/PLSDA.Rd @@ -4,13 +4,15 @@ \alias{PLSDA} \title{Partial least squares discriminant analysis} \usage{ -PLSDA(number_components = 2, factor_name, ...) +PLSDA(number_components = 2, factor_name, pred_method = "max_prob", ...) } \arguments{ \item{number_components}{(numeric, integer) The number of PLS components. The default is \code{2}.} \item{factor_name}{(character) The name of a sample-meta column to use.} +\item{pred_method}{(character) Prediction method. Allowed values are limited to the following: \itemize{\item{\code{"max_yhat"}: The predicted group is selected based on the largest value of y_hat.}\item{\code{"max_prob"}: The predicted group is selected based on the largest probability of group membership.}} The default is \code{"max_prob"}.} + \item{...}{Additional slots and values passed to \code{struct_class}.} } \value{ @@ -32,7 +34,7 @@ A \code{PLSDA} object with the following \code{output} slots: } } \description{ -PLS is a multivariate regression technique that extracts latent variables maximising covariance between the input data and the response. The Discriminant Analysis variant uses group labels in the response variable and applies a threshold to the predicted values in order to predict group membership for new samples. +PLS is a multivariate regression technique that extracts latent variables maximising covariance between the input data and the response. The Discriminant Analysis variant uses group labels in the response variable. For >2 groups a 1-vs-all approach is used. Group membership can be predicted for test samples based on a probability estimate of group membership, or the estimated y-value. } \details{ This object makes use of functionality from the following packages:\itemize{\item{\code{pls}}} @@ -41,9 +43,9 @@ This object makes use of functionality from the following packages:\itemize{\ite M = PLSDA('number_components'=2,factor_name='Species') } \references{ -Liland K, Mevik B, Wehrens R (2021). +Liland K, Mevik B, Wehrens R (2023). \emph{pls: Partial Least Squares and Principal Component Regression}. -R package version 2.8-0, \url{https://CRAN.R-project.org/package=pls}. +R package version 2.8-2, \url{https://CRAN.R-project.org/package=pls}. Perez NF, Ferre J, Boque R (2009). ``Calculation of the reliability of classification in discriminant partial least-squares binary classification.'' diff --git a/tests/testthat/test-gridsearch1d.R b/tests/testthat/test-gridsearch1d.R index 96ea05b..a4a5b19 100644 --- a/tests/testthat/test-gridsearch1d.R +++ b/tests/testthat/test-gridsearch1d.R @@ -16,7 +16,7 @@ test_that('grid_search iterator',{ # run I=run(I,D,B) # calculate metric - expect_equal(I$metric$value,0.3,tolerance=0.05) + expect_equal(I$metric$value,0.045,tolerance=0.0005) }) # test grid search @@ -36,7 +36,7 @@ test_that('grid_search wf',{ # run I=run(I,D,B) # calculate metric - expect_equal(I$metric$value[1],0.3,tolerance=0.05) + expect_equal(I$metric$value[1],0.04,tolerance=0.005) }) # test grid search diff --git a/tests/testthat/test-kfold-xval.R b/tests/testthat/test-kfold-xval.R index a4c8c2e..964a01b 100644 --- a/tests/testthat/test-kfold-xval.R +++ b/tests/testthat/test-kfold-xval.R @@ -1,76 +1,76 @@ # test kfold_xval class test_that('kfold xval venetian',{ - set.seed('57475') - # DatasetExperiment - D=iris_DatasetExperiment() - # iterator - I = kfold_xval(folds=5,method='venetian',factor_name='Species')*(mean_centre()+PLSDA(factor_name='Species')) - # metric - B=balanced_accuracy() - # run - I=run(I,D,B) - # calculate metric - expect_equal(I$metric$mean,0.23,tolerance=0.05) + set.seed('57475') + # DatasetExperiment + D=iris_DatasetExperiment() + # iterator + I = kfold_xval(folds=5,method='venetian',factor_name='Species')*(mean_centre()+PLSDA(factor_name='Species')) + # metric + B=balanced_accuracy() + # run + I=run(I,D,B) + # calculate metric + expect_equal(I$metric$mean,0.11,tolerance=0.005) }) test_that('kfold xval blocks',{ - set.seed('57475') - # DatasetExperiment - D=iris_DatasetExperiment() - # iterator - I = permute_sample_order(number_of_permutations=1)* # needs to be permuted or all groups not in training set - kfold_xval(folds=5,method='blocks',factor_name='Species')* - (mean_centre()+PLSDA(factor_name='Species')) - # metric - B=balanced_accuracy() - # run - I=run(I,D,B) - # calculate metric - expect_equal(I$metric$mean,0.23,tolerance=0.05) + set.seed('57475') + # DatasetExperiment + D=iris_DatasetExperiment() + # iterator + I = permute_sample_order(number_of_permutations=1)* # needs to be permuted or all groups not in training set + kfold_xval(folds=5,method='blocks',factor_name='Species')* + (mean_centre()+PLSDA(factor_name='Species')) + # metric + B=balanced_accuracy() + # run + I=run(I,D,B) + # calculate metric + expect_equal(I$metric$mean,0.115,tolerance=0.005) }) test_that('kfold xval random',{ - set.seed(57475) - # DatasetExperiment - D=iris_DatasetExperiment() - # iterator - I = kfold_xval(folds=5,method='random',factor_name='Species')*(mean_centre()+PLSDA(factor_name='Species')) - # metric - B=balanced_accuracy() - # run - I=run(I,D,B) - # calculate metric - expect_equal(I$metric$mean,0.23,tolerance=0.05) + set.seed(57475) + # DatasetExperiment + D=iris_DatasetExperiment() + # iterator + I = kfold_xval(folds=5,method='random',factor_name='Species')*(mean_centre()+PLSDA(factor_name='Species')) + # metric + B=balanced_accuracy() + # run + I=run(I,D,B) + # calculate metric + expect_equal(I$metric$mean,0.105,tolerance=0.0005) }) test_that('kfold xval metric plot',{ - set.seed('57475') - # DatasetExperiment - D=iris_DatasetExperiment() - # iterator - I = kfold_xval(folds=5,method='venetian',factor_name='Species')*(mean_centre()+PLSDA(factor_name='Species')) - # metric - B=balanced_accuracy() - # run - I=run(I,D,B) - # chart - C = kfoldxcv_metric() - gg=chart_plot(C,I) - expect_true(is(gg,'ggplot')) + set.seed('57475') + # DatasetExperiment + D=iris_DatasetExperiment() + # iterator + I = kfold_xval(folds=5,method='venetian',factor_name='Species')*(mean_centre()+PLSDA(factor_name='Species')) + # metric + B=balanced_accuracy() + # run + I=run(I,D,B) + # chart + C = kfoldxcv_metric() + gg=chart_plot(C,I) + expect_true(is(gg,'ggplot')) }) test_that('kfold xval grid plot',{ - set.seed('57475') - # DatasetExperiment - D=iris_DatasetExperiment() - # iterator - I = kfold_xval(folds=5,method='venetian',factor_name='Species')*(mean_centre()+PLSDA(factor_name='Species')) - # metric - B=balanced_accuracy() - # run - I=run(I,D,B) - # chart - C = kfoldxcv_grid(factor_name='Species',level='setosa') - gg=chart_plot(C,I) - expect_true(is(gg,'ggplot')) + set.seed('57475') + # DatasetExperiment + D=iris_DatasetExperiment() + # iterator + I = kfold_xval(folds=5,method='venetian',factor_name='Species')*(mean_centre()+PLSDA(factor_name='Species')) + # metric + B=balanced_accuracy() + # run + I=run(I,D,B) + # chart + C = kfoldxcv_grid(factor_name='Species',level='setosa') + gg=chart_plot(C,I) + expect_true(is(gg,'ggplot')) }) diff --git a/tests/testthat/test-permutation_test.R b/tests/testthat/test-permutation_test.R index 62c62a4..2d65d13 100644 --- a/tests/testthat/test-permutation_test.R +++ b/tests/testthat/test-permutation_test.R @@ -13,7 +13,7 @@ test_that('permutation test',{ # calculate metric B=calculate(B,Yhat=output_value(I,'results.unpermuted')$predicted, Y=output_value(I,'results.unpermuted')$actual) - expect_equal(value(B),expected=0.211,tolerance=0.004) + expect_equal(value(B),expected=0.105,tolerance=0.0005) }) # permutation test box plot diff --git a/tests/testthat/test-permute-sample-order.R b/tests/testthat/test-permute-sample-order.R index cd6138e..90d64cc 100644 --- a/tests/testthat/test-permute-sample-order.R +++ b/tests/testthat/test-permute-sample-order.R @@ -9,7 +9,7 @@ test_that('permute sample order model_seq',{ B=balanced_accuracy() # run I=run(I,D,B) - expect_equal(I$metric$mean,expected=0.335,tolerance=0.05) + expect_equal(I$metric$mean,expected=0.04,tolerance=0.005) }) # permute sample order @@ -23,5 +23,5 @@ test_that('permute sample order iterator',{ B=balanced_accuracy() # run I=run(I,D,B) - expect_equal(I$metric$mean,expected=0.339,tolerance=0.05) + expect_equal(I$metric$mean,expected=0.048,tolerance=0.0005) })