Skip to content

Commit 064829c

Browse files
committed
Fix initialization for position_dependent_vector_size < vector_size
1 parent fa9dfcf commit 064829c

File tree

1 file changed

+40
-14
lines changed

1 file changed

+40
-14
lines changed

gensim/models/fasttext.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,8 @@ def __init__(self, sentences=None, corpus_file=None, sg=0, hs=0, vector_size=100
484484
bucket = 0
485485

486486
self.wv = FastTextKeyedVectors(
487-
vector_size, position_dependent_weights, position_dependent_vector_size, min_n, max_n,
488-
bucket)
487+
vector_size, self.position_dependent_weights, self.position_dependent_vector_size,
488+
min_n, max_n, bucket)
489489
self.wv.bucket = bucket
490490

491491
super(FastText, self).__init__(
@@ -1366,10 +1366,6 @@ def init_ngrams_weights(self, seed, window):
13661366

13671367
rand_obj = np.random.default_rng(seed=seed) # use new instance of numpy's recommended generator/algorithm
13681368

1369-
vocab_shape = (len(self), self.vector_size)
1370-
ngrams_shape = (self.bucket, self.vector_size)
1371-
positions_shape = (2 * window, self.position_dependent_vector_size)
1372-
#
13731369
# We could have initialized vectors_ngrams at construction time, but we
13741370
# do it here for two reasons:
13751371
#
@@ -1379,13 +1375,32 @@ def init_ngrams_weights(self, seed, window):
13791375
# time because the vocab is not initialized at that stage.
13801376
#
13811377
if self.position_dependent_weights:
1378+
vocab_shape = (len(self), self.position_dependent_vector_size)
1379+
ngrams_shape = (self.bucket, self.position_dependent_vector_size)
1380+
positions_shape = (2 * window, self.position_dependent_vector_size)
13821381
hi = sqrt(sqrt(3.0) / self.vector_size)
13831382
lo = -hi
13841383
self.vectors_positions = rand_obj.uniform(lo, hi, positions_shape).astype(REAL)
1384+
self.vectors_vocab = rand_obj.uniform(lo, hi, vocab_shape).astype(REAL)
1385+
self.vectors_ngrams = rand_obj.uniform(lo, hi, ngrams_shape).astype(REAL)
1386+
if self.vector_size > self.position_dependent_vector_size:
1387+
vocab_shape = (len(self), self.vector_size - self.position_dependent_vector_size)
1388+
ngrams_shape = (self.bucket, self.vector_size - self.position_dependent_vector_size)
1389+
lo, hi = -1.0 / self.vector_size, 1.0 / self.vector_size
1390+
self.vectors_vocab = np.concatenate((
1391+
self.vectors_vocab,
1392+
rand_obj.uniform(lo, hi, vocab_shape).astype(REAL),
1393+
))
1394+
self.vectors_ngrams = np.concatenate((
1395+
self.vectors_ngrams,
1396+
rand_obj.uniform(lo, hi, ngrams_shape).astype(REAL),
1397+
))
13851398
else:
1399+
vocab_shape = (len(self), self.vector_size)
1400+
ngrams_shape = (self.bucket, self.vector_size)
13861401
lo, hi = -1.0 / self.vector_size, 1.0 / self.vector_size
1387-
self.vectors_vocab = rand_obj.uniform(lo, hi, vocab_shape).astype(REAL)
1388-
self.vectors_ngrams = rand_obj.uniform(lo, hi, ngrams_shape).astype(REAL)
1402+
self.vectors_vocab = rand_obj.uniform(lo, hi, vocab_shape).astype(REAL)
1403+
self.vectors_ngrams = rand_obj.uniform(lo, hi, ngrams_shape).astype(REAL)
13891404

13901405
def update_ngrams_weights(self, seed, old_vocab_len):
13911406
"""Update the vocabulary weights for training continuation.
@@ -1408,8 +1423,13 @@ def update_ngrams_weights(self, seed, old_vocab_len):
14081423
rand_obj.seed(seed)
14091424

14101425
new_vocab = len(self) - old_vocab_len
1411-
self.vectors_vocab = _pad_random(self.vectors_vocab, new_vocab, rand_obj,
1412-
squared=self.position_dependent_weights)
1426+
self.vectors_vocab = _pad_random(
1427+
self.vectors_vocab,
1428+
new_vocab,
1429+
rand_obj,
1430+
position_dependent_weights=self.position_dependent_weights,
1431+
position_dependent_vector_size=self.position_dependent_vector_size,
1432+
)
14131433

14141434
def init_post_load(self, fb_vectors):
14151435
"""Perform initialization after loading a native Facebook model.
@@ -1476,16 +1496,22 @@ def recalc_char_ngram_buckets(self):
14761496
)
14771497

14781498

1479-
def _pad_random(m, new_rows, rand, squared=False):
1499+
def _pad_random(m, new_rows, rand, position_dependent_weights=False, position_dependent_vector_size=0):
14801500
"""Pad a matrix with additional rows filled with random values."""
14811501
_, columns = m.shape
1482-
shape = (new_rows, columns)
1483-
if squared:
1502+
if position_dependent_weights:
1503+
shape = (new_rows, position_dependent_vector_size)
14841504
high = sqrt(sqrt(3.0) / columns)
14851505
low = -high
1506+
suffix = rand.uniform(low, high, shape).astype(REAL)
1507+
if columns > position_dependent_vector_size:
1508+
shape = (new_rows, columns - position_dependent_vector_size)
1509+
low, high = -1.0 / columns, 1.0 / columns
1510+
suffix = np.concatenate((suffix, rand.uniform(low, high, shape).astype(REAL)))
14861511
else:
1512+
shape = (new_rows, columns)
14871513
low, high = -1.0 / columns, 1.0 / columns
1488-
suffix = rand.uniform(low, high, shape).astype(REAL)
1514+
suffix = rand.uniform(low, high, shape).astype(REAL)
14891515
return vstack([m, suffix])
14901516

14911517

0 commit comments

Comments
 (0)