Skip to content

Commit

Permalink
initial impl dynamic_batcher
Browse files Browse the repository at this point in the history
  • Loading branch information
yczhang-nv committed Jun 10, 2024
1 parent bf9b553 commit 7ac9ad9
Showing 1 changed file with 43 additions and 7 deletions.
50 changes: 43 additions & 7 deletions cpp/mrc/include/mrc/node/dynamic_batcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "mrc/utils/type_utils.hpp"

#include <glog/logging.h>
#include <rxcpp/operators/rx-buffer_time_count.hpp>
#include <rxcpp/rx.hpp>

#include <exception>
Expand All @@ -51,7 +52,8 @@ class DynamicBatcher : public mrc::node::WritableProvider<T>,
using output_t = std::vector<T>;

public:
DynamicBatcher(size_t max_count) {
DynamicBatcher(size_t max_count, std::chrono::milliseconds duration)
: m_max_count(max_count), m_duration(duration) {
// Set the default channel
mrc::node::SinkChannelOwner<input_t>::set_channel(
std::make_unique<mrc::channel::BufferedChannel<input_t>>());
Expand All @@ -65,14 +67,32 @@ class DynamicBatcher : public mrc::node::WritableProvider<T>,
* @brief Runnable's entrypoint.
*/
void run(mrc::runnable::Context &ctx) override {
T input_data;
auto status = this->get_readable_edge()->await_read(input_data);

// TODO(Yuchen): fill out the implementation here

// T input_data;
// auto status = this->get_readable_edge()->await_read(input_data);

// Create an observable from the input channel
auto input_observable =
rxcpp::observable<>::create<T>([this](rxcpp::subscriber<T> s) {
T input_data;
while (this->get_readable_edge()->await_read(input_data) ==
mrc::channel::Status::success) {
s.on_next(input_data);
}
s.on_completed();
});

// Buffer the items from the input observable
auto buffered_observable = input_observable.buffer_with_time_or_count(
m_duration, m_max_count, rxcpp::observe_on_new_thread());

// Subscribe to the buffered observable
buffered_observable.subscribe(
[this](const std::vector<T> &buffer) {
this->get_writable_edge()->await_write(buffer);
},
[]() {
// Handle completion
});

// Only drop the output edges if we are rank 0
if (ctx.rank() == 0) {
Expand All @@ -85,7 +105,23 @@ class DynamicBatcher : public mrc::node::WritableProvider<T>,
/**
* @brief Runnable's state control, for stopping from MRC.
*/
void on_state_update(const state_t &state) final;
void on_state_update(const state_t &state) final {
switch (state) {
case state_t::Stop:
// Do nothing, we wait for the upstream channel to return closed
// m_stop_source.request_stop();
break;

case state_t::Kill:
m_stop_source.request_stop();
break;

default:
break;
}
}

std::stop_source m_stop_source;
size_t m_max_count;
std::chrono::milliseconds m_duration;
};

0 comments on commit 7ac9ad9

Please sign in to comment.