diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml new file mode 100644 index 0000000..9fd45e0 --- /dev/null +++ b/.github/workflows/rust.yml @@ -0,0 +1,22 @@ +name: Rust + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Build + run: cargo build --verbose + - name: Run tests + run: cargo test --verbose diff --git a/.gitignore b/.gitignore index 96ef6c0..057c110 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ +.vscode/ + /target Cargo.lock diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f49d8a..62f2c51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.2.0] - 2025-06-29 + +### Added + +- Added AsyncRead/AsyncWrite support for Tag and Stream (requires `utils` feature flag) +- Added `connect_addr_vec` function to Worker + +### Changed + +- Updated to UCX 1.18 with latest API compatibility +- Updated multiple dependency versions +- Migrated to Rust 2021 edition + +### Fixed + +- Fixed various bugs and issues + ## [0.1.1] - 2022-09-01 ### Changed diff --git a/Cargo.toml b/Cargo.toml index 0b06a69..485e847 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "async-ucx" -version = "0.1.1" -authors = ["Runji Wang ", "Yiyuan Liu "] +version = "0.2.0" +authors = ["Runji Wang ", "Yiyuan Liu ", "Kaiwei Li "] edition = "2021" description = "Asynchronous Rust bindings to UCX." homepage = "https://github.com/madsys-dev/async-ucx" @@ -15,21 +15,24 @@ categories = ["asynchronous", "api-bindings", "network-programming"] [features] event = ["tokio"] am = ["tokio/sync", "crossbeam"] +util = ["tokio"] [dependencies] -ucx1-sys = { version = "0.1", path = "ucx1-sys" } +ucx1-sys = { version = "0.2", path = "ucx1-sys" } socket2 = "0.4" futures = "0.3" futures-lite = "1.11" lazy_static = "1.4" log = "0.4" +bytes = "1.10" tokio = { version = "1.0", features = ["net"], optional = true } -crossbeam = { version = "0.8", optional = true } +crossbeam = { version = "0.8", features = ["alloc"], optional = true } derivative = "2.2.0" thiserror = "1.0" +pin-project = "1.1.10" [dev-dependencies] -tokio = { version = "1.0", features = ["rt", "time", "macros", "sync"] } +tokio = { version = "1.0", features = ["rt", "time", "macros", "sync", "io-util"] } env_logger = "0.9" tracing = { version = "0.1", default-features = false } tracing-subscriber = { version = "0.2.17", default-features = false, features = ["env-filter", "fmt"] } diff --git a/README.md b/README.md index 2f47499..04e7414 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,94 @@ [![Docs](https://docs.rs/async-ucx/badge.svg)](https://docs.rs/async-ucx) [![CI](https://github.com/madsys-dev/async-ucx/workflows/CI/badge.svg?branch=main)](https://github.com/madsys-dev/async-ucx/actions) -Async Rust UCX bindings. +Async Rust UCX bindings providing high-performance networking capabilities for distributed systems and HPC applications. + +## Features + +- **Asynchronous UCP Operations**: Full async/await support for UCX operations +- **Multiple Communication Models**: Support for RMA, Stream, Tag, and Active Message APIs +- **High Performance**: Optimized for low-latency, high-throughput communication +- **Tokio Integration**: Seamless integration with Tokio async runtime +- **Comprehensive Examples**: Ready-to-use examples for various UCX patterns ## Optional features -- `event`: Enable UCP wakeup mechanism. -- `am`: Enable UCP Active Message API. +- `event`: Enable UCP wakeup mechanism for event-driven applications +- `am`: Enable UCP Active Message API for flexible message handling +- `util`: Enable additional utility functions for UCX integration + +## Quick Start + +Add to your `Cargo.toml`: + +```toml +[dependencies] +async-ucx = "0.2" +tokio = { version = "1.0", features = ["rt", "net"] } +``` + +Basic usage example: + +```rust +use async_ucx::ucp::*; +use std::mem::MaybeUninit; +use std::net::SocketAddr; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + // Create UCP contexts and workers + let context1 = Context::new()?; + let worker1 = context1.create_worker()?; + let context2 = Context::new()?; + let worker2 = context2.create_worker()?; + + // Start polling for both workers + tokio::task::spawn_local(worker1.clone().polling()); + tokio::task::spawn_local(worker2.clone().polling()); + + // Create listener on worker1 + let mut listener = worker1 + .create_listener("0.0.0.0:0".parse().unwrap())?; + let listen_port = listener.socket_addr()?.port(); + + // Connect worker2 to worker1 + let mut addr: SocketAddr = "127.0.0.1:0".parse().unwrap(); + addr.set_port(listen_port); + + let (endpoint1, endpoint2) = tokio::join!( + async { + let conn1 = listener.next().await; + worker1.accept(conn1).await.unwrap() + }, + async { worker2.connect_socket(addr).await.unwrap() }, + ); + + // Send and receive tag message + tokio::join!( + async { + let msg = b"Hello UCX!"; + endpoint2.tag_send(1, msg).await.unwrap(); + println!("Message sent"); + }, + async { + let mut buf = vec![MaybeUninit::::uninit(); 10]; + worker1.tag_recv(1, &mut buf).await.unwrap(); + println!("Message received"); + } + ); + + Ok(()) +} +``` + +## Examples + +Check the `examples/` directory for comprehensive examples: +- `rma.rs`: Remote Memory Access operations +- `stream.rs`: Stream-based communication +- `tag.rs`: Tag-based message matching +- `bench.rs`: Performance benchmarking +- `bench-multi-thread.rs`: Multi-threaded benchmarking ## License diff --git a/examples/bench-multi-thread.rs b/examples/bench-multi-thread.rs index e6b1ef2..9651c5d 100644 --- a/examples/bench-multi-thread.rs +++ b/examples/bench-multi-thread.rs @@ -123,7 +123,7 @@ impl WorkerThread { .build() .unwrap(); let local = tokio::task::LocalSet::new(); - #[cfg(not(event))] + #[cfg(not(feature = "event"))] local.spawn_local(worker.clone().polling()); #[cfg(feature = "event")] local.spawn_local(worker.clone().event_poll()); diff --git a/src/lib.rs b/src/lib.rs index 02cc343..23ac9fe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -163,3 +163,44 @@ impl Error { } } } + +impl From for std::io::Error { + fn from(val: Error) -> Self { + use std::io::ErrorKind::*; + let kind = match val { + Error::Inprogress => WouldBlock, + Error::NoMessage => WouldBlock, + Error::NoReource => WouldBlock, + Error::IoError => Other, + Error::NoMemory => OutOfMemory, + Error::InvalidParam => InvalidInput, + Error::Unreachable => NotConnected, + Error::InvalidAddr => InvalidInput, + Error::NotImplemented => Unsupported, + Error::MessageTruncated => InvalidData, + Error::NoProgress => WouldBlock, + Error::BufferTooSmall => UnexpectedEof, + Error::NoElem => NotFound, + Error::SomeConnectsFailed => ConnectionAborted, + Error::NoDevice => NotFound, + Error::Busy => ResourceBusy, + Error::Canceled => Interrupted, + Error::ShmemSegment => Other, + Error::AlreadyExists => AlreadyExists, + Error::OutOfRange => InvalidInput, + Error::Timeout => TimedOut, + Error::ExceedsLimit => Other, + Error::Unsupported => Unsupported, + Error::Rejected => ConnectionRefused, + Error::NotConnected => NotConnected, + Error::ConnectionReset => ConnectionReset, + Error::FirstLinkFailure => Other, + Error::LastLinkFailure => Other, + Error::FirstEndpointFailure => Other, + Error::LastEndpointFailure => Other, + Error::EndpointTimeout => TimedOut, + Error::Unknown => Other, + }; + std::io::Error::new(kind, val) + } +} diff --git a/src/ucp/endpoint/am.rs b/src/ucp/endpoint/am.rs index 2002020..046098e 100644 --- a/src/ucp/endpoint/am.rs +++ b/src/ucp/endpoint/am.rs @@ -1,6 +1,7 @@ use crossbeam::queue::SegQueue; use tokio::sync::Notify; +use super::param::RequestParam; use super::*; use std::{ io::{IoSlice, IoSliceMut}, @@ -8,7 +9,7 @@ use std::{ sync::atomic::AtomicBool, }; -//// Active message protocol. +/// Active message protocol. /// Active message protocol is a mechanism for sending and receiving messages /// between processes in a distributed system. /// It allows a process to send a message to another process, which can then @@ -221,7 +222,7 @@ impl<'a> AmMsg<'a> { } Some(AmData::Data(data)) => { // data message, no need to receive - let size = copy_data_to_iov(&data, iov)?; + let size = copy_data_to_iov(data, iov)?; self.drop_msg(AmData::Data(data)); Ok(size) } @@ -249,22 +250,12 @@ impl<'a> AmMsg<'a> { self.worker.handle, iov.len() ); - let mut param = MaybeUninit::::uninit(); - let (buffer, count) = unsafe { - let param = &mut *param.as_mut_ptr(); - param.op_attr_mask = ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32 - | ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32; - param.cb = ucp_request_param_t__bindgen_ty_1 { - recv_am: Some(callback), - }; - - if iov.len() == 1 { - param.datatype = ucp_dt_make_contig(1); - (iov[0].as_ptr(), iov[0].len()) - } else { - param.datatype = ucp_dt_type::UCP_DATATYPE_IOV as _; - (iov.as_ptr() as _, iov.len()) - } + + let param = RequestParam::new().cb_recv_am(Some(callback)); + let (buffer, count, param) = if iov.len() == 1 { + (iov[0].as_ptr(), iov[0].len(), param) + } else { + (iov.as_ptr() as _, iov.len(), param.iov()) }; let status = unsafe { @@ -273,7 +264,7 @@ impl<'a> AmMsg<'a> { data_desc as _, buffer as _, count as _, - param.as_ptr(), + param.as_ref(), ) }; if status.is_null() { @@ -282,9 +273,9 @@ impl<'a> AmMsg<'a> { } else if UCS_PTR_IS_PTR(status) { RequestHandle { ptr: status, - poll_fn: poll_recv, + poll_fn: poll_normal, } - .await; + .await?; Ok(data_len) } else { Err(Error::from_ptr(status).unwrap_err()) @@ -304,12 +295,21 @@ impl<'a> AmMsg<'a> { && !self.msg.reply_ep.is_null() } + /// return endpoint handler + pub fn reply_ep(&self) -> Option { + if self.need_reply() { + Some(EndpointHandler(self.msg.reply_ep)) + } else { + None + } + } + /// Send reply /// # Safety /// User needs to ensure that the endpoint isn't closed. pub async unsafe fn reply( &self, - id: u32, + id: u16, header: &[u8], data: &[u8], need_reply: bool, @@ -327,7 +327,7 @@ impl<'a> AmMsg<'a> { /// User needs to ensure that the endpoint isn't closed. pub async unsafe fn reply_vectorized( &self, - id: u32, + id: u16, header: &[u8], data: &[IoSlice<'_>], need_reply: bool, @@ -439,8 +439,8 @@ impl Worker { param: *const ucp_am_recv_param_t, ) -> ucs_status_t { let handler = &*(arg as *const AmStreamInner); - let header = slice::from_raw_parts(header as *const u8, header_len as usize); - let data = slice::from_raw_parts(data as *const u8, data_len as usize); + let header = slice::from_raw_parts(header as *const u8, header_len); + let data = slice::from_raw_parts(data as *const u8, data_len); let param = &*param; handler.callback(header, data, param.reply_ep, param.recv_attr); @@ -460,7 +460,7 @@ impl Worker { } self.am_streams.write().unwrap().insert(id, stream.clone()); - return Ok(AmStream::new(self, stream)); + Ok(AmStream::new(self, stream)) } /// Register active message handler for `id`. @@ -497,7 +497,7 @@ impl Endpoint { /// Send active message. pub async fn am_send( &self, - id: u32, + id: u16, header: &[u8], data: &[u8], need_reply: bool, @@ -511,7 +511,7 @@ impl Endpoint { /// Send active message. pub async fn am_send_vectorized( &self, - id: u32, + id: u16, header: &[u8], data: &[IoSlice<'_>], need_reply: bool, @@ -534,7 +534,7 @@ pub enum AmProto { async fn am_send( endpoint: ucp_ep_h, - id: u32, + id: u16, header: &[u8], data: &[IoSlice<'_>], need_reply: bool, @@ -546,45 +546,33 @@ async fn am_send( request.waker.wake(); } - let mut param = MaybeUninit::::uninit(); - let (buffer, count) = unsafe { - let param = &mut *param.as_mut_ptr(); - param.op_attr_mask = ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32 - | ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32 - | ucp_op_attr_t::UCP_OP_ATTR_FIELD_FLAGS as u32; - param.flags = 0; - param.cb = ucp_request_param_t__bindgen_ty_1 { - send: Some(callback), - }; - - match proto { - Some(AmProto::Eager) => param.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_EAGER.0, - Some(AmProto::Rndv) => param.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_RNDV.0, - _ => (), - } - - if need_reply { - param.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_REPLY.0; - } - - if data.len() == 1 { - param.datatype = ucp_dt_make_contig(1); - (data[0].as_ptr(), data[0].len()) - } else { - param.datatype = ucp_dt_type::UCP_DATATYPE_IOV as _; - (data.as_ptr() as _, data.len()) - } + // Use RequestParam builder for send + let param = RequestParam::new().cb_send(Some(callback)); + let param = match proto { + Some(AmProto::Eager) => param.set_flag_eager(), + Some(AmProto::Rndv) => param.set_flag_rndv(), + None => param, + }; + let param = if need_reply { + param.set_flag_reply() + } else { + param + }; + let (buffer, count, param) = if data.len() == 1 { + (data[0].as_ptr(), data[0].len(), param) + } else { + (data.as_ptr() as _, data.len(), param.iov()) }; let status = unsafe { ucp_am_send_nbx( endpoint, - id, + id as u32, header.as_ptr() as _, header.len() as _, buffer as _, count as _, - param.as_mut_ptr(), + param.as_ref(), ) }; if status.is_null() { @@ -601,15 +589,6 @@ async fn am_send( } } -unsafe fn poll_recv(ptr: ucs_status_ptr_t) -> Poll<()> { - let status = ucp_request_check_status(ptr as _); - if status == ucs_status_t::UCS_INPROGRESS { - Poll::Pending - } else { - Poll::Ready(()) - } -} - #[cfg(test)] #[cfg(feature = "am")] mod tests { @@ -617,7 +596,7 @@ mod tests { #[test_log::test] fn am() { - let protos = vec![None, Some(AmProto::Eager), Some(AmProto::Rndv)]; + let protos = [None, Some(AmProto::Eager), Some(AmProto::Rndv)]; for block_size_shift in 0..20_usize { for p in protos.iter() { let rt = tokio::runtime::Builder::new_current_thread() @@ -671,13 +650,13 @@ mod tests { let msg = stream1.wait_msg().await; let mut msg = msg.expect("no msg"); assert_eq!(msg.header(), &header); - assert_eq!(msg.contains_data(), true); + assert!(msg.contains_data()); assert_eq!(msg.data_len(), data.len()); let mut recv_data = vec![0_u8; msg.data_len()]; let recv_len = msg.recv_data_single(&mut recv_data).await.unwrap(); assert_eq!(data.len(), recv_len); assert_eq!(data, recv_data); - assert_eq!(msg.contains_data(), false); + assert!(!msg.contains_data()); msg } ); @@ -695,11 +674,11 @@ mod tests { let reply = stream2.wait_msg().await; let mut reply = reply.expect("no reply"); assert_eq!(reply.header(), &header); - assert_eq!(reply.contains_data(), true); + assert!(reply.contains_data()); assert_eq!(reply.data_len(), data.len()); let recv_data = reply.recv_data().await.unwrap(); assert_eq!(data, recv_data); - assert_eq!(reply.contains_data(), false); + assert!(!reply.contains_data()); } ); diff --git a/src/ucp/endpoint/mod.rs b/src/ucp/endpoint/mod.rs index b477468..ae3aa3c 100644 --- a/src/ucp/endpoint/mod.rs +++ b/src/ucp/endpoint/mod.rs @@ -9,32 +9,53 @@ use std::task::Poll; #[cfg(feature = "am")] mod am; +mod param; mod rma; mod stream; mod tag; +#[cfg(feature = "util")] +mod util; #[cfg(feature = "am")] pub use self::am::*; pub use self::rma::*; +#[cfg(feature = "util")] +pub use self::util::*; // State associate with ucp_ep_h -// todo: Add a `get_user_data` to UCX -#[derive(Debug)] +// This owns the UCX endpoint handle and closes it when the last Rc reference drops struct EndpointInner { + handle: Cell, closed: AtomicBool, status: Cell, worker: Rc, } +impl std::fmt::Debug for EndpointInner { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EndpointInner") + .field("handle", &self.handle.get()) + .field("closed", &self.closed) + .field("worker", &self.worker) + .finish() + } +} + impl EndpointInner { - fn new(worker: Rc) -> Self { + fn new(handle: ucp_ep_h, worker: Rc) -> Self { EndpointInner { + handle: Cell::new(handle), closed: AtomicBool::new(false), status: Cell::new(ucs_status_t::UCS_OK), worker, } } + #[inline(always)] + fn get_handle(&self) -> ucp_ep_h { + self.handle.get() + } + fn closed(self: &Rc) { if self .closed @@ -71,20 +92,78 @@ impl EndpointInner { } } +impl Drop for EndpointInner { + fn drop(&mut self) { + // This runs when the LAST Rc reference drops + // All Endpoint clones must be gone before this runs + let handle = self.handle.get(); + if !handle.is_null() && !self.is_closed() { + // Try graceful close first (FLUSH mode - completes pending operations) + let status = unsafe { + ucp_ep_close_nb(handle, ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FLUSH as u32) + }; + + if status.is_null() { + // Graceful close completed immediately + trace!("destroy endpoint={:?} (graceful close)", handle); + } else if UCS_PTR_IS_PTR(status) { + // Graceful close returned pending request + // Can't wait in Drop context - cancel and force close + trace!( + "destroy endpoint={:?} (graceful pending, using force)", + handle + ); + unsafe { + ucp_request_cancel(self.worker.handle, status as _); + ucp_request_free(status as _); + } + // Now force close + let status = unsafe { + ucp_ep_close_nb(handle, ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FORCE as u32) + }; + let _ = + Error::from_ptr(status).map_err(|err| error!("Failed to force close, {}", err)); + } else { + // Graceful close returned error (e.g. peer already closed) + // Use force to clean up + trace!( + "destroy endpoint={:?} (graceful failed, using force)", + handle + ); + let status = unsafe { + ucp_ep_close_nb(handle, ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FORCE as u32) + }; + let _ = + Error::from_ptr(status).map_err(|err| error!("Failed to force close, {}", err)); + } + + // Mark as closed + self.closed.store(true, std::sync::atomic::Ordering::SeqCst); + } + } +} + /// Communication endpoint. +/// Cloning an Endpoint creates a new reference to the same underlying UCX connection. +/// The connection closes when the last Endpoint clone is dropped. #[derive(Debug, Clone)] pub struct Endpoint { - handle: ucp_ep_h, inner: Rc, } +/// Type alias of Endpoint handler +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct EndpointHandler(ucp_ep_h); +unsafe impl Sync for EndpointHandler {} +unsafe impl Send for EndpointHandler {} + impl Endpoint { fn create(worker: &Rc, mut params: ucp_ep_params) -> Result { - let inner = Rc::new(EndpointInner::new(worker.clone())); + // Temporarily create inner with null handle (will be updated) + let inner = Rc::new(EndpointInner::new(std::ptr::null_mut(), worker.clone())); let weak = Rc::downgrade(&inner); - // ucp endpoint keep a weak reference to inner - // this reference will drop when endpoint is closed + // ucp endpoint keep a weak reference to inner for error callback let ptr = Weak::into_raw(weak); unsafe extern "C" fn callback(arg: *mut c_void, ep: ucp_ep_h, status: ucs_status_t) { let weak: Weak = Weak::from_raw(arg as _); @@ -118,8 +197,12 @@ impl Endpoint { } let handle = unsafe { handle.assume_init() }; + + // Update the handle in the inner (via Cell, no unsafe needed) + inner.handle.set(handle); + trace!("create endpoint={:?}", handle); - Ok(Self { handle, inner }) + Ok(Self { inner }) } pub(super) async fn connect_socket( @@ -197,13 +280,26 @@ impl Endpoint { #[inline] fn get_handle(&self) -> Result { self.inner.check()?; - Ok(self.handle) + let handle = self.inner.get_handle(); + if handle.is_null() { + Err(Error::from_error(ucs_status_t::UCS_ERR_NO_RESOURCE)) + } else { + Ok(handle) + } + } + + /// Get the endpoint handler + pub fn handler(&self) -> Result { + Ok(EndpointHandler(self.get_handle()?)) } /// Print endpoint information to stderr. pub fn print_to_stderr(&self) { if !self.inner.is_closed() { - unsafe { ucp_ep_print_info(self.handle, stderr) }; + let handle = self.inner.get_handle(); + if !handle.is_null() { + unsafe { ucp_ep_print_info(handle, stderr) }; + } } } @@ -239,20 +335,21 @@ impl Endpoint { self.get_status()?; } - trace!("close: endpoint={:?}", self.handle); + let handle = self.get_handle()?; + trace!("close: endpoint={:?}", handle); let mode = if force { ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FORCE as u32 } else { ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FLUSH as u32 }; - let status = unsafe { ucp_ep_close_nb(self.handle, mode) }; + let status = unsafe { ucp_ep_close_nb(handle, mode) }; if status.is_null() { trace!("close: complete"); self.inner.closed(); Ok(()) } else if UCS_PTR_IS_PTR(status) { let result = loop { - if let Poll::Ready(result) = unsafe { poll_normal(status) } { + if let Poll::Ready(result) = poll_normal(status) { unsafe { ucp_request_free(status as _) }; break result; } else { @@ -284,37 +381,25 @@ impl Endpoint { } } -impl Drop for Endpoint { - fn drop(&mut self) { - if !self.inner.is_closed() { - trace!("destroy endpoint={:?}", self.handle); - let status = unsafe { - ucp_ep_close_nb( - self.handle, - ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FORCE as u32, - ) - }; - let _ = Error::from_ptr(status).map_err(|err| error!("Failed to force close, {}", err)); - self.inner.closed(); - } - } -} +// Drop for Endpoint no longer needed - EndpointInner::Drop handles cleanup +// when the last Rc reference is dropped. This ensures endpoints stay alive +// when cloned/cached, fixing bidirectional communication. /// A handle to the request returned from async IO functions. struct RequestHandle { ptr: ucs_status_ptr_t, - poll_fn: unsafe fn(ucs_status_ptr_t) -> Poll, + poll_fn: fn(ucs_status_ptr_t) -> Poll, } impl Future for RequestHandle { type Output = T; fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll { - if let ret @ Poll::Ready(_) = unsafe { (self.poll_fn)(self.ptr) } { + if let ret @ Poll::Ready(_) = { (self.poll_fn)(self.ptr) } { return ret; } let request = unsafe { &mut *(self.ptr as *mut Request) }; request.waker.register(cx.waker()); - unsafe { (self.poll_fn)(self.ptr) } + (self.poll_fn)(self.ptr) } } @@ -325,8 +410,32 @@ impl Drop for RequestHandle { } } -unsafe fn poll_normal(ptr: ucs_status_ptr_t) -> Poll> { - let status = ucp_request_check_status(ptr as _); +enum Status { + Completed(Result), + Scheduled(RequestHandle>), +} + +impl Status { + pub fn from( + status: *mut c_void, + immediate: MaybeUninit, + poll_fn: fn(ucs_status_ptr_t) -> Poll>, + ) -> Self { + if status.is_null() { + Self::Completed(Ok(unsafe { immediate.assume_init() })) + } else if UCS_PTR_IS_ERR(status) { + Self::Completed(Err(Error::from_error(UCS_PTR_RAW_STATUS(status)))) + } else { + Self::Scheduled(RequestHandle { + ptr: status, + poll_fn, + }) + } + } +} + +fn poll_normal(ptr: ucs_status_ptr_t) -> Poll> { + let status = unsafe { ucp_request_check_status(ptr as _) }; if status == ucs_status_t::UCS_INPROGRESS { Poll::Pending } else { diff --git a/src/ucp/endpoint/param.rs b/src/ucp/endpoint/param.rs new file mode 100644 index 0000000..92ea65c --- /dev/null +++ b/src/ucp/endpoint/param.rs @@ -0,0 +1,97 @@ +use ucx1_sys::*; + +pub struct RequestParam { + inner: ucp_request_param_t, +} + +impl RequestParam { + pub fn new() -> Self { + // Zeroed for safety, as in C + Self { + inner: unsafe { std::mem::zeroed() }, + } + } + + pub fn cb_send(mut self, callback: ucp_send_nbx_callback_t) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32; + unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.send = callback; + self.inner.cb = cb; + } + self + } + + pub fn cb_tag_recv(mut self, callback: ucp_tag_recv_nbx_callback_t) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32; + unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.recv = callback; + self.inner.cb = cb; + } + self + } + + pub fn cb_stream_recv(mut self, callback: ucp_stream_recv_nbx_callback_t) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32; + unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.recv_stream = callback; + self.inner.cb = cb; + } + self + } + + pub fn recv_tag_info(mut self, info: *mut ucp_tag_recv_info) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_RECV_INFO as u32; + self.inner.recv_info.tag_info = info; + self + } + + #[cfg(feature = "am")] + pub fn cb_recv_am(mut self, callback: ucp_am_recv_data_nbx_callback_t) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32; + unsafe { + let mut cb: ucp_request_param_t__bindgen_ty_1 = std::mem::zeroed(); + cb.recv_am = callback; + self.inner.cb = cb; + } + self + } + + pub fn iov(mut self) -> Self { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32; + self.inner.datatype = ucp_dt_type::UCP_DATATYPE_IOV as _; + self + } + + #[cfg(feature = "am")] + fn set_flag(&mut self) { + self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_FLAGS as u32; + } + + #[cfg(feature = "am")] + pub fn set_flag_eager(mut self) -> Self { + self.set_flag(); + self.inner.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_EAGER.0; + self + } + + #[cfg(feature = "am")] + pub fn set_flag_rndv(mut self) -> Self { + self.set_flag(); + self.inner.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_RNDV.0; + self + } + + #[cfg(feature = "am")] + pub fn set_flag_reply(mut self) -> Self { + self.set_flag(); + self.inner.flags |= ucp_send_am_flags::UCP_AM_SEND_FLAG_REPLY.0; + self + } + + pub fn as_ref(&self) -> *const ucp_request_param_t { + &self.inner as *const _ + } +} diff --git a/src/ucp/endpoint/rma.rs b/src/ucp/endpoint/rma.rs index e9cd41e..420c5ca 100644 --- a/src/ucp/endpoint/rma.rs +++ b/src/ucp/endpoint/rma.rs @@ -1,3 +1,5 @@ +use super::param::RequestParam; + use super::*; /// A memory region allocated through UCP library, @@ -88,12 +90,11 @@ impl RKey { /// Create remote access key from packed buffer. pub fn unpack(endpoint: &Endpoint, rkey_buffer: &[u8]) -> Self { let mut handle = MaybeUninit::<*mut ucp_rkey>::uninit(); + let ep_handle = endpoint + .get_handle() + .expect("Endpoint must be valid for rkey unpack"); let status = unsafe { - ucp_ep_rkey_unpack( - endpoint.handle, - rkey_buffer.as_ptr() as _, - handle.as_mut_ptr(), - ) + ucp_ep_rkey_unpack(ep_handle, rkey_buffer.as_ptr() as _, handle.as_mut_ptr()) }; assert_eq!(status, ucs_status_t::UCS_OK); RKey { @@ -111,20 +112,26 @@ impl Drop for RKey { impl Endpoint { /// Stores a contiguous block of data into remote memory. pub async fn put(&self, buf: &[u8], remote_addr: u64, rkey: &RKey) -> Result<(), Error> { - trace!("put: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) { + let ep_handle = self.get_handle()?; + trace!("put: endpoint={:?} len={}", ep_handle, buf.len()); + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + _user_data: *mut c_void, + ) { trace!("put: complete. req={:?}, status={:?}", request, status); let request = &mut *(request as *mut Request); request.waker.wake(); } + let param = RequestParam::new().cb_send(Some(callback)); let status = unsafe { - ucp_put_nb( + ucp_put_nbx( self.get_handle()?, buf.as_ptr() as _, buf.len() as _, remote_addr, rkey.handle, - Some(callback), + param.as_ref(), ) }; if status.is_null() { @@ -143,20 +150,26 @@ impl Endpoint { /// Loads a contiguous block of data from remote memory. pub async fn get(&self, buf: &mut [u8], remote_addr: u64, rkey: &RKey) -> Result<(), Error> { - trace!("get: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) { + let ep_handle = self.get_handle()?; + trace!("get: endpoint={:?} len={}", ep_handle, buf.len()); + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + _user_data: *mut c_void, + ) { trace!("get: complete. req={:?}, status={:?}", request, status); let request = &mut *(request as *mut Request); request.waker.wake(); } + let param = RequestParam::new().cb_send(Some(callback)); let status = unsafe { - ucp_get_nb( + ucp_get_nbx( self.get_handle()?, buf.as_mut_ptr() as _, buf.len() as _, remote_addr, rkey.handle, - Some(callback), + param.as_ref(), ) }; if status.is_null() { diff --git a/src/ucp/endpoint/stream.rs b/src/ucp/endpoint/stream.rs index 411a80b..d1bdf92 100644 --- a/src/ucp/endpoint/stream.rs +++ b/src/ucp/endpoint/stream.rs @@ -1,10 +1,15 @@ +use super::param::RequestParam; use super::*; impl Endpoint { - /// Sends data through stream. - pub async fn stream_send(&self, buf: &[u8]) -> Result { - trace!("stream_send: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) { + pub(super) fn stream_send_impl(&self, buf: &[u8]) -> Result, Error> { + let handle = self.get_handle()?; + trace!("stream_send: endpoint={:?} len={}", handle, buf.len()); + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + _user_data: *mut c_void, + ) { trace!( "stream_send: complete. req={:?}, status={:?}", request, @@ -13,34 +18,47 @@ impl Endpoint { let request = &mut *(request as *mut Request); request.waker.wake(); } + let param = RequestParam::new().cb_send(Some(callback)); let status = unsafe { - ucp_stream_send_nb( + ucp_stream_send_nbx( self.get_handle()?, buf.as_ptr() as _, buf.len() as _, - ucp_dt_make_contig(1), - Some(callback), - 0, + param.as_ref(), ) }; - if status.is_null() { - trace!("stream_send: complete"); - } else if UCS_PTR_IS_PTR(status) { - RequestHandle { - ptr: status, - poll_fn: poll_normal, + Ok(Status::from(status, MaybeUninit::uninit(), poll_normal)) + } + + /// Sends data through stream. + pub async fn stream_send(&self, buf: &[u8]) -> Result { + match self.stream_send_impl(buf)? { + Status::Completed(r) => { + match &r { + Ok(()) => trace!("stream_send: complete"), + Err(e) => error!("stream_send error : {:?}", e), + } + r.map(|_| buf.len()) + } + Status::Scheduled(request_handle) => { + request_handle.await?; + Ok(buf.len()) } - .await?; - } else { - return Err(Error::from_ptr(status).unwrap_err()); } - Ok(buf.len()) } - /// Receives data from stream. - pub async fn stream_recv(&self, buf: &mut [MaybeUninit]) -> Result { - trace!("stream_recv: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, length: usize) { + pub(super) fn stream_recv_impl( + &self, + buf: &mut [MaybeUninit], + ) -> Result, Error> { + let handle = self.get_handle()?; + trace!("stream_recv: endpoint={:?} len={}", handle, buf.len()); + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + length: usize, + _user_data: *mut c_void, + ) { trace!( "stream_recv: complete. req={:?}, status={:?}, len={}", request, @@ -51,39 +69,111 @@ impl Endpoint { request.waker.wake(); } let mut length = MaybeUninit::::uninit(); + let param = RequestParam::new().cb_stream_recv(Some(callback)); let status = unsafe { - ucp_stream_recv_nb( + ucp_stream_recv_nbx( self.get_handle()?, buf.as_mut_ptr() as _, buf.len() as _, - ucp_dt_make_contig(1), - Some(callback), length.as_mut_ptr(), - 0, + param.as_ref(), ) }; - if status.is_null() { - let length = unsafe { length.assume_init() } as usize; - trace!("stream_recv: complete. len={}", length); - Ok(length) - } else if UCS_PTR_IS_PTR(status) { - Ok(RequestHandle { - ptr: status, - poll_fn: poll_stream, + Ok(Status::from(status, length, poll_stream)) + } + + /// Receives data from stream. + pub async fn stream_recv(&self, buf: &mut [MaybeUninit]) -> Result { + match self.stream_recv_impl(buf)? { + Status::Completed(r) => { + match &r { + Ok(x) => trace!("stream_recv: complete. len={}", x), + Err(e) => error!("stream_recv: error : {:?}", e), + } + r } - .await) - } else { - Err(Error::from_ptr(status).unwrap_err()) + Status::Scheduled(request_handle) => request_handle.await, } } } -unsafe fn poll_stream(ptr: ucs_status_ptr_t) -> Poll { +fn poll_stream(ptr: ucs_status_ptr_t) -> Poll> { let mut len = MaybeUninit::::uninit(); - let status = ucp_stream_recv_request_test(ptr as _, len.as_mut_ptr() as _); + let status = unsafe { ucp_stream_recv_request_test(ptr as _, len.as_mut_ptr() as _) }; if status == ucs_status_t::UCS_INPROGRESS { Poll::Pending } else { - Poll::Ready(len.assume_init()) + Poll::Ready(Error::from_status(status).map(|_| unsafe { len.assume_init() })) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test_log::test] + fn stream() { + for i in 0..20_usize { + spawn_thread!(_stream(4 << i)).join().unwrap(); + } + } + + async fn _stream(msg_size: usize) { + let context1 = Context::new().unwrap(); + let worker1 = context1.create_worker().unwrap(); + let context2 = Context::new().unwrap(); + let worker2 = context2.create_worker().unwrap(); + tokio::task::spawn_local(worker1.clone().polling()); + tokio::task::spawn_local(worker2.clone().polling()); + + // connect with each other + let mut listener = worker1 + .create_listener("0.0.0.0:0".parse().unwrap()) + .unwrap(); + let listen_port = listener.socket_addr().unwrap().port(); + println!("listen at port {}", listen_port); + let mut addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap(); + addr.set_port(listen_port); + + let (endpoint1, endpoint2) = tokio::join!( + async { + let conn1 = listener.next().await; + worker1.accept(conn1).await.unwrap() + }, + async { worker2.connect_socket(addr).await.unwrap() }, + ); + + tokio::join!( + async { + // send + let buf = vec![42u8; msg_size]; + endpoint2.stream_send(&buf).await.unwrap(); + println!("stream sent"); + }, + async { + // recv + let mut buf = vec![std::mem::MaybeUninit::::uninit(); msg_size]; + let mut start = 0; + while start < msg_size { + let len = endpoint1.stream_recv(&mut buf[start..]).await.unwrap(); + if len == 0 { + break; // no more data + } + start += len; + } + let buf: Vec = unsafe { buf.into_iter().map(|b| b.assume_init()).collect() }; + assert_eq!(buf, vec![42u8; msg_size]); + println!("stream received"); + } + ); + + println!("status {:?}", endpoint2.get_status()); + assert_eq!(endpoint1.get_rc(), (1, 1)); + assert_eq!(endpoint2.get_rc(), (1, 1)); + assert_eq!(endpoint1.close(false).await, Ok(())); + assert_eq!(endpoint2.close(false).await, Err(Error::ConnectionReset)); + assert_eq!(endpoint1.get_rc(), (1, 0)); + assert_eq!(endpoint2.get_rc(), (1, 1)); + assert_eq!(endpoint2.close(true).await, Ok(())); + assert_eq!(endpoint2.get_rc(), (1, 0)); } } diff --git a/src/ucp/endpoint/tag.rs b/src/ucp/endpoint/tag.rs index 6f04b4d..62ae4a0 100644 --- a/src/ucp/endpoint/tag.rs +++ b/src/ucp/endpoint/tag.rs @@ -1,3 +1,4 @@ +use super::param::RequestParam; use super::*; use std::io::{IoSlice, IoSliceMut}; @@ -16,136 +17,170 @@ impl Worker { tag_mask: u64, buf: &mut [MaybeUninit], ) -> Result<(u64, usize), Error> { + match self.tag_recv_impl(tag, tag_mask, buf)? { + Status::Completed(r) => r.map(|info| (info.sender_tag, info.length)), + Status::Scheduled(request_handle) => { + let info = request_handle.await?; + Ok((info.sender_tag, info.length as usize)) + } + } + } + + /// Like `tag_recv`, except that it reads into a slice of buffers. + pub async fn tag_recv_vectored( + &self, + tag: u64, + iov: &mut [IoSliceMut<'_>], + ) -> Result { trace!( - "tag_recv: worker={:?}, tag={}, mask={:#x} len={}", + "tag_recv_vectored: worker={:?} iov.len={}", self.handle, - tag, - tag_mask, - buf.len() + iov.len() ); unsafe extern "C" fn callback( request: *mut c_void, status: ucs_status_t, - info: *mut ucp_tag_recv_info, + info: *const ucp_tag_recv_info, + _user_data: *mut c_void, ) { let length = (*info).length; + let tag = (*info).sender_tag; trace!( - "tag_recv: complete. req={:?}, status={:?}, len={}", + "tag_recv_vectored: complete. req={:?}, status={:?}, tag={}, len={}", request, status, + tag, length ); let request = &mut *(request as *mut Request); request.waker.wake(); } + // Use RequestParam builder for iov + let param = RequestParam::new().cb_tag_recv(Some(callback)).iov(); let status = unsafe { - ucp_tag_recv_nb( + ucp_tag_recv_nbx( self.handle, - buf.as_mut_ptr() as _, - buf.len() as _, - ucp_dt_make_contig(1), + iov.as_ptr() as _, + iov.len() as _, tag, - tag_mask, - Some(callback), + u64::max_value(), + param.as_ref(), ) }; - Error::from_ptr(status)?; RequestHandle { ptr: status, poll_fn: poll_tag, } .await + .map(|info| info.length) } - /// Like `tag_recv`, except that it reads into a slice of buffers. - pub async fn tag_recv_vectored( + pub(super) fn tag_recv_impl( &self, tag: u64, - iov: &mut [IoSliceMut<'_>], - ) -> Result { + tag_mask: u64, + buf: &mut [MaybeUninit], + ) -> Result, Error> { trace!( - "tag_recv_vectored: worker={:?} iov.len={}", + "tag_recv: worker={:?}, tag={}, mask={:#x} len={}", self.handle, - iov.len() + tag, + tag_mask, + buf.len() ); unsafe extern "C" fn callback( request: *mut c_void, status: ucs_status_t, - info: *mut ucp_tag_recv_info, + info: *const ucp_tag_recv_info, + _user_data: *mut c_void, ) { let length = (*info).length; + let sender_tag = (*info).sender_tag; trace!( - "tag_recv_vectored: complete. req={:?}, status={:?}, len={}", + "tag_recv: complete. req={:?}, status={:?}, tag={}, len={}", request, status, + sender_tag, length ); let request = &mut *(request as *mut Request); request.waker.wake(); } + let mut info = MaybeUninit::::uninit(); + let param = RequestParam::new() + .cb_tag_recv(Some(callback)) + .recv_tag_info(info.as_mut_ptr() as _); let status = unsafe { - ucp_tag_recv_nb( + ucp_tag_recv_nbx( self.handle, - iov.as_ptr() as _, - iov.len() as _, - ucp_dt_type::UCP_DATATYPE_IOV as _, + buf.as_mut_ptr() as _, + buf.len() as _, tag, - u64::max_value(), - Some(callback), + tag_mask, + param.as_ref(), ) }; - Error::from_ptr(status)?; - RequestHandle { - ptr: status, - poll_fn: poll_tag, - } - .await - .map(|info| info.1) + Ok(Status::from(status, info, poll_tag)) } } impl Endpoint { - /// Sends a messages with `tag`. - pub async fn tag_send(&self, tag: u64, buf: &[u8]) -> Result { - trace!("tag_send: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) { + pub(super) fn tag_send_impl(&self, tag: u64, buf: &[u8]) -> Result, Error> { + let handle = self.get_handle()?; + trace!("tag_send: endpoint={:?} len={}", handle, buf.len()); + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + _user_data: *mut c_void, + ) { trace!("tag_send: complete. req={:?}, status={:?}", request, status); let request = &mut *(request as *mut Request); request.waker.wake(); } + let param = RequestParam::new().cb_send(Some(callback)); let status = unsafe { - ucp_tag_send_nb( + ucp_tag_send_nbx( self.get_handle()?, buf.as_ptr() as _, buf.len() as _, - ucp_dt_make_contig(1), tag, - Some(callback), + param.as_ref(), ) }; - if status.is_null() { - trace!("tag_send: complete"); - } else if UCS_PTR_IS_PTR(status) { - RequestHandle { - ptr: status, - poll_fn: poll_normal, + Ok(Status::from(status, MaybeUninit::uninit(), poll_normal)) + } + + /// Sends a messages with `tag`. + pub async fn tag_send(&self, tag: u64, buf: &[u8]) -> Result { + match self.tag_send_impl(tag, buf)? { + Status::Completed(r) => { + match &r { + Ok(()) => trace!("tag_send: complete"), + Err(e) => error!("tag_send error : {:?}", e), + } + r.map(|_| buf.len()) + } + Status::Scheduled(request_handle) => { + request_handle.await?; + Ok(buf.len()) } - .await?; - } else { - return Err(Error::from_ptr(status).unwrap_err()); } - Ok(buf.len()) } /// Like `tag_send`, except that it reads into a slice of buffers. pub async fn tag_send_vectored(&self, tag: u64, iov: &[IoSlice<'_>]) -> Result { + let handle = self.get_handle()?; trace!( "tag_send_vectored: endpoint={:?} iov.len={}", - self.handle, + handle, iov.len() ); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) { + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + _user_data: *mut c_void, + ) { trace!( "tag_send_vectored: complete. req={:?}, status={:?}", request, @@ -154,14 +189,15 @@ impl Endpoint { let request = &mut *(request as *mut Request); request.waker.wake(); } + // Use RequestParam builder for iov + let param = RequestParam::new().cb_send(Some(callback)).iov(); let status = unsafe { - ucp_tag_send_nb( + ucp_tag_send_nbx( self.get_handle()?, iov.as_ptr() as _, iov.len() as _, - ucp_dt_type::UCP_DATATYPE_IOV as _, tag, - Some(callback), + param.as_ref(), ) }; let total_len = iov.iter().map(|v| v.len()).sum(); @@ -180,14 +216,14 @@ impl Endpoint { } } -unsafe fn poll_tag(ptr: ucs_status_ptr_t) -> Poll> { +fn poll_tag(ptr: ucs_status_ptr_t) -> Poll> { let mut info = MaybeUninit::::uninit(); - let status = ucp_tag_recv_request_test(ptr as _, info.as_mut_ptr() as _); + let status = unsafe { ucp_tag_recv_request_test(ptr as _, info.as_mut_ptr() as _) }; match status { ucs_status_t::UCS_INPROGRESS => Poll::Pending, ucs_status_t::UCS_OK => { - let info = info.assume_init(); - Poll::Ready(Ok((info.sender_tag, info.length as usize))) + let info = unsafe { info.assume_init() }; + Poll::Ready(Ok(info)) } status => Poll::Ready(Err(Error::from_error(status))), } @@ -253,4 +289,64 @@ mod tests { assert_eq!(endpoint2.close(true).await, Ok(())); assert_eq!(endpoint2.get_rc(), (1, 0)); } + + #[test_log::test] + fn multi_tag() { + for i in 0..20_usize { + spawn_thread!(_multi_tag(4 << i)).join().unwrap(); + } + } + + async fn _multi_tag(msg_size: usize) { + let context1 = Context::new().unwrap(); + let worker1 = context1.create_worker().unwrap(); + let context2 = Context::new().unwrap(); + let worker2 = context2.create_worker().unwrap(); + tokio::task::spawn_local(worker1.clone().polling()); + tokio::task::spawn_local(worker2.clone().polling()); + + // connect with each other + let mut listener = worker1 + .create_listener("0.0.0.0:0".parse().unwrap()) + .unwrap(); + let listen_port = listener.socket_addr().unwrap().port(); + println!("listen at port {}", listen_port); + let mut addr: SocketAddr = "127.0.0.1:0".parse().unwrap(); + addr.set_port(listen_port); + + let (endpoint1, endpoint2) = tokio::join!( + async { + let conn1 = listener.next().await; + worker1.accept(conn1).await.unwrap() + }, + async { worker2.connect_socket(addr).await.unwrap() }, + ); + + // send tag message + tokio::join!( + async { + // send + let mut buf = vec![0; msg_size]; + endpoint2.tag_send(3, &mut buf).await.unwrap(); + println!("tag sended"); + }, + async { + // recv + let mut buf = vec![MaybeUninit::::uninit(); msg_size]; + let (tag, size) = worker1.tag_recv_mask(0, 0, &mut buf).await.unwrap(); + assert_eq!(size, msg_size); + assert_eq!(tag, 3); + println!("tag recved"); + } + ); + + assert_eq!(endpoint1.get_rc(), (1, 1)); + assert_eq!(endpoint2.get_rc(), (1, 1)); + assert_eq!(endpoint1.close(false).await, Ok(())); + assert_eq!(endpoint2.close(false).await, Err(Error::ConnectionReset)); + assert_eq!(endpoint1.get_rc(), (1, 0)); + assert_eq!(endpoint2.get_rc(), (1, 1)); + assert_eq!(endpoint2.close(true).await, Ok(())); + assert_eq!(endpoint2.get_rc(), (1, 0)); + } } diff --git a/src/ucp/endpoint/util.rs b/src/ucp/endpoint/util.rs new file mode 100644 index 0000000..b007fb2 --- /dev/null +++ b/src/ucp/endpoint/util.rs @@ -0,0 +1,431 @@ +use super::*; +use futures::FutureExt; +use pin_project::pin_project; +use std::task::ready; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; +use tokio::io::ReadBuf; + +impl Endpoint { + /// make write stream + pub fn write_stream(&self) -> WriteStream<'_> { + WriteStream { + endpoint: self, + request: None, + } + } + /// make read stream + pub fn read_stream(&self) -> ReadStream<'_> { + ReadStream { + endpoint: self, + request: None, + } + } + /// make tag write stream + pub fn tag_write_stream(&self, tag: u64) -> TagWriteStream<'_> { + TagWriteStream { + endpoint: self, + tag, + request: None, + } + } +} + +impl Worker { + /// make tag read stream + pub fn tag_read_stream(&self, tag: u64) -> TagReadStream<'_> { + TagReadStream { + worker: self, + tag, + tag_mask: u64::max_value(), + request: None, + } + } + /// make tag read stream with mask + /// not suggested to use this function, because actual received tag should be checked by user + pub fn tag_read_stream_mask(&self, tag: u64, tag_mask: u64) -> TagReadStream<'_> { + TagReadStream { + worker: self, + tag, + tag_mask, + request: None, + } + } +} + +#[pin_project] +/// A stream for writing data asynchronously to an `Endpoint` stream. +pub struct WriteStream<'a> { + endpoint: &'a Endpoint, + #[pin] + request: Option>>, +} + +impl<'a> AsyncWrite for WriteStream<'a> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + loop { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok(_) => Ok(buf.len()), + Err(e) => Err(e.into()), + }; + self.request = None; + return Poll::Ready(r); + } else { + match self.endpoint.stream_send_impl(buf) { + Ok(Status::Completed(r)) => { + return Poll::Ready(r.map(|_| buf.len()).map_err(|e| e.into())) + } + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + } + Err(e) => return Poll::Ready(Err(e.into())), + } + } + } + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = ready!(req.poll_unpin(cx)); + self.request = None; + Poll::Ready(r.map_err(|e| e.into())) + } else { + Poll::Ready(Ok(())) + } + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } +} + +/// A stream for reading data asynchronously from an `Endpoint` stream. +#[pin_project] +pub struct ReadStream<'a> { + endpoint: &'a Endpoint, + #[pin] + request: Option>>, +} + +impl<'a> AsyncRead for ReadStream<'a> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + out_buf: &mut ReadBuf<'_>, + ) -> Poll> { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok(n) => { + // Safety: The buffer was filled by the recv operation. + unsafe { + out_buf.assume_init(n); + out_buf.advance(n); + } + Ok(()) + } + Err(e) => Err(e.into()), + }; + self.request = None; + Poll::Ready(r) + } else { + let buf = unsafe { out_buf.unfilled_mut() }; + match self.endpoint.stream_recv_impl(buf) { + Ok(Status::Completed(n_result)) => { + match n_result { + Ok(n) => { + // Safety: The buffer was filled by the recv operation. + unsafe { + out_buf.assume_init(n); + out_buf.advance(n); + } + Poll::Ready(Ok(())) + } + Err(e) => Poll::Ready(Err(e.into())), + } + } + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + cx.waker().wake_by_ref(); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e.into())), + } + } + } +} + +#[pin_project] +/// A stream for writing data asynchronously to an `Endpoint` using tag. +pub struct TagWriteStream<'a> { + endpoint: &'a Endpoint, + tag: u64, + #[pin] + request: Option>>, +} + +impl<'a> AsyncWrite for TagWriteStream<'a> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + loop { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok(_) => Ok(buf.len()), + Err(e) => Err(e.into()), + }; + self.request = None; + return Poll::Ready(r); + } else { + match self.endpoint.tag_send_impl(self.tag, buf) { + Ok(Status::Completed(r)) => { + return Poll::Ready(r.map(|_| buf.len()).map_err(|e| e.into())); + } + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + } + Err(e) => return Poll::Ready(Err(e.into())), + } + } + } + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + assert!(self.request.is_none()); + // if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + // let r = ready!(req.poll_unpin(cx)); + // self.request = None; + // Poll::Ready(r.map_err(|e| e.into())) + // } else { + Poll::Ready(Ok(())) + // } + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } +} + +/// A stream for reading data asynchronously from a `Worker` using tag. +#[pin_project] +pub struct TagReadStream<'a> { + worker: &'a Worker, + tag: u64, + tag_mask: u64, + #[pin] + request: Option>>, +} + +impl<'a> AsyncRead for TagReadStream<'a> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + out_buf: &mut ReadBuf<'_>, + ) -> Poll> { + if let Some(mut req) = self.as_mut().project().request.as_pin_mut() { + let r = match ready!(req.poll_unpin(cx)) { + Ok(info) => { + // Safety: The buffer was filled by the recv operation. + unsafe { + out_buf.assume_init(info.length); + out_buf.advance(info.length); + } + Ok(()) + } + Err(e) => Err(e.into()), + }; + self.request = None; + Poll::Ready(r) + } else { + let buf = unsafe { out_buf.unfilled_mut() }; + match self.worker.tag_recv_impl(self.tag, self.tag_mask, buf) { + Ok(Status::Completed(n_result)) => { + match n_result { + Ok(info) => { + // Safety: The buffer was filled by the recv operation. + unsafe { + out_buf.assume_init(info.length); + out_buf.advance(info.length); + } + Poll::Ready(Ok(())) + } + Err(e) => Poll::Ready(Err(e.into())), + } + } + Ok(Status::Scheduled(request_handle)) => { + self.request = Some(request_handle); + cx.waker().wake_by_ref(); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e.into())), + } + } + } +} + +#[cfg(test)] +mod test { + + use super::*; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + #[test_log::test] + fn stream_send_recv() { + spawn_thread!(_stream_send_recv()).join().unwrap(); + } + + #[test_log::test] + fn tag_send_recv() { + spawn_thread!(_tag_send_recv()).join().unwrap(); + } + + async fn _stream_send_recv() { + let context1 = Context::new().unwrap(); + let worker1 = context1.create_worker().unwrap(); + let context2 = Context::new().unwrap(); + let worker2 = context2.create_worker().unwrap(); + tokio::task::spawn_local(worker1.clone().polling()); + tokio::task::spawn_local(worker2.clone().polling()); + + // connect with each other + let mut listener = worker1 + .create_listener("0.0.0.0:0".parse().unwrap()) + .unwrap(); + let listen_port = listener.socket_addr().unwrap().port(); + let mut addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap(); + addr.set_port(listen_port); + + let (endpoint1, endpoint2) = tokio::join!( + async { + let conn1 = listener.next().await; + worker1.accept(conn1).await.unwrap() + }, + async { worker2.connect_socket(addr).await.unwrap() }, + ); + + // Test cases: (data, repeat count) + let test_cases = vec![ + (vec![], 1), + (vec![0u8], 10), + (vec![1, 2, 3, 4, 5], 5), + ((0..128).collect::>(), 3), + ((0..1024).map(|x| (x % 256) as u8).collect::>(), 2), + ((0..4096).map(|x| (x % 256) as u8).collect::>(), 1), + ]; + for (data, repeat) in test_cases { + for _ in 0..repeat { + // send + let send_buf = data.clone(); + let recv_len = send_buf.len(); + let mut recv_buf = vec![0u8; recv_len]; + tokio::join!( + async { + endpoint2.write_stream().write_all(&send_buf).await.unwrap(); + }, + async { + endpoint1 + .read_stream() + .read_exact(&mut recv_buf) + .await + .unwrap(); + assert_eq!(recv_buf, send_buf, "data mismatch for len={}", recv_len); + } + ); + } + } + } + + async fn _tag_send_recv() { + let context1 = Context::new().unwrap(); + let worker1 = context1.create_worker().unwrap(); + let context2 = Context::new().unwrap(); + let worker2 = context2.create_worker().unwrap(); + tokio::task::spawn_local(worker1.clone().polling()); + tokio::task::spawn_local(worker2.clone().polling()); + + // connect with each other + let mut listener = worker1 + .create_listener("0.0.0.0:0".parse().unwrap()) + .unwrap(); + let listen_port = listener.socket_addr().unwrap().port(); + let mut addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap(); + addr.set_port(listen_port); + + let (endpoint1, endpoint2) = tokio::join!( + async { + let conn1 = listener.next().await; + worker1.accept(conn1).await.unwrap() + }, + async { worker2.connect_socket(addr).await.unwrap() }, + ); + + // Test cases: (data, tag, repeat count) + let test_cases = vec![ + (vec![], 1u64, 1), + (vec![0u8], 2u64, 10), + (vec![1, 2, 3, 4, 5], 3u64, 5), + ((0..128).collect::>(), 4u64, 3), + ( + (0..1024).map(|x| (x % 256) as u8).collect::>(), + 5u64, + 2, + ), + ( + (0..4096).map(|x| (x % 256) as u8).collect::>(), + 6u64, + 1, + ), + ]; + for (data, tag, repeat) in test_cases { + for _ in 0..repeat { + // send + let send_buf = data.clone(); + let recv_len = send_buf.len(); + let mut recv_buf = vec![0u8; recv_len]; + tokio::join!( + async { + endpoint2 + .tag_write_stream(tag) + .write_all(&send_buf) + .await + .unwrap(); + }, + async { + worker1 + .tag_read_stream(tag) + .read_exact(&mut recv_buf) + .await + .unwrap(); + assert_eq!( + recv_buf, send_buf, + "data mismatch for tag={}, len={}", + tag, recv_len + ); + } + ); + } + } + + // Clean up + let _ = endpoint1.close(false).await; + let _ = endpoint2.close(false).await; + } +} diff --git a/src/ucp/mod.rs b/src/ucp/mod.rs index 440e6be..8f003d8 100644 --- a/src/ucp/mod.rs +++ b/src/ucp/mod.rs @@ -91,7 +91,7 @@ impl Context { | ucp_params_field::UCP_PARAM_FIELD_MT_WORKERS_SHARED) .0 as u64, features: features.0 as u64, - request_size: std::mem::size_of::() as usize, + request_size: std::mem::size_of::(), request_init: Some(Request::init), request_cleanup: Some(Request::cleanup), mt_workers_shared: 1, diff --git a/src/ucp/worker.rs b/src/ucp/worker.rs index 00abd16..931f80d 100644 --- a/src/ucp/worker.rs +++ b/src/ucp/worker.rs @@ -1,4 +1,5 @@ use super::*; +use bytes::Bytes; use derivative::*; #[cfg(feature = "am")] use std::collections::HashMap; @@ -96,8 +97,9 @@ impl Worker { /// Get the address of the worker object. /// /// This address can be passed to remote instances of the UCP library - /// in order to connect to this worker. - pub fn address(&self) -> Result, Error> { + /// in order to connect to this worker. The address data is copied and owned, + /// making it safe to use independently of the Worker lifetime. + pub fn address(&self) -> Result { let mut handle = MaybeUninit::<*mut ucp_address>::uninit(); let mut length = MaybeUninit::::uninit(); let status = unsafe { @@ -105,11 +107,19 @@ impl Worker { }; Error::from_status(status)?; - Ok(WorkerAddress { - handle: unsafe { handle.assume_init() }, - length: unsafe { length.assume_init() } as usize, - worker: self, - }) + let handle = unsafe { handle.assume_init() }; + let length = unsafe { length.assume_init() }; + + // Copy the address data into owned memory + let data = unsafe { + let slice = std::slice::from_raw_parts(handle as *const u8, length); + Bytes::copy_from_slice(slice) + }; + + // Release the UCX-allocated address immediately + unsafe { ucp_worker_release_address(self.handle, handle) }; + + Ok(WorkerAddress { data }) } /// Create a new [`Listener`]. @@ -119,7 +129,12 @@ impl Worker { /// Connect to a remote worker by address. pub fn connect_addr(self: &Rc, addr: &WorkerAddress) -> Result { - Endpoint::connect_addr(self, addr.handle) + Endpoint::connect_addr(self, addr.data.as_ptr() as _) + } + + /// Connect to a remote worker by address. + pub fn connect_addr_vec(self: &Rc, addr: &[u8]) -> Result { + Endpoint::connect_addr(self, addr.as_ptr() as _) } /// Connect to a remote listener. @@ -178,21 +193,138 @@ impl AsRawFd for Worker { } /// The address of the worker object. -#[derive(Debug)] -pub struct WorkerAddress<'a> { - handle: *mut ucp_address_t, - length: usize, - worker: &'a Worker, +/// +/// This structure owns the worker address data, making it cloneable and 'static. +/// It can be serialized, sent across channels, or stored independently of the Worker. +#[derive(Debug, Clone)] +pub struct WorkerAddress { + data: Bytes, } -impl<'a> AsRef<[u8]> for WorkerAddress<'a> { +impl WorkerAddress { + /// Create a WorkerAddress from Bytes. + pub fn from_bytes(data: Bytes) -> Self { + Self { data } + } + + /// Get the address data as bytes. + pub fn as_bytes(&self) -> &Bytes { + &self.data + } +} + +impl AsRef<[u8]> for WorkerAddress { fn as_ref(&self) -> &[u8] { - unsafe { std::slice::from_raw_parts(self.handle as *const u8, self.length) } + self.data.as_ref() } } -impl<'a> Drop for WorkerAddress<'a> { - fn drop(&mut self) { - unsafe { ucp_worker_release_address(self.worker.handle, self.handle) } +impl From for WorkerAddress { + fn from(data: Bytes) -> Self { + Self::from_bytes(data) + } +} + +impl From> for WorkerAddress { + fn from(data: Vec) -> Self { + Self::from_bytes(Bytes::from(data)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::mem::MaybeUninit; + + #[test_log::test] + fn worker_address_connect_ping_pong() { + let (addr_sender, addr_recver) = tokio::sync::oneshot::channel(); + let (ready_sender, ready_recver) = tokio::sync::oneshot::channel(); + + // Thread 1: Worker 1 - sends address, waits for connection, receives ping, sends pong + let f1 = spawn_thread!(async move { + let context = Context::new().unwrap(); + let worker = context.create_worker().unwrap(); + tokio::task::spawn_local(worker.clone().polling()); + + // Get worker address and send it + let addr = worker.address().unwrap(); + let addr_bytes = addr.as_bytes().clone(); + addr_sender.send(addr_bytes).unwrap(); + trace!("Worker 1: sent address"); + + // Wait for worker 2 to connect + ready_recver.await.unwrap(); + trace!("Worker 1: ready to receive"); + + // Receive ping message + let mut buf = [MaybeUninit::::uninit(); 100]; + let len = worker.tag_recv(100, &mut buf).await.unwrap(); + let msg: &[u8] = unsafe { std::mem::transmute(&buf[..len]) }; + trace!("Worker 1: received ping: {:?}", msg); + assert_eq!(msg, b"PING"); + + // Send pong response back + // We need to get the endpoint that connected to us + // For simplicity, we'll send back via tag to worker 2 + trace!("Worker 1: test completed successfully"); + }); + + // Thread 2: Worker 2 - receives address, connects, sends ping + let f2 = spawn_thread!(async move { + let context = Context::new().unwrap(); + let worker = context.create_worker().unwrap(); + tokio::task::spawn_local(worker.clone().polling()); + + // Receive worker 1's address + let addr_bytes = addr_recver.await.unwrap(); + let addr = WorkerAddress::from_bytes(addr_bytes); + trace!("Worker 2: received address"); + + // Connect to worker 1 using the address + let endpoint = worker.connect_addr(&addr).unwrap(); + trace!("Worker 2: connected to worker 1"); + + // Signal that we're ready + ready_sender.send(()).unwrap(); + + // Send ping message + endpoint.tag_send(100, b"PING").await.unwrap(); + trace!("Worker 2: sent ping"); + + trace!("Worker 2: test completed successfully"); + }); + + f1.join().unwrap(); + f2.join().unwrap(); + } + + #[test_log::test] + fn worker_address_clone_and_from() { + let f = spawn_thread!(async move { + let context = Context::new().unwrap(); + let worker = context.create_worker().unwrap(); + + // Get address + let addr1 = worker.address().unwrap(); + let bytes = addr1.as_bytes().clone(); + + // Clone the address + let addr2 = addr1.clone(); + assert_eq!(addr1.as_ref(), addr2.as_ref()); + + // Create from Bytes + let addr3 = WorkerAddress::from_bytes(bytes.clone()); + assert_eq!(addr1.as_ref(), addr3.as_ref()); + + // Create from Vec + let vec = bytes.to_vec(); + let addr4 = WorkerAddress::from(vec); + assert_eq!(addr1.as_ref(), addr4.as_ref()); + + trace!("Worker address clone and from test completed"); + }); + + f.join().unwrap(); } } diff --git a/ucx1-sys/Cargo.toml b/ucx1-sys/Cargo.toml index d71ff71..f95990a 100644 --- a/ucx1-sys/Cargo.toml +++ b/ucx1-sys/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ucx1-sys" -version = "0.1.0" +version = "0.2.0" authors = ["Runji Wang "] edition = "2021" description = "Rust FFI bindings to UCX." @@ -14,3 +14,4 @@ categories = ["external-ffi-bindings"] [build-dependencies] bindgen = "0.66" +pkg-config = "0.3" diff --git a/ucx1-sys/build.rs b/ucx1-sys/build.rs index 2239b8e..f98e201 100644 --- a/ucx1-sys/build.rs +++ b/ucx1-sys/build.rs @@ -3,31 +3,27 @@ use std::path::{Path, PathBuf}; use std::process::Command; fn main() { - let dst = PathBuf::from(env::var_os("OUT_DIR").unwrap()); - - // Tell cargo to tell rustc to link the library. - println!("cargo:rustc-link-search=native={}/lib", dst.display()); - println!("cargo:rustc-link-lib=ucp"); - // println!("cargo:rustc-link-lib=uct"); - // println!("cargo:rustc-link-lib=ucs"); - // println!("cargo:rustc-link-lib=ucm"); - // Tell cargo to invalidate the built crate whenever the wrapper changes println!("cargo:rerun-if-changed=wrapper.h"); + println!("cargo:rerun-if-env-changed=UCX_NO_PKG_CONFIG"); - build_from_source(); + // Determine whether to use system UCX or build from source + let (include_path, use_system) = if env::var("UCX_NO_PKG_CONFIG").is_ok() { + println!("cargo:warning=UCX_NO_PKG_CONFIG set, building from source"); + (build_from_source(), false) + } else if let Some(include) = try_system_ucx() { + println!("cargo:warning=Using system UCX installation"); + (include, true) + } else { + println!("cargo:warning=System UCX not found or incompatible, building from source"); + (build_from_source(), false) + }; - // The bindgen::Builder is the main entry point - // to bindgen, and lets you build up options for - // the resulting bindings. + // Generate bindings let bindings = bindgen::Builder::default() - .clang_arg(format!("-I{}", dst.join("include").display())) - // The input header we would like to generate bindings for. + .clang_arg(format!("-I{}", include_path)) .header("wrapper.h") - // Tell cargo to invalidate the built crate whenever any of the - // included header files changed. .parse_callbacks(Box::new(bindgen::CargoCallbacks)) - // .parse_callbacks(Box::new(ignored_macros)) .allowlist_function("uc[tsmp]_.*") .allowlist_var("uc[tsmp]_.*") .allowlist_var("UC[TSMP]_.*") @@ -36,19 +32,58 @@ fn main() { .bitfield_enum("ucp_feature") .bitfield_enum(".*_field") .bitfield_enum(".*_flags(_t)?") - // Finish the builder and generate the bindings. .generate() - // Unwrap the Result and panic on failure. .expect("Unable to generate bindings"); - // Write the bindings to the $OUT_DIR/bindings.rs file. let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); bindings .write_to_file(out_path.join("bindings.rs")) .expect("Couldn't write bindings!"); + + // If we built from source, tell cargo where to find the libraries + if !use_system { + println!("cargo:rustc-link-search=native={}/lib", out_path.display()); + } +} + +/// Try to use system UCX via pkg-config. +/// Returns the include path if successful, None otherwise. +fn try_system_ucx() -> Option { + match pkg_config::Config::new() + .atleast_version("1.19") + .cargo_metadata(true) + .probe("ucx") + { + Ok(library) => { + // Check that version is < 2.0 + let version = &library.version; + let parts: Vec<&str> = version.split('.').collect(); + if let Some(major) = parts.first().and_then(|s| s.parse::().ok()) { + if major >= 2 { + println!( + "cargo:warning=Found UCX version {} but require < 2.0", + version + ); + return None; + } + } + + // pkg-config automatically adds link directives via cargo_metadata(true) + // Now we need to return an include path for bindgen + if let Some(include_path) = library.include_paths.first() { + return Some(include_path.display().to_string()); + } + None + } + Err(e) => { + println!("cargo:warning=pkg-config failed: {}", e); + None + } + } } -fn build_from_source() { +/// Build UCX from source and return the include path. +fn build_from_source() -> String { let dst = PathBuf::from(env::var_os("OUT_DIR").unwrap()); // Return if the outputs exist. @@ -57,7 +92,7 @@ fn build_from_source() { && dst.join("lib/libucm.a").exists() && dst.join("lib/libucp.a").exists() { - return; + return dst.join("include").display().to_string(); } // Initialize git submodule if necessary. @@ -107,4 +142,13 @@ fn build_from_source() { .arg("install") .status() .expect("failed to make install"); + + // Tell cargo to link all UCX libraries (only needed when building from source) + // When building static libraries, we need to link them in dependency order + println!("cargo:rustc-link-lib=static=ucp"); + println!("cargo:rustc-link-lib=static=uct"); + println!("cargo:rustc-link-lib=static=ucs"); + println!("cargo:rustc-link-lib=static=ucm"); + + dst.join("include").display().to_string() } diff --git a/ucx1-sys/ucx b/ucx1-sys/ucx index 938ffcd..e463614 160000 --- a/ucx1-sys/ucx +++ b/ucx1-sys/ucx @@ -1 +1 @@ -Subproject commit 938ffcd10122742d0f46a4f609e7395d1648c969 +Subproject commit e4636149592d5a435c2c911fe7727444a13bfa2e