-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
368 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
--- | ||
title: "nnet benchmark fit" | ||
author: "George Fisher" | ||
date: "June 16, 2015" | ||
output: | ||
pdf_document: | ||
toc: yes | ||
toc_depth: 1 | ||
--- | ||
|
||
#SETUP | ||
|
||
```{r setup, message=FALSE} | ||
library(nnet) | ||
library(plyr) | ||
library(caret) | ||
library(psych) | ||
library(pryr) | ||
library(ggplot2) | ||
library(foreach) | ||
library(doParallel) | ||
library(readr) | ||
library(data.table) | ||
rm(list = setdiff(ls(), lsf.str())) # clear variables, leave functions | ||
ptm <- proc.time() # start timer | ||
opar = par(no.readonly=TRUE) | ||
# ############################ PARAMETER SETUP ################################## | ||
# =============================================================================== | ||
deskewed = FALSE # deskewed (TRUE) or original (FALSE) | ||
source("load_TrainTest.R") # load the data | ||
best_M = 120 | ||
best_maxit = 250 | ||
best_decay = 0.5 | ||
# =============================================================================== | ||
# ############################ PARAMETER SETUP ################################## | ||
# ################################## END ######################################## | ||
# calculate the length of the Wts vector | ||
# ====================================== | ||
num.Wts = function(p, M, K) { | ||
# returns the length of the Wts vector | ||
# | ||
# p = ncol(X) # number of predictor variables | ||
# M = size # number of hidden units | ||
# K = 1 # number of classes | ||
return ((p + 1) * M + K * (M + 1)) | ||
} | ||
p = ncol(trainX) # number of predictor variables | ||
K = 10 # x, y input the number of output classes | ||
``` | ||
|
||
#TRAIN WITH THE FULL TRAINING DATASET | ||
|
||
```{r fit_best, message=FALSE} | ||
nnet.fit = nnet(x=trainX, y=class.ind(trainY), | ||
softmax=TRUE, | ||
size=best_M, decay=best_decay, maxit=best_maxit, | ||
bag=TRUE, MaxNWts=num.Wts(p, best_M, K), | ||
Wts=runif (num.Wts(p, best_M, K), -0.5, 0.5), | ||
trace=TRUE) | ||
``` | ||
|
||
#FIT THE TEST DATASET | ||
|
||
```{r pred_test,message=FALSE} | ||
nnet.pred = predict(nnet.fit, newdata=testX, type="class") | ||
(matrix = table(actual = as.character(testY), | ||
predicted = nnet.pred)) | ||
# tr() expects a square matrix | ||
# if predict() does not produce all values 0...9 an error is thrown | ||
correct.entries = 0 | ||
tr_error = tryCatch({ | ||
correct.entries = tr(matrix) | ||
}, warning = function(w) { | ||
#warning-handler-code | ||
}, error = function(e) { | ||
print(e) | ||
}, finally = { | ||
#cleanup code | ||
}) | ||
(model.accuracy = correct.entries / sum(matrix)) | ||
(model.misclass = 1 - model.accuracy) | ||
# which were the hardest to detect? | ||
# ================================= | ||
if (correct.entries > 0) { | ||
results = data.frame(number=numeric(0), percent=numeric(0)) | ||
for (i in seq(from=0,to=9)){ | ||
results[nrow(results)+1,] = c(i, round(prop.table(matrix,1),digits=3)[i+1,i+1]) | ||
} | ||
results[nrow(results)+1,] = c(100,model.accuracy) | ||
print(arrange(results,percent)) | ||
} | ||
# run time | ||
run_time = proc.time() - ptm | ||
print(paste("elapsed minutes",round(run_time[3]/60,digits=2), | ||
"; elapsed hours",round(run_time[3]/(60*60),digits=2))) | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,259 @@ | ||
--- | ||
title: "nnet cross validation" | ||
author: "George Fisher" | ||
date: "June 17, 2015" | ||
output: | ||
pdf_document: | ||
toc: yes | ||
toc_depth: 1 | ||
--- | ||
|
||
#SETUP | ||
|
||
```{r setup, message=FALSE} | ||
library(nnet) | ||
library(psych) | ||
library(plyr) | ||
library(caret) | ||
library(pryr) | ||
library(ggplot2) | ||
library(foreach) | ||
library(doParallel) | ||
library(readr) | ||
library(data.table) | ||
library(NMF) | ||
library(RColorBrewer) | ||
rm(list = setdiff(ls(), lsf.str())) # clear variables, leave functions | ||
ptm <- proc.time() # start timer | ||
opar = par(no.readonly=TRUE) | ||
# ############################ PARAMETER SETUP ################################## | ||
# =============================================================================== | ||
deskewed = TRUE # deskewed (TRUE) or original (FALSE) | ||
source("../load_TrainTest.R") # load the data | ||
k = 5 # k=5 fold CV | ||
fold.idx = createFolds(trainY, k = k) # the indexes of the hold-out folds | ||
registerDoParallel(cores=k) # each fold gets its own core | ||
if (deskewed) { | ||
cv_filename = "nnet_cv_deskewed.txt" | ||
} else { | ||
cv_filename = "nnet_cv_original.txt" | ||
} | ||
# start a new file | ||
# ------> comment out to append to an existing file <------ | ||
# ======================================================================== | ||
header="fold,M,maxit,decay,avg.misclass,avg.mem_used,parms,date,elapsed" | ||
write.table(header, file=cv_filename, quote=FALSE, sep="", append=FALSE, | ||
row.names = FALSE, col.names = FALSE) | ||
# ======================================================================== | ||
params = expand.grid(M = 40, maxit = c(250,300,400), decay = 1) | ||
#params = expand.grid(M = 5, maxit = 50, decay = 0.5) # for script testing only | ||
# =============================================================================== | ||
# ############################ PARAMETER SETUP ################################## | ||
# ################################## END ######################################## | ||
# calculate the length of the Wts vector | ||
# ====================================== | ||
num.Wts = function(p, M, K) { | ||
# returns the length of the Wts vector | ||
# | ||
# p = ncol(X) # number of predictor variables | ||
# M = size # number of hidden units | ||
# K = 1 # number of classes | ||
return ((p + 1) * M + K * (M + 1)) | ||
} | ||
p = ncol(trainX) # number of predictor variables | ||
K = 10 # x, y input the number of output classes | ||
``` | ||
|
||
#CROSS VALIDATION | ||
|
||
```{r cv, message=FALSE} | ||
# results of each k-fold CV loop | ||
# ============================== | ||
cv_results = data.frame(M=numeric(0), | ||
maxit=numeric(0), | ||
decay=numeric(0), | ||
avg.misclass=numeric(0), | ||
avg.mem_used=numeric(0), | ||
parms=character(0)) | ||
# ================================= k-fold CV loop ================================= | ||
for (i in 1:nrow(params)) { | ||
# for each set of parameters ... | ||
M = params[i,"M"] | ||
maxit = params[i,"maxit"] | ||
decay = params[i,"decay"] | ||
# ... run k-fold CV, each fold on a separate processor | ||
results = foreach (fold = seq(k), .combine = 'rbind', .inorder = FALSE) %dopar% { | ||
fold_start = proc.time() | ||
idx = fold.idx[[fold]] | ||
# fit 4/5 of the training data | ||
nnet.fit = nnet(x = trainX[-idx,], y = class.ind(trainY[-idx]), | ||
softmax = TRUE, | ||
size = M, decay = decay, maxit = maxit, | ||
bag = TRUE, MaxNWts = num.Wts(p, M, K),# #MaxNWts=1000000, | ||
Wts = runif (num.Wts(p, M, K),-0.5, 0.5), | ||
trace = FALSE) | ||
# predict 1/5 of the training data | ||
nnet.pred = predict(nnet.fit, newdata = trainX[idx,],type = "class") | ||
matrix = table(actual = as.character(trainY[idx]), | ||
predicted = nnet.pred) | ||
# tr() expects a square matrix | ||
# if predict() does not produce all values 0...9 an error is thrown | ||
model.misclass = tryCatch({ | ||
# code to try | ||
# 'tryCatch()' will return the last evaluated expression | ||
# in case the "try" part was completed successfully | ||
1 - (tr(matrix) / sum(matrix)) | ||
}, | ||
warning = function(w) { | ||
# warning-handler-code | ||
print(w) | ||
print(matrix) | ||
return(1.0) | ||
}, | ||
error = function(e) { | ||
# error-handler-code | ||
print(e) | ||
print(matrix) | ||
return(1.0) | ||
}, | ||
finally = { | ||
# NOTE: | ||
# Here goes everything that should be executed at the end, | ||
# regardless of success or error. | ||
} ) | ||
# write data to disk for later evaluation | ||
fold_row = c(fold, M, maxit, decay, model.misclass, mem_used()) | ||
frt = proc.time() - fold_start # fold run time | ||
frte = as.numeric(frt["elapsed"]) # fold run time elapsed | ||
write.table(matrix(c(fold_row, | ||
paste0("M",M,"it",maxit,"d",decay), | ||
date(), frte), ncol=9, nrow=1), | ||
file=cv_filename, | ||
append=TRUE, quote=TRUE, sep=",", | ||
row.names = FALSE, col.names = FALSE) | ||
# each fold outputs one row of "results" | ||
fold_row | ||
} | ||
# after each k-fold CV save the information in a data.frame | ||
cv_results[nrow(cv_results)+1,] = c(M, maxit, decay, | ||
mean(as.numeric(results[,5])), # model.misclass | ||
mean(as.numeric(results[,6])), # mem_used() | ||
paste0("M",M,"it",maxit,"d",decay)) | ||
} | ||
# ================================= k-fold CV loop ================================= | ||
cv_results[,1:5] = sapply(cv_results[,1:5],as.numeric) | ||
print(cv_results) | ||
# find the parameters that produced the lowest average misclassification rate | ||
# =========================================================================== | ||
best_cv_row = which.min(cv_results$avg.misclass) | ||
best_results = cv_results[best_cv_row,] | ||
best_M = best_results$M | ||
best_maxit = best_results$maxit | ||
best_decay = best_results$decay | ||
best_misclass = best_results$avg.misclass | ||
``` | ||
|
||
|
||
```{r cv_runtime} | ||
# run time | ||
run_time = proc.time() - ptm | ||
print(paste( | ||
"elapsed minutes",round(run_time[3] / 60,digits = 2), | ||
"; elapsed hours",round(run_time[3] / (60 * 60),digits = 2), | ||
"; user/elapsed",round((run_time[1]+run_time[4])/run_time[3],digits=0) | ||
)) | ||
mem_range = prettyNum(range(cv_results$avg.mem_used),big.mark=",",scientific=FALSE) | ||
print(paste("Range of R memory usage",mem_range[1],":",mem_range[2])) | ||
``` | ||
|
||
#TRAIN WITH THE FULL TRAINING DATASET | ||
using the best parameters found in the CV steps | ||
|
||
```{r fit_best, message=FALSE} | ||
# fit the full training dataset | ||
# with the best parameters found by CV | ||
nnet.fit = nnet(x=trainX, y=class.ind(trainY), | ||
softmax=TRUE, | ||
size=best_M, decay=best_decay, maxit=best_maxit, | ||
bag=TRUE, MaxNWts=num.Wts(p, best_M, K), | ||
Wts=runif (num.Wts(p, best_M, K), -0.5, 0.5), | ||
trace=TRUE) | ||
``` | ||
|
||
#FIT THE TEST DATASET | ||
|
||
```{r pred_test,message=FALSE} | ||
# get the specific class predictions | ||
nnet.pred = predict(nnet.fit, newdata=testX, type="class") | ||
matrix = table(actual = as.character(testY), | ||
predicted = nnet.pred) | ||
# heatmap of the range of probabilities | ||
aheatmap(prop.table(matrix,margin=1), Rowv=NA, Colv=NA) | ||
confusionMatrix(data=nnet.pred, reference=as.character(testY)) | ||
# tr() expects a square matrix | ||
# if predict() does not produce all values 0...9 an error is thrown | ||
correct.entries = tryCatch({ | ||
# code to try | ||
# 'tryCatch()' will return the last evaluated expression | ||
# in case the "try" part was completed successfully | ||
tr(matrix) | ||
}, | ||
warning = function(w) { | ||
# warning-handler-code | ||
print(w) | ||
return(0.0) | ||
}, | ||
error = function(e) { | ||
# error-handler-code | ||
print(e) | ||
return(0.0) | ||
}, | ||
finally = { | ||
# NOTE: | ||
# Here goes everything that should be executed at the end, | ||
# regardless of success or error. | ||
} ) | ||
(model.accuracy = correct.entries / sum(matrix)) | ||
(model.misclass = 1 - model.accuracy) | ||
# which were the hardest to detect? | ||
# ================================= | ||
if (correct.entries > 0) { | ||
results = data.frame(number=numeric(0), percent=numeric(0)) | ||
for (i in seq(from=0,to=9)){ | ||
results[nrow(results)+1,] = c(i, round(prop.table(matrix,1),digits=3)[i+1,i+1]) | ||
} | ||
results[nrow(results)+1,] = c(100,model.accuracy) | ||
print(arrange(results,percent)) | ||
} | ||
# run time | ||
run_time = proc.time() - ptm | ||
print(paste("elapsed minutes",round(run_time[3]/60,digits=2), | ||
"; elapsed hours",round(run_time[3]/(60*60),digits=2))) | ||
``` | ||
|