Skip to content

Commit

Permalink
Adding RoundRobinRouter node type for distributing values to downstre…
Browse files Browse the repository at this point in the history
…am nodes (#449)

1. Adds a new C++ type `RoundRobinRouterTypeless` which is very similar to `BroadcastTypeless` except it only pushes values to one of the downstream connections instead of copying
2. Adds a new Python type `RoundRobinRouter` which allows using the `RoundRobinRouterTypeless` from python
3. Adds a C++ test to confirm connectivity
4. Adds Python tests to verify output

Authors:
  - Michael Demoret (https://github.com/mdemoret-nv)

Approvers:
  - Devin Robison (https://github.com/drobison00)

URL: #449
  • Loading branch information
mdemoret-nv authored Mar 7, 2024
1 parent 3010601 commit 2dbd985
Show file tree
Hide file tree
Showing 4 changed files with 287 additions and 35 deletions.
144 changes: 144 additions & 0 deletions cpp/mrc/include/mrc/node/operators/round_robin_router_typeless.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "mrc/edge/deferred_edge.hpp"

#include <atomic>
#include <memory>
#include <mutex>

namespace mrc::node {

class RoundRobinRouterTypeless : public edge::IWritableProviderBase, public edge::IWritableAcceptorBase
{
public:
std::shared_ptr<edge::WritableEdgeHandle> get_writable_edge_handle() const override
{
auto* self = const_cast<RoundRobinRouterTypeless*>(this);

// Create a new upstream edge. On connection, have it attach to any downstreams
auto deferred_ingress = std::make_shared<edge::DeferredWritableEdgeHandle>(
[self](std::shared_ptr<edge::DeferredWritableMultiEdgeBase> deferred_edge) {
// Set the broadcast indices function
deferred_edge->set_indices_fn([self](edge::DeferredWritableMultiEdgeBase& deferred_edge) {
// Increment the index and return the key for that index
auto next_idx = self->m_current_idx++;

auto current_keys = deferred_edge.edge_connection_keys();

return std::vector<size_t>{current_keys[next_idx % current_keys.size()]};
});

// Need to work with weak ptr here otherwise we will keep it from closing
std::weak_ptr<edge::DeferredWritableMultiEdgeBase> weak_deferred_edge = deferred_edge;

// Use a connector here in case the object never gets set to an edge
deferred_edge->add_connector([self, weak_deferred_edge]() {
// Lock whenever working on the handles
std::unique_lock<std::mutex> lock(self->m_mutex);

// Save to the upstream handles
self->m_upstream_handles.emplace_back(weak_deferred_edge);

auto deferred_edge = weak_deferred_edge.lock();

CHECK(deferred_edge) << "Edge was destroyed before making connection.";

for (const auto& downstream : self->m_downstream_handles)
{
auto count = deferred_edge->edge_connection_count();

// Connect
deferred_edge->set_writable_edge_handle(count, downstream);
}

// Now add a disconnector that will remove it from the list
deferred_edge->add_disconnector([self, weak_deferred_edge]() {
// Need to lock here since this could be driven by different progress engines
std::unique_lock<std::mutex> lock(self->m_mutex);

bool is_expired = weak_deferred_edge.expired();

// Cull all expired ptrs from the list
auto iter = self->m_upstream_handles.begin();

while (iter != self->m_upstream_handles.end())
{
if ((*iter).expired())
{
iter = self->m_upstream_handles.erase(iter);
}
else
{
++iter;
}
}

// If there are no more upstream handles, then delete the downstream
if (self->m_upstream_handles.empty())
{
self->m_downstream_handles.clear();
}
});
});
});

return deferred_ingress;
}

edge::EdgeTypeInfo writable_provider_type() const override
{
return edge::EdgeTypeInfo::create_deferred();
}

void set_writable_edge_handle(std::shared_ptr<edge::WritableEdgeHandle> ingress) override
{
// Lock whenever working on the handles
std::unique_lock<std::mutex> lock(m_mutex);

// We have a new downstream object. Hold onto it
m_downstream_handles.push_back(ingress);

// If we have an upstream object, try to make a connection now
for (auto& upstream_weak : m_upstream_handles)
{
auto upstream = upstream_weak.lock();

CHECK(upstream) << "Upstream edge went out of scope before downstream edges were connected";

auto count = upstream->edge_connection_count();

// Connect
upstream->set_writable_edge_handle(count, ingress);
}
}

edge::EdgeTypeInfo writable_acceptor_type() const override
{
return edge::EdgeTypeInfo::create_deferred();
}

private:
std::mutex m_mutex;
std::atomic_size_t m_current_idx{0};
std::vector<std::weak_ptr<edge::DeferredWritableMultiEdgeBase>> m_upstream_handles;
std::vector<std::shared_ptr<edge::WritableEdgeHandle>> m_downstream_handles;
};

} // namespace mrc::node
18 changes: 17 additions & 1 deletion cpp/mrc/tests/test_edges.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -29,6 +29,7 @@
#include "mrc/node/operators/broadcast.hpp"
#include "mrc/node/operators/combine_latest.hpp"
#include "mrc/node/operators/node_component.hpp"
#include "mrc/node/operators/round_robin_router_typeless.hpp"
#include "mrc/node/operators/router.hpp"
#include "mrc/node/rx_node.hpp"
#include "mrc/node/sink_channel_owner.hpp"
Expand Down Expand Up @@ -666,6 +667,21 @@ TEST_F(TestEdges, SourceToRouterToDifferentSinks)
sink1->run();
}

TEST_F(TestEdges, SourceToRoundRobinRouterTypelessToDifferentSinks)
{
auto source = std::make_shared<node::TestSource<int>>();
auto router = std::make_shared<node::RoundRobinRouterTypeless>();
auto sink1 = std::make_shared<node::TestSink<int>>();
auto sink2 = std::make_shared<node::TestSinkComponent<int>>();

mrc::make_edge(*source, *router);
mrc::make_edge(*router, *sink1);
mrc::make_edge(*router, *sink2);

source->run();
sink1->run();
}

TEST_F(TestEdges, SourceToBroadcastToSink)
{
auto source = std::make_shared<node::TestSource<int>>();
Expand Down
12 changes: 11 additions & 1 deletion python/mrc/core/node.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -20,6 +20,7 @@
#include "pymrc/utils.hpp"

#include "mrc/node/operators/broadcast.hpp"
#include "mrc/node/operators/round_robin_router_typeless.hpp"
#include "mrc/segment/builder.hpp"
#include "mrc/segment/object.hpp"
#include "mrc/utils/string_utils.hpp"
Expand Down Expand Up @@ -58,6 +59,15 @@ PYBIND11_MODULE(node, py_mod)
return node;
}));

py::class_<mrc::segment::Object<node::RoundRobinRouterTypeless>,
mrc::segment::ObjectProperties,
std::shared_ptr<mrc::segment::Object<node::RoundRobinRouterTypeless>>>(py_mod, "RoundRobinRouter")
.def(py::init<>([](mrc::segment::IBuilder& builder, std::string name) {
auto node = builder.construct_object<node::RoundRobinRouterTypeless>(name);

return node;
}));

py_mod.attr("__version__") = MRC_CONCAT_STR(mrc_VERSION_MAJOR << "." << mrc_VERSION_MINOR << "."
<< mrc_VERSION_PATCH);
}
Expand Down
148 changes: 115 additions & 33 deletions python/tests/test_edges.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -252,6 +252,16 @@ def add_broadcast(seg: mrc.Builder, *upstream: mrc.SegmentObject):
return node


def add_round_robin_router(seg: mrc.Builder, *upstream: mrc.SegmentObject):

node = mrc.core.node.RoundRobinRouter(seg, "RoundRobinRouter")

for u in upstream:
seg.make_edge(u, node)

return node


# THIS TEST IS CAUSING ISSUES WHEN RUNNING ALL TESTS TOGETHER

# @dataclasses.dataclass
Expand Down Expand Up @@ -431,14 +441,15 @@ def fail_if_more_derived_type(combo: typing.Tuple):
@pytest.mark.parametrize("source_cpp", [True, False], ids=["source_cpp", "source_py"])
@pytest.mark.parametrize("sink1_cpp", [True, False], ids=["sink1_cpp", "sink2_py"])
@pytest.mark.parametrize("sink2_cpp", [True, False], ids=["sink2_cpp", "sink2_py"])
@pytest.mark.parametrize("source_type,sink1_type,sink2_type",
gen_parameters("source",
"sink1",
"sink2",
is_fail_fn=fail_if_more_derived_type,
values={
"base": m.Base, "derived": m.DerivedA
}))
@pytest.mark.parametrize(
"source_type,sink1_type,sink2_type",
gen_parameters("source",
"sink1",
"sink2",
is_fail_fn=fail_if_more_derived_type,
values={
"base": m.Base, "derived": m.DerivedA
}))
def test_source_to_broadcast_to_sinks(run_segment,
sink1_component: bool,
sink2_component: bool,
Expand Down Expand Up @@ -503,13 +514,84 @@ def segment_init(seg: mrc.Builder):
assert results == expected_node_counts


@pytest.mark.parametrize("sink1_component,sink2_component",
gen_parameters("sink1", "sink2", is_fail_fn=lambda x: False))
@pytest.mark.parametrize("source_cpp", [True, False], ids=["source_cpp", "source_py"])
@pytest.mark.parametrize("sink1_cpp", [True, False], ids=["sink1_cpp", "sink2_py"])
@pytest.mark.parametrize("sink2_cpp", [True, False], ids=["sink2_cpp", "sink2_py"])
@pytest.mark.parametrize(
"source_type,sink1_type,sink2_type",
gen_parameters("source",
"sink1",
"sink2",
is_fail_fn=fail_if_more_derived_type,
values={
"base": m.Base, "derived": m.DerivedA
}))
def test_source_to_round_robin_router_to_sinks(run_segment,
sink1_component: bool,
sink2_component: bool,
source_cpp: bool,
sink1_cpp: bool,
sink2_cpp: bool,
source_type: type,
sink1_type: type,
sink2_type: type):

def segment_init(seg: mrc.Builder):

source = add_source(seg, is_cpp=source_cpp, data_type=source_type, is_component=False)
broadcast = add_round_robin_router(seg, source)
add_sink(seg,
broadcast,
is_cpp=sink1_cpp,
data_type=sink1_type,
is_component=sink1_component,
suffix="1",
count=3)
add_sink(seg,
broadcast,
is_cpp=sink2_cpp,
data_type=sink2_type,
is_component=sink2_component,
suffix="2",
count=2)

results = run_segment(segment_init)

assert results == expected_node_counts


@pytest.mark.parametrize("sink1_component,sink2_component",
gen_parameters("sink1", "sink2", is_fail_fn=lambda x: False))
@pytest.mark.parametrize("source_cpp", [True, False], ids=["source_cpp", "source_py"])
@pytest.mark.parametrize("sink1_cpp", [True, False], ids=["sink1_cpp", "sink1_py"])
@pytest.mark.parametrize("sink2_cpp", [True, False], ids=["sink2_cpp", "sink2_py"])
def test_multi_source_to_round_robin_router_to_multi_sink(run_segment,
sink1_component: bool,
sink2_component: bool,
source_cpp: bool,
sink1_cpp: bool,
sink2_cpp: bool):

def segment_init(seg: mrc.Builder):

source1 = add_source(seg, is_cpp=source_cpp, data_type=m.Base, is_component=False, suffix="1")
source2 = add_source(seg, is_cpp=source_cpp, data_type=m.Base, is_component=False, suffix="2")
broadcast = add_round_robin_router(seg, source1, source2)
add_sink(seg, broadcast, is_cpp=sink1_cpp, data_type=m.Base, is_component=sink1_component, suffix="1")
add_sink(seg, broadcast, is_cpp=sink2_cpp, data_type=m.Base, is_component=sink2_component, suffix="2")

results = run_segment(segment_init)

assert results == expected_node_counts


@pytest.mark.parametrize("source_cpp", [True, False], ids=["source_cpp", "source_py"])
@pytest.mark.parametrize("source_type",
gen_parameters("source",
is_fail_fn=lambda _: False,
values={
"base": m.Base, "derived": m.DerivedA
}))
@pytest.mark.parametrize(
"source_type", gen_parameters("source", is_fail_fn=lambda _: False, values={
"base": m.Base, "derived": m.DerivedA
}))
def test_source_to_null(run_segment, source_cpp: bool, source_type: type):

def segment_init(seg: mrc.Builder):
Expand All @@ -522,24 +604,24 @@ def segment_init(seg: mrc.Builder):
assert results == expected_node_counts


@pytest.mark.parametrize("source_cpp,node_cpp",
gen_parameters("source", "node", is_fail_fn=lambda _: False, values={
"cpp": True, "py": False
}))
@pytest.mark.parametrize("source_type,node_type",
gen_parameters("source",
"node",
is_fail_fn=fail_if_more_derived_type,
values={
"base": m.Base, "derived": m.DerivedA
}))
@pytest.mark.parametrize("source_component,node_component",
gen_parameters("source",
"node",
is_fail_fn=lambda x: x[0] and x[1],
values={
"run": False, "com": True
}))
@pytest.mark.parametrize(
"source_cpp,node_cpp",
gen_parameters("source", "node", is_fail_fn=lambda _: False, values={
"cpp": True, "py": False
}))
@pytest.mark.parametrize(
"source_type,node_type",
gen_parameters("source",
"node",
is_fail_fn=fail_if_more_derived_type,
values={
"base": m.Base, "derived": m.DerivedA
}))
@pytest.mark.parametrize(
"source_component,node_component",
gen_parameters("source", "node", is_fail_fn=lambda x: x[0] and x[1], values={
"run": False, "com": True
}))
def test_source_to_node_to_null(run_segment,
source_cpp: bool,
node_cpp: bool,
Expand Down

0 comments on commit 2dbd985

Please sign in to comment.