Skip to content

Commit

Permalink
fix latch + refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
kelbon committed Apr 24, 2024
1 parent 2ad1b8f commit bab4cdc
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 117 deletions.
5 changes: 3 additions & 2 deletions include/kelcoro/executor_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ struct any_executor_ref {
requires(!std::same_as<std::remove_volatile_t<E>, any_executor_ref> && !std::is_const_v<E>)
: attach_node(&do_attach<E>), data(std::addressof(exe)) {
}

any_executor_ref(const any_executor_ref&) = default;
any_executor_ref(any_executor_ref&&) = default;

Expand Down Expand Up @@ -97,15 +98,15 @@ struct KELCORO_CO_AWAIT_REQUIRED create_node_and_attach : task_node {
task = handle;
e.attach(this);
}
schedule_status await_resume() noexcept {
[[nodiscard]] schedule_status await_resume() noexcept {
return schedule_status{status};
}
};

// ADL customization point, may be overloaded for your executor type, should return awaitable which
// schedules execution of coroutine to 'e'
template <executor E>
KELCORO_CO_AWAIT_REQUIRED constexpr auto jump_on(E& e KELCORO_LIFETIMEBOUND) noexcept {
KELCORO_CO_AWAIT_REQUIRED constexpr auto jump_on(E&& e KELCORO_LIFETIMEBOUND) noexcept {
return create_node_and_attach<E>(e);
}

Expand Down
39 changes: 19 additions & 20 deletions include/kelcoro/latch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ namespace dd {
// when completed, starts all waiters on passed executor
struct latch {
private:
alignas(hardware_destructive_interference_size) mutable nonowner_lockfree_stack<task_node> stack;
alignas(hardware_destructive_interference_size) std::atomic_ptrdiff_t counter;
nonowner_lockfree_stack<task_node> stack;
any_executor_ref exe;
std::atomic_ptrdiff_t counter;

struct wait_awaiter {
const latch& l;
latch& l;
task_node node;

bool await_ready() const noexcept {
Expand All @@ -24,6 +24,9 @@ struct latch {
void await_suspend(std::coroutine_handle<> handle) noexcept {
node.task = handle;
l.stack.push(&node);
// in case if 'ready' changed from 'false' to 'true' wakeup all who call 'wait'
if (l.ready())
l.wakeup_all();
}
static void await_resume() noexcept {
}
Expand All @@ -47,13 +50,11 @@ struct latch {
node.task = handle;
l.stack.push(&node);
// copy logic from count down but never resume
assert(n >= 0 && n <= l.counter.load(std::memory_order::relaxed));
ptrdiff_t c = l.counter.fetch_sub(n, std::memory_order::acq_rel) - n;
if (c != 0) [[likely]] {
assert(c > 0 && "precondition violated");
return;
}
l.wakeup_all();
assert(n >= 0 && n <= l.counter.load(std::memory_order::acquire));
ptrdiff_t c = l.counter.fetch_sub(n, std::memory_order::acq_rel);
assert(c >= n);
if (c == n) [[unlikely]]
l.wakeup_all();
}
static void await_resume() noexcept {
}
Expand All @@ -74,13 +75,11 @@ struct latch {
// decrements the internal counter by n without blocking the caller
// precondition: n >= 0 && n <= internal counter
void count_down(std::ptrdiff_t n = 1) noexcept {
assert(n >= 0 && n <= counter.load(std::memory_order::relaxed));
ptrdiff_t c = counter.fetch_sub(n, std::memory_order::acq_rel) - n;
if (c != 0) [[likely]] {
assert(c >= 0 && "precondition violated");
return;
}
wakeup_all();
assert(n >= 0 && n <= counter.load(std::memory_order::acquire));
ptrdiff_t c = counter.fetch_sub(n, std::memory_order::acq_rel);
assert(c >= n && "precondition violated");
if (c == n) [[unlikely]]
wakeup_all();
}

// returns true if the internal counter has reached zero
Expand All @@ -91,14 +90,14 @@ struct latch {

// suspends the calling coroutine until the internal counter reaches ​0​.
// If it is zero already, returns immediately
KELCORO_CO_AWAIT_REQUIRED co_awaiter auto wait() const noexcept {
KELCORO_CO_AWAIT_REQUIRED co_awaiter auto wait() noexcept {
return wait_awaiter{*this};
}

// precondition: n >= 0 && n <= internal counter
// logical equivalent to count_down(n); wait() (but atomicaly, really count down + wait is rata race)
KELCORO_CO_AWAIT_REQUIRED co_awaiter auto arrive_and_wait(std::ptrdiff_t n = 1) noexcept {
assert(n >= 0 && n <= counter.load(std::memory_order::relaxed));
assert(n >= 0 && n <= counter.load(std::memory_order::acquire));
return arrive_and_wait_awaiter{*this, n};
}

Expand All @@ -108,7 +107,7 @@ struct latch {

private:
void wakeup_all() noexcept {
task_node* top = stack.try_pop_all(std::memory_order::relaxed);
task_node* top = stack.try_pop_all();
attach_list(exe, top);
}
};
Expand Down
2 changes: 2 additions & 0 deletions include/kelcoro/noexport/macro.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include <cassert>

#define KELCORO_CO_AWAIT_REQUIRED [[nodiscard("forget co_await?")]]

#if defined(__GNUC__) || defined(__clang__)
Expand Down
10 changes: 5 additions & 5 deletions include/kelcoro/noexport/thread_pool_monitoring.hpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#pragma once

#if !defined(NDEBUG) && !defined(KELCORO_DISABLE_MONITORING)
#define KELCORO_ENABLE_THREADPOOL_MONITORING
#if !defined(NDEBUG) && !defined(KELCORO_DISABLE_MONITORING) && defined(KELCORO_ENABLE_THREADPOOL_MONITORING)
#define KELCORO_THREADPOOL_MONITORING_IS_ENABLED
#endif

#ifdef KELCORO_ENABLE_THREADPOOL_MONITORING
#ifdef KELCORO_THREADPOOL_MONITORING_IS_ENABLED

#include <vector>
#include <atomic>
Expand Down Expand Up @@ -66,8 +66,8 @@ struct monitoring_t {
}
};

// only for debug
thread_local static inline std::vector<monitoring_t> monitorings;
// only for debug with access from one thread
static inline std::vector<monitoring_t> monitorings;

#define KELCORO_MONITORING(...) __VA_ARGS__
#define KELCORO_MONITORING_INC(x) KELCORO_MONITORING(x.fetch_add(1, ::std::memory_order::relaxed))
Expand Down
4 changes: 2 additions & 2 deletions include/kelcoro/nonowner_lockfree_stack.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ struct nonowner_lockfree_stack {
}

// returns top of the stack
[[nodiscard]] node_type* try_pop_all(std::memory_order exchange_order = acq_rel) noexcept {
return top.exchange(nullptr, exchange_order);
[[nodiscard]] node_type* try_pop_all() noexcept {
return top.exchange(nullptr, acq_rel);
}
// other_top must be a top of other stack ( for example from try_pop_all() )
void push_stack(node_type* other_top) noexcept {
Expand Down
11 changes: 11 additions & 0 deletions include/kelcoro/thread_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,17 @@ struct thread_pool {
void stop(worker* w, size_t count) noexcept;
};

// specialization for thread pool uses hash to maximize parallel execution
inline void attach_list(thread_pool& e, task_node* top) {
operation_hash_t hash = 0;
while (top) {
task_node* next = top->next;
e.select_worker(hash).attach(top);
++hash;
top = next;
}
}

struct jump_on_thread_pool : private create_node_and_attach<thread_pool> {
using base_t = create_node_and_attach<thread_pool>;

Expand Down
57 changes: 16 additions & 41 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,43 +1,18 @@
cmake_minimum_required(VERSION 3.05)

add_executable(kelcorotest ${CMAKE_CURRENT_SOURCE_DIR}/test_coroutines.cpp)
target_link_libraries(kelcorotest PUBLIC kelcorolib)
set_target_properties(kelcorotest PROPERTIES
CMAKE_CXX_EXTENSIONS OFF
LINKER_LANGUAGE CXX
CMAKE_CXX_STANDARD_REQUIRED ON
CXX_STANDARD 20
)

# sanitizers do not work on clang with coroutines... Rly...
if(UNIX)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
target_compile_options(kelcorotest PUBLIC "-fsanitize=undefined" "-fsanitize=address")
target_link_options(kelcorotest PUBLIC "-fsanitize=undefined" "-fsanitize=address")
endif()
endif()
add_executable(generator_tests ${CMAKE_CURRENT_SOURCE_DIR}/tests_generator.cpp)

target_link_libraries(generator_tests PUBLIC kelcorolib)
set_target_properties(generator_tests PROPERTIES
CMAKE_CXX_EXTENSIONS OFF
LINKER_LANGUAGE CXX
CMAKE_CXX_STANDARD_REQUIRED ON
CXX_STANDARD 20
)
if (UNIX)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
target_compile_options(generator_tests PUBLIC "-fsanitize=undefined" "-fsanitize=address")
target_link_options(generator_tests PUBLIC "-fsanitize=undefined" "-fsanitize=address")
endif()
endif()

add_test(NAME generator_tests COMMAND generator_tests)

add_test(NAME test_kelcorotest COMMAND kelcorotest)

add_executable(thread_pool_tests ${CMAKE_CURRENT_SOURCE_DIR}/test_thread_pool.cpp)

target_link_libraries(thread_pool_tests PUBLIC kelcorolib)

add_test(NAME test_thread_pool COMMAND thread_pool_tests)
set(KELCORO_PARTS coroutines generator thread_pool)
foreach(PART ${KELCORO_PARTS})
add_executable(${PART} "${CMAKE_CURRENT_SOURCE_DIR}/test_${PART}.cpp")
target_link_libraries(${PART} PUBLIC kelcorolib)
set_target_properties(${PART} PROPERTIES
CMAKE_CXX_EXTENSIONS OFF
LINKER_LANGUAGE CXX
if (CMAKE_CXX_COMPILER_ID MATCHES "Clang")
target_compile_options(${PART} -fsanitize=thread)
endif()
CMAKE_CXX_STANDARD_REQUIRED ON
CXX_STANDARD 20
)

add_test(NAME test_${PART} COMMAND ${PART})
endforeach()
56 changes: 11 additions & 45 deletions tests/test_coroutines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ inline dd::logical_thread multithread(std::atomic<int32_t>& value) {
(void)handle;
auto token = co_await dd::this_coro::stop_token;
(void)token.stop_requested();
co_await dd::jump_on(dd::new_thread_executor);
(void)co_await dd::jump_on(dd::new_thread_executor);
for (auto i : std::views::iota(0, 100))
++value, (void)i;
}
Expand All @@ -237,7 +237,7 @@ TEST(logical_thread) {
dd::logical_thread bar(bool& requested) {
auto handle = co_await dd::this_coro::handle;
(void)handle;
co_await dd::jump_on(dd::new_thread_executor);
(void)co_await dd::jump_on(dd::new_thread_executor);
auto token = co_await dd::this_coro::stop_token;
while (true) {
std::this_thread::sleep_for(std::chrono::microseconds(5));
Expand Down Expand Up @@ -305,7 +305,7 @@ TEST(job_mm) {
std::atomic<size_t> err_c = 0;
auto job_creator = [&](std::atomic<int32_t>& value) -> dd::job {
auto th_id = std::this_thread::get_id();
co_await dd::jump_on(dd::new_thread_executor);
(void)co_await dd::jump_on(dd::new_thread_executor);
if (th_id == std::this_thread::get_id())
++err_c;
value.fetch_add(1, std::memory_order::release);
Expand Down Expand Up @@ -337,7 +337,7 @@ dd::job sub(std::atomic<int>& count) {
}

dd::logical_thread writer(std::atomic<int>& count) {
co_await dd::jump_on(dd::new_thread_executor);
(void)co_await dd::jump_on(dd::new_thread_executor);
dd::stop_token tok = co_await dd::this_coro::stop_token;
for (auto i : std::views::iota(0, 1000)) {
(void)i;
Expand All @@ -348,7 +348,7 @@ dd::logical_thread writer(std::atomic<int>& count) {
}

dd::logical_thread reader() {
co_await dd::jump_on(dd::new_thread_executor);
(void)co_await dd::jump_on(dd::new_thread_executor);
dd::stop_token tok = co_await dd::this_coro::stop_token;
for (;;) {
e1.notify_all(dd::this_thread_executor, 1);
Expand Down Expand Up @@ -376,7 +376,7 @@ inline dd::event<std::vector<std::string>> three;
inline dd::event<void> four;

dd::async_task<void> waiter_any(uint32_t& count) {
co_await dd::jump_on(dd::new_thread_executor);
(void)co_await dd::jump_on(dd::new_thread_executor);
std::mutex m;
int32_t i = 0;
while (i < 100000) {
Expand All @@ -387,7 +387,7 @@ dd::async_task<void> waiter_any(uint32_t& count) {
}
}
dd::async_task<void> waiter_all(uint32_t& count) {
co_await dd::jump_on(dd::new_thread_executor);
(void)co_await dd::jump_on(dd::new_thread_executor);
for (int32_t i : std::views::iota(0, 100000)) {
(void)i;
auto tuple = co_await dd::when_all(one, two, three, four);
Expand Down Expand Up @@ -418,41 +418,9 @@ dd::logical_thread notifier(auto& pool, auto input) {
co_return;
}
}
// TEST(when_any) {
// auto _1 = notifier(one);
// auto _2 = notifier(two, 5);
// auto _3 = notifier(three, std::vector<std::string>(3, "hello world"));
// auto _4 = notifier(four);
// uint32_t count = 0;
// auto anyx = waiter_any(count);
// anyx.wait();
// stop(_1, _2, _3, _4);
// one.notify_all(dd::this_thread_executor);
// two.notify_all(dd::this_thread_executor, 5);
// three.notify_all(dd::this_thread_executor, std::vector<std::string>(3, "hello world"));
// four.notify_all(dd::this_thread_executor);
// error_if(count != 100000);
// return error_count;
// }
// TEST(when_all) {
// auto _1 = notifier(one);
// auto _2 = notifier(two, 5);
// auto _3 = notifier(three, std::vector<std::string>(3, "hello world"));
// auto _4 = notifier(four);
// uint32_t count = 0;
// auto allx = waiter_all(count);
// allx.wait();
// stop(_1, _2, _3, _4);
// one.notify_all(dd::this_thread_executor);
// two.notify_all(dd::this_thread_executor, 5);
// three.notify_all(dd::this_thread_executor, std::vector<std::string>(3, "hello world"));
// four.notify_all(dd::this_thread_executor);
// error_if(count != 100000);
// return error_count;
//}

dd::async_task<std::string> afoo() {
co_await dd::jump_on(dd::new_thread_executor);
(void)co_await dd::jump_on(dd::new_thread_executor);
co_return "hello world";
}

Expand All @@ -475,7 +443,7 @@ TEST(void_async_task) {
}

dd::task<std::string> do_smth() {
co_await dd::jump_on(dd::new_thread_executor);
(void)co_await dd::jump_on(dd::new_thread_executor);
co_return "hello from task";
}

Expand All @@ -493,7 +461,7 @@ dd::async_task<void> tasks_user() {

dd::channel<std::tuple<int, double, float>> creator() {
for (int i = 0; i < 100; ++i) {
co_await dd::jump_on(dd::new_thread_executor);
(void)co_await dd::jump_on(dd::new_thread_executor);
std::this_thread::sleep_for(std::chrono::microseconds(3));
co_yield std::tuple{i, static_cast<double>(i), static_cast<float>(i)};
}
Expand All @@ -517,7 +485,7 @@ TEST(channel) {
}

dd::async_task<int> small_task() {
co_await dd::jump_on(dd::new_thread_executor);
(void)co_await dd::jump_on(dd::new_thread_executor);
co_return 1;
}

Expand Down Expand Up @@ -545,8 +513,6 @@ int main() {
ec += TESTlogical_thread_mm();
ec += TESTgen_mm();
ec += TESTjob_mm();
// TODOec += TESTthread_safety();
/*TESTwhen_any() + TESTwhen_all() +*/
ec += TESTasync_tasks();
ec += TESTvoid_async_task();
ec += TESTchannel();
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_generator.cpp → tests/test_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ using dd::generator;
#define RANDOM_CONTROL_FLOW \
if constexpr (std::is_same_v<G<int>, dd::channel<int>>) \
if (flip()) \
co_await dd::jump_on(dd::new_thread_executor)
(void)co_await dd::jump_on(dd::new_thread_executor)

static bool flip() {
static thread_local std::mt19937 rng = [] {
Expand Down
Loading

0 comments on commit bab4cdc

Please sign in to comment.