Skip to content

Commit dc56ae2

Browse files
committed
Support position-dependent weighting with fastText CBOW and negatives
1 parent c0e0169 commit dc56ae2

File tree

3 files changed

+84
-21
lines changed

3 files changed

+84
-21
lines changed

gensim/models/fasttext.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def __init__(self, sentences=None, corpus_file=None, sg=0, hs=0, vector_size=100
312312
max_vocab_size=None, word_ngrams=1, sample=1e-3, seed=1, workers=3, min_alpha=0.0001,
313313
negative=5, ns_exponent=0.75, cbow_mean=1, hashfxn=hash, epochs=5, null_word=0, min_n=3, max_n=6,
314314
sorted_vocab=1, bucket=2000000, trim_rule=None, batch_words=MAX_WORDS_IN_BATCH, callbacks=(),
315-
max_final_vocab=None):
315+
max_final_vocab=None, position_dependent_weights=0):
316316
"""Train, use and evaluate word representations learned using the method
317317
described in `Enriching Word Vectors with Subword Information <https://arxiv.org/abs/1607.04606>`_,
318318
aka FastText.
@@ -421,6 +421,14 @@ def __init__(self, sentences=None, corpus_file=None, sg=0, hs=0, vector_size=100
421421
``min_count```. If the specified ``min_count`` is more than the
422422
automatically calculated ``min_count``, the former will be used.
423423
Set to ``None`` if not required.
424+
position_dependent_weights : {1,0}, optional
425+
If position vectors should be computed beside word and n-gram vectors, and used to weight the
426+
context words during the training (1), or if all context words should be uniformly weighted (0).
427+
428+
Notes
429+
-----
430+
Positional vectors are only implemented for CBOW with negative sampling, not SG or hierarchical softmax.
431+
Locking positional vectors is not supported. BLAS primitives are not used by the implementation.
424432
425433
Examples
426434
--------
@@ -451,6 +459,10 @@ def __init__(self, sentences=None, corpus_file=None, sg=0, hs=0, vector_size=100
451459
self.callbacks = callbacks
452460
if word_ngrams != 1:
453461
raise NotImplementedError("Gensim's FastText implementation does not yet support word_ngrams != 1.")
462+
if position_dependent_weights and (sg or hs):
463+
raise NotImplementedError("Gensim's FastText implementation does not yet support position-dependent "
464+
"weighting with SG or hierarchical softmax")
465+
self.position_dependent_weights = position_dependent_weights
454466
self.word_ngrams = word_ngrams
455467
if max_n < min_n:
456468
# with no eligible char-ngram lengths, no buckets need be allocated
@@ -468,7 +480,8 @@ def __init__(self, sentences=None, corpus_file=None, sg=0, hs=0, vector_size=100
468480
seed=seed, hs=hs, negative=negative, cbow_mean=cbow_mean, min_alpha=min_alpha)
469481

470482
def prepare_weights(self, update=False):
471-
"""In addition to superclass allocations, compute ngrams of all words present in vocabulary.
483+
"""In addition to superclass allocations, compute ngrams of all words present in vocabulary
484+
and initialize positional vectors.
472485
473486
Parameters
474487
----------
@@ -479,6 +492,8 @@ def prepare_weights(self, update=False):
479492
super(FastText, self).prepare_weights(update=update)
480493
if not update:
481494
self.wv.init_ngrams_weights(self.seed)
495+
if self.position_dependent_weights:
496+
self.wv.init_positional_weights(self.seed, self.window)
482497
# EXPERIMENTAL lockf feature; create minimal no-op lockf arrays (1 element of 1.0)
483498
# advanced users should directly resize/adjust as necessary
484499
self.wv.vectors_vocab_lockf = ones(1, dtype=REAL)
@@ -570,6 +585,8 @@ def build_vocab(self, corpus_iterable=None, corpus_file=None, update=False, prog
570585
"""
571586
if not update:
572587
self.wv.init_ngrams_weights(self.seed)
588+
if self.position_dependent_weights:
589+
self.wv.init_positional_weights(self.seed, self.window)
573590
elif not len(self.wv):
574591
raise RuntimeError(
575592
"You cannot do an online vocabulary-update of a model which has no prior vocabulary. "
@@ -1190,6 +1207,7 @@ def __init__(self, vector_size, min_n, max_n, bucket):
11901207
self.vectors_vocab = None # fka syn0_vocab
11911208
self.vectors_ngrams = None # fka syn0_ngrams
11921209
self.buckets_word = None
1210+
self.vectors_positions = None
11931211
self.min_n = min_n
11941212
self.max_n = max_n
11951213
self.bucket = bucket # count of buckets, fka num_ngram_vectors
@@ -1329,7 +1347,6 @@ def init_ngrams_weights(self, seed):
13291347
vocab_shape = (len(self), self.vector_size)
13301348
ngrams_shape = (self.bucket, self.vector_size)
13311349
self.vectors_vocab = rand_obj.uniform(lo, hi, vocab_shape).astype(REAL)
1332-
13331350
#
13341351
# We could have initialized vectors_ngrams at construction time, but we
13351352
# do it here for two reasons:
@@ -1341,6 +1358,25 @@ def init_ngrams_weights(self, seed):
13411358
#
13421359
self.vectors_ngrams = rand_obj.uniform(lo, hi, ngrams_shape).astype(REAL)
13431360

1361+
def init_positional_weights(self, seed, window):
1362+
"""Initialize the positional weights prior to training.
1363+
1364+
Creates the weight matrix and initializes it with uniform random values.
1365+
1366+
Parameters
1367+
----------
1368+
seed : float
1369+
The seed for the PRNG.
1370+
window : int
1371+
The size of the window used during the training.
1372+
1373+
"""
1374+
rand_obj = np.random.default_rng(seed=seed) # use new instance of numpy's recommended generator/algorithm
1375+
1376+
lo, hi = -1.0 / self.vector_size, 1.0 / self.vector_size
1377+
positional_shape = (2 * window, self.vector_size)
1378+
self.vectors_positions = rand_obj.uniform(lo, hi, positional_shape).astype(REAL)
1379+
13441380
def update_ngrams_weights(self, seed, old_vocab_len):
13451381
"""Update the vocabulary weights for training continuation.
13461382

gensim/models/fasttext_inner.pxd

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,18 @@ cdef struct FastTextConfig:
4646
#
4747
# Model parameters. These get copied as-is from the Python model.
4848
#
49-
int sg, hs, negative, sample, size, window, cbow_mean, workers
49+
int sg, hs, pdw, negative, sample, size, window, cbow_mean, workers
5050
REAL_t alpha
5151

5252
#
53-
# The syn0_vocab and syn0_ngrams arrays store vectors for vocabulary terms
54-
# and ngrams, respectively, as 1D arrays in scanline order. For example,
55-
# syn0_vocab[i * size : (i + 1) * size] contains the elements for the ith
56-
# vocab term.
53+
# The syn0_vocab, syn0_ngrams, and syn0_positions arrays store vectors for
54+
# vocabulary terms, ngrams, and positions, respectively, as 1D arrays in
55+
# scanline order. For example, syn0_vocab[i * size : (i + 1) * size]
56+
# contains the elements for the ith vocab term.
5757
#
5858
REAL_t *syn0_vocab
5959
REAL_t *syn0_ngrams
60+
REAL_t *syn0_positions
6061

6162
#
6263
# EXPERIMENTAL

gensim/models/fasttext_inner.pyx

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -242,22 +242,32 @@ cdef void fasttext_fast_sentence_cbow_neg(FastTextConfig *c, int i, int j, int k
242242

243243
cdef long long row2
244244
cdef unsigned long long modulo = 281474976710655ULL
245-
cdef REAL_t f, g, count, inv_count = 1.0, label, f_dot
245+
cdef REAL_t f, g, count, inv_count = 1.0, label, f_dot, positional_feature
246246
cdef np.uint32_t target_index, word_index
247-
cdef int d, m
247+
cdef int d, m, n, o
248248

249249
word_index = c.indexes[i]
250250

251251
memset(c.neu1, 0, c.size * cython.sizeof(REAL_t))
252252
count = <REAL_t>0.0
253+
n = j - i + c.window
253254
for m in range(j, k):
254255
if m == i:
255256
continue
256257
count += ONEF
257-
our_saxpy(&c.size, &ONEF, &c.syn0_vocab[c.indexes[m] * c.size], &ONE, c.neu1, &ONE)
258-
for d in range(c.subwords_idx_len[m]):
259-
count += ONEF
260-
our_saxpy(&c.size, &ONEF, &c.syn0_ngrams[c.subwords_idx[m][d] * c.size], &ONE, c.neu1, &ONE)
258+
if c.pdw:
259+
for d in range(c.size): # TODO make into a Hadamard product using a BLAS primitive: DSBMV, followed by SAXPY
260+
c.neu1[d] += c.syn0_vocab[c.indexes[m] * c.size + d] * c.syn0_positions[n * c.size + d]
261+
for o in range(c.subwords_idx_len[m]):
262+
count += ONEF
263+
for d in range(c.size): # TODO make into a Hadamard product using a BLAS primitive: DSBMV, followed by SAXPY
264+
c.neu1[d] += c.syn0_ngrams[c.subwords_idx[m][o] * c.size + d] * c.syn0_positions[n * c.size + d]
265+
else:
266+
our_saxpy(&c.size, &ONEF, &c.syn0_vocab[c.indexes[m] * c.size], &ONE, c.neu1, &ONE)
267+
for o in range(c.subwords_idx_len[m]):
268+
count += ONEF
269+
our_saxpy(&c.size, &ONEF, &c.syn0_ngrams[c.subwords_idx[m][o] * c.size], &ONE, c.neu1, &ONE)
270+
n += 1
261271

262272
if count > (<REAL_t>0.5):
263273
inv_count = ONEF / count
@@ -293,16 +303,29 @@ cdef void fasttext_fast_sentence_cbow_neg(FastTextConfig *c, int i, int j, int k
293303
if not c.cbow_mean: # divide error over summed window vectors
294304
sscal(&c.size, &inv_count, c.work, &ONE)
295305

296-
for m in range(j,k):
306+
n = j - i + c.window
307+
for m in range(j, k):
297308
if m == i:
298309
continue
299-
our_saxpy(
300-
&c.size, &c.vocab_lockf[c.indexes[m] % c.vocab_lockf_len], c.work, &ONE,
301-
&c.syn0_vocab[c.indexes[m]*c.size], &ONE)
302-
for d in range(c.subwords_idx_len[m]):
310+
if c.pdw:
311+
for d in range(c.size): # TODO make into a Hadamard product using a BLAS primitive: DSBMV, followed by SAXPY
312+
positional_feature = c.syn0_positions[n * c.size + d]
313+
c.syn0_positions[n * c.size + d] += c.work[d] * c.syn0_vocab[c.indexes[m] * c.size + d]
314+
c.syn0_vocab[c.indexes[m] * c.size + d] += c.vocab_lockf[c.indexes[m] % c.vocab_lockf_len] * c.work[d] * positional_feature
315+
for o in range(c.subwords_idx_len[m]):
316+
for d in range(c.size): # TODO make into two Hadamard products using a BLAS primitive: DSBMV, followed by SAXPY
317+
positional_feature = c.syn0_positions[n * c.size + d]
318+
c.syn0_positions[n * c.size + d] += c.work[d] * c.syn0_ngrams[c.subwords_idx[m][o] * c.size + d]
319+
c.syn0_ngrams[c.subwords_idx[m][o] * c.size + d] += c.ngrams_lockf[c.subwords_idx[m][o] % c.ngrams_lockf_len] * c.work[d] * positional_feature
320+
else:
303321
our_saxpy(
304-
&c.size, &c.ngrams_lockf[c.subwords_idx[m][d] % c.ngrams_lockf_len], c.work, &ONE,
305-
&c.syn0_ngrams[c.subwords_idx[m][d]*c.size], &ONE)
322+
&c.size, &c.vocab_lockf[c.indexes[m] % c.vocab_lockf_len], c.work, &ONE,
323+
&c.syn0_vocab[c.indexes[m] * c.size], &ONE)
324+
for o in range(c.subwords_idx_len[m]):
325+
our_saxpy(
326+
&c.size, &c.ngrams_lockf[c.subwords_idx[m][o] % c.ngrams_lockf_len], c.work, &ONE,
327+
&c.syn0_ngrams[c.subwords_idx[m][o] * c.size], &ONE)
328+
n += 1
306329

307330

308331
cdef void fasttext_fast_sentence_cbow_hs(FastTextConfig *c, int i, int j, int k) nogil:
@@ -398,9 +421,12 @@ cdef void init_ft_config(FastTextConfig *c, model, alpha, _work, _neu1):
398421
c.cbow_mean = model.cbow_mean
399422
c.window = model.window
400423
c.workers = model.workers
424+
c.pdw = model.position_dependent_weights
401425

402426
c.syn0_vocab = <REAL_t *>(np.PyArray_DATA(model.wv.vectors_vocab))
403427
c.syn0_ngrams = <REAL_t *>(np.PyArray_DATA(model.wv.vectors_ngrams))
428+
if c.pdw:
429+
c.syn0_positions = <REAL_t *>(np.PyArray_DATA(model.wv.vectors_positions))
404430

405431
# EXPERIMENTAL lockf scaled suppression/enablement of training
406432
c.vocab_lockf = <REAL_t *>(np.PyArray_DATA(model.wv.vectors_vocab_lockf))

0 commit comments

Comments
 (0)