diff --git a/include/ylt/coro_io/io_context_pool.hpp b/include/ylt/coro_io/io_context_pool.hpp index 1f5caab61..e9f38493b 100644 --- a/include/ylt/coro_io/io_context_pool.hpp +++ b/include/ylt/coro_io/io_context_pool.hpp @@ -31,6 +31,11 @@ namespace coro_io { +inline asio::io_context **get_current() { + static thread_local asio::io_context *current = nullptr; + return ¤t; +} + template class ExecutorWrapper : public async_simple::Executor { private: @@ -71,6 +76,17 @@ class ExecutorWrapper : public async_simple::Executor { operator ExecutorImpl() { return executor_; } + bool currentThreadInExecutor() const override { + auto ctx = get_current(); + return *ctx == &executor_.context(); + } + + size_t currentContextId() const override { + auto ctx = get_current(); + auto ptr = *ctx; + return ptr ? (size_t)ptr : 0; + } + private: void schedule(Func func, Duration dur) override { auto timer = std::make_unique(executor_, dur); @@ -120,6 +136,8 @@ class io_context_pool { for (std::size_t i = 0; i < io_contexts_.size(); ++i) { threads.emplace_back(std::make_shared( [](io_context_ptr svr) { + auto ctx = get_current(); + *ctx = svr.get(); svr->run(); }, io_contexts_[i])); diff --git a/src/coro_io/tests/test_corofile.cpp b/src/coro_io/tests/test_corofile.cpp index 74202ea9a..9921c92d9 100644 --- a/src/coro_io/tests/test_corofile.cpp +++ b/src/coro_io/tests/test_corofile.cpp @@ -156,6 +156,23 @@ void create_file(std::string filename, size_t file_size, // } // } +async_simple::coro::Lazy foo() { co_return; } + +TEST_CASE("test currentThreadInExecutor") { + CHECK(*coro_io::get_current() == nullptr); + CHECK(coro_io::get_global_executor()->currentContextId() == 0); + CHECK_NOTHROW( + async_simple::coro::syncAwait(foo().via(coro_io::get_global_executor()))); + auto executor = coro_io::get_global_executor(); + + foo().via(executor).start([executor](auto&&) { + auto ptr = &executor->get_asio_executor().context(); + CHECK(ptr == *coro_io::get_current()); + size_t id = executor->currentContextId(); + CHECK(id > 0); + }); +} + TEST_CASE("read write 100 small files") { size_t total = 100; std::vector filenames;