From 46c79734ff35d81471564e29372e6d460de720a0 Mon Sep 17 00:00:00 2001 From: Gavin Rhys Lloyd Date: Tue, 5 Sep 2023 16:04:56 +0100 Subject: [PATCH] hotfix PLSDA predicted groups - predicted group now correctly assigned based on ingroup probability or yhat value --- R/PLSDA_class.R | 186 ++++++++++++--------- man/PLSDA.Rd | 6 +- tests/testthat/test-gridsearch1d.R | 4 +- tests/testthat/test-kfold-xval.R | 6 +- tests/testthat/test-permutation_test.R | 2 +- tests/testthat/test-permute-sample-order.R | 4 +- 6 files changed, 117 insertions(+), 91 deletions(-) diff --git a/R/PLSDA_class.R b/R/PLSDA_class.R index cdebd6d..64b170e 100644 --- a/R/PLSDA_class.R +++ b/R/PLSDA_class.R @@ -3,11 +3,12 @@ #' @include PLSR_class.R #' @examples #' M = PLSDA('number_components'=2,factor_name='Species') -PLSDA = function(number_components=2,factor_name,...) { +PLSDA = function(number_components=2,factor_name,pred_method='max_prob',...) { out=struct::new_struct('PLSDA', - number_components=number_components, - factor_name=factor_name, - ...) + number_components=number_components, + factor_name=factor_name, + pred_method=pred_method, + ...) return(out) } @@ -29,19 +30,24 @@ PLSDA = function(number_components=2,factor_name,...) { pred='data.frame', threshold='numeric', sr = 'entity', - sr_pvalue='entity' + sr_pvalue='entity', + pred_method='entity' ), - prototype = list(name='Partial least squares discriminant analysis', + prototype = list( + name='Partial least squares discriminant analysis', type="classification", predicted='pred', libraries='pls', - description=paste0('PLS is a multivariate regression technique that ', + description=paste0( + 'PLS is a multivariate regression technique that ', 'extracts latent variables maximising covariance between the input ', 'data and the response. The Discriminant Analysis variant uses group ', - 'labels in the response variable and applies a threshold to the ', - 'predicted values in order to predict group membership for new samples.'), - .params=c('number_components','factor_name'), + 'labels in the response variable. For >2 groups a 1-vs-all ', + 'approach is used. Group membership can be predicted for test ', + 'samples based on a probability estimate of group membership, ', + 'or the estimated y-value.'), + .params=c('number_components','factor_name','pred_method'), .outputs=c( 'scores', 'loadings', @@ -57,12 +63,28 @@ PLSDA = function(number_components=2,factor_name,...) { 'sr', 'sr_pvalue'), - number_components=entity(value = 2, + number_components=entity( + value = 2, name = 'Number of components', description = 'The number of PLS components', type = c('numeric','integer') ), factor_name=ents$factor_name, + pred_method=enum( + name='Prediction method', + description=c( + 'max_yhat'= + paste0('The predicted group is selected based on the ', + 'largest value of y_hat.'), + 'max_prob'= + paste0('The predicted group is selected based on the ', + 'largest probability of group membership.') + ), + value='max_prob', + allowed=c('max_yhat','max_prob'), + type='character', + max_length=1 + ), sr = entity( name = 'Selectivity ratio', description = paste0( @@ -92,8 +114,8 @@ PLSDA = function(number_components=2,factor_name,...) { pages = '122-128', author = as.person("Nestor F. Perez and Joan Ferre and Ricard Boque"), title = paste0('Calculation of the reliability of ', - 'classification in discriminant partial least-squares ', - 'binary classification'), + 'classification in discriminant partial least-squares ', + 'binary classification'), journal = "Chemometrics and Intelligent Laboratory Systems" ), bibentry( @@ -113,80 +135,83 @@ PLSDA = function(number_components=2,factor_name,...) { #' @export #' @template model_train setMethod(f="model_train", - signature=c("PLSDA",'DatasetExperiment'), - definition=function(M,D) - { - SM=D$sample_meta - y=SM[[M$factor_name]] - # convert the factor to a design matrix - z=model.matrix(~y+0) - z[z==0]=-1 # +/-1 for PLS - - X=as.matrix(D$data) # convert X to matrix - - Z=as.data.frame(z) - colnames(Z)=as.character(interaction('PLSDA',1:ncol(Z),sep='_')) - - D$sample_meta=cbind(D$sample_meta,Z) - - # PLSR model - N = PLSR(number_components=M$number_components,factor_name=colnames(Z)) - N = model_apply(N,D) - - # copy outputs across - output_list(M) = output_list(N) - - # some specific outputs for PLSDA - output_value(M,'design_matrix')=Z - output_value(M,'y')=D$sample_meta[,M$factor_name,drop=FALSE] - - # for PLSDA compute probabilities - probs=prob(as.matrix(M$yhat),as.matrix(M$yhat),D$sample_meta[[M$factor_name]]) - output_value(M,'probability')=as.data.frame(probs$ingroup) - output_value(M,'threshold')=probs$threshold - - # update column names for outputs - colnames(M$reg_coeff)=levels(y) - colnames(M$sr)=levels(y) - colnames(M$vip)=levels(y) - colnames(M$yhat)=levels(y) - colnames(M$design_matrix)=levels(y) - colnames(M$probability)=levels(y) - names(M$threshold)=levels(y) - colnames(M$sr_pvalue)=levels(y) - - return(M) - } + signature=c("PLSDA",'DatasetExperiment'), + definition=function(M,D) + { + SM=D$sample_meta + y=SM[[M$factor_name]] + # convert the factor to a design matrix + z=model.matrix(~y+0) + z[z==0]=-1 # +/-1 for PLS + + X=as.matrix(D$data) # convert X to matrix + + Z=as.data.frame(z) + colnames(Z)=as.character(interaction('PLSDA',1:ncol(Z),sep='_')) + + D$sample_meta=cbind(D$sample_meta,Z) + + # PLSR model + N = PLSR(number_components=M$number_components,factor_name=colnames(Z)) + N = model_apply(N,D) + + # copy outputs across + output_list(M) = output_list(N) + + # some specific outputs for PLSDA + output_value(M,'design_matrix')=Z + output_value(M,'y')=D$sample_meta[,M$factor_name,drop=FALSE] + + # for PLSDA compute probabilities + probs=prob(as.matrix(M$yhat),as.matrix(M$yhat),D$sample_meta[[M$factor_name]]) + output_value(M,'probability')=as.data.frame(probs$ingroup) + output_value(M,'threshold')=probs$threshold + + # update column names for outputs + colnames(M$reg_coeff)=levels(y) + colnames(M$sr)=levels(y) + colnames(M$vip)=levels(y) + colnames(M$yhat)=levels(y) + colnames(M$design_matrix)=levels(y) + colnames(M$probability)=levels(y) + names(M$threshold)=levels(y) + colnames(M$sr_pvalue)=levels(y) + + return(M) + } ) #' @export #' @template model_predict setMethod(f="model_predict", - signature=c("PLSDA",'DatasetExperiment'), - definition=function(M,D) - { - # call PLSR predict - N=callNextMethod(M,D) - SM=N$y - - ## probability estimate - # http://www.eigenvector.com/faq/index.php?id=38%7C - p=as.matrix(N$pred) - d=prob(x=p,yhat=as.matrix(N$yhat),ytrue=SM[[M$factor_name]]) - pred=(p>d$threshold)*1 - pred=apply(pred,MARGIN=1,FUN=which.max) - hi=apply(d$ingroup,MARGIN=1,FUN=which.max) # max probability - if (sum(is.na(pred)>0)) { - pred[is.na(pred)]=hi[is.na(pred)] # if none above threshold, use group with highest probability - } - pred=factor(pred,levels=1:nlevels(SM[[M$factor_name]]),labels=levels(SM[[M$factor_name]])) # make sure pred has all the levels of y - q=data.frame("pred"=pred) - output_value(M,'pred')=q - return(M) - } + signature=c("PLSDA",'DatasetExperiment'), + definition=function(M,D) + { + # call PLSR predict + N=callNextMethod(M,D) + SM=N$y + + ## probability estimate + # http://www.eigenvector.com/faq/index.php?id=38%7C + p=as.matrix(N$pred) + d=prob(x=p,yhat=as.matrix(N$yhat),ytrue=M$y[[M$factor_name]]) + + # predictions + if (M$pred_method=='max_yhat') { + pred=apply(p,MARGIN=1,FUN=which.max) + } else if (M$pred_method=='max_prob') { + pred=apply(d$ingroup,MARGIN=1,FUN=which.max) + } + pred=factor(pred,levels=1:nlevels(SM[[M$factor_name]]),labels=levels(SM[[M$factor_name]])) # make sure pred has all the levels of y + q=data.frame("pred"=pred) + output_value(M,'pred')=q + return(M) + } ) + + prob=function(x,yhat,ytrue) { # x is predicted values @@ -250,8 +275,7 @@ prob=function(x,yhat,ytrue) } -gauss_intersect=function(m1,m2,s1,s2) -{ +gauss_intersect=function(m1,m2,s1,s2) { #https://stackoverflow.com/questions/22579434/python-finding-the-intersection-point-of-two-gaussian-curves a=(1/(2*s1*s1))-(1/(2*s2*s2)) b=(m2/(s2*s2)) - (m1/(s1*s1)) diff --git a/man/PLSDA.Rd b/man/PLSDA.Rd index 2e9a92f..724ae1d 100644 --- a/man/PLSDA.Rd +++ b/man/PLSDA.Rd @@ -4,13 +4,15 @@ \alias{PLSDA} \title{Partial least squares discriminant analysis} \usage{ -PLSDA(number_components = 2, factor_name, ...) +PLSDA(number_components = 2, factor_name, pred_method = "max_prob", ...) } \arguments{ \item{number_components}{(numeric, integer) The number of PLS components. The default is \code{2}.} \item{factor_name}{(character) The name of a sample-meta column to use.} +\item{pred_method}{(character) Prediction method. Allowed values are limited to the following: \itemize{\item{\code{"max_yhat"}: The predicted group is selected based on the largest value of y_hat.}\item{\code{"max_prob"}: The predicted group is selected based on the largest probability of group membership.}} The default is \code{"max_prob"}.} + \item{...}{Additional slots and values passed to \code{struct_class}.} } \value{ @@ -32,7 +34,7 @@ A \code{PLSDA} object with the following \code{output} slots: } } \description{ -PLS is a multivariate regression technique that extracts latent variables maximising covariance between the input data and the response. The Discriminant Analysis variant uses group labels in the response variable and applies a threshold to the predicted values in order to predict group membership for new samples. +PLS is a multivariate regression technique that extracts latent variables maximising covariance between the input data and the response. The Discriminant Analysis variant uses group labels in the response variable. For >2 groups a 1-vs-all approach is used. Group membership can be predicted for test samples based on a probability estimate of group membership, or the estimated y-value. } \details{ This object makes use of functionality from the following packages:\itemize{\item{\code{pls}}} diff --git a/tests/testthat/test-gridsearch1d.R b/tests/testthat/test-gridsearch1d.R index 96ea05b..a4a5b19 100644 --- a/tests/testthat/test-gridsearch1d.R +++ b/tests/testthat/test-gridsearch1d.R @@ -16,7 +16,7 @@ test_that('grid_search iterator',{ # run I=run(I,D,B) # calculate metric - expect_equal(I$metric$value,0.3,tolerance=0.05) + expect_equal(I$metric$value,0.045,tolerance=0.0005) }) # test grid search @@ -36,7 +36,7 @@ test_that('grid_search wf',{ # run I=run(I,D,B) # calculate metric - expect_equal(I$metric$value[1],0.3,tolerance=0.05) + expect_equal(I$metric$value[1],0.04,tolerance=0.005) }) # test grid search diff --git a/tests/testthat/test-kfold-xval.R b/tests/testthat/test-kfold-xval.R index a4c8c2e..9db22ca 100644 --- a/tests/testthat/test-kfold-xval.R +++ b/tests/testthat/test-kfold-xval.R @@ -10,7 +10,7 @@ test_that('kfold xval venetian',{ # run I=run(I,D,B) # calculate metric - expect_equal(I$metric$mean,0.23,tolerance=0.05) + expect_equal(I$metric$mean,0.11,tolerance=0.005) }) test_that('kfold xval blocks',{ @@ -26,7 +26,7 @@ test_that('kfold xval blocks',{ # run I=run(I,D,B) # calculate metric - expect_equal(I$metric$mean,0.23,tolerance=0.05) + expect_equal(I$metric$mean,0.115,tolerance=0.005) }) test_that('kfold xval random',{ @@ -40,7 +40,7 @@ test_that('kfold xval random',{ # run I=run(I,D,B) # calculate metric - expect_equal(I$metric$mean,0.23,tolerance=0.05) + expect_equal(I$metric$mean,0.105,tolerance=0.0005) }) test_that('kfold xval metric plot',{ diff --git a/tests/testthat/test-permutation_test.R b/tests/testthat/test-permutation_test.R index 62c62a4..2d65d13 100644 --- a/tests/testthat/test-permutation_test.R +++ b/tests/testthat/test-permutation_test.R @@ -13,7 +13,7 @@ test_that('permutation test',{ # calculate metric B=calculate(B,Yhat=output_value(I,'results.unpermuted')$predicted, Y=output_value(I,'results.unpermuted')$actual) - expect_equal(value(B),expected=0.211,tolerance=0.004) + expect_equal(value(B),expected=0.105,tolerance=0.0005) }) # permutation test box plot diff --git a/tests/testthat/test-permute-sample-order.R b/tests/testthat/test-permute-sample-order.R index cd6138e..90d64cc 100644 --- a/tests/testthat/test-permute-sample-order.R +++ b/tests/testthat/test-permute-sample-order.R @@ -9,7 +9,7 @@ test_that('permute sample order model_seq',{ B=balanced_accuracy() # run I=run(I,D,B) - expect_equal(I$metric$mean,expected=0.335,tolerance=0.05) + expect_equal(I$metric$mean,expected=0.04,tolerance=0.005) }) # permute sample order @@ -23,5 +23,5 @@ test_that('permute sample order iterator',{ B=balanced_accuracy() # run I=run(I,D,B) - expect_equal(I$metric$mean,expected=0.339,tolerance=0.05) + expect_equal(I$metric$mean,expected=0.048,tolerance=0.0005) })