diff --git a/.travis.yml b/.travis.yml index ef043758..f2266f4e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -26,7 +26,7 @@ matrix: - clang-11 - musl-tools rust: - - stable + - nightly-2023-05-07 env: - RUST_BACKTRACE=1 - CFLAGS_x86_64_fortanix_unknown_sgx="-isystem/usr/include/x86_64-linux-gnu -mlvi-hardening -mllvm -x86-experimental-lvi-inline-asm-hardening" @@ -44,7 +44,8 @@ matrix: - rustup toolchain add nightly - rustup target add x86_64-fortanix-unknown-sgx --toolchain nightly script: - - cargo test --verbose --locked --all --exclude sgxs-loaders && [ "$(echo $(nm -D target/debug/sgx-detect|grep __vdso_sgx_enter_enclave))" = "w __vdso_sgx_enter_enclave" ] + - cargo test --verbose --locked --all --exclude sgxs-loaders --exclude async-usercalls && [ "$(echo $(nm -D target/debug/sgx-detect|grep __vdso_sgx_enter_enclave))" = "w __vdso_sgx_enter_enclave" ] + - cargo test --verbose --locked -p async-usercalls --target x86_64-fortanix-unknown-sgx --no-run - cargo test --verbose --locked -p dcap-ql --features link - cargo test --verbose --locked -p dcap-ql --features verify - cargo test --verbose --locked -p ias --features mbedtls diff --git a/Cargo.lock b/Cargo.lock index 8dd725f5..5fc3ff48 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -60,6 +60,17 @@ version = "1.0.47" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38d9ff5d688f1c13395289f67db01d4826b46dd694e7580accdc3e8430f2d98e" +[[package]] +name = "async-usercalls" +version = "0.5.0" +dependencies = [ + "crossbeam-channel", + "fnv", + "fortanix-sgx-abi", + "ipc-queue", + "lazy_static", +] + [[package]] name = "atty" version = "0.2.14" @@ -506,26 +517,26 @@ dependencies = [ [[package]] name = "crossbeam" -version = "0.7.3" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69323bff1fb41c635347b8ead484a5ca6c3f11914d784170b158d8449ab07f8e" +checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c" dependencies = [ - "cfg-if 0.1.10", + "cfg-if 1.0.0", "crossbeam-channel", - "crossbeam-deque", - "crossbeam-epoch", - "crossbeam-queue", - "crossbeam-utils", + "crossbeam-deque 0.8.3", + "crossbeam-epoch 0.9.15", + "crossbeam-queue 0.3.8", + "crossbeam-utils 0.8.16", ] [[package]] name = "crossbeam-channel" -version = "0.4.4" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b153fe7cbef478c567df0f972e02e6d736db11affe43dfc9c56a9374d1adfb87" +checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200" dependencies = [ - "crossbeam-utils", - "maybe-uninit", + "cfg-if 1.0.0", + "crossbeam-utils 0.8.16", ] [[package]] @@ -534,11 +545,22 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f02af974daeee82218205558e51ec8768b48cf524bd01d550abe5573a608285" dependencies = [ - "crossbeam-epoch", - "crossbeam-utils", + "crossbeam-epoch 0.8.2", + "crossbeam-utils 0.7.2", "maybe-uninit", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" +dependencies = [ + "cfg-if 1.0.0", + "crossbeam-epoch 0.9.15", + "crossbeam-utils 0.8.16", +] + [[package]] name = "crossbeam-epoch" version = "0.8.2" @@ -547,13 +569,26 @@ checksum = "058ed274caafc1f60c4997b5fc07bf7dc7cca454af7c6e81edffe5f33f70dace" dependencies = [ "autocfg 1.0.1", "cfg-if 0.1.10", - "crossbeam-utils", + "crossbeam-utils 0.7.2", "lazy_static", "maybe-uninit", "memoffset 0.5.6", "scopeguard", ] +[[package]] +name = "crossbeam-epoch" +version = "0.9.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" +dependencies = [ + "autocfg 1.0.1", + "cfg-if 1.0.0", + "crossbeam-utils 0.8.16", + "memoffset 0.9.0", + "scopeguard", +] + [[package]] name = "crossbeam-queue" version = "0.2.3" @@ -561,10 +596,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "774ba60a54c213d409d5353bda12d49cd68d14e45036a285234c8d6f91f92570" dependencies = [ "cfg-if 0.1.10", - "crossbeam-utils", + "crossbeam-utils 0.7.2", "maybe-uninit", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add" +dependencies = [ + "cfg-if 1.0.0", + "crossbeam-utils 0.8.16", +] + [[package]] name = "crossbeam-utils" version = "0.7.2" @@ -576,6 +621,15 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "crossbeam-utils" +version = "0.8.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" +dependencies = [ + "cfg-if 1.0.0", +] + [[package]] name = "crypto-hash" version = "0.3.4" @@ -877,7 +931,7 @@ dependencies = [ "failure", "failure_derive", "fnv", - "fortanix-sgx-abi 0.4.1", + "fortanix-sgx-abi", "futures 0.3.17", "ipc-queue", "lazy_static", @@ -1087,12 +1141,6 @@ dependencies = [ "percent-encoding 2.1.0", ] -[[package]] -name = "fortanix-sgx-abi" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "816a38bd53bd5c87dd7edf4f15a2ee6b989ad7a5b5e616b75d70de64ad2a1329" - [[package]] name = "fortanix-sgx-abi" version = "0.5.0" @@ -1764,7 +1812,7 @@ dependencies = [ name = "ipc-queue" version = "0.2.0" dependencies = [ - "fortanix-sgx-abi 0.4.1", + "fortanix-sgx-abi", "futures 0.3.17", "static_assertions", "tokio 0.2.22", @@ -2010,6 +2058,15 @@ dependencies = [ "autocfg 1.0.1", ] +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg 1.0.1", +] + [[package]] name = "mime" version = "0.2.6" @@ -2837,7 +2894,7 @@ version = "2.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6653d384a260fedff0a466e894e05c5b8d75e261a14e9f93e81e43ef86cad23" dependencies = [ - "log 0.3.9", + "log 0.4.14", "which 4.0.2", ] @@ -3949,7 +4006,7 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fb2d1b8f4548dbf5e1f7818512e9c406860678f29c300cdf0ebac72d1a3a1671" dependencies = [ - "crossbeam-utils", + "crossbeam-utils 0.7.2", "futures 0.1.30", ] @@ -4014,7 +4071,7 @@ version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09bc590ec4ba8ba87652da2068d150dcada2cfa2e07faae270a5e0409aa51351" dependencies = [ - "crossbeam-utils", + "crossbeam-utils 0.7.2", "futures 0.1.30", "lazy_static", "log 0.4.14", @@ -4057,9 +4114,9 @@ version = "0.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df720b6581784c118f0eb4310796b12b1d242a7eb95f716a8367855325c25f89" dependencies = [ - "crossbeam-deque", - "crossbeam-queue", - "crossbeam-utils", + "crossbeam-deque 0.7.3", + "crossbeam-queue 0.2.3", + "crossbeam-utils 0.7.2", "futures 0.1.30", "lazy_static", "log 0.4.14", @@ -4074,7 +4131,7 @@ version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "93044f2d313c95ff1cb7809ce9a7a05735b012288a888b62d4434fd58c94f296" dependencies = [ - "crossbeam-utils", + "crossbeam-utils 0.7.2", "futures 0.1.30", "slab", "tokio-executor", diff --git a/Cargo.toml b/Cargo.toml index 5195a54b..90acffde 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ members = [ "fortanix-vme/tests/iron", "fortanix-vme/vme-pkix", "intel-sgx/aesm-client", + "intel-sgx/async-usercalls", "intel-sgx/dcap-provider", "intel-sgx/dcap-ql-sys", "intel-sgx/dcap-ql", diff --git a/doc/generate-api-docs.sh b/doc/generate-api-docs.sh index b25386ff..ded6b9a7 100755 --- a/doc/generate-api-docs.sh +++ b/doc/generate-api-docs.sh @@ -58,6 +58,9 @@ for LIB in $LIBS_SORTED; do if FEATURES="$(cargo read-manifest|jq -r '.metadata.docs.rs.features | join(",")' 2> /dev/null)"; then ARGS="--features $FEATURES" fi + if grep -q 'feature(sgx_platform)' ./src/lib.rs; then + ARGS+=" --target x86_64-fortanix-unknown-sgx" + fi cargo doc --no-deps --lib $ARGS popd fi diff --git a/intel-sgx/async-usercalls/Cargo.toml b/intel-sgx/async-usercalls/Cargo.toml new file mode 100644 index 00000000..442d71aa --- /dev/null +++ b/intel-sgx/async-usercalls/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "async-usercalls" +version = "0.5.0" +authors = ["Fortanix, Inc."] +license = "MPL-2.0" +edition = "2018" +description = """ +An interface for asynchronous usercalls in SGX enclaves. +This is an SGX-only crate, you should compile it with the `x86_64-fortanix-unknown-sgx` target. +""" +repository = "https://github.com/fortanix/rust-sgx" +documentation = "https://edp.fortanix.com/docs/api/async_usercalls/" +homepage = "https://edp.fortanix.com/" +keywords = ["sgx", "async", "usercall"] +categories = ["asynchronous"] + +[dependencies] +# Project dependencies +ipc-queue = { version = "0.2", path = "../../ipc-queue" } +fortanix-sgx-abi = { version = "0.5.0", path = "../fortanix-sgx-abi" } + +# External dependencies +lazy_static = "1.4.0" # MIT/Apache-2.0 +crossbeam-channel = "0.5" # MIT/Apache-2.0 +fnv = "1.0" # MIT/Apache-2.0 + +# For cargo test --target x86_64-fortanix-unknown-sgx +[package.metadata.fortanix-sgx] +threads = 128 diff --git a/intel-sgx/async-usercalls/src/batch_drop.rs b/intel-sgx/async-usercalls/src/batch_drop.rs new file mode 100644 index 00000000..2f7bb698 --- /dev/null +++ b/intel-sgx/async-usercalls/src/batch_drop.rs @@ -0,0 +1,126 @@ +use crate::provider_core::ProviderCore; +use ipc_queue::Identified; +use std::cell::RefCell; +use std::mem; +use std::os::fortanix_sgx::usercalls::alloc::{User, UserSafe}; +use std::os::fortanix_sgx::usercalls::raw::{Usercall, UsercallNrs}; + +pub trait BatchDroppable: private::BatchDroppable {} +impl BatchDroppable for T {} + +/// Drop the given value at some point in the future (no rush!). This is useful +/// for freeing userspace memory when we don't particularly care about when the +/// buffer is freed. Multiple `free` usercalls are batched together and sent to +/// userspace asynchronously. It is also guaranteed that the memory is freed if +/// the current thread exits before there is a large enough batch. +/// +/// This is mainly an optimization to avoid exitting the enclave for each +/// usercall. Note that even when sending usercalls asynchronously, if the +/// usercall queue is empty we still need to exit the enclave to signal the +/// userspace that the queue is not empty anymore. The batch send would send +/// multiple usercalls and notify the userspace at most once. +pub fn batch_drop(t: T) { + t.batch_drop(); +} + +mod private { + use super::*; + + const BATCH_SIZE: usize = 8; + + struct BatchDropProvider { + core: ProviderCore, + deferred: Vec>, + } + + impl BatchDropProvider { + pub fn new() -> Self { + Self { + core: ProviderCore::new(None), + deferred: Vec::with_capacity(BATCH_SIZE), + } + } + + fn make_progress(&self, deferred: &[Identified]) -> usize { + let sent = self.core.try_send_multiple_usercalls(deferred); + if sent == 0 { + self.core.send_usercall(deferred[0]); + return 1; + } + sent + } + + fn maybe_send_usercall(&mut self, u: Usercall) { + self.deferred.push(self.core.assign_id(u)); + if self.deferred.len() < BATCH_SIZE { + return; + } + let sent = self.make_progress(&self.deferred); + let mut not_sent = self.deferred.split_off(sent); + self.deferred.clear(); + self.deferred.append(&mut not_sent); + } + + pub fn free(&mut self, buf: User) { + let ptr = buf.into_raw(); + let size = unsafe { mem::size_of_val(&mut *ptr) }; + let alignment = T::align_of(); + let ptr = ptr as *mut u8; + let u = Usercall(UsercallNrs::free as _, ptr as _, size as _, alignment as _, 0); + self.maybe_send_usercall(u); + } + } + + impl Drop for BatchDropProvider { + fn drop(&mut self) { + let mut sent = 0; + while sent < self.deferred.len() { + sent += self.make_progress(&self.deferred[sent..]); + } + } + } + + std::thread_local! { + static PROVIDER: RefCell = RefCell::new(BatchDropProvider::new()); + } + + pub trait BatchDroppable { + fn batch_drop(self); + } + + impl BatchDroppable for User { + fn batch_drop(self) { + PROVIDER.with(|p| p.borrow_mut().free(self)); + } + } +} + +#[cfg(test)] +mod tests { + use super::batch_drop; + use std::os::fortanix_sgx::usercalls::alloc::User; + use std::thread; + + #[test] + fn basic() { + for _ in 0..100 { + batch_drop(User::<[u8]>::uninitialized(100)); + } + } + + #[test] + fn multiple_threads() { + const THREADS: usize = 16; + let mut handles = Vec::with_capacity(THREADS); + for _ in 0..THREADS { + handles.push(thread::spawn(move || { + for _ in 0..1000 { + batch_drop(User::<[u8]>::uninitialized(100)); + } + })); + } + for h in handles { + h.join().unwrap(); + } + } +} diff --git a/intel-sgx/async-usercalls/src/callback.rs b/intel-sgx/async-usercalls/src/callback.rs new file mode 100644 index 00000000..46e2fded --- /dev/null +++ b/intel-sgx/async-usercalls/src/callback.rs @@ -0,0 +1,65 @@ +use fortanix_sgx_abi::{invoke_with_usercalls, Fd, Result}; +use std::io; +use std::os::fortanix_sgx::usercalls::raw::{Return, ReturnValue}; +use std::os::fortanix_sgx::usercalls::FromSgxResult; + +pub struct CbFn(Box); + +impl CbFn { + fn call(self, t: T) { + (self.0)(t); + } +} + +impl From for CbFn + where + F: FnOnce(T) + Send + 'static, +{ + fn from(f: F) -> Self { + Self(Box::new(f)) + } +} + +macro_rules! cbfn_type { + ( ) => { CbFn<()> }; + ( -> ! ) => { () }; + ( -> u64 ) => { CbFn }; + ( -> (Result, usize) ) => { CbFn> }; + ( -> (Result, u64) ) => { CbFn> }; + ( -> (Result, Fd) ) => { CbFn> }; + ( -> (Result, *mut u8) ) => { CbFn> }; + ( -> Result ) => { CbFn> }; +} + +macro_rules! call_cbfn { + ( $cb:ident, $rv:expr, ) => { let x: () = $rv; $cb.call(x); }; + ( $cb:ident, $rv:expr, -> ! ) => { let _: ! = $rv; }; + ( $cb:ident, $rv:expr, -> u64 ) => { let x: u64 = $rv; $cb.call(x); }; + ( $cb:ident, $rv:expr, -> $t:ty ) => { let x: $t = $rv; $cb.call(x.from_sgx_result()); }; +} + +macro_rules! define_callback { + ($(fn $name:ident($($n:ident: $t:ty),*) $(-> $r:tt)*; )*) => { + #[allow(unused)] + #[allow(non_camel_case_types)] + pub(crate) enum Callback { + $( $name(cbfn_type! { $(-> $r)* }), )* + } + + impl Callback { + pub(crate) fn call(self, ret: Return) { + match self {$( + Callback::$name(_cb) => { + call_cbfn!( + _cb, + ReturnValue::from_registers(stringify!($name), (ret.0, ret.1)), + $(-> $r)* + ); + } + )*} + } + } + }; +} + +invoke_with_usercalls!(define_callback); \ No newline at end of file diff --git a/intel-sgx/async-usercalls/src/io_bufs.rs b/intel-sgx/async-usercalls/src/io_bufs.rs new file mode 100644 index 00000000..e039aef9 --- /dev/null +++ b/intel-sgx/async-usercalls/src/io_bufs.rs @@ -0,0 +1,324 @@ +use std::cell::UnsafeCell; +use std::cmp; +use std::io::IoSlice; +use std::ops::{Deref, DerefMut, Range}; +use std::os::fortanix_sgx::usercalls::alloc::{User, UserRef}; +use std::sync::Arc; + +pub struct UserBuf(UserBufKind); + +enum UserBufKind { + Owned { + user: User<[u8]>, + range: Range, + }, + Shared { + user: Arc>>, + range: Range, + }, +} + +impl UserBuf { + pub fn into_user(self) -> Result, Self> { + match self.0 { + UserBufKind::Owned { user, .. } => Ok(user), + UserBufKind::Shared { user, range } => Err(Self(UserBufKind::Shared { user, range })), + } + } + + fn into_shared(self) -> Option>>> { + match self.0 { + UserBufKind::Owned { .. } => None, + UserBufKind::Shared { user, .. } => Some(user), + } + } +} + +unsafe impl Send for UserBuf {} + +impl Deref for UserBuf { + type Target = UserRef<[u8]>; + + fn deref(&self) -> &Self::Target { + match self.0 { + UserBufKind::Owned { ref user, ref range } => &user[range.start..range.end], + UserBufKind::Shared { ref user, ref range } => { + let user = unsafe { &*user.get() }; + &user[range.start..range.end] + } + } + } +} + +impl DerefMut for UserBuf { + fn deref_mut(&mut self) -> &mut Self::Target { + match self.0 { + UserBufKind::Owned { + ref mut user, + ref range, + } => &mut user[range.start..range.end], + UserBufKind::Shared { ref user, ref range } => { + let user = unsafe { &mut *user.get() }; + &mut user[range.start..range.end] + } + } + } +} + +impl From> for UserBuf { + fn from(user: User<[u8]>) -> Self { + UserBuf(UserBufKind::Owned { + range: 0..user.len(), + user, + }) + } +} + +impl From<(User<[u8]>, Range)> for UserBuf { + fn from(pair: (User<[u8]>, Range)) -> Self { + UserBuf(UserBufKind::Owned { + user: pair.0, + range: pair.1, + }) + } +} + +/// `WriteBuffer` provides a ring buffer that can be written to by the code +/// running in the enclave while a portion of it can be passed to a `write` +/// usercall running concurrently. It ensures that enclave code does not write +/// to the portion sent to userspace. +pub struct WriteBuffer { + userbuf: Arc>>, + buf_len: usize, + read: u32, + write: u32, +} + +unsafe impl Send for WriteBuffer {} + +impl WriteBuffer { + pub fn new(userbuf: User<[u8]>) -> Self { + Self { + buf_len: userbuf.len(), + userbuf: Arc::new(UnsafeCell::new(userbuf)), + read: 0, + write: 0, + } + } + + pub fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> usize { + if self.is_full() { + return 0; + } + let mut wrote = 0; + for buf in bufs { + wrote += self.write(buf); + } + wrote + } + + pub fn write(&mut self, buf: &[u8]) -> usize { + let (_, write_offset) = self.offsets(); + let rem = self.remaining_capacity(); + let can_write = cmp::min(buf.len(), rem); + let end = cmp::min(self.buf_len, write_offset + can_write); + let n = end - write_offset; + unsafe { + let userbuf = &mut *self.userbuf.get(); + userbuf[write_offset..write_offset + n].copy_from_enclave(&buf[..n]); + } + self.advance_write(n); + n + if n < can_write { self.write(&buf[n..]) } else { 0 } + } + + /// This function returns a slice of bytes appropriate for writing to a socket. + /// Once some or all of these bytes are successfully written to the socket, + /// `self.consume()` must be called to actually consume those bytes. + /// + /// Returns None if the buffer is empty. + /// + /// Panics if called more than once in a row without either calling `consume()` + /// or dropping the previously returned buffer. + pub fn consumable_chunk(&mut self) -> Option { + assert!( + Arc::strong_count(&self.userbuf) == 1, + "called consumable_chunk() more than once in a row" + ); + let range = match self.offsets() { + (_, _) if self.read == self.write => return None, // empty + (r, w) if r < w => r..w, + (r, _) => r..self.buf_len, + }; + Some(UserBuf(UserBufKind::Shared { + user: self.userbuf.clone(), + range, + })) + } + + /// Mark `n` bytes as consumed. `buf` must have been produced by a call + /// to `self.consumable_chunk()`. + /// Panics if: + /// - `n > buf.len()` + /// - `buf` was not produced by `self.consumable_chunk()` + /// + /// This function is supposed to be used in conjunction with `consumable_chunk()`. + pub fn consume(&mut self, buf: UserBuf, n: usize) { + assert!(n <= buf.len()); + const PANIC_MESSAGE: &'static str = "`buf` not produced by self.consumable_chunk()"; + let buf = buf.into_shared().expect(PANIC_MESSAGE); + assert!(Arc::ptr_eq(&self.userbuf, &buf), "{}", PANIC_MESSAGE); + drop(buf); + assert!(Arc::strong_count(&self.userbuf) == 1, "{}", PANIC_MESSAGE); + self.advance_read(n); + } + + fn len(&self) -> usize { + match self.offsets() { + (_, _) if self.read == self.write => 0, // empty + (r, w) if r == w && self.read != self.write => self.buf_len, // full + (r, w) if r < w => w - r, + (r, w) => w + self.buf_len - r, + } + } + + fn remaining_capacity(&self) -> usize { + let len = self.len(); + debug_assert!(len <= self.buf_len); + self.buf_len - len + } + + fn offsets(&self) -> (usize, usize) { + (self.read as usize % self.buf_len, self.write as usize % self.buf_len) + } + + pub fn is_empty(&self) -> bool { + self.read == self.write + } + + fn is_full(&self) -> bool { + let (read_offset, write_offset) = self.offsets(); + read_offset == write_offset && self.read != self.write + } + + fn advance_read(&mut self, by: usize) { + debug_assert!(by <= self.len()); + self.read = ((self.read as usize + by) % (self.buf_len * 2)) as _; + } + + fn advance_write(&mut self, by: usize) { + debug_assert!(by <= self.remaining_capacity()); + self.write = ((self.write as usize + by) % (self.buf_len * 2)) as _; + } +} + +pub struct ReadBuffer { + userbuf: User<[u8]>, + position: usize, + len: usize, +} + +impl ReadBuffer { + /// Constructs a new `ReadBuffer`, assuming `len` bytes of `userbuf` have + /// meaningful data. Panics if `len > userbuf.len()`. + pub fn new(userbuf: User<[u8]>, len: usize) -> ReadBuffer { + assert!(len <= userbuf.len()); + ReadBuffer { + userbuf, + position: 0, + len, + } + } + + pub fn read(&mut self, buf: &mut [u8]) -> usize { + debug_assert!(self.position <= self.len); + if self.position == self.len { + return 0; + } + let n = cmp::min(buf.len(), self.len - self.position); + self.userbuf[self.position..self.position + n].copy_to_enclave(&mut buf[..n]); + self.position += n; + n + } + + /// Returns the number of bytes that have not been read yet. + pub fn remaining_bytes(&self) -> usize { + debug_assert!(self.position <= self.len); + self.len - self.position + } + + pub fn len(&self) -> usize { + self.len + } + + /// Consumes self and returns the internal userspace buffer. + /// It's the caller's responsibility to ensure all bytes have been read + /// before calling this function. + pub fn into_inner(self) -> User<[u8]> { + self.userbuf + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::os::fortanix_sgx::usercalls::alloc::User; + + #[test] + fn write_buffer_basic() { + const LENGTH: usize = 1024; + let mut write_buffer = WriteBuffer::new(User::<[u8]>::uninitialized(1024)); + + let buf = vec![0u8; LENGTH]; + assert_eq!(write_buffer.write(&buf), LENGTH); + assert_eq!(write_buffer.write(&buf), 0); + + let chunk = write_buffer.consumable_chunk().unwrap(); + write_buffer.consume(chunk, 200); + assert_eq!(write_buffer.write(&buf), 200); + assert_eq!(write_buffer.write(&buf), 0); + } + + #[test] + #[should_panic] + fn call_consumable_chunk_twice() { + const LENGTH: usize = 1024; + let mut write_buffer = WriteBuffer::new(User::<[u8]>::uninitialized(1024)); + + let buf = vec![0u8; LENGTH]; + assert_eq!(write_buffer.write(&buf), LENGTH); + assert_eq!(write_buffer.write(&buf), 0); + + let chunk1 = write_buffer.consumable_chunk().unwrap(); + let _ = write_buffer.consumable_chunk().unwrap(); + drop(chunk1); + } + + #[test] + #[should_panic] + fn consume_wrong_buf() { + const LENGTH: usize = 1024; + let mut write_buffer = WriteBuffer::new(User::<[u8]>::uninitialized(1024)); + + let buf = vec![0u8; LENGTH]; + assert_eq!(write_buffer.write(&buf), LENGTH); + assert_eq!(write_buffer.write(&buf), 0); + + let unrelated_buf: UserBuf = User::<[u8]>::uninitialized(512).into(); + write_buffer.consume(unrelated_buf, 100); + } + + #[test] + fn read_buffer_basic() { + let mut buf = User::<[u8]>::uninitialized(64); + const DATA: &'static [u8] = b"hello"; + buf[0..DATA.len()].copy_from_enclave(DATA); + + let mut read_buffer = ReadBuffer::new(buf, DATA.len()); + assert_eq!(read_buffer.len(), DATA.len()); + assert_eq!(read_buffer.remaining_bytes(), DATA.len()); + let mut buf = [0u8; 8]; + assert_eq!(read_buffer.read(&mut buf), DATA.len()); + assert_eq!(read_buffer.remaining_bytes(), 0); + assert_eq!(&buf, b"hello\0\0\0"); + } +} \ No newline at end of file diff --git a/intel-sgx/async-usercalls/src/lib.rs b/intel-sgx/async-usercalls/src/lib.rs new file mode 100644 index 00000000..2ab4701f --- /dev/null +++ b/intel-sgx/async-usercalls/src/lib.rs @@ -0,0 +1,433 @@ +//! This crate provides an interface for performing asynchronous usercalls in +//! SGX enclaves. The motivation behind asynchronous usercalls and ABI +//! documentation can be found +//! [here](https://edp.fortanix.com/docs/api/fortanix_sgx_abi/async/index.html). +//! The API provided here is fairly low level and is not meant for general use. +//! These APIs can be used to implement [mio] abstractions which in turn +//! allows us to use [tokio] in SGX enclaves! +//! +//! The main interface is provided through `AsyncUsercallProvider` which works +//! in tandem with `CallbackHandler`: +//! ``` +//! use async_usercalls::AsyncUsercallProvider; +//! use std::{io::Result, net::TcpStream, sync::mpsc, time::Duration}; +//! +//! let (provider, callback_handler) = AsyncUsercallProvider::new(); +//! let (tx, rx) = mpsc::sync_channel(1); +//! // The closure is called when userspace sends back the result of the +//! // usercall. +//! let cancel_handle = provider.connect_stream("www.example.com:80", move |res| { +//! tx.send(res).unwrap(); +//! }); +//! // We can cancel the connect usercall using `cancel_handle.cancel()`, but +//! // note that we may still get a successful result. +//! // We need to poll `callback_handler` to make progress. +//! loop { +//! let n = callback_handler.poll(Some(Duration::from_millis(100))); +//! if n > 0 { +//! break; // at least 1 callback function was executed! +//! } +//! } +//! let connect_result: Result = rx.recv().unwrap(); +//! ``` +//! +//! [mio]: https://docs.rs/mio/latest/mio/ +//! [tokio]: https://docs.rs/tokio/latest/tokio/ + +#![feature(sgx_platform)] +#![feature(never_type)] +#![cfg_attr(test, feature(unboxed_closures))] +#![cfg_attr(test, feature(fn_traits))] + +use crossbeam_channel as mpmc; +use ipc_queue::Identified; +use std::collections::HashMap; +use std::os::fortanix_sgx::usercalls::raw::{Cancel, Return, Usercall}; +use std::sync::Mutex; +use std::time::Duration; + +mod batch_drop; +mod callback; +mod io_bufs; +mod provider_api; +mod provider_core; +mod queues; +mod raw; +#[cfg(test)] +mod test_support; +mod utils; + +pub use self::batch_drop::batch_drop; +pub use self::callback::CbFn; +pub use self::io_bufs::{ReadBuffer, UserBuf, WriteBuffer}; +pub use self::raw::RawApi; + +use self::callback::*; +use self::provider_core::ProviderCore; +use self::queues::*; + +pub struct CancelHandle(Identified); + +impl CancelHandle { + pub fn cancel(self) { + PROVIDERS + .cancel_sender() + .send(self.0) + .expect("failed to send cancellation"); + } + + pub(crate) fn new(c: Identified) -> Self { + CancelHandle(c) + } +} + +/// This type provides a mechanism for submitting usercalls asynchronously. +/// Usercalls are sent to the enclave runner through a queue. The results are +/// retrieved when `CallbackHandler::poll` is called. Users are notified of the +/// results through callback functions. +/// +/// Users of this type should take care not to block execution in callbacks. +/// Certain usercalls can be cancelled through a handle, but note that it is +/// still possible to receive successful results for cancelled usercalls. +pub struct AsyncUsercallProvider { + core: ProviderCore, + callback_tx: mpmc::Sender<(u64, Callback)>, +} + +impl AsyncUsercallProvider { + pub fn new() -> (Self, CallbackHandler) { + let (return_tx, return_rx) = mpmc::unbounded(); + let core = ProviderCore::new(Some(return_tx)); + let callbacks = Mutex::new(HashMap::new()); + let (callback_tx, callback_rx) = mpmc::unbounded(); + let provider = Self { core, callback_tx }; + let waker = CallbackHandlerWaker::new(); + let handler = CallbackHandler { + return_rx, + callbacks, + callback_rx, + waker, + }; + (provider, handler) + } + + #[cfg(test)] + pub(crate) fn provider_id(&self) -> u32 { + self.core.provider_id() + } + + fn send_usercall(&self, usercall: Usercall, callback: Option) -> CancelHandle { + let usercall = self.core.assign_id(usercall); + if let Some(callback) = callback { + self.callback_tx + .send((usercall.id, callback)) + .expect("failed to send callback"); + } + self.core.send_usercall(usercall) + } +} + +#[derive(Clone)] +pub struct CallbackHandlerWaker { + rx: mpmc::Receiver<()>, + tx: mpmc::Sender<()>, +} + +impl CallbackHandlerWaker { + fn new() -> Self { + let (tx, rx) = mpmc::bounded(1); + Self { tx, rx } + } + + /// Interrupts the currently running or a future call to the related + /// CallbackHandler's `poll()`. + pub fn wake(&self) { + let _ = self.tx.try_send(()); + } + + /// Clears the effect of a previous call to `self.wake()` that is not yet + /// observed by `CallbackHandler::poll()`. + pub fn clear(&self) { + let _ = self.rx.try_recv(); + } +} + +pub struct CallbackHandler { + return_rx: mpmc::Receiver>, + callbacks: Mutex>, + // This is used so that threads sending usercalls don't have to take the lock. + callback_rx: mpmc::Receiver<(u64, Callback)>, + waker: CallbackHandlerWaker, +} + +impl CallbackHandler { + const RECV_BATCH_SIZE: usize = 128; + + // Returns an object that can be used to interrupt a blocked `self.poll()`. + pub fn waker(&self) -> CallbackHandlerWaker { + self.waker.clone() + } + + #[inline] + fn recv_returns(&self, timeout: Option, returns: &mut [Identified]) -> usize { + let first = match timeout { + None => mpmc::select! { + recv(self.return_rx) -> res => res.ok(), + recv(self.waker.rx) -> _res => return 0, + }, + Some(timeout) => mpmc::select! { + recv(self.return_rx) -> res => res.ok(), + recv(self.waker.rx) -> _res => return 0, + default(timeout) => return 0, + }, + } + .expect("return channel closed unexpectedly"); + let mut count = 0; + for ret in std::iter::once(first).chain(self.return_rx.try_iter().take(returns.len() - 1)) { + returns[count] = ret; + count += 1; + } + count + } + + /// Poll for returned usercalls and execute their respective callback + /// functions. If `timeout` is `None`, it will block execution until at + /// least one return is received, otherwise it will block until there is a + /// return or timeout is elapsed. Returns the number of executed callbacks. + /// This can be interrupted using `CallbackHandlerWaker::wake()`. + pub fn poll(&self, timeout: Option) -> usize { + // 1. wait for returns + let mut returns = [Identified::default(); Self::RECV_BATCH_SIZE]; + let returns = match self.recv_returns(timeout, &mut returns) { + 0 => return 0, + n => &returns[..n], + }; + // 2. try to lock the mutex, if successful, receive all pending callbacks and put them in the hash map + let mut guard = match self.callbacks.try_lock() { + Ok(mut callbacks) => { + for (id, cb) in self.callback_rx.try_iter() { + callbacks.insert(id, cb); + } + callbacks + } + _ => self.callbacks.lock().unwrap(), + }; + // 3. remove callbacks for returns received in step 1 from the hash map + let mut ret_callbacks = Vec::with_capacity(returns.len()); + for ret in returns { + let cb = guard.remove(&ret.id); + ret_callbacks.push((ret, cb)); + } + drop(guard); + // 4. execute the callbacks without hugging the mutex + let mut count = 0; + for (ret, cb) in ret_callbacks { + if let Some(cb) = cb { + cb.call(ret.data); + count += 1; + } + } + count + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_support::*; + use crate::utils::MakeSend; + use crossbeam_channel as mpmc; + use std::io; + use std::net::{TcpListener, TcpStream}; + use std::os::fortanix_sgx::io::AsRawFd; + use std::os::fortanix_sgx::usercalls::alloc::User; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + use std::thread; + use std::time::Duration; + + #[test] + fn cancel_accept() { + let provider = AutoPollingProvider::new(); + let port = 6688; + let addr = format!("0.0.0.0:{}", port); + let (tx, rx) = mpmc::bounded(1); + provider.bind_stream(&addr, move |res| { + tx.send(res).unwrap(); + }); + let bind_res = rx.recv().unwrap(); + let listener = bind_res.unwrap(); + let fd = listener.as_raw_fd(); + let accept_count = Arc::new(AtomicUsize::new(0)); + let accept_count1 = Arc::clone(&accept_count); + let (tx, rx) = mpmc::bounded(1); + let accept = provider.accept_stream(fd, move |res| { + if let Ok(_) = res { + accept_count1.fetch_add(1, Ordering::Relaxed); + } + tx.send(()).unwrap(); + }); + accept.cancel(); + thread::sleep(Duration::from_millis(10)); + let _ = TcpStream::connect(&addr); + let _ = rx.recv(); + assert_eq!(accept_count.load(Ordering::Relaxed), 0); + } + + #[test] + fn connect() { + let listener = TcpListener::bind("0.0.0.0:0").unwrap(); + let addr = listener.local_addr().unwrap().to_string(); + let provider = AutoPollingProvider::new(); + let (tx, rx) = mpmc::bounded(1); + provider.connect_stream(&addr, move |res| { + tx.send(res).unwrap(); + }); + let res = rx.recv().unwrap(); + assert!(res.is_ok()); + } + + #[test] + fn safe_alloc_free() { + let provider = AutoPollingProvider::new(); + + const LEN: usize = 64 * 1024; + let (tx, rx) = mpmc::bounded(1); + provider.alloc_slice::(LEN, move |res| { + let buf = res.expect("failed to allocate memory"); + tx.send(MakeSend::new(buf)).unwrap(); + }); + let user_buf = rx.recv().unwrap().into_inner(); + assert_eq!(user_buf.len(), LEN); + + let (tx, rx) = mpmc::bounded(1); + let cb = move || { + tx.send(()).unwrap(); + }; + provider.free(user_buf, Some(cb)); + rx.recv().unwrap(); + } + + #[test] + fn callback_handler_waker() { + let (_provider, handler) = AsyncUsercallProvider::new(); + let waker = handler.waker(); + let (tx, rx) = mpmc::bounded(1); + let h = thread::spawn(move || { + let n1 = handler.poll(None); + tx.send(()).unwrap(); + let n2 = handler.poll(Some(Duration::from_secs(3))); + tx.send(()).unwrap(); + n1 + n2 + }); + for _ in 0..2 { + waker.wake(); + rx.recv().unwrap(); + } + assert_eq!(h.join().unwrap(), 0); + } + + #[test] + #[ignore] + fn echo() { + println!(); + let provider = Arc::new(AutoPollingProvider::new()); + const ADDR: &'static str = "0.0.0.0:7799"; + let (tx, rx) = mpmc::bounded(1); + provider.bind_stream(ADDR, move |res| { + tx.send(res).unwrap(); + }); + let bind_res = rx.recv().unwrap(); + let listener = bind_res.unwrap(); + println!("bind done: {:?}", listener); + let fd = listener.as_raw_fd(); + let cb = KeepAccepting { + listener, + provider: Arc::clone(&provider), + }; + provider.accept_stream(fd, cb); + thread::sleep(Duration::from_secs(60)); + } + + struct KeepAccepting { + listener: TcpListener, + provider: Arc, + } + + impl FnOnce<(io::Result,)> for KeepAccepting { + type Output = (); + + extern "rust-call" fn call_once(self, args: (io::Result,)) -> Self::Output { + let res = args.0; + println!("accept result: {:?}", res); + if let Ok(stream) = res { + let fd = stream.as_raw_fd(); + let cb = Echo { + stream, + read: true, + provider: self.provider.clone(), + }; + self.provider + .read(fd, User::<[u8]>::uninitialized(Echo::READ_BUF_SIZE), cb); + } + let provider = Arc::clone(&self.provider); + provider.accept_stream(self.listener.as_raw_fd(), self); + } + } + + struct Echo { + stream: TcpStream, + read: bool, + provider: Arc, + } + + impl Echo { + const READ_BUF_SIZE: usize = 1024; + + fn close(self) { + let fd = self.stream.as_raw_fd(); + println!("connection closed, fd = {}", fd); + self.provider.close(fd, None::>); + } + } + + // read callback + impl FnOnce<(io::Result, User<[u8]>)> for Echo { + type Output = (); + + extern "rust-call" fn call_once(mut self, args: (io::Result, User<[u8]>)) -> Self::Output { + let (res, user) = args; + assert!(self.read); + match res { + Ok(len) if len > 0 => { + self.read = false; + let provider = Arc::clone(&self.provider); + provider.write(self.stream.as_raw_fd(), (user, 0..len).into(), self); + } + _ => self.close(), + } + } + } + + // write callback + impl FnOnce<(io::Result, UserBuf)> for Echo { + type Output = (); + + extern "rust-call" fn call_once(mut self, args: (io::Result, UserBuf)) -> Self::Output { + let (res, _) = args; + assert!(!self.read); + match res { + Ok(len) if len > 0 => { + self.read = true; + let provider = Arc::clone(&self.provider); + provider.read( + self.stream.as_raw_fd(), + User::<[u8]>::uninitialized(Echo::READ_BUF_SIZE), + self, + ); + } + _ => self.close(), + } + } + } +} \ No newline at end of file diff --git a/intel-sgx/async-usercalls/src/provider_api.rs b/intel-sgx/async-usercalls/src/provider_api.rs new file mode 100644 index 00000000..7faf733d --- /dev/null +++ b/intel-sgx/async-usercalls/src/provider_api.rs @@ -0,0 +1,275 @@ +use crate::batch_drop; +use crate::io_bufs::UserBuf; +use crate::raw::RawApi; +use crate::utils::MakeSend; +use crate::{AsyncUsercallProvider, CancelHandle}; +use fortanix_sgx_abi::Fd; +use std::io; +use std::mem::{self, ManuallyDrop}; +use std::net::{TcpListener, TcpStream}; +use std::os::fortanix_sgx::io::{FromRawFd, TcpListenerMetadata, TcpStreamMetadata}; +use std::os::fortanix_sgx::usercalls::alloc::{User, UserRef, UserSafe}; +use std::os::fortanix_sgx::usercalls::raw::ByteBuffer; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +impl AsyncUsercallProvider { + /// Sends an asynchronous `read` usercall. `callback` is called when a + /// return value is received from userspace. `read_buf` is returned as an + /// argument to `callback` along with the result of the `read` usercall. + /// + /// Returns a handle that can be used to cancel the usercall if desired. + /// Please refer to the type-level documentation for general notes about + /// callbacks. + pub fn read(&self, fd: Fd, read_buf: User<[u8]>, callback: F) -> CancelHandle + where + F: FnOnce(io::Result, User<[u8]>) + Send + 'static, + { + let mut read_buf = ManuallyDrop::new(MakeSend::new(read_buf)); + let ptr = read_buf.as_mut_ptr(); + let len = read_buf.len(); + let cb = move |res: io::Result| { + let read_buf = ManuallyDrop::into_inner(read_buf).into_inner(); + callback(res, read_buf); + }; + unsafe { self.raw_read(fd, ptr, len, Some(cb.into())) } + } + + /// Sends an asynchronous `write` usercall. `callback` is called when a + /// return value is received from userspace. `write_buf` is returned as an + /// argument to `callback` along with the result of the `write` usercall. + /// + /// Returns a handle that can be used to cancel the usercall if desired. + /// Please refer to the type-level documentation for general notes about + /// callbacks. + pub fn write(&self, fd: Fd, write_buf: UserBuf, callback: F) -> CancelHandle + where + F: FnOnce(io::Result, UserBuf) + Send + 'static, + { + let mut write_buf = ManuallyDrop::new(write_buf); + let ptr = write_buf.as_mut_ptr(); + let len = write_buf.len(); + let cb = move |res| { + let write_buf = ManuallyDrop::into_inner(write_buf); + callback(res, write_buf); + }; + unsafe { self.raw_write(fd, ptr, len, Some(cb.into())) } + } + + /// Sends an asynchronous `flush` usercall. `callback` is called when a + /// return value is received from userspace. + /// + /// Please refer to the type-level documentation for general notes about + /// callbacks. + pub fn flush(&self, fd: Fd, callback: F) + where + F: FnOnce(io::Result<()>) + Send + 'static, + { + unsafe { + self.raw_flush(fd, Some(callback.into())); + } + } + + /// Sends an asynchronous `close` usercall. If specified, `callback` is + /// called when a return is received from userspace. + /// + /// Please refer to the type-level documentation for general notes about + /// callbacks. + pub fn close(&self, fd: Fd, callback: Option) + where + F: FnOnce() + Send + 'static, + { + let cb = callback.map(|callback| move |()| callback()); + unsafe { + self.raw_close(fd, cb.map(Into::into)); + } + } + + /// Sends an asynchronous `bind_stream` usercall. `callback` is called when + /// a return value is received from userspace. + /// + /// Please refer to the type-level documentation for general notes about + /// callbacks. + pub fn bind_stream(&self, addr: &str, callback: F) + where + F: FnOnce(io::Result) + Send + 'static, + { + let mut addr_buf = ManuallyDrop::new(MakeSend::new(User::<[u8]>::uninitialized(addr.len()))); + let mut local_addr_buf = ManuallyDrop::new(MakeSend::new(User::::uninitialized())); + + addr_buf[0..addr.len()].copy_from_enclave(addr.as_bytes()); + let addr_buf_ptr = addr_buf.as_raw_mut_ptr() as *mut u8; + let local_addr_ptr = local_addr_buf.as_raw_mut_ptr(); + + let cb = move |res: io::Result| { + let _addr_buf = ManuallyDrop::into_inner(addr_buf); + let local_addr_buf = ManuallyDrop::into_inner(local_addr_buf); + + let local_addr = Some(string_from_bytebuffer(&local_addr_buf, "bind_stream", "local_addr")); + let res = res.map(|fd| unsafe { TcpListener::from_raw_fd(fd, TcpListenerMetadata { local_addr }) }); + callback(res); + }; + unsafe { self.raw_bind_stream(addr_buf_ptr, addr.len(), local_addr_ptr, Some(cb.into())) } + } + + /// Sends an asynchronous `accept_stream` usercall. `callback` is called + /// when a return value is received from userspace. + /// + /// Returns a handle that can be used to cancel the usercall if desired. + /// Please refer to the type-level documentation for general notes about + /// callbacks. + pub fn accept_stream(&self, fd: Fd, callback: F) -> CancelHandle + where + F: FnOnce(io::Result) + Send + 'static, + { + let mut local_addr_buf = ManuallyDrop::new(MakeSend::new(User::::uninitialized())); + let mut peer_addr_buf = ManuallyDrop::new(MakeSend::new(User::::uninitialized())); + + let local_addr_ptr = local_addr_buf.as_raw_mut_ptr(); + let peer_addr_ptr = peer_addr_buf.as_raw_mut_ptr(); + + let cb = move |res: io::Result| { + let local_addr_buf = ManuallyDrop::into_inner(local_addr_buf); + let peer_addr_buf = ManuallyDrop::into_inner(peer_addr_buf); + + let local_addr = Some(string_from_bytebuffer(&*local_addr_buf, "accept_stream", "local_addr")); + let peer_addr = Some(string_from_bytebuffer(&*peer_addr_buf, "accept_stream", "peer_addr")); + let res = res.map(|fd| unsafe { TcpStream::from_raw_fd(fd, TcpStreamMetadata { local_addr, peer_addr }) }); + callback(res); + }; + unsafe { self.raw_accept_stream(fd, local_addr_ptr, peer_addr_ptr, Some(cb.into())) } + } + + /// Sends an asynchronous `connect_stream` usercall. `callback` is called + /// when a return value is received from userspace. + /// + /// Returns a handle that can be used to cancel the usercall if desired. + /// Please refer to the type-level documentation for general notes about + /// callbacks. + pub fn connect_stream(&self, addr: &str, callback: F) -> CancelHandle + where + F: FnOnce(io::Result) + Send + 'static, + { + let mut addr_buf = ManuallyDrop::new(MakeSend::new(User::<[u8]>::uninitialized(addr.len()))); + let mut local_addr_buf = ManuallyDrop::new(MakeSend::new(User::::uninitialized())); + let mut peer_addr_buf = ManuallyDrop::new(MakeSend::new(User::::uninitialized())); + + addr_buf[0..addr.len()].copy_from_enclave(addr.as_bytes()); + let addr_buf_ptr = addr_buf.as_raw_mut_ptr() as *mut u8; + let local_addr_ptr = local_addr_buf.as_raw_mut_ptr(); + let peer_addr_ptr = peer_addr_buf.as_raw_mut_ptr(); + + let cb = move |res: io::Result| { + let _addr_buf = ManuallyDrop::into_inner(addr_buf); + let local_addr_buf = ManuallyDrop::into_inner(local_addr_buf); + let peer_addr_buf = ManuallyDrop::into_inner(peer_addr_buf); + + let local_addr = Some(string_from_bytebuffer(&local_addr_buf, "connect_stream", "local_addr")); + let peer_addr = Some(string_from_bytebuffer(&peer_addr_buf, "connect_stream", "peer_addr")); + let res = res.map(|fd| unsafe { TcpStream::from_raw_fd(fd, TcpStreamMetadata { local_addr, peer_addr }) }); + callback(res); + }; + unsafe { self.raw_connect_stream(addr_buf_ptr, addr.len(), local_addr_ptr, peer_addr_ptr, Some(cb.into())) } + } + + /// Sends an asynchronous `alloc` usercall to allocate one instance of `T` + /// in userspace. `callback` is called when a return value is received from + /// userspace. + /// + /// Please refer to the type-level documentation for general notes about + /// callbacks. + pub fn alloc(&self, callback: F) + where + T: UserSafe, + F: FnOnce(io::Result>) + Send + 'static, + { + let cb = move |res: io::Result<*mut u8>| { + let res = res.map(|ptr| unsafe { User::::from_raw(ptr as _) }); + callback(res); + }; + unsafe { + self.raw_alloc(mem::size_of::(), T::align_of(), Some(cb.into())); + } + } + + /// Sends an asynchronous `alloc` usercall to allocate a slice of `T` in + /// userspace with the specified `len`. `callback` is called when a return + /// value is received from userspace. + /// + /// Please refer to the type-level documentation for general notes about + /// callbacks. + pub fn alloc_slice(&self, len: usize, callback: F) + where + [T]: UserSafe, + F: FnOnce(io::Result>) + Send + 'static, + { + let cb = move |res: io::Result<*mut u8>| { + let res = res.map(|ptr| unsafe { User::<[T]>::from_raw_parts(ptr as _, len) }); + callback(res); + }; + unsafe { + self.raw_alloc(len * mem::size_of::(), <[T]>::align_of(), Some(cb.into())); + } + } + + /// Sends an asynchronous `free` usercall to deallocate the userspace + /// buffer `buf`. If specified, `callback` is called when a return is + /// received from userspace. + /// + /// Please refer to the type-level documentation for general notes about + /// callbacks. + pub fn free(&self, mut buf: User, callback: Option) + where + T: ?Sized + UserSafe, + F: FnOnce() + Send + 'static, + { + let ptr = buf.as_raw_mut_ptr(); + let cb = callback.map(|callback| move |()| callback()); + unsafe { + self.raw_free( + buf.into_raw() as _, + mem::size_of_val(&mut *ptr), + T::align_of(), + cb.map(Into::into), + ); + } + } + + /// Sends an asynchronous `insecure_time` usercall. `callback` is called + /// when a return value is received from userspace. + /// + /// Please refer to the type-level documentation for general notes about + /// callbacks. + pub fn insecure_time(&self, callback: F) + where + F: FnOnce(SystemTime) + Send + 'static, + { + let cb = move |nanos_since_epoch| { + let t = UNIX_EPOCH + Duration::from_nanos(nanos_since_epoch); + callback(t); + }; + unsafe { + self.raw_insecure_time(Some(cb.into())); + } + } +} + +fn string_from_bytebuffer(buf: &UserRef, usercall: &str, arg: &str) -> String { + String::from_utf8(copy_user_buffer(buf)) + .unwrap_or_else(|_| panic!("Usercall {}: expected {} to be valid UTF-8", usercall, arg)) +} + +// adapted from libstd sys/sgx/abi/usercalls/alloc.rs +fn copy_user_buffer(buf: &UserRef) -> Vec { + unsafe { + let buf = buf.to_enclave(); + if buf.len > 0 { + let user = User::from_raw_parts(buf.data as _, buf.len); + let v = user.to_enclave(); + batch_drop(user); + v + } else { + // Mustn't look at `data` or call `free` if `len` is `0`. + Vec::new() + } + } +} \ No newline at end of file diff --git a/intel-sgx/async-usercalls/src/provider_core.rs b/intel-sgx/async-usercalls/src/provider_core.rs new file mode 100644 index 00000000..55fc7d34 --- /dev/null +++ b/intel-sgx/async-usercalls/src/provider_core.rs @@ -0,0 +1,76 @@ +use crate::queues::*; +use crate::CancelHandle; +use crossbeam_channel as mpmc; +use ipc_queue::Identified; +use std::os::fortanix_sgx::usercalls::raw::{Cancel, Return, Usercall}; +use std::sync::atomic::{AtomicU32, Ordering}; + +pub(crate) struct ProviderCore { + provider_id: u32, + next_id: AtomicU32, +} + +impl ProviderCore { + pub fn new(return_tx: Option>>) -> ProviderCore { + let provider_id = PROVIDERS.new_provider(return_tx); + ProviderCore { + provider_id, + next_id: AtomicU32::new(1), + } + } + + #[cfg(test)] + pub fn provider_id(&self) -> u32 { + self.provider_id + } + + fn next_id(&self) -> u32 { + let id = self.next_id.fetch_add(1, Ordering::Relaxed); + match id { + 0 => self.next_id(), + _ => id, + } + } + + pub fn assign_id(&self, usercall: Usercall) -> Identified { + let id = self.next_id(); + Identified { + id: ((self.provider_id as u64) << 32) | id as u64, + data: usercall, + } + } + + pub fn send_usercall(&self, usercall: Identified) -> CancelHandle { + assert!(usercall.id != 0); + let cancel = Identified { + id: usercall.id, + data: Cancel, + }; + PROVIDERS + .usercall_sender() + .send(usercall) + .expect("failed to send async usercall"); + CancelHandle::new(cancel) + } + + // returns the number of usercalls successfully sent. + pub fn try_send_multiple_usercalls(&self, usercalls: &[Identified]) -> usize { + PROVIDERS.usercall_sender().try_send_multiple(usercalls).unwrap_or(0) + } +} + +impl Drop for ProviderCore { + fn drop(&mut self) { + PROVIDERS.remove_provider(self.provider_id); + } +} + +pub trait ProviderId { + fn provider_id(&self) -> u32; +} + +impl ProviderId for Identified { + fn provider_id(&self) -> u32 { + (self.id >> 32) as u32 + } +} \ No newline at end of file diff --git a/intel-sgx/async-usercalls/src/queues.rs b/intel-sgx/async-usercalls/src/queues.rs new file mode 100644 index 00000000..d5a58f68 --- /dev/null +++ b/intel-sgx/async-usercalls/src/queues.rs @@ -0,0 +1,192 @@ +use crate::provider_core::ProviderId; +use crossbeam_channel as mpmc; +use fortanix_sgx_abi::{EV_CANCELQ_NOT_FULL, EV_RETURNQ_NOT_EMPTY, EV_USERCALLQ_NOT_FULL}; +use ipc_queue::{self, Identified, QueueEvent, RecvError, SynchronizationError, Synchronizer}; +use lazy_static::lazy_static; +use std::os::fortanix_sgx::usercalls::alloc::User; +use std::os::fortanix_sgx::usercalls::raw::{ + self, async_queues, Cancel, FifoDescriptor, Return, Usercall, +}; +use std::sync::{Arc, Mutex}; +use std::{io, iter, thread}; + +pub(crate) type Sender = ipc_queue::Sender; +pub(crate) type Receiver = ipc_queue::Receiver; + +pub(crate) struct Providers { + usercall_queue_tx: Sender, + cancel_queue_tx: Sender, + provider_map: Arc>>>>>, +} + +impl Providers { + pub(crate) fn new_provider(&self, return_tx: Option>>) -> u32 { + self.provider_map.lock().unwrap().insert(return_tx) + } + + pub(crate) fn remove_provider(&self, id: u32) { + let entry = self.provider_map.lock().unwrap().remove(id); + assert!(entry.is_some()); + } + + pub(crate) fn usercall_sender(&self) -> &Sender { + &self.usercall_queue_tx + } + + pub(crate) fn cancel_sender(&self) -> &Sender { + &self.cancel_queue_tx + } +} + +lazy_static! { + pub(crate) static ref PROVIDERS: Providers = { + let (utx, ctx, rx) = init_async_queues().expect("Failed to initialize async queues"); + let provider_map = Arc::new(Mutex::new(Map::new())); + let return_handler = ReturnHandler { + return_queue_rx: rx, + provider_map: Arc::clone(&provider_map), + }; + thread::spawn(move || return_handler.run()); + Providers { + usercall_queue_tx: utx, + cancel_queue_tx: ctx, + provider_map, + } + }; +} + +fn init_async_queues() -> io::Result<(Sender, Sender, Receiver)> { + let usercall_q = User::>::uninitialized().into_raw(); + let cancel_q = User::>::uninitialized().into_raw(); + let return_q = User::>::uninitialized().into_raw(); + + let r = unsafe { async_queues(usercall_q, return_q, cancel_q) }; + if r != 0 { + return Err(io::Error::from_raw_os_error(r)); + } + + let usercall_queue = unsafe { User::>::from_raw(usercall_q) }.to_enclave(); + let cancel_queue = unsafe { User::>::from_raw(cancel_q) }.to_enclave(); + let return_queue = unsafe { User::>::from_raw(return_q) }.to_enclave(); + + // FIXME: once `WithId` is exported from `std::os::fortanix_sgx::usercalls::raw`, we can remove + // `transmute` calls here and use FifoDescriptor/WithId from std everywhere including in ipc-queue. + let utx = unsafe { Sender::from_descriptor(std::mem::transmute(usercall_queue), QueueSynchronizer { queue: Queue::Usercall }) }; + let ctx = unsafe { Sender::from_descriptor(std::mem::transmute(cancel_queue), QueueSynchronizer { queue: Queue::Cancel }) }; + let rx = unsafe { Receiver::from_descriptor(std::mem::transmute(return_queue), QueueSynchronizer { queue: Queue::Return }) }; + Ok((utx, ctx, rx)) +} + +struct ReturnHandler { + return_queue_rx: Receiver, + provider_map: Arc>>>>>, +} + +impl ReturnHandler { + const RECV_BATCH_SIZE: usize = 1024; + + fn send(&self, returns: &[Identified]) { + // This should hold the lock only for a short amount of time + // since mpmc::Sender::send() will not block (unbounded channel). + // Also note that the lock is uncontested most of the time, so + // taking the lock should be fast. + let provider_map = self.provider_map.lock().unwrap(); + for ret in returns { + // NOTE: some providers might decide not to receive results of usercalls they send + // because the results are not interesting, e.g. BatchDropProvider. + if let Some(sender) = provider_map.get(ret.provider_id()).and_then(|entry| entry.as_ref()) { + let _ = sender.send(*ret); + } + } + } + + fn run(self) { + let mut returns = [Identified::default(); Self::RECV_BATCH_SIZE]; + loop { + // Block until there is a return. Then we receive any other values + // from the return queue **without** blocking using `try_iter()`. + let first = match self.return_queue_rx.recv() { + Ok(ret) => ret, + Err(RecvError::Closed) => break, + }; + let mut count = 0; + for ret in iter::once(first).chain(self.return_queue_rx.try_iter().take(Self::RECV_BATCH_SIZE - 1)) { + assert!(ret.id != 0); + returns[count] = ret; + count += 1; + } + self.send(&returns[..count]); + } + } +} + +#[derive(Clone, Copy, Debug)] +enum Queue { + Usercall, + Return, + Cancel, +} + +#[derive(Clone, Debug)] +pub(crate) struct QueueSynchronizer { + queue: Queue, +} + +impl Synchronizer for QueueSynchronizer { + fn wait(&self, event: QueueEvent) -> Result<(), SynchronizationError> { + let ev = match (self.queue, event) { + (Queue::Usercall, QueueEvent::NotEmpty) => panic!("enclave should not recv on usercall queue"), + (Queue::Cancel, QueueEvent::NotEmpty) => panic!("enclave should not recv on cancel queue"), + (Queue::Return, QueueEvent::NotFull) => panic!("enclave should not send on return queue"), + (Queue::Usercall, QueueEvent::NotFull) => EV_USERCALLQ_NOT_FULL, + (Queue::Cancel, QueueEvent::NotFull) => EV_CANCELQ_NOT_FULL, + (Queue::Return, QueueEvent::NotEmpty) => EV_RETURNQ_NOT_EMPTY, + }; + unsafe { + raw::wait(ev, raw::WAIT_INDEFINITE); + } + Ok(()) + } + + fn notify(&self, _event: QueueEvent) { + // any synchronous usercall would do + unsafe { + raw::wait(0, raw::WAIT_NO); + } + } +} + +use self::map::Map; +mod map { + use fnv::FnvHashMap; + + pub struct Map { + map: FnvHashMap, + next_id: u32, + } + + impl Map { + pub fn new() -> Self { + Self { + map: FnvHashMap::with_capacity_and_hasher(16, Default::default()), + next_id: 0, + } + } + + pub fn insert(&mut self, value: T) -> u32 { + let id = self.next_id; + self.next_id += 1; + let old = self.map.insert(id, value); + debug_assert!(old.is_none()); + id + } + + pub fn get(&self, id: u32) -> Option<&T> { + self.map.get(&id) + } + + pub fn remove(&mut self, id: u32) -> Option { + self.map.remove(&id) + } + } +} \ No newline at end of file diff --git a/intel-sgx/async-usercalls/src/raw.rs b/intel-sgx/async-usercalls/src/raw.rs new file mode 100644 index 00000000..189a1c0c --- /dev/null +++ b/intel-sgx/async-usercalls/src/raw.rs @@ -0,0 +1,243 @@ +use crate::callback::*; +use crate::{AsyncUsercallProvider, CancelHandle}; +use fortanix_sgx_abi::Fd; +use std::io; +use std::os::fortanix_sgx::usercalls::raw::ByteBuffer; +use std::os::fortanix_sgx::usercalls::raw::{Usercall, UsercallNrs}; + +pub trait RawApi { + unsafe fn raw_read( + &self, + fd: Fd, + buf: *mut u8, + len: usize, + callback: Option>>, + ) -> CancelHandle; + + unsafe fn raw_write( + &self, + fd: Fd, + buf: *const u8, + len: usize, + callback: Option>>, + ) -> CancelHandle; + + unsafe fn raw_flush(&self, fd: Fd, callback: Option>>); + + unsafe fn raw_close(&self, fd: Fd, callback: Option>); + + unsafe fn raw_bind_stream( + &self, + addr: *const u8, + len: usize, + local_addr: *mut ByteBuffer, + callback: Option>>, + ); + + unsafe fn raw_accept_stream( + &self, + fd: Fd, + local_addr: *mut ByteBuffer, + peer_addr: *mut ByteBuffer, + callback: Option>>, + ) -> CancelHandle; + + unsafe fn raw_connect_stream( + &self, + addr: *const u8, + len: usize, + local_addr: *mut ByteBuffer, + peer_addr: *mut ByteBuffer, + callback: Option>>, + ) -> CancelHandle; + + unsafe fn raw_insecure_time(&self, callback: Option>); + + unsafe fn raw_alloc(&self, size: usize, alignment: usize, callback: Option>>); + + unsafe fn raw_free(&self, ptr: *mut u8, size: usize, alignment: usize, callback: Option>); +} + +impl RawApi for AsyncUsercallProvider { + unsafe fn raw_read( + &self, + fd: Fd, + buf: *mut u8, + len: usize, + callback: Option>>, + ) -> CancelHandle { + let u = Usercall(UsercallNrs::read as _, fd as _, buf as _, len as _, 0); + self.send_usercall(u, callback.map(|cb| Callback::read(cb))) + } + + unsafe fn raw_write( + &self, + fd: Fd, + buf: *const u8, + len: usize, + callback: Option>>, + ) -> CancelHandle { + let u = Usercall(UsercallNrs::write as _, fd as _, buf as _, len as _, 0); + self.send_usercall(u, callback.map(|cb| Callback::write(cb))) + } + + unsafe fn raw_flush(&self, fd: Fd, callback: Option>>) { + let u = Usercall(UsercallNrs::flush as _, fd as _, 0, 0, 0); + self.send_usercall(u, callback.map(|cb| Callback::flush(cb))); + } + + unsafe fn raw_close(&self, fd: Fd, callback: Option>) { + let u = Usercall(UsercallNrs::close as _, fd as _, 0, 0, 0); + self.send_usercall(u, callback.map(|cb| Callback::close(cb))); + } + + unsafe fn raw_bind_stream( + &self, + addr: *const u8, + len: usize, + local_addr: *mut ByteBuffer, + callback: Option>>, + ) { + let u = Usercall(UsercallNrs::bind_stream as _, addr as _, len as _, local_addr as _, 0); + self.send_usercall(u, callback.map(|cb| Callback::bind_stream(cb))); + } + + unsafe fn raw_accept_stream( + &self, + fd: Fd, + local_addr: *mut ByteBuffer, + peer_addr: *mut ByteBuffer, + callback: Option>>, + ) -> CancelHandle { + let u = Usercall( + UsercallNrs::accept_stream as _, + fd as _, + local_addr as _, + peer_addr as _, + 0, + ); + self.send_usercall(u, callback.map(|cb| Callback::accept_stream(cb))) + } + + unsafe fn raw_connect_stream( + &self, + addr: *const u8, + len: usize, + local_addr: *mut ByteBuffer, + peer_addr: *mut ByteBuffer, + callback: Option>>, + ) -> CancelHandle { + let u = Usercall( + UsercallNrs::connect_stream as _, + addr as _, + len as _, + local_addr as _, + peer_addr as _, + ); + self.send_usercall(u, callback.map(|cb| Callback::connect_stream(cb))) + } + + unsafe fn raw_insecure_time(&self, callback: Option>) { + let u = Usercall(UsercallNrs::insecure_time as _, 0, 0, 0, 0); + self.send_usercall(u, callback.map(|cb| Callback::insecure_time(cb))); + } + + unsafe fn raw_alloc(&self, size: usize, alignment: usize, callback: Option>>) { + let u = Usercall(UsercallNrs::alloc as _, size as _, alignment as _, 0, 0); + self.send_usercall(u, callback.map(|cb| Callback::alloc(cb))); + } + + unsafe fn raw_free(&self, ptr: *mut u8, size: usize, alignment: usize, callback: Option>) { + let u = Usercall(UsercallNrs::free as _, ptr as _, size as _, alignment as _, 0); + self.send_usercall(u, callback.map(|cb| Callback::free(cb))); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_support::*; + use crossbeam_channel as mpmc; + use std::io; + use std::sync::atomic::{AtomicPtr, Ordering}; + use std::sync::Arc; + use std::thread; + use std::time::{Duration, UNIX_EPOCH}; + + #[test] + fn get_time_async_raw() { + fn run(tid: u32, provider: AutoPollingProvider) -> (u32, u32, Duration) { + let pid = provider.provider_id(); + const N: usize = 500; + let (tx, rx) = mpmc::bounded(N); + for _ in 0..N { + let tx = tx.clone(); + let cb = move |d| { + let system_time = UNIX_EPOCH + Duration::from_nanos(d); + tx.send(system_time).unwrap(); + }; + unsafe { + provider.raw_insecure_time(Some(cb.into())); + } + } + let mut all = Vec::with_capacity(N); + for _ in 0..N { + all.push(rx.recv().unwrap()); + } + + assert_eq!(all.len(), N); + // The results are returned in arbitrary order + all.sort(); + let t0 = *all.first().unwrap(); + let tn = *all.last().unwrap(); + let total = tn.duration_since(t0).unwrap(); + (tid, pid, total / N as u32) + } + + println!(); + const THREADS: usize = 4; + let mut providers = Vec::with_capacity(THREADS); + for _ in 0..THREADS { + providers.push(AutoPollingProvider::new()); + } + let mut handles = Vec::with_capacity(THREADS); + for (i, provider) in providers.into_iter().enumerate() { + handles.push(thread::spawn(move || run(i as u32, provider))); + } + for h in handles { + let res = h.join().unwrap(); + println!("[{}/{}] (Tn - T0) / N = {:?}", res.0, res.1, res.2); + } + } + + #[test] + fn raw_alloc_free() { + let provider = AutoPollingProvider::new(); + let ptr: Arc> = Arc::new(AtomicPtr::new(0 as _)); + let ptr2 = Arc::clone(&ptr); + const SIZE: usize = 1024; + const ALIGN: usize = 8; + + let (tx, rx) = mpmc::bounded(1); + let cb_alloc = move |p: io::Result<*mut u8>| { + let p = p.unwrap(); + ptr2.store(p, Ordering::Relaxed); + tx.send(()).unwrap(); + }; + unsafe { + provider.raw_alloc(SIZE, ALIGN, Some(cb_alloc.into())); + } + rx.recv().unwrap(); + let p = ptr.load(Ordering::Relaxed); + assert!(!p.is_null()); + + let (tx, rx) = mpmc::bounded(1); + let cb_free = move |()| { + tx.send(()).unwrap(); + }; + unsafe { + provider.raw_free(p, SIZE, ALIGN, Some(cb_free.into())); + } + rx.recv().unwrap(); + } +} \ No newline at end of file diff --git a/intel-sgx/async-usercalls/src/test_support.rs b/intel-sgx/async-usercalls/src/test_support.rs new file mode 100644 index 00000000..b73f50e6 --- /dev/null +++ b/intel-sgx/async-usercalls/src/test_support.rs @@ -0,0 +1,47 @@ +use crate::AsyncUsercallProvider; +use std::ops::Deref; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::thread; + +pub(crate) struct AutoPollingProvider { + provider: AsyncUsercallProvider, + shutdown: Arc, + join_handle: Option>, +} + +impl AutoPollingProvider { + pub fn new() -> Self { + let (provider, handler) = AsyncUsercallProvider::new(); + let shutdown = Arc::new(AtomicBool::new(false)); + let shutdown1 = shutdown.clone(); + let join_handle = Some(thread::spawn(move || loop { + handler.poll(None); + if shutdown1.load(Ordering::Relaxed) { + break; + } + })); + Self { + provider, + shutdown, + join_handle, + } + } +} + +impl Deref for AutoPollingProvider { + type Target = AsyncUsercallProvider; + + fn deref(&self) -> &Self::Target { + &self.provider + } +} + +impl Drop for AutoPollingProvider { + fn drop(&mut self) { + self.shutdown.store(true, Ordering::Relaxed); + // send a usercall to ensure thread wakes up + self.provider.insecure_time(|_| {}); + self.join_handle.take().unwrap().join().unwrap(); + } +} \ No newline at end of file diff --git a/intel-sgx/async-usercalls/src/utils.rs b/intel-sgx/async-usercalls/src/utils.rs new file mode 100644 index 00000000..78f3c051 --- /dev/null +++ b/intel-sgx/async-usercalls/src/utils.rs @@ -0,0 +1,38 @@ +use std::ops::{Deref, DerefMut}; +use std::os::fortanix_sgx::usercalls::alloc::User; +use std::os::fortanix_sgx::usercalls::raw::ByteBuffer; + +pub(crate) trait MakeSendMarker {} + +pub(crate) struct MakeSend(T); + +impl MakeSend { + pub fn new(t: T) -> Self { + Self(t) + } + + #[allow(unused)] + pub fn into_inner(self) -> T { + self.0 + } +} + +impl Deref for MakeSend { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for MakeSend { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +unsafe impl Send for MakeSend {} + +impl MakeSendMarker for ByteBuffer {} +impl MakeSendMarker for User {} +impl MakeSendMarker for User<[u8]> {} diff --git a/intel-sgx/async-usercalls/test.sh b/intel-sgx/async-usercalls/test.sh new file mode 100644 index 00000000..cdb85673 --- /dev/null +++ b/intel-sgx/async-usercalls/test.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +# Run this in parallel with: +# $ cargo test --target x86_64-fortanix-unknown-sgx --release -- --nocapture --ignored echo + +for i in $(seq 1 100); do + echo $i + telnet localhost 7799 < /dev/zero &> /dev/null & + sleep 0.01 +done + +sleep 10s +kill $(jobs -p) +wait diff --git a/intel-sgx/enclave-runner/Cargo.toml b/intel-sgx/enclave-runner/Cargo.toml index a49eb1b3..9b766772 100644 --- a/intel-sgx/enclave-runner/Cargo.toml +++ b/intel-sgx/enclave-runner/Cargo.toml @@ -21,7 +21,7 @@ exclude = ["fake-vdso/.gitignore", "fake-vdso/Makefile", "fake-vdso/main.S"] [dependencies] # Project dependencies sgxs = { version = "0.7.2", path = "../sgxs" } -fortanix-sgx-abi = { version = "0.4.0" } # TODO: add back `path = "../fortanix-sgx-abi"` +fortanix-sgx-abi = { version = "0.5.0", path = "../fortanix-sgx-abi" } sgx-isa = { version = "0.4.0", path = "../sgx-isa" } ipc-queue = { version = "0.2.0", path = "../../ipc-queue" } @@ -33,7 +33,7 @@ lazy_static = "1.2.0" # MIT/Apache-2.0 libc = "0.2.48" # MIT/Apache-2.0 nix = "0.13.0" # MIT openssl = { version = "0.10", optional = true } # Apache-2.0 -crossbeam = "0.7.1" # MIT/Apache-2.0 +crossbeam = "0.8.2" # MIT/Apache-2.0 num_cpus = "1.10.0" # MIT/Apache-2.0 tokio = { version = "0.2", features = ["full"] } # MIT futures = { version = "0.3", features = ["compat", "io-compat"] } # MIT/Apache-2.0 diff --git a/intel-sgx/enclave-runner/src/usercalls/abi.rs b/intel-sgx/enclave-runner/src/usercalls/abi.rs index 8345aa85..0fcc0c88 100644 --- a/intel-sgx/enclave-runner/src/usercalls/abi.rs +++ b/intel-sgx/enclave-runner/src/usercalls/abi.rs @@ -19,7 +19,7 @@ use futures::future::Future; type Register = u64; -trait RegisterArgument { +pub(super) trait RegisterArgument { fn from_register(_: Register) -> Self; fn into_register(self) -> Register; } @@ -29,7 +29,7 @@ type EnclaveAbort = super::EnclaveAbort; pub(crate) type UsercallResult = ::std::result::Result; pub(crate) type DispatchResult = UsercallResult<(Register, Register)>; -trait ReturnValue { +pub(super) trait ReturnValue { fn into_registers(self) -> DispatchResult; } diff --git a/intel-sgx/enclave-runner/src/usercalls/interface.rs b/intel-sgx/enclave-runner/src/usercalls/interface.rs index c5ec9ca1..c8731bcc 100644 --- a/intel-sgx/enclave-runner/src/usercalls/interface.rs +++ b/intel-sgx/enclave-runner/src/usercalls/interface.rs @@ -252,12 +252,13 @@ impl<'future, 'ioinput: 'future, 'tcs: 'ioinput> Usercalls<'future> for Handler< self, usercall_queue: *mut FifoDescriptor, return_queue: *mut FifoDescriptor, + cancel_queue: *mut FifoDescriptor, ) -> std::pin::Pin)> + 'future>> { async move { unsafe { let ret = match (usercall_queue.as_mut(), return_queue.as_mut()) { (Some(usercall_queue), Some(return_queue)) => { - self.0.async_queues(usercall_queue, return_queue).await.map(Ok) + self.0.async_queues(usercall_queue, return_queue, cancel_queue.as_mut()).await.map(Ok) }, _ => { Ok(Err(IoErrorKind::InvalidInput.into())) @@ -321,13 +322,13 @@ fn result_from_io_error(err: IoError) -> Result { ret as _ } -trait ToSgxResult { +pub(super) trait ToSgxResult { type Return; fn to_sgx_result(self) -> Self::Return; } -trait SgxReturn { +pub(super) trait SgxReturn { fn on_error() -> Self; } diff --git a/intel-sgx/enclave-runner/src/usercalls/mod.rs b/intel-sgx/enclave-runner/src/usercalls/mod.rs index 51dd389d..3624e891 100644 --- a/intel-sgx/enclave-runner/src/usercalls/mod.rs +++ b/intel-sgx/enclave-runner/src/usercalls/mod.rs @@ -6,7 +6,7 @@ use std::alloc::{GlobalAlloc, Layout, System}; use std::cell::RefCell; -use std::collections::VecDeque; +use std::collections::{HashMap, VecDeque}; use std::io::{self, ErrorKind as IoErrorKind, Read, Result as IoResult}; use std::pin::Pin; use std::result::Result as StdResult; @@ -29,12 +29,9 @@ use libc::*; use nix::sys::signal; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::stream::Stream as TokioStream; -use tokio::sync::broadcast; -use tokio::sync::mpsc as async_mpsc; -use tokio::sync::Semaphore; - +use tokio::sync::{broadcast, mpsc as async_mpsc, oneshot, Semaphore}; use fortanix_sgx_abi::*; -use ipc_queue::{self, DescriptorGuard, Identified, QueueEvent}; +use ipc_queue::{self, DescriptorGuard, Identified, QueueEvent, WritePosition}; use sgxs::loader::Tcs as SgxsTcs; use crate::loader::{EnclavePanic, ErasedTcs}; @@ -49,20 +46,25 @@ lazy_static! { static ref DEBUGGER_TOGGLE_SYNC: Mutex<()> = Mutex::new(()); } -const EV_ABORT: u64 = 0b0000_0000_0000_1000; +// This is not an event in the sense that it could be passed to `send()` or +// `wait()` usercalls in enclave code. However, it's easier for the enclave +// runner implementation to lump it in with events. Also note that this constant +// is not public. +const EV_ABORT: u64 = 0b0000_0000_0001_0000; const USERCALL_QUEUE_SIZE: usize = 16; const RETURN_QUEUE_SIZE: usize = 1024; +const CANCEL_QUEUE_SIZE: usize = USERCALL_QUEUE_SIZE * 2; enum UsercallSendData { Sync(ThreadResult, RunningTcs, RefCell<[u8; 1024]>), - Async(Identified), + Async(Identified, Option>), } // This is the same as UsercallSendData except that it can't be Sync(CoResult::Return(...), ...) enum UsercallHandleData { Sync(tcs::Usercall, RunningTcs, RefCell<[u8; 1024]>), - Async(Identified), + Async(Identified, Option>, Option>), } type EnclaveResult = StdResult<(u64, u64), EnclaveAbort>>; @@ -502,7 +504,7 @@ struct StoppedTcs { struct IOHandlerInput<'tcs> { tcs: Option<&'tcs mut RunningTcs>, enclave: Arc, - work_sender: &'tcs crossbeam::crossbeam_channel::Sender, + work_sender: &'tcs crossbeam::channel::Sender, } struct PendingEvents { @@ -515,7 +517,7 @@ struct PendingEvents { impl PendingEvents { // Will error if it doesn't fit in a `u64` - const EV_MAX_U64: u64 = (EV_USERCALLQ_NOT_FULL | EV_RETURNQ_NOT_EMPTY | EV_UNPARK) + 1; + const EV_MAX_U64: u64 = (EV_USERCALLQ_NOT_FULL | EV_RETURNQ_NOT_EMPTY | EV_UNPARK | EV_CANCELQ_NOT_FULL) + 1; const EV_MAX: usize = Self::EV_MAX_U64 as _; // Will error if it doesn't fit in a `usize` const _ERROR_IF_USIZE_TOO_SMALL: u64 = u64::MAX + (Self::EV_MAX_U64 - (Self::EV_MAX as u64)); @@ -528,6 +530,8 @@ impl PendingEvents { counts: [ Semaphore::new(0), Semaphore::new(0), Semaphore::new(0), Semaphore::new(0), Semaphore::new(0), Semaphore::new(0), Semaphore::new(0), Semaphore::new(0), + Semaphore::new(0), Semaphore::new(0), Semaphore::new(0), Semaphore::new(0), + Semaphore::new(0), Semaphore::new(0), Semaphore::new(0), Semaphore::new(0), ], abort: Semaphore::new(0), } @@ -639,6 +643,7 @@ impl EnclaveKind { struct FifoGuards { usercall_queue: DescriptorGuard, return_queue: DescriptorGuard, + cancel_queue: DescriptorGuard, async_queues_called: bool, } @@ -684,6 +689,27 @@ impl Work { } } +enum UsercallEvent { + Started(u64, oneshot::Sender<()>), + Finished(u64), + Cancelled(u64, WritePosition), +} + +trait IgnoreCancel { + fn ignore_cancel(&self) -> bool; +} + +impl IgnoreCancel for Identified { + fn ignore_cancel(&self) -> bool { + self.data.0 != UsercallList::read as u64 && + self.data.0 != UsercallList::read_alloc as u64 && + self.data.0 != UsercallList::write as u64 && + self.data.0 != UsercallList::accept_stream as u64 && + self.data.0 != UsercallList::connect_stream as u64 && + self.data.0 != UsercallList::wait as u64 + } +} + impl EnclaveState { fn event_queue_add_tcs( event_queues: &mut FnvHashMap, @@ -750,19 +776,40 @@ impl EnclaveState { async fn handle_usercall( enclave: Arc, - work_sender: crossbeam::crossbeam_channel::Sender, + work_sender: crossbeam::channel::Sender, tx_return_channel: tokio::sync::mpsc::UnboundedSender<(EnclaveResult, ReturnSource)>, mut handle_data: UsercallHandleData, ) { + let notifier_rx = match handle_data { + UsercallHandleData::Async(_, ref mut notifier_rx, _) => notifier_rx.take(), + _ => None, + }; let (parameters, mode, tcs) = match handle_data { UsercallHandleData::Sync(ref usercall, ref mut tcs, _) => (usercall.parameters(), tcs.mode.into(), Some(tcs)), - UsercallHandleData::Async(ref usercall) => (usercall.data.into(), ReturnSource::AsyncUsercall, None), + UsercallHandleData::Async(ref usercall, _, _) => (usercall.data.into(), ReturnSource::AsyncUsercall, None), }; let mut input = IOHandlerInput { enclave: enclave.clone(), tcs, work_sender: &work_sender }; let handler = Handler(&mut input); - let (_handler, result) = { + let result = { + use self::interface::ToSgxResult; + use self::abi::ReturnValue; + let (p1, p2, p3, p4, p5) = parameters; - dispatch(handler, p1, p2, p3, p4, p5).await + match notifier_rx { + None => dispatch(handler, p1, p2, p3, p4, p5).await.1, + Some(notifier_rx) => { + let a = dispatch(handler, p1, p2, p3, p4, p5).boxed_local(); + let b = notifier_rx; + match futures::future::select(a, b).await { + Either::Left((ret, _)) => ret.1, + Either::Right((Ok(()), _)) => { + let result: IoResult = Err(IoErrorKind::Interrupted.into()); + ReturnValue::into_registers(Ok(result.to_sgx_result())) + }, + Either::Right((Err(_), _)) => panic!("notifier channel closed unexpectedly"), + } + }, + } }; let ret = match result { Ok(ret) => { @@ -773,7 +820,11 @@ impl EnclaveState { entry: CoEntry::Resume(usercall, ret), }).expect("Work sender couldn't send data to receiver"); } - UsercallHandleData::Async(usercall) => { + UsercallHandleData::Async(usercall, _, usercall_event_tx) => { + if let Some(usercall_event_tx) = usercall_event_tx { + usercall_event_tx.send(UsercallEvent::Finished(usercall.id)).ok() + .expect("failed to send usercall event"); + } let return_queue_tx = enclave.return_queue_tx.lock().await.clone().expect("return_queue_tx not initialized"); let ret = Identified { id: usercall.id, @@ -794,7 +845,7 @@ impl EnclaveState { } EnclavePanic::from(debug_buf) } - UsercallHandleData::Async(_) => { + UsercallHandleData::Async(_, _, _) => { // TODO: https://github.com/fortanix/rust-sgx/issues/235#issuecomment-641811437 EnclavePanic::DebugStr("async exit with a panic".to_owned()) } @@ -817,7 +868,7 @@ impl EnclaveState { enclave: Arc, io_queue_receive: tokio::sync::mpsc::UnboundedReceiver, io_queue_send: tokio::sync::mpsc::UnboundedSender, - work_sender: crossbeam::crossbeam_channel::Sender, + work_sender: crossbeam::channel::Sender, ) -> EnclaveResult { let (tx_return_channel, mut rx_return_channel) = tokio::sync::mpsc::unbounded_channel(); let enclave_clone = enclave.clone(); @@ -872,23 +923,66 @@ impl EnclaveState { }; let enclave_clone = enclave.clone(); let io_future = async move { - let (usercall_queue_synchronizer, return_queue_synchronizer, sync_usercall_tx) = QueueSynchronizer::new(enclave_clone.clone()); + let (uqs, rqs, cqs, sync_usercall_tx) = QueueSynchronizer::new(enclave_clone.clone()); - let (usercall_queue_tx, usercall_queue_rx) = ipc_queue::bounded_async(USERCALL_QUEUE_SIZE, usercall_queue_synchronizer); - let (return_queue_tx, return_queue_rx) = ipc_queue::bounded_async(RETURN_QUEUE_SIZE, return_queue_synchronizer); + let (usercall_queue_tx, usercall_queue_rx) = ipc_queue::bounded_async(USERCALL_QUEUE_SIZE, uqs); + let (return_queue_tx, return_queue_rx) = ipc_queue::bounded_async(RETURN_QUEUE_SIZE, rqs); + let (cancel_queue_tx, cancel_queue_rx) = ipc_queue::bounded_async(CANCEL_QUEUE_SIZE, cqs); let fifo_guards = FifoGuards { usercall_queue: usercall_queue_tx.into_descriptor_guard(), return_queue: return_queue_rx.into_descriptor_guard(), + cancel_queue: cancel_queue_tx.into_descriptor_guard(), async_queues_called: false, }; *enclave_clone.fifo_guards.lock().await = Some(fifo_guards); *enclave_clone.return_queue_tx.lock().await = Some(return_queue_tx); + let usercall_queue_monitor = usercall_queue_rx.position_monitor(); + + let (usercall_event_tx, mut usercall_event_rx) = async_mpsc::unbounded_channel(); + let usercall_event_tx_clone = usercall_event_tx.clone(); tokio::task::spawn_local(async move { while let Ok(usercall) = usercall_queue_rx.recv().await { - let _ = io_queue_send.send(UsercallSendData::Async(usercall)); + let notifier_rx = if usercall.ignore_cancel() { + None + } else { + let (notifier_tx, notifier_rx) = oneshot::channel(); + usercall_event_tx_clone.send(UsercallEvent::Started(usercall.id, notifier_tx)).ok().expect("failed to send usercall event"); + Some(notifier_rx) + }; + let _ = io_queue_send.send(UsercallSendData::Async(usercall, notifier_rx)); + } + }); + + let usercall_event_tx_clone = usercall_event_tx.clone(); + let usercall_queue_monitor_clone = usercall_queue_monitor.clone(); + tokio::task::spawn_local(async move { + while let Ok(c) = cancel_queue_rx.recv().await { + let write_position = usercall_queue_monitor_clone.write_position(); + let _ = usercall_event_tx_clone.send(UsercallEvent::Cancelled(c.id, write_position)); + } + }); + + tokio::task::spawn_local(async move { + let mut notifiers = HashMap::new(); + let mut cancels: HashMap = HashMap::new(); + loop { + match usercall_event_rx.recv().await.expect("usercall_event channel closed unexpectedly") { + UsercallEvent::Started(id, notifier) => match cancels.remove(&id) { + Some(_) => { let _ = notifier.send(()); }, + _ => { notifiers.insert(id, notifier); }, + }, + UsercallEvent::Finished(id) => { notifiers.remove(&id); }, + UsercallEvent::Cancelled(id, wp) => match notifiers.remove(&id) { + Some(notifier) => { let _ = notifier.send(()); }, + None => { cancels.insert(id, wp); }, + }, + } + // cleanup old cancels + let read_position = usercall_queue_monitor.read_position(); + cancels.retain(|_id, wp| !read_position.is_past(wp)); } }); @@ -898,8 +992,9 @@ impl EnclaveState { let enclave_clone = enclave_clone.clone(); let tx_return_channel = tx_return_channel.clone(); match work { - UsercallSendData::Async(usercall) => { - let uchd = UsercallHandleData::Async(usercall); + UsercallSendData::Async(usercall, notifier_rx) => { + let usercall_event_tx = if usercall.ignore_cancel() { None } else { Some(usercall_event_tx.clone()) }; + let uchd = UsercallHandleData::Async(usercall, notifier_rx, usercall_event_tx); let fut = Self::handle_usercall(enclave_clone, work_sender.clone(), tx_return_channel, uchd); tokio::task::spawn_local(fut); } @@ -962,7 +1057,7 @@ impl EnclaveState { ) -> EnclaveResult { fn create_worker_threads( num_of_worker_threads: usize, - work_receiver: crossbeam::crossbeam_channel::Receiver, + work_receiver: crossbeam::channel::Receiver, io_queue_send: tokio::sync::mpsc::UnboundedSender, ) -> Vec> { let mut thread_handles = vec![]; @@ -981,7 +1076,7 @@ impl EnclaveState { let (io_queue_send, io_queue_receive) = tokio::sync::mpsc::unbounded_channel(); - let (work_sender, work_receiver) = crossbeam::crossbeam_channel::unbounded(); + let (work_sender, work_receiver) = crossbeam::channel::unbounded(); work_sender .send(start_work) .expect("Work sender couldn't send data to receiver"); @@ -1055,7 +1150,7 @@ impl EnclaveState { rt.block_on(async move { enclave.abort_all_threads(); //clear the threads_queue - while enclave.threads_queue.pop().is_ok() {} + while enclave.threads_queue.pop().is_some() {} let cmd = enclave.kind.as_command().unwrap(); let mut cmddata = cmd.panic_reason.lock().await; @@ -1445,8 +1540,8 @@ impl<'tcs> IOHandlerInput<'tcs> { .as_command() .ok_or(IoErrorKind::InvalidInput)?; let new_tcs = match self.enclave.threads_queue.pop() { - Ok(tcs) => tcs, - Err(_) => { + Some(tcs) => tcs, + None => { return Err(IoErrorKind::WouldBlock.into()); } }; @@ -1483,7 +1578,7 @@ impl<'tcs> IOHandlerInput<'tcs> { } fn check_event_set(set: u64) -> IoResult<()> { - const EV_ALL: u64 = EV_USERCALLQ_NOT_FULL | EV_RETURNQ_NOT_EMPTY | EV_UNPARK; + const EV_ALL: u64 = EV_USERCALLQ_NOT_FULL | EV_RETURNQ_NOT_EMPTY | EV_UNPARK | EV_CANCELQ_NOT_FULL; if (set & !EV_ALL) != 0 { return Err(IoErrorKind::InvalidInput.into()); } @@ -1593,12 +1688,16 @@ impl<'tcs> IOHandlerInput<'tcs> { &mut self, usercall_queue: &mut FifoDescriptor, return_queue: &mut FifoDescriptor, + cancel_queue: Option<&mut FifoDescriptor>, ) -> StdResult<(), EnclaveAbort> { let mut fifo_guards = self.enclave.fifo_guards.lock().await; match &mut *fifo_guards { Some(ref mut fifo_guards) if !fifo_guards.async_queues_called => { *usercall_queue = fifo_guards.usercall_queue.fifo_descriptor(); *return_queue = fifo_guards.return_queue.fifo_descriptor(); + if let Some(cancel_queue) = cancel_queue { + *cancel_queue = fifo_guards.cancel_queue.fifo_descriptor(); + } fifo_guards.async_queues_called = true; Ok(()) } @@ -1617,6 +1716,7 @@ impl<'tcs> IOHandlerInput<'tcs> { enum Queue { Usercall, Return, + Cancel, } struct QueueSynchronizer { @@ -1629,7 +1729,7 @@ struct QueueSynchronizer { } impl QueueSynchronizer { - fn new(enclave: Arc) -> (Self, Self, broadcast::Sender<()>) { + fn new(enclave: Arc) -> (Self, Self, Self, broadcast::Sender<()>) { // This broadcast channel is used to notify enclave-runner of any // synchronous usercalls made by the enclave for the purpose of // synchronizing access to usercall and return queues. @@ -1637,6 +1737,7 @@ impl QueueSynchronizer { // return RecvError::Lagged. let (tx, rx1) = broadcast::channel(1); let rx2 = tx.subscribe(); + let rx3 = tx.subscribe(); let usercall_queue_synchronizer = QueueSynchronizer { queue: Queue::Usercall, enclave: enclave.clone(), @@ -1645,11 +1746,17 @@ impl QueueSynchronizer { }; let return_queue_synchronizer = QueueSynchronizer { queue: Queue::Return, - enclave, + enclave: enclave.clone(), subscription: Mutex::new(rx2), subscription_maker: tx.clone(), }; - (usercall_queue_synchronizer, return_queue_synchronizer, tx) + let cancel_queue_synchronizer = QueueSynchronizer { + queue: Queue::Cancel, + enclave, + subscription: Mutex::new(rx3), + subscription_maker: tx.clone(), + }; + (usercall_queue_synchronizer, return_queue_synchronizer, cancel_queue_synchronizer, tx) } } @@ -1668,6 +1775,7 @@ impl ipc_queue::AsyncSynchronizer for QueueSynchronizer { fn wait(&self, event: QueueEvent) -> Pin> + '_>> { match (self.queue, event) { (Queue::Usercall, QueueEvent::NotFull) => panic!("enclave runner should not send on the usercall queue"), + (Queue::Cancel, QueueEvent::NotFull) => panic!("enclave runner should not send on the cancel queue"), (Queue::Return, QueueEvent::NotEmpty) => panic!("enclave runner should not receive on the return queue"), _ => {} } @@ -1686,12 +1794,14 @@ impl ipc_queue::AsyncSynchronizer for QueueSynchronizer { fn notify(&self, event: QueueEvent) { let ev = match (self.queue, event) { (Queue::Usercall, QueueEvent::NotEmpty) => panic!("enclave runner should not send on the usercall queue"), - (Queue::Return, QueueEvent::NotFull) => panic!("enclave runner should not receive on the return queue"), - (Queue::Usercall, QueueEvent::NotFull) => EV_USERCALLQ_NOT_FULL, - (Queue::Return, QueueEvent::NotEmpty) => EV_RETURNQ_NOT_EMPTY, + (Queue::Cancel, QueueEvent::NotEmpty) => panic!("enclave runner should not send on the cancel queue"), + (Queue::Return, QueueEvent::NotFull) => panic!("enclave runner should not receive on the return queue"), + (Queue::Usercall, QueueEvent::NotFull) => EV_USERCALLQ_NOT_FULL, + (Queue::Return, QueueEvent::NotEmpty) => EV_RETURNQ_NOT_EMPTY, + (Queue::Cancel, QueueEvent::NotFull) => EV_CANCELQ_NOT_FULL, }; // When the enclave needs to wait on a queue, it executes the wait() usercall synchronously, - // specifying EV_USERCALLQ_NOT_FULL, EV_RETURNQ_NOT_EMPTY, or both in the event_mask. + // specifying EV_USERCALLQ_NOT_FULL, EV_RETURNQ_NOT_EMPTY or EV_CANCELQ_NOT_FULL in the event_mask. // Userspace will wake any or all threads waiting on the appropriate event when it is triggered. for pending_events in self.enclave.event_queues.values() { pending_events.push(ev as _); diff --git a/ipc-queue/Cargo.toml b/ipc-queue/Cargo.toml index a5e8c1b5..2a4d1cd7 100644 --- a/ipc-queue/Cargo.toml +++ b/ipc-queue/Cargo.toml @@ -14,7 +14,7 @@ keywords = ["sgx", "fifo", "queue", "ipc"] categories = ["asynchronous"] [dependencies] -fortanix-sgx-abi = { version = "0.4.0" } # TODO: add back `path = "../intel-sgx/fortanix-sgx-abi"` +fortanix-sgx-abi = { version = "0.5.0", path = "../intel-sgx/fortanix-sgx-abi" } [dev-dependencies] static_assertions = "1.1.0" diff --git a/ipc-queue/src/fifo.rs b/ipc-queue/src/fifo.rs index b000562d..76b6d595 100644 --- a/ipc-queue/src/fifo.rs +++ b/ipc-queue/src/fifo.rs @@ -68,7 +68,7 @@ where let arc = Arc::new(FifoBuffer::new(len)); let inner = Fifo::from_arc(arc); let tx = AsyncSender { inner: inner.clone(), synchronizer: s.clone() }; - let rx = AsyncReceiver { inner, synchronizer: s }; + let rx = AsyncReceiver { inner, synchronizer: s, read_epoch: Arc::new(AtomicU64::new(0)) }; (tx, rx) } @@ -156,6 +156,12 @@ impl Clone for Fifo { } } +impl Fifo { + pub(crate) fn current_offsets(&self, ordering: Ordering) -> Offsets { + Offsets::new(self.offsets.load(ordering), self.data.len() as u32) + } +} + impl Fifo { pub(crate) unsafe fn from_descriptor(descriptor: FifoDescriptor) -> Self { assert!( @@ -209,7 +215,7 @@ impl Fifo { pub(crate) fn try_send_impl(&self, val: Identified) -> Result { let (new, was_empty) = loop { // 1. Load the current offsets. - let current = Offsets::new(self.offsets.load(SeqCst), self.data.len() as u32); + let current = self.current_offsets(Ordering::SeqCst); let was_empty = current.is_empty(); // 2. If the queue is full, wait, then go to step 1. @@ -218,7 +224,7 @@ impl Fifo { } // 3. Add 1 to the write offset and do an atomic compare-and-swap (CAS) - // with the current offsets. If the CAS was not succesful, go to step 1. + // with the current offsets. If the CAS was not successful, go to step 1. let new = current.increment_write_offset(); let current = current.as_usize(); if self.offsets.compare_exchange(current, new.as_usize(), SeqCst, SeqCst).is_ok() { @@ -237,9 +243,9 @@ impl Fifo { Ok(was_empty) } - pub(crate) fn try_recv_impl(&self) -> Result<(Identified, /*wake up writer:*/ bool), TryRecvError> { + pub(crate) fn try_recv_impl(&self) -> Result<(Identified, /*wake up writer:*/ bool, /*read offset wrapped around:*/ bool), TryRecvError> { // 1. Load the current offsets. - let current = Offsets::new(self.offsets.load(SeqCst), self.data.len() as u32); + let current = self.current_offsets(Ordering::SeqCst); // 2. If the queue is empty, wait, then go to step 1. if current.is_empty() { @@ -275,7 +281,7 @@ impl Fifo { // 8. If the queue was full before step 7, signal the writer to wake up. let was_full = Offsets::new(before, self.data.len() as u32).is_full(); - Ok((val, was_full)) + Ok((val, was_full, new.read_offset() == 0)) } } @@ -341,6 +347,14 @@ impl Offsets { ..*self } } + + pub(crate) fn read_high_bit(&self) -> bool { + self.read & self.len == self.len + } + + pub(crate) fn write_high_bit(&self) -> bool { + self.write & self.len == self.len + } } #[cfg(test)] @@ -366,7 +380,7 @@ mod tests { } for i in 1..=7 { - let (v, wake) = inner.try_recv_impl().unwrap(); + let (v, wake, _) = inner.try_recv_impl().unwrap(); assert!(!wake); assert_eq!(v.id, i); assert_eq!(v.data.0, i); @@ -385,7 +399,7 @@ mod tests { assert!(inner.try_send_impl(Identified { id: 9, data: TestValue(9) }).is_err()); for i in 1..=8 { - let (v, wake) = inner.try_recv_impl().unwrap(); + let (v, wake, _) = inner.try_recv_impl().unwrap(); assert!(if i == 1 { wake } else { !wake }); assert_eq!(v.id, i); assert_eq!(v.data.0, i); diff --git a/ipc-queue/src/interface_async.rs b/ipc-queue/src/interface_async.rs index 9478e93e..68fd63c3 100644 --- a/ipc-queue/src/interface_async.rs +++ b/ipc-queue/src/interface_async.rs @@ -5,6 +5,7 @@ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ use super::*; +use std::sync::atomic::Ordering; unsafe impl Send for AsyncSender {} unsafe impl Sync for AsyncSender {} @@ -53,10 +54,13 @@ impl AsyncReceiver { pub async fn recv(&self) -> Result, RecvError> { loop { match self.inner.try_recv_impl() { - Ok((val, wake_sender)) => { + Ok((val, wake_sender, read_wrapped_around)) => { if wake_sender { self.synchronizer.notify(QueueEvent::NotFull); } + if read_wrapped_around { + self.read_epoch.fetch_add(1, Ordering::Relaxed); + } return Ok(val); } Err(TryRecvError::QueueEmpty) => { @@ -69,6 +73,13 @@ impl AsyncReceiver { } } + pub fn position_monitor(&self) -> PositionMonitor { + PositionMonitor { + read_epoch: self.read_epoch.clone(), + fifo: self.inner.clone(), + } + } + /// Consumes `self` and returns a DescriptorGuard. /// The returned guard can be used to make `FifoDescriptor`s that remain /// valid as long as the guard is not dropped. @@ -155,6 +166,65 @@ mod tests { do_multi_sender(1024, 30, 100).await; } + #[tokio::test] + async fn positions() { + const LEN: usize = 16; + let s = TestAsyncSynchronizer::new(); + let (tx, rx) = bounded_async(LEN, s); + let monitor = rx.position_monitor(); + let mut id = 1; + + let p0 = monitor.write_position(); + tx.send(Identified { id, data: TestValue(1) }).await.unwrap(); + let p1 = monitor.write_position(); + tx.send(Identified { id: id + 1, data: TestValue(2) }).await.unwrap(); + let p2 = monitor.write_position(); + tx.send(Identified { id: id + 2, data: TestValue(3) }).await.unwrap(); + let p3 = monitor.write_position(); + id += 3; + assert!(monitor.read_position().is_past(&p0) == false); + assert!(monitor.read_position().is_past(&p1) == false); + assert!(monitor.read_position().is_past(&p2) == false); + assert!(monitor.read_position().is_past(&p3) == false); + + rx.recv().await.unwrap(); + assert!(monitor.read_position().is_past(&p0) == true); + assert!(monitor.read_position().is_past(&p1) == false); + assert!(monitor.read_position().is_past(&p2) == false); + assert!(monitor.read_position().is_past(&p3) == false); + + rx.recv().await.unwrap(); + assert!(monitor.read_position().is_past(&p0) == true); + assert!(monitor.read_position().is_past(&p1) == true); + assert!(monitor.read_position().is_past(&p2) == false); + assert!(monitor.read_position().is_past(&p3) == false); + + rx.recv().await.unwrap(); + assert!(monitor.read_position().is_past(&p0) == true); + assert!(monitor.read_position().is_past(&p1) == true); + assert!(monitor.read_position().is_past(&p2) == true); + assert!(monitor.read_position().is_past(&p3) == false); + + for i in 0..1000 { + let n = 1 + (i % LEN); + let p4 = monitor.write_position(); + for _ in 0..n { + tx.send(Identified { id, data: TestValue(id) }).await.unwrap(); + id += 1; + } + let p5 = monitor.write_position(); + for _ in 0..n { + rx.recv().await.unwrap(); + assert!(monitor.read_position().is_past(&p0) == true); + assert!(monitor.read_position().is_past(&p1) == true); + assert!(monitor.read_position().is_past(&p2) == true); + assert!(monitor.read_position().is_past(&p3) == true); + assert!(monitor.read_position().is_past(&p4) == true); + assert!(monitor.read_position().is_past(&p5) == false); + } + } + } + struct Subscription { tx: broadcast::Sender, rx: Mutex>, diff --git a/ipc-queue/src/interface_sync.rs b/ipc-queue/src/interface_sync.rs index 2096c3c6..1e07cafa 100644 --- a/ipc-queue/src/interface_sync.rs +++ b/ipc-queue/src/interface_sync.rs @@ -112,7 +112,7 @@ impl Receiver { } pub fn try_recv(&self) -> Result, TryRecvError> { - self.inner.try_recv_impl().map(|(val, wake_sender)| { + self.inner.try_recv_impl().map(|(val, wake_sender, _)| { if wake_sender { self.synchronizer.notify(QueueEvent::NotFull); } @@ -127,7 +127,7 @@ impl Receiver { pub fn recv(&self) -> Result, RecvError> { loop { match self.inner.try_recv_impl() { - Ok((val, wake_sender)) => { + Ok((val, wake_sender, _)) => { if wake_sender { self.synchronizer.notify(QueueEvent::NotFull); } diff --git a/ipc-queue/src/lib.rs b/ipc-queue/src/lib.rs index cbada6fe..7518f1c4 100644 --- a/ipc-queue/src/lib.rs +++ b/ipc-queue/src/lib.rs @@ -10,6 +10,8 @@ use std::future::Future; use std::pin::Pin; +use std::sync::atomic::AtomicU64; +use std::sync::Arc; use fortanix_sgx_abi::FifoDescriptor; @@ -21,13 +23,13 @@ use std::os::fortanix_sgx::usercalls::alloc::{UserRef, UserSafeSized}; #[cfg(not(target_env = "sgx"))] use { std::ptr, - std::sync::Arc, self::fifo::FifoBuffer, }; mod fifo; mod interface_sync; mod interface_async; +mod position; #[cfg(test)] mod test_support; @@ -152,6 +154,7 @@ pub struct AsyncSender { pub struct AsyncReceiver { inner: Fifo, synchronizer: S, + read_epoch: Arc, } /// `DescriptorGuard` can produce a `FifoDescriptor` that is guaranteed @@ -167,3 +170,19 @@ impl DescriptorGuard { self.descriptor } } + +/// `PositionMonitor` can be used to record the current read/write positions +/// of a queue. Even though a queue is comprised of a limited number of slots +/// arranged as a ring buffer, we can assign a position to each value written/ +/// read to/from the queue. This is useful in case we want to know whether or +/// not a particular value written to the queue has been read. +pub struct PositionMonitor { + read_epoch: Arc, + fifo: Fifo, +} + +/// A read position in a queue. +pub struct ReadPosition(u64); + +/// A write position in a queue. +pub struct WritePosition(u64); \ No newline at end of file diff --git a/ipc-queue/src/position.rs b/ipc-queue/src/position.rs new file mode 100644 index 00000000..eaa520e8 --- /dev/null +++ b/ipc-queue/src/position.rs @@ -0,0 +1,49 @@ +/* Copyright (c) Fortanix, Inc. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use super::*; +use std::sync::atomic::Ordering; + +impl PositionMonitor { + pub fn read_position(&self) -> ReadPosition { + let current = self.fifo.current_offsets(Ordering::Relaxed); + let read_epoch = self.read_epoch.load(Ordering::Relaxed); + ReadPosition(((read_epoch as u64) << 32) | (current.read_offset() as u64)) + } + + pub fn write_position(&self) -> WritePosition { + let current = self.fifo.current_offsets(Ordering::Relaxed); + let mut write_epoch = self.read_epoch.load(Ordering::Relaxed); + if current.read_high_bit() != current.write_high_bit() { + write_epoch += 1; + } + WritePosition(((write_epoch as u64) << 32) | (current.write_offset() as u64)) + } +} + +impl Clone for PositionMonitor { + fn clone(&self) -> Self { + Self { + read_epoch: self.read_epoch.clone(), + fifo: self.fifo.clone(), + } + } +} + +impl ReadPosition { + /// A `WritePosition` can be compared to a `ReadPosition` **correctly** if + /// at most 2³¹ (2 to the power of 31) writes + /// have occurred since the write position was recorded. + pub fn is_past(&self, write: &WritePosition) -> bool { + let (read, write) = (self.0, write.0); + let hr = read & (1 << 63); + let hw = write & (1 << 63); + if hr == hw { + return read > write; + } + true + } +} \ No newline at end of file