-
Notifications
You must be signed in to change notification settings - Fork 0
/
boundary_creator.py
124 lines (94 loc) · 3.73 KB
/
boundary_creator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import sentencepiece as spm
class BoundaryCreator():
def __init__(
self,
boundaries_type,
fixed_sf,
whitespace_id,
**kwargs,
):
self.boundaries_type = boundaries_type
self.whitespace_id = whitespace_id
if boundaries_type == 'fixed':
assert fixed_sf > 0
self.fixed_sf = fixed_sf
def get_boundaries(self, txt=None, tensor=None):
"""
Function that generates boundaries for given tensor of data
Attributes:
data - (torch.LongTensor) - [seq_len x batch_size]
Returns:
boundaries - (torch.BoolTensor) - [batch_size x seq_len]
"""
assert tensor is not None
data = tensor
data = data.transpose(0, 1) # batch_size x seq_len
boundaries = torch.zeros_like(data, dtype=torch.bool)
if self.boundaries_type == 'whitespaces':
boundaries |= (data == self.whitespace_id)
elif self.boundaries_type == 'fixed':
boundaries[:, ::self.fixed_sf] = 1
else:
return None
return boundaries
class SPMBoundaries():
def __init__(self, tokenizer_path, **kwargs):
self.tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path)
def get_boundaries(self, txt=None, tensor=None):
"""
This boundaries are compatible with boundary ends group
For each modality/dataset it's worth to investigate produced
boundaries with debugger
Attributes:
data - (torch.LongTensor) - [seq_len x batch_size]
Returns:
boundaries - (torch.BoolTensor) - [batch_size x seq_len]
"""
assert txt is not None
data = txt
words_set = set()
batch_size = len(data)
for i in range(batch_size):
words_set.update(data[i].split(' '))
words_list = list(words_set)
words_segmentation = {}
for word, segmentation in zip(words_list,
self.tokenizer.encode(words_list,
out_type=str)):
if word == '':
words_segmentation[''] = [0]
continue
else:
assert len(segmentation)
assert len(segmentation[0])
assert segmentation[0].startswith('▁')
if segmentation[0] == '▁':
segmentation = segmentation[1:]
else:
segmentation[0] = segmentation[0][1:]
words_segmentation[word] = [len(x) for x in segmentation]
assert len(word) == sum(words_segmentation[word])
sample_lengths = []
for i in range(batch_size):
words_lengths = [words_segmentation[word] for word in data[i].split(" ")]
pieces_lengths = [
((y + 1) if (i > 0 and j == (len(sublengths) - 1)) else y)
for i, sublengths in enumerate(words_lengths)
for j, y in enumerate(sublengths)
]
sample_lengths.append(torch.tensor(pieces_lengths))
total_lengths = [x.sum().item() for x in sample_lengths]
assert len(set(total_lengths)) == 1
assert total_lengths[0] == len(data[0])
boundaries = torch.zeros(batch_size, total_lengths[0])
for i in range(batch_size):
boundaries[i, sample_lengths[i].cumsum(dim=0)[:-1]] = 1
return boundaries
def get_boundary_creator(boundaries_type, **kwargs):
if boundaries_type == 'unigram':
return SPMBoundaries(**kwargs)
else:
return BoundaryCreator(boundaries_type, **kwargs)
if __name__ == '__main__':
pass