Skip to content

Commit

Permalink
Revert changes made to CMake and add thorough commenting
Browse files Browse the repository at this point in the history
  • Loading branch information
jadu-nv committed Aug 18, 2024
1 parent 3147b96 commit 19451ab
Show file tree
Hide file tree
Showing 12 changed files with 180 additions and 148 deletions.
42 changes: 42 additions & 0 deletions cpp/mrc/include/mrc/runtime/remote_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,21 +418,58 @@ std::unique_ptr<TypedValueDescriptor<T>> TypedValueDescriptor<T>::from_local(
new TypedValueDescriptor<T>(mrc::codable::decode2<T>(local_descriptor->encoded_object())));
}

/**
* @brief Descriptor2 class used to faciliate communication between any arbitrary pair of machines. Supports multi-node,
* multi-gpu communication, and asynchronous data transfer.
*/
class Descriptor2
{
public:
/**
* @brief Gets the protobuf object associated with this descriptor instance
*
* @return codable::protos::DescriptorObject&
*/
virtual codable::protos::DescriptorObject& encoded_object();

/**
* @brief Serialize the encoded object stored by this descriptor into a byte stream for remote communication
*
* @param mr Instance of memory_resource for allocating a memory_buffer to return
* @return memory::buffer
*/
memory::buffer serialize(std::shared_ptr<memory::memory_resource> mr);

/**
* @brief Deserialize the encoded object stored by this descriptor into a class T instance
*
* @return T
*/
template <typename T>
[[nodiscard]] const T deserialize();

/**
* @brief Creates a Descriptor2 instance from a class T value
*
* @param value class T instance
* @param data_plane_resources reference to DataPlaneResources2 for remote communication
* @return std::shared_ptr<Descriptor2>
*/
template <typename T>
static std::shared_ptr<Descriptor2> create_from_value(T value, data_plane::DataPlaneResources2& data_plane_resources);

/**
* @brief Creates a Descriptor2 instance from a byte stream
*
* @param view byte stream
* @param data_plane_resources reference to DataPlaneResources2 for remote communication
* @return std::shared_ptr<Descriptor2>
*/
static std::shared_ptr<Descriptor2> create_from_bytes(memory::buffer_view&& view, data_plane::DataPlaneResources2& data_plane_resources);

/**
* @brief Fetches all deferred payloads from the sending remote machine
*/
void fetch_remote_payloads();

protected:
Expand All @@ -452,12 +489,16 @@ class Descriptor2
data_plane::DataPlaneResources2& m_data_plane_resources;
};

/**
* @brief Class used for type erasure of Descriptor2 when serialized with class T instance
*/
template <typename T>
class TypedDescriptor : public Descriptor2
{
public:
codable::protos::DescriptorObject& encoded_object()
{
// If the encoded object does not exist yet, lazily create it
if (!m_encoded_object)
{
m_encoded_object = std::move(mrc::codable::encode2<T>(std::any_cast<const T&>(m_value)));
Expand All @@ -470,6 +511,7 @@ class TypedDescriptor : public Descriptor2
template <typename U>
friend std::shared_ptr<Descriptor2> Descriptor2::create_from_value(U value, data_plane::DataPlaneResources2& data_plane_resources);

// Private constructor to prohibit instantiation of this class outside of use in create_from_value
TypedDescriptor(T value, data_plane::DataPlaneResources2& data_plane_resources):
Descriptor2(std::move(value), data_plane_resources) {}
};
Expand Down
37 changes: 25 additions & 12 deletions cpp/mrc/src/internal/data_plane/data_plane_resources.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ DataPlaneResources2::DataPlaneResources2()
m_registration_cache3 = std::make_shared<ucx::RegistrationCache3>(m_context);

// When DataPlanResources2 initializes, m_max_remote_descriptors is initialized to std::numeric_limits<uint64_t>::max()
// By default, m_remote_descriptors_semaphore should have capacity = practical limit, 100000
// By default, the following should have capacity = practical limit, 100000
m_recv_descriptors = std::unique_ptr<coroutines::ClosableRingBuffer<std::shared_ptr<runtime::Descriptor2>>>(
new coroutines::ClosableRingBuffer<std::shared_ptr<runtime::Descriptor2>>({.capacity = uint64_t(100000)}));

m_remote_descriptors_semaphore = std::unique_ptr<coroutines::Semaphore>(
new coroutines::Semaphore{{.capacity = static_cast<uint64_t>(100000)}});

Expand All @@ -180,6 +183,7 @@ DataPlaneResources2::DataPlaneResources2()
auto* message = reinterpret_cast<remote_descriptor::DescriptorPullCompletionMessage*>(
req->getRecvBuffer()->data());

// Await for callback function to decrement shared_ptr reference count or signal end-of-life of a descriptor
coroutines::sync_wait(complete_remote_pull(message));
});
m_worker->registerAmReceiverCallback(
Expand All @@ -199,12 +203,16 @@ DataPlaneResources2::DataPlaneResources2()
req->getRecvBuffer()->getSize(),
mrc::memory::memory_kind::host);

std::shared_ptr<runtime::Descriptor2> recv_descriptor = runtime::Descriptor2::create_from_bytes(std::move(buffer_view), *this);
// Create the descriptor object from received data
// Note that we do not immediately call fetch_remote_payloads. This callback function runs on a UCXX thread and
// should freed ASAP. Defer remote payload pulling when descriptor object is actually consumed
std::shared_ptr<runtime::Descriptor2> recv_descriptor =
runtime::Descriptor2::create_from_bytes(std::move(buffer_view), *this);

// Although ClosableRingBuffer::write is a coroutine, write always completes instantaneously without awaiting.
// ClosablRingBuffer size is always >= m_max_remote_descriptors, so there is always an empty slot.
auto write_descriptor = [this, recv_descriptor]() -> coroutines::Task<void> {
co_await m_recv_descriptors.write(recv_descriptor);
co_await m_recv_descriptors->write(recv_descriptor);
co_return;
};

Expand Down Expand Up @@ -340,7 +348,6 @@ std::shared_ptr<ucxx::Request> DataPlaneResources2::memory_recv_async(std::share
uintptr_t remote_addr,
const std::string& serialized_rkey)
{
// Const cast away because UCXX only accepts void*
auto rkey = ucxx::createRemoteKeyFromSerialized(endpoint, serialized_rkey);
auto request = endpoint->memGet(addr, bytes, rkey);

Expand Down Expand Up @@ -408,8 +415,9 @@ coroutines::Task<std::shared_ptr<ucxx::Request>> DataPlaneResources2::await_am_s
{
coroutines::Event event{};

// Const cast away because UCXX only accepts void*
auto request = endpoint->amSend(const_cast<void*>(buffer_view.data()),
// Use AmReceiverCallbackInfo to handle receiving/processing message downstream and lambda callback function
// to signal send request completion
auto request = endpoint->amSend(const_cast<void*>(buffer_view.data()), // Const cast away, UCXX only accepts void*
buffer_view.bytes(),
ucx::to_ucs_memory_type(buffer_view.kind()),
ucxx::AmReceiverCallbackInfo("MRC", 1),
Expand Down Expand Up @@ -440,7 +448,6 @@ coroutines::Task<std::shared_ptr<ucxx::Request>> DataPlaneResources2::await_send
std::shared_ptr<ucxx::Endpoint> endpoint)
{
// Wait until there is an empty slot to register remote descriptor
// Require register_remote_descriptor before descriptor serialize as serialize requires object_id to be in the protobuf
uint64_t object_id = co_await this->register_remote_descriptor(send_descriptor);

// Serialize the descriptor's protobuf into a byte stream for remote communication
Expand All @@ -452,17 +459,19 @@ coroutines::Task<std::shared_ptr<ucxx::Request>> DataPlaneResources2::await_send

coroutines::Task<std::shared_ptr<runtime::Descriptor2>> DataPlaneResources2::await_recv_descriptor()
{
auto read_element = co_await m_recv_descriptors.read();
// Await and get descriptor object from shared buffer
auto read_element = co_await m_recv_descriptors->read();
std::shared_ptr<runtime::Descriptor2> recv_descriptor = std::move(*read_element);

// Now that user is consuming the descriptor object, pull deferred payloads from remote machine
recv_descriptor->fetch_remote_payloads();

co_return recv_descriptor;
}

coroutines::Task<uint64_t> DataPlaneResources2::register_remote_descriptor(std::shared_ptr<runtime::Descriptor2> descriptor)
{
// If the descriptor has an object_id > 0, the descriptor has already been registered and should not be re-registered
// If the descriptor has an object_id > 0, descriptor has already been registered and should not be re-registered
auto object_id = descriptor->encoded_object().object_id();
if (object_id > 0)
{
Expand All @@ -477,8 +486,7 @@ coroutines::Task<uint64_t> DataPlaneResources2::register_remote_descriptor(std::
descriptor->encoded_object().set_object_id(object_id);

// Wait for semaphore to ensure that we have an empty slot to register the current descriptor
co_await m_remote_descriptors_semaphore->acquire(); // Directly await the semaphore

co_await m_remote_descriptors_semaphore->acquire();
{
std::unique_lock lock(m_remote_descriptors_mutex);
m_descriptor_by_id[object_id].push_back(descriptor);
Expand All @@ -494,7 +502,6 @@ coroutines::Task<void> DataPlaneResources2::complete_remote_pull(remote_descript
// Once we've completed pulling of a descriptor, we remove a descriptor shared ptr from the vector
// When the vector becomes empty, there will be no more shared ptrs pointing to the descriptor object,
// it will be destructed accordingly.
// We should also remove that mapping as the object_id corresponding to that mapping will not be reused.
auto& descriptors = m_descriptor_by_id[message->object_id];
descriptors.pop_back();
if (descriptors.size() == 0)
Expand Down Expand Up @@ -523,6 +530,12 @@ uint64_t DataPlaneResources2::registered_remote_descriptor_ptr_count(uint64_t ob
void DataPlaneResources2::set_max_remote_descriptors(uint64_t max_remote_descriptors)
{
m_max_remote_descriptors = max_remote_descriptors;

// Update the remote descriptor ClosableRingBuffer and Semaphore capacity
m_recv_descriptors = std::unique_ptr<coroutines::ClosableRingBuffer<std::shared_ptr<runtime::Descriptor2>>>(
new coroutines::ClosableRingBuffer<std::shared_ptr<runtime::Descriptor2>>(
{.capacity = std::min(m_max_remote_descriptors, uint64_t(100000))}));

m_remote_descriptors_semaphore = std::unique_ptr<coroutines::Semaphore>(
new coroutines::Semaphore{{.capacity = std::min(m_max_remote_descriptors, static_cast<uint64_t>(100000))}});
}
Expand Down
13 changes: 11 additions & 2 deletions cpp/mrc/src/internal/data_plane/data_plane_resources.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class DataPlaneResources2
bool has_instance_id() const;
uint64_t get_instance_id() const;

// Should only be called when there are no in-flight messages as m_recv_descriptors will be reset
void set_max_remote_descriptors(uint64_t max_remote_descriptors);

ucxx::Context& context() const;
Expand Down Expand Up @@ -202,14 +203,19 @@ class DataPlaneResources2
std::size_t bytes,
ucs_memory_type_t mem_type);

// Coroutine to asynchronously send message to remote machine
coroutines::Task<std::shared_ptr<ucxx::Request>> await_am_send(std::shared_ptr<ucxx::Endpoint> endpoint,
memory::const_buffer_view buffer_view);

std::shared_ptr<ucxx::Request> am_recv_async(std::shared_ptr<ucxx::Endpoint> endpoint);

// Coroutine to async register, serialize, and send a descriptor to the specified endpoint
// Relies on callback to receive the message. Must be used in tandem with await_recv_descriptor
coroutines::Task<std::shared_ptr<ucxx::Request>> await_send_descriptor(
std::shared_ptr<runtime::Descriptor2> send_descriptor,
std::shared_ptr<ucxx::Endpoint> endpoint);

// Coroutine to async await on new descriptor object in shared buffer, fetch deferred payloads from remote machine
coroutines::Task<std::shared_ptr<runtime::Descriptor2>> await_recv_descriptor();

coroutines::Task<uint64_t> register_remote_descriptor(std::shared_ptr<runtime::Descriptor2> descriptor);
Expand Down Expand Up @@ -238,6 +244,8 @@ class DataPlaneResources2

uint64_t get_next_object_id();

// Callback function to decrement shared_ptr reference count or signal end-of-life of a descriptor object
// Requires awaiting on the release of coroutines::Semaphore
coroutines::Task<void> complete_remote_pull(remote_descriptor::DescriptorPullCompletionMessage* message);

uint64_t m_max_remote_descriptors{std::numeric_limits<uint64_t>::max()};
Expand All @@ -249,10 +257,11 @@ class DataPlaneResources2
boost::fibers::mutex m_remote_descriptors_mutex{};

// ClosableRingBuffer uses 100000 as a "practical" limit where the capacity is the minimum of the two values.
coroutines::ClosableRingBuffer<std::shared_ptr<runtime::Descriptor2>> m_recv_descriptors{
{.capacity = std::min(m_max_remote_descriptors, static_cast<uint64_t>(100000))}};
std::unique_ptr<coroutines::ClosableRingBuffer<std::shared_ptr<runtime::Descriptor2>>> m_recv_descriptors;

protected:
// Maps descriptor id to a vector of shared_ptr instances
// Uses std::shared_ptr reference counting for maintaining the lifetime of a descriptor object
std::map<uint64_t, std::vector<std::shared_ptr<runtime::Descriptor2>>> m_descriptor_by_id;
};

Expand Down
1 change: 1 addition & 0 deletions cpp/mrc/src/internal/ucx/registration_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ std::optional<std::shared_ptr<ucxx::MemoryHandle>> RegistrationCache3::lookup(ui
{
std::lock_guard<decltype(m_mutex)> lock(m_mutex);

// The descriptor obj_id and memory block addr must both be valid
if (m_memory_handle_by_address.find(obj_id) != m_memory_handle_by_address.end())
{
auto descriptor_handles = m_memory_handle_by_address.at(obj_id);
Expand Down
14 changes: 13 additions & 1 deletion cpp/mrc/src/internal/ucx/registration_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class RegistrationCache2 final
* @brief UCX Registration Cache
*
* UCX memory registration object that will both register/deregister memory. The cache can be queried for the original
* memory block by providing the starting address of the contiguous block.
* memory block by providing the id of the descriptor object and the starting address of the contiguous block.
*/
class RegistrationCache3 final
{
Expand All @@ -194,8 +194,11 @@ class RegistrationCache3 final
* For each block of memory registered with the RegistrationCache, an entry containing the block information is
* storage and can be queried.
*
* @param obj_id ID of the descriptor object that owns the memory block being registered
* @param addr
* @param bytes
* @param memory_type
* @return std::shared_ptr<ucxx::MemoryHandle>
*/
std::shared_ptr<ucxx::MemoryHandle> add_block(uint64_t obj_id, void* addr, std::size_t bytes, memory::memory_kind memory_type);

Expand All @@ -207,13 +210,22 @@ class RegistrationCache3 final
* This method queries the registration cache to find the MemoryHanlde containing the original address and size as
* well as the serialized remote keys associated with the memory block.
*
* @param obj_id ID of the descriptor object that owns the memory block being registered
* @param addr
* @return std::shared_ptr<ucxx::MemoryHandle>
*/
std::optional<std::shared_ptr<ucxx::MemoryHandle>> lookup(uint64_t obj_id, const void* addr) const noexcept;

std::optional<std::shared_ptr<ucxx::MemoryHandle>> lookup(uint64_t obj_id, uintptr_t addr) const noexcept;

/**
* @brief Deregistration of all memory blocks owned by the descriptor object with id obj_id
*
* This method deregisters all memory blocks owned by the descriptor object at the end of the descriptor's lifetime.
* Required so the system does not run into memory insufficiency errors.
*
* @param obj_id ID of the descriptor object that owns the memory block being registered
*/
void remove_descriptor(uint64_t obj_id);

private:
Expand Down
2 changes: 2 additions & 0 deletions cpp/mrc/src/public/codable/decode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ void DecoderBase::read_descriptor(memory::buffer_view dst_view) const
else
{
const auto& deferred_msg = payload.deferred_msg();

// Depending on the message memory type, we will use a different memcpy method to properly copy the data
switch (payload.memory_kind())
{
case protos::MemoryKind::Host:
Expand Down
5 changes: 5 additions & 0 deletions cpp/mrc/src/public/codable/encode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ EncoderBase::EncoderBase(DescriptorObjectHandler& encoded_object) :

void EncoderBase::write_descriptor(memory::const_buffer_view view)
{
// Static check with arbitrary memory size to determine whether we should use eager or deferred protocol
// Thorough benchmarking and analysis should be done to derive protocol selection heuristic
MessageKind kind = (view.bytes() < 64_KiB) ? MessageKind::Eager : MessageKind::Deferred;

protos::Payload* payload = m_encoded_object.proto().add_payloads();
Expand All @@ -37,10 +39,12 @@ void EncoderBase::write_descriptor(memory::const_buffer_view view)
switch (kind)
{
case MessageKind::Eager: {
// If the message is allocated on device memory, we should fall through and default to using deferred protocol
if (view.kind() == memory::memory_kind::host)
{
auto* eager_msg = payload->mutable_eager_msg();

// Directly set the data for eager payload
eager_msg->set_data(view.data(), view.bytes());

return;
Expand All @@ -49,6 +53,7 @@ void EncoderBase::write_descriptor(memory::const_buffer_view view)
case MessageKind::Deferred: {
auto* deferred_msg = payload->mutable_deferred_msg();

// Set the payload address and number of bytes for later RDMA operation
deferred_msg->set_address(reinterpret_cast<uintptr_t>(view.data()));
deferred_msg->set_bytes(view.bytes());

Expand Down
5 changes: 3 additions & 2 deletions cpp/mrc/src/public/runtime/remote_descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,11 +383,12 @@ void Descriptor2::setup_remote_payloads()

auto* deferred_msg = payload.mutable_deferred_msg();

// Look for the memory block in the registration cache
auto ucx_block = m_data_plane_resources.registration_cache3().lookup(remote_object.object_id(), deferred_msg->address());

if (!ucx_block.has_value())
{
// Need to register the memory
// Given that the memory block is not registered, we must register the memory
ucx_block = m_data_plane_resources.registration_cache3().add_block(remote_object.object_id(),
deferred_msg->address(),
deferred_msg->bytes(),
Expand All @@ -409,7 +410,7 @@ void Descriptor2::fetch_remote_payloads()
// Loop over all remote payloads and convert them to local payloads
for (auto& remote_payload : *m_encoded_object->proto().mutable_payloads())
{
// If payload is an EagerMessage, we do not need to do any pulling
// If payload is an EagerMessage, we do not need to do RDMA operations on remote sending machine
if (remote_payload.has_eager_msg())
{
continue;
Expand Down
18 changes: 18 additions & 0 deletions cpp/mrc/src/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,28 @@ add_executable(test_mrc_private
pipelines/multi_segment.cpp
pipelines/single_segment.cpp
segments/common_segments.cpp
test_codable.cpp
# test_control_plane_components.cpp
# test_control_plane.cpp
test_expected.cpp
test_grpc.cpp
test_main.cpp
test_memory.cpp
test_network.cpp
test_next.cpp
# test_partition_manager.cpp
# test_partitions.cpp
test_pipeline.cpp
test_ranges.cpp
# test_remote_descriptor.cpp
# test_resources.cpp
test_reusable_pool.cpp
# test_runnable.cpp
# test_runtime.cpp
test_service.cpp
test_system.cpp
test_topology.cpp
test_ucx.cpp
)

target_link_libraries(test_mrc_private
Expand Down
Loading

0 comments on commit 19451ab

Please sign in to comment.