Skip to content

Commit

Permalink
hotfix PLSDA predictions
Browse files Browse the repository at this point in the history
- predictions now correctly assigned based on max probability or max yhat
- tests updated
- documentation updated
- news updated
- version bump
  • Loading branch information
grlloyd committed Sep 5, 2023
1 parent 8662767 commit 9b1b307
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 152 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: structToolbox
Type: Package
Title: Data processing & analysis tools for Metabolomics and other omics
Version: 1.12.2
Version: 1.12.3
Authors@R: c(
person(
c("Gavin","Rhys"),
Expand Down
5 changes: 5 additions & 0 deletions NEWS
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
Changes 1.12.3
+ fix PLSDA predicted groups
+ add option to use probability or yhat for PLSDA predictions
+ update PLSDA tests

Changes 1.12.2
+ dratio equations changed to match Broadhurst et al (2018)
+ added more tests for dratio_filter
Expand Down
186 changes: 105 additions & 81 deletions R/PLSDA_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
#' @include PLSR_class.R
#' @examples
#' M = PLSDA('number_components'=2,factor_name='Species')
PLSDA = function(number_components=2,factor_name,...) {
PLSDA = function(number_components=2,factor_name,pred_method='max_prob',...) {
out=struct::new_struct('PLSDA',
number_components=number_components,
factor_name=factor_name,
...)
number_components=number_components,
factor_name=factor_name,
pred_method=pred_method,
...)
return(out)
}

Expand All @@ -29,19 +30,24 @@ PLSDA = function(number_components=2,factor_name,...) {
pred='data.frame',
threshold='numeric',
sr = 'entity',
sr_pvalue='entity'
sr_pvalue='entity',
pred_method='entity'

),
prototype = list(name='Partial least squares discriminant analysis',
prototype = list(
name='Partial least squares discriminant analysis',
type="classification",
predicted='pred',
libraries='pls',
description=paste0('PLS is a multivariate regression technique that ',
description=paste0(
'PLS is a multivariate regression technique that ',
'extracts latent variables maximising covariance between the input ',
'data and the response. The Discriminant Analysis variant uses group ',
'labels in the response variable and applies a threshold to the ',
'predicted values in order to predict group membership for new samples.'),
.params=c('number_components','factor_name'),
'labels in the response variable. For >2 groups a 1-vs-all ',
'approach is used. Group membership can be predicted for test ',
'samples based on a probability estimate of group membership, ',
'or the estimated y-value.'),
.params=c('number_components','factor_name','pred_method'),
.outputs=c(
'scores',
'loadings',
Expand All @@ -57,12 +63,28 @@ PLSDA = function(number_components=2,factor_name,...) {
'sr',
'sr_pvalue'),

number_components=entity(value = 2,
number_components=entity(
value = 2,
name = 'Number of components',
description = 'The number of PLS components',
type = c('numeric','integer')
),
factor_name=ents$factor_name,
pred_method=enum(
name='Prediction method',
description=c(
'max_yhat'=
paste0('The predicted group is selected based on the ',
'largest value of y_hat.'),
'max_prob'=
paste0('The predicted group is selected based on the ',
'largest probability of group membership.')
),
value='max_prob',
allowed=c('max_yhat','max_prob'),
type='character',
max_length=1
),
sr = entity(
name = 'Selectivity ratio',
description = paste0(
Expand Down Expand Up @@ -92,8 +114,8 @@ PLSDA = function(number_components=2,factor_name,...) {
pages = '122-128',
author = as.person("Nestor F. Perez and Joan Ferre and Ricard Boque"),
title = paste0('Calculation of the reliability of ',
'classification in discriminant partial least-squares ',
'binary classification'),
'classification in discriminant partial least-squares ',
'binary classification'),
journal = "Chemometrics and Intelligent Laboratory Systems"
),
bibentry(
Expand All @@ -113,80 +135,83 @@ PLSDA = function(number_components=2,factor_name,...) {
#' @export
#' @template model_train
setMethod(f="model_train",
signature=c("PLSDA",'DatasetExperiment'),
definition=function(M,D)
{
SM=D$sample_meta
y=SM[[M$factor_name]]
# convert the factor to a design matrix
z=model.matrix(~y+0)
z[z==0]=-1 # +/-1 for PLS

X=as.matrix(D$data) # convert X to matrix

Z=as.data.frame(z)
colnames(Z)=as.character(interaction('PLSDA',1:ncol(Z),sep='_'))

D$sample_meta=cbind(D$sample_meta,Z)

# PLSR model
N = PLSR(number_components=M$number_components,factor_name=colnames(Z))
N = model_apply(N,D)

# copy outputs across
output_list(M) = output_list(N)

# some specific outputs for PLSDA
output_value(M,'design_matrix')=Z
output_value(M,'y')=D$sample_meta[,M$factor_name,drop=FALSE]

# for PLSDA compute probabilities
probs=prob(as.matrix(M$yhat),as.matrix(M$yhat),D$sample_meta[[M$factor_name]])
output_value(M,'probability')=as.data.frame(probs$ingroup)
output_value(M,'threshold')=probs$threshold

# update column names for outputs
colnames(M$reg_coeff)=levels(y)
colnames(M$sr)=levels(y)
colnames(M$vip)=levels(y)
colnames(M$yhat)=levels(y)
colnames(M$design_matrix)=levels(y)
colnames(M$probability)=levels(y)
names(M$threshold)=levels(y)
colnames(M$sr_pvalue)=levels(y)

return(M)
}
signature=c("PLSDA",'DatasetExperiment'),
definition=function(M,D)
{
SM=D$sample_meta
y=SM[[M$factor_name]]
# convert the factor to a design matrix
z=model.matrix(~y+0)
z[z==0]=-1 # +/-1 for PLS
X=as.matrix(D$data) # convert X to matrix
Z=as.data.frame(z)
colnames(Z)=as.character(interaction('PLSDA',1:ncol(Z),sep='_'))
D$sample_meta=cbind(D$sample_meta,Z)
# PLSR model
N = PLSR(number_components=M$number_components,factor_name=colnames(Z))
N = model_apply(N,D)
# copy outputs across
output_list(M) = output_list(N)
# some specific outputs for PLSDA
output_value(M,'design_matrix')=Z
output_value(M,'y')=D$sample_meta[,M$factor_name,drop=FALSE]
# for PLSDA compute probabilities
probs=prob(as.matrix(M$yhat),as.matrix(M$yhat),D$sample_meta[[M$factor_name]])
output_value(M,'probability')=as.data.frame(probs$ingroup)
output_value(M,'threshold')=probs$threshold
# update column names for outputs
colnames(M$reg_coeff)=levels(y)
colnames(M$sr)=levels(y)
colnames(M$vip)=levels(y)
colnames(M$yhat)=levels(y)
colnames(M$design_matrix)=levels(y)
colnames(M$probability)=levels(y)
names(M$threshold)=levels(y)
colnames(M$sr_pvalue)=levels(y)
return(M)
}
)

#' @export
#' @template model_predict
setMethod(f="model_predict",
signature=c("PLSDA",'DatasetExperiment'),
definition=function(M,D)
{
# call PLSR predict
N=callNextMethod(M,D)
SM=N$y

## probability estimate
# http://www.eigenvector.com/faq/index.php?id=38%7C
p=as.matrix(N$pred)
d=prob(x=p,yhat=as.matrix(N$yhat),ytrue=SM[[M$factor_name]])
pred=(p>d$threshold)*1
pred=apply(pred,MARGIN=1,FUN=which.max)
hi=apply(d$ingroup,MARGIN=1,FUN=which.max) # max probability
if (sum(is.na(pred)>0)) {
pred[is.na(pred)]=hi[is.na(pred)] # if none above threshold, use group with highest probability
}
pred=factor(pred,levels=1:nlevels(SM[[M$factor_name]]),labels=levels(SM[[M$factor_name]])) # make sure pred has all the levels of y
q=data.frame("pred"=pred)
output_value(M,'pred')=q
return(M)
}
signature=c("PLSDA",'DatasetExperiment'),
definition=function(M,D)
{
# call PLSR predict
N=callNextMethod(M,D)
SM=N$y

## probability estimate
# http://www.eigenvector.com/faq/index.php?id=38%7C
p=as.matrix(N$pred)
d=prob(x=p,yhat=as.matrix(N$yhat),ytrue=M$y[[M$factor_name]])

# predictions
if (M$pred_method=='max_yhat') {
pred=apply(p,MARGIN=1,FUN=which.max)
} else if (M$pred_method=='max_prob') {
pred=apply(d$ingroup,MARGIN=1,FUN=which.max)
}
pred=factor(pred,levels=1:nlevels(SM[[M$factor_name]]),labels=levels(SM[[M$factor_name]])) # make sure pred has all the levels of y
q=data.frame("pred"=pred)
output_value(M,'pred')=q
return(M)
}
)




prob=function(x,yhat,ytrue)
{
# x is predicted values
Expand Down Expand Up @@ -250,8 +275,7 @@ prob=function(x,yhat,ytrue)
}


gauss_intersect=function(m1,m2,s1,s2)
{
gauss_intersect=function(m1,m2,s1,s2) {
#https://stackoverflow.com/questions/22579434/python-finding-the-intersection-point-of-two-gaussian-curves
a=(1/(2*s1*s1))-(1/(2*s2*s2))
b=(m2/(s2*s2)) - (m1/(s1*s1))
Expand Down
10 changes: 6 additions & 4 deletions man/PLSDA.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/test-gridsearch1d.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ test_that('grid_search iterator',{
# run
I=run(I,D,B)
# calculate metric
expect_equal(I$metric$value,0.3,tolerance=0.05)
expect_equal(I$metric$value,0.045,tolerance=0.0005)
})

# test grid search
Expand All @@ -36,7 +36,7 @@ test_that('grid_search wf',{
# run
I=run(I,D,B)
# calculate metric
expect_equal(I$metric$value[1],0.3,tolerance=0.05)
expect_equal(I$metric$value[1],0.04,tolerance=0.005)
})

# test grid search
Expand Down
Loading

0 comments on commit 9b1b307

Please sign in to comment.