Skip to content

Commit 676e4f3

Browse files
committed
add batch runner
1 parent 621139f commit 676e4f3

File tree

4 files changed

+156
-8
lines changed

4 files changed

+156
-8
lines changed

format.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Format selected folders
22

33
# Format folders recursively
4-
for folder in examples unittests include/input_reader src/input_reader; do
4+
for folder in examples unittests include/input_reader src/input_reader include/hashtables/batch_runner; do
55
find $folder -regex '.*\.\(cpp\|hpp\|cc\|cxx\|c\|h\)' -exec clang-format -style=file -i {} \;
66
done
77

88
# Format files
9-
for file in include/tests/HashjoinTest.hpp src/tests/hashjoin_test.cpp include/hashtables/batch_inserter.hpp; do
9+
for file in include/tests/HashjoinTest.hpp src/tests/hashjoin_test.cpp; do
1010
clang-format -style=file -i $file
1111
done
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#ifndef BATCH_RUNNER_BATCH_FINDER_HPP
2+
#define BATCH_RUNNER_BATCH_FINDER_HPP
3+
4+
#include <functional>
5+
#include <span>
6+
7+
#include "constants.hpp"
8+
#include "hashtables/base_kht.hpp"
9+
#include "types.hpp"
10+
11+
namespace kmercounter {
12+
template <size_t N = HT_TESTS_BATCH_LENGTH>
13+
class HTBatchFinder {
14+
public:
15+
using FindCallback = std::function<void(const FindResult&)>;
16+
17+
HTBatchFinder() : HTBatchFinder(nullptr) {}
18+
HTBatchFinder(BaseHashTable* ht) : HTBatchFinder(ht, nullptr) {}
19+
HTBatchFinder(BaseHashTable* ht, FindCallback callback_fn)
20+
: ht_(ht),
21+
buffer_size_(0),
22+
results_(0, result_buffer_),
23+
callback_fn_(callback_fn) {}
24+
~HTBatchFinder() { flush(); }
25+
26+
/// Find a key. `id` is used to track the find operation.
27+
/// Set `parition_id` to the actual partition if you have more than one
28+
/// partition when using PartitionedHT.
29+
void find(const uint64_t key, const uint64_t id,
30+
const uint64_t partition_id = 0) {
31+
// Append kv to `buffer_`
32+
buffer_[buffer_size_].key = key;
33+
buffer_[buffer_size_].id = id;
34+
buffer_[buffer_size_].part_id = partition_id;
35+
buffer_size_++;
36+
37+
// Flush if `buffer_` is full.
38+
if (buffer_size_ >= N) {
39+
flush_buffer();
40+
}
41+
}
42+
43+
/// Flush everything to the hashtable and flush the hashtable find queue.
44+
void flush() {
45+
if (buffer_size_ > 0) {
46+
flush_buffer();
47+
}
48+
flush_ht();
49+
}
50+
51+
// Returns the number of elements flushed.
52+
size_t num_flushed() { return num_flushed_; }
53+
54+
// Set the callback function.
55+
void set_callback(FindCallback callback_fn) { callback_fn_ = callback_fn; }
56+
57+
private:
58+
// Flush the insertion buffer without checking `buffer_size_`.
59+
void flush_buffer() {
60+
ht_->find_batch(InsertFindArguments(buffer_, buffer_size_), results_);
61+
num_flushed_ += buffer_size_;
62+
buffer_size_ = 0;
63+
process_results();
64+
}
65+
66+
// Issue a flush to the hashtable.
67+
void flush_ht() {
68+
ht_->flush_find_queue(results_);
69+
process_results();
70+
}
71+
72+
/// Process each result, if there's any.
73+
void process_results() {
74+
for (const auto& result : std::span(results_.second, results_.first)) {
75+
callback_fn_(result);
76+
}
77+
results_.first = 0;
78+
}
79+
80+
// Target hashtable.
81+
BaseHashTable* ht_;
82+
// Buffer to hold the arguments for batch insertion.
83+
__attribute__((aligned(64))) InsertFindArgument buffer_[N];
84+
// Current size of the buffer.
85+
size_t buffer_size_;
86+
// Total number of elements flushed.
87+
size_t num_flushed_;
88+
// The buffer for storing the results.
89+
__attribute__((alignas(64))) FindResult result_buffer_[N];
90+
// The results of finds.
91+
ValuePairs results_;
92+
// A user provided function for processing a result
93+
FindCallback callback_fn_;
94+
95+
// Sanity checks
96+
static_assert(N > 0);
97+
};
98+
} // namespace kmercounter
99+
#endif // BATCH_RUNNER_BATCH_FINDER_HPP

include/hashtables/batch_inserter.hpp include/hashtables/batch_runner/batch_inserter.hpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#ifndef HASHTABLES_BATCH_INSERTER_HPP
22
#define HASHTABLES_BATCH_INSERTER_HPP
33

4-
#include "base_kht.hpp"
54
#include "constants.hpp"
5+
#include "hashtables/base_kht.hpp"
66
#include "types.hpp"
77

88
namespace kmercounter {
@@ -22,31 +22,31 @@ class HTBatchInserter {
2222

2323
// Flush if `buffer_` is full.
2424
if (buffer_size_ >= N) {
25-
flush_buffer_();
25+
flush_buffer();
2626
}
2727
}
2828

2929
// Flush everything to the hashtable and flush the hashtable insert queue.
3030
void flush() {
3131
if (buffer_size_ > 0) {
32-
flush_buffer_();
32+
flush_buffer();
3333
}
34-
flush_ht_();
34+
flush_ht();
3535
}
3636

3737
// Returns the number of elements flushed.
3838
size_t num_flushed() { return num_flushed_; }
3939

4040
private:
4141
// Flush the insertion buffer without checking `buffer_size_`.
42-
void flush_buffer_() {
42+
void flush_buffer() {
4343
ht_->insert_batch(InsertFindArguments(buffer_, buffer_size_));
4444
num_flushed_ += buffer_size_;
4545
buffer_size_ = 0;
4646
}
4747

4848
// Issue a flush to the hashtable.
49-
void flush_ht_() { ht_->flush_insert_queue(); }
49+
void flush_ht() { ht_->flush_insert_queue(); }
5050

5151
// Target hashtable.
5252
BaseHashTable* ht_;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#ifndef BATCH_RUNNER_BATCH_RUNNER_HPP
2+
#define BATCH_RUNNER_BATCH_RUNNER_HPP
3+
4+
#include "batch_finder.hpp"
5+
#include "batch_inserter.hpp"
6+
#include "hashtables/base_kht.hpp"
7+
8+
namespace kmercounter {
9+
/// A wrapper around `HTBatchInserter` and `HTBatchFinder`.
10+
template <size_t N = HT_TESTS_BATCH_LENGTH>
11+
class HTBatchRunner : public HTBatchInserter<N>, public HTBatchFinder<N> {
12+
public:
13+
using FindCallback = HTBatchFinder<N>::FindCallback;
14+
15+
HTBatchRunner() = default;
16+
HTBatchRunner(BaseHashTable* ht) : HTBatchRunner(ht, nullptr) {}
17+
HTBatchRunner(BaseHashTable* ht, FindCallback find_callback)
18+
: HTBatchInserter<N>(ht), HTBatchFinder<N>(ht, find_callback) {}
19+
~HTBatchRunner() { flush(); }
20+
21+
/// Insert one kv pair.
22+
void insert(const uint64_t key, const uint64_t value) {
23+
HTBatchInserter<N>::insert(key, value);
24+
}
25+
26+
/// Flush both insert and find queue.
27+
void flush() {
28+
flush_insert();
29+
flush_find();
30+
}
31+
32+
/// Flush insert queue.
33+
void flush_insert() { HTBatchInserter<N>::flush(); }
34+
35+
/// Flush find queue.
36+
void flush_find() { HTBatchFinder<N>::flush(); }
37+
38+
/// Returns the number of inserts flushed.
39+
size_t num_insert_flushed() { return HTBatchInserter<N>::num_flushed(); }
40+
41+
/// Returns the number of inserts flushed.
42+
size_t num_find_flushed() { return HTBatchFinder<N>::num_flushed(); }
43+
44+
// Sanity checks
45+
static_assert(N > 0);
46+
};
47+
} // namespace kmercounter
48+
49+
#endif // BATCH_RUNNER_BATCH_RUNNER_HPP

0 commit comments

Comments
 (0)