Skip to content

Commit

Permalink
Revise and simplify Vectors class
Browse files Browse the repository at this point in the history
  • Loading branch information
honnibal committed Oct 31, 2017
1 parent cb52170 commit 77d8f5d
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 145 deletions.
4 changes: 2 additions & 2 deletions spacy/tests/doc/test_doc_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ def test_doc_api_right_edge(en_tokenizer):

def test_doc_api_has_vector():
vocab = Vocab()
vocab.clear_vectors(2)
vocab.vectors.add('kitten', vector=numpy.asarray([0., 2.], dtype='f'))
vocab.reset_vectors(width=2)
vocab.set_vector('kitten', vector=numpy.asarray([0., 2.], dtype='f'))
doc = Doc(vocab, words=['kitten'])
assert doc.has_vector

Expand Down
6 changes: 3 additions & 3 deletions spacy/tests/doc/test_token_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def test_doc_token_api_is_properties(en_vocab):

def test_doc_token_api_vectors():
vocab = Vocab()
vocab.clear_vectors(2)
vocab.vectors.add('apples', vector=numpy.asarray([0., 2.], dtype='f'))
vocab.vectors.add('oranges', vector=numpy.asarray([0., 1.], dtype='f'))
vocab.reset_vectors(width=2)
vocab.set_vector('apples', vector=numpy.asarray([0., 2.], dtype='f'))
vocab.set_vector('oranges', vector=numpy.asarray([0., 1.], dtype='f'))
doc = Doc(vocab, words=['apples', 'oranges', 'oov'])
assert doc.has_vector

Expand Down
4 changes: 2 additions & 2 deletions spacy/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def add_vecs_to_vocab(vocab, vectors):
"""Add list of vector tuples to given vocab. All vectors need to have the
same length. Format: [("text", [1, 2, 3])]"""
length = len(vectors[0][1])
vocab.clear_vectors(length)
vocab.reset_vectors(width=length)
for word, vec in vectors:
vocab.set_vector(word, vec)
vocab.set_vector(word, vector=vec)
return vocab


Expand Down
21 changes: 9 additions & 12 deletions spacy/tests/vectors/test_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,38 +35,35 @@ def vocab(en_vocab, vectors):


def test_init_vectors_with_data(strings, data):
v = Vectors(strings, data=data)
v = Vectors(data=data)
assert v.shape == data.shape

def test_init_vectors_with_width(strings):
v = Vectors(strings, width=3)
for string in strings:
v.add(string)
def test_init_vectors_with_shape(strings):
v = Vectors(shape=(len(strings), 3))
assert v.shape == (len(strings), 3)


def test_get_vector(strings, data):
v = Vectors(strings, data=data)
for string in strings:
v.add(string)
v = Vectors(data=data)
for i, string in enumerate(strings):
v.add(string, row=i)
assert list(v[strings[0]]) == list(data[0])
assert list(v[strings[0]]) != list(data[1])
assert list(v[strings[1]]) != list(data[0])


def test_set_vector(strings, data):
orig = data.copy()
v = Vectors(strings, data=data)
for string in strings:
v.add(string)
v = Vectors(data=data)
for i, string in enumerate(strings):
v.add(string, row=i)
assert list(v[strings[0]]) == list(orig[0])
assert list(v[strings[0]]) != list(orig[1])
v[strings[0]] = data[1]
assert list(v[strings[0]]) == list(orig[1])
assert list(v[strings[0]]) != list(orig[0])



@pytest.fixture()
def tokenizer_v(vocab):
return Tokenizer(vocab, {}, None, None, None)
Expand Down
2 changes: 1 addition & 1 deletion spacy/tests/vocab/test_add_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ def test_vocab_prune_vectors():
remap = vocab.prune_vectors(2)
assert list(remap.keys()) == [u'kitten']
neighbour, similarity = remap.values()[0]
assert neighbour == u'cat'
assert neighbour == u'cat', remap
assert_allclose(similarity, cosine(data[0], data[2]), atol=1e-6)
Loading

0 comments on commit 77d8f5d

Please sign in to comment.