Skip to content

Commit

Permalink
temp commit
Browse files Browse the repository at this point in the history
  • Loading branch information
yczhang-nv committed Jun 18, 2024
1 parent 7ac9ad9 commit f9dbdd7
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 11 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ enable_language(CUDA)
set(MRC_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR})

# Set a default build type if none was specified
rapids_cmake_build_type(Release)
rapids_cmake_build_type(Debug)

set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "mrc/node/rx_subscribable.hpp"
#include "mrc/runnable/runnable.hpp"
#include "mrc/utils/type_utils.hpp"
#include "rxcpp/operators/rx-observe_on.hpp"

#include <glog/logging.h>
#include <rxcpp/operators/rx-buffer_time_count.hpp>
Expand All @@ -39,6 +40,7 @@
#include <memory>
#include <mutex>

namespace mrc::node {
template <typename T, typename ContextT>
class DynamicBatcher : public mrc::node::WritableProvider<T>,
public mrc::node::ReadableAcceptor<T>,
Expand Down Expand Up @@ -81,9 +83,11 @@ class DynamicBatcher : public mrc::node::WritableProvider<T>,
s.on_completed();
});

// DVLOG(1) << "DynamicBatcher: m_duration: " << m_duration.count() << std::endl;

// Buffer the items from the input observable
auto buffered_observable = input_observable.buffer_with_time_or_count(
m_duration, m_max_count, rxcpp::observe_on_new_thread());
m_duration, m_max_count, rxcpp::observe_on_event_loop());

// Subscribe to the buffered observable
buffered_observable.subscribe(
Expand Down Expand Up @@ -122,6 +126,7 @@ class DynamicBatcher : public mrc::node::WritableProvider<T>,
}

std::stop_source m_stop_source;
size_t m_max_count;
int m_max_count;
std::chrono::milliseconds m_duration;
};
} // namespace mrc::node
43 changes: 43 additions & 0 deletions cpp/mrc/tests/test_segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mrc/benchmarking/trace_statistics.hpp"
#include "mrc/exceptions/runtime_error.hpp"
#include "mrc/node/operators/broadcast.hpp"
#include "mrc/node/operators/dynamic_batcher.hpp"
#include "mrc/node/rx_node.hpp"
#include "mrc/node/rx_sink.hpp"
#include "mrc/node/rx_source.hpp"
Expand Down Expand Up @@ -1122,4 +1123,46 @@ TEST_F(TestSegment, SegmentGetEgressNotEgressError)
*/
}

TEST_F(TestSegment, SegmentDynamicBatcher)
{
unsigned int iterations{3};
std::atomic<unsigned int> sink1_results{0};
float sink2_results{0};
std::mutex mux;

auto init = [&](segment::IBuilder& segment) {
auto src = segment.make_source<int>("src", [&](rxcpp::subscriber<int>& s) {
for (size_t i = 0; i < iterations && s.is_subscribed(); i++)
{
s.on_next(1);
s.on_next(2);
s.on_next(3);
}

s.on_completed();
});

auto dynamic_batcher = segment.construct_object<node::DynamicBatcher<int, runnable::Context>>("dynamic_batcher", 2, std::chrono::milliseconds(100));

segment.make_edge(src, dynamic_batcher);

auto sink = segment.make_sink<std::vector<int>>("sink", [&](std::vector<int> x) {
DVLOG(1) << "Sink got vector" << std::endl;
for (auto i : x)
{
DVLOG(1) << "Sink got value: " << i << std::endl;
// sink1_results.fetch_add(i, std::memory_order_relaxed);
}
});

segment.make_edge(dynamic_batcher, sink);
};

auto segdef = Segment::create("dynamic_batcher_test", init);

auto pipeline = mrc::make_pipeline();
pipeline->register_segment(std::move(segdef));
execute_pipeline(std::move(pipeline));
}

} // namespace mrc
40 changes: 32 additions & 8 deletions mrc.code-workspace
Original file line number Diff line number Diff line change
Expand Up @@ -86,25 +86,49 @@
"type": "cppdbg"
},
{
"MIMode": "gdb",
"MIMode": "lldb",
"args": [],
"cwd": "${workspaceFolder}",
"environment": [],
"externalConsole": false,
"miDebuggerPath": "gdb",
"name": "debug bench_mrc.x",
"preLaunchTask": "C/C++: g++ build active file",
"program": "${workspaceFolder}/build/benchmarks/bench_mrc",
"miDebuggerPath": "lldb",
"name": "debug test_mrc.x with lldb",
// "preLaunchTask": "C/C++: g++ build active file",
"program": "${workspaceFolder}/build/cpp/mrc/tests/test_mrc.x",
"request": "launch",
"setupCommands": [
{
"description": "Enable pretty-printing for gdb",
"description": "Enable pretty-printing for lldb",
"ignoreFailures": true,
"text": "-enable-pretty-printing"
"text": "command script import pretty_printers.py"
}
],
"justMyCode": true,
"stopAtEntry": false,
"type": "cppdbg"
"type": "lldb"
},
{
"MIMode": "lldb",
"args": [],
"cwd": "${workspaceFolder}",
"environment": [],
"externalConsole": false,
"miDebuggerPath": "lldb",
"name": "debug TestSegment.SegmentDynamicBatcher with lldb",
// "preLaunchTask": "C/C++: g++ build active file",
"program": "${workspaceFolder}/build/cpp/mrc/tests/test_mrc.x",
"request": "launch",
"setupCommands": [
{
"description": "Enable pretty-printing for lldb",
"ignoreFailures": true,
"text": "command script import pretty_printers.py"
}
],
"args": ["--gtest_filter=TestSegment.SegmentDynamicBatcher"],
"justMyCode": true,
"stopAtEntry": false,
"type": "lldb"
},
{
"MIMode": "gdb",
Expand Down

0 comments on commit f9dbdd7

Please sign in to comment.