Skip to content

Commit

Permalink
add QBG C APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
masajiro committed Jun 26, 2023
1 parent d9b5a91 commit 86c41c3
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 4 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.0.12
2.0.13
51 changes: 51 additions & 0 deletions lib/NGT/NGTQ/Capi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,26 @@ ObjectID qbg_append_object(QBGIndex index, float *obj, uint32_t obj_dim, QBGErro
}
}

ObjectID qbg_append_object_as_uint8(QBGIndex index, uint8_t *obj, uint32_t obj_dim, QBGError error) {
if (index == NULL || obj == NULL || obj_dim == 0){
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index << " obj = " << obj << " obj_dim = " << obj_dim;
operate_error_string_(ss, error);
return 0;
}

try {
auto *pindex = static_cast<QBG::Index*>(index);
std::vector<uint8_t> vobj(&obj[0], &obj[obj_dim]);
return pindex->append(vobj);
} catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return 0;
}
}

void qbg_initialize_build_parameters(QBGBuildParameters *parameters) {
parameters->hierarchical_clustering_init_mode = static_cast<int>(NGT::Clustering::InitializationModeKmeansPlusPlus);
parameters->number_of_first_objects = 0;
Expand Down Expand Up @@ -444,6 +464,37 @@ float* qbg_get_object(QBGIndex index, ObjectID id, QBGError error) {
}
}

uint8_t* qbg_get_object_as_uint8(QBGIndex index, ObjectID id, QBGError error) {
if (index == NULL) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : parametor error: index = " << index;
operate_error_string_(ss, error);
return 0;
}

auto *pindex = static_cast<QBG::Index*>(index);

try {
auto o = pindex->getObject(id);
std::vector<uint8_t> object(o.begin(), o.end());
size_t size = sizeof(uint8_t) * object.size();
auto obj = malloc(size);
if (obj == 0) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: Cannot allocate memory.";
operate_error_string_(ss, error);
return 0;
}
memcpy(obj, object.data(), size);
return static_cast<uint8_t*>(obj);
} catch(std::exception &err) {
std::stringstream ss;
ss << "Capi : " << __FUNCTION__ << "() : Error: " << err.what();
operate_error_string_(ss, error);
return 0;
}
}

size_t qbg_get_dimension(QBGIndex index, QBGError error) {
if (index == NULL) {
std::stringstream ss;
Expand Down
4 changes: 4 additions & 0 deletions lib/NGT/NGTQ/Capi.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ extern "C" {

ObjectID qbg_append_object(QBGIndex index, float *obj, uint32_t obj_dim, QBGError error);

ObjectID qbg_append_object_as_uint8(QBGIndex index, uint8_t *obj, uint32_t obj_dim, QBGError error);

void qbg_initialize_build_parameters(QBGBuildParameters *parameters);

bool qbg_build_index(const char *index_path, QBGBuildParameters *parameters, QBGError error);
Expand All @@ -129,6 +131,8 @@ extern "C" {

float* qbg_get_object(QBGIndex index, ObjectID id, QBGError error);

uint8_t* qbg_get_object_as_uint8(QBGIndex index, ObjectID id, QBGError error);

size_t qbg_get_dimension(QBGIndex index, QBGError error);

#ifdef __cplusplus
Expand Down
12 changes: 9 additions & 3 deletions lib/NGT/NGTQ/QuantizedBlobGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -370,11 +370,17 @@ namespace QBG {
void insert(const size_t id, std::vector<float> &object) {
getQuantizer().objectList.put(id, object, &getQuantizer().globalCodebookIndex.getObjectSpace());
}

NGT::ObjectID append(std::vector<float> &object) {
template<typename T>
NGT::ObjectID append(std::vector<T> &object) {
NGT::ObjectID id = getQuantizer().objectList.size();
id = id == 0 ? 1 : id;
getQuantizer().objectList.put(id, object, &getQuantizer().globalCodebookIndex.getObjectSpace());
if (typeid(T) == typeid(float)) {
auto &obj = *reinterpret_cast<std::vector<float>*>(&object);
getQuantizer().objectList.put(id, obj, &getQuantizer().globalCodebookIndex.getObjectSpace());
} else {
std::vector<float> obj(object.begin(), object.end());
getQuantizer().objectList.put(id, obj, &getQuantizer().globalCodebookIndex.getObjectSpace());
}
return id;
}

Expand Down

0 comments on commit 86c41c3

Please sign in to comment.