|  | 
| 2 | 2 | 
 | 
| 3 | 3 | #include "visited_list_pool.h" | 
| 4 | 4 | #include "hnswlib.h" | 
| 5 |  | -#include <atomic> | 
| 6 |  | -#include <random> | 
| 7 |  | -#include <stdlib.h> | 
|  | 5 | + | 
| 8 | 6 | #include <assert.h> | 
| 9 |  | -#include <unordered_set> | 
|  | 7 | +#include <stdlib.h> | 
|  | 8 | + | 
|  | 9 | +#include <atomic> | 
|  | 10 | +#include <limits> | 
| 10 | 11 | #include <list> | 
| 11 | 12 | #include <memory> | 
|  | 13 | +#include <mutex> | 
|  | 14 | +#include <random> | 
|  | 15 | +#include <unordered_set> | 
| 12 | 16 | 
 | 
| 13 | 17 | namespace hnswlib { | 
| 14 | 18 | typedef unsigned int tableint; | 
|  | 19 | +constexpr tableint kInvalidInternalId = std::numeric_limits<tableint>::max(); | 
| 15 | 20 | typedef unsigned int linklistsizeint; | 
| 16 | 21 | 
 | 
| 17 | 22 | template<typename dist_t> | 
| @@ -195,6 +200,17 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> { | 
| 195 | 200 |     } | 
| 196 | 201 | 
 | 
| 197 | 202 | 
 | 
|  | 203 | +    tableint getInternalIdByLabel(labeltype label) const { | 
|  | 204 | +        std::lock_guard<std::mutex> lock_table(label_lookup_lock); | 
|  | 205 | +        auto label_lookup_result = label_lookup_.find(label); | 
|  | 206 | +        if (label_lookup_result == label_lookup_.end() || | 
|  | 207 | +            isMarkedDeleted(label_lookup_result->second)) { | 
|  | 208 | +            return kInvalidInternalId; | 
|  | 209 | +        } | 
|  | 210 | +        return label_lookup_result->second; | 
|  | 211 | +    } | 
|  | 212 | + | 
|  | 213 | + | 
| 198 | 214 |     inline void setExternalLabel(tableint internal_id, labeltype label) const { | 
| 199 | 215 |         memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype)); | 
| 200 | 216 |     } | 
| @@ -870,13 +886,10 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> { | 
| 870 | 886 |         // lock all operations with element by label | 
| 871 | 887 |         std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label)); | 
| 872 | 888 | 
 | 
| 873 |  | -        std::unique_lock <std::mutex> lock_table(label_lookup_lock); | 
| 874 |  | -        auto search = label_lookup_.find(label); | 
| 875 |  | -        if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { | 
|  | 889 | +        tableint internalId = getInternalIdByLabel(label); | 
|  | 890 | +        if (internalId == kInvalidInternalId) { | 
| 876 | 891 |             return Status("Label not found"); | 
| 877 | 892 |         } | 
| 878 |  | -        tableint internalId = search->second; | 
| 879 |  | -        lock_table.unlock(); | 
| 880 | 893 | 
 | 
| 881 | 894 |         char* data_ptrv = getDataByInternalId(internalId); | 
| 882 | 895 |         size_t dim = *((size_t *) dist_func_param_); | 
| @@ -1190,7 +1203,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> { | 
| 1190 | 1203 |     } | 
| 1191 | 1204 | 
 | 
| 1192 | 1205 | 
 | 
| 1193 |  | -    // This internal function adds a point at a specific level. If level is | 
|  | 1206 | +    // This internal function adds a point at a specific level. | 
| 1194 | 1207 |     StatusOr<tableint> addPointWithLevel(const void *data_point, labeltype label, int level) { | 
| 1195 | 1208 |         tableint cur_c = 0; | 
| 1196 | 1209 |         { | 
|  | 
0 commit comments