diff --git a/MODULE.bazel.lock b/MODULE.bazel.lock index 23e63bec..c620aa0b 100644 --- a/MODULE.bazel.lock +++ b/MODULE.bazel.lock @@ -106,13 +106,14 @@ "https://bcr.bazel.build/modules/openssl/3.3.2/MODULE.bazel": "not found", "https://bcr.bazel.build/modules/platforms/0.0.10/MODULE.bazel": "8cb8efaf200bdeb2150d93e162c40f388529a25852b332cec879373771e48ed5", "https://bcr.bazel.build/modules/platforms/0.0.11/MODULE.bazel": "0daefc49732e227caa8bfa834d65dc52e8cc18a2faf80df25e8caea151a9413f", - "https://bcr.bazel.build/modules/platforms/0.0.11/source.json": "f7e188b79ebedebfe75e9e1d098b8845226c7992b307e28e1496f23112e8fc29", "https://bcr.bazel.build/modules/platforms/0.0.4/MODULE.bazel": "9b328e31ee156f53f3c416a64f8491f7eb731742655a47c9eec4703a71644aee", "https://bcr.bazel.build/modules/platforms/0.0.5/MODULE.bazel": "5733b54ea419d5eaf7997054bb55f6a1d0b5ff8aedf0176fef9eea44f3acda37", "https://bcr.bazel.build/modules/platforms/0.0.6/MODULE.bazel": "ad6eeef431dc52aefd2d77ed20a4b353f8ebf0f4ecdd26a807d2da5aa8cd0615", "https://bcr.bazel.build/modules/platforms/0.0.7/MODULE.bazel": "72fd4a0ede9ee5c021f6a8dd92b503e089f46c227ba2813ff183b71616034814", "https://bcr.bazel.build/modules/platforms/0.0.8/MODULE.bazel": "9f142c03e348f6d263719f5074b21ef3adf0b139ee4c5133e2aa35664da9eb2d", "https://bcr.bazel.build/modules/platforms/0.0.9/MODULE.bazel": "4a87a60c927b56ddd67db50c89acaa62f4ce2a1d2149ccb63ffd871d5ce29ebc", + "https://bcr.bazel.build/modules/platforms/1.0.0/MODULE.bazel": "f05feb42b48f1b3c225e4ccf351f367be0371411a803198ec34a389fb22aa580", + "https://bcr.bazel.build/modules/platforms/1.0.0/source.json": "f4ff1fd412e0246fd38c82328eb209130ead81d62dcd5a9e40910f867f733d96", "https://bcr.bazel.build/modules/protobuf/21.7/MODULE.bazel": "a5a29bb89544f9b97edce05642fac225a808b5b7be74038ea3640fae2f8e66a7", "https://bcr.bazel.build/modules/protobuf/27.0/MODULE.bazel": "7873b60be88844a0a1d8f80b9d5d20cfbd8495a689b8763e76c6372998d3f64c", "https://bcr.bazel.build/modules/protobuf/27.1/MODULE.bazel": "703a7b614728bb06647f965264967a8ef1c39e09e8f167b3ca0bb1fd80449c0d", @@ -316,6 +317,7 @@ "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/platforms/0.0.7/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/platforms/0.0.8/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/platforms/0.0.9/MODULE.bazel": "not found", + "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/platforms/1.0.0/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/protobuf/21.7/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/protobuf/27.0/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/protobuf/27.1/MODULE.bazel": "not found", diff --git a/yacl/link/BUILD.bazel b/yacl/link/BUILD.bazel index 361b0442..0a0e71d1 100644 --- a/yacl/link/BUILD.bazel +++ b/yacl/link/BUILD.bazel @@ -143,3 +143,29 @@ cc_proto_library( ":link_proto", ], ) + +yacl_cc_library( + name = "mbox_capi", + srcs = ["mbox_capi.cc"], + hdrs = ["mbox_capi.h"], +) + +yacl_cc_library( + name = "mbox_wrapper", + srcs = ["mbox_wrapper.cc"], + hdrs = ["mbox_wrapper.h"], + deps = [ + ":mbox_capi", + "@abseil-cpp//absl/types:span", + ], +) + +yacl_cc_library( + name = "link_bridge", + srcs = ["link_bridge.cc"], + hdrs = ["link_bridge.h"], + deps = [ + ":mbox_capi", + "//yacl/link/transport:channel", + ], +) diff --git a/yacl/link/link_bridge.cc b/yacl/link/link_bridge.cc new file mode 100644 index 00000000..96f56057 --- /dev/null +++ b/yacl/link/link_bridge.cc @@ -0,0 +1,151 @@ +// Copyright 2025 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/link/link_bridge.h" + +#include +#include +#include + +#include "yacl/base/buffer.h" +#include "yacl/base/byte_container_view.h" +#include "yacl/base/exception.h" +#include "yacl/link/mbox_capi.h" +#include "yacl/link/transport/channel.h" + +namespace yacl::link { + +// Structure to hold the channel and msgloop data +struct ChannelMboxData { + std::vector> channels; + std::shared_ptr msg_loop; +}; + +// Send function implementation using channels +static mbox_error_t channel_send_fn(void* user_data, size_t dst, + const char* key, const uint8_t* data, + size_t data_len) { + if (user_data == nullptr || key == nullptr || + (data == nullptr && data_len > 0)) { + return MBOX_ERROR_INVALID_ARGUMENT; + } + + auto* channel_data = static_cast(user_data); + if (dst >= channel_data->channels.size() || + channel_data->channels[dst] == nullptr) { + return MBOX_ERROR_INVALID_ARGUMENT; + } + + try { + channel_data->channels[dst]->Send(key, ByteContainerView(data, data_len)); + return MBOX_SUCCESS; + } catch (const std::bad_alloc&) { + return MBOX_ERROR_MEMORY; + } catch (const std::exception& e) { + // Handle other exceptions as network errors + return MBOX_ERROR_NETWORK; + } +} + +// Receive function implementation using channels +static mbox_error_t channel_recv_fn(void* user_data, size_t src, + const char* key, int64_t timeout_ms, + uint8_t** buffer, size_t* buffer_len) { + if (user_data == nullptr || key == nullptr || buffer == nullptr || + buffer_len == nullptr) { + return MBOX_ERROR_INVALID_ARGUMENT; + } + + auto* channel_data = static_cast(user_data); + if (src >= channel_data->channels.size() || + channel_data->channels[src] == nullptr) { + return MBOX_ERROR_INVALID_ARGUMENT; + } + + try { + // Set timeout if provided + if (timeout_ms >= 0) { + channel_data->channels[src]->SetRecvTimeout(timeout_ms); + } + + Buffer received_data = channel_data->channels[src]->Recv(key); + + if (received_data.size() == 0) { + *buffer = nullptr; + *buffer_len = 0; + return MBOX_ERROR_NOT_FOUND; + } + + // TODO: zero-copy optimization + // Allocate buffer for received data + *buffer = static_cast(malloc(received_data.size())); + if (*buffer == nullptr) { + return MBOX_ERROR_MEMORY; + } + + std::memcpy(*buffer, received_data.data(), received_data.size()); + *buffer_len = received_data.size(); + + return MBOX_SUCCESS; + } catch (const std::bad_alloc&) { + return MBOX_ERROR_MEMORY; + } catch (const IoError& e) { + // TODO: refine exception handling + return MBOX_ERROR_NOT_FOUND; + } catch (const std::exception& e) { + return MBOX_ERROR_NETWORK; + } +} + +// Free function for user data +static void channel_free_user_data_fn(void* user_data) { + if (user_data != nullptr) { + delete static_cast(user_data); + } +} + +// Bridge function to create a mbox instance from channels and receiver loop +mbox_t* CreateMbox(std::vector> channels, + std::shared_ptr msg_loop) { + if (channels.empty() || msg_loop == nullptr) { + return nullptr; + } + + try { + // Create user data structure + auto* channel_data = new (std::nothrow) ChannelMboxData(); + if (channel_data == nullptr) { + return nullptr; + } + + channel_data->channels = std::move(channels); + channel_data->msg_loop = std::move(msg_loop); + + // Create vtable with channel-based functions + mbox_vtable_t vtable; + vtable.user_data = channel_data; + vtable.send_fn = channel_send_fn; + vtable.recv_fn = channel_recv_fn; + vtable.free_user_data_fn = channel_free_user_data_fn; + + // Create mbox instance using the vtable + return mbox_create(vtable); + } catch (const std::bad_alloc&) { + return nullptr; + } catch (...) { + return nullptr; + } +} + +} // namespace yacl::link \ No newline at end of file diff --git a/yacl/link/link_bridge.h b/yacl/link/link_bridge.h new file mode 100644 index 00000000..6a093cc2 --- /dev/null +++ b/yacl/link/link_bridge.h @@ -0,0 +1,27 @@ +// Copyright 2025 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "yacl/link/mbox_capi.h" +#include "yacl/link/transport/channel.h" + +namespace yacl::link { + +// Bridge function to create a mbox instance from channels and receiver loop +mbox_t* CreateMbox(std::vector> channels, + std::shared_ptr msg_loop); + +} // namespace yacl::link \ No newline at end of file diff --git a/yacl/link/mbox_capi.cc b/yacl/link/mbox_capi.cc new file mode 100644 index 00000000..ec0c5658 --- /dev/null +++ b/yacl/link/mbox_capi.cc @@ -0,0 +1,99 @@ +// Copyright 2025 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/link/mbox_capi.h" + +#include +#include +#include + +// mbox_t is defined in the header as an opaque type +struct mbox_s { + mbox_vtable_t impl; +}; + +extern "C" { + +mbox_t* mbox_create(mbox_vtable_t vtable) { + if (vtable.send_fn == nullptr || vtable.recv_fn == nullptr) { + return nullptr; + } + + try { + mbox_t* mbox = new (std::nothrow) mbox_t(); + if (mbox == nullptr) { + return nullptr; + } + mbox->impl = vtable; + return mbox; + } catch (...) { + return nullptr; + } +} + +void mbox_destroy(mbox_t* mbox) { + if (mbox == nullptr) { + return; + } + + // Call the free function if provided + if (mbox->impl.free_user_data_fn != nullptr && + mbox->impl.user_data != nullptr) { + mbox->impl.free_user_data_fn(mbox->impl.user_data); + } + + delete mbox; +} + +mbox_error_t mbox_send(mbox_t* mbox, size_t dst, const char* key, + const uint8_t* data, size_t data_len) { + if (mbox == nullptr) { + return MBOX_ERROR_NOT_INITIALIZED; + } + + if (key == nullptr) { + return MBOX_ERROR_INVALID_ARGUMENT; + } + + if (data == nullptr && data_len > 0) { + return MBOX_ERROR_INVALID_ARGUMENT; + } + + if (mbox->impl.send_fn == nullptr) { + return MBOX_ERROR_INTERNAL; + } + + return mbox->impl.send_fn(mbox->impl.user_data, dst, key, data, data_len); +} + +mbox_error_t mbox_recv(mbox_t* mbox, size_t src, const char* key, + int64_t timeout_ms, uint8_t** buffer, + size_t* buffer_len) { + if (mbox == nullptr) { + return MBOX_ERROR_NOT_INITIALIZED; + } + + if (key == nullptr || buffer == nullptr || buffer_len == nullptr) { + return MBOX_ERROR_INVALID_ARGUMENT; + } + + if (mbox->impl.recv_fn == nullptr) { + return MBOX_ERROR_INTERNAL; + } + + return mbox->impl.recv_fn(mbox->impl.user_data, src, key, timeout_ms, buffer, + buffer_len); +} + +} // extern "C" \ No newline at end of file diff --git a/yacl/link/mbox_capi.h b/yacl/link/mbox_capi.h new file mode 100644 index 00000000..6524a6b5 --- /dev/null +++ b/yacl/link/mbox_capi.h @@ -0,0 +1,94 @@ +// Copyright 2025 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include + +/// Error codes for cross-language error handling. +typedef enum { + MBOX_SUCCESS = 0, ///< Operation completed successfully. + MBOX_ERROR_INVALID_ARGUMENT = -1, ///< Invalid argument provided. + MBOX_ERROR_NOT_FOUND = -2, ///< Message not found or timeout. + MBOX_ERROR_MEMORY = -3, ///< Memory allocation failed. + MBOX_ERROR_NETWORK = -4, ///< Network communication error. + MBOX_ERROR_INTERNAL = -5, ///< Internal error in implementation. + MBOX_ERROR_NOT_INITIALIZED = -6 ///< Mbox instance not properly initialized. +} mbox_error_t; + +// Opaque handle for mbox instance +typedef struct mbox_s mbox_t; +typedef struct mbox_vtable_s { + void* user_data; // user implementation pointer. + // Send function pointer + mbox_error_t (*send_fn)(void* user_data, size_t dst, const char* key, + const uint8_t* data, size_t data_len); + // Recv function pointer + mbox_error_t (*recv_fn)(void* user_data, size_t src, const char* key, + int64_t timeout_ms, uint8_t** buffer, + size_t* buffer_len); + // Free user data function pointer + void (*free_user_data_fn)(void* user_data); +} mbox_vtable_t; + +/// Creates a new mbox instance using the default C++ implementation. +/// +/// @return A new mbox instance, or nullptr on failure. +mbox_t* mbox_create(mbox_vtable_t vtable); + +/// Destroys a mbox instance created by mbox_create(). +/// +/// @param mbox The mbox instance to destroy. If nullptr, the function does +/// nothing. After destruction, the pointer becomes invalid. +void mbox_destroy(mbox_t* mbox); + +/// Sends a message to a specific destination. +/// +/// @param mbox The mbox instance. +/// @param dst Destination party ID (0-based index). +/// @param key Message identifier (null-terminated string). +/// @param data Raw message data to send. +/// @param data_len Length of data in bytes. +/// +/// @return MBOX_SUCCESS on success, appropriate error code on failure. +mbox_error_t mbox_send(mbox_t* mbox, size_t dst, const char* key, + const uint8_t* data, size_t data_len); + +/// Receives a message from a specific source. +/// +/// @param mbox The mbox instance. +/// @param src Source party ID to receive from (0-based index). +/// @param key Message identifier to receive (null-terminated string). +/// @param timeout_ms Timeout in milliseconds (-1 for infinite wait). +/// @param buffer Output parameter set to a newly allocated buffer +/// containing the received data. The caller must free this +/// buffer with free(). +/// @param buffer_len Output parameter set to the length of the received data. +/// +/// @return MBOX_SUCCESS on success, appropriate error code on failure. +/// +/// @note On success, `*buffer_len` contains the number of bytes received. +/// @note The returned buffer must be freed by the caller using free(). +mbox_error_t mbox_recv(mbox_t* mbox, size_t src, const char* key, + int64_t timeout_ms, uint8_t** buffer, + size_t* buffer_len); + +#ifdef __cplusplus +} +#endif diff --git a/yacl/link/mbox_wrapper.cc b/yacl/link/mbox_wrapper.cc new file mode 100644 index 00000000..392320a9 --- /dev/null +++ b/yacl/link/mbox_wrapper.cc @@ -0,0 +1,127 @@ +// Copyright 2025 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/link/mbox_wrapper.h" + +#include + +namespace yacl::link { + +MboxWrapper::MboxWrapper(mbox_t* mbox, size_t rank, size_t world_size, + bool take_ownership) + : mbox_(mbox), + owns_mbox_(take_ownership), + rank_(rank), + world_size_(world_size) { + if (mbox_ == nullptr) { + throw std::invalid_argument("Cannot wrap nullptr mbox"); + } +} + +MboxWrapper::~MboxWrapper() { + if (mbox_ != nullptr && owns_mbox_) { + mbox_destroy(mbox_); + mbox_ = nullptr; + } +} + +MboxWrapper::MboxWrapper(MboxWrapper&& other) noexcept + : mbox_(other.mbox_), + owns_mbox_(other.owns_mbox_), + rank_(other.rank_), + world_size_(other.world_size_) { + other.mbox_ = nullptr; + other.owns_mbox_ = false; +} + +MboxWrapper& MboxWrapper::operator=(MboxWrapper&& other) noexcept { + if (this != &other) { + if (mbox_ != nullptr && owns_mbox_) { + mbox_destroy(mbox_); + } + + mbox_ = other.mbox_; + owns_mbox_ = other.owns_mbox_; + rank_ = other.rank_; + world_size_ = other.world_size_; + + other.mbox_ = nullptr; + other.owns_mbox_ = false; + other.rank_ = 0; + other.world_size_ = 1; + } + return *this; +} + +void MboxWrapper::Send(size_t dst, std::string_view key, + absl::Span data) { + mbox_error_t result = + mbox_send(mbox_, dst, key.data(), data.data(), data.size()); + + switch (result) { + case MBOX_SUCCESS: + return; + case MBOX_ERROR_INVALID_ARGUMENT: + throw std::invalid_argument("Invalid arguments provided to Send"); + case MBOX_ERROR_MEMORY: + throw std::bad_alloc(); + case MBOX_ERROR_NETWORK: + throw std::runtime_error("Network error during Send"); + case MBOX_ERROR_INTERNAL: + throw std::runtime_error("Internal error during Send"); + default: + throw std::runtime_error("Unknown error during Send"); + } +} + +std::vector MboxWrapper::Recv(size_t src, std::string_view key, + int64_t timeout_ms) { + uint8_t* buffer = nullptr; + size_t buffer_len = 0; + + mbox_error_t result = + mbox_recv(mbox_, src, key.data(), timeout_ms, &buffer, &buffer_len); + + // Always create a cleanup guard for the buffer + struct BufferGuard { + uint8_t* ptr; + ~BufferGuard() { + if (ptr) free(ptr); + } + } guard{buffer}; + + switch (result) { + case MBOX_SUCCESS: { + if (buffer == nullptr || buffer_len == 0) { + return {}; + } + std::vector data(buffer, buffer + buffer_len); + return data; + } + case MBOX_ERROR_INVALID_ARGUMENT: + throw std::invalid_argument("Invalid arguments provided to Recv"); + case MBOX_ERROR_NOT_FOUND: + return {}; // Return empty vector for timeout/not found + case MBOX_ERROR_MEMORY: + throw std::bad_alloc(); + case MBOX_ERROR_NETWORK: + throw std::runtime_error("Network error during Recv"); + case MBOX_ERROR_INTERNAL: + throw std::runtime_error("Internal error during Recv"); + default: + throw std::runtime_error("Unknown error during Recv"); + } +} + +} // namespace yacl::link \ No newline at end of file diff --git a/yacl/link/mbox_wrapper.h b/yacl/link/mbox_wrapper.h new file mode 100644 index 00000000..76aaf957 --- /dev/null +++ b/yacl/link/mbox_wrapper.h @@ -0,0 +1,86 @@ +// Copyright 2025 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "absl/types/span.h" + +#include "yacl/link/mbox_capi.h" + +namespace yacl::link { + +/// MboxWrapper provides a C++ wrapper around the C API mbox interface. +/// This replaces the previous abstract Mbox class with a concrete +/// implementation that uses the C API internally. +class MboxWrapper { + public: + /// Constructor - wraps an existing mbox_t instance. + /// + /// @param mbox The mbox instance to wrap. + /// @param take_ownership If true, the wrapper will take ownership and destroy + /// the mbox when destroyed. If false, the wrapper will + /// not destroy the mbox. + MboxWrapper(mbox_t* mbox, size_t rank, size_t world_size, + bool take_ownership = false); + + /// Destructor - cleans up the mbox instance if owned. + ~MboxWrapper(); + + /// Move constructor. + MboxWrapper(MboxWrapper&& other) noexcept; + + /// Move assignment operator. + MboxWrapper& operator=(MboxWrapper&& other) noexcept; + + // Disable copy constructor and copy assignment + MboxWrapper(const MboxWrapper&) = delete; + MboxWrapper& operator=(const MboxWrapper&) = delete; + + /// Send a message to the specified destination. + /// + /// @param dst The destination rank. + /// @param key The message key. + /// @param data The message data. + void Send(size_t dst, std::string_view key, absl::Span data); + + /// Receive a message from the specified source. + /// + /// @param src The source rank. + /// @param key The message key. + /// @param timeout_ms Timeout in milliseconds (-1 for infinite wait). + /// @return The received message data, or empty vector if timeout/error. + std::vector Recv(size_t src, std::string_view key, + int64_t timeout_ms); + + /// Get the rank of this mbox instance. + /// @return The rank (0-based index). + size_t Rank() const { return rank_; } + + /// Get the world size (total number of parties). + /// @return The world size. + size_t WorldSize() const { return world_size_; } + + private: + mbox_t* mbox_ = nullptr; + bool owns_mbox_ = false; + size_t rank_; + size_t world_size_; +}; + +} // namespace yacl::link \ No newline at end of file