forked from kimiyoung/transformer-xl
-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathshard_dataset.py
50 lines (39 loc) · 1.38 KB
/
shard_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#!/usr/bin/env python
# shard dataset for transformer-xl codebase
# TODO(y): smarrer sharding to avoid breaking within articles?
import argparse
import math
import os
import sys
parser = argparse.ArgumentParser(description='shard utility')
parser.add_argument('--datadir', type=str,
default='/ncluster/data/transformer-xl-data/wikitext-103',
help='location of train.txt')
parser.add_argument('--shards', type=int, default=4, help='how many ways to shard')
args = parser.parse_args()
def shard(fn):
assert os.path.exists(args.datadir), args.target
if args.shards < 2:
print(f'args.shards is {args.shards}, doing nothing')
sys.exit()
corpus = open(f'{args.datadir}/{fn}').read()
shard_length = int(math.ceil(len(corpus) / args.shards))
offset = 0
for i in range(args.shards):
new_location = f"{args.datadir}-{i:05d}-of-{args.shards:05d}"
if os.path.exists(new_location):
pass
else:
os.system(f'mkdir {new_location}')
shard = corpus[offset:offset + shard_length]
offset += shard_length
target = f'{new_location}/{fn}'
with open(f'{target}', 'w') as f:
print(f"{target}: {len(shard)}")
f.write(shard)
def main():
shard('train.txt')
shard('valid.txt')
shard('test.txt')
if __name__ == '__main__':
main()