Skip to content

Commit

Permalink
Fixing rendezvous payload types
Browse files Browse the repository at this point in the history
  • Loading branch information
mdemoret-nv committed Feb 15, 2024
1 parent 5adfb6d commit f37268b
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 17 deletions.
6 changes: 3 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,15 @@ set(CMAKE_CXX_EXTENSIONS ON)
set(CMAKE_POSITION_INDEPENDENT_CODE TRUE)
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)

add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:-fsanitize=address>")
add_link_options("$<$<COMPILE_LANGUAGE:CXX>:-fsanitize=address>")
# add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:-fsanitize=address>")
# add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:-fsanitize=address>")

# Setup cache before dependencies
# Configure CCache if requested
include(environment/init_ccache)

# Disable exporting compile commands for dependencies
set(CMAKE_EXPORT_COMPILE_COMMANDS OFF)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

# Create a custom target to allow preparing for style checks
add_custom_target(${PROJECT_NAME}_style_checks
Expand Down
10 changes: 8 additions & 2 deletions cpp/mrc/include/mrc/runtime/remote_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ class TypedValueDescriptor : public ValueDescriptor
public:
~TypedValueDescriptor() override
{
LOG(INFO) << "TypedValueDescriptor::~TypedValueDescriptor()";
DVLOG(20) << "TypedValueDescriptor::~TypedValueDescriptor()";
}

const T& value() const
Expand Down Expand Up @@ -350,7 +350,10 @@ class LocalDescriptor2 : public ValueDescriptor

private:
LocalDescriptor2(std::unique_ptr<codable::LocalSerializedWrapper> encoded_object,
std::unique_ptr<ValueDescriptor> value_descriptor = nullptr);
std::unique_ptr<ValueDescriptor> value_descriptor);

LocalDescriptor2(std::unique_ptr<codable::LocalSerializedWrapper> encoded_object,
std::vector<memory::buffer> payload_buffers);

// TODO(MDD): Quick hack to get this working. Need to restructure the objects a bit
std::unique_ptr<codable::LocalSerializedWrapper> encode(
Expand All @@ -362,6 +365,9 @@ class LocalDescriptor2 : public ValueDescriptor
std::unique_ptr<codable::LocalSerializedWrapper> m_encoded_object;

std::unique_ptr<ValueDescriptor> m_value_descriptor; // Necessary to keep the value alive when serializing

// Holds onto the payloads when deserializing the object. Must have the same lifetime as the encoded object
std::vector<memory::buffer> m_payload_buffers;
};

template <typename T>
Expand Down
2 changes: 1 addition & 1 deletion cpp/mrc/src/internal/memory/block_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class BlockManager final
auto key = reinterpret_cast<std::uintptr_t>(block.data()) + block.bytes();
DCHECK(!owns(block.data()) && !owns(reinterpret_cast<void*>(key - 1))) << "block manager already owns a block "
"with an overlapping address";
DVLOG(10) << "adding block: " << key << " - " << block.data() << "; " << block.bytes();
DVLOG(20) << "adding block: " << key << " - " << block.data() << "; " << block.bytes();
m_block_map[key] = std::move(block);
return m_block_map[key];
}
Expand Down
19 changes: 11 additions & 8 deletions cpp/mrc/src/public/runtime/remote_descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,12 @@ std::unique_ptr<codable::IDecodableStorage> RemoteDescriptor::release_storage()

ValueDescriptor::~ValueDescriptor()
{
LOG(INFO) << "ValueDescriptor::~ValueDescriptor()";
DVLOG(20) << "ValueDescriptor::~ValueDescriptor()";
}

LocalDescriptor2::~LocalDescriptor2()
{
LOG(INFO) << "LocalDescriptor2::~LocalDescriptor2()";
DVLOG(20) << "LocalDescriptor2::~LocalDescriptor2()";

m_value_descriptor.reset();
}
Expand Down Expand Up @@ -215,9 +215,6 @@ std::unique_ptr<LocalDescriptor2> LocalDescriptor2::from_remote(std::unique_ptr<
f.get();
}

// Clear the memory out now that the requests have finished
buffers.clear();

PortAddress2 port_address(remote_descriptor->encoded_object().source_address());

// For the remote descriptor message, send decrement to the remote resources
Expand All @@ -234,7 +231,7 @@ std::unique_ptr<LocalDescriptor2> LocalDescriptor2::from_remote(std::unique_ptr<
UCS_MEMORY_TYPE_HOST,
ucxx::AmReceiverCallbackInfo("MRC", 0));

return std::unique_ptr<LocalDescriptor2>(new LocalDescriptor2(std::move(local_obj)));
return std::unique_ptr<LocalDescriptor2>(new LocalDescriptor2(std::move(local_obj), std::move(buffers)));
}

LocalDescriptor2::LocalDescriptor2(std::unique_ptr<codable::LocalSerializedWrapper> encoded_object,
Expand All @@ -243,6 +240,12 @@ LocalDescriptor2::LocalDescriptor2(std::unique_ptr<codable::LocalSerializedWrapp
m_value_descriptor(std::move(value_descriptor))
{}

LocalDescriptor2::LocalDescriptor2(std::unique_ptr<codable::LocalSerializedWrapper> encoded_object,
std::vector<memory::buffer> payload_buffers) :
m_encoded_object(std::move(encoded_object)),
m_payload_buffers(std::move(payload_buffers))
{}

RemoteDescriptorImpl2::RemoteDescriptorImpl2(std::unique_ptr<codable::protos::RemoteSerializedObject> encoded_object,
std::unique_ptr<LocalDescriptor2> local_descriptor) :
m_serialized_object(std::move(encoded_object)),
Expand All @@ -251,7 +254,7 @@ RemoteDescriptorImpl2::RemoteDescriptorImpl2(std::unique_ptr<codable::protos::Re

RemoteDescriptorImpl2::~RemoteDescriptorImpl2()
{
LOG(INFO) << "RemoteDescriptorImpl2::~RemoteDescriptorImpl2()";
DVLOG(20) << "RemoteDescriptorImpl2::~RemoteDescriptorImpl2()";

m_local_descriptor.reset();
}
Expand Down Expand Up @@ -346,7 +349,7 @@ RemoteDescriptor2::RemoteDescriptor2(std::shared_ptr<RemoteDescriptorImpl2> impl

RemoteDescriptor2::~RemoteDescriptor2()
{
LOG(INFO) << "RemoteDescriptor2::~RemoteDescriptor2()";
DVLOG(20) << "RemoteDescriptor2::~RemoteDescriptor2()";

m_impl.reset();
}
Expand Down
54 changes: 51 additions & 3 deletions cpp/mrc/tests/test_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class TestExecutor : public ::testing::Test

static std::unique_ptr<pipeline::IPipeline> make_pipeline()
{
using transfer_t = std::vector<uint8_t>;
using transfer_t = std::vector<int>;

auto pipeline = mrc::make_pipeline();

Expand Down Expand Up @@ -107,9 +107,9 @@ class TestExecutor : public ::testing::Test
auto src = s.make_source<transfer_t>("rx_source", [](rxcpp::subscriber<transfer_t> s) {
using namespace mrc::memory::literals;

for (int i = 0; i < 2; i++)
for (int i = 0; i < 100; i++)
{
s.on_next(transfer_t(10, i));
s.on_next(transfer_t(1_MiB / sizeof(transfer_t::value_type), i));

// #ifndef NDEBUG
// boost::this_fiber::sleep_for(std::chrono::milliseconds(100));
Expand Down Expand Up @@ -491,6 +491,54 @@ TEST_F(TestExecutor, MultiNode)
machine_1.join();
}

TEST_F(TestExecutor, MultiNodeA)
{
auto options_1 = make_options();

options_1->architect_url("127.0.0.1:13337");
options_1->enable_server(true);

Executor machine_1(std::move(options_1));

auto pipeline_1 = make_pipeline();

auto& mapping_1 = machine_1.register_pipeline(std::move(pipeline_1));

mapping_1.get_segment("seg_4").set_enabled(false);

auto start_1 = boost::fibers::async([&] {
machine_1.start();
});

start_1.get();

machine_1.join();
}

TEST_F(TestExecutor, MultiNodeB)
{
auto options_2 = make_options();

options_2->architect_url("127.0.0.1:13337");
options_2->topology().user_cpuset("1");

Executor machine_2(std::move(options_2));

auto pipeline_2 = make_pipeline();

auto& mapping_2 = machine_2.register_pipeline(std::move(pipeline_2));

mapping_2.get_segment("seg_1").set_enabled(false);

auto start_2 = boost::fibers::async([&] {
machine_2.start();
});

start_2.get();

machine_2.join();
}

// TEST_F(TestExecutor, MultiNodeTwoSegmentExample)
// {
// GTEST_SKIP();
Expand Down

0 comments on commit f37268b

Please sign in to comment.