-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconsolidate.py
65 lines (52 loc) · 1.95 KB
/
consolidate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import pickle
from tqdm.contrib import tenumerate
from typing import List
def char_gen(string):
for ch in string:
yield ch
def time_decoder(decoded: str, batch_decoded: List) -> List[int]:
"""
This is a function that takes in two arguments:
decoded: A string of decoded text
batch_decoded: A list of strings of decoded text
It returns a list of integers, which are the indices of the start
of each word in the decoded string.
This is done by iterating through the characters in the decoded
string, and comparing them to the characters in the batch_decoded
list. If a character matches, it is added to a list of indices.
Once all characters have been compared, the function returns the
list of indices.
For example, if decoded = "hello world", and
batch_decoded = ["hello", "world","again"],
then this function would return [0, 1]
"""
gen_decoded = char_gen(decoded)
decoded_index = 0
char = next(gen_decoded)
add_char = True
start_list = []
end_list = []
for index, token in tenumerate(batch_decoded):
if token == char:
if add_char == True:
start_list.append(index)
add_char = False
decoded_index += 1
if decoded_index == len(decoded):
break
else:
char = next(gen_decoded)
if char == " ":
end_list.append(index + 1)
decoded_index += 1
if decoded_index == len(decoded):
break
else:
add_char = True
char = next(gen_decoded)
if len(end_list) != len(start_list):
last_letter = decoded
list_reversed = batch_decoded[::-1]
index_last_letter = len(batch_decoded) - list_reversed.index(decoded[-1])
end_list.append(index_last_letter)
return start_list, end_list