diff --git a/include/seastar/coroutine/try_future.hh b/include/seastar/coroutine/try_future.hh new file mode 100644 index 00000000000..fd3b2ab3b0f --- /dev/null +++ b/include/seastar/coroutine/try_future.hh @@ -0,0 +1,145 @@ +/* + * This file is open source software, licensed to you under the terms + * of the Apache License, Version 2.0 (the "License"). See the NOTICE file + * distributed with this work for additional information regarding copyright + * ownership. You may not use this file except in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * Copyright (C) 2025-present ScyllaDB + */ + +#pragma once + +#include + +namespace seastar::internal { + +template +void try_future_resume_or_destroy_coroutine(seastar::future& fut, seastar::task& coroutine_task) { + auto promise_ptr = static_cast(&coroutine_task); + auto hndl = std::coroutine_handle::from_promise(*promise_ptr); + + if (fut.failed()) { + hndl.promise().set_exception(std::move(fut).get_exception()); + hndl.destroy(); + } else { + hndl.resume(); + } +} + +template +class [[nodiscard]] try_future_awaiter : public seastar::task { + seastar::future _future; + void (*_resume_or_destroy)(seastar::future&, seastar::task&){}; + seastar::task* _coroutine_task{}; + seastar::task* _waiting_task{}; + +public: + explicit try_future_awaiter(seastar::future&& f) noexcept : _future(std::move(f)) {} + + try_future_awaiter(const try_future_awaiter&) = delete; + try_future_awaiter(try_future_awaiter&&) = delete; + + bool await_ready() const noexcept { + // Will suspend+schedule for ready failed futures too. + return _future.available() && !_future.failed() && (!CheckPreempt || !need_preempt()); + } + + template + void await_suspend(std::coroutine_handle hndl) noexcept { + _resume_or_destroy = try_future_resume_or_destroy_coroutine; + _coroutine_task = &hndl.promise(); + _waiting_task = hndl.promise().waiting_task(); + + if (!_future.available()) { + _future.set_coroutine(*this); + } else { + schedule(this); + } + } + + T await_resume() { + if constexpr (std::is_void_v) { + _future.get(); + } else { + return std::move(_future).get(); + } + } + + virtual void run_and_dispose() noexcept override { + _resume_or_destroy(_future, *_coroutine_task); + } + + virtual task* waiting_task() noexcept override { + return _waiting_task; + } +}; + +} // namespace seastar::internal + +namespace seastar::coroutine { + +/// \brief co_await:s a \ref future and returns the wrapped result if successful, +/// terminates the coroutine otherwise, propagating the exception directly to the +/// waiter. +/// +/// If the future was successful, this is identical to co_await-ing the future +/// directly. If the future failed, the coroutine is not resumed and instead the +/// exception from the future is forwarded to the waiter directly and the +/// coroutine is destroyed. +/// +/// For example: +/// ``` +/// // Function careful to not throw exceptions, instead returning failed futures. +/// future bar() { +/// if (something_bad_happened) { +/// return make_exception_future<>(std::runtime_error("error")); +/// } +/// return result; +/// } +/// +/// future<> foo() { +/// auto result = co_await coroutine::try_future(bar()); +/// // This code is only executed if bar() returned a successful future. +/// // Otherwise the exception is forwarded to the waiter future directly +/// // and the coroutine is destroyed. +/// check_result(result); +/// } +/// ``` +/// +/// Note that by default, `try_future` checks for if the task quota is depleted, +/// which means that it will yield if the future is ready and \ref seastar::need_preempt() +/// returns true. Use \ref coroutine::try_future_without_preemption_check +/// to disable preemption checking. +template +class [[nodiscard]] try_future : public seastar::internal::try_future_awaiter { +public: + explicit try_future(seastar::future&& f) noexcept + : seastar::internal::try_future_awaiter(std::move(f)) + {} +}; + +/// \brief co_await:s a \ref future, returns the wrapped result if successful, +/// terminates the coroutine otherwise, propagating the exception to the waiter. +/// +/// Same as \ref coroutine::try_future, but does not check for preemption. +template +class [[nodiscard]] try_future_without_preemption_check : public seastar::internal::try_future_awaiter { +public: + explicit try_future_without_preemption_check(seastar::future&& f) noexcept + : seastar::internal::try_future_awaiter(std::move(f)) + {} +}; + +} // namespace seastar::coroutine diff --git a/tests/unit/coroutines_test.cc b/tests/unit/coroutines_test.cc index 0b1d3e79c85..a0d86ba3b32 100644 --- a/tests/unit/coroutines_test.cc +++ b/tests/unit/coroutines_test.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -35,8 +36,10 @@ #include #include #include +#include #include #include +#include #include using seastar::broken_promise; @@ -1018,3 +1021,161 @@ SEASTAR_TEST_CASE(test_lambda_coroutine_in_continuation) { })); BOOST_REQUIRE_EQUAL(sin1, sin2); } + +class test_exception : std::exception { }; + +future<> throw_void() { + fmt::print("throw_void()\n"); + co_await sleep(1ms); + throw test_exception{}; +} + +future<> return_ex_void() { + fmt::print("return_ex_void()\n"); + return make_exception_future<>(test_exception{}); +} + +future<> return_void() { + fmt::print("return_void()\n"); + co_await sleep(1ms); +} + +future throw_int() { + fmt::print("throw_int()\n"); + co_await sleep(1ms); + throw test_exception{}; +} + +future return_ex_int() { + fmt::print("return_ex_int()\n"); + return make_exception_future(test_exception{}); +} + +future return_int() { + fmt::print("return_int()\n"); + co_await sleep(1ms); + co_return 128; +} + +class dummy { + int& _c; + +public: + explicit dummy(int& c) : _c(c) { ++_c; } + dummy(const dummy& o) : _c(o._c) { ++_c; } + dummy(dummy&&) = delete; + ~dummy() { --_c; } +}; + +template +struct result_wrapper { + T value; + explicit result_wrapper(T v) : value(v) {} +}; + +template <> +struct result_wrapper { +}; + +template F> +// Use result_wrapper to create a mismatch between the return type of +// the coroutine and that of the underlying function, to ensure that +// try_future handles this case correctly. +future::value_type>> +do_run_try_future_test(F underlying_func, int& ctor_dtor_counter, bool& run_past) { + const auto check_cxx_exceptions_on_exit = seastar::defer([cxx_exception_before = seastar::engine().cxx_exceptions()] () noexcept { + if (seastar::engine().cxx_exceptions() != cxx_exception_before) { + // We are in a destructor, cannot throw + std::abort(); + } + }); + + dummy d1{ctor_dtor_counter}; + dummy d2{ctor_dtor_counter}; + + std::vector dummies; + for (unsigned i = 0; i < 10; ++i) { + dummies.emplace_back(ctor_dtor_counter); + } + + BOOST_REQUIRE_GT(ctor_dtor_counter, 0); + + using return_future_type = std::invoke_result_t; + using return_type = typename return_future_type::value_type; + constexpr bool is_void = std::is_same_v>; + + std::any ret; + + try { + if constexpr (is_void) { + if constexpr (CheckPreempt) { + co_await seastar::coroutine::try_future(underlying_func()); + } else { + co_await seastar::coroutine::try_future_without_preemption_check(underlying_func()); + } + } else { + if constexpr (CheckPreempt) { + ret = co_await seastar::coroutine::try_future(underlying_func()); + } else { + ret = co_await seastar::coroutine::try_future_without_preemption_check(underlying_func()); + } + } + run_past = true; + } catch (...) { + BOOST_FAIL(fmt::format("Exception should be handled in try_future, bug caught: {}", std::current_exception())); + } + + if constexpr (is_void) { + co_return result_wrapper{}; + } else { + co_return result_wrapper(std::any_cast(ret)); + } +} + +template F> +future<> run_try_future_test(F underlying_func, std::optional expected_value, std::source_location sl = std::source_location::current()) { + fmt::print("running test case at {}:{}\n", sl.file_name(), sl.line()); + + int ctor_dtor_counter{0}; + bool run_past{false}; + + const bool throws = !expected_value.has_value(); + + using return_future_type = std::invoke_result_t; + using return_type = typename return_future_type::value_type; + constexpr bool is_void = std::is_same_v>; + + try { + if constexpr (is_void) { + co_await do_run_try_future_test(std::move(underlying_func), ctor_dtor_counter, run_past); + BOOST_REQUIRE(expected_value); + } else { + auto res = co_await do_run_try_future_test(std::move(underlying_func), ctor_dtor_counter, run_past); + BOOST_REQUIRE(expected_value); + BOOST_REQUIRE_EQUAL(res.value, std::any_cast(*expected_value)); + } + } catch (test_exception&) { + BOOST_REQUIRE(throws); + } catch (...) { + BOOST_FAIL(fmt::format("Unexpected exception {}", std::current_exception())); + } + + BOOST_REQUIRE_EQUAL(run_past, !throws); + BOOST_REQUIRE_EQUAL(ctor_dtor_counter, 0); +} + +SEASTAR_TEST_CASE(test_try_future) { + co_await run_try_future_test(return_void, std::any{}); + co_await run_try_future_test(return_void, std::any{}); + co_await run_try_future_test(return_ex_void, std::nullopt); + co_await run_try_future_test(return_ex_void, std::nullopt); + co_await run_try_future_test(throw_void, std::nullopt); + co_await run_try_future_test(throw_void, std::nullopt); + + co_await run_try_future_test(return_int, 128); + co_await run_try_future_test(return_int, 128); + co_await run_try_future_test(return_ex_int, std::nullopt); + co_await run_try_future_test(return_ex_int, std::nullopt); + co_await run_try_future_test(throw_int, std::nullopt); + co_await run_try_future_test(throw_int, std::nullopt); +}