From 9fc501c35c7e65ff2a469bc0a40bcb095ca1e533 Mon Sep 17 00:00:00 2001 From: Aaron Silvas Date: Sun, 1 Oct 2023 20:50:11 -0700 Subject: [PATCH] Support metrics & codes --- lib/index.d.ts | 26 ++++++++++++++++++++++++-- lib/index.js | 24 ++++++++++++++++++++++++ src/faiss.cc | 34 ++++++++++++++++++++++++++++++++++ test/Index.test.js | 13 +++++++++++++ test/IndexFlatL2.test.js | 17 +++++++++++++++++ 5 files changed, 112 insertions(+), 2 deletions(-) diff --git a/lib/index.d.ts b/lib/index.d.ts index 0bca395..38164f5 100644 --- a/lib/index.d.ts +++ b/lib/index.d.ts @@ -45,6 +45,14 @@ export class Index { * @return {number} Whether training is required. */ isTrained(): boolean; + /** + * @return {MetricType} The metric of the index. + */ + get metricType(): MetricType; + /** + * @return {number} Argument of the metric type. + */ + get metricArg(): number; /** * Add n vectors of dimension d to the index. * Vectors are implicitly assigned labels ntotal .. ntotal + n - 1 @@ -110,12 +118,26 @@ export class Index { } +/** + * IndexFlat Abstract Index. + */ +export abstract class IndexFlat extends Index { + /** + * Byte size of each encoded vector. + */ + get codeSize(): number; + /** + * Encoded dataset, size ntotal * codeSize. + */ + get codes(): Buffer; +} + /** * IndexFlatL2 Index. * IndexFlatL2 that stores the full vectors and performs `squared L2` search. * @param {number} d The dimensionality of index. */ -export class IndexFlatL2 extends Index { +export class IndexFlatL2 extends IndexFlat { /** * Read index from a file. * @param {string} fname File path to read. @@ -140,7 +162,7 @@ export class IndexFlatL2 extends Index { * Index that stores the full vectors and performs `maximum inner product` search. * @param {number} d The dimensionality of index. */ -export class IndexFlatIP extends Index { +export class IndexFlatIP extends IndexFlat { /** * Read index from a file. * @param {string} fname File path to read. diff --git a/lib/index.js b/lib/index.js index ced9f1f..434591b 100644 --- a/lib/index.js +++ b/lib/index.js @@ -14,4 +14,28 @@ var MetricType; MetricType[MetricType["METRIC_Jaccard"] = 23] = "METRIC_Jaccard"; })(MetricType || (faiss.MetricType = MetricType = {})); +function wireupGetterSetters(propName, indexes, getter, setter) { + for (let Index of indexes) { + if (!(propName in Index.prototype)) { // prevents redefinition in jest + const args = {}; + if (getter) { + args['get'] = function () { + return this[getter](); + } + } + if (setter) { + args['set'] = function (v) { + this[setter](v); + } + } + Object.defineProperty(Index.prototype, propName, args); + } + } +} + +wireupGetterSetters('codeSize', [faiss.IndexFlatL2], 'getCodeSize'); +wireupGetterSetters('codes', [faiss.IndexFlatL2], 'getCodesUInt8'); +wireupGetterSetters('metricType', [faiss.Index, faiss.IndexFlatL2, faiss.IndexFlatIP], 'getMetricType'); +wireupGetterSetters('metricArg', [faiss.Index, faiss.IndexFlatL2, faiss.IndexFlatIP], 'getMetricArg'); + module.exports = faiss; \ No newline at end of file diff --git a/src/faiss.cc b/src/faiss.cc index f32b6b7..f0459af 100644 --- a/src/faiss.cc +++ b/src/faiss.cc @@ -440,6 +440,30 @@ class IndexBase : public Napi::ObjectWrap return Napi::Buffer::Copy(env, writer->data.data(), writer->data.size()); } + Napi::Value getMetricType(const Napi::CallbackInfo &info) + { + return Napi::Number::New(info.Env(), index_->metric_type); + } + + Napi::Value getMetricArg(const Napi::CallbackInfo &info) + { + return Napi::Number::New(info.Env(), index_->metric_arg); + } + + Napi::Value getCodeSize(const Napi::CallbackInfo &info) + { + auto index = dynamic_cast(index_.get()); + return Napi::Number::New(info.Env(), index->code_size); + } + + Napi::Value getCodesUInt8(const Napi::CallbackInfo &info) + { + Napi::Env env = info.Env(); + + auto index = dynamic_cast(index_.get()); + return Napi::Buffer::Copy(env, index->codes.data(), index->codes.size()); + } + protected: std::unique_ptr index_; inline static Napi::FunctionReference *constructor; @@ -467,6 +491,8 @@ class Index : public IndexBase InstanceMethod("mergeFrom", &Index::mergeFrom), InstanceMethod("removeIds", &Index::removeIds), InstanceMethod("toBuffer", &Index::toBuffer), + InstanceMethod("getMetricType", &Index::getMetricType), + InstanceMethod("getMetricArg", &Index::getMetricArg), StaticMethod("read", &Index::read), StaticMethod("fromBuffer", &Index::fromBuffer), StaticMethod("fromFactory", &Index::fromFactory), @@ -502,6 +528,10 @@ class IndexFlatL2 : public IndexBase InstanceMethod("mergeFrom", &IndexFlatL2::mergeFrom), InstanceMethod("removeIds", &IndexFlatL2::removeIds), InstanceMethod("toBuffer", &IndexFlatL2::toBuffer), + InstanceMethod("getMetricType", &IndexFlatL2::getMetricType), + InstanceMethod("getMetricArg", &IndexFlatL2::getMetricArg), + InstanceMethod("getCodeSize", &IndexFlatL2::getCodeSize), + InstanceMethod("getCodesUInt8", &IndexFlatL2::getCodesUInt8), StaticMethod("read", &IndexFlatL2::read), StaticMethod("fromBuffer", &IndexFlatL2::fromBuffer), }); @@ -536,6 +566,10 @@ class IndexFlatIP : public IndexBase InstanceMethod("mergeFrom", &IndexFlatIP::mergeFrom), InstanceMethod("removeIds", &IndexFlatIP::removeIds), InstanceMethod("toBuffer", &IndexFlatIP::toBuffer), + InstanceMethod("getMetricType", &IndexFlatIP::getMetricType), + InstanceMethod("getMetricArg", &IndexFlatIP::getMetricArg), + InstanceMethod("getCodeSize", &IndexFlatIP::getCodeSize), + InstanceMethod("getCodesUInt8", &IndexFlatIP::getCodesUInt8), StaticMethod("read", &IndexFlatIP::read), StaticMethod("fromBuffer", &IndexFlatIP::fromBuffer), }); diff --git a/test/Index.test.js b/test/Index.test.js index 00b0cd4..c6c44e9 100644 --- a/test/Index.test.js +++ b/test/Index.test.js @@ -43,4 +43,17 @@ describe('Index', () => { expect(index.ntotal()).toBe(newIndex.ntotal()); }); }); + + describe('#metricType', () => { + it('metric adheres to default', () => { + const index = Index.fromFactory(2, 'Flat'); + expect(index.metricType).toBe(MetricType.METRIC_L2); + expect(index.metricArg).toBe(0); + }); + + it('metric adheres to initialized value', () => { + const index = Index.fromFactory(2, 'Flat', MetricType.METRIC_INNER_PRODUCT); + expect(index.metricType).toBe(MetricType.METRIC_INNER_PRODUCT); + }); + }); }); diff --git a/test/IndexFlatL2.test.js b/test/IndexFlatL2.test.js index fc70c53..89ae6df 100644 --- a/test/IndexFlatL2.test.js +++ b/test/IndexFlatL2.test.js @@ -218,4 +218,21 @@ describe('IndexFlatL2', () => { expect(index.search([1, 3], 1)).toMatchObject({ distances: [0], labels: [0] }); }); }); + + describe('#codes', () => { + it("returns codeSize", () => { + const index = new IndexFlatL2(2); + expect(index.codeSize).toBe(8); + }); + + it("returns codes", () => { + const index = new IndexFlatL2(2); + const arr = [1, 1, 255, 255]; + index.add(arr.slice(0, 2)); + index.add(arr.slice(2, 4)); + expect(index.codes).toStrictEqual(Buffer.from(Float32Array.from(arr).buffer)); + index.add([99, 99]); + expect(index.codes).toStrictEqual(Buffer.from(Float32Array.from(arr.concat([99, 99])).buffer)); + }); + }); }); \ No newline at end of file