-
Notifications
You must be signed in to change notification settings - Fork 8
/
shuffle.py
47 lines (33 loc) · 1.15 KB
/
shuffle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import random
from argparse import ArgumentParser
from tqdm import tqdm
from krnnt.structure import Paragraph
from krnnt.serial_pickle import SerialPickler, SerialUnpickler
usage = """%prog CORPUS SAVE_PATH
Shuffle training data.
E.g. %prog train-merged.spickle train-merged.shuf.spickle
"""
if __name__ == '__main__':
parser = ArgumentParser(usage=usage)
parser.add_argument('file_path', type=str, help='paths to corpus')
parser.add_argument('output_path', type=str, help='save path')
parser.add_argument('--seed', '-s', type=int, default=1337, help='seed')
args = parser.parse_args()
file_path1 = args.file_path
file_path2 = args.output_path
file = open(file_path1, 'rb')
su = SerialUnpickler(file)
paragraphs = []
paragraph: Paragraph
for paragraph in tqdm(su, desc='Loading', total=18484):
paragraphs.append(paragraph)
file.close()
random.seed(args.seed)
random.shuffle(paragraphs)
file2 = open(file_path2, 'wb')
sp = SerialPickler(file2)
for paragraph in tqdm(paragraphs, desc='Saving'):
sp.add(paragraph)
file2.close()