Skip to content

Commit 8f9110a

Browse files
committed
Fix initialization for position_dependent_vector_size < vector_size
1 parent fa9dfcf commit 8f9110a

File tree

1 file changed

+49
-14
lines changed

1 file changed

+49
-14
lines changed

gensim/models/fasttext.py

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,8 @@ def __init__(self, sentences=None, corpus_file=None, sg=0, hs=0, vector_size=100
484484
bucket = 0
485485

486486
self.wv = FastTextKeyedVectors(
487-
vector_size, position_dependent_weights, position_dependent_vector_size, min_n, max_n,
488-
bucket)
487+
vector_size, self.position_dependent_weights, position_dependent_vector_size,
488+
min_n, max_n, bucket)
489489
self.wv.bucket = bucket
490490

491491
super(FastText, self).__init__(
@@ -1366,10 +1366,6 @@ def init_ngrams_weights(self, seed, window):
13661366

13671367
rand_obj = np.random.default_rng(seed=seed) # use new instance of numpy's recommended generator/algorithm
13681368

1369-
vocab_shape = (len(self), self.vector_size)
1370-
ngrams_shape = (self.bucket, self.vector_size)
1371-
positions_shape = (2 * window, self.position_dependent_vector_size)
1372-
#
13731369
# We could have initialized vectors_ngrams at construction time, but we
13741370
# do it here for two reasons:
13751371
#
@@ -1379,13 +1375,38 @@ def init_ngrams_weights(self, seed, window):
13791375
# time because the vocab is not initialized at that stage.
13801376
#
13811377
if self.position_dependent_weights:
1378+
vocab_shape = (len(self), self.position_dependent_vector_size)
1379+
ngrams_shape = (self.bucket, self.position_dependent_vector_size)
1380+
positions_shape = (2 * window, self.position_dependent_vector_size)
13821381
hi = sqrt(sqrt(3.0) / self.vector_size)
13831382
lo = -hi
13841383
self.vectors_positions = rand_obj.uniform(lo, hi, positions_shape).astype(REAL)
1384+
self.vectors_vocab = rand_obj.uniform(lo, hi, vocab_shape).astype(REAL)
1385+
self.vectors_ngrams = rand_obj.uniform(lo, hi, ngrams_shape).astype(REAL)
1386+
if self.vector_size > self.position_dependent_vector_size:
1387+
vocab_shape = (len(self), self.vector_size - self.position_dependent_vector_size)
1388+
ngrams_shape = (self.bucket, self.vector_size - self.position_dependent_vector_size)
1389+
lo, hi = -1.0 / self.vector_size, 1.0 / self.vector_size
1390+
self.vectors_vocab = np.concatenate(
1391+
(
1392+
self.vectors_vocab,
1393+
rand_obj.uniform(lo, hi, vocab_shape).astype(REAL),
1394+
),
1395+
axis=-1,
1396+
)
1397+
self.vectors_ngrams = np.concatenate(
1398+
(
1399+
self.vectors_ngrams,
1400+
rand_obj.uniform(lo, hi, ngrams_shape).astype(REAL),
1401+
),
1402+
axis=-1,
1403+
)
13851404
else:
1405+
vocab_shape = (len(self), self.vector_size)
1406+
ngrams_shape = (self.bucket, self.vector_size)
13861407
lo, hi = -1.0 / self.vector_size, 1.0 / self.vector_size
1387-
self.vectors_vocab = rand_obj.uniform(lo, hi, vocab_shape).astype(REAL)
1388-
self.vectors_ngrams = rand_obj.uniform(lo, hi, ngrams_shape).astype(REAL)
1408+
self.vectors_vocab = rand_obj.uniform(lo, hi, vocab_shape).astype(REAL)
1409+
self.vectors_ngrams = rand_obj.uniform(lo, hi, ngrams_shape).astype(REAL)
13891410

13901411
def update_ngrams_weights(self, seed, old_vocab_len):
13911412
"""Update the vocabulary weights for training continuation.
@@ -1408,8 +1429,13 @@ def update_ngrams_weights(self, seed, old_vocab_len):
14081429
rand_obj.seed(seed)
14091430

14101431
new_vocab = len(self) - old_vocab_len
1411-
self.vectors_vocab = _pad_random(self.vectors_vocab, new_vocab, rand_obj,
1412-
squared=self.position_dependent_weights)
1432+
self.vectors_vocab = _pad_random(
1433+
self.vectors_vocab,
1434+
new_vocab,
1435+
rand_obj,
1436+
position_dependent_weights=self.position_dependent_weights,
1437+
position_dependent_vector_size=self.position_dependent_vector_size,
1438+
)
14131439

14141440
def init_post_load(self, fb_vectors):
14151441
"""Perform initialization after loading a native Facebook model.
@@ -1476,16 +1502,25 @@ def recalc_char_ngram_buckets(self):
14761502
)
14771503

14781504

1479-
def _pad_random(m, new_rows, rand, squared=False):
1505+
def _pad_random(m, new_rows, rand, position_dependent_weights=False, position_dependent_vector_size=0):
14801506
"""Pad a matrix with additional rows filled with random values."""
14811507
_, columns = m.shape
1482-
shape = (new_rows, columns)
1483-
if squared:
1508+
if position_dependent_weights:
1509+
shape = (new_rows, position_dependent_vector_size)
14841510
high = sqrt(sqrt(3.0) / columns)
14851511
low = -high
1512+
suffix = rand.uniform(low, high, shape).astype(REAL)
1513+
if columns > position_dependent_vector_size:
1514+
shape = (new_rows, columns - position_dependent_vector_size)
1515+
low, high = -1.0 / columns, 1.0 / columns
1516+
suffix = np.concatenate(
1517+
(suffix, rand.uniform(low, high, shape).astype(REAL)),
1518+
axis=-1,
1519+
)
14861520
else:
1521+
shape = (new_rows, columns)
14871522
low, high = -1.0 / columns, 1.0 / columns
1488-
suffix = rand.uniform(low, high, shape).astype(REAL)
1523+
suffix = rand.uniform(low, high, shape).astype(REAL)
14891524
return vstack([m, suffix])
14901525

14911526

0 commit comments

Comments
 (0)