Skip to content

Commit

Permalink
fix io_scheduler/task_container circular reference
Browse files Browse the repository at this point in the history
  • Loading branch information
cwharris committed Feb 28, 2024
1 parent 77bd0b1 commit 9f7ffc8
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 4 deletions.
10 changes: 7 additions & 3 deletions cpp/mrc/include/mrc/coroutines/io_scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#pragma once

#include "mrc/coroutines/detail/poll_info.hpp"
#include "mrc/coroutines/fd.hpp"
Expand All @@ -60,11 +60,15 @@
#include <vector>

namespace mrc::coroutines {

class IoScheduler : public Scheduler
{
private:
using timed_events_t = detail::PollInfo::timed_events_t;

public:
static std::shared_ptr<IoScheduler> get_instance();

class schedule_operation;
friend schedule_operation;

Expand Down Expand Up @@ -124,7 +128,7 @@ class IoScheduler : public Scheduler
auto operator=(const IoScheduler&) -> IoScheduler& = delete;
auto operator=(IoScheduler&&) -> IoScheduler& = delete;

~IoScheduler();
~IoScheduler() override;

/**
* Given a ThreadStrategy::manual this function should be called at regular intervals to
Expand Down
5 changes: 5 additions & 0 deletions cpp/mrc/include/mrc/coroutines/scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#pragma once

#include "mrc/coroutines/task.hpp"
#include "mrc/coroutines/time.hpp"

#include <coroutine>
#include <cstddef>
Expand All @@ -44,6 +45,10 @@ class Scheduler : public std::enable_shared_from_this<Scheduler>
* @brief Suspends the current function and resumes it according to the scheduler's implementation.
*/
[[nodiscard]] virtual Task<> yield() = 0;

[[nodiscard]] virtual Task<> yield_for(std::chrono::milliseconds amount) = 0;

[[nodiscard]] virtual Task<> yield_until(time_point_t time) = 0;
};

} // namespace mrc::coroutines
21 changes: 20 additions & 1 deletion cpp/mrc/src/public/coroutines/io_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,32 @@
using namespace std::chrono_literals;

namespace mrc::coroutines {

std::shared_ptr<IoScheduler> IoScheduler::get_instance()
{
static std::shared_ptr<IoScheduler> instance;
static std::mutex instance_mutex{};

if (instance == nullptr)
{
auto lock = std::lock_guard(instance_mutex);

if (instance == nullptr)
{
instance = std::make_shared<IoScheduler>();
}
}

return instance;
}

IoScheduler::IoScheduler(Options opts) :
m_opts(std::move(opts)),
m_epoll_fd(epoll_create1(EPOLL_CLOEXEC)),
m_shutdown_fd(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK)),
m_timer_fd(timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK | TFD_CLOEXEC)),
m_schedule_fd(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK)),
m_owned_tasks(new mrc::coroutines::TaskContainer(this->shared_from_this()))
m_owned_tasks(new mrc::coroutines::TaskContainer(std::shared_ptr<IoScheduler>(this, [](auto _) {})))
{
if (opts.execution_strategy == ExecutionStrategy::process_tasks_on_thread_pool)
{
Expand Down
1 change: 1 addition & 0 deletions cpp/mrc/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
add_executable(test_mrc
coroutines/test_async_generator.cpp
coroutines/test_event.cpp
coroutines/test_io_scheduler.cpp
coroutines/test_latch.cpp
coroutines/test_ring_buffer.cpp
coroutines/test_task_container.cpp
Expand Down
43 changes: 43 additions & 0 deletions cpp/mrc/tests/coroutines/test_io_scheduler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "mrc/coroutines/async_generator.hpp"
#include "mrc/coroutines/io_scheduler.hpp"
#include "mrc/coroutines/sync_wait.hpp"
#include "mrc/coroutines/task.hpp"

#include <gtest/gtest.h>

#include <coroutine>

using namespace mrc;
using namespace std::chrono_literals;

class TestCoroIoScheduler : public ::testing::Test
{};

TEST_F(TestCoroIoScheduler, YieldFor)
{
auto scheduler = coroutines::IoScheduler::get_instance();

auto task = [scheduler]() -> coroutines::Task<> {
// co_await scheduler->yield_for(1000ms);
co_return;
};

coroutines::sync_wait(task());
}
14 changes: 14 additions & 0 deletions python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
#include "pymrc/utilities/object_wrappers.hpp"

#include <boost/fiber/future/async.hpp>
#include <mrc/coroutines/io_scheduler.hpp>
#include <mrc/coroutines/scheduler.hpp>
#include <mrc/coroutines/task.hpp>
#include <mrc/coroutines/task_container.hpp>
#include <mrc/coroutines/time.hpp>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>

Expand Down Expand Up @@ -90,6 +92,18 @@ class AsyncioScheduler : public mrc::coroutines::Scheduler
co_await ContinueOnLoopOperation(m_loop);
}

[[nodiscard]] coroutines::Task<> yield_for(std::chrono::milliseconds amount) override

Check warning on line 95 in python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp#L95

Added line #L95 was not covered by tests
{
co_await coroutines::IoScheduler::get_instance()->yield_for(amount);
co_await ContinueOnLoopOperation(m_loop);
};

Check warning on line 99 in python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp#L99

Added line #L99 was not covered by tests

[[nodiscard]] coroutines::Task<> yield_until(mrc::coroutines::time_point_t time) override

Check warning on line 101 in python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp#L101

Added line #L101 was not covered by tests
{
co_await coroutines::IoScheduler::get_instance()->yield_until(time);
co_await ContinueOnLoopOperation(m_loop);
};

Check warning on line 105 in python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp

View check run for this annotation

Codecov / codecov/patch

python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp#L105

Added line #L105 was not covered by tests

private:
mrc::pymrc::PyHolder m_loop;
};
Expand Down

0 comments on commit 9f7ffc8

Please sign in to comment.