forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
filter_ngrams.py
479 lines (394 loc) · 17.9 KB
/
filter_ngrams.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""
Deduplicate downstream tasks from training dataset. 13-grams have been used.
All split documents with less than 200 characters got filtered. Any document
with more than 10 splits got filtered as well.
"""
import argparse
from functools import partial
import json
import multiprocessing
import nltk
import pickle
import re
import string
import sys
import time
def get_words(text):
# get all the lowercase words from text
words, positions = [], []
for match in re.finditer(r'\w+', text.lower()):
words.append(match.group(0))
positions.append(match.start())
return words, positions
# splits the text
def split_text(text, start_position, remove_char_each_side, seq):
# first part of the text
punctuations = ".!?"
pos = start_position - remove_char_each_side
text_first = ""
while pos > 0 and not text[pos] in punctuations:
pos -= 1
if pos > 0:
text_first = text[0:pos+1]
# add length of seq and remove_char_each_side
pos = start_position + len(seq) + remove_char_each_side
# last part of the text
text_second = ""
while pos < len(text) and not text[pos] in punctuations:
pos += 1
if pos + 1 < len(text):
text_second = text[pos+1:len(text)]
return text_first, text_second
def check_and_clean_text(args, words, ngrams, text, start_position, \
text_buf_ngram_free, text_buf, local_ngram):
seq = " ".join(words)
if seq in ngrams:
print(" [matched]: {}".format(seq), flush=True)
if args.get_ngram_freq_only:
# increase freq of this seq and then only consider the later part
# of the text for further processing
if seq in local_ngram:
local_ngram[seq] += 1
else:
local_ngram[seq] = 1
#print(" [increased]: {} {}".format(seq, ngrams[seq]), flush=True)
if (start_position + len(seq) + 1) < len(text):
text_buf.append(text[start_position + len(seq) + 1:len(text)])
return False
# split the text
text_first, text_second = split_text(text, start_position, \
args.remove_char_each_side, seq)
# first part of ngrams free
if len(text_first) > args.filter_text_char_len:
text_buf_ngram_free.append(text_first)
# add second part for further processing
if len(text_second) > args.filter_text_char_len:
text_buf.append(text_second)
return False # not ngram free
# ngram free
return True
def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
# remove all the ngrams
try:
myjson = json.loads(line)
text_buf = [myjson[key]]
except Exception as e:
print("Error: {}".format(e), flush=True)
text_buf = []
text_buf_ngram_free = []
local_ngram = {}
while len(text_buf) > 0:
# get the first one from the buffer
text = text_buf.pop(0)
words, positions = get_words(text)
ngram_free = True
# find each max n-grams and check dictionary
for i in range(len(words) - args.max_ngram_size + 1):
check_ngram_free = check_and_clean_text(args, words[i:\
i+args.max_ngram_size], ngrams, text, positions[i], \
text_buf_ngram_free, text_buf, local_ngram)
# the seq is ngram free? if yes, break
if not check_ngram_free:
ngram_free = False
break
# if max ngrams doesn't match, check if any other lower n-grams
# within max ngram macthes
for ngram_len, _ in ngrams_freq_sorted:
check_ngram_free = check_and_clean_text(args, words[i:\
i+ngram_len], ngrams, text, positions[i], \
text_buf_ngram_free, text_buf, local_ngram)
# same check as above
if not check_ngram_free:
ngram_free = False
break
# check break from lower than max ngram loop above
if not ngram_free:
break
# for the last max n-gram, check all the lower ngrams in it
if ngram_free and len(words) - args.max_ngram_size > 0:
# get the last words of the lax max ngram
last_seq_words = words[(len(words)-args.max_ngram_size):len(words)]
last_seq_start_position = len(words) - args.max_ngram_size
# check all n-grams lower than the max
for pos, (ngram_len, _) in enumerate(ngrams_freq_sorted):
# ignore the max ngram as has been considered already
if ngram_len == args.max_ngram_size:
continue
# find each ngram of ngram_len in max n-grams and check
for i in range(len(last_seq_words) - ngram_len + 1):
check_ngram_free = check_and_clean_text(args, \
last_seq_words[i:i+ngram_len], ngrams, text,\
positions[last_seq_start_position+i], \
text_buf_ngram_free, text_buf, local_ngram)
if not check_ngram_free:
ngram_free = False
break
if not ngram_free:
break
# texts are ngram free
if ngram_free and not args.get_ngram_freq_only:
text_buf_ngram_free.append(text)
# check if the text has only been trimmed
trimmed = 0
if not args.get_ngram_freq_only and len(text_buf_ngram_free) == 1 and \
len(text_buf_ngram_free[0]) < len(myjson[key]):
trimmed = 1
return text_buf_ngram_free, trimmed, myjson, local_ngram
# insert word sequence into dictionary
def insert_dict(words, ngrams, pos):
seq = " ".join(words)
if seq not in ngrams:
ngrams[seq] = 0
#ngrams[seq] = pos
# insert each ngram from text into the ngrams dictionary
def compute_ngrams_insert_dict(args, text, ngrams):
words, positions = get_words(text)
if len(words) < args.min_ngram_size:
return
if len(words) < args.max_ngram_size:
insert_dict(words, ngrams, positions[0])
for i in range(len(words) - args.max_ngram_size+1):
insert_dict(words[i:i+args.max_ngram_size], ngrams, positions[i])
# Build ngrams for the lambada dataset
def process_task_lambda(args, task_file, ngrams):
print(' reading from {} and computing ngrams'.format(task_file))
with open(task_file, 'r') as f:
for line in f:
try:
myjson = json.loads(line)
text = myjson['text']
compute_ngrams_insert_dict(args, text, ngrams)
except Exception as e:
print('Error:', e)
print(" Entities in ngrams {}".format(len(ngrams)), flush=True)
# Build ngrams for the dataset of the given task
def process_task(args, task_name, ngrams):
print(' reading from {} and computing ngrams'.format('import datasets'))
print(" Current entities in ngrams {}".format(len(ngrams)), flush=True)
# using validation/test data from datasets
from datasets import load_dataset
entities_in_ngrams = len(ngrams)
# load the dataset
if task_name == 'squad':
dataset = load_dataset('squad_v2', split='validation')
elif task_name == 'natural_questions':
dataset = load_dataset('natural_questions', split='validation')
elif task_name == 'triviaqa':
dataset = load_dataset('trivia_qa', 'unfiltered', split='test')
elif task_name == 'webqa':
dataset = load_dataset('web_questions', split='test')
elif task_name == 'race':
dataset = load_dataset('race', 'all', split='test')
elif task_name == 'drop':
dataset = load_dataset('drop', split='validation')
elif task_name == 'coqa':
dataset = load_dataset('coqa', split='validation')
elif task_name == 'piqa':
dataset = load_dataset('piqa', split='test')
else:
print("Invalid task name: {}".format(task_name), flush=True)
return
# read the dataset and add to ngrams
for line in dataset:
try:
if task_name in ['squad', 'triviaqa', 'webqa', 'race', 'drop']:
text = line['question']
compute_ngrams_insert_dict(args, text, ngrams)
elif task_name == 'natural_questions':
text = line['question']['text']
compute_ngrams_insert_dict(args, text, ngrams)
elif task_name == 'coqa':
all_questions = line['questions']
for question in all_questions:
compute_ngrams_insert_dict(args, question, ngrams)
elif task_name == 'piqa':
text = line['goal']
compute_ngrams_insert_dict(args, text, ngrams)
except Exception as e:
print('Error:', e)
print(" After task {} entities in ngrams {}, added {}".format(task_name, \
len(ngrams), len(ngrams) - entities_in_ngrams), flush=True)
def compute_tasks_ngrams(args, ngrams):
start_time = time.time()
for _, task_name in enumerate(args.tasks):
print('Task: {}'.format(task_name), flush=True)
if task_name == 'lambada':
assert args.lambada_path is not None
process_task_lambda(args, args.lambada_path, ngrams)
else:
process_task(args, task_name, ngrams)
print(" Taken time to compute ngrams {:.2f}".format(time.time() - \
start_time), flush=True)
def compute_ngram_freq_sorted(args, ngrams):
ngrams_freq = {}
for ngram_key in ngrams.keys():
length = len(ngram_key.split())
ngrams_freq[length] = ngrams_freq[length] + 1 if length in \
ngrams_freq else 1
ngrams_freq_sorted = sorted(ngrams_freq.items(), key=lambda item: item[0])
print(" Ngram frequencies: {}".format(ngrams_freq_sorted), flush=True)
print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\
len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\
ngrams_freq_sorted) -1 ][0]), flush=True)
return ngrams_freq_sorted
def get_ngrams_below_threshold(args, ngrams, ngrams_below_threshold, \
dedup_file, dedup_key, ngrams_freq_sorted):
start_time = time.time()
# get the ngrams frequency
args.get_ngram_freq_only = True
# Open the large file to process in parallel
num_workers = args.num_threads
pool = multiprocessing.Pool(num_workers)
fin = open(dedup_file, 'r', encoding='utf-8')
free_ngram_abt_partial=partial(free_ngram, args=args, key=dedup_key, \
ngrams=ngrams, ngrams_freq_sorted=ngrams_freq_sorted)
free_ngrams_abt = pool.imap(free_ngram_abt_partial, fin, 500)
counter = 0
for _, _, _, local_ngram in free_ngrams_abt:
counter += 1
if counter % 1000 == 0:
print(' [compute_stat]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True)
for local_key in local_ngram:
if local_key in ngrams:
ngrams[local_key] += 1
local_ngram = {}
print(' Time taken to compute statistics {:.2f} seconds'.format(time.time() - \
start_time), flush=True)
pool.close()
pool.join()
start_time = time.time()
counter_threshold = 0
# Get ngram below theadhold
for local_key, local_val in ngrams.items():
if ngrams[local_key] < args.key_threshold:
print(" [threshold] {} {}".format(local_key, local_val), flush=True)
counter_threshold += 1
ngrams_below_threshold[local_key] = 1
print(' Ngrams below threshold {}'.format(counter_threshold), flush=True)
fin.close()
def clean_ngrams_below_threshold(args, ngrams_below_threshold, dedup_file, \
dedup_key):
start_time = time.time()
# Now actually filter the dataset
args.get_ngram_freq_only = False
#id_prefix = '-'.join(args.tasks[::2])
id_prefix = '-'.join(args.tasks[::1])
# get the range of the size of the ngrams
ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams_below_threshold)
# Open the large file to process in parallel
counter = splitted = ignored = split_mt_thld = trimmed_count = 0
num_workers = args.num_threads
pool = multiprocessing.Pool(num_workers)
fin = open(dedup_file, 'r', encoding='utf-8')
free_ngram_clean_partial=partial(free_ngram, args=args, key=dedup_key, \
ngrams=ngrams_below_threshold, ngrams_freq_sorted=ngrams_freq_sorted)
free_ngrams_clean = pool.imap(free_ngram_clean_partial, fin, 500)
out_f = open(args.output, 'wb')
for text_buf_ngram_free, trimmed, myjson, _ in free_ngrams_clean:
counter += 1
try:
trimmed_count += trimmed
if len(text_buf_ngram_free) > 1:
splitted += 1
if len(text_buf_ngram_free) == 0:
ignored += 1
# more than 10 splits ignored
if len(text_buf_ngram_free) > args.splits_count:
text_buf_ngram_free = []
split_mt_thld += 1
if args.output is not None:
if "split_id" in myjson:
use_prefix = myjson["split_id"] + "-"
else:
use_prefix = ""
for i in range(len(text_buf_ngram_free)):
split_id_string = id_prefix + '-{:010d}'.format(int(\
counter)) + '-{:04d}'.format(int(i))
myjson[dedup_key] = text_buf_ngram_free[i]
myjson["split_id"] = use_prefix + split_id_string
outjson = json.dumps(myjson, ensure_ascii=False)
#outjson = json.dumps({"text":text_buf_ngram_free[i],
# id_prefix+"_split_id":split_id_string},
# ensure_ascii=False)
out_f.write(outjson.encode('utf-8'))
out_f.write('\n'.encode('utf-8'))
if counter % 1000 == 0:
print(' [final]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True)
except Exception as e:
print('Error:', e)
print(' [final]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True)
print(' Total docs {} splitted {} ignored {} splits > theshold {} trimmed'\
' {}'.format(counter, splitted, ignored, split_mt_thld, trimmed_count)\
, flush=True)
pool.close()
pool.join()
out_f.close()
fin.close()
if __name__ == '__main__':
# we use 13-grams, any text less than 200 characters got removed
# any text splitted more than 10 got removed as well
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--tasks', nargs = '*', required=True, default=None, \
help = 'Tasks to use for deduplication: currently '
' suuport [lambada, squad, natural_questions,'
' triviaqa, webqa, race, drop, coqa, and piqa]')
parser.add_argument('--lambada-path', type=str, default=None,
help='Only Lambada task needs the path')
parser.add_argument('--dedup-dataset', nargs = '*', default=None,
help='Dataset to deduplicate with the key to use'
' e.g. cc.json text')
parser.add_argument('--output', type=str, default=None,
help='Output file name to save dedup dataset')
parser.add_argument('--num-threads', type=int, default=40,
help='Number of threads to use')
# Default dedup values
parser.add_argument('--max-ngram-size', type=int, default=13,
help='Maximum size of ngram to use.')
parser.add_argument('--min-ngram-size', type=int, default=8,
help='Minimum size of ngram to use.')
parser.add_argument('--filter-text-char-len', type=int, default=200,
help='Remove any text below this length.')
parser.add_argument('--key-threshold', type=int, default=10,
help='Number of keys to consider as threshold')
parser.add_argument('--save-dictionary', type=str, default=None,
help='Save the dictionary')
parser.add_argument('--load-dictionary', type=str, default=None,
help='Load the dictionary')
parser.add_argument('--splits-count', type=int, default=10,
help='Remove any documents more than this many splits')
parser.add_argument('--remove-char-each-side', type=int, default=200,
help='Maximum size of ngram to use.')
args = parser.parse_args()
assert len(args.dedup_dataset) == 2
dedup_file = args.dedup_dataset[0]
dedup_key = args.dedup_dataset[1]
# Setup multi-processing
num_workers = args.num_threads
if args.load_dictionary is None:
# Build ngrams
ngrams = {}
compute_tasks_ngrams(args, ngrams)
# get the range of the size of the ngrams
ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams)
# get ngram freq from large file in parallel
# get ngrams below threshold
ngrams_below_threshold = {}
get_ngrams_below_threshold(args, ngrams, ngrams_below_threshold, \
dedup_file, dedup_key, ngrams_freq_sorted)
# save the dictionary if needed
if args.save_dictionary is not None:
with open(args.save_dictionary, 'wb') as save_dict_handle:
pickle.dump(ngrams_below_threshold, save_dict_handle)
else:
with open(args.load_dictionary, 'rb') as load_dict_handle:
ngrams_below_threshold = pickle.load(load_dict_handle)
# filter the large file
if args.output is not None:
clean_ngrams_below_threshold(args, ngrams_below_threshold, \
dedup_file, dedup_key)
print('done :-)')