diff --git a/cpp/mrc/include/mrc/runnable/context.hpp b/cpp/mrc/include/mrc/runnable/context.hpp index 090804972..b26e6f8ed 100644 --- a/cpp/mrc/include/mrc/runnable/context.hpp +++ b/cpp/mrc/include/mrc/runnable/context.hpp @@ -17,15 +17,19 @@ #pragma once -#include +#include "mrc/types.hpp" // for Future -#include -#include -#include -#include +#include // for CHECK, COMPACT_GOOGLE_LOG_FATAL, LogMessag... + +#include // for size_t +#include // for exception_ptr +#include // for function +#include // for stringstream +#include // for allocator, string namespace mrc::runnable { +class IEngine; class Runner; enum class EngineType; @@ -41,7 +45,7 @@ class Context { public: Context() = delete; - Context(std::size_t rank, std::size_t size); + Context(const Runner& runner, IEngine& engine, std::size_t rank, std::size_t size); virtual ~Context() = default; EngineType execution_context() const; @@ -54,6 +58,8 @@ class Context void barrier(); void yield(); + Future launch_fiber(std::function task); + const std::string& info() const; template @@ -69,7 +75,7 @@ class Context void set_exception(std::exception_ptr exception_ptr); protected: - void init(const Runner& runner); + void start(); bool status() const; void finish(); virtual void init_info(std::stringstream& ss); @@ -79,7 +85,8 @@ class Context std::size_t m_size; std::string m_info{"Uninitialized Context"}; std::exception_ptr m_exception_ptr{nullptr}; - const Runner* m_runner{nullptr}; + const Runner& m_runner; + IEngine& m_engine; virtual void do_lock() = 0; virtual void do_unlock() = 0; diff --git a/cpp/mrc/include/mrc/runnable/detail/type_traits.hpp b/cpp/mrc/include/mrc/runnable/detail/type_traits.hpp index d33d7be84..57444696f 100644 --- a/cpp/mrc/include/mrc/runnable/detail/type_traits.hpp +++ b/cpp/mrc/include/mrc/runnable/detail/type_traits.hpp @@ -85,16 +85,10 @@ static auto unwrap_context(l4_concept c, ThreadContext& t) return std::make_pair(self{}, self{}); } -static auto unwrap_context(l2_concept c, Context& t) -{ - return std::make_pair(self{}, self{}); -} - template -static error unwrap_context(error e, T& t) +static auto unwrap_context(l2_concept c, T& t) { - static_assert(invalid_concept::error, "object is not a Context"); - return {}; + return std::make_pair(self{}, self{}); } template diff --git a/cpp/mrc/include/mrc/runnable/engine.hpp b/cpp/mrc/include/mrc/runnable/engine.hpp index 1c8f95212..73852fc18 100644 --- a/cpp/mrc/include/mrc/runnable/engine.hpp +++ b/cpp/mrc/include/mrc/runnable/engine.hpp @@ -22,6 +22,7 @@ #include "mrc/core/fiber_meta_data.hpp" #include "mrc/core/fiber_pool.hpp" #include "mrc/core/task_queue.hpp" +#include "mrc/runnable/context.hpp" #include "mrc/runnable/launch_options.hpp" #include "mrc/types.hpp" @@ -54,6 +55,7 @@ class IEngine virtual Future launch_task(std::function task) = 0; friend Runner; + friend Context; }; /** diff --git a/cpp/mrc/include/mrc/runnable/launch_control.hpp b/cpp/mrc/include/mrc/runnable/launch_control.hpp index 67371a218..2fddf3b13 100644 --- a/cpp/mrc/include/mrc/runnable/launch_control.hpp +++ b/cpp/mrc/include/mrc/runnable/launch_control.hpp @@ -103,6 +103,9 @@ class LaunchControl final // engines are out way of running some task on the specified backend std::shared_ptr engines = build_engines(options); + // create runner + auto runner = runnable::make_runner(std::move(runnable)); + // make contexts std::vector> contexts; if constexpr (is_fiber_runnable_v) @@ -113,6 +116,7 @@ class LaunchControl final "ThreadEngine"; contexts = make_contexts>>( + *runner, *engines, std::forward(context_args)...); } @@ -123,6 +127,7 @@ class LaunchControl final "to be run on a " "FiberEngine"; contexts = make_contexts>>( + *runner, *engines, std::forward(context_args)...); } @@ -132,12 +137,14 @@ class LaunchControl final if (backend == EngineType::Fiber) { contexts = make_contexts>>( + *runner, *engines, std::forward(context_args)...); } else if (backend == EngineType::Thread) { contexts = make_contexts>>( + *runner, *engines, std::forward(context_args)...); } @@ -147,9 +154,6 @@ class LaunchControl final } } - // create runner - auto runner = runnable::make_runner(std::move(runnable)); - // construct the launcher return std::make_unique(std::move(runner), std::move(contexts), std::move(engines)); } @@ -204,6 +208,9 @@ class LaunchControl final // engines are out way of running some task on the specified backend std::shared_ptr engines = build_engines(options); + // create runner + auto runner = runnable::make_runner(std::move(runnable)); + // make contexts std::vector> contexts; if constexpr (is_fiber_runnable_v) @@ -212,7 +219,7 @@ class LaunchControl final "FiberRunnable to " "be run on a " "ThreadEngine"; - contexts = make_contexts(*engines, std::forward(context_args)...); + contexts = make_contexts(*runner, *engines, std::forward(context_args)...); } else if constexpr (is_thread_context_v) { @@ -220,19 +227,21 @@ class LaunchControl final "ThreadRunnable " "to be run on a " "FiberEngine"; - contexts = make_contexts(*engines, std::forward(context_args)...); + contexts = make_contexts(*runner, *engines, std::forward(context_args)...); } else { auto backend = get_engine_factory(options.engine_factory_name).backend(); if (backend == EngineType::Fiber) { - contexts = make_contexts>(*engines, + contexts = make_contexts>(*runner, + *engines, std::forward(context_args)...); } else if (backend == EngineType::Thread) { - contexts = make_contexts>(*engines, + contexts = make_contexts>(*runner, + *engines, std::forward(context_args)...); } else @@ -241,9 +250,6 @@ class LaunchControl final } } - // create runner - auto runner = runnable::make_runner(std::move(runnable)); - // construct the launcher return std::make_unique(std::move(runner), std::move(contexts), std::move(engines)); } @@ -325,14 +331,15 @@ class LaunchControl final * @return auto */ template - auto make_contexts(const IEngines& engines, ArgsT&&... args) + auto make_contexts(const Runner& runner, const IEngines& engines, ArgsT&&... args) { const auto size = engines.size(); std::vector> contexts; auto resources = std::make_shared(size); for (std::size_t i = 0; i < size; ++i) { - contexts.push_back(std::make_shared(resources, i, size, args...)); + contexts.push_back( + std::make_shared(resources, runner, *engines.launchers()[i], i, size, args...)); } return std::move(contexts); } diff --git a/cpp/mrc/include/mrc/runnable/runner.hpp b/cpp/mrc/include/mrc/runnable/runner.hpp index c2539b601..c3cf12f9b 100644 --- a/cpp/mrc/include/mrc/runnable/runner.hpp +++ b/cpp/mrc/include/mrc/runnable/runner.hpp @@ -248,7 +248,7 @@ class SpecializedRunner : public Runner auto resources = std::make_shared(size); for (std::size_t i = 0; i < size; ++i) { - contexts.push_back(std::make_shared(resources, i, size, std::forward(args)...)); + contexts.push_back(std::make_shared(resources, i, std::forward(args)...)); } return std::move(contexts); } diff --git a/cpp/mrc/include/mrc/segment/builder.hpp b/cpp/mrc/include/mrc/segment/builder.hpp index a35f571c9..b1db230e9 100644 --- a/cpp/mrc/include/mrc/segment/builder.hpp +++ b/cpp/mrc/include/mrc/segment/builder.hpp @@ -278,6 +278,9 @@ class IBuilder typename... ArgsT> auto make_node(std::string name, ArgsT&&... ops); + template + auto make_node_explicit(std::string name, ArgsT&&... ops); + /** * Creates and returns an instance of a node component with the specified type, name and arguments. * @tparam SinkTypeT The sink type of the node component to be created. @@ -436,6 +439,12 @@ auto IBuilder::make_node(std::string name, ArgsT&&... ops) return construct_object>(name, std::forward(ops)...); } +template +auto IBuilder::make_node_explicit(std::string name, ArgsT&&... ops) +{ + return construct_object(name, std::forward(ops)...); +} + template class NodeTypeT, typename... ArgsT> auto IBuilder::make_node_component(std::string name, ArgsT&&... ops) { diff --git a/cpp/mrc/include/mrc/segment/context.hpp b/cpp/mrc/include/mrc/segment/context.hpp index 485f05d20..627cd793d 100644 --- a/cpp/mrc/include/mrc/segment/context.hpp +++ b/cpp/mrc/include/mrc/segment/context.hpp @@ -28,8 +28,13 @@ class Context : public ContextT { public: template - Context(std::size_t rank, std::size_t size, std::string name, ArgsT&&... args) : - ContextT(std::move(rank), std::move(size), std::forward(args)...), + Context(const mrc::runnable::Runner& runner, + mrc::runnable::IEngine& engine, + std::size_t rank, + std::size_t size, + std::string name, + ArgsT&&... args) : + ContextT(runner, engine, std::move(rank), std::move(size), std::forward(args)...), m_name(std::move(name)) { static_assert(std::is_base_of_v, "ContextT must derive from Context"); diff --git a/cpp/mrc/src/internal/runnable/engine.cpp b/cpp/mrc/src/internal/runnable/engine.cpp index 4ebe57491..fbd6c6e31 100644 --- a/cpp/mrc/src/internal/runnable/engine.cpp +++ b/cpp/mrc/src/internal/runnable/engine.cpp @@ -17,23 +17,16 @@ #include "internal/runnable/engine.hpp" -#include "mrc/types.hpp" +#include "mrc/types.hpp" // for Future -#include - -#include -#include -#include +#include // for mutex, lock_guard +#include // for move namespace mrc::runnable { Future Engine::launch_task(std::function task) { std::lock_guard lock(m_mutex); - if (m_launched) - { - LOG(FATAL) << "detected attempted reuse of a runnable::Engine; this is a fatal error"; - } m_launched = true; return do_launch_task(std::move(task)); } diff --git a/cpp/mrc/src/public/runnable/context.cpp b/cpp/mrc/src/public/runnable/context.cpp index 8e1aa510e..a1edfd3cc 100644 --- a/cpp/mrc/src/public/runnable/context.cpp +++ b/cpp/mrc/src/public/runnable/context.cpp @@ -17,16 +17,17 @@ #include "mrc/runnable/context.hpp" -#include "mrc/runnable/runner.hpp" +#include "mrc/runnable/runner.hpp" // for Runner -#include -#include +#include // for fiber_specific_ptr +#include // for async +#include // for COMPACT_GOOGLE_LOG_FATAL -#include -#include -#include -#include -#include +#include // for size_t +#include // for exception_ptr, current_excep... +#include // for operator<<, basic_ostream +#include // for char_traits, operator<<, string +#include // for move namespace mrc::runnable { @@ -47,7 +48,12 @@ struct FiberLocalContext } // namespace -Context::Context(std::size_t rank, std::size_t size) : m_rank(rank), m_size(size) {} +Context::Context(const Runner& runner, IEngine& engine, std::size_t rank, std::size_t size) : + m_runner(runner), + m_engine(engine), + m_rank(rank), + m_size(size) +{} EngineType Context::execution_context() const { @@ -93,7 +99,23 @@ void Context::yield() do_yield(); } -void Context::init(const Runner& runner) +Future Context::launch_fiber(std::function task) +{ + return boost::fibers::async([this, task]() { + auto& fiber_local = FiberLocalContext::get(); + fiber_local.reset(new FiberLocalContext()); + fiber_local->m_context = this; + try + { + task(); + } catch (...) + { + set_exception(std::current_exception()); + } + }); +} + +void Context::start() { auto& fiber_local = FiberLocalContext::get(); fiber_local.reset(new FiberLocalContext()); @@ -102,8 +124,6 @@ void Context::init(const Runner& runner) std::stringstream ss; this->init_info(ss); m_info = ss.str(); - - m_runner = &runner; } void Context::finish() @@ -127,7 +147,7 @@ void Context::set_exception(std::exception_ptr exception_ptr) if (m_exception_ptr == nullptr) { m_exception_ptr = std::move(std::current_exception()); - m_runner->kill(); + m_runner.kill(); } } } diff --git a/cpp/mrc/src/public/runnable/runner.cpp b/cpp/mrc/src/public/runnable/runner.cpp index 5545fa628..bd8aa5a74 100644 --- a/cpp/mrc/src/public/runnable/runner.cpp +++ b/cpp/mrc/src/public/runnable/runner.cpp @@ -130,7 +130,7 @@ void Runner::enqueue(std::shared_ptr launcher, std::vectorlaunch_task([this, context, &instance] { - context->init(*this); + context->start(); update_state(context->rank(), State::Running); instance.m_live_promise.set_value(); m_runnable->main(*context); diff --git a/dependencies.yaml b/dependencies.yaml index 966608a19..0ebfd8495 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -8,6 +8,7 @@ files: includes: - empty - build_cpp + - dev_cpp - cudatoolkit channels: @@ -46,6 +47,13 @@ dependencies: - scikit-build>=0.17 - pybind11-stubgen=0.10 - python=3.10 + + dev_cpp: + common: + - output_types: [conda] + packages: + - clangdev=16 + cudatoolkit: specific: - output_types: [conda] diff --git a/python/mrc/_pymrc/include/pymrc/executor.hpp b/python/mrc/_pymrc/include/pymrc/executor.hpp index c339d925a..cec4c43ca 100644 --- a/python/mrc/_pymrc/include/pymrc/executor.hpp +++ b/python/mrc/_pymrc/include/pymrc/executor.hpp @@ -32,6 +32,9 @@ class IExecutor; } namespace mrc::pymrc { + +std::function create_gil_initializer(); + class Pipeline; // Export everything in the mrc::pymrc namespace by default since we compile with -fvisibility=hidden diff --git a/python/mrc/_pymrc/include/pymrc/node.hpp b/python/mrc/_pymrc/include/pymrc/node.hpp index f5d72e7c3..5b590b1f0 100644 --- a/python/mrc/_pymrc/include/pymrc/node.hpp +++ b/python/mrc/_pymrc/include/pymrc/node.hpp @@ -33,14 +33,19 @@ #include "mrc/node/rx_sink.hpp" #include "mrc/node/rx_source.hpp" #include "mrc/runnable/context.hpp" +#include "mrc/runnable/runner.hpp" // for Runner #include #include #include #include +#include // for uint32_t +#include +#include #include #include +#include // for thread #include // Avoid forward declaring template specialization base classes @@ -49,6 +54,12 @@ namespace mrc { +namespace runnable { + +class IEngine; + +} + namespace edge { template @@ -310,14 +321,48 @@ class PythonSinkComponent : public node::RxSinkComponent, using base_t::base_t; }; -template -class PythonNode : public node::RxNode, +class PythonNodeLoopHandle +{ + public: + PythonNodeLoopHandle(); + ~PythonNodeLoopHandle(); + + uint32_t inc_ref(); + uint32_t dec_ref(); + + PyHolder get_asyncio_event_loop(); + + private: + uint32_t m_references = 0; + PyHolder m_loop; + std::atomic m_loop_ct = false; + std::thread m_loop_thread; +}; + +class PythonNodeContext : public mrc::runnable::Context +{ + public: + PythonNodeContext(const mrc::runnable::Runner& runner, + mrc::runnable::IEngine& engine, + std::size_t rank, + std::size_t size); + ~PythonNodeContext() override; + + PyHolder get_asyncio_event_loop(); + + private: + // TODO(cwharris): this should be a thread-specific pointer, + std::unique_ptr m_loop_handle; +}; + +template +class PythonNode : public node::RxNode, public pymrc::AutoRegSourceAdapter, public pymrc::AutoRegSinkAdapter, public pymrc::AutoRegIngressPort, public pymrc::AutoRegEgressPort { - using base_t = node::RxNode; + using base_t = node::RxNode; public: using typename base_t::stream_fn_t; diff --git a/python/mrc/_pymrc/include/pymrc/operators.hpp b/python/mrc/_pymrc/include/pymrc/operators.hpp index 3a7b788d4..50d87297d 100644 --- a/python/mrc/_pymrc/include/pymrc/operators.hpp +++ b/python/mrc/_pymrc/include/pymrc/operators.hpp @@ -19,6 +19,9 @@ #include "pymrc/types.hpp" +#include // for module_ + +#include // for uint32_t #include #include @@ -54,12 +57,33 @@ class OperatorProxy static std::string get_name(PythonOperator& self); }; +class AsyncOperatorHandler +{ + public: + AsyncOperatorHandler(); + ~AsyncOperatorHandler() = default; + + void process_async_task(PyObjectHolder task, PyObjectSubscriber sink); + void process_async_generator(PyObjectHolder asyncgen, PyObjectSubscriber sink); + + void wait_completed() const; + void wait_error(); + + private: + boost::fibers::future future_from_async_task(PyObjectHolder task); + pybind11::module_ m_asyncio; + uint32_t m_outstanding = 0; + bool m_cancelled = false; +}; + class OperatorsProxy { public: static PythonOperator build(PyFuncHolder build_fn); static PythonOperator filter(PyFuncHolder filter_fn); static PythonOperator flatten(); + static PythonOperator flat_map_async(PyFuncHolder flatmap_fn); + static PythonOperator map_async(PyFuncHolder flatmap_fn); static PythonOperator map(OnDataFunction map_fn); static PythonOperator on_completed(PyFuncHolder()> finally_fn); static PythonOperator pairwise(); diff --git a/python/mrc/_pymrc/src/node.cpp b/python/mrc/_pymrc/src/node.cpp index c3a1563e5..210293016 100644 --- a/python/mrc/_pymrc/src/node.cpp +++ b/python/mrc/_pymrc/src/node.cpp @@ -17,4 +17,107 @@ #include "pymrc/node.hpp" -namespace mrc::pymrc {} // namespace mrc::pymrc +#include "pymrc/executor.hpp" + +#include +#include // for module_ + +#include +#include + +namespace mrc { + +namespace runnable { + +class IEngine; + +} + +namespace pymrc { + +PythonNodeLoopHandle::PythonNodeLoopHandle() +{ + pybind11::gil_scoped_acquire acquire; + + auto asyncio = pybind11::module_::import("asyncio"); + + auto setup_debugging = create_gil_initializer(); + + m_loop = asyncio.attr("new_event_loop")(); + m_loop_ct = false; + m_loop_thread = std::thread([loop = m_loop, &ct = m_loop_ct, setup_debugging = std::move(setup_debugging)]() { + setup_debugging(); + + while (not ct) + { + { + // run event loop once + pybind11::gil_scoped_acquire acquire; + loop.attr("stop")(); + loop.attr("run_forever")(); + } + + std::this_thread::yield(); + } + + pybind11::gil_scoped_acquire acquire; + auto shutdown = loop.attr("shutdown_asyncgens")(); + loop.attr("run_until_complete")(shutdown); + loop.attr("close")(); + }); +} + +PythonNodeLoopHandle::~PythonNodeLoopHandle() +{ + if (m_loop_thread.joinable()) + { + m_loop_ct = true; + m_loop_thread.join(); + } +} + +uint32_t PythonNodeLoopHandle::inc_ref() +{ + return ++m_references; +} + +uint32_t PythonNodeLoopHandle::dec_ref() +{ + return --m_references; +} + +PyHolder PythonNodeLoopHandle::get_asyncio_event_loop() +{ + return m_loop; +} + +PythonNodeContext::PythonNodeContext(const mrc::runnable::Runner& runner, + mrc::runnable::IEngine& engine, + std::size_t rank, + std::size_t size) : + mrc::runnable::Context(runner, engine, rank, size) +{ + if (m_loop_handle == nullptr) + { + m_loop_handle = std::make_unique(); + } + + m_loop_handle->inc_ref(); +} + +PythonNodeContext::~PythonNodeContext() +{ + if (m_loop_handle != nullptr and m_loop_handle->dec_ref() == 0) + { + m_loop_handle.reset(); + } +} + +PyHolder PythonNodeContext::get_asyncio_event_loop() +{ + return m_loop_handle->get_asyncio_event_loop(); +} + +} // namespace pymrc + +} // namespace mrc diff --git a/python/mrc/_pymrc/src/operators.cpp b/python/mrc/_pymrc/src/operators.cpp index cb0b679e1..52054ed45 100644 --- a/python/mrc/_pymrc/src/operators.cpp +++ b/python/mrc/_pymrc/src/operators.cpp @@ -17,19 +17,30 @@ #include "pymrc/operators.hpp" +#include "pymrc/node.hpp" #include "pymrc/types.hpp" #include "pymrc/utilities/acquire_gil.hpp" #include "pymrc/utilities/function_wrappers.hpp" +#include "mrc/core/utils.hpp" +#include "mrc/runnable/context.hpp" + +#include +#include +#include +#include #include #include #include // IWYU pragma: keep #include #include #include +#include #include +#include #include +#include #include #include #include @@ -169,6 +180,192 @@ PythonOperator OperatorsProxy::flatten() }}; } +AsyncOperatorHandler::AsyncOperatorHandler() +{ + pybind11::gil_scoped_acquire acquire; + m_asyncio = py::module_::import("asyncio"); +} + +void AsyncOperatorHandler::wait_completed() const +{ + while (m_outstanding > 0) + { + boost::this_fiber::yield(); + } +} + +void AsyncOperatorHandler::wait_error() +{ + m_cancelled = true; + wait_completed(); +} + +void AsyncOperatorHandler::process_async_task(PyObjectHolder task, PyObjectSubscriber sink) +{ + if (m_cancelled) + { + return; + } + + ++m_outstanding; + + runnable::Context::get_runtime_context().launch_fiber([this, sink, task]() { + auto unwinder = Unwinder([this]() { + --m_outstanding; + }); + + using namespace std::chrono_literals; + + auto yielded = this->future_from_async_task(task); + + while (yielded.wait_for(0s) != boost::fibers::future_status::ready) + { + boost::this_fiber::yield(); + } + + if (not sink.is_subscribed()) + { + return; + } + + try + { + sink.on_next(yielded.get()); + } catch (pybind11::error_already_set& ex) + { + sink.on_error(std::current_exception()); + } + }); +} + +void AsyncOperatorHandler::process_async_generator(PyObjectHolder asyncgen, PyObjectSubscriber sink) +{ + if (m_cancelled) + { + return; + } + + ++m_outstanding; + + runnable::Context::get_runtime_context().launch_fiber([this, sink, asyncgen]() { + auto unwinder = Unwinder([this]() { + --m_outstanding; + }); + while (sink.is_subscribed()) + { + using namespace std::chrono_literals; + + auto gil = std::make_unique(); + PyObjectHolder task = asyncgen.attr("__anext__")(); + gil.reset(); + auto yielded = this->future_from_async_task(task); + + while (yielded.wait_for(0s) != boost::fibers::future_status::ready) + { + boost::this_fiber::yield(); + + if (not sink.is_subscribed()) + { + return; + } + } + + try + { + sink.on_next(yielded.get()); + } catch (pybind11::error_already_set& ex) + { + if (ex.matches(PyExc_StopAsyncIteration)) + { + return; + } + sink.on_error(std::current_exception()); + } + } + }); +} + +boost::fibers::future AsyncOperatorHandler::future_from_async_task(PyObjectHolder task) +{ + py::gil_scoped_acquire acquire; + + auto& ctx = runnable::Context::get_runtime_context().as(); + auto loop = ctx.get_asyncio_event_loop(); + auto future = m_asyncio.attr("run_coroutine_threadsafe")(task, loop); + auto promise = std::make_unique>(); + + auto result_future = promise->get_future(); + + future.attr("add_done_callback")(py::cpp_function([result = std::move(promise)](py::object future) { + try + { + auto acquire = std::make_unique(); + auto value = future.attr("result")(); + acquire.reset(); + result->set_value(std::move(py::reinterpret_borrow(value))); + } catch (std::exception& ex) + { + result->set_exception(std::current_exception()); + } + })); + + return result_future; +} + +PythonOperator OperatorsProxy::flat_map_async(PyFuncHolder flatmap_fn) +{ + return {"flat_map_async", [=](PyObjectObservable source) { + return rxcpp::observable<>::create([=](PyObjectSubscriber sink) { + auto async_handler = std::make_unique(); + source.subscribe( + sink, + [sink, flatmap_fn, &async_handler = *async_handler](PyHolder value) { + auto acquire = std::make_unique(); + auto asyncgen = flatmap_fn(std::move(value)); + acquire.reset(); + async_handler.process_async_generator(asyncgen, sink); + }, + [sink, &async_handler = *async_handler](std::exception_ptr ex) { + // Forward + async_handler.wait_error(); + sink.on_error(std::current_exception()); + }, + [sink, &async_handler = *async_handler]() { + // Forward + async_handler.wait_completed(); + sink.on_completed(); + }); + }); + }}; +} + +PythonOperator OperatorsProxy::map_async(PyFuncHolder flatmap_fn) +{ + return {"map_async", [=](PyObjectObservable source) { + return rxcpp::observable<>::create([=](PyObjectSubscriber sink) { + auto async_handler = std::make_unique(); + source.subscribe( + sink, + [sink, flatmap_fn, &async_handler = *async_handler](PyHolder value) { + auto acquire = std::make_unique(); + auto task = flatmap_fn(std::move(value)); + acquire.reset(); + async_handler.process_async_task(task, sink); + }, + [sink, &async_handler = *async_handler](std::exception_ptr ex) { + // Forward + async_handler.wait_error(); + sink.on_error(std::current_exception()); + }, + [sink, &async_handler = *async_handler]() { + // Forward + async_handler.wait_completed(); + sink.on_completed(); + }); + }); + }}; +} + PythonOperator OperatorsProxy::map(OnDataFunction map_fn) { // Build and return the map operator diff --git a/python/mrc/_pymrc/src/segment.cpp b/python/mrc/_pymrc/src/segment.cpp index 4e60e63e4..9c614d85b 100644 --- a/python/mrc/_pymrc/src/segment.cpp +++ b/python/mrc/_pymrc/src/segment.cpp @@ -28,12 +28,9 @@ #include "mrc/channel/status.hpp" #include "mrc/edge/edge_builder.hpp" #include "mrc/node/port_registry.hpp" -#include "mrc/node/rx_sink_base.hpp" -#include "mrc/node/rx_source_base.hpp" #include "mrc/runnable/context.hpp" #include "mrc/segment/builder.hpp" #include "mrc/segment/object.hpp" -#include "mrc/types.hpp" #include #include @@ -44,7 +41,6 @@ #include #include #include -#include #include #include #include @@ -52,7 +48,6 @@ #include #include #include -#include // IWYU thinks we need array for py::print // IWYU pragma: no_include @@ -390,7 +385,8 @@ std::shared_ptr BuilderProxy::make_node(mrc::seg const std::string& name, pybind11::args operators) { - auto node = self.make_node(name); + // auto node = self.make_node(name); + auto node = self.make_node_explicit>(name); node->object().make_stream( [operators = PyObjectHolder(std::move(operators))](const PyObjectObservable& input) -> PyObjectObservable { @@ -412,7 +408,8 @@ std::shared_ptr BuilderProxy::make_node_full( "make_node_full(name, sub_fn) is deprecated and will be removed in a future version. Use " "make_node(name, mrc.core.operators.build(sub_fn)) instead."); - auto node = self.make_node(name); + // auto node = self.make_node(name); + auto node = self.make_node_explicit>(name); node->object().make_stream([sub_fn](const PyObjectObservable& input) -> PyObjectObservable { return rxcpp::observable<>::create([input, sub_fn](pymrc::PyObjectSubscriber output) { diff --git a/python/mrc/core/operators.cpp b/python/mrc/core/operators.cpp index b74ff96ec..4d41269da 100644 --- a/python/mrc/core/operators.cpp +++ b/python/mrc/core/operators.cpp @@ -28,7 +28,6 @@ #include #include // IWYU pragma: keep -#include #include namespace mrc::pymrc { @@ -55,7 +54,9 @@ PYBIND11_MODULE(operators, py_mod) py_mod.def("build", &OperatorsProxy::build); py_mod.def("filter", &OperatorsProxy::filter); py_mod.def("flatten", &OperatorsProxy::flatten); + py_mod.def("flat_map_async", &OperatorsProxy::flat_map_async); py_mod.def("map", &OperatorsProxy::map); + py_mod.def("map_async", &OperatorsProxy::map_async); py_mod.def("on_completed", &OperatorsProxy::on_completed); py_mod.def("pairwise", &OperatorsProxy::pairwise); py_mod.def("to_list", &OperatorsProxy::to_list); diff --git a/python/tests/test_operators.py b/python/tests/test_operators.py index 9fa939b3a..95367ecdb 100644 --- a/python/tests/test_operators.py +++ b/python/tests/test_operators.py @@ -130,6 +130,53 @@ def node_fn(input: mrc.Observable, output: mrc.Subscriber): assert actual == expected +def test_map_async(run_segment): + + input_data = [0, 1, 2, 3] + expected = [0, 1, 4, 9] + + async def square_async(value): + import asyncio + await asyncio.sleep(0) + return value * value + + def node_fn(input: mrc.Observable, output: mrc.Subscriber): + input.pipe(ops.map_async(square_async)).subscribe(output) + + actual, raised_error = run_segment(input_data, node_fn) + + assert set(actual) == set(expected) + +def test_flat_map_async(run_segment): + + input_data = [('a', 5), ('b', 1), ('c', 3)] + expected = [('a', 0), ('a', 1), ('a', 2), ('a', 3), ('a', 4), ('b', 0), ('c', 0), ('c', 1), ('c', 2)] + + import random + random.shuffle(input_data) # the output order is not dictated by the input order + + async def generate(value): + name, count = value + for i in range(count): + yield (name, i) + + def node_fn(input: mrc.Observable, output: mrc.Subscriber): + input.pipe(ops.flat_map_async(generate)).subscribe(output) + + actual, raised_error = run_segment(input_data, node_fn) + + assert set(actual) == set(expected) + + def assert_sequential(name, actual): + # the output of individual generators must be sequential + import itertools + for i, (name, value) in zip(itertools.count(), filter(lambda pair: pair[0] == name, actual)): + assert i == value + + assert_sequential('a', actual) + assert_sequential('b', actual) + assert_sequential('c', actual) + def test_filter(run_segment):