diff --git a/CMakeLists.txt b/CMakeLists.txt index 6151e0eed72..ddf3e52258e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -280,6 +280,14 @@ add_ccf_static_library( LINK_LIBS qcbor t_cose http_parser ccfcrypto ccf_kv ) +# CCF task system library +add_ccf_static_library( + ccf_tasks + SRCS ${CCF_DIR}/src/tasks/task_system.cpp ${CCF_DIR}/src/tasks/job_board.cpp + ${CCF_DIR}/src/tasks/ordered_tasks.cpp + ${CCF_DIR}/src/tasks/fan_in_tasks.cpp +) + # Common test args for Python scripts starting up CCF networks set(WORKER_THREADS 0 @@ -529,6 +537,20 @@ if(BUILD_TESTS) ) target_link_libraries(ds_test PRIVATE ${CMAKE_THREAD_LIBS_INIT}) + add_unit_test( + task_system_test + ${CMAKE_CURRENT_SOURCE_DIR}/src/tasks/test/basic_tasks.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/tasks/test/ordered_tasks.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/tasks/test/delayed_tasks.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/tasks/test/fan_in_tasks.cpp + ) + target_link_libraries(task_system_test PRIVATE ccf_tasks) + + add_unit_test( + task_system_demo ${CMAKE_CURRENT_SOURCE_DIR}/src/tasks/test/demo/main.cpp + ) + target_link_libraries(task_system_demo PRIVATE ccf_tasks) + add_unit_test( ledger_test ${CMAKE_CURRENT_SOURCE_DIR}/src/host/test/ledger.cpp ) @@ -775,6 +797,14 @@ if(BUILD_TESTS) ) add_picobench(merkle_bench SRCS src/node/test/merkle_bench.cpp) add_picobench(hash_bench SRCS src/ds/test/hash_bench.cpp) + + add_picobench( + task_bench + SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/tasks/test/bench/merge_bench.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/tasks/test/bench/sleep_bench.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/tasks/test/bench/contention_bench.cpp + LINK_LIBS ccf_tasks + ) endif() if(LONG_TESTS) diff --git a/src/tasks/basic_task.h b/src/tasks/basic_task.h new file mode 100644 index 00000000000..7a15f4c5395 --- /dev/null +++ b/src/tasks/basic_task.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include "tasks/task.h" + +namespace ccf::tasks +{ + struct BasicTask : public BaseTask + { + using Fn = std::function; + + Fn fn; + const std::string name; + + BasicTask(const Fn& _fn, const std::string& s = "[Anon]") : fn(_fn), name(s) + {} + + void do_task_implementation() override + { + fn(); + } + + std::string_view get_name() const override + { + return name; + } + }; + + template + Task make_basic_task(Ts&&... ts) + { + return std::make_shared(std::forward(ts)...); + } +} \ No newline at end of file diff --git a/src/tasks/fan_in_tasks.cpp b/src/tasks/fan_in_tasks.cpp new file mode 100644 index 00000000000..8c27f9cd28b --- /dev/null +++ b/src/tasks/fan_in_tasks.cpp @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. + +#include "tasks/fan_in_tasks.h" + +#include +#include +#include + +#define FMT_HEADER_ONLY +#include + +namespace ccf::tasks +{ + struct FanInTasks::PImpl + { + std::string name; + IJobBoard& job_board; + + // Synchronise access to pending_tasks and next_expected_task_index + std::mutex pending_tasks_mutex; + std::map pending_tasks; + size_t next_expected_task_index = 0; + + std::atomic active = false; + }; + + void FanInTasks::enqueue_on_board() + { + pimpl->job_board.add_task(shared_from_this()); + } + + void FanInTasks::do_task_implementation() + { + std::vector current_batch; + + { + std::lock_guard lock(pimpl->pending_tasks_mutex); + pimpl->active.store(true); + + auto it = pimpl->pending_tasks.find(pimpl->next_expected_task_index); + while (it != pimpl->pending_tasks.end()) + { + current_batch.push_back(it->second); + pimpl->pending_tasks.erase(it); + + ++pimpl->next_expected_task_index; + it = pimpl->pending_tasks.find(pimpl->next_expected_task_index); + } + } + + for (auto& task : current_batch) + { + task->do_task(); + } + + { + std::lock_guard lock(pimpl->pending_tasks_mutex); + pimpl->active.store(false); + + auto it = pimpl->pending_tasks.find(pimpl->next_expected_task_index); + if (it != pimpl->pending_tasks.end()) + { + // While we were executing the previous batch, a call to fan_in_task + // provided the _next_ contiguous task. We're now responsible for + // re-enqueuing this task + enqueue_on_board(); + } + } + } + + FanInTasks::FanInTasks( + [[maybe_unused]] FanInTasks::Private force_private_constructor, + IJobBoard& job_board_, + const std::string& name_) : + pimpl(std::make_unique(name_, job_board_)) + {} + + FanInTasks::~FanInTasks() = default; + + std::string_view FanInTasks::get_name() const + { + return pimpl->name; + } + + void FanInTasks::add_task(size_t task_index, Task task) + { + { + std::lock_guard lock(pimpl->pending_tasks_mutex); + + if (task_index < pimpl->next_expected_task_index) + { + throw std::runtime_error(fmt::format( + "[{}] Received task {} ({}) out-of-order - already advanced next " + "expected " + "to {}", + get_name(), + task_index, + task->get_name(), + pimpl->next_expected_task_index)); + } + + auto it = pimpl->pending_tasks.find(task_index); + if (it != pimpl->pending_tasks.end()) + { + throw std::runtime_error(fmt::format( + "[{}] Received duplicate task {} ({}) - already have pending task {}", + get_name(), + task_index, + task->get_name(), + it->second == nullptr ? std::string("nullptr") : + it->second->get_name())); + } + + pimpl->pending_tasks.emplace(task_index, task); + + if (!pimpl->active.load()) + { + if (task_index == pimpl->next_expected_task_index) + { + enqueue_on_board(); + } + } + } + } + + std::shared_ptr FanInTasks::create( + IJobBoard& job_board_, const std::string& name_) + { + return std::make_shared(Private{}, job_board_, name_); + } +} diff --git a/src/tasks/fan_in_tasks.h b/src/tasks/fan_in_tasks.h new file mode 100644 index 00000000000..69bf4c4bd83 --- /dev/null +++ b/src/tasks/fan_in_tasks.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include "tasks/job_board_interface.h" +#include "tasks/task.h" + +#include + +namespace ccf::tasks +{ + class FanInTasks : public BaseTask, + public std::enable_shared_from_this + { + protected: + struct PImpl; + std::unique_ptr pimpl = nullptr; + + void enqueue_on_board(); + void do_task_implementation() override; + + // Non-public constructor argument type, so this can only be constructed by + // this class (ensuring shared ptr ownership) + struct Private + { + explicit Private() = default; + }; + + public: + FanInTasks(Private, IJobBoard& job_board_, const std::string& name_); + ~FanInTasks(); + + static std::shared_ptr create( + IJobBoard& job_board_, const std::string& name_ = "[FanIn]"); + + std::string_view get_name() const override; + + void add_task(size_t task_index, Task task); + }; +} diff --git a/src/tasks/job_board.cpp b/src/tasks/job_board.cpp new file mode 100644 index 00000000000..945f97eb73d --- /dev/null +++ b/src/tasks/job_board.cpp @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#include "tasks/job_board.h" + +namespace ccf::tasks +{ + void JobBoard::add_task(Task&& task) + { + { + std::lock_guard lock(mutex); + queue.emplace(std::move(task)); + } + work_beacon.notify_work_available(); + } + + Task JobBoard::get_task() + { + std::lock_guard lock(mutex); + if (queue.empty()) + { + return nullptr; + } + + Task task = queue.front(); + queue.pop(); + return task; + } + + bool JobBoard::empty() + { + std::lock_guard lock(mutex); + return queue.empty(); + } + + Task JobBoard::wait_for_task(const std::chrono::milliseconds& timeout) + { + using TClock = std::chrono::system_clock; + + const auto start = TClock::now(); + const auto until = start + timeout; + + while (true) + { + auto task = get_task(); + if (task != nullptr || TClock::now() >= until) + { + return task; + } + + work_beacon.wait_for_work_with_timeout(timeout); + } + } +} diff --git a/src/tasks/job_board.h b/src/tasks/job_board.h new file mode 100644 index 00000000000..9409bd399bf --- /dev/null +++ b/src/tasks/job_board.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include "ds/work_beacon.h" +#include "tasks/job_board_interface.h" + +#include +#include + +namespace ccf::tasks +{ + struct JobBoard : public IJobBoard + { + std::mutex mutex; + std::queue queue; + ccf::ds::WorkBeacon work_beacon; + + void add_task(Task&& t) override; + Task get_task() override; + bool empty() override; + + Task wait_for_task(const std::chrono::milliseconds& timeout) override; + }; +} diff --git a/src/tasks/job_board_interface.h b/src/tasks/job_board_interface.h new file mode 100644 index 00000000000..251b272f53c --- /dev/null +++ b/src/tasks/job_board_interface.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include "tasks/task.h" + +#include +#include +#include + +namespace ccf::tasks +{ + struct IJobBoard + { + virtual void add_task(Task&& t) = 0; + virtual Task get_task() = 0; + virtual bool empty() = 0; + + virtual Task wait_for_task(const std::chrono::milliseconds& timeout) = 0; + }; +} diff --git a/src/tasks/ordered_tasks.cpp b/src/tasks/ordered_tasks.cpp new file mode 100644 index 00000000000..05399043dca --- /dev/null +++ b/src/tasks/ordered_tasks.cpp @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. + +#include "tasks/ordered_tasks.h" + +#include "tasks/sub_task_queue.h" + +namespace ccf::tasks +{ + struct OrderedTasks::PImpl + { + std::string name; + IJobBoard& job_board; + SubTaskQueue actions; + }; + + struct OrderedTasks::ResumeOrderedTasks : public ccf::tasks::IResumable + { + std::shared_ptr tasks; + + ResumeOrderedTasks(std::shared_ptr tasks_) : + tasks(std::move(tasks_)) + {} + + void resume() override + { + if (tasks->pimpl->actions.unpause()) + { + tasks->enqueue_on_board(); + } + } + }; + + void OrderedTasks::enqueue_on_board() + { + pimpl->job_board.add_task(shared_from_this()); + } + + OrderedTasks::~OrderedTasks() = default; + + OrderedTasks::OrderedTasks( + [[maybe_unused]] OrderedTasks::Private force_private_constructor, + IJobBoard& job_board_, + const std::string& name_) : + pimpl(std::make_unique(name_, job_board_)) + {} + + void OrderedTasks::do_task_implementation() + { + if (pimpl->actions.pop_and_visit( + [this](TaskAction&& action) { action->do_action(); })) + { + enqueue_on_board(); + } + } + + ccf::tasks::Resumable OrderedTasks::pause() + { + pimpl->actions.pause(); + + return std::make_unique(shared_from_this()); + } + + std::string_view OrderedTasks::get_name() const + { + return pimpl->name; + } + + void OrderedTasks::add_action(TaskAction&& action) + { + if (pimpl->actions.push(std::move(action))) + { + enqueue_on_board(); + } + } + + void OrderedTasks::get_queue_summary(size_t& num_pending, bool& is_active) + { + pimpl->actions.get_queue_summary(num_pending, is_active); + } + + std::shared_ptr OrderedTasks::create( + IJobBoard& job_board_, const std::string& name_) + { + return std::make_shared(Private{}, job_board_, name_); + } +} diff --git a/src/tasks/ordered_tasks.h b/src/tasks/ordered_tasks.h new file mode 100644 index 00000000000..77b4763bee9 --- /dev/null +++ b/src/tasks/ordered_tasks.h @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include "tasks/job_board.h" + +#include +#include + +namespace ccf::tasks +{ + struct ITaskAction + { + virtual ~ITaskAction() = default; + + virtual void do_action() = 0; + + virtual std::string_view get_name() const + { + return "[Anon]"; + } + }; + + using TaskAction = std::shared_ptr; + + struct BasicTaskAction : public ITaskAction + { + using Fn = std::function; + + Fn fn; + const std::string name; + + BasicTaskAction(const Fn& fn_, const std::string& name_ = "[Anon]") : + fn(fn_), + name(name_) + {} + + void do_action() override + { + fn(); + } + + std::string_view get_name() const override + { + return name; + } + }; + + template + TaskAction make_basic_action(Ts&&... ts) + { + return std::make_shared(std::forward(ts)...); + } + + // Self-scheduling collection of in-order tasks. Tasks + // will be executed in the order they are added. To self-schedule, this + // instance will ensure that it is posted to the given JobBoard whenever more + // sub-tasks are available for execution. + class OrderedTasks : public BaseTask, + public std::enable_shared_from_this + { + protected: + struct PImpl; + std::unique_ptr pimpl = nullptr; + + struct ResumeOrderedTasks; + + void enqueue_on_board(); + void do_task_implementation() override; + + // Non-public constructor argument type, so this can only be constructed by + // this class (ensuring shared ptr ownership) + struct Private + { + explicit Private() = default; + }; + + public: + OrderedTasks(Private, IJobBoard& job_board, const std::string& name); + ~OrderedTasks(); + + static std::shared_ptr create( + IJobBoard& job_board_, const std::string& name_ = "[Ordered]"); + + ccf::tasks::Resumable pause() override; + std::string_view get_name() const override; + + void add_action(TaskAction&& action); + + void get_queue_summary(size_t& num_pending, bool& is_active); + }; +} diff --git a/src/tasks/resumable.h b/src/tasks/resumable.h new file mode 100644 index 00000000000..071c598124b --- /dev/null +++ b/src/tasks/resumable.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include + +namespace ccf::tasks +{ + struct IResumable; + void resume_task(std::unique_ptr&& resumable); + + struct IResumable + { + private: + virtual void resume() = 0; + + public: + virtual ~IResumable() = default; + + friend void ccf::tasks::resume_task( + std::unique_ptr&& resumable); + }; + + using Resumable = std::unique_ptr; + + Resumable pause_current_task(); + void resume_task(Resumable&& resumable); +} \ No newline at end of file diff --git a/src/tasks/sub_task_queue.h b/src/tasks/sub_task_queue.h new file mode 100644 index 00000000000..ad76c014ae4 --- /dev/null +++ b/src/tasks/sub_task_queue.h @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include +#include +#include + +namespace ccf::tasks +{ + // Helper type for OrderedTasks, containing a list of sub-tasks to be + // performed in-order. Modifiers return bools indicating whether the caller + // is responsible for scheduling a future flush of this queue. + template + class SubTaskQueue + { + protected: + std::mutex pending_mutex; + std::deque pending; + std::atomic active; + std::atomic paused; + + public: + bool push(T&& t) + { + std::lock_guard lock(pending_mutex); + const bool ret = pending.empty() && !active.load(); + pending.emplace_back(std::forward(t)); + return ret; + } + + using Visitor = std::function; + bool pop_and_visit(Visitor&& visitor) + { + decltype(pending) local; + { + std::lock_guard lock(pending_mutex); + active.store(true); + + std::swap(local, pending); + } + + auto it = local.begin(); + while (!paused.load() && it != local.end()) + { + visitor(std::forward(*it)); + ++it; + } + + { + std::lock_guard lock(pending_mutex); + if (it != local.end()) + { + // Paused mid-execution - some actions remain that need to be + // spliced back onto the front of the pending pending + pending.insert(pending.begin(), it, local.end()); + } + + active.store(false); + return !pending.empty() && !paused.load(); + } + } + + void pause() + { + std::lock_guard lock(pending_mutex); + paused.store(true); + } + + bool unpause() + { + std::lock_guard lock(pending_mutex); + paused.store(false); + return !pending.empty() && !active.load(); + } + + void get_queue_summary(size_t& num_pending, bool& is_active) + { + std::lock_guard lock(pending_mutex); + num_pending = pending.size(); + is_active = active.load(); + } + }; +} diff --git a/src/tasks/task.h b/src/tasks/task.h new file mode 100644 index 00000000000..042a569f651 --- /dev/null +++ b/src/tasks/task.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include "tasks/resumable.h" + +#include +#include +#include +#include + +namespace ccf::tasks +{ + struct BaseTask + { + private: + std::atomic cancelled = false; + + friend Resumable ccf::tasks::pause_current_task(); + virtual ccf::tasks::Resumable pause(); + + protected: + virtual void do_task_implementation() = 0; + + public: + virtual ~BaseTask() = default; + + void do_task(); + + virtual std::string_view get_name() const + { + return "[Anon]"; + } + + void cancel_task(); + bool is_cancelled(); + }; + + using Task = std::shared_ptr; +} \ No newline at end of file diff --git a/src/tasks/task_system.cpp b/src/tasks/task_system.cpp new file mode 100644 index 00000000000..bdb05214573 --- /dev/null +++ b/src/tasks/task_system.cpp @@ -0,0 +1,174 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. + +#include "tasks/task_system.h" + +#include "ds/internal_logger.h" +#include "tasks/job_board.h" +#include "tasks/resumable.h" +#include "tasks/task.h" + +#include +#include + +namespace ccf::tasks +{ + // Implementation of BaseTask + namespace + { + thread_local BaseTask* current_task = nullptr; + } + + void BaseTask::do_task() + { + if (cancelled.load()) + { + return; + } + + ccf::tasks::current_task = this; + + do_task_implementation(); + + ccf::tasks::current_task = nullptr; + } + + ccf::tasks::Resumable BaseTask::pause() + { + return nullptr; + } + + void BaseTask::cancel_task() + { + cancelled.store(true); + } + + bool BaseTask::is_cancelled() + { + return cancelled.load(); + } + + // Implementation of ccf::tasks namespace static functions + IJobBoard& get_main_job_board() + { + static JobBoard main_job_board; + return main_job_board; + } + + void add_task(Task task) + { + get_main_job_board().add_task(std::move(task)); + } + + struct DelayedTask + { + Task task; + std::optional repeat = std::nullopt; + }; + + using DelayedTasks = std::vector; + + using DelayedTasksByTime = std::map; + + using namespace std::chrono_literals; + + namespace + { + std::atomic total_elapsed = 0ms; + + DelayedTasksByTime delayed_tasks; + std::mutex delayed_tasks_mutex; + } + + void add_delayed_task( + Task task, + std::chrono::milliseconds initial_delay, + std::optional periodic_delay) + { + std::lock_guard lock(delayed_tasks_mutex); + + const auto trigger_time = total_elapsed.load() + initial_delay; + delayed_tasks[trigger_time].emplace_back(task, periodic_delay); + } + + void add_delayed_task(Task task, std::chrono::milliseconds delay) + { + add_delayed_task(task, delay, std::nullopt); + } + + void add_periodic_task( + Task task, + std::chrono::milliseconds initial_delay, + std::chrono::milliseconds repeat_period) + { + add_delayed_task(task, initial_delay, repeat_period); + } + + void tick(std::chrono::milliseconds elapsed) + { + elapsed += total_elapsed.load(); + + { + std::lock_guard lock(delayed_tasks_mutex); + auto end_it = delayed_tasks.upper_bound(elapsed); + + DelayedTasksByTime repeats; + + for (auto it = delayed_tasks.begin(); it != end_it; ++it) + { + DelayedTasks& ready = it->second; + + for (DelayedTask& delayed_task : ready) + { + // Don't schedule (or repeat) cancelled tasks + if (delayed_task.task->is_cancelled()) + { + continue; + } + + add_task(delayed_task.task); + if (delayed_task.repeat.has_value()) + { + repeats[elapsed + delayed_task.repeat.value()].emplace_back( + delayed_task); + } + } + } + + delayed_tasks.erase(delayed_tasks.begin(), end_it); + + for (auto&& [repeat_time, repeated_tasks] : repeats) + { + DelayedTasks& delayed_tasks_at_time = delayed_tasks[repeat_time]; + delayed_tasks_at_time.insert( + delayed_tasks_at_time.end(), + repeated_tasks.begin(), + repeated_tasks.end()); + } + } + + total_elapsed.store(elapsed); + } + + // From resumable.h + Resumable pause_current_task() + { + if (current_task == nullptr) + { + throw std::logic_error("Cannot pause: No task currently running"); + } + + auto handle = current_task->pause(); + if (handle == nullptr) + { + throw std::logic_error("Cannot pause: Current task is not pausable"); + } + + return handle; + } + + void resume_task(Resumable&& resumable) + { + resumable->resume(); + } +} \ No newline at end of file diff --git a/src/tasks/task_system.h b/src/tasks/task_system.h new file mode 100644 index 00000000000..3a215b2adde --- /dev/null +++ b/src/tasks/task_system.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include "tasks/job_board_interface.h" +#include "tasks/resumable.h" +#include "tasks/task.h" + +namespace ccf::tasks +{ + IJobBoard& get_main_job_board(); + + void add_task(Task task); + + void add_delayed_task(Task task, std::chrono::milliseconds delay); + + void add_periodic_task( + Task task, + std::chrono::milliseconds initial_delay, + std::chrono::milliseconds repeat_period); + + void tick(std::chrono::milliseconds elapsed); +} \ No newline at end of file diff --git a/src/tasks/test/basic_tasks.cpp b/src/tasks/test/basic_tasks.cpp new file mode 100644 index 00000000000..6a10b3d0125 --- /dev/null +++ b/src/tasks/test/basic_tasks.cpp @@ -0,0 +1,241 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. + +#include "tasks/basic_task.h" +#include "tasks/task_system.h" + +#include +#include +#include +#include + +#define FMT_HEADER_ONLY +#include +#include +#include + +TEST_CASE("TaskSystem" * doctest::test_suite("basic_tasks")) +{ + constexpr auto short_wait = std::chrono::milliseconds(10); + + // There's a global singleton job board, initially empty + auto& job_board = ccf::tasks::get_main_job_board(); + + REQUIRE(job_board.empty()); + REQUIRE(job_board.get_task() == nullptr); + REQUIRE(job_board.wait_for_task(short_wait) == nullptr); + + // Encapsulate the work to be done in Tasks + // Either as a lambda passed to make_basic_task + std::atomic a = false; + ccf::tasks::Task toggle_a = + ccf::tasks::make_basic_task([&a]() { a.store(true); }); + + // Or by extending BaseTask + struct SetAtomic : public ccf::tasks::BaseTask + { + std::atomic& my_var; + + SetAtomic(std::atomic& v) : my_var(v) {} + + void do_task_implementation() override + { + my_var.store(true); + } + + std::string_view get_name() const override + { + return "SetAtomic Task"; + } + }; + + std::atomic b = false; + ccf::tasks::Task toggle_b = std::make_shared(b); + + // These tasks aren't scheduled yet, and can't have been executed! + REQUIRE(job_board.empty()); + REQUIRE_FALSE(a.load()); + REQUIRE_FALSE(b.load()); + + // Queue them on a job board, where a worker can find them + ccf::tasks::add_task(toggle_a); + ccf::tasks::add_task(toggle_b); + + // Now there's something scheduled + REQUIRE_FALSE(job_board.empty()); + + // But it's not _executed_ yet + REQUIRE_FALSE(a.load()); + REQUIRE_FALSE(b.load()); + + // Eventually something like a dedicated worker thread arrives, and asks for a + // task + auto first_task = job_board.get_task(); + + // They likely take things one-at-a-time, so there's still something scheduled + REQUIRE_FALSE(job_board.empty()); + + // Not a critical guarantee, but for now the job board is FIFO, so in this + // constrained example we know exactly what the task is + REQUIRE(first_task == toggle_a); + + // This caller has taken ownership of this task, and is now responsible for + // executing it + REQUIRE_FALSE(a.load()); + first_task->do_task(); + REQUIRE(a.load()); + + // Then someone, maybe the same worker, arrives and takes the second task + auto second_task = job_board.get_task(); + REQUIRE(second_task == toggle_b); + REQUIRE(job_board.empty()); + + REQUIRE_FALSE(b.load()); + second_task->do_task(); + REQUIRE(b.load()); +} + +TEST_CASE("Cancellation" * doctest::test_suite("basic_tasks")) +{ + // If you keep a handle to a task, you can cancel it... + std::atomic a = false; + ccf::tasks::Task toggle_a = + ccf::tasks::make_basic_task([&a]() { a.store(true); }); + + // ... even after it has been scheduled + ccf::tasks::add_task(toggle_a); + + // ... at any point until some worker calls do_task + auto first_task = ccf::tasks::get_main_job_board().get_task(); + REQUIRE(first_task != nullptr); + + REQUIRE_FALSE(a.load()); + toggle_a->cancel_task(); + first_task->do_task(); + REQUIRE_FALSE(a.load()); +} + +TEST_CASE("Scheduling" * doctest::test_suite("basic_tasks")) +{ + // Tasks can be scheduled from anywhere, including during execution of + // other tasks + struct WaitPoint + { + std::atomic passed{false}; + + void wait() + { + while (!passed.load()) + { + std::this_thread::yield(); + } + } + + void notify() + { + passed.store(true); + } + }; + + WaitPoint a_started; + WaitPoint b_started; + WaitPoint task_0_started; + WaitPoint task_1_started; + WaitPoint task_2_started; + WaitPoint task_3_started; + WaitPoint task_4_started; + WaitPoint task_5_started; + + std::atomic stop_signal = false; + std::vector count_with_me; + + std::thread thread_a([&]() { + count_with_me.push_back(0); + a_started.notify(); + + ccf::tasks::add_task(ccf::tasks::make_basic_task([&]() { + task_1_started.wait(); + count_with_me.push_back(2); + task_2_started.notify(); + + ccf::tasks::add_task(ccf::tasks::make_basic_task([&]() { + task_3_started.wait(); + count_with_me.push_back(4); + task_4_started.notify(); + + ccf::tasks::add_task(ccf::tasks::make_basic_task([&]() { + task_5_started.wait(); + count_with_me.push_back(6); + stop_signal.store(true); + })); + })); + })); + }); + + std::thread thread_b([&]() { + a_started.wait(); + + ccf::tasks::add_task(ccf::tasks::make_basic_task([&]() { + count_with_me.push_back(1); + task_1_started.notify(); + + ccf::tasks::add_task(ccf::tasks::make_basic_task([&]() { + task_4_started.wait(); + count_with_me.push_back(5); + task_5_started.notify(); + })); + + ccf::tasks::add_task(ccf::tasks::make_basic_task([&]() { + task_2_started.wait(); + count_with_me.push_back(3); + task_3_started.notify(); + })); + })); + }); + + auto worker_fn = [&]() { + while (!stop_signal.load()) + { + auto task = ccf::tasks::get_main_job_board().wait_for_task( + std::chrono::milliseconds(100)); + if (task != nullptr) + { + task->do_task(); + } + } + }; + + std::vector workers; + + // Potentially 3 parallel jobs => need at least 3 workers + for (size_t i = 0; i < 3; ++i) + { + workers.emplace_back(worker_fn); + } + + std::thread watchdog([&]() { + using Clock = std::chrono::steady_clock; + auto start = Clock::now(); + while (!stop_signal.load()) + { + auto now = Clock::now(); + auto elapsed = now - start; + REQUIRE(elapsed < std::chrono::seconds(1)); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + }); + + for (auto& worker : workers) + { + worker.join(); + } + + thread_a.join(); + thread_b.join(); + + watchdog.join(); + + decltype(count_with_me) target(7); + std::iota(target.begin(), target.end(), 0); + REQUIRE(count_with_me == target); +} diff --git a/src/tasks/test/bench/contention_bench.cpp b/src/tasks/test/bench/contention_bench.cpp new file mode 100644 index 00000000000..82315a8b26e --- /dev/null +++ b/src/tasks/test/bench/contention_bench.cpp @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. + +#include "tasks/basic_task.h" +#include "tasks/task_system.h" + +#include +#include + +#define PICOBENCH_DONT_BIND_TO_ONE_CORE +#include + +struct NopTask : public ccf::tasks::BaseTask +{ + void do_task_implementation() override {} + + std::string_view get_name() const override + { + return "NopTask"; + } +}; + +void enqueue_many(picobench::state& s, size_t thread_count, size_t task_count) +{ + s.start_timer(); + std::vector threads; + for (auto i = 0; i < thread_count; ++i) + { + threads.emplace_back([task_count]() { + for (auto j = 0; j < task_count; ++j) + { + ccf::tasks::add_task(std::make_shared()); + std::this_thread::yield(); + } + }); + } + + for (auto& thread : threads) + { + thread.join(); + } + s.stop_timer(); +} + +template +static void benchmark_enqueue(picobench::state& s) +{ + enqueue_many(s, num_threads, s.iterations()); +} + +struct IncTask : public ccf::tasks::BaseTask +{ + std::atomic& value; + + IncTask(std::atomic& v) : value(v) {} + + void do_task_implementation() override + { + ++value; + } + + std::string_view get_name() const override + { + return "IncTask"; + } +}; + +void dequeue_many(picobench::state& s, size_t thread_count, size_t task_count) +{ + std::atomic tasks_done = 0; + for (auto j = 0; j < task_count; ++j) + { + ccf::tasks::add_task(std::make_shared(tasks_done)); + std::this_thread::yield(); + } + + s.start_timer(); + std::vector threads; + for (auto i = 0; i < thread_count; ++i) + { + threads.emplace_back([task_count, &tasks_done]() { + if (tasks_done.load() < task_count) + { + auto task = ccf::tasks::get_main_job_board().get_task(); + if (task != nullptr) + { + task->do_task(); + } + std::this_thread::yield(); + } + }); + } + + for (auto& thread : threads) + { + thread.join(); + } + s.stop_timer(); +} + +template +static void benchmark_dequeue(picobench::state& s) +{ + dequeue_many(s, num_threads, s.iterations()); +} + +const std::vector task_counts{32'000, 64'000}; + +namespace +{ + auto enq_1 = benchmark_enqueue<1>; + auto enq_2 = benchmark_enqueue<2>; + auto enq_4 = benchmark_enqueue<4>; + auto enq_8 = benchmark_enqueue<8>; + auto enq_16 = benchmark_enqueue<16>; + auto enq_32 = benchmark_enqueue<32>; + + PICOBENCH_SUITE("contended enqueue"); + PICOBENCH(enq_1).iterations(task_counts).baseline(); + PICOBENCH(enq_2).iterations(task_counts); + PICOBENCH(enq_4).iterations(task_counts); + PICOBENCH(enq_8).iterations(task_counts); + PICOBENCH(enq_16).iterations(task_counts); + PICOBENCH(enq_32).iterations(task_counts); +} + +namespace +{ + auto deq_1 = benchmark_dequeue<1>; + auto deq_2 = benchmark_dequeue<2>; + auto deq_4 = benchmark_dequeue<4>; + auto deq_8 = benchmark_dequeue<8>; + auto deq_16 = benchmark_dequeue<16>; + auto deq_32 = benchmark_dequeue<32>; + + PICOBENCH_SUITE("contended dequeue"); + PICOBENCH(deq_1).iterations(task_counts).baseline(); + PICOBENCH(deq_2).iterations(task_counts); + PICOBENCH(deq_4).iterations(task_counts); + PICOBENCH(deq_8).iterations(task_counts); + PICOBENCH(deq_16).iterations(task_counts); + PICOBENCH(deq_32).iterations(task_counts); +} diff --git a/src/tasks/test/bench/flush_all_tasks.h b/src/tasks/test/bench/flush_all_tasks.h new file mode 100644 index 00000000000..647676a38b5 --- /dev/null +++ b/src/tasks/test/bench/flush_all_tasks.h @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include "tasks/task_system.h" + +static inline void flush_all_tasks( + std::atomic& stop_signal, + size_t worker_count, + std::chrono::seconds kill_after = std::chrono::seconds(5)) +{ + std::vector workers; + for (size_t i = 0; i < worker_count; ++i) + { + workers.emplace_back([&stop_signal]() { + while (!stop_signal.load()) + { + auto task = ccf::tasks::get_main_job_board().get_task(); + if (task != nullptr) + { + task->do_task(); + } + std::this_thread::yield(); + } + }); + } + + using TClock = std::chrono::steady_clock; + auto now = TClock::now(); + + const auto hard_end = now + kill_after; + + while (true) + { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + now = TClock::now(); + if (now > hard_end) + { + break; + } + + if (stop_signal.load()) + { + break; + } + } + + stop_signal.store(true); + + for (auto& worker : workers) + { + worker.join(); + } +} \ No newline at end of file diff --git a/src/tasks/test/bench/merge_bench.cpp b/src/tasks/test/bench/merge_bench.cpp new file mode 100644 index 00000000000..129906db2f6 --- /dev/null +++ b/src/tasks/test/bench/merge_bench.cpp @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. + +#include "./flush_all_tasks.h" +#include "./merge_sort.h" + +#include + +#define PICOBENCH_DONT_BIND_TO_ONE_CORE +#define PICOBENCH_IMPLEMENT_WITH_MAIN +#include + +#define FMT_HEADER_ONLY +#include +#include + +static inline std::vector get_merge_sort_data(size_t n) +{ + static std::random_device rd; + static std::mt19937 g(rd()); + + std::vector data(n); + for (auto& x : data) + { + x = rand(); + } + + return data; +} + +void do_merge_sort(picobench::state& s, size_t worker_count, size_t data_size) +{ + auto ns = get_merge_sort_data(data_size); + if (std::is_sorted(ns.begin(), ns.end())) + { + throw std::logic_error("Initial data already sorted"); + } + + std::atomic stop_signal{false}; + + ccf::tasks::add_task( + std::make_shared(ns.begin(), ns.end(), stop_signal)); + + s.start_timer(); + flush_all_tasks(stop_signal, worker_count); + s.stop_timer(); + + if (!std::is_sorted(ns.begin(), ns.end())) + { + throw std::logic_error("Final data not sorted"); + } +} + +template +static void benchmark_mergesort(picobench::state& s) +{ + do_merge_sort(s, num_threads, s.iterations()); +} + +namespace +{ + const std::vector data_sizes{1'000, 1'000'000}; + + auto threads_1 = benchmark_mergesort<1>; + auto threads_2 = benchmark_mergesort<2>; + auto threads_3 = benchmark_mergesort<3>; + auto threads_4 = benchmark_mergesort<4>; + auto threads_5 = benchmark_mergesort<5>; + auto threads_6 = benchmark_mergesort<6>; + auto threads_7 = benchmark_mergesort<7>; + auto threads_8 = benchmark_mergesort<8>; + + PICOBENCH_SUITE("merge sort"); + PICOBENCH(threads_1).iterations(data_sizes).baseline(); + PICOBENCH(threads_2).iterations(data_sizes); + PICOBENCH(threads_3).iterations(data_sizes); + PICOBENCH(threads_4).iterations(data_sizes); + PICOBENCH(threads_5).iterations(data_sizes); + PICOBENCH(threads_6).iterations(data_sizes); + PICOBENCH(threads_7).iterations(data_sizes); + PICOBENCH(threads_8).iterations(data_sizes); +} diff --git a/src/tasks/test/bench/merge_sort.h b/src/tasks/test/bench/merge_sort.h new file mode 100644 index 00000000000..b635ed19033 --- /dev/null +++ b/src/tasks/test/bench/merge_sort.h @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include "tasks/task_system.h" + +#include + +struct MergeSortTask : public ccf::tasks::BaseTask, + public std::enable_shared_from_this +{ + // How many items will we actually directly sort, vs forking 2 new tasks to + // sub-sort + static constexpr size_t sort_threshold = 50; + + using Iterator = std::vector::iterator; + + Iterator begin; + Iterator end; + std::atomic& stop_signal; + std::shared_ptr parent; + std::atomic sub_tasks; + + MergeSortTask( + Iterator b, + Iterator e, + std::atomic& ss, + std::shared_ptr p = nullptr) : + begin(b), + end(e), + parent(p), + stop_signal(ss) + {} + + void merge() + { + std::sort(begin, end); + + if (parent != nullptr) + { + if (--parent->sub_tasks == 0) + { + parent->merge(); + } + } + else + { + stop_signal.store(true); + } + } + + void do_task_implementation() override + { + const auto dist = std::distance(begin, end); + if (dist >= sort_threshold) + { + sub_tasks.store(2); + + auto self = shared_from_this(); + + auto mid_point = begin + (dist / 2); + + ccf::tasks::add_task( + std::make_shared(begin, mid_point, stop_signal, self)); + ccf::tasks::add_task( + std::make_shared(mid_point, end, stop_signal, self)); + } + else + { + merge(); + } + } +}; diff --git a/src/tasks/test/bench/sleep_bench.cpp b/src/tasks/test/bench/sleep_bench.cpp new file mode 100644 index 00000000000..966f7100cbc --- /dev/null +++ b/src/tasks/test/bench/sleep_bench.cpp @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. + +#include "./flush_all_tasks.h" +#include "tasks/basic_task.h" + +#include + +#define PICOBENCH_DONT_BIND_TO_ONE_CORE +#include + +#define FMT_HEADER_ONLY +#include +#include + +struct TrueSleep +{ + static void sleep_for(std::chrono::milliseconds duration) + { + std::this_thread::sleep_for(duration); + } +}; + +struct SpinLoop +{ + static void sleep_for(std::chrono::milliseconds duration) + { + std::chrono::steady_clock clock; + auto start = clock.now(); + auto end = start + duration; + + while (clock.now() < end) + { + std::this_thread::yield(); + } + } +}; + +template +void sleep_with_many_workers( + picobench::state& s, size_t worker_count, size_t num_sleeps) +{ + std::atomic stop_signal{false}; + + for (auto i = 0; i < num_sleeps; ++i) + { + ccf::tasks::add_task(ccf::tasks::make_basic_task( + []() { SleepImpl::sleep_for(std::chrono::milliseconds(1)); })); + } + + ccf::tasks::add_task( + ccf::tasks::make_basic_task([&]() { stop_signal.store(true); })); + + s.start_timer(); + flush_all_tasks(stop_signal, worker_count); + s.stop_timer(); +} + +template +static void benchmark_sleeps(picobench::state& s) +{ + sleep_with_many_workers(s, num_threads, s.iterations()); +} + +namespace +{ + const std::vector num_sleeps{100, 1000}; + + auto threads_1 = benchmark_sleeps; + auto threads_2 = benchmark_sleeps; + auto threads_3 = benchmark_sleeps; + auto threads_4 = benchmark_sleeps; + auto threads_5 = benchmark_sleeps; + auto threads_6 = benchmark_sleeps; + auto threads_7 = benchmark_sleeps; + auto threads_8 = benchmark_sleeps; + + PICOBENCH_SUITE("sleeps"); + PICOBENCH(threads_1).iterations(num_sleeps).baseline(); + PICOBENCH(threads_2).iterations(num_sleeps); + PICOBENCH(threads_3).iterations(num_sleeps); + PICOBENCH(threads_4).iterations(num_sleeps); + PICOBENCH(threads_5).iterations(num_sleeps); + PICOBENCH(threads_6).iterations(num_sleeps); + PICOBENCH(threads_7).iterations(num_sleeps); + PICOBENCH(threads_8).iterations(num_sleeps); + + auto threads_1_spin = benchmark_sleeps; + auto threads_2_spin = benchmark_sleeps; + auto threads_3_spin = benchmark_sleeps; + auto threads_4_spin = benchmark_sleeps; + auto threads_5_spin = benchmark_sleeps; + auto threads_6_spin = benchmark_sleeps; + auto threads_7_spin = benchmark_sleeps; + auto threads_8_spin = benchmark_sleeps; + + PICOBENCH_SUITE("spins"); + PICOBENCH(threads_1_spin).iterations(num_sleeps).baseline(); + PICOBENCH(threads_2_spin).iterations(num_sleeps); + PICOBENCH(threads_3_spin).iterations(num_sleeps); + PICOBENCH(threads_4_spin).iterations(num_sleeps); + PICOBENCH(threads_5_spin).iterations(num_sleeps); + PICOBENCH(threads_6_spin).iterations(num_sleeps); + PICOBENCH(threads_7_spin).iterations(num_sleeps); + PICOBENCH(threads_8_spin).iterations(num_sleeps); +} diff --git a/src/tasks/test/delayed_tasks.cpp b/src/tasks/test/delayed_tasks.cpp new file mode 100644 index 00000000000..31b7d6dbf6f --- /dev/null +++ b/src/tasks/test/delayed_tasks.cpp @@ -0,0 +1,229 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. + +#include "ds/internal_logger.h" +#include "tasks/basic_task.h" +#include "tasks/task_system.h" + +#include + +namespace +{ + struct FakeTime + { + const std::chrono::milliseconds polling_period{1}; + + void sleep_for(size_t workers, std::chrono::milliseconds duration) + { + std::chrono::milliseconds elapsed{0}; + + auto& job_board = ccf::tasks::get_main_job_board(); + + while (elapsed < duration) + { + ccf::tasks::tick(polling_period); + + size_t worker_idx = 0; + while (worker_idx < workers) + { + auto task = job_board.get_task(); + if (task != nullptr) + { + task->do_task(); + ++worker_idx; + } + else + { + break; + } + } + + elapsed += polling_period; + } + } + }; +} + +TEST_CASE("DelayedTasks" * doctest::test_suite("delayed_tasks")) +{ + FakeTime fake_time; + + std::atomic n = 0; + ccf::tasks::Task incrementer = + ccf::tasks::make_basic_task([&n]() { ++n; }, "incrementer"); + + ccf::tasks::add_task(incrementer); + // Task is not done when no workers are present + REQUIRE(n.load() == 0); + + { + fake_time.sleep_for(1, fake_time.polling_period * 2); + REQUIRE(n.load() == 1); + } + + std::chrono::milliseconds delay = std::chrono::milliseconds(50); + ccf::tasks::add_delayed_task(incrementer, delay); + // Delayed task is not done when no workers are present + REQUIRE(n.load() == 1); + // Even after waiting for delay + fake_time.sleep_for(0, delay * 2); + REQUIRE(n.load() == 1); + + { + // Delayed task is executed when worker thread arrives + fake_time.sleep_for(1, delay * 2); + REQUIRE(n.load() == 2); + // Task is only executed once + fake_time.sleep_for(1, delay * 2); + REQUIRE(n.load() == 2); + } + + ccf::tasks::add_periodic_task(incrementer, delay, delay); + // Periodic task is not done when no workers are present + REQUIRE(n.load() == 2); + // Even after waiting for delay + fake_time.sleep_for(0, delay * 2); + REQUIRE(n.load() == 2); + + { + // Periodic task is executed when worker thread arrives + fake_time.sleep_for(1, delay * 2); + const auto a = n.load(); + REQUIRE(a > 2); + + // Periodic task is executed multiple times + fake_time.sleep_for(1, delay * 2); + const auto b = n.load(); + REQUIRE(b > a); + + // Periodic task is cancellable + incrementer->cancel_task(); + + fake_time.sleep_for(1, delay * 2); + const auto c = n.load(); + REQUIRE(c >= b); + + fake_time.sleep_for(1, delay * 2); + const auto d = n.load(); + REQUIRE(d == c); + } +} + +void do_all_tasks() +{ + auto& job_board = ccf::tasks::get_main_job_board(); + auto task = job_board.get_task(); + while (task != nullptr) + { + task->do_task(); + task = job_board.get_task(); + } +} + +TEST_CASE("ExplicitTicks" * doctest::test_suite("delayed_tasks")) +{ + std::atomic a = false; + std::atomic b = false; + std::atomic c = false; + + auto set_a = ccf::tasks::make_basic_task([&a]() { a.store(true); }); + auto set_b = ccf::tasks::make_basic_task([&b]() { b.store(true); }); + auto set_c = ccf::tasks::make_basic_task([&c]() { c.store(true); }); + + using namespace std::chrono_literals; + ccf::tasks::add_periodic_task(set_a, 5ms, 5ms); + ccf::tasks::add_periodic_task(set_b, 7ms, 8ms); + ccf::tasks::add_delayed_task(set_c, 20ms); + auto do_all_check_and_reset = [&a, &b, &c]( + std::string_view label, + bool expected_a, + bool expected_b, + bool expected_c) { + DOCTEST_INFO(label); + do_all_tasks(); + + REQUIRE(a == expected_a); + REQUIRE(b == expected_b); + REQUIRE(c == expected_c); + + a.store(false); + b.store(false); + c.store(false); + }; + + do_all_check_and_reset("0ms", false, false, false); + + ccf::tasks::tick(1ms); + do_all_check_and_reset("1ms", false, false, false); + + ccf::tasks::tick(3ms); + do_all_check_and_reset("4ms", false, false, false); + + ccf::tasks::tick(1ms); + // First set_a is enqueued, but not yet run + REQUIRE(a == false); + do_all_check_and_reset("5ms", true, false, false); // First set_a + do_all_check_and_reset("5ms (after reset)", false, false, false); + + ccf::tasks::tick(1ms); + do_all_check_and_reset("6ms", false, false, false); + + ccf::tasks::tick(1ms); + do_all_check_and_reset("7ms", false, true, false); // First set_b + + ccf::tasks::tick(2ms); + do_all_check_and_reset("9ms", false, false, false); + + ccf::tasks::tick(1ms); + do_all_check_and_reset("10ms", true, false, false); // Second set_a + + ccf::tasks::tick(4ms); + do_all_check_and_reset("14ms", false, false, false); // Second set_a + + ccf::tasks::tick(1ms); + do_all_check_and_reset("15ms", true, true, false); // set_a and set_b + + ccf::tasks::tick(4ms); + do_all_check_and_reset("19ms", false, false, false); + + ccf::tasks::tick(1ms); + do_all_check_and_reset("20ms", true, false, true); // set_a and set_c + + ccf::tasks::tick(6ms); + do_all_check_and_reset("26ms", true, true, false); // set_a@25, set_b@23 + + // Repeats do not correct for large ticks, they just add the repeat value to + // the current elapsed. + // Next set_a is now at 26 + 5 = 31 (NOT 25 + 5 = 30) + // Next set_b is now at 26 + 8 = 34 (NOT 23 + 8 = 31) + + ccf::tasks::tick(4ms); + do_all_check_and_reset("30ms", false, false, false); + + ccf::tasks::tick(1ms); + do_all_check_and_reset("31ms", true, false, false); + + ccf::tasks::tick(3ms); + do_all_check_and_reset("34ms", false, true, false); +} + +TEST_CASE("TickEnqueue" * doctest::test_suite("delayed_tasks")) +{ + INFO( + "Each tick will only trigger a single instance of a task, even if multiple " + "periods have elapsed"); + + std::atomic n = 0; + + auto incrementer = ccf::tasks::make_basic_task([&n]() { ++n; }); + + using namespace std::chrono_literals; + ccf::tasks::add_periodic_task(incrementer, 1ms, 1ms); + + REQUIRE(n.load() == 0); + ccf::tasks::tick(100ms); + do_all_tasks(); + REQUIRE(n.load() == 1); + do_all_tasks(); + REQUIRE(n.load() == 1); +} diff --git a/src/tasks/test/demo/actions.h b/src/tasks/test/demo/actions.h new file mode 100644 index 00000000000..c185093ff7e --- /dev/null +++ b/src/tasks/test/demo/actions.h @@ -0,0 +1,168 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include "ccf/crypto/key_pair.h" +#include "ccf/ds/hex.h" +#include "ccf/ds/nonstd.h" +#include "ds/internal_logger.h" + +#include +#include +#define DOCTEST_CONFIG_IMPLEMENT +#include +#define FMT_HEADER_ONLY +#include +#include +#include +#include +#include + +using SerialisedAction = std::string; +using SerialisedResponse = std::string; + +size_t id_from_string(const std::string_view& sv) +{ + size_t n_id; + const auto [p, ec] = std::from_chars(sv.begin(), sv.end(), n_id); + REQUIRE(ec == std::errc()); + return n_id; +} + +struct IAction +{ + virtual ~IAction() = default; + + virtual SerialisedAction serialise() const = 0; + virtual void verify_serialised_response( + SerialisedResponse& response) const = 0; + + virtual SerialisedResponse do_action() const = 0; +}; + +using ActionPtr = std::unique_ptr; + +static std::atomic action_id_generator = 0; +struct OrderedAction : public IAction +{ + const size_t id; + + OrderedAction() : id(++action_id_generator) {} + OrderedAction(size_t _id) : id(_id) {} + + SerialisedAction serialise() const override + { + return fmt::format("{}|", id); + } + + void verify_serialised_response(SerialisedResponse& response) const override + { + auto [s_id, remainder] = ccf::nonstd::split_1(response, "|"); + size_t n_id = id_from_string(s_id); + REQUIRE(id == n_id); + + response = remainder; + } + + SerialisedResponse do_action() const override + { + return fmt::format("{}|", id); + } +}; + +struct SignAction : public OrderedAction +{ + const std::vector tbs; + + static std::vector generate_random_data() + { + auto len = rand() % 100; + std::vector data(len); + for (auto& n : data) + { + n = rand(); + } + return data; + } + + SignAction() : OrderedAction(), tbs(generate_random_data()) + { + LOG_DEBUG_FMT("Created a new SignAction id={}", id); + } + SignAction(size_t _id, const std::vector& _tbs) : + OrderedAction(_id), + tbs(_tbs) + {} + + SerialisedAction serialise() const override + { + return fmt::format( + "{}SIGN|{}", OrderedAction::serialise(), ccf::ds::to_hex(tbs)); + } + + void verify_serialised_response(SerialisedResponse& response) const override + { + LOG_DEBUG_FMT("Verifying a signature, for action id={}", id); + OrderedAction::verify_serialised_response(response); + + auto [a, b] = ccf::nonstd::split_1(response, "|"); + + if (a == "FAILED") + { + // auto reason = b; + } + else + { + auto key_s = a; + auto signature_s = b; + + ccf::crypto::Pem pem{std::string(key_s)}; + auto pubk = ccf::crypto::make_public_key(pem); + + auto signature = ccf::ds::from_hex(std::string(signature_s)); + REQUIRE(pubk->verify(tbs, signature)); + } + } + + SerialisedResponse do_action() const override + { + LOG_DEBUG_FMT("Signing something a client gave me, id={}", id); + + // Randomly fail some small fraction of requests + if (rand() % 50 == 0) + { + return fmt::format( + "{}FAILED|Randomly unlucky", OrderedAction::do_action()); + } + else + { + auto key_pair = ccf::crypto::make_key_pair(); + auto signature = key_pair->sign(tbs); + return fmt::format( + "{}{}|{}", + OrderedAction::do_action(), + key_pair->public_key_pem().str(), + ccf::ds::to_hex(signature)); + } + } +}; + +ActionPtr deserialise_action(const SerialisedAction& ser) +{ + const auto components = ccf::nonstd::split(ser, "|"); + + REQUIRE(components.size() >= 1); + + const auto id = id_from_string(components[0]); + + if (components.size() == 3) + { + if (components[1] == "SIGN") + { + const auto tbs = ccf::ds::from_hex(std::string(components[2])); + return std::make_unique(id, tbs); + } + } + + throw std::logic_error(fmt::format("Unknown action: {}", ser)); +} \ No newline at end of file diff --git a/src/tasks/test/demo/clients.h b/src/tasks/test/demo/clients.h new file mode 100644 index 00000000000..619ebf34fe6 --- /dev/null +++ b/src/tasks/test/demo/clients.h @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include "./actions.h" +#include "./looping_thread.h" +#include "./session.h" + +#include +#include +#include +#include +#include +#include + +struct ClientParams +{ + std::chrono::milliseconds submission_duration = + std::chrono::milliseconds(1000); + + std::function submission_delay = []() { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + }; + + std::function generate_next_action = []() { + return std::make_unique(); + }; +}; + +struct ClientState +{ + Session& session; + const ClientParams& params; + + std::queue pending_actions; + + using TClock = std::chrono::system_clock; + TClock::time_point submission_end; + + size_t requests_sent; + size_t responses_seen; + + bool terminated_early = false; +}; + +struct Client : public LoopingThread +{ + Client(Session& _session, const ClientParams& _params, size_t idx) : + LoopingThread(fmt::format("c{}", idx), _session, _params) + {} + + ~Client() override + { + shutdown(); + + LOG_INFO_FMT( + "Shutting down {}, sent {} requests and saw {} responses", + name, + state.requests_sent, + state.responses_seen); + + if (!state.terminated_early) + { + REQUIRE(state.requests_sent == state.responses_seen); + } + } + + void init_behaviour() override + { + const auto start = State::TClock::now(); + state.submission_end = start + state.params.submission_duration; + } + + Stage loop_behaviour() override + { + if (rand() % 4000 == 0) + { + state.terminated_early = true; + state.session.abandoned.store(true); + return Stage::Terminated; + } + + const bool still_submitting = State::TClock::now() < state.submission_end; + if (still_submitting) + { + // Generate and submit new work + auto action = state.params.generate_next_action(); + state.session.to_node.emplace_back(action->serialise()); + state.pending_actions.push(std::move(action)); + ++state.requests_sent; + LOG_DEBUG_FMT("Pushed a pending action"); + } + + // If we have any responses + auto response = state.session.from_node.try_pop(); + while (response.has_value()) + { + // Verification is expensive, so we end up spending a long tail time in + // this test verifying every response (longer than we spent doing real + // work). Mitigate this by only checking some responses, randomly + // determined, estimating how far 'behind' we are (and thus how likely we + // should be to skip verification) by the length of pending messages. + const auto n = rand() % 100; + if (n >= state.pending_actions.size() || n == 0) + { + // Verify (check that the first response matches the first pending + // action) + REQUIRE(!state.pending_actions.empty()); + state.pending_actions.front()->verify_serialised_response( + response.value()); + } + + state.pending_actions.pop(); + ++state.responses_seen; + + // ...and check for further responses + response = state.session.from_node.try_pop(); + } + + // End loop if this client has submitted and verified everything + if (still_submitting) + { + return Stage::Running; + } + else + { + if (state.pending_actions.empty()) + { + return Stage::Terminated; + } + else + { + return Stage::ShuttingDown; + } + } + } + + Stage idle_behaviour() override + { + state.params.submission_delay(); + return lifetime_stage.load(); + } +}; diff --git a/src/tasks/test/demo/dispatcher.h b/src/tasks/test/demo/dispatcher.h new file mode 100644 index 00000000000..11c8a8528c4 --- /dev/null +++ b/src/tasks/test/demo/dispatcher.h @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include "./actions.h" +#include "./looping_thread.h" +#include "tasks/job_board.h" +#include "tasks/ordered_tasks.h" +#include "tasks/task_system.h" + +#include + +struct Action_ProcessClientAction : public ccf::tasks::ITaskAction +{ + const SerialisedAction input_action; + Session& client_session; + std::atomic& responses_sent; + const std::string name; + + Action_ProcessClientAction( + const SerialisedAction& action, Session& cs, std::atomic& rs) : + input_action(action), + client_session(cs), + responses_sent(rs), + name(fmt::format( + "Processing action '{}' from session {}", + input_action, + (void*)&client_session)) + {} + + void do_action() override + { + auto received_action = deserialise_action(input_action); + auto result = received_action->do_action(); + + if (rand() % 50 == 0) + { + auto paused_task = ccf::tasks::pause_current_task(); + + // Rough hack to simulate "something async" happening + auto _ = std::async( + std::launch::async, + [paused_task = std::move(paused_task), + result = std::move(result), + &client_session = client_session, + &responses_sent = responses_sent]() mutable { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + client_session.from_node.push_back(std::move(result)); + ++responses_sent; + ccf::tasks::resume_task(std::move(paused_task)); + }); + } + else + { + client_session.from_node.push_back(std::move(result)); + ++responses_sent; + } + } + + std::string_view get_name() const override + { + return name; + } +}; + +struct DispatcherState +{ + ccf::tasks::IJobBoard& job_board; + SessionManager& session_manager; + std::atomic& responses_sent; + + std::unordered_map> + ordered_tasks_per_client; + + std::atomic consider_termination = false; +}; + +struct Dispatcher : public LoopingThread +{ + Dispatcher( + ccf::tasks::IJobBoard& jb, + SessionManager& sm, + std::atomic& response_count) : + LoopingThread(fmt::format("dsp"), jb, sm, response_count) + {} + + ~Dispatcher() override + { + shutdown(); + } + + Stage loop_behaviour() override + { + // Handle incoming IO, producing tasks to process each item + + // Produce a return value of Terminated if consider_termination has been + // set, and we pop nothing off incoming in this iteration + Stage ret_val = + state.consider_termination.load() ? Stage::Terminated : Stage::Running; + + std::lock_guard lock(state.session_manager.sessions_mutex); + for (auto& session : state.session_manager.all_sessions) + { + auto it = state.ordered_tasks_per_client.find(session.get()); + if (it == state.ordered_tasks_per_client.end()) + { + it = state.ordered_tasks_per_client.emplace_hint( + it, + session.get(), + ccf::tasks::OrderedTasks::create( + state.job_board, fmt::format("Tasks for {}", session->name))); + } + + auto& tasks = *it->second; + + // If the client has abandoned this session, cancel all corresponding + // tasks + if (session->abandoned.load()) + { + if (!tasks.is_cancelled()) + { + tasks.cancel_task(); + } + } + else + { + // Otherwise, produce a task to process this client request + auto incoming = session->to_node.try_pop(); + while (incoming.has_value()) + { + ret_val = Stage::Running; + + tasks.add_action(std::make_shared( + incoming.value(), *session, state.responses_sent)); + + incoming = session->to_node.try_pop(); + } + } + } + + return ret_val; + } +}; diff --git a/src/tasks/test/demo/locking_mpmc_queue.h b/src/tasks/test/demo/locking_mpmc_queue.h new file mode 100644 index 00000000000..c0b0b93cccb --- /dev/null +++ b/src/tasks/test/demo/locking_mpmc_queue.h @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include +#include + +namespace ccf::tasks +{ + // A very simple (slow) MPMPC queue, implemented by a std container guarded by + // a mutex + template + class LockingMPMCQueue + { + protected: + std::mutex mutex; + std::deque deque; + + public: + bool empty() + { + std::lock_guard lock(mutex); + return deque.empty(); + } + + size_t size() + { + std::lock_guard lock(mutex); + return deque.size(); + } + + void push_back(const T& t) + { + std::lock_guard lock(mutex); + deque.push_back(t); + } + + void emplace_back(T&& t) + { + std::lock_guard lock(mutex); + deque.emplace_back(std::move(t)); + } + + std::optional try_pop() + { + std::lock_guard lock(mutex); + + if (deque.empty()) + { + return std::nullopt; + } + + std::optional val = deque.front(); + deque.pop_front(); + return val; + } + }; +} diff --git a/src/tasks/test/demo/looping_thread.h b/src/tasks/test/demo/looping_thread.h new file mode 100644 index 00000000000..5045ced49c9 --- /dev/null +++ b/src/tasks/test/demo/looping_thread.h @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include "ccf/threading/thread_ids.h" + +#include +#include +#include + +enum class Stage +{ + PreInit, + Running, + ShuttingDown, + Terminated, +}; +template +struct LoopingThread +{ + using State = TState; + + // Derived instances will likely access state inside their loop_behaviour, + // which should be destroyed _after_ the loop ends. That means (because of C++ + // destructor order) it needs to be defined as a member here, so that it is + // destructed _after_ the destructor runs + TState state; + + std::atomic stop_signal = false; + std::thread thread; + + const std::string name; + + std::atomic lifetime_stage; + + template + LoopingThread(const std::string& _name, Ts&&... args) : + state(std::forward(args)...), + name(_name), + lifetime_stage(Stage::PreInit) + {} + + virtual ~LoopingThread() = 0; + + virtual void shutdown() + { + LOG_DEBUG_FMT("Stopping {}", name); + stop_signal.store(true); + + if (thread.joinable()) + { + thread.join(); + } + + lifetime_stage.store(Stage::Terminated); + } + + virtual void start() + { + thread = std::thread([this]() { + lifetime_stage.store(Stage::PreInit); + + this->init_behaviour(); + + lifetime_stage.store(Stage::Running); + + while (!stop_signal) + { + auto loop_behaviour_target_stage = this->loop_behaviour(); + REQUIRE(loop_behaviour_target_stage >= lifetime_stage); + lifetime_stage.store(loop_behaviour_target_stage); + if (lifetime_stage.load() == Stage::Terminated) + { + break; + } + + auto idle_behaviour_target_stage = this->idle_behaviour(); + REQUIRE(idle_behaviour_target_stage >= lifetime_stage); + lifetime_stage.store(idle_behaviour_target_stage); + if (lifetime_stage.load() == Stage::Terminated) + { + break; + } + } + + LOG_DEBUG_FMT("Terminating thread"); + }); + } + + virtual void init_behaviour() {} + + virtual Stage loop_behaviour() + { + // Base loop_behaviour is to terminate immediately + return Stage::Terminated; + } + + virtual Stage idle_behaviour() + { + std::this_thread::yield(); + return lifetime_stage.load(); + } +}; + +template +inline LoopingThread::~LoopingThread() +{} diff --git a/src/tasks/test/demo/main.cpp b/src/tasks/test/demo/main.cpp new file mode 100644 index 00000000000..4f47435bcdd --- /dev/null +++ b/src/tasks/test/demo/main.cpp @@ -0,0 +1,379 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. + +#include "./actions.h" +#include "./clients.h" +#include "./node.h" +#include "tasks/basic_task.h" + +#define DOCTEST_CONFIG_IMPLEMENT +#include + +// A few simple sanity checks that the basic operations do what we expect +TEST_CASE("SignAction") +{ + for (size_t i = 0; i < 100; ++i) + { + auto orig = std::make_unique(); + auto ser = orig->serialise(); + + auto received = deserialise_action(ser); + auto result = received->do_action(); + + orig->verify_serialised_response(result); + } +} + +TEST_CASE("SessionOrdering") +{ + // With more sessions than workers, and tasks concurrently added to these + // sessions, each task is still executed in-order for that session + static constexpr auto num_sessions = 5; + static constexpr auto num_workers = 2; + + // Record last x seen for each session + using Result = std::atomic; + std::vector results(num_sessions); + + ccf::tasks::JobBoard job_board; + { + // Record next x to send for each session + std::vector, size_t>> + all_tasks; + for (auto i = 0; i < num_sessions; ++i) + { + all_tasks.emplace_back( + ccf::tasks::OrderedTasks::create(job_board, std::to_string(i)), 0); + } + + auto add_action = [&](size_t idx, size_t sleep_time_ms) { + auto& [tasks, n] = all_tasks[idx]; + tasks->add_action(ccf::tasks::make_basic_action([=, &n, &results]() { + std::this_thread::sleep_for(std::chrono::milliseconds(sleep_time_ms)); + const auto x = ++n; + LOG_TRACE_FMT("{} {}", tasks->get_name(), x); + REQUIRE(++results[idx] == x); + })); + }; + + // Add some initial tasks on each session + const auto spacing = 3; + const auto period = spacing * num_sessions + 1; + for (auto i = 0; i < num_sessions; ++i) + { + add_action(i, spacing * i); + add_action(i, period); + add_action(i, period); + } + + { + std::vector> workers; + for (auto i = 0; i < num_workers; ++i) + { + workers.emplace_back(std::make_unique(job_board, i)); + } + + // Start processing those tasks on worker threads + for (auto& worker : workers) + { + worker->start(); + } + + // Continually add tasks, while the workers are running + for (auto i = 0; i < num_workers * num_sessions * 10; ++i) + { + add_action(i % all_tasks.size(), period); + // Try to produce an interesting interleaving of tasks across sessions + if (i % ((num_workers * num_sessions) - 1) == 0) + { + std::this_thread::sleep_for(std::chrono::milliseconds(period / 2)); + } + } + + while (!job_board.empty()) + { + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + } + } + } +} + +TEST_CASE("PauseAndResume") +{ + ccf::tasks::JobBoard job_board; + { + std::atomic x = 0; + std::atomic y = 0; + + auto increment = [](std::atomic& n) { + return ccf::tasks::make_basic_action([&n]() { ++n; }); + }; + + std::shared_ptr x_tasks = + ccf::tasks::OrderedTasks::create(job_board, "x"); + std::shared_ptr y_tasks = + ccf::tasks::OrderedTasks::create(job_board, "y"); + + x_tasks->add_action(increment(x)); + y_tasks->add_action(increment(y)); + y_tasks->add_action(increment(y)); + + { + Worker worker(job_board, 0); + + // Worker exists but hasn't started yet - no increments have occurred + REQUIRE(x.load() == 0); + REQUIRE(y.load() == 0); + + // Even if we wait + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + REQUIRE(x.load() == 0); + REQUIRE(y.load() == 0); + + // If we start the worker (and wait), it will execute the pending tasks + worker.start(); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + REQUIRE(x.load() == 1); + REQUIRE(y.load() == 2); + + // We can concurrently queue many more tasks, which will be executed + // immediately + for (auto i = 0; i < 100; ++i) + { + x_tasks->add_action(increment(x)); + y_tasks->add_action(increment(y)); + } + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + REQUIRE(x.load() == 101); + REQUIRE(y.load() == 102); + } + + { + // Terminating previous worker, creating a new one (not yet running) + Worker worker(job_board, 1); + + // If we need to block, we can ask for a task to be paused. Note that the + // current action will still complete + std::atomic happened = false; + ccf::tasks::Resumable resumable; + ccf::ds::WorkBeacon beacon; + + x_tasks->add_action(increment(x)); + x_tasks->add_action(ccf::tasks::make_basic_action([&]() { + // NB: This doesn't need to _know_ the current task, just that it is + // executed _as part of a task_. This means it could occur deep within a + // call-stack. + resumable = ccf::tasks::pause_current_task(); + // NB: The current _action_ will still complete execution! + happened = true; + beacon.notify_work_available(); + })); + x_tasks->add_action(increment(x)); + + worker.start(); + beacon.wait_for_work_with_timeout(std::chrono::milliseconds(100)); + REQUIRE(x.load() == 102); // One increment action happened + REQUIRE(happened == true); // Then the pause action ran to completion + REQUIRE( + resumable != nullptr); // We got a handle to later resume this task + + // Other actions can be scheduled, including on the paused task. + // Unpaused tasks will complete as normal. + for (auto i = 0; i < 100; ++i) + { + x_tasks->add_action(increment(x)); + y_tasks->add_action(increment(y)); + } + + beacon.wait_for_work_with_timeout(std::chrono::milliseconds(100)); + REQUIRE(x.load() == 102); + REQUIRE(y.load() == 202); + + // After resume, all queued actions will (be able to) execute, in-order + ccf::tasks::resume_task(std::move(resumable)); + + beacon.wait_for_work_with_timeout(std::chrono::milliseconds(100)); + REQUIRE(x.load() == 203); + REQUIRE(y.load() == 202); + + // A task might be paused multiple times during its life + resumable = nullptr; + x_tasks->add_action(increment(x)); + x_tasks->add_action(ccf::tasks::make_basic_action([&]() { + resumable = ccf::tasks::pause_current_task(); + beacon.notify_work_available(); + })); + x_tasks->add_action(increment(x)); + + beacon.wait_for_work_with_timeout(std::chrono::milliseconds(100)); + REQUIRE(x.load() == 204); + REQUIRE(resumable != nullptr); + + // A paused task can be cancelled + x_tasks->cancel_task(); + + // Cancellation supercedes resumption - nothing more happens on this task + ccf::tasks::resume_task(std::move(resumable)); + + beacon.wait_for_work_with_timeout(std::chrono::milliseconds(100)); + REQUIRE(x.load() == 204); + } + } + + // Trying to pause outside of a task will throw an error + REQUIRE_THROWS(ccf::tasks::pause_current_task()); +} + +void describe_session_manager(SessionManager& sm) +{ + std::lock_guard lock(sm.sessions_mutex); + fmt::print("SessionManager contains {} sessions\n", sm.all_sessions.size()); + for (auto& session : sm.all_sessions) + { + fmt::print( + " {}: {} to_node, {} from_node\n", + session->name, + session->to_node.size(), + session->from_node.size()); + } +} + +void describe_job_board(ccf::tasks::JobBoard& jb) +{ + std::lock_guard lock(jb.mutex); + fmt::print("JobBoard contains {} tasks\n", jb.queue.size()); + // for (auto& task : jb.queue) + // { + // fmt::print(" {}\n", task->get_name()); + // } +} + +void describe_dispatcher(Dispatcher& d) +{ + describe_session_manager(d.state.session_manager); + describe_job_board((ccf::tasks::JobBoard&)d.state.job_board); + + fmt::print( + "Dispatcher is tracking {} sessions\n", + d.state.ordered_tasks_per_client.size()); + + for (auto& [session, tasks] : d.state.ordered_tasks_per_client) + { + size_t pending; + bool active; + tasks->get_queue_summary(pending, active); + fmt::print( + " {}: {} (active: {}, queue.size: {})\n", + session->name, + tasks->get_name(), + active, + pending); + } +} + +TEST_CASE("Run") +{ + size_t total_requests_sent = 0; + std::atomic total_responses_sent = 0; + size_t total_responses_seen = 0; + + { + // Create a node + ccf::tasks::JobBoard job_board; + Node node(4, job_board, total_responses_sent); + node.start(); + + { + // Create some clients + ClientParams client_params; + std::vector> clients; + for (auto i = 0u; i < 12; ++i) + { + clients.push_back(std::make_unique( + node.new_session(std::to_string(i)), client_params, i)); + clients.back()->start(); + } + + LOG_INFO_FMT("Leaving to run"); + + // Run everything, checking if all clients are done + const auto n_clients = clients.size(); + while (true) + { + size_t running = 0; + size_t shutting_down = 0; + for (auto& client : clients) + { + const auto stage = client->lifetime_stage.load(); + if (stage <= Stage::Running) + { + ++running; + } + else if (stage == Stage::ShuttingDown) + { + ++shutting_down; + } + } + LOG_INFO_FMT( + "{} clients running (submitting), {} shutting down (checking " + "responses), ({} total) ...", + running, + shutting_down, + n_clients); + if (running + shutting_down == 0) + { + break; + } + else + { + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + } + } + + for (auto& client : clients) + { + total_requests_sent += client->state.requests_sent; + total_responses_seen += client->state.responses_seen; + } + + LOG_INFO_FMT( + "Shutting down clients, total sent: {}, total seen: {}", + total_requests_sent, + total_responses_seen); + } + + node.dispatcher.state.consider_termination.store(true); + node.dispatcher.shutdown(); + + describe_dispatcher(node.dispatcher); + + for (auto& worker : node.workers) + { + worker->state.consider_termination.store(true); + worker->shutdown(); + } + } + + LOG_INFO_FMT( + "{} vs {} vs {}", + total_requests_sent, + total_responses_sent, + total_responses_seen); + + REQUIRE(total_requests_sent >= total_responses_sent); + REQUIRE(total_responses_sent >= total_responses_seen); +} + +int main(int argc, char** argv) +{ + ccf::logger::config::default_init(); + ccf::logger::config::level() = ccf::LoggerLevel::INFO; + + doctest::Context context; + context.applyCommandLine(argc, argv); + int res = context.run(); + if (context.shouldExit()) + return res; + return res; +} diff --git a/src/tasks/test/demo/node.h b/src/tasks/test/demo/node.h new file mode 100644 index 00000000000..10c94528b72 --- /dev/null +++ b/src/tasks/test/demo/node.h @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include "./actions.h" +#include "./dispatcher.h" +#include "./session.h" +#include "./worker.h" +#include "tasks/job_board.h" +#include "tasks/ordered_tasks.h" + +#include +#include +#include +#include +#include + +struct Node +{ + SessionManager session_manager; + + ccf::tasks::IJobBoard& job_board; + + Dispatcher dispatcher; + std::vector> workers; + + Node( + size_t num_workers, + ccf::tasks::IJobBoard& jb, + std::atomic& response_count) : + job_board(jb), + dispatcher(jb, session_manager, response_count) + { + for (size_t i = 0; i < num_workers; ++i) + { + workers.push_back(std::make_unique(job_board, i)); + } + } + + void start() + { + dispatcher.start(); + for (auto& worker : workers) + { + worker->start(); + } + } + + Session& new_session(const std::string& s) + { + return session_manager.new_session(s); + } +}; \ No newline at end of file diff --git a/src/tasks/test/demo/session.h b/src/tasks/test/demo/session.h new file mode 100644 index 00000000000..25211a97184 --- /dev/null +++ b/src/tasks/test/demo/session.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include "./locking_mpmc_queue.h" +#include "ds/internal_logger.h" + +#include +#include +#include + +struct Session +{ + const std::string name; + + ccf::tasks::LockingMPMCQueue to_node; + ccf::tasks::LockingMPMCQueue from_node; + + std::atomic abandoned = false; + + Session(const std::string& s) : name(s) {} +}; + +struct SessionManager +{ + using SessionPtr = std::unique_ptr; + + std::mutex sessions_mutex; + std::vector all_sessions; + + ~SessionManager() + { + LOG_DEBUG_FMT("Destroying SessionManager"); + } + + Session& new_session(const std::string& s) + { + std::lock_guard lock(sessions_mutex); + return *all_sessions.emplace_back(std::make_unique(s)); + } +}; \ No newline at end of file diff --git a/src/tasks/test/demo/worker.h b/src/tasks/test/demo/worker.h new file mode 100644 index 00000000000..d05c5c8df26 --- /dev/null +++ b/src/tasks/test/demo/worker.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include "./looping_thread.h" + +struct WorkerState +{ + ccf::tasks::IJobBoard& job_board; + + size_t work_completed; + + std::atomic consider_termination = false; +}; + +struct Worker : public LoopingThread +{ + Worker(ccf::tasks::IJobBoard& jb, size_t idx) : + LoopingThread(fmt::format("w{}", idx), jb) + {} + + ~Worker() override + { + shutdown(); + + LOG_INFO_FMT( + "Shutting down {}, processed {} tasks", name, state.work_completed); + } + + Stage loop_behaviour() override + { + // Wait (with timeout) for a task + auto task = state.job_board.wait_for_task(std::chrono::milliseconds(10)); + if (task != nullptr) + { + task->do_task(); + state.work_completed += 1; + } + else if (state.consider_termination.load()) + { + return Stage::Terminated; + } + + return Stage::Running; + } +}; diff --git a/src/tasks/test/fan_in_tasks.cpp b/src/tasks/test/fan_in_tasks.cpp new file mode 100644 index 00000000000..863c69b04b9 --- /dev/null +++ b/src/tasks/test/fan_in_tasks.cpp @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. + +#include "tasks/fan_in_tasks.h" + +#include "./utils.h" +#include "tasks/basic_task.h" +#include "tasks/job_board.h" + +#include +#include + +#define FMT_HEADER_ONLY +#include +#include + +TEST_CASE("ContiguousQueuing" * doctest::test_suite("fan_in_tasks")) +{ + ccf::tasks::JobBoard jb; + ccf::tasks::Task task; + + auto collection = ccf::tasks::FanInTasks::create(jb); + + REQUIRE(jb.empty()); + + std::atomic done_0{false}; + std::atomic done_1{false}; + std::atomic done_2{false}; + + auto set_0 = ccf::tasks::make_basic_task([&]() { done_0.store(true); }); + auto set_1 = ccf::tasks::make_basic_task([&]() { done_1.store(true); }); + auto set_2 = ccf::tasks::make_basic_task([&]() { done_2.store(true); }); + + // Adding the next-contiguous task instantly enqueues this collection + collection->add_task(0, set_0); + REQUIRE_FALSE(jb.empty()); + + // Non-contiguous tasks can be stored + collection->add_task(2, set_2); + + task = jb.get_task(); + REQUIRE(task != nullptr); + + // Only a contiguous block is executed + REQUIRE_FALSE(done_0.load()); + REQUIRE_FALSE(done_2.load()); + task->do_task(); + REQUIRE(done_0.load()); + REQUIRE_FALSE(done_2.load()); + + // Enqueuing invalid indices results in an error + auto never_execd = ccf::tasks::make_basic_task([]() { REQUIRE(false); }); + REQUIRE_THROWS(collection->add_task(0, never_execd)); + REQUIRE_THROWS(collection->add_task(2, never_execd)); + + // Contiguous next task may arrive out-of-order, queuing a batch of tasks + REQUIRE(jb.empty()); + collection->add_task(1, set_1); + REQUIRE_FALSE(jb.empty()); + + task = jb.get_task(); + REQUIRE(task != nullptr); + + REQUIRE_FALSE(done_1.load()); + REQUIRE_FALSE(done_2.load()); + task->do_task(); + REQUIRE(done_1.load()); + REQUIRE(done_2.load()); +} + +TEST_CASE("InterleavedCompletions" * doctest::test_suite("fan_in_tasks")) +{ + // Testing mutexes + re-enqueuing logic of FanInTasks, where tasks are added + // to the collection _while the collection is being executed_ + ccf::tasks::JobBoard jb; + ccf::tasks::Task task; + + auto collection = ccf::tasks::FanInTasks::create(jb); + + std::atomic all_done{false}; + collection->add_task( + 0, ccf::tasks::make_basic_task([&]() { + collection->add_task( + 1, ccf::tasks::make_basic_task([&]() { all_done.store(true); })); + })); + + REQUIRE_FALSE(jb.empty()); + task = jb.get_task(); + REQUIRE(task != nullptr); + + REQUIRE_FALSE(all_done.load()); + task->do_task(); + // setter task was _enqueued_, but not _executed_ yet + REQUIRE_FALSE(all_done.load()); + + REQUIRE_FALSE(jb.empty()); + task = jb.get_task(); + REQUIRE(task != nullptr); + task->do_task(); + REQUIRE(all_done.load()); + REQUIRE(jb.empty()); + + { + // Reset, and try a more complex example + all_done.store(false); + + collection->add_task( + 2, ccf::tasks::make_basic_task([&]() { + collection->add_task( + 5, ccf::tasks::make_basic_task([&]() { all_done.store(true); })); + })); + + REQUIRE_FALSE(jb.empty()); + task = jb.get_task(); + REQUIRE(task != nullptr); + task->do_task(); + REQUIRE_FALSE(all_done.load()); + + collection->add_task( + 3, ccf::tasks::make_basic_task([&]() { + collection->add_task( + 4, ccf::tasks::make_basic_task([&]() { all_done.store(true); })); + })); + + REQUIRE_FALSE(jb.empty()); + task = jb.get_task(); + REQUIRE(task != nullptr); + task->do_task(); + REQUIRE_FALSE(all_done.load()); + + REQUIRE_FALSE(jb.empty()); + task = jb.get_task(); + REQUIRE(task != nullptr); + task->do_task(); + REQUIRE(all_done.load()); + REQUIRE(jb.empty()); + } +} + +TEST_CASE("DelayedCompletions" * doctest::test_suite("fan_in_tasks")) +{ + ccf::tasks::JobBoard jb; + + static constexpr size_t num_tasks = 100; + + struct CalledInOrder : public ccf::tasks::BaseTask + { + std::atomic& counter; + const size_t expected_value; + const std::string name; + + CalledInOrder(std::atomic& c, size_t ev) : + counter(c), + expected_value(ev), + name(fmt::format("CalledInOrder {}", expected_value)) + {} + + void do_task_implementation() override + { + REQUIRE(counter.load() == expected_value); + ++counter; + } + + std::string_view get_name() const override + { + return name; + } + }; + + auto completions = ccf::tasks::FanInTasks::create(jb); + std::atomic counter; + + for (auto i = 0; i < num_tasks; ++i) + { + jb.add_task(ccf::tasks::make_basic_task([&, i]() { + const std::chrono::milliseconds sleep_time(rand() % 100); + std::this_thread::sleep_for(sleep_time); + + completions->add_task(i, std::make_shared(counter, i)); + })); + } + + test::utils::flush_board(jb, num_tasks); + + // Each task asserted that it executed in-order, and this confirms that all + // tasks executed + REQUIRE(counter.load() == num_tasks); +} \ No newline at end of file diff --git a/src/tasks/test/ordered_tasks.cpp b/src/tasks/test/ordered_tasks.cpp new file mode 100644 index 00000000000..8218f599723 --- /dev/null +++ b/src/tasks/test/ordered_tasks.cpp @@ -0,0 +1,173 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. + +#include "tasks/ordered_tasks.h" + +#include "./utils.h" +#include "tasks/basic_task.h" +#include "tasks/sub_task_queue.h" + +#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN +#include +#define FMT_HEADER_ONLY +#include +#include +#include +#include +#include +#include +#include +#include + +uint8_t thread_name() +{ + return std::hash{}(std::this_thread::get_id()); +} + +void thread_print(const std::string& s) +{ +#if false + static std::mutex logging_mutex; + std::lock_guard guard(logging_mutex); + fmt::print("[{:0x}] {}\n", thread_name(), s); +#endif +} + +// Confirm expected semantics of SubTaskQueue type +TEST_CASE("SubTaskQueue" * doctest::test_suite("ordered_tasks")) +{ + ccf::tasks::SubTaskQueue fq; + + // push returns true iff queue was previously empty and inactive + REQUIRE(fq.push(1)); + REQUIRE_FALSE(fq.push(2)); + REQUIRE_FALSE(fq.push(3)); + REQUIRE_FALSE(fq.push(4)); + + // pop returns true iff queue is non-empty when it completes + REQUIRE_FALSE(fq.pop_and_visit([](size_t&& n) {})); + + // Visits an empty queue, leaves an empty queue + REQUIRE_FALSE(fq.pop_and_visit([](size_t&& n) {})); + + // Not the first push _ever_, but the first on an empty queue, so gets a true + // response + REQUIRE(fq.push(5)); + + // If the visitor (or anything concurrent with it) pushes a new element, then + // the pop returns true to indicate that queue is now non-empty + REQUIRE(fq.pop_and_visit([&](size_t&& n) { + // While popping/visiting, the queue is active + REQUIRE_FALSE(fq.push(6)); + })); + + REQUIRE(fq.pop_and_visit([&](size_t&& n) { + REQUIRE_FALSE(fq.push(7)); + REQUIRE_FALSE(fq.push(8)); + REQUIRE_FALSE(fq.push(9)); + })); + + REQUIRE_FALSE(fq.pop_and_visit([&](size_t&& n) {})); +} + +TEST_CASE("OrderedTasks" * doctest::test_suite("ordered_tasks")) +{ + ccf::tasks::JobBoard jb; + + auto p_a = ccf::tasks::OrderedTasks::create(jb); + auto p_b = ccf::tasks::OrderedTasks::create(jb); + auto p_c = ccf::tasks::OrderedTasks::create(jb); + + std::atomic executed[14] = {0}; + + ccf::tasks::OrderedTasks& tasks_a = *p_a; + tasks_a.add_action(ccf::tasks::make_basic_action([&]() { + thread_print("A (no dependencies)"); + executed[0].store(true); + })); + tasks_a.add_action(ccf::tasks::make_basic_action([&]() { + thread_print("B (after A)"); + REQUIRE(executed[0].load()); + executed[1].store(true); + })); + tasks_a.add_action(ccf::tasks::make_basic_action([&]() { + thread_print("C (after B)"); + REQUIRE(executed[1].load()); + executed[2].store(true); + })); + + ccf::tasks::OrderedTasks& tasks_b = *p_b; + tasks_b.add_action(ccf::tasks::make_basic_action([&]() { + thread_print("D (no dependencies)"); + executed[3].store(true); + + tasks_b.add_action(ccf::tasks::make_basic_action([&]() { + thread_print("E (after D)"); + REQUIRE(executed[3].load()); + executed[4].store(true); + + tasks_b.add_action(ccf::tasks::make_basic_action([&]() { + thread_print("F (after E)"); + REQUIRE(executed[4].load()); + executed[5].store(true); + + tasks_b.add_action(ccf::tasks::make_basic_action([&]() { + thread_print("G (after F)"); + REQUIRE(executed[5].load()); + executed[6].store(true); + })); + })); + })); + })); + + ccf::tasks::OrderedTasks& tasks_c = *p_c; + tasks_c.add_action(ccf::tasks::make_basic_action([&]() { + thread_print("I (no dependencies)"); + executed[7].store(true); + + tasks_a.add_action(ccf::tasks::make_basic_action([&]() { + thread_print("J (after I and C)"); + REQUIRE(executed[2].load()); + REQUIRE(executed[7].load()); + executed[8].store(true); + + tasks_a.add_action(ccf::tasks::make_basic_action([&]() { + thread_print("K (after J)"); + REQUIRE(executed[8].load()); + executed[9].store(true); + + tasks_c.add_action(ccf::tasks::make_basic_action([&]() { + thread_print("L (after K)"); + REQUIRE(executed[9].load()); + executed[10].store(true); + })); + })); + })); + + tasks_b.add_action(ccf::tasks::make_basic_action([&]() { + thread_print("M (after I and D)"); + REQUIRE(executed[3].load()); + REQUIRE(executed[7].load()); + executed[11].store(true); + + tasks_a.add_action(ccf::tasks::make_basic_action([&]() { + thread_print("N (after M and C)"); + REQUIRE(executed[2].load()); + REQUIRE(executed[11].load()); + executed[12].store(true); + + tasks_c.add_action(ccf::tasks::make_basic_action([&]() { + thread_print("O (after N)"); + REQUIRE(executed[12].load()); + executed[13].store(true); + })); + })); + })); + })); + + test::utils::flush_board(jb, 8, [&]() { + return std::all_of(std::begin(executed), std::end(executed), [](auto&& e) { + return e.load(); + }); + }); +} diff --git a/src/tasks/test/utils.h b/src/tasks/test/utils.h new file mode 100644 index 00000000000..3bc0f9369f2 --- /dev/null +++ b/src/tasks/test/utils.h @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include "tasks/job_board.h" + +#include + +namespace test::utils +{ + static inline void worker_loop_func( + ccf::tasks::IJobBoard& job_board, std::atomic& stop) + { + while (!stop.load()) + { + auto task = job_board.get_task(); + if (task != nullptr) + { + task->do_task(); + } + std::this_thread::yield(); + } + } + + static inline void flush_board( + ccf::tasks::IJobBoard& job_board, + size_t max_workers = 8, + std::function safe_to_end = nullptr, + std::chrono::seconds kill_after = std::chrono::seconds(5)) + { + std::atomic stop_signal{false}; + + std::vector workers; + for (size_t i = 0; i < max_workers; ++i) + { + workers.emplace_back( + worker_loop_func, std::ref(job_board), std::ref(stop_signal)); + } + + using TClock = std::chrono::steady_clock; + auto now = TClock::now(); + const auto end_time = now + std::chrono::seconds(1); + const auto hard_end = now + kill_after; + + if (safe_to_end == nullptr) + { + safe_to_end = [&]() { return now > end_time && job_board.empty(); }; + } + + while (true) + { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + now = TClock::now(); + if (now > hard_end) + { + break; + } + + if (safe_to_end()) + { + break; + } + } + + stop_signal.store(true); + + for (auto& worker : workers) + { + worker.join(); + } + } +}