-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataloader.py
76 lines (69 loc) · 3.31 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import numpy as np
from torch.utils.data import DataLoader
from expression_dataset import EyeExpressionDataset
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data.sampler import Sampler
from sklearn.model_selection import KFold, RepeatedKFold
from collate import collate_fn
def prepare_kfold_dataloaders(data, opt):
train_loaders = []
test_loaders = []
# KFold cross validation
kfold = RepeatedKFold(opt.n_splits, int(opt.epoch/opt.n_splits), 45)
eye_expression_dataset = EyeExpressionDataset(
word2idx=data['lang'].word2index,
idx2word=data['lang'].index2word,
src_insts=data['src_insts'],
trg_ints=data['trg_insts'],)
for train_indicies, test_indices in kfold.split(eye_expression_dataset):
# get sampler based on generated indicies
train_sampler = SubsetRandomSampler(train_indicies)
test_sampler = SubsetRandomSampler(test_indices)
# get train and test dataloader
train_loader = DataLoader(
dataset=eye_expression_dataset,
batch_size=opt.batch_size,
num_workers=opt.num_workers,
sampler=train_sampler,
collate_fn=collate_fn)
test_loader = DataLoader(
dataset=eye_expression_dataset,
batch_size=opt.batch_size,
num_workers=opt.num_workers,
sampler=test_sampler,
collate_fn=collate_fn)
train_loaders.append(train_loader)
test_loaders.append(test_loader)
return train_loaders, test_loaders
def prepare_dataloaders(data, opt):
validation_split = 0.2 # which is 20% of whole dataset
random_seed = 45
# get dataset class
eye_expression_dataset = EyeExpressionDataset(
word2idx=data['lang'].word2index,
idx2word=data['lang'].index2word,
src_insts=data['src_insts'],
trg_ints=data['trg_insts'],)
dataset_indicies = list(range(eye_expression_dataset.__len__()))
split_index = int(np.floor(validation_split * eye_expression_dataset.__len__()))
if opt.is_shuffle:
np.random.seed(random_seed)
np.random.shuffle(dataset_indicies)
# get train and validation dataset indicies
train_indicies, valid_indicies = dataset_indicies[split_index:], dataset_indicies[:split_index]
train_sampler = SubsetRandomSampler(train_indicies)
valid_sampler = SubsetRandomSampler(valid_indicies)
# get train and valid loader
train_loader = DataLoader(
dataset=eye_expression_dataset,
batch_size=opt.batch_size,
num_workers=opt.num_workers,
sampler=train_sampler,
collate_fn=collate_fn)
valid_loader = DataLoader(
dataset=eye_expression_dataset,
batch_size=opt.batch_size,
num_workers=opt.num_workers,
sampler=valid_sampler,
collate_fn=collate_fn)
return train_loader, valid_loader