Skip to content

Commit

Permalink
[2032] - Changed python set to cpp stl set (explosion#2170)
Browse files Browse the repository at this point in the history
Changed python set to cpp stl set explosion#2032 

## Description

Changed python set to cpp stl set. CPP stl set works better due to the logarithmic run time of its methods. Finding minimum in the cpp set is done in constant time as opposed to the worst case linear runtime of python set. Operations such as find,count,insert,delete are also done in either constant and logarithmic time thus making cpp set a better option to manage vectors.
Reference : http://www.cplusplus.com/reference/set/set/

### Types of change
Enhancement for `Vectors` for faster initialising of word vectors(fasttext)
  • Loading branch information
skrcode authored and honnibal committed Mar 31, 2018
1 parent 6f84e32 commit 1cdbb7c
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 17 deletions.
8 changes: 4 additions & 4 deletions .github/CONTRIBUTOR_AGREEMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,16 @@ mark both statements:
or entity, including my employer, has or will have rights with respect to my
contributions.

* [x] I am signing on behalf of my employer or a legal entity and I have the
* [] I am signing on behalf of my employer or a legal entity and I have the
actual authority to contractually bind that entity.

## Contributor Details

| Field | Entry |
|------------------------------- | -------------------- |
| Name | |
| Name | Suraj Rajan |
| Company name (if applicable) | |
| Title or role (if applicable) | |
| Date | |
| GitHub username | |
| Date | 31/Mar/2018 |
| GitHub username | skrcode |
| Website (optional) | |
26 changes: 26 additions & 0 deletions spacy/tests/vectors/test_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,38 @@ def vectors():
def data():
return numpy.asarray([[0.0, 1.0, 2.0], [3.0, -2.0, 4.0]], dtype='f')

@pytest.fixture
def resize_data():
return numpy.asarray([[0.0, 1.0], [2.0, 3.0]], dtype='f')

@pytest.fixture()
def vocab(en_vocab, vectors):
add_vecs_to_vocab(en_vocab, vectors)
return en_vocab

def test_init_vectors_with_resize_shape(strings,resize_data):
v = Vectors(shape=(len(strings), 3))
v.resize(shape=resize_data.shape)
assert v.shape == resize_data.shape
assert v.shape != (len(strings), 3)

def test_init_vectors_with_resize_data(data,resize_data):
v = Vectors(data=data)
v.resize(shape=resize_data.shape)
assert v.shape == resize_data.shape
assert v.shape != data.shape

def test_get_vector_resize(strings, data,resize_data):
v = Vectors(data=data)
v.resize(shape=resize_data.shape)
strings = [hash_string(s) for s in strings]
for i, string in enumerate(strings):
v.add(string, row=i)

assert list(v[strings[0]]) == list(resize_data[0])
assert list(v[strings[0]]) != list(resize_data[1])
assert list(v[strings[1]]) != list(resize_data[0])
assert list(v[strings[1]]) == list(resize_data[1])

def test_init_vectors_with_data(strings, data):
v = Vectors(data=data)
Expand Down
28 changes: 15 additions & 13 deletions spacy/vectors.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ from .strings cimport StringStore, hash_string
from .compat import basestring_, path2str
from . import util

from cython.operator cimport dereference as deref
from libcpp.set cimport set as cppset

def unpickle_vectors(bytes_data):
return Vectors().from_bytes(bytes_data)
Expand Down Expand Up @@ -50,7 +52,7 @@ cdef class Vectors:
cdef public object name
cdef public object data
cdef public object key2row
cdef public object _unset
cdef cppset[int] _unset

def __init__(self, *, shape=None, data=None, keys=None, name=None):
"""Create a new vector store.
Expand All @@ -69,9 +71,9 @@ cdef class Vectors:
self.data = data
self.key2row = OrderedDict()
if self.data is not None:
self._unset = set(range(self.data.shape[0]))
self._unset = cppset[int]({i for i in range(self.data.shape[0])})
else:
self._unset = set()
self._unset = cppset[int]()
if keys is not None:
for i, key in enumerate(keys):
self.add(key, row=i)
Expand All @@ -93,7 +95,7 @@ cdef class Vectors:
@property
def is_full(self):
"""RETURNS (bool): `True` if no slots are available for new keys."""
return len(self._unset) == 0
return self._unset.size() == 0

@property
def n_keys(self):
Expand Down Expand Up @@ -124,8 +126,8 @@ cdef class Vectors:
"""
i = self.key2row[key]
self.data[i] = vector
if i in self._unset:
self._unset.remove(i)
if self._unset.count(i):
self._unset.erase(self._unset.find(i))

def __iter__(self):
"""Iterate over the keys in the table.
Expand Down Expand Up @@ -164,7 +166,7 @@ cdef class Vectors:
xp = get_array_module(self.data)
self.data = xp.resize(self.data, shape)
filled = {row for row in self.key2row.values()}
self._unset = {row for row in range(shape[0]) if row not in filled}
self._unset = cppset[int]({row for row in range(shape[0]) if row not in filled})
removed_items = []
for key, row in list(self.key2row.items()):
if row >= shape[0]:
Expand All @@ -188,7 +190,7 @@ cdef class Vectors:
YIELDS (ndarray): A vector in the table.
"""
for row, vector in enumerate(range(self.data.shape[0])):
if row not in self._unset:
if not self._unset.count(row):
yield vector

def items(self):
Expand Down Expand Up @@ -253,13 +255,13 @@ cdef class Vectors:
elif row is None:
if self.is_full:
raise ValueError("Cannot add new key to vectors -- full")
row = min(self._unset)
row = deref(self._unset.begin())

self.key2row[key] = row
if vector is not None:
self.data[row] = vector
if row in self._unset:
self._unset.remove(row)
if self._unset.count(row):
self._unset.erase(self._unset.find(row))
return row

def most_similar(self, queries, *, batch_size=1024):
Expand Down Expand Up @@ -365,8 +367,8 @@ cdef class Vectors:
with path.open('rb') as file_:
self.key2row = msgpack.load(file_)
for key, row in self.key2row.items():
if row in self._unset:
self._unset.remove(row)
if self._unset.count(row):
self._unset.erase(self._unset.find(row))

def load_keys(path):
if path.exists():
Expand Down

0 comments on commit 1cdbb7c

Please sign in to comment.