-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathaltid.swig
111 lines (77 loc) · 2.49 KB
/
altid.swig
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
%module altid
// to get uint32_t and friends
%include <stdint.i>
// This means: assume what's declared in these .h files is provided
// by the Faiss module.
#define FAISS_API
%import(module="faiss") "faiss/MetricType.h"
%import(module="faiss") "faiss/Index.h"
%import(module="faiss") "faiss/impl/NSG.h"
%import(module="faiss") "faiss/IndexNSG.h"
%template(FinalNSGGraph) faiss::nsg::Graph< int32_t >;
%{
#include <faiss/impl/FaissAssert.h>
#include "altid_impl.h"
%}
// This is important to release GIL and do Faiss exception handing
%exception {
Py_BEGIN_ALLOW_THREADS
try {
$action
} catch(faiss::FaissException & e) {
PyEval_RestoreThread(_save);
if (PyErr_Occurred()) {
// some previous code already set the error type.
} else {
PyErr_SetString(PyExc_RuntimeError, e.what());
}
SWIG_fail;
} catch(std::bad_alloc & ba) {
PyEval_RestoreThread(_save);
PyErr_SetString(PyExc_MemoryError, "std::bad_alloc");
SWIG_fail;
}
Py_END_ALLOW_THREADS
}
%include "altid_impl.h"
%inline %{
void NSG_replace_final_graph(faiss::NSG & nsg, faiss::nsg::Graph<int32_t> *graph) {
nsg.final_graph.reset(graph);
}
// make an untyped version of the above because there is a big mess-up with the SWIG types
void search_NSG_and_trace_untyped(
const faiss::IndexNSG & index,
faiss::idx_t n,
const float *x,
int k,
void *labels,
float *distances,
void * visited_nodes)
{
search_NSG_and_trace(index, n, x, k, (faiss::idx_t*)labels, distances, *(std::vector<faiss::idx_t> *)visited_nodes);
}
%}
%pythoncode %{
import faiss
import numpy as np
def replace_final_graph(self, graph):
_altid.NSG_replace_final_graph(self, graph)
graph.this.disown()
faiss.NSG.replace_final_graph = replace_final_graph
def search_and_trace(self, x, k):
n, d = x.shape
I = np.empty((n, k), dtype='int64')
D = np.empty((n, k), dtype='float32')
visited_nodes = faiss.Int64Vector()
search_NSG_and_trace_untyped(
self, n, faiss.swig_ptr(x), k,
faiss.swig_ptr(I), faiss.swig_ptr(D),
visited_nodes)
return D, I, faiss.vector_to_array(visited_nodes)
faiss.IndexNSG.search_and_trace = search_and_trace
%}