-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_val_split_txt.py
59 lines (44 loc) · 1.72 KB
/
train_val_split_txt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import argparse
import os
import numpy as np
def write_lines(file, lines):
with open(file, "w") as out_f:
for el in lines:
out_f.write(f'{el[0]} {el[1]}\n')
parser = argparse.ArgumentParser("Script to split list of labeled files in train and validation sets")
parser.add_argument('input', help="Path to input file, which contains lines with format 'file_path lbl'")
parser.add_argument('--val_size', help="Proportion of data in validation set, default=0.2", default=0.2)
parser.add_argument('--per_class', action='store_true', help="Keep the same class ratios")
args = parser.parse_args()
input_file = args.input
assert os.path.isfile(input_file)
np_lines = np.genfromtxt(input_file, dtype='unicode')
if args.per_class:
class_dict = {}
for el in np_lines:
if not el[1] in class_dict:
class_dict[el[1]] = []
class_dict[el[1]].append(el)
train_out = []
val_out = []
for k in class_dict.keys():
np_cls = np.array(class_dict[k])
np.random.shuffle(np_cls)
val_size = int(args.val_size*len(np_cls))
val_cls = np_cls[:val_size]
train_cls = np_cls[val_size:]
train_out.extend([el for el in train_cls])
val_out.extend([el for el in val_cls])
else:
np.random.shuffle(np_lines)
val_size = int(args.val_size*len(np_lines))
val_out = np_lines[:val_size]
train_out = np_lines[val_size:]
dirname = os.path.dirname(input_file)
if len(dirname) == 0:
dirname = '.'
file_name = os.path.basename(input_file).split('.')[0]
out_train_file = os.path.join(dirname, file_name + '_train.txt')
out_val_file = os.path.join(dirname, file_name + '_val.txt')
write_lines(out_train_file, train_out)
write_lines(out_val_file, val_out)