-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathchanges in torchtext.txt
81 lines (75 loc) · 3.99 KB
/
changes in torchtext.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#####################################################################################
################################# dataset.py ########################################
#####################################################################################
def check_split_ratio(split_ratio):
"""Check that the split ratio argument is not malformed"""
valid_ratio = 0.
if isinstance(split_ratio, float):
# Only the train set relative ratio is provided
# Assert in bounds, validation size is zero
assert 0. < split_ratio < 1., (
"Split ratio {} not between 0 and 1".format(split_ratio))
test_ratio = (1. - split_ratio) / 2
valid_ratio = (1. - split_ratio) / 2
return (split_ratio, test_ratio, valid_ratio)
elif isinstance(split_ratio, list):
# A list of relative ratios is provided
length = len(split_ratio)
assert length == 2 or length == 3, (
"Length of split ratio list should be 2 or 3, got {}".format(split_ratio))
# Normalize if necessary
ratio_sum = sum(split_ratio)
if not ratio_sum == 1.:
split_ratio = [float(ratio) / ratio_sum for ratio in split_ratio]
if length == 2:
return tuple(split_ratio + [valid_ratio])
return tuple(split_ratio)
else:
raise ValueError('Split ratio must be float or a list, got {}'
.format(type(split_ratio)))
#####################################################################################
################################# vocab.py ##########################################
#####################################################################################
class Vectors(object):
....
def __getitem__(self, token):
if token in self.stoi:
return self.vectors[self.stoi[token]]
else:
return self.unk_init(torch.Tensor(self.dim), token)
.....
#####################################################################################
################################# vocab.py ##########################################
#####################################################################################
pretrained_aliases = {
"charngram.100d": partial(CharNGram),
"fasttext.en.300d": partial(FastText, language="en"),
"fasttext.tr.300d": partial(FastText, language="tr"), # Added Turkish Fasttext
"fasttext.simple.300d": partial(FastText, language="simple"),
"glove.42B.300d": partial(GloVe, name="42B", dim="300"),
"glove.840B.300d": partial(GloVe, name="840B", dim="300"),
"glove.twitter.27B.25d": partial(GloVe, name="twitter.27B", dim="25"),
"glove.twitter.27B.50d": partial(GloVe, name="twitter.27B", dim="50"),
"glove.twitter.27B.100d": partial(GloVe, name="twitter.27B", dim="100"),
"glove.twitter.27B.200d": partial(GloVe, name="twitter.27B", dim="200"),
"glove.6B.50d": partial(GloVe, name="6B", dim="50"),
"glove.6B.100d": partial(GloVe, name="6B", dim="100"),
"glove.6B.200d": partial(GloVe, name="6B", dim="200"),
"glove.6B.300d": partial(GloVe, name="6B", dim="300")
}
#####################################################################################
################################# utils.py ##########################################
#####################################################################################
def unicode_csv_reader(unicode_csv_data, **kwargs):
"""Since the standard csv library does not handle unicode in Python 2, we need a wrapper.
Borrowed and slightly modified from the Python docs:
https://docs.python.org/2/library/csv.html#csv-examples"""
if six.PY2:
# csv.py doesn't do Unicode; encode temporarily as UTF-8:
csv_reader = csv.reader(utf_8_encoder(unicode_csv_data), **kwargs)
for row in csv_reader:
# decode UTF-8 back to Unicode, cell by cell:
yield [cell.decode('utf-8') for cell in row]
else:
for line in csv.reader(unicode_csv_data, quotechar=None, **kwargs):
yield line