Skip to content

Commit

Permalink
improve
Browse files Browse the repository at this point in the history
  • Loading branch information
qicosmos committed Jul 27, 2024
1 parent 247584d commit e61d6ae
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 82 deletions.
42 changes: 12 additions & 30 deletions include/ylt/thread/thread_local.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,12 @@ class thread_local_t {
static std::mutex mtx;
std::lock_guard lock(mtx);
thrd_locals_.push_back(std::move(t));
if constexpr (is_dynamic) {
mtxs_.push_back(std::make_unique<std::mutex>());
indexs_.push_back(-1);
}
}

decltype(auto) value() {
static thread_local size_t index = get_rr();
if constexpr (is_dynamic) {
indexs_[index] = index;
return std::make_pair(thrd_locals_.at(index).get(),
mtxs_.at(index).get());
return index;
}
else {
return *thrd_locals_.at(index);
Expand All @@ -73,19 +67,8 @@ class thread_local_t {

template <typename F>
void for_each(F &&fn) {
if constexpr (is_dynamic) {
for (size_t i = 0; i < indexs_.size(); i++) {
int index = indexs_[i];
if (index >= 0) {
std::lock_guard lock(*mtxs_[index]);
fn(*thrd_locals_[index]);
}
}
}
else {
for (auto &ptr : thrd_locals_) {
fn(*ptr);
}
for (auto &ptr : thrd_locals_) {
fn(*ptr);
}
}

Expand All @@ -112,15 +95,15 @@ class thread_local_t {

tls_key_t tls_key_;
std::vector<std::unique_ptr<T>> thrd_locals_;
std::vector<std::unique_ptr<std::mutex>> mtxs_;
std::vector<int> indexs_;
};

inline thread_local_t<std::unordered_map<std::string, std::shared_ptr<void>>,
true>
g_ylt_tls_map{};
#ifndef YLT_TLS_COUNT
#define YLT_TLS_COUNT std::thread::hardware_concurrency()
#endif

inline thread_local_t<size_t, true> g_ylt_tls_index{};

inline decltype(auto) get_tls_pair() { return g_ylt_tls_map.value(); }
inline decltype(auto) get_tls_index() { return g_ylt_tls_index.value(); }

inline async_simple::coro::Lazy<void> init_dynamic_thread_locals_impl(
size_t thd_num) {
Expand All @@ -129,16 +112,15 @@ inline async_simple::coro::Lazy<void> init_dynamic_thread_locals_impl(
auto executor = coro_io::get_global_block_executor();
vec.push_back(coro_io::post(
[] {
g_ylt_tls_map.create_tls();
g_ylt_tls_index.create_tls();
},
executor));
}

co_await async_simple::coro::collectAll(std::move(vec));
}

inline void init_dynamic_thread_locals(
size_t thd_num = std::thread::hardware_concurrency()) {
inline void init_dynamic_thread_locals(size_t thd_num = YLT_TLS_COUNT) {
coro_io::g_block_io_context_pool(thd_num);
async_simple::coro::syncAwait(init_dynamic_thread_locals_impl(thd_num));
}
Expand Down Expand Up @@ -172,7 +154,7 @@ constexpr inline size_t tls_keys_max() {
template <typename T>
inline void init_static_thread_locals(
std::vector<thread_local_t<T> *> &tls_list,
size_t thd_num = std::thread::hardware_concurrency()) {
size_t thd_num = YLT_TLS_COUNT) {
if (tls_list.size() > tls_keys_max()) {
throw std::out_of_range("exceed the max number of tls keys");
}
Expand Down
89 changes: 37 additions & 52 deletions src/metric/tests/test_metric.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ TEST_CASE("test thread local") {
for (size_t i = 0; i < 10; i++)
counters.push_back(std::make_shared<test_counter>());

std::vector<ylt::thread::thread_local_t<std::atomic<int64_t>>*> vec;
for (auto& w : counters) {
std::vector<ylt::thread::thread_local_t<std::atomic<int64_t>> *> vec;
for (auto &w : counters) {
vec.push_back(&w->val_);
}

Expand All @@ -42,12 +42,12 @@ TEST_CASE("test thread local") {
threads.push_back(std::move(thd));
}

for (auto& thd : threads) {
for (auto &thd : threads) {
thd.join();
}

int64_t total = 0;
counters[0]->val_.for_each([&](auto& val) {
counters[0]->val_.for_each([&](auto &val) {
total += val;
});
CHECK(total == 20);
Expand All @@ -56,48 +56,37 @@ TEST_CASE("test thread local") {

template <typename T>
struct test_counter1 {
test_counter1(std::string str) : name(str) {}
test_counter1(std::string str) : name(str), vec_(YLT_TLS_COUNT) {}

~test_counter1() {
ylt::thread::g_ylt_tls_map.for_each([this](auto& t) {
auto it = t.find(name);
if (it != t.end())
t.erase(it);
});
}

auto& value() {
// map and it's mutex
static thread_local auto pair = ylt::thread::get_tls_pair();

std::unique_lock lock(*pair.second);
auto [it, r] = pair.first->try_emplace(name, nullptr);
if (r) {
it->second = std::make_shared<std::atomic<T>>(0);
lock.unlock();
auto &value() {
static thread_local auto index = ylt::thread::get_tls_index();

std::unique_lock guard(mtx_);
atomics_.push_back((std::atomic<T>*)it->second.get());
auto &cur_ptr = atomics_[index];
if (cur_ptr == nullptr) {
auto ptr = std::make_unique<std::atomic<T>>(0);
cur_ptr = ptr.get();
std::unique_lock lock(mtx_);
vec_[index] = std::move(ptr);
}
else {
lock.unlock();
}

return *((std::atomic<T>*)it->second.get());
return *cur_ptr;
}

T total() {
T val = 0;
std::shared_lock guard(mtx_);
for (auto& t : atomics_) {
val += t->load();
std::shared_lock lock(mtx_);
for (auto &t : vec_) {
if (t) {
val += t->load();
}
}
return val;
}

std::string name;
std::shared_mutex mtx_;
std::vector<std::atomic<T>*> atomics_;
inline static thread_local std::vector<std::atomic<T> *> atomics_{
YLT_TLS_COUNT};
std::vector<std::unique_ptr<std::atomic<T>>> vec_;
};

TEST_CASE("test thread local") {
Expand All @@ -119,31 +108,27 @@ TEST_CASE("test thread local") {
counters[0]->value()++;
}
});

threads.push_back(std::move(thd));
}

for (auto& thd : threads) {
// std::thread thd2([&] {
// while (true) {
// int64_t cur_tls = counters[0]->total();

// std::cout << "tls: " << cur_tls << "\n";
// std::this_thread::sleep_for(std::chrono::seconds(1));
// }
// });

// thd2.join();

for (auto &thd : threads) {
thd.join();
}

int64_t total = counters[0]->total();
CHECK(total == 20);
counters.clear();

std::vector<std::thread> test_vec;
for (size_t i = 0; i < N; i++) {
test_vec.push_back(std::thread([] {
static thread_local auto pair = ylt::thread::get_tls_pair();
CHECK(pair.first->empty());
for (auto& [k, v] : *pair.first) {
std::cout << k << ", " << *((int64_t*)v.get()) << "\n";
}
}));
}

for (auto& thd : test_vec) {
thd.join();
}
}
}

Expand Down Expand Up @@ -399,7 +384,7 @@ TEST_CASE("test register metric") {
default_metric_manager::register_metric_static(g);

auto map1 = default_metric_manager::metric_map_static();
for (auto& [k, v] : map1) {
for (auto &[k, v] : map1) {
bool r = k == "get_count" || k == "get_guage_count";
break;
}
Expand Down Expand Up @@ -1186,5 +1171,5 @@ TEST_CASE("test remove dynamic metric") {
}

DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4007)
int main(int argc, char** argv) { return doctest::Context(argc, argv).run(); }
int main(int argc, char **argv) { return doctest::Context(argc, argv).run(); }
DOCTEST_MSVC_SUPPRESS_WARNING_POP

0 comments on commit e61d6ae

Please sign in to comment.