diff --git a/include/ylt/thread/thread_local.hpp b/include/ylt/thread/thread_local.hpp index 33f904446..a4ebbb6a0 100644 --- a/include/ylt/thread/thread_local.hpp +++ b/include/ylt/thread/thread_local.hpp @@ -53,18 +53,12 @@ class thread_local_t { static std::mutex mtx; std::lock_guard lock(mtx); thrd_locals_.push_back(std::move(t)); - if constexpr (is_dynamic) { - mtxs_.push_back(std::make_unique()); - indexs_.push_back(-1); - } } decltype(auto) value() { static thread_local size_t index = get_rr(); if constexpr (is_dynamic) { - indexs_[index] = index; - return std::make_pair(thrd_locals_.at(index).get(), - mtxs_.at(index).get()); + return index; } else { return *thrd_locals_.at(index); @@ -73,19 +67,8 @@ class thread_local_t { template void for_each(F &&fn) { - if constexpr (is_dynamic) { - for (size_t i = 0; i < indexs_.size(); i++) { - int index = indexs_[i]; - if (index >= 0) { - std::lock_guard lock(*mtxs_[index]); - fn(*thrd_locals_[index]); - } - } - } - else { - for (auto &ptr : thrd_locals_) { - fn(*ptr); - } + for (auto &ptr : thrd_locals_) { + fn(*ptr); } } @@ -112,15 +95,15 @@ class thread_local_t { tls_key_t tls_key_; std::vector> thrd_locals_; - std::vector> mtxs_; - std::vector indexs_; }; -inline thread_local_t>, - true> - g_ylt_tls_map{}; +#ifndef YLT_TLS_COUNT +#define YLT_TLS_COUNT std::thread::hardware_concurrency() +#endif + +inline thread_local_t g_ylt_tls_index{}; -inline decltype(auto) get_tls_pair() { return g_ylt_tls_map.value(); } +inline decltype(auto) get_tls_index() { return g_ylt_tls_index.value(); } inline async_simple::coro::Lazy init_dynamic_thread_locals_impl( size_t thd_num) { @@ -129,7 +112,7 @@ inline async_simple::coro::Lazy init_dynamic_thread_locals_impl( auto executor = coro_io::get_global_block_executor(); vec.push_back(coro_io::post( [] { - g_ylt_tls_map.create_tls(); + g_ylt_tls_index.create_tls(); }, executor)); } @@ -137,8 +120,7 @@ inline async_simple::coro::Lazy init_dynamic_thread_locals_impl( co_await async_simple::coro::collectAll(std::move(vec)); } -inline void init_dynamic_thread_locals( - size_t thd_num = std::thread::hardware_concurrency()) { +inline void init_dynamic_thread_locals(size_t thd_num = YLT_TLS_COUNT) { coro_io::g_block_io_context_pool(thd_num); async_simple::coro::syncAwait(init_dynamic_thread_locals_impl(thd_num)); } @@ -172,7 +154,7 @@ constexpr inline size_t tls_keys_max() { template inline void init_static_thread_locals( std::vector *> &tls_list, - size_t thd_num = std::thread::hardware_concurrency()) { + size_t thd_num = YLT_TLS_COUNT) { if (tls_list.size() > tls_keys_max()) { throw std::out_of_range("exceed the max number of tls keys"); } diff --git a/src/metric/tests/test_metric.cpp b/src/metric/tests/test_metric.cpp index f17d7dba7..4f0092d29 100644 --- a/src/metric/tests/test_metric.cpp +++ b/src/metric/tests/test_metric.cpp @@ -19,8 +19,8 @@ TEST_CASE("test thread local") { for (size_t i = 0; i < 10; i++) counters.push_back(std::make_shared()); - std::vector>*> vec; - for (auto& w : counters) { + std::vector> *> vec; + for (auto &w : counters) { vec.push_back(&w->val_); } @@ -42,12 +42,12 @@ TEST_CASE("test thread local") { threads.push_back(std::move(thd)); } - for (auto& thd : threads) { + for (auto &thd : threads) { thd.join(); } int64_t total = 0; - counters[0]->val_.for_each([&](auto& val) { + counters[0]->val_.for_each([&](auto &val) { total += val; }); CHECK(total == 20); @@ -56,48 +56,37 @@ TEST_CASE("test thread local") { template struct test_counter1 { - test_counter1(std::string str) : name(str) {} + test_counter1(std::string str) : name(str), vec_(YLT_TLS_COUNT) {} - ~test_counter1() { - ylt::thread::g_ylt_tls_map.for_each([this](auto& t) { - auto it = t.find(name); - if (it != t.end()) - t.erase(it); - }); - } - - auto& value() { - // map and it's mutex - static thread_local auto pair = ylt::thread::get_tls_pair(); - - std::unique_lock lock(*pair.second); - auto [it, r] = pair.first->try_emplace(name, nullptr); - if (r) { - it->second = std::make_shared>(0); - lock.unlock(); + auto &value() { + static thread_local auto index = ylt::thread::get_tls_index(); - std::unique_lock guard(mtx_); - atomics_.push_back((std::atomic*)it->second.get()); + auto &cur_ptr = atomics_[index]; + if (cur_ptr == nullptr) { + auto ptr = std::make_unique>(0); + cur_ptr = ptr.get(); + std::unique_lock lock(mtx_); + vec_[index] = std::move(ptr); } - else { - lock.unlock(); - } - - return *((std::atomic*)it->second.get()); + return *cur_ptr; } T total() { T val = 0; - std::shared_lock guard(mtx_); - for (auto& t : atomics_) { - val += t->load(); + std::shared_lock lock(mtx_); + for (auto &t : vec_) { + if (t) { + val += t->load(); + } } return val; } std::string name; std::shared_mutex mtx_; - std::vector*> atomics_; + inline static thread_local std::vector *> atomics_{ + YLT_TLS_COUNT}; + std::vector>> vec_; }; TEST_CASE("test thread local") { @@ -119,31 +108,27 @@ TEST_CASE("test thread local") { counters[0]->value()++; } }); + threads.push_back(std::move(thd)); } - for (auto& thd : threads) { + // std::thread thd2([&] { + // while (true) { + // int64_t cur_tls = counters[0]->total(); + + // std::cout << "tls: " << cur_tls << "\n"; + // std::this_thread::sleep_for(std::chrono::seconds(1)); + // } + // }); + + // thd2.join(); + + for (auto &thd : threads) { thd.join(); } int64_t total = counters[0]->total(); CHECK(total == 20); - counters.clear(); - - std::vector test_vec; - for (size_t i = 0; i < N; i++) { - test_vec.push_back(std::thread([] { - static thread_local auto pair = ylt::thread::get_tls_pair(); - CHECK(pair.first->empty()); - for (auto& [k, v] : *pair.first) { - std::cout << k << ", " << *((int64_t*)v.get()) << "\n"; - } - })); - } - - for (auto& thd : test_vec) { - thd.join(); - } } } @@ -399,7 +384,7 @@ TEST_CASE("test register metric") { default_metric_manager::register_metric_static(g); auto map1 = default_metric_manager::metric_map_static(); - for (auto& [k, v] : map1) { + for (auto &[k, v] : map1) { bool r = k == "get_count" || k == "get_guage_count"; break; } @@ -1186,5 +1171,5 @@ TEST_CASE("test remove dynamic metric") { } DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4007) -int main(int argc, char** argv) { return doctest::Context(argc, argv).run(); } +int main(int argc, char **argv) { return doctest::Context(argc, argv).run(); } DOCTEST_MSVC_SUPPRESS_WARNING_POP \ No newline at end of file