-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathKFoldValidation.py
More file actions
executable file
·25 lines (21 loc) · 1 KB
/
KFoldValidation.py
File metadata and controls
executable file
·25 lines (21 loc) · 1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import numpy as np
class KFold(object):
def __init__(self, X, Y, foldTotal=10):
self.X = X
self.Y = Y
self.foldTotal = foldTotal
self.spiltLength = len(self.Y) // foldTotal
def spilt(self, foldTime):
'''
It will be a little not well distributed because there is a remain for len(self.Y) // foldTotal.
But the remain will smaller than foldTotal and does not matter comparing with the large training set.
:param foldTime: the counter of spilt operation
:return: training data of input and label, validating
'''
validateStart = foldTime * self.spiltLength
validateEnd = (foldTime + 1) * self.spiltLength
trainX = np.concatenate((self.X[:validateStart], self.X[validateEnd:]))
trainY = np.concatenate((self.Y[:validateStart], self.Y[validateEnd:]))
validateX = self.X[validateStart:validateEnd]
validateY = self.Y[validateStart:validateEnd]
return trainX, trainY, validateX, validateY