|
| 1 | +#ifndef BATCH_RUNNER_BATCH_FINDER_HPP |
| 2 | +#define BATCH_RUNNER_BATCH_FINDER_HPP |
| 3 | + |
| 4 | +#include <functional> |
| 5 | +#include <span> |
| 6 | + |
| 7 | +#include "constants.hpp" |
| 8 | +#include "hashtables/base_kht.hpp" |
| 9 | +#include "types.hpp" |
| 10 | + |
| 11 | +namespace kmercounter { |
| 12 | +template <size_t N = HT_TESTS_BATCH_LENGTH> |
| 13 | +class HTBatchFinder { |
| 14 | + public: |
| 15 | + using FindCallback = std::function<void(const FindResult&)>; |
| 16 | + |
| 17 | + HTBatchFinder() : HTBatchFinder(nullptr) {} |
| 18 | + HTBatchFinder(BaseHashTable* ht) : HTBatchFinder(ht, nullptr) {} |
| 19 | + HTBatchFinder(BaseHashTable* ht, FindCallback callback_fn) |
| 20 | + : ht_(ht), |
| 21 | + buffer_size_(0), |
| 22 | + results_(0, result_buffer_), |
| 23 | + callback_fn_(callback_fn) {} |
| 24 | + ~HTBatchFinder() { flush(); } |
| 25 | + |
| 26 | + /// Find a key. `id` is used to track the find operation. |
| 27 | + /// Set `parition_id` to the actual partition if you have more than one |
| 28 | + /// partition when using PartitionedHT. |
| 29 | + void find(const uint64_t key, const uint64_t id, |
| 30 | + const uint64_t partition_id = 0) { |
| 31 | + // Append kv to `buffer_` |
| 32 | + buffer_[buffer_size_].key = key; |
| 33 | + buffer_[buffer_size_].id = id; |
| 34 | + buffer_[buffer_size_].part_id = partition_id; |
| 35 | + buffer_size_++; |
| 36 | + |
| 37 | + // Flush if `buffer_` is full. |
| 38 | + if (buffer_size_ >= N) { |
| 39 | + flush_buffer(); |
| 40 | + } |
| 41 | + } |
| 42 | + |
| 43 | + /// Flush everything to the hashtable and flush the hashtable find queue. |
| 44 | + void flush() { |
| 45 | + if (buffer_size_ > 0) { |
| 46 | + flush_buffer(); |
| 47 | + } |
| 48 | + flush_ht(); |
| 49 | + } |
| 50 | + |
| 51 | + // Returns the number of elements flushed. |
| 52 | + size_t num_flushed() { return num_flushed_; } |
| 53 | + |
| 54 | + // Set the callback function. |
| 55 | + void set_callback(FindCallback callback_fn) { callback_fn_ = callback_fn; } |
| 56 | + |
| 57 | + private: |
| 58 | + // Flush the insertion buffer without checking `buffer_size_`. |
| 59 | + void flush_buffer() { |
| 60 | + ht_->find_batch(InsertFindArguments(buffer_, buffer_size_), results_); |
| 61 | + num_flushed_ += buffer_size_; |
| 62 | + buffer_size_ = 0; |
| 63 | + process_results(); |
| 64 | + } |
| 65 | + |
| 66 | + // Issue a flush to the hashtable. |
| 67 | + void flush_ht() { |
| 68 | + ht_->flush_find_queue(results_); |
| 69 | + process_results(); |
| 70 | + } |
| 71 | + |
| 72 | + /// Process each result, if there's any. |
| 73 | + void process_results() { |
| 74 | + for (const auto& result : std::span(results_.second, results_.first)) { |
| 75 | + callback_fn_(result); |
| 76 | + } |
| 77 | + results_.first = 0; |
| 78 | + } |
| 79 | + |
| 80 | + // Target hashtable. |
| 81 | + BaseHashTable* ht_; |
| 82 | + // Buffer to hold the arguments for batch insertion. |
| 83 | + __attribute__((aligned(64))) InsertFindArgument buffer_[N]; |
| 84 | + // Current size of the buffer. |
| 85 | + size_t buffer_size_; |
| 86 | + // Total number of elements flushed. |
| 87 | + size_t num_flushed_; |
| 88 | + // The buffer for storing the results. |
| 89 | + __attribute__((alignas(64))) FindResult result_buffer_[N]; |
| 90 | + // The results of finds. |
| 91 | + ValuePairs results_; |
| 92 | + // A user provided function for processing a result |
| 93 | + FindCallback callback_fn_; |
| 94 | + |
| 95 | + // Sanity checks |
| 96 | + static_assert(N > 0); |
| 97 | +}; |
| 98 | +} // namespace kmercounter |
| 99 | +#endif // BATCH_RUNNER_BATCH_FINDER_HPP |
0 commit comments