From 6f3fa1d3ea6a8fbcd5859bbeb869b7707139608b Mon Sep 17 00:00:00 2001 From: Ben Dean-Kawamura Date: Fri, 8 Dec 2023 11:40:02 -0500 Subject: [PATCH] Async blocking task support Added the `BlockingTaskQueue` type BlockingTaskQueue allows a Rust closure to be scheduled on a foreign thread where blocking operations are okay. The closure runs inside the parent future, which is nice because it allows the closure to reference its outside scope. On the foreign side, a `BlockingTaskQueue` is a native type that runs a task in some sort of thread queue (`DispatchQueue`, `CoroutineContext`, `futures.Executor`, etc.). Added new tests for this in the futures fixtures. Updated the tests to check that handles are being released properly. --- Cargo.lock | 81 ++++++++++- docs/manual/src/futures.md | 60 ++++++++ fixtures/futures/Cargo.toml | 1 + fixtures/futures/src/lib.rs | 56 ++++++++ .../futures/tests/bindings/test_futures.kts | 73 ++++++++-- .../futures/tests/bindings/test_futures.py | 63 +++++++++ .../futures/tests/bindings/test_futures.swift | 129 ++++++++---------- fixtures/metadata/src/tests.rs | 1 + .../kotlin/gen_kotlin/blocking_task_queue.rs | 19 +++ .../src/bindings/kotlin/gen_kotlin/mod.rs | 5 +- .../src/bindings/kotlin/templates/Async.kt | 38 ++++-- .../templates/BlockingTaskQueueTemplate.kt | 44 ++++++ .../templates/CallbackInterfaceRuntime.kt | 34 ----- .../src/bindings/kotlin/templates/Helpers.kt | 38 ++++++ .../kotlin/templates/ObjectRuntime.kt | 1 - .../src/bindings/kotlin/templates/Types.kt | 5 + .../src/bindings/kotlin/templates/wrapper.kt | 3 + .../python/gen_python/blocking_task_queue.rs | 19 +++ .../src/bindings/python/gen_python/mod.rs | 3 + .../src/bindings/python/templates/Async.py | 50 +++++-- .../templates/BlockingTaskQueueTemplate.py | 39 ++++++ .../templates/CallbackInterfaceRuntime.py | 35 ----- .../src/bindings/python/templates/Helpers.py | 40 ++++++ .../python/templates/PointerManager.py | 10 ++ .../src/bindings/python/templates/Types.py | 3 + .../src/bindings/python/templates/wrapper.py | 1 + .../src/bindings/ruby/gen_ruby/mod.rs | 4 + .../swift/gen_swift/blocking_task_queue.rs | 19 +++ .../src/bindings/swift/gen_swift/mod.rs | 3 + .../src/bindings/swift/templates/Async.swift | 72 ++++++---- .../templates/BlockingTaskQueueTemplate.swift | 37 +++++ .../templates/CallbackInterfaceRuntime.swift | 57 -------- .../templates/CallbackInterfaceTemplate.swift | 10 +- .../bindings/swift/templates/Helpers.swift | 63 +++++++++ .../swift/templates/ObjectTemplate.swift | 2 +- .../src/bindings/swift/templates/Types.swift | 3 + uniffi_bindgen/src/interface/ffi.rs | 3 +- uniffi_bindgen/src/interface/mod.rs | 49 ++++++- uniffi_bindgen/src/interface/universe.rs | 1 + uniffi_bindgen/src/scaffolding/mod.rs | 1 + .../src/ffi/rustfuture/blocking_task_queue.rs | 69 ++++++++++ uniffi_core/src/ffi/rustfuture/future.rs | 38 +++++- uniffi_core/src/ffi/rustfuture/mod.rs | 30 +++- uniffi_core/src/ffi/rustfuture/scheduler.rs | 110 ++++++++++++++- uniffi_core/src/ffi/rustfuture/tests.rs | 119 +++++++++++++--- uniffi_core/src/ffi_converter_impls.rs | 34 ++++- uniffi_core/src/metadata.rs | 1 + uniffi_macros/src/setup_scaffolding.rs | 9 +- uniffi_meta/src/metadata.rs | 1 + uniffi_meta/src/reader.rs | 1 + uniffi_meta/src/types.rs | 1 + uniffi_udl/src/resolver.rs | 1 + 52 files changed, 1274 insertions(+), 315 deletions(-) create mode 100644 uniffi_bindgen/src/bindings/kotlin/gen_kotlin/blocking_task_queue.rs create mode 100644 uniffi_bindgen/src/bindings/kotlin/templates/BlockingTaskQueueTemplate.kt create mode 100644 uniffi_bindgen/src/bindings/python/gen_python/blocking_task_queue.rs create mode 100644 uniffi_bindgen/src/bindings/python/templates/BlockingTaskQueueTemplate.py create mode 100644 uniffi_bindgen/src/bindings/swift/gen_swift/blocking_task_queue.rs create mode 100644 uniffi_bindgen/src/bindings/swift/templates/BlockingTaskQueueTemplate.swift create mode 100644 uniffi_core/src/ffi/rustfuture/blocking_task_queue.rs diff --git a/Cargo.lock b/Cargo.lock index 6c73b85452..3e1c7c827f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -612,26 +612,53 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0845fa252299212f0389d64ba26f34fa32cfe41588355f21ed507c59a0f64541" +[[package]] +name = "futures" +version = "0.3.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da0290714b38af9b4a7b094b8a37086d1b4e61f2df9122c3cad2577669145335" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb" dependencies = [ "futures-core", + "futures-sink", ] [[package]] name = "futures-core" -version = "0.3.28" +version = "0.3.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" + +[[package]] +name = "futures-executor" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +checksum = "0f4fb8693db0cf099eadcca0efe2a5a22e4550f98ed16aba6c48700da29597bc" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] [[package]] name = "futures-io" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" +checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa" [[package]] name = "futures-lite" @@ -648,6 +675,47 @@ dependencies = [ "waker-fn", ] +[[package]] +name = "futures-macro" +version = "0.3.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e36d3378ee38c2a36ad710c5d30c2911d752cb941c00c72dbabfb786a7970817" + +[[package]] +name = "futures-task" +version = "0.3.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" + +[[package]] +name = "futures-util" +version = "0.3.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "gimli" version = "0.28.0" @@ -1632,6 +1700,7 @@ dependencies = [ name = "uniffi-fixture-futures" version = "0.21.0" dependencies = [ + "futures", "once_cell", "thiserror", "tokio", diff --git a/docs/manual/src/futures.md b/docs/manual/src/futures.md index dbe5e5a163..97f9aaf164 100644 --- a/docs/manual/src/futures.md +++ b/docs/manual/src/futures.md @@ -45,3 +45,63 @@ This code uses `asyncio` to drive the future to completion, while our exposed fu In Rust `Future` terminology this means the foreign bindings supply the "executor" - think event-loop, or async runtime. In this example it's `asyncio`. There's no requirement for a Rust event loop. There are [some great API docs](https://docs.rs/uniffi_core/latest/uniffi_core/ffi/rustfuture/index.html) on the implementation that are well worth a read. + +## Blocking tasks + +Rust executors are designed around an assumption that the `Future::poll` function will return quickly. +This assumption, combined with cooperative scheduling, allows for a large number of futures to be handled by a small number of threads. +Foreign executors make similar assumptions and sometimes more extreme ones. +For example, the Python eventloop is single threaded -- if any task spends a long time between `await` points, then it will block all other tasks from progressing. + +This raises the question of how async code can interact with blocking code that performs blocking IO, long-running computations without `await` breaks, etc. +UniFFI defines the `BlockingTaskQueue` type, which is a foreign object that schedules work on a thread where it's okay to block. + +On Rust, `BlockingTaskQueue` is a UniFFI type that can safely run blocking code. +It's `execute` method works like tokio's [block_in_place](https://docs.rs/tokio/latest/tokio/task/fn.block_in_place.html) function. +It inputs a closure and runs it in the `BlockingTaskQueue`. +This closure can reference the outside scope (i.e. it does not need to be `'static`). +For example: + +```rust +#[derive(uniffi::Object)] +struct DataStore { + // Used to run blocking tasks + queue: uniffi::BlockingTaskQueue, + // Low-level DB object with blocking methods + db: Mutex, +} + +#[uniffi::export] +impl DataStore { + #[uniffi::constructor] + fn new(queue: uniffi::BlockingTaskQueue) -> Self { + Self { + queue, + db: Mutex::new(Database::new()) + } + } + + fn fetch_all_items(&self) -> Vec { + self.queue.execute(|| self.db.lock().fetch_all_items()) + } +} +``` + +On the foreign side `BlockingTaskQueue` corresponds to a language-dependent class. + +### Kotlin +Kotlin uses `CoroutineContext` for its `BlockingTaskQueue`. +Any `CoroutineContext` will work, but `Dispatchers.IO` is usually a good choice. +A DataStore from the example above can be created with `DataStore(Dispatchers.IO)`. + +### Swift +Swift uses `DispatchQueue` for its `BlockingTaskQueue`. +The user-initiated global queue is normally a good choice. +A DataStore from the example above can be created with `DataStore(queue: DispatchQueue.global(qos: .userInitiated)`. +The `DispatchQueue` should be concurrent. + +### Python + +Python uses a `futures.Executor` for its `BlockingTaskQueue`. +`ThreadPoolExecutor` is typically a good choice. +A DataStore from the example above can be created with `DataStore(ThreadPoolExecutor())`. diff --git a/fixtures/futures/Cargo.toml b/fixtures/futures/Cargo.toml index 78b08cb689..67c850e9f0 100644 --- a/fixtures/futures/Cargo.toml +++ b/fixtures/futures/Cargo.toml @@ -16,6 +16,7 @@ path = "src/bin.rs" [dependencies] uniffi = { path = "../../uniffi", version = "0.25", features = ["tokio", "cli"] } +futures = "0.3.29" thiserror = "1.0" tokio = { version = "1.24.1", features = ["time", "sync"] } once_cell = "1.18.0" diff --git a/fixtures/futures/src/lib.rs b/fixtures/futures/src/lib.rs index 39a521495e..7466794740 100644 --- a/fixtures/futures/src/lib.rs +++ b/fixtures/futures/src/lib.rs @@ -11,6 +11,8 @@ use std::{ time::Duration, }; +use futures::stream::{FuturesUnordered, StreamExt}; + /// Non-blocking timer future. pub struct TimerFuture { shared_state: Arc>, @@ -326,4 +328,58 @@ pub async fn use_shared_resource(options: SharedResourceOptions) -> Result<(), A Ok(()) } +/// Async function that uses a blocking task queue to do its work +#[uniffi::export] +pub async fn calc_square(queue: uniffi::BlockingTaskQueue, value: i32) -> i32 { + queue.execute(|| value * value).await +} + +/// Same as before, but this one runs multiple tasks +#[uniffi::export] +pub async fn calc_squares(queue: uniffi::BlockingTaskQueue, items: Vec) -> Vec { + // Use `FuturesUnordered` to test our blocking task queue code which is known to be a tricky API to work with. + // In particular, if we don't notify the waker then FuturesUnordered will not poll again. + let mut futures: FuturesUnordered<_> = (0..items.len()) + .map(|i| { + // Test that we can use references from the surrounding scope + let items = &items; + queue.execute(move || items[i] * items[i]) + }) + .collect(); + let mut results = vec![]; + while let Some(result) = futures.next().await { + results.push(result); + } + results.sort(); + results +} + +/// ...and this one uses multiple BlockingTaskQueues +#[uniffi::export] +pub async fn calc_squares_multi_queue( + queues: Vec, + items: Vec, +) -> Vec { + let mut futures: FuturesUnordered<_> = (0..items.len()) + .map(|i| { + // Test that we can use references from the surrounding scope + let items = &items; + queues[i].execute(move || items[i] * items[i]) + }) + .collect(); + let mut results = vec![]; + while let Some(result) = futures.next().await { + results.push(result); + } + results.sort(); + results +} + +/// Like calc_square, but it clones the BlockingTaskQueue first then drops both copies. Used to +/// test that a) the clone works and b) we correctly drop the references. +#[uniffi::export] +pub async fn calc_square_with_clone(queue: uniffi::BlockingTaskQueue, value: i32) -> i32 { + queue.clone().execute(|| value * value).await +} + uniffi::include_scaffolding!("futures"); diff --git a/fixtures/futures/tests/bindings/test_futures.kts b/fixtures/futures/tests/bindings/test_futures.kts index 810bb40f41..0dece29410 100644 --- a/fixtures/futures/tests/bindings/test_futures.kts +++ b/fixtures/futures/tests/bindings/test_futures.kts @@ -1,9 +1,22 @@ import uniffi.fixture.futures.* +import java.util.concurrent.Executors import kotlinx.coroutines.* import kotlin.system.* +fun runAsyncTest(test: suspend CoroutineScope.() -> Unit) { + val initialBlockingTaskQueueHandleCount = uniffiBlockingTaskQueueHandleCount() + val initialPollHandleCount = uniffiPollHandleCount() + val time = runBlocking { + measureTimeMillis { + test() + } + } + assert(uniffiBlockingTaskQueueHandleCount() == initialBlockingTaskQueueHandleCount) + assert(uniffiPollHandleCount() == initialPollHandleCount) +} + // init UniFFI to get good measurements after that -runBlocking { +runAsyncTest { val time = measureTimeMillis { alwaysReady() } @@ -24,7 +37,7 @@ fun assertApproximateTime(actualTime: Long, expectedTime: Int, testName: String } // Test `always_ready`. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val result = alwaysReady() @@ -35,7 +48,7 @@ runBlocking { } // Test `void`. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val result = void() @@ -46,7 +59,7 @@ runBlocking { } // Test `sleep`. -runBlocking { +runAsyncTest { val time = measureTimeMillis { sleep(200U) } @@ -55,7 +68,7 @@ runBlocking { } // Test sequential futures. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val resultAlice = sayAfter(100U, "Alice") val resultBob = sayAfter(200U, "Bob") @@ -68,7 +81,7 @@ runBlocking { } // Test concurrent futures. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val resultAlice = async { sayAfter(100U, "Alice") } val resultBob = async { sayAfter(200U, "Bob") } @@ -81,7 +94,7 @@ runBlocking { } // Test async methods. -runBlocking { +runAsyncTest { val megaphone = newMegaphone() val time = measureTimeMillis { val resultAlice = megaphone.sayAfter(200U, "Alice") @@ -92,7 +105,7 @@ runBlocking { assertApproximateTime(time, 200, "async methods") } -runBlocking { +runAsyncTest { val megaphone = newMegaphone() val time = measureTimeMillis { val resultAlice = sayAfterWithMegaphone(megaphone, 200U, "Alice") @@ -104,7 +117,7 @@ runBlocking { } // Test async method returning optional object -runBlocking { +runAsyncTest { val megaphone = asyncMaybeNewMegaphone(true) assert(megaphone != null) @@ -113,7 +126,7 @@ runBlocking { } // Test with the Tokio runtime. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val resultAlice = sayAfterWithTokio(200U, "Alice") @@ -124,7 +137,7 @@ runBlocking { } // Test fallible function/method. -runBlocking { +runAsyncTest { val time1 = measureTimeMillis { try { fallibleMe(false) @@ -189,7 +202,7 @@ runBlocking { } // Test record. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val result = newMyRecord("foo", 42U) @@ -203,7 +216,7 @@ runBlocking { } // Test a broken sleep. -runBlocking { +runAsyncTest { val time = measureTimeMillis { brokenSleep(100U, 0U) // calls the waker twice immediately sleep(100U) // wait for possible failure @@ -217,7 +230,7 @@ runBlocking { // Test a future that uses a lock and that is cancelled. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val job = launch { useSharedResource(SharedResourceOptions(releaseAfterMs=5000U, timeoutMs=100U)) @@ -236,7 +249,7 @@ runBlocking { } // Test a future that uses a lock and that is not cancelled. -runBlocking { +runAsyncTest { val time = measureTimeMillis { useSharedResource(SharedResourceOptions(releaseAfterMs=100U, timeoutMs=1000U)) @@ -244,3 +257,33 @@ runBlocking { } println("useSharedResource (not canceled): ${time}ms") } + +// Test blocking task queues +runAsyncTest { + withTimeout(1000) { + assert(calcSquare(Dispatchers.IO, 20) == 400) + } + + withTimeout(1000) { + assert(calcSquares(Dispatchers.IO, listOf(1, -2, 3)) == listOf(1, 4, 9)) + } + + val executors = listOf( + Executors.newSingleThreadExecutor(), + Executors.newSingleThreadExecutor(), + Executors.newSingleThreadExecutor(), + ) + withTimeout(1000) { + assert(calcSquaresMultiQueue(executors.map { it.asCoroutineDispatcher() }, listOf(1, -2, 3)) == listOf(1, 4, 9)) + } + for (executor in executors) { + executor.shutdown() + } +} + +// Test blocking task queue cloning +runAsyncTest { + withTimeout(1000) { + assert(calcSquareWithClone(Dispatchers.IO, 20) == 400) + } +} diff --git a/fixtures/futures/tests/bindings/test_futures.py b/fixtures/futures/tests/bindings/test_futures.py index bfbeba86f8..37825db3a0 100644 --- a/fixtures/futures/tests/bindings/test_futures.py +++ b/fixtures/futures/tests/bindings/test_futures.py @@ -1,25 +1,31 @@ +import futures from futures import * +import contextlib import unittest from datetime import datetime import asyncio +from concurrent.futures import ThreadPoolExecutor def now(): return datetime.now() class TestFutures(unittest.TestCase): def test_always_ready(self): + @self.check_handle_counts() async def test(): self.assertEqual(await always_ready(), True) asyncio.run(test()) def test_void(self): + @self.check_handle_counts() async def test(): self.assertEqual(await void(), None) asyncio.run(test()) def test_sleep(self): + @self.check_handle_counts() async def test(): t0 = now() await sleep(2000) @@ -31,6 +37,7 @@ async def test(): asyncio.run(test()) def test_sequential_futures(self): + @self.check_handle_counts() async def test(): t0 = now() result_alice = await say_after(100, 'Alice') @@ -45,6 +52,7 @@ async def test(): asyncio.run(test()) def test_concurrent_tasks(self): + @self.check_handle_counts() async def test(): alice = asyncio.create_task(say_after(100, 'Alice')) bob = asyncio.create_task(say_after(200, 'Bob')) @@ -62,6 +70,7 @@ async def test(): asyncio.run(test()) def test_async_methods(self): + @self.check_handle_counts() async def test(): megaphone = new_megaphone() t0 = now() @@ -75,6 +84,7 @@ async def test(): asyncio.run(test()) def test_async_object_param(self): + @self.check_handle_counts() async def test(): megaphone = new_megaphone() t0 = now() @@ -88,6 +98,7 @@ async def test(): asyncio.run(test()) def test_with_tokio_runtime(self): + @self.check_handle_counts() async def test(): t0 = now() result_alice = await say_after_with_tokio(200, 'Alice') @@ -100,6 +111,7 @@ async def test(): asyncio.run(test()) def test_fallible(self): + @self.check_handle_counts() async def test(): result = await fallible_me(False) self.assertEqual(result, 42) @@ -124,6 +136,7 @@ async def test(): asyncio.run(test()) def test_fallible_struct(self): + @self.check_handle_counts() async def test(): megaphone = await fallible_struct(False) self.assertEqual(await megaphone.fallible_me(False), 42) @@ -137,6 +150,7 @@ async def test(): asyncio.run(test()) def test_record(self): + @self.check_handle_counts() async def test(): result = await new_my_record("foo", 42) self.assertEqual(result.__class__, MyRecord) @@ -146,6 +160,7 @@ async def test(): asyncio.run(test()) def test_cancel(self): + @self.check_handle_counts() async def test(): # Create a task task = asyncio.create_task(say_after(200, 'Alice')) @@ -163,6 +178,7 @@ async def test(): # Test a future that uses a lock and that is cancelled. def test_shared_resource_cancellation(self): + @self.check_handle_counts() async def test(): task = asyncio.create_task(use_shared_resource( SharedResourceOptions(release_after_ms=5000, timeout_ms=100))) @@ -173,10 +189,57 @@ async def test(): asyncio.run(test()) def test_shared_resource_no_cancellation(self): + @self.check_handle_counts() async def test(): await use_shared_resource(SharedResourceOptions(release_after_ms=100, timeout_ms=1000)) await use_shared_resource(SharedResourceOptions(release_after_ms=0, timeout_ms=1000)) asyncio.run(test()) + # blocking task queue tests + + def test_calc_square(self): + @self.check_handle_counts() + async def test(): + async with asyncio.timeout(1): + executor = ThreadPoolExecutor() + self.assertEqual(await calc_square(executor, 20), 400) + asyncio.run(test()) + + def test_calc_square_with_clone(self): + @self.check_handle_counts() + async def test(): + async with asyncio.timeout(1): + executor = ThreadPoolExecutor() + self.assertEqual(await calc_square_with_clone(executor, 20), 400) + asyncio.run(test()) + + def test_calc_squares(self): + @self.check_handle_counts() + async def test(): + async with asyncio.timeout(1): + executor = ThreadPoolExecutor() + self.assertEqual(await calc_squares(executor, [1, -2, 3]), [1, 4, 9]) + asyncio.run(test()) + + def test_calc_squares_multi_queue(self): + @self.check_handle_counts() + async def test(): + async with asyncio.timeout(1): + executors = [ + ThreadPoolExecutor(), + ThreadPoolExecutor(), + ThreadPoolExecutor(), + ] + self.assertEqual(await calc_squares_multi_queue(executors, [1, -2, 3]), [1, 4, 9]) + asyncio.run(test()) + + @contextlib.asynccontextmanager + async def check_handle_counts(self): + initial_poll_handle_count = len(futures.UNIFFI_POLL_DATA_POINTER_MANAGER) + initial_blocking_task_queue_handle_count = len(futures.UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP) + yield + self.assertEqual(len(futures.UNIFFI_POLL_DATA_POINTER_MANAGER), initial_poll_handle_count) + self.assertEqual(len(futures.UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP), initial_blocking_task_queue_handle_count) + if __name__ == '__main__': unittest.main() diff --git a/fixtures/futures/tests/bindings/test_futures.swift b/fixtures/futures/tests/bindings/test_futures.swift index 20e24c40ff..836a8f5a7b 100644 --- a/fixtures/futures/tests/bindings/test_futures.swift +++ b/fixtures/futures/tests/bindings/test_futures.swift @@ -3,10 +3,21 @@ import Foundation // To get `DispatchGroup` and `Date` types. var counter = DispatchGroup() -// Test `alwaysReady` -counter.enter() +func asyncTest(test: @escaping () async throws -> ()) { + let initialBlockingTaskQueueCount = uniffiBlockingTaskQueueHandleCount() + let initialPollDataHandleCount = uniffiPollDataHandleCount() + counter.enter() + Task { + try! await test() + counter.leave() + } + counter.wait() + assert(uniffiBlockingTaskQueueHandleCount() == initialBlockingTaskQueueCount) + assert(uniffiPollDataHandleCount() == initialPollDataHandleCount) +} -Task { +// Test `alwaysReady` +asyncTest { let t0 = Date() let result = await alwaysReady() let t1 = Date() @@ -14,40 +25,28 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration < 0.1) assert(result == true) - - counter.leave() } // Test record. -counter.enter() - -Task { +asyncTest { let result = await newMyRecord(a: "foo", b: 42) assert(result.a == "foo") assert(result.b == 42) - - counter.leave() } // Test `void` -counter.enter() - -Task { +asyncTest { let t0 = Date() await void() let t1 = Date() let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration < 0.1) - - counter.leave() } // Test `Sleep` -counter.enter() - -Task { +asyncTest { let t0 = Date() let result = await sleep(ms: 2000) let t1 = Date() @@ -55,14 +54,10 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 2 && tDelta.duration < 2.1) assert(result == true) - - counter.leave() } // Test sequential futures. -counter.enter() - -Task { +asyncTest { let t0 = Date() let result_alice = await sayAfter(ms: 1000, who: "Alice") let result_bob = await sayAfter(ms: 2000, who: "Bob") @@ -72,14 +67,10 @@ Task { assert(tDelta.duration > 3 && tDelta.duration < 3.1) assert(result_alice == "Hello, Alice!") assert(result_bob == "Hello, Bob!") - - counter.leave() } // Test concurrent futures. -counter.enter() - -Task { +asyncTest { async let alice = sayAfter(ms: 1000, who: "Alice") async let bob = sayAfter(ms: 2000, who: "Bob") @@ -91,14 +82,10 @@ Task { assert(tDelta.duration > 2 && tDelta.duration < 2.1) assert(result_alice == "Hello, Alice!") assert(result_bob == "Hello, Bob!") - - counter.leave() } // Test async methods -counter.enter() - -Task { +asyncTest { let megaphone = newMegaphone() let t0 = Date() @@ -108,26 +95,18 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 2 && tDelta.duration < 2.1) assert(result_alice == "HELLO, ALICE!") - - counter.leave() } // Test async function returning an object -counter.enter() - -Task { +asyncTest { let megaphone = await asyncNewMegaphone() let result = try await megaphone.fallibleMe(doFail: false) assert(result == 42) - - counter.leave() } // Test with the Tokio runtime. -counter.enter() - -Task { +asyncTest { let t0 = Date() let result_alice = await sayAfterWithTokio(ms: 2000, who: "Alice") let t1 = Date() @@ -135,15 +114,11 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 2 && tDelta.duration < 2.1) assert(result_alice == "Hello, Alice (with Tokio)!") - - counter.leave() } // Test fallible function/method… // … which doesn't throw. -counter.enter() - -Task { +asyncTest { let t0 = Date() let result = try await fallibleMe(doFail: false) let t1 = Date() @@ -151,19 +126,15 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 0 && tDelta.duration < 0.1) assert(result == 42) - - counter.leave() } -Task { +asyncTest { let m = try await fallibleStruct(doFail: false) let result = try await m.fallibleMe(doFail: false) assert(result == 42) } -counter.enter() - -Task { +asyncTest { let megaphone = newMegaphone() let t0 = Date() @@ -173,14 +144,10 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 0 && tDelta.duration < 0.1) assert(result == 42) - - counter.leave() } // … which does throw. -counter.enter() - -Task { +asyncTest { let t0 = Date() do { @@ -195,11 +162,9 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 0 && tDelta.duration < 0.1) - - counter.leave() } -Task { +asyncTest { do { let _ = try await fallibleStruct(doFail: true) } catch MyError.Foo { @@ -209,9 +174,7 @@ Task { } } -counter.enter() - -Task { +asyncTest { let megaphone = newMegaphone() let t0 = Date() @@ -228,13 +191,10 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 0 && tDelta.duration < 0.1) - - counter.leave() } // Test a future that uses a lock and that is cancelled. -counter.enter() -Task { +asyncTest { let task = Task { try! await useSharedResource(options: SharedResourceOptions(releaseAfterMs: 100, timeoutMs: 1000)) } @@ -250,15 +210,36 @@ Task { // Try accessing the shared resource again. The initial task should release the shared resource // before the timeout expires. try! await useSharedResource(options: SharedResourceOptions(releaseAfterMs: 0, timeoutMs: 1000)) - counter.leave() } // Test a future that uses a lock and that is not cancelled. -counter.enter() -Task { +asyncTest { try! await useSharedResource(options: SharedResourceOptions(releaseAfterMs: 100, timeoutMs: 1000)) try! await useSharedResource(options: SharedResourceOptions(releaseAfterMs: 0, timeoutMs: 1000)) - counter.leave() } -counter.wait() +// Test blocking task queues +asyncTest { + let calcSquareResult = await calcSquare(queue: DispatchQueue.global(qos: .userInitiated), value: 20) + assert(calcSquareResult == 400) + + let calcSquaresResult = await calcSquares(queue: DispatchQueue.global(qos: .userInitiated), items: [1, -2, 3]) + assert(calcSquaresResult == [1, 4, 9]) + + let calcSquaresMultiQueueResult = await calcSquaresMultiQueue( + queues: [ + DispatchQueue(label: "test-queue1", attributes: DispatchQueue.Attributes.concurrent), + DispatchQueue(label: "test-queue2", attributes: DispatchQueue.Attributes.concurrent), + DispatchQueue(label: "test-queue3", attributes: DispatchQueue.Attributes.concurrent) + ], + items: [1, -2, 3] + ) + assert(calcSquaresMultiQueueResult == [1, 4, 9]) +} + +// Test blocking task queue cloning +asyncTest { + let calcSquareResult = await calcSquareWithClone(queue: DispatchQueue.global(qos: .userInitiated), value: 20) + assert(calcSquareResult == 400) +} + diff --git a/fixtures/metadata/src/tests.rs b/fixtures/metadata/src/tests.rs index f4d8ae244f..cc21f14a62 100644 --- a/fixtures/metadata/src/tests.rs +++ b/fixtures/metadata/src/tests.rs @@ -123,6 +123,7 @@ mod test_type_ids { check_type_id::(Type::Float64); check_type_id::(Type::Boolean); check_type_id::(Type::String); + check_type_id::(Type::BlockingTaskQueue); } #[test] diff --git a/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/blocking_task_queue.rs b/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/blocking_task_queue.rs new file mode 100644 index 0000000000..a664f1fcd3 --- /dev/null +++ b/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/blocking_task_queue.rs @@ -0,0 +1,19 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use super::CodeType; + +#[derive(Debug)] +pub struct BlockingTaskQueueCodeType; + +impl CodeType for BlockingTaskQueueCodeType { + fn type_label(&self, _ci: &crate::ComponentInterface) -> String { + // Kotlin uses CoroutineContext for BlockingTaskQueue + "CoroutineContext".into() + } + + fn canonical_name(&self) -> String { + "BlockingTaskQueue".into() + } +} diff --git a/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs b/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs index 9b542adf85..cac4accbdf 100644 --- a/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs +++ b/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs @@ -16,6 +16,7 @@ use crate::backend::TemplateExpression; use crate::interface::*; use crate::BindingsConfig; +mod blocking_task_queue; mod callback_interface; mod compounds; mod custom; @@ -426,6 +427,8 @@ impl AsCodeType for T { Type::Timestamp => Box::new(miscellany::TimestampCodeType), Type::Duration => Box::new(miscellany::DurationCodeType), + Type::BlockingTaskQueue => Box::new(blocking_task_queue::BlockingTaskQueueCodeType), + Type::Enum { name, .. } => Box::new(enum_::EnumCodeType::new(name)), Type::Object { name, imp, .. } => Box::new(object::ObjectCodeType::new(name, imp)), Type::Record { name, .. } => Box::new(record::RecordCodeType::new(name)), @@ -561,7 +564,7 @@ mod filters { ) -> Result { let ffi_func = callable.ffi_rust_future_poll(ci); Ok(format!( - "{{ future, callback, continuation -> _UniFFILib.INSTANCE.{ffi_func}(future, callback, continuation) }}" + "{{ future, callback, continuation, blockingTaskQueueHandle -> _UniFFILib.INSTANCE.{ffi_func}(future, callback, continuation, blockingTaskQueueHandle) }}" )) } diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt b/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt index 4f4ece37f7..cf2c48f1fb 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt @@ -3,18 +3,37 @@ internal const val UNIFFI_RUST_FUTURE_POLL_READY = 0.toByte() internal const val UNIFFI_RUST_FUTURE_POLL_MAYBE_READY = 1.toByte() -internal val uniffiContinuationHandleMap = UniFfiHandleMap>() +/** + * Data for an in-progress poll of a RustFuture + */ +internal data class UniffiPollData( + val continuation: CancellableContinuation, + val rustFuture: Pointer, + val pollFunc: (Pointer, UniffiRustFutureContinuationCallback, USize, Long) -> Unit, +) + +internal val uniffiPollDataHandleMap = UniFfiHandleMap() // FFI type for Rust future continuations internal object uniffiRustFutureContinuationCallback: UniffiRustFutureContinuationCallback { - override fun callback(data: USize, pollResult: Byte) { - uniffiContinuationHandleMap.remove(data)?.resume(pollResult) + override fun callback(data: USize, pollResult: Byte, blockingTaskQueueHandle: Long) { + if (blockingTaskQueueHandle == 0L) { + // Complete the Kotlin continuation + uniffiPollDataHandleMap.remove(data)!!.continuation.resume(pollResult) + } else { + // Call the poll function again, but inside the BlockingTaskQueue coroutine context + val coroutineContext = uniffiBlockingTaskQueueHandleMap.get(blockingTaskQueueHandle) + val pollData = uniffiPollDataHandleMap.get(data)!! + CoroutineScope(coroutineContext).launch { + pollData.pollFunc(pollData.rustFuture, uniffiRustFutureContinuationCallback, data, blockingTaskQueueHandle) + } + } } } internal suspend fun uniffiRustCallAsync( rustFuture: Pointer, - pollFunc: (Pointer, UniffiRustFutureContinuationCallback, USize) -> Unit, + pollFunc: (Pointer, UniffiRustFutureContinuationCallback, USize, Long) -> Unit, completeFunc: (Pointer, RustCallStatus) -> F, freeFunc: (Pointer) -> Unit, liftFunc: (F) -> T, @@ -23,11 +42,9 @@ internal suspend fun uniffiRustCallAsync( try { do { val pollResult = suspendCancellableCoroutine { continuation -> - pollFunc( - rustFuture, - uniffiRustFutureContinuationCallback, - uniffiContinuationHandleMap.insert(continuation) - ) + val pollData = UniffiPollData(continuation, rustFuture, pollFunc) + val pollDataHandle = uniffiPollDataHandleMap.insert(pollData) + pollFunc(rustFuture, uniffiRustFutureContinuationCallback, pollDataHandle, 0L) } } while (pollResult != UNIFFI_RUST_FUTURE_POLL_READY); @@ -39,3 +56,6 @@ internal suspend fun uniffiRustCallAsync( } } +// For testing +public fun uniffiPollHandleCount() = uniffiPollDataHandleMap.size + diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/BlockingTaskQueueTemplate.kt b/uniffi_bindgen/src/bindings/kotlin/templates/BlockingTaskQueueTemplate.kt new file mode 100644 index 0000000000..717ff62df3 --- /dev/null +++ b/uniffi_bindgen/src/bindings/kotlin/templates/BlockingTaskQueueTemplate.kt @@ -0,0 +1,44 @@ +{{ self.add_import("kotlin.coroutines.CoroutineContext") }} + +object uniffiBlockingTaskQueueClone : UniffiBlockingTaskQueueClone { + override fun callback(handle: Long): Long { + val coroutineContext = uniffiBlockingTaskQueueHandleMap.get(handle) + return uniffiBlockingTaskQueueHandleMap.insert(coroutineContext) + } +} + +object uniffiBlockingTaskQueueFree : UniffiBlockingTaskQueueFree { + override fun callback(handle: Long) { + uniffiBlockingTaskQueueHandleMap.remove(handle) + } +} + +internal val uniffiBlockingTaskQueueVTable = UniffiBlockingTaskQueueVTable( + uniffiBlockingTaskQueueClone, + uniffiBlockingTaskQueueFree, +) +internal val uniffiBlockingTaskQueueHandleMap = ConcurrentHandleMap() + +public object {{ ffi_converter_name }}: FfiConverterRustBuffer { + override fun allocationSize(value: {{ type_name }}) = 16 + + override fun write(value: CoroutineContext, buf: ByteBuffer) { + // Call `write()` to make sure the data is written to the JNA backing data + uniffiBlockingTaskQueueVTable.write() + val handle = uniffiBlockingTaskQueueHandleMap.insert(value) + buf.putLong(handle) + buf.putLong(Pointer.nativeValue(uniffiBlockingTaskQueueVTable.getPointer())) + } + + override fun read(buf: ByteBuffer): CoroutineContext { + val handle = buf.getLong() + val coroutineContext = uniffiBlockingTaskQueueHandleMap.remove(handle)!! + // Read the VTable pointer and throw it out. The vtable is only used by Rust and always the + // same value. + buf.getLong() + return coroutineContext + } +} + +// For testing +public fun uniffiBlockingTaskQueueHandleCount() = uniffiBlockingTaskQueueHandleMap.size diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/CallbackInterfaceRuntime.kt b/uniffi_bindgen/src/bindings/kotlin/templates/CallbackInterfaceRuntime.kt index 5f4d12ddf2..b485779adc 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/CallbackInterfaceRuntime.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/CallbackInterfaceRuntime.kt @@ -1,37 +1,3 @@ -{{- self.add_import("java.util.concurrent.atomic.AtomicLong") }} -{{- self.add_import("java.util.concurrent.locks.ReentrantLock") }} -{{- self.add_import("kotlin.concurrent.withLock") }} - -internal typealias Handle = Long -internal class ConcurrentHandleMap( - private val leftMap: MutableMap = mutableMapOf(), -) { - private val lock = java.util.concurrent.locks.ReentrantLock() - private val currentHandle = AtomicLong(0L) - private val stride = 1L - - fun insert(obj: T): Handle = - lock.withLock { - currentHandle.getAndAdd(stride) - .also { handle -> - leftMap[handle] = obj - } - } - - fun get(handle: Handle) = lock.withLock { - leftMap[handle] ?: throw InternalException("No callback in handlemap; this is a Uniffi bug") - } - - fun delete(handle: Handle) { - this.remove(handle) - } - - fun remove(handle: Handle): T? = - lock.withLock { - leftMap.remove(handle) - } -} - // Magic number for the Rust proxy to call using the same mechanism as every other method, // to free the callback once it's dropped by Rust. internal const val IDX_CALLBACK_FREE = 0 diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/Helpers.kt b/uniffi_bindgen/src/bindings/kotlin/templates/Helpers.kt index c623c37734..b5f95f7779 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/Helpers.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/Helpers.kt @@ -189,3 +189,41 @@ internal class UniFfiHandleMap { return map.remove(handle) } } + +internal typealias Handle = Long +// Like UniFfiHandleMap, but allocates u64 values rather than usize ones. +// +// FIXME: consolidate this code with `UniFfiHandleMap` +// https://github.com/mozilla/uniffi-rs/pull/1823 +internal class ConcurrentHandleMap( + private val leftMap: MutableMap = mutableMapOf(), +) { + private val lock = java.util.concurrent.locks.ReentrantLock() + // Start with 1 so that 0 can be special-cased as the null value. + private val currentHandle = AtomicLong(1L) + private val stride = 1L + + val size: Int + get() = leftMap.size + + fun insert(obj: T): Handle = + lock.withLock { + currentHandle.getAndAdd(stride) + .also { handle -> + leftMap[handle] = obj + } + } + + fun get(handle: Handle) = lock.withLock { + leftMap[handle] ?: throw InternalException("No callback in handlemap; this is a Uniffi bug") + } + + fun delete(handle: Handle) { + this.remove(handle) + } + + fun remove(handle: Handle): T? = + lock.withLock { + leftMap.remove(handle) + } +} diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/ObjectRuntime.kt b/uniffi_bindgen/src/bindings/kotlin/templates/ObjectRuntime.kt index fa22b6ad2b..ed2d6c8f71 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/ObjectRuntime.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/ObjectRuntime.kt @@ -1,4 +1,3 @@ -{{- self.add_import("java.util.concurrent.atomic.AtomicLong") }} {{- self.add_import("java.util.concurrent.atomic.AtomicBoolean") }} // The base class for all UniFFI Object types. // diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/Types.kt b/uniffi_bindgen/src/bindings/kotlin/templates/Types.kt index fb7cb4afde..f720ff6e41 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/Types.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/Types.kt @@ -86,6 +86,9 @@ inline fun T.use(block: (T) -> R) = {%- when Type::Bytes %} {%- include "ByteArrayHelper.kt" %} +{%- when Type::BlockingTaskQueue %} +{%- include "BlockingTaskQueueTemplate.kt" %} + {%- when Type::Enum { name, module_path } %} {%- let e = ci.get_enum_definition(name).unwrap() %} {%- if !ci.is_name_used_as_error(name) %} @@ -131,6 +134,8 @@ inline fun T.use(block: (T) -> R) = {%- if ci.has_async_fns() %} {# Import types needed for async support #} {{ self.add_import("kotlin.coroutines.resume") }} +{{ self.add_import("kotlinx.coroutines.launch") }} {{ self.add_import("kotlinx.coroutines.suspendCancellableCoroutine") }} {{ self.add_import("kotlinx.coroutines.CancellableContinuation") }} +{{ self.add_import("kotlinx.coroutines.CoroutineScope") }} {%- endif %} diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/wrapper.kt b/uniffi_bindgen/src/bindings/kotlin/templates/wrapper.kt index 0f924e4d63..438254f4ae 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/wrapper.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/wrapper.kt @@ -30,7 +30,10 @@ import java.nio.ByteBuffer import java.nio.ByteOrder import java.nio.CharBuffer import java.nio.charset.CodingErrorAction +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.locks.ReentrantLock import java.util.concurrent.ConcurrentHashMap +import kotlin.concurrent.withLock {%- for req in self.imports() %} {{ req.render() }} diff --git a/uniffi_bindgen/src/bindings/python/gen_python/blocking_task_queue.rs b/uniffi_bindgen/src/bindings/python/gen_python/blocking_task_queue.rs new file mode 100644 index 0000000000..2f5a74fda0 --- /dev/null +++ b/uniffi_bindgen/src/bindings/python/gen_python/blocking_task_queue.rs @@ -0,0 +1,19 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use super::CodeType; + +#[derive(Debug)] +pub struct BlockingTaskQueueCodeType; + +impl CodeType for BlockingTaskQueueCodeType { + // On python we use an concurrent.futures.Executor for a BlockingTaskQueue + fn type_label(&self) -> String { + "concurrent.futures.Executor".into() + } + + fn canonical_name(&self) -> String { + "BlockingTaskQueue".into() + } +} diff --git a/uniffi_bindgen/src/bindings/python/gen_python/mod.rs b/uniffi_bindgen/src/bindings/python/gen_python/mod.rs index 2297627043..dffef2599a 100644 --- a/uniffi_bindgen/src/bindings/python/gen_python/mod.rs +++ b/uniffi_bindgen/src/bindings/python/gen_python/mod.rs @@ -16,6 +16,7 @@ use crate::backend::TemplateExpression; use crate::interface::*; use crate::BindingsConfig; +mod blocking_task_queue; mod callback_interface; mod compounds; mod custom; @@ -423,6 +424,8 @@ impl AsCodeType for T { Type::Timestamp => Box::new(miscellany::TimestampCodeType), Type::Duration => Box::new(miscellany::DurationCodeType), + Type::BlockingTaskQueue => Box::new(blocking_task_queue::BlockingTaskQueueCodeType), + Type::Enum { name, .. } => Box::new(enum_::EnumCodeType::new(name)), Type::Object { name, .. } => Box::new(object::ObjectCodeType::new(name)), Type::Record { name, .. } => Box::new(record::RecordCodeType::new(name)), diff --git a/uniffi_bindgen/src/bindings/python/templates/Async.py b/uniffi_bindgen/src/bindings/python/templates/Async.py index aae87ab98b..b31b421485 100644 --- a/uniffi_bindgen/src/bindings/python/templates/Async.py +++ b/uniffi_bindgen/src/bindings/python/templates/Async.py @@ -2,33 +2,63 @@ _UNIFFI_RUST_FUTURE_POLL_READY = 0 _UNIFFI_RUST_FUTURE_POLL_MAYBE_READY = 1 +""" +Data for an in-progress poll of a RustFuture +""" +class UniffiPoll(typing.NamedTuple): + eventloop: asyncio.AbstractEventLoop + py_future: asyncio.Future + rust_future: int + poll_fn: UNIFFI_RUST_FUTURE_CONTINUATION_CALLBACK + # Stores futures for _uniffi_continuation_callback -_UniffiContinuationPointerManager = _UniffiPointerManager() +UNIFFI_POLL_DATA_POINTER_MANAGER = _UniffiPointerManager() # Continuation callback for async functions # lift the return value or error and resolve the future, causing the async function to resume. @UNIFFI_RUST_FUTURE_CONTINUATION_CALLBACK -def _uniffi_continuation_callback(future_ptr, poll_code): - (eventloop, future) = _UniffiContinuationPointerManager.release_pointer(future_ptr) - eventloop.call_soon_threadsafe(_uniffi_set_future_result, future, poll_code) +def _uniffi_continuation_callback(poll_data_ptr, poll_code, blocking_task_queue_handle): + if blocking_task_queue_handle == 0: + # Complete the Python Future + poll_data = UNIFFI_POLL_DATA_POINTER_MANAGER.release_pointer(poll_data_ptr) + poll_data.eventloop.call_soon_threadsafe(_uniffi_set_future_result, poll_data.py_future, poll_code) + else: + # Call the poll function again, but inside the executor + poll_data = UNIFFI_POLL_DATA_POINTER_MANAGER.lookup(poll_data_ptr) + executor = UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.get(blocking_task_queue_handle) + executor.submit( + poll_data.poll_fn, + poll_data.rust_future, + _uniffi_continuation_callback, + poll_data_ptr, + blocking_task_queue_handle + ) def _uniffi_set_future_result(future, poll_code): if not future.cancelled(): future.set_result(poll_code) -async def _uniffi_rust_call_async(rust_future, ffi_poll, ffi_complete, ffi_free, lift_func, error_ffi_converter): +async def _uniffi_rust_call_async(rust_future, poll_fn, ffi_complete, ffi_free, lift_func, error_ffi_converter): try: eventloop = asyncio.get_running_loop() - # Loop and poll until we see a _UNIFFI_RUST_FUTURE_POLL_READY value + # Loop and poll until we see a UNIFFI_RUST_FUTURE_POLL_READY value while True: - future = eventloop.create_future() - ffi_poll( + py_future = eventloop.create_future() + poll_data = UniffiPoll( + eventloop=eventloop, + py_future=py_future, + rust_future=rust_future, + poll_fn=poll_fn, + ) + poll_handle = UNIFFI_POLL_DATA_POINTER_MANAGER.new_pointer(poll_data) + poll_fn( rust_future, _uniffi_continuation_callback, - _UniffiContinuationPointerManager.new_pointer((eventloop, future)), + poll_handle, + 0, ) - poll_code = await future + poll_code = await py_future if poll_code == _UNIFFI_RUST_FUTURE_POLL_READY: break diff --git a/uniffi_bindgen/src/bindings/python/templates/BlockingTaskQueueTemplate.py b/uniffi_bindgen/src/bindings/python/templates/BlockingTaskQueueTemplate.py new file mode 100644 index 0000000000..329ff37906 --- /dev/null +++ b/uniffi_bindgen/src/bindings/python/templates/BlockingTaskQueueTemplate.py @@ -0,0 +1,39 @@ +{{ self.add_import("concurrent.futures") }} + +@UNIFFI_BLOCKING_TASK_QUEUE_CLONE +def uniffi_blocking_task_queue_clone(handle): + executor = UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.get(handle) + return UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.insert(executor) + +@UNIFFI_BLOCKING_TASK_QUEUE_FREE +def uniffi_blocking_task_queue_free(handle): + UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.remove(handle) + +UNIFFI_BLOCKING_TASK_QUEUE_VTABLE = UniffiBlockingTaskQueueVTable( + uniffi_blocking_task_queue_clone, + uniffi_blocking_task_queue_free, +) + +UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP = ConcurrentHandleMap() + +class {{ ffi_converter_name }}(_UniffiConverterRustBuffer): + @staticmethod + def check_lower(value): + if not isinstance(value, concurrent.futures.Executor): + raise TypeError("Expected concurrent.futures.Executor instance, {} found".format(type(value).__name__)) + + @staticmethod + def write(value, buf): + handle = UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.insert(value) + buf.write_u64(handle) + buf.write_u64(ctypes.addressof(UNIFFI_BLOCKING_TASK_QUEUE_VTABLE)) + + @staticmethod + def read(buf): + handle = buf.read_u64() + executor = UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.remove(handle) + # Read the VTable pointer and throw it out. The vtable is only used by Rust and always the + # same value. + buf.read_u64() + return executor + diff --git a/uniffi_bindgen/src/bindings/python/templates/CallbackInterfaceRuntime.py b/uniffi_bindgen/src/bindings/python/templates/CallbackInterfaceRuntime.py index 1337d9685f..7862513db5 100644 --- a/uniffi_bindgen/src/bindings/python/templates/CallbackInterfaceRuntime.py +++ b/uniffi_bindgen/src/bindings/python/templates/CallbackInterfaceRuntime.py @@ -1,38 +1,3 @@ -import threading - -class ConcurrentHandleMap: - """ - A map where inserting, getting and removing data is synchronized with a lock. - """ - - def __init__(self): - # type Handle = int - self._left_map = {} # type: Dict[Handle, Any] - - self._lock = threading.Lock() - self._current_handle = 0 - self._stride = 1 - - def insert(self, obj): - with self._lock: - handle = self._current_handle - self._current_handle += self._stride - self._left_map[handle] = obj - return handle - - def get(self, handle): - with self._lock: - obj = self._left_map.get(handle) - if not obj: - raise InternalError("No callback in handlemap; this is a uniffi bug") - return obj - - def remove(self, handle): - with self._lock: - if handle in self._left_map: - obj = self._left_map.pop(handle) - return obj - # Magic number for the Rust proxy to call using the same mechanism as every other method, # to free the callback once it's dropped by Rust. IDX_CALLBACK_FREE = 0 diff --git a/uniffi_bindgen/src/bindings/python/templates/Helpers.py b/uniffi_bindgen/src/bindings/python/templates/Helpers.py index b4dad8da12..41f770b546 100644 --- a/uniffi_bindgen/src/bindings/python/templates/Helpers.py +++ b/uniffi_bindgen/src/bindings/python/templates/Helpers.py @@ -84,3 +84,43 @@ def _uniffi_trait_interface_call_with_error(call_status, make_call, write_return except Exception as e: call_status.code = _UniffiRustCallStatus.CALL_UNEXPECTED_ERROR call_status.error_buf = {{ Type::String.borrow()|lower_fn }}(repr(e)) + +class ConcurrentHandleMap: + """ + A map where inserting, getting and removing data is synchronized with a lock. + + TODO: consolidate this with `PointerManager` + https://github.com/mozilla/uniffi-rs/pull/1823 + """ + + def __init__(self): + # type Handle = int + self._left_map = {} # type: Dict[Handle, Any] + + self._lock = threading.Lock() + # Start with 1 so that 0 can be special-cased as the null value. + self._current_handle = 1 + self._stride = 1 + + def insert(self, obj): + with self._lock: + handle = self._current_handle + self._current_handle += self._stride + self._left_map[handle] = obj + return handle + + def get(self, handle): + with self._lock: + obj = self._left_map.get(handle) + if not obj: + raise InternalError("No callback in handlemap; this is a uniffi bug") + return obj + + def remove(self, handle): + with self._lock: + if handle in self._left_map: + obj = self._left_map.pop(handle) + return obj + + def __len__(self): + return len(self._left_map) diff --git a/uniffi_bindgen/src/bindings/python/templates/PointerManager.py b/uniffi_bindgen/src/bindings/python/templates/PointerManager.py index 23aa28eab4..526afc563a 100644 --- a/uniffi_bindgen/src/bindings/python/templates/PointerManager.py +++ b/uniffi_bindgen/src/bindings/python/templates/PointerManager.py @@ -5,6 +5,8 @@ class _UniffiPointerManagerCPython: This class is used to generate opaque pointers that reference Python objects to pass to Rust. It assumes a CPython platform. See _UniffiPointerManagerGeneral for the alternative. """ + def __init__(self): + self.count = 0 def new_pointer(self, obj): """ @@ -19,17 +21,22 @@ def new_pointer(self, obj): ctypes.pythonapi.Py_IncRef(ctypes.py_object(obj)) # id() is the object address on CPython # (https://docs.python.org/3/library/functions.html#id) + self.count += 1 return id(obj) def release_pointer(self, address): py_obj = ctypes.cast(address, ctypes.py_object) obj = py_obj.value ctypes.pythonapi.Py_DecRef(py_obj) + self.count -= 1 return obj def lookup(self, address): return ctypes.cast(address, ctypes.py_object).value + def __len__(self): + return self.count + class _UniffiPointerManagerGeneral: """ Manage giving out pointers to Python objects on non-CPython platforms @@ -61,6 +68,9 @@ def lookup(self, handle): with self._lock: return self._map[handle] + def __len__(self): + return len(self._map) + # Pick an pointer manager implementation based on the platform if platform.python_implementation() == 'CPython': _UniffiPointerManager = _UniffiPointerManagerCPython # type: ignore diff --git a/uniffi_bindgen/src/bindings/python/templates/Types.py b/uniffi_bindgen/src/bindings/python/templates/Types.py index 4aaed253e0..9f2840fefb 100644 --- a/uniffi_bindgen/src/bindings/python/templates/Types.py +++ b/uniffi_bindgen/src/bindings/python/templates/Types.py @@ -55,6 +55,9 @@ {%- when Type::Bytes %} {%- include "BytesHelper.py" %} +{%- when Type::BlockingTaskQueue %} +{%- include "BlockingTaskQueueTemplate.py" %} + {%- when Type::Enum { name, module_path } %} {%- let e = ci.get_enum_definition(name).unwrap() %} {# For enums, there are either an error *or* an enum, they can't be both. #} diff --git a/uniffi_bindgen/src/bindings/python/templates/wrapper.py b/uniffi_bindgen/src/bindings/python/templates/wrapper.py index 276ba868c3..7dd2f6a2ab 100644 --- a/uniffi_bindgen/src/bindings/python/templates/wrapper.py +++ b/uniffi_bindgen/src/bindings/python/templates/wrapper.py @@ -23,6 +23,7 @@ import contextlib import datetime import typing +import threading {%- if ci.has_async_fns() %} import asyncio {%- endif %} diff --git a/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs b/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs index c2c55e8018..78359b8d78 100644 --- a/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs +++ b/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs @@ -57,6 +57,7 @@ pub fn canonical_name(t: &Type) -> String { Type::CallbackInterface { name, .. } => format!("CallbackInterface{name}"), Type::Timestamp => "Timestamp".into(), Type::Duration => "Duration".into(), + Type::BlockingTaskQueue => "BlockingTaskQueue".into(), // Recursive types. // These add a prefix to the name of the underlying type. // The component API definition cannot give names to recursive types, so as long as the @@ -262,6 +263,7 @@ mod filters { } Type::External { .. } => panic!("No support for external types, yet"), Type::Custom { .. } => panic!("No support for custom types, yet"), + Type::BlockingTaskQueue => panic!("No support for async functions, yet"), }) } @@ -315,6 +317,7 @@ mod filters { ), Type::External { .. } => panic!("No support for lowering external types, yet"), Type::Custom { .. } => panic!("No support for lowering custom types, yet"), + Type::BlockingTaskQueue => panic!("No support for async functions, yet"), }) } @@ -355,6 +358,7 @@ mod filters { ), Type::External { .. } => panic!("No support for lifting external types, yet"), Type::Custom { .. } => panic!("No support for lifting custom types, yet"), + Type::BlockingTaskQueue => panic!("No support for async functions, yet"), }) } } diff --git a/uniffi_bindgen/src/bindings/swift/gen_swift/blocking_task_queue.rs b/uniffi_bindgen/src/bindings/swift/gen_swift/blocking_task_queue.rs new file mode 100644 index 0000000000..ab4df07f9b --- /dev/null +++ b/uniffi_bindgen/src/bindings/swift/gen_swift/blocking_task_queue.rs @@ -0,0 +1,19 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use super::CodeType; + +#[derive(Debug)] +pub struct BlockingTaskQueueCodeType; + +impl CodeType for BlockingTaskQueueCodeType { + fn type_label(&self) -> String { + // On Swift, we use a DispatchQueue for BlockingTaskQueue + "DispatchQueue".into() + } + + fn canonical_name(&self) -> String { + "BlockingTaskQueue".into() + } +} diff --git a/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs b/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs index bf55c6b9b8..4db5080eb1 100644 --- a/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs +++ b/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs @@ -18,6 +18,7 @@ use crate::backend::TemplateExpression; use crate::interface::*; use crate::BindingsConfig; +mod blocking_task_queue; mod callback_interface; mod compounds; mod custom; @@ -462,6 +463,8 @@ impl SwiftCodeOracle { Type::Timestamp => Box::new(miscellany::TimestampCodeType), Type::Duration => Box::new(miscellany::DurationCodeType), + Type::BlockingTaskQueue => Box::new(blocking_task_queue::BlockingTaskQueueCodeType), + Type::Enum { name, .. } => Box::new(enum_::EnumCodeType::new(name)), Type::Object { name, imp, .. } => Box::new(object::ObjectCodeType::new(name, imp)), Type::Record { name, .. } => Box::new(record::RecordCodeType::new(name)), diff --git a/uniffi_bindgen/src/bindings/swift/templates/Async.swift b/uniffi_bindgen/src/bindings/swift/templates/Async.swift index 578ff4ddf4..8fb37fc88f 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/Async.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/Async.swift @@ -1,9 +1,26 @@ private let UNIFFI_RUST_FUTURE_POLL_READY: Int8 = 0 private let UNIFFI_RUST_FUTURE_POLL_MAYBE_READY: Int8 = 1 +// Data for an in-progress poll of a RustFuture +fileprivate class UniffiPollData { + let continuation: UnsafeContinuation + let rustFuture: UnsafeMutableRawPointer + let pollFunc: (UnsafeMutableRawPointer, @escaping UniffiRustFutureContinuationCallback, UnsafeMutableRawPointer, UInt64) -> () + + init( + continuation: UnsafeContinuation, + rustFuture: UnsafeMutableRawPointer, + pollFunc: @escaping (UnsafeMutableRawPointer, @escaping UniffiRustFutureContinuationCallback, UnsafeMutableRawPointer, UInt64) -> () + ) { + self.continuation = continuation + self.rustFuture = rustFuture + self.pollFunc = pollFunc + } +} + fileprivate func uniffiRustCallAsync( rustFutureFunc: () -> UnsafeMutableRawPointer, - pollFunc: (UnsafeMutableRawPointer, @escaping UniffiRustFutureContinuationCallback, UnsafeMutableRawPointer) -> (), + pollFunc: @escaping (UnsafeMutableRawPointer, @escaping UniffiRustFutureContinuationCallback, UnsafeMutableRawPointer, UInt64) -> (), completeFunc: (UnsafeMutableRawPointer, UnsafeMutablePointer) -> F, freeFunc: (UnsafeMutableRawPointer) -> (), liftFunc: (F) throws -> T, @@ -19,7 +36,14 @@ fileprivate func uniffiRustCallAsync( var pollResult: Int8; repeat { pollResult = await withUnsafeContinuation { - pollFunc(rustFuture, uniffiFutureContinuationCallback, ContinuationHolder($0).toOpaque()) + let pollData = UniffiPollData( + continuation: $0, + rustFuture: rustFuture, + pollFunc: pollFunc + ) + let pollDataPtr = Unmanaged.passRetained(pollData).toOpaque() + UNIFFI_POLL_DATA_HANDLE_COUNT += 1 + pollFunc(rustFuture, uniffiFutureContinuationCallback, pollDataPtr, 0) } } while pollResult != UNIFFI_RUST_FUTURE_POLL_READY @@ -31,28 +55,28 @@ fileprivate func uniffiRustCallAsync( // Callback handlers for an async calls. These are invoked by Rust when the future is ready. They // lift the return value or error and resume the suspended function. -fileprivate func uniffiFutureContinuationCallback(ptr: UnsafeMutableRawPointer, pollResult: Int8) { - ContinuationHolder.fromOpaque(ptr).resume(pollResult) -} - -// Wraps UnsafeContinuation in a class so that we can use reference counting when passing it across -// the FFI -fileprivate class ContinuationHolder { - let continuation: UnsafeContinuation - - init(_ continuation: UnsafeContinuation) { - self.continuation = continuation - } - - func resume(_ pollResult: Int8) { - self.continuation.resume(returning: pollResult) - } - - func toOpaque() -> UnsafeMutableRawPointer { - return Unmanaged.passRetained(self).toOpaque() +fileprivate func uniffiFutureContinuationCallback( + pollDataPtr: UnsafeMutableRawPointer, + pollResult: Int8, + blockingTaskQueueHandle: UInt64 +) { + if (blockingTaskQueueHandle == 0) { + // Complete the Python Future + let pollData = Unmanaged.fromOpaque(pollDataPtr).takeRetainedValue() + UNIFFI_POLL_DATA_HANDLE_COUNT -= 1 + pollData.continuation.resume(returning: pollResult) + } else { + // Call the poll function again, but inside the DispatchQuee + let pollData = Unmanaged.fromOpaque(pollDataPtr).takeUnretainedValue() + let queue = UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.get(handle: blockingTaskQueueHandle)! + queue.async { + pollData.pollFunc(pollData.rustFuture, uniffiFutureContinuationCallback, pollDataPtr, blockingTaskQueueHandle) + } } +} - static func fromOpaque(_ ptr: UnsafeRawPointer) -> ContinuationHolder { - return Unmanaged.fromOpaque(ptr).takeRetainedValue() - } +// For testing +fileprivate var UNIFFI_POLL_DATA_HANDLE_COUNT: Int = 0 +public func uniffiPollDataHandleCount() -> Int { + return UNIFFI_POLL_DATA_HANDLE_COUNT } diff --git a/uniffi_bindgen/src/bindings/swift/templates/BlockingTaskQueueTemplate.swift b/uniffi_bindgen/src/bindings/swift/templates/BlockingTaskQueueTemplate.swift new file mode 100644 index 0000000000..1c81e72595 --- /dev/null +++ b/uniffi_bindgen/src/bindings/swift/templates/BlockingTaskQueueTemplate.swift @@ -0,0 +1,37 @@ +fileprivate var UNIFFI_BLOCKING_TASK_QUEUE_VTABLE = UniffiBlockingTaskQueueVTable( + clone: { (handle: UInt64) -> UInt64 in + let dispatchQueue = UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.get(handle: handle)! + return UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.insert(obj: dispatchQueue) + }, + free: { (handle: UInt64) in + UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.remove(handle: handle) + } +) +fileprivate var UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP = UniffiHandleMap() + +fileprivate struct {{ ffi_converter_name }}: FfiConverterRustBuffer { + typealias SwiftType = DispatchQueue + + public static func write(_ value: DispatchQueue, into buf: inout [UInt8]) { + let handle = UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.insert(obj: value) + writeInt(&buf, handle) + // From Apple: "You can safely use the address of a global variable as a persistent unique + // pointer value" (https://developer.apple.com/swift/blog/?id=6) + let vtablePointer = UnsafeMutablePointer(&UNIFFI_BLOCKING_TASK_QUEUE_VTABLE) + // Convert the pointer to a word-sized Int then to a 64-bit int then write it out. + writeInt(&buf, Int64(Int(bitPattern: vtablePointer))) + } + + public static func read(from buf: inout (data: Data, offset: Data.Index)) throws -> DispatchQueue { + let handle: UInt64 = try readInt(&buf) + // Read the VTable pointer and throw it out. The vtable is only used by Rust and always the + // same value. + let _: UInt64 = try readInt(&buf) + return UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.remove(handle: handle)! + } +} + +// For testing +public func uniffiBlockingTaskQueueHandleCount() -> Int { + UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.count +} diff --git a/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceRuntime.swift b/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceRuntime.swift index d03b7ccb3f..5863c2ad41 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceRuntime.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceRuntime.swift @@ -1,60 +1,3 @@ -fileprivate extension NSLock { - func withLock(f: () throws -> T) rethrows -> T { - self.lock() - defer { self.unlock() } - return try f() - } -} - -fileprivate typealias UniFFICallbackHandle = UInt64 -fileprivate class UniFFICallbackHandleMap { - private var leftMap: [UniFFICallbackHandle: T] = [:] - private var counter: [UniFFICallbackHandle: UInt64] = [:] - private var rightMap: [ObjectIdentifier: UniFFICallbackHandle] = [:] - - private let lock = NSLock() - private var currentHandle: UniFFICallbackHandle = 1 - private let stride: UniFFICallbackHandle = 1 - - func insert(obj: T) -> UniFFICallbackHandle { - lock.withLock { - let id = ObjectIdentifier(obj as AnyObject) - let handle = rightMap[id] ?? { - currentHandle += stride - let handle = currentHandle - leftMap[handle] = obj - rightMap[id] = handle - return handle - }() - counter[handle] = (counter[handle] ?? 0) + 1 - return handle - } - } - - func get(handle: UniFFICallbackHandle) -> T? { - lock.withLock { - leftMap[handle] - } - } - - func delete(handle: UniFFICallbackHandle) { - remove(handle: handle) - } - - @discardableResult - func remove(handle: UniFFICallbackHandle) -> T? { - lock.withLock { - defer { counter[handle] = (counter[handle] ?? 1) - 1 } - guard counter[handle] == 1 else { return leftMap[handle] } - let obj = leftMap.removeValue(forKey: handle) - if let obj = obj { - rightMap.removeValue(forKey: ObjectIdentifier(obj as AnyObject)) - } - return obj - } - } -} - // Magic number for the Rust proxy to call using the same mechanism as every other method, // to free the callback once it's dropped by Rust. private let IDX_CALLBACK_FREE: Int32 = 0 diff --git a/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceTemplate.swift b/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceTemplate.swift index 8cdd735b9a..f649b8b001 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceTemplate.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceTemplate.swift @@ -13,15 +13,15 @@ // FfiConverter protocol for callback interfaces fileprivate struct {{ ffi_converter_name }} { - fileprivate static var handleMap = UniFFICallbackHandleMap<{{ type_name }}>() + fileprivate static var handleMap = UniffiHandleMap<{{ type_name }}>() } extension {{ ffi_converter_name }} : FfiConverter { typealias SwiftType = {{ type_name }} // We can use Handle as the FfiType because it's a typealias to UInt64 - typealias FfiType = UniFFICallbackHandle + typealias FfiType = UInt64 - public static func lift(_ handle: UniFFICallbackHandle) throws -> SwiftType { + public static func lift(_ handle: UInt64) throws -> SwiftType { guard let callback = handleMap.get(handle: handle) else { throw UniffiInternalError.unexpectedStaleHandle } @@ -29,11 +29,11 @@ extension {{ ffi_converter_name }} : FfiConverter { } public static func read(from buf: inout (data: Data, offset: Data.Index)) throws -> SwiftType { - let handle: UniFFICallbackHandle = try readInt(&buf) + let handle: UInt64 = try readInt(&buf) return try lift(handle) } - public static func lower(_ v: SwiftType) -> UniFFICallbackHandle { + public static func lower(_ v: SwiftType) -> UInt64 { return handleMap.insert(obj: v) } diff --git a/uniffi_bindgen/src/bindings/swift/templates/Helpers.swift b/uniffi_bindgen/src/bindings/swift/templates/Helpers.swift index d233d4b762..3cc19914f3 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/Helpers.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/Helpers.swift @@ -44,6 +44,69 @@ fileprivate extension RustCallStatus { } } +fileprivate class UniffiHandleMap { + private var leftMap: [UInt64: T] = [:] + private var counter: [UInt64: UInt64] = [:] + private var rightMap: [ObjectIdentifier: UInt64] = [:] + + private let lock = NSLock() + // Start with 1 so that 0 can be special-cased as the null value. + private var currentHandle: UInt64 = 1 + private let stride: UInt64 = 1 + + func insert(obj: T) -> UInt64 { + lock.withLock { + let id = ObjectIdentifier(obj as AnyObject) + let handle = rightMap[id] ?? { + currentHandle += stride + let handle = currentHandle + leftMap[handle] = obj + rightMap[id] = handle + return handle + }() + counter[handle] = (counter[handle] ?? 0) + 1 + return handle + } + } + + func get(handle: UInt64) -> T? { + lock.withLock { + leftMap[handle] + } + } + + func delete(handle: UInt64) { + remove(handle: handle) + } + + @discardableResult + func remove(handle: UInt64) -> T? { + lock.withLock { + defer { counter[handle] = (counter[handle] ?? 1) - 1 } + guard counter[handle] == 1 else { return leftMap[handle] } + let obj = leftMap.removeValue(forKey: handle) + if let obj = obj { + rightMap.removeValue(forKey: ObjectIdentifier(obj as AnyObject)) + } + return obj + } + } + + var count: Int { + get { + leftMap.count + } + } +} + +fileprivate extension NSLock { + func withLock(f: () throws -> T) rethrows -> T { + self.lock() + defer { self.unlock() } + return try f() + } +} + private func rustCall(_ callback: (UnsafeMutablePointer) -> T) throws -> T { try makeRustCall(callback, errorHandler: nil) } diff --git a/uniffi_bindgen/src/bindings/swift/templates/ObjectTemplate.swift b/uniffi_bindgen/src/bindings/swift/templates/ObjectTemplate.swift index 553e60045f..b1eee0dfd0 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/ObjectTemplate.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/ObjectTemplate.swift @@ -154,7 +154,7 @@ public class {{ impl_class_name }}: public struct {{ ffi_converter_name }}: FfiConverter { {%- if obj.is_trait_interface() %} - fileprivate static var handleMap = UniFFICallbackHandleMap<{{ type_name }}>() + fileprivate static var handleMap = UniffiHandleMap<{{ type_name }}>() {%- endif %} typealias FfiType = UnsafeMutableRawPointer diff --git a/uniffi_bindgen/src/bindings/swift/templates/Types.swift b/uniffi_bindgen/src/bindings/swift/templates/Types.swift index 5e26758f3c..ba4d2059c8 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/Types.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/Types.swift @@ -64,6 +64,9 @@ {%- when Type::CallbackInterface { name, module_path } %} {%- include "CallbackInterfaceTemplate.swift" %} +{%- when Type::BlockingTaskQueue %} +{%- include "BlockingTaskQueueTemplate.swift" %} + {%- when Type::Custom { name, module_path, builtin } %} {%- include "CustomType.swift" %} diff --git a/uniffi_bindgen/src/interface/ffi.rs b/uniffi_bindgen/src/interface/ffi.rs index 5ecb582567..b7a5bd06d0 100644 --- a/uniffi_bindgen/src/interface/ffi.rs +++ b/uniffi_bindgen/src/interface/ffi.rs @@ -106,7 +106,8 @@ impl From<&Type> for FfiType { | Type::Sequence { .. } | Type::Map { .. } | Type::Timestamp - | Type::Duration => FfiType::RustBuffer(None), + | Type::Duration + | Type::BlockingTaskQueue => FfiType::RustBuffer(None), Type::External { name, kind: ExternalKind::Interface, diff --git a/uniffi_bindgen/src/interface/mod.rs b/uniffi_bindgen/src/interface/mod.rs index 81090ec37d..91febfeb50 100644 --- a/uniffi_bindgen/src/interface/mod.rs +++ b/uniffi_bindgen/src/interface/mod.rs @@ -67,7 +67,7 @@ mod record; pub use record::{Field, Record}; pub mod ffi; -pub use ffi::{FfiArgument, FfiCallbackFunction, FfiFunction, FfiStruct, FfiType}; +pub use ffi::{FfiArgument, FfiCallbackFunction, FfiField, FfiFunction, FfiStruct, FfiType}; pub use uniffi_meta::Radix; use uniffi_meta::{ ConstructorMetadata, LiteralMetadata, NamespaceMetadata, ObjectMetadata, TraitMethodMetadata, @@ -232,6 +232,7 @@ impl ComponentInterface { arguments: vec![ FfiArgument::new("data", FfiType::RustFutureContinuationData), FfiArgument::new("poll_result", FfiType::Int8), + FfiArgument::new("blocking_task_queue_handle", FfiType::UInt64), ], return_type: None, has_rust_call_status_arg: false, @@ -242,6 +243,18 @@ impl ComponentInterface { return_type: None, has_rust_call_status_arg: false, }, + FfiCallbackFunction { + name: "BlockingTaskQueueClone".to_owned(), + arguments: vec![FfiArgument::new("handle", FfiType::UInt64)], + return_type: Some(FfiType::UInt64), + has_rust_call_status_arg: false, + }, + FfiCallbackFunction { + name: "BlockingTaskQueueFree".to_owned(), + arguments: vec![FfiArgument::new("handle", FfiType::UInt64)], + return_type: None, + has_rust_call_status_arg: false, + }, ] } @@ -249,6 +262,30 @@ impl ComponentInterface { /// /// These are defined by the foreign code and invoked by Rust. pub fn ffi_struct_definitions(&self) -> impl IntoIterator + '_ { + self.builtin_vtable_definitions() + .into_iter() + .chain(self.callback_interface_vtable_definitions()) + } + + pub fn builtin_vtable_definitions(&self) -> impl IntoIterator + '_ { + [FfiStruct { + name: "BlockingTaskQueueVTable".to_owned(), + fields: vec![ + FfiField::new( + "clone", + FfiType::Callback("BlockingTaskQueueClone".to_owned()), + ), + FfiField::new( + "free", + FfiType::Callback("BlockingTaskQueueFree".to_owned()), + ), + ], + }] + } + + pub fn callback_interface_vtable_definitions( + &self, + ) -> impl IntoIterator + '_ { self.callback_interface_definitions() .iter() .map(|cbi| cbi.vtable_definition()) @@ -517,6 +554,10 @@ impl ComponentInterface { name: "callback_data".to_owned(), type_: FfiType::RustFutureContinuationData, }, + FfiArgument { + name: "blocking_task_queue_handle".to_owned(), + type_: FfiType::UInt64, + }, ], return_type: None, has_rust_call_status_arg: false, @@ -809,8 +850,12 @@ impl ComponentInterface { bail!("Conflicting type definition for \"{}\"", defn.name()); } self.types.add_known_types(defn.iter_types())?; - self.functions.push(defn); + if defn.is_async() { + self.types + .add_known_type(&uniffi_meta::Type::BlockingTaskQueue)?; + } + self.functions.push(defn); Ok(()) } diff --git a/uniffi_bindgen/src/interface/universe.rs b/uniffi_bindgen/src/interface/universe.rs index 70bc61f8a9..2faef72fd6 100644 --- a/uniffi_bindgen/src/interface/universe.rs +++ b/uniffi_bindgen/src/interface/universe.rs @@ -84,6 +84,7 @@ impl TypeUniverse { Type::Bytes => self.add_type_definition("bytes", type_)?, Type::Timestamp => self.add_type_definition("timestamp", type_)?, Type::Duration => self.add_type_definition("duration", type_)?, + Type::BlockingTaskQueue => self.add_type_definition("BlockingTaskQueue", type_)?, Type::Object { name, .. } | Type::Record { name, .. } | Type::Enum { name, .. } diff --git a/uniffi_bindgen/src/scaffolding/mod.rs b/uniffi_bindgen/src/scaffolding/mod.rs index 7fd81831aa..231e0495d3 100644 --- a/uniffi_bindgen/src/scaffolding/mod.rs +++ b/uniffi_bindgen/src/scaffolding/mod.rs @@ -45,6 +45,7 @@ mod filters { format!("std::sync::Arc<{}>", imp.rust_name_for(name)) } Type::CallbackInterface { name, .. } => format!("Box"), + Type::BlockingTaskQueue => "::uniffi::BlockingTaskQueue".to_owned(), Type::Optional { inner_type } => { format!("std::option::Option<{}>", type_rs(inner_type)?) } diff --git a/uniffi_core/src/ffi/rustfuture/blocking_task_queue.rs b/uniffi_core/src/ffi/rustfuture/blocking_task_queue.rs new file mode 100644 index 0000000000..50abe5a66f --- /dev/null +++ b/uniffi_core/src/ffi/rustfuture/blocking_task_queue.rs @@ -0,0 +1,69 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +//! Defines the BlockingTaskQueue struct +//! +//! This module is responsible for the general handling of BlockingTaskQueue instances (cloning, droping, etc). +//! See `scheduler.rs` and the foreign bindings code for how the async functionality is implemented. + +use super::scheduler::schedule_in_blocking_task_queue; +use std::num::NonZeroU64; + +/// Foreign-managed blocking task queue that we can use to schedule futures +/// +/// On the foreign side this is a Kotlin `CoroutineContext`, Python `Executor` or Swift +/// `DispatchQueue`. UniFFI converts those objects into this struct for the Rust code to use. +/// +/// Rust async code can call [BlockingTaskQueue::execute] to run a closure in that +/// blocking task queue. Use this for functions with blocking operations that should not be executed +/// in a normal async context. Some examples are non-async file/network operations, long-running +/// CPU-bound tasks, blocking database operations, etc. +#[repr(C)] +pub struct BlockingTaskQueue { + /// Opaque handle for the task queue + pub handle: NonZeroU64, + /// Method VTable + /// + /// This is simply a C struct where each field is a function pointer that inputs a + /// BlockingTaskQueue handle + pub vtable: &'static BlockingTaskQueueVTable, +} + +#[repr(C)] +#[derive(Debug)] +pub struct BlockingTaskQueueVTable { + clone: extern "C" fn(u64) -> u64, + drop: extern "C" fn(u64), +} + +// Note: see `scheduler.rs` for details on how BlockingTaskQueue is used. +impl BlockingTaskQueue { + /// Run a closure in a blocking task queue + pub async fn execute(&self, f: F) -> R + where + F: FnOnce() -> R, + { + schedule_in_blocking_task_queue(self.handle).await; + f() + } +} + +impl Clone for BlockingTaskQueue { + fn clone(&self) -> Self { + let raw_handle = (self.vtable.clone)(self.handle.into()); + let handle = raw_handle + .try_into() + .expect("BlockingTaskQueue.clone() returned 0"); + Self { + handle, + vtable: self.vtable, + } + } +} + +impl Drop for BlockingTaskQueue { + fn drop(&mut self) { + (self.vtable.drop)(self.handle.into()) + } +} diff --git a/uniffi_core/src/ffi/rustfuture/future.rs b/uniffi_core/src/ffi/rustfuture/future.rs index b104b20a32..b4eed8fe9d 100644 --- a/uniffi_core/src/ffi/rustfuture/future.rs +++ b/uniffi_core/src/ffi/rustfuture/future.rs @@ -21,6 +21,10 @@ //! 2b. If the async function is cancelled, then call [rust_future_cancel]. This causes the //! continuation function to be called with [RustFuturePoll::Ready] and the [RustFuture] to //! enter a cancelled state. +//! 2c. If the Rust code wants schedule work to be run in a `BlockingTaskQueue`, then the +//! continuation is called with [RustFuturePoll::MaybeReady] and the blocking task queue handle. +//! The foreign code is responsible for ensuring the next [rust_future_poll] call happens in +//! that blocking task queue and the handle is passed to [rust_future_poll]. //! 3. Call [rust_future_complete] to get the result of the future. //! 4. Call [rust_future_free] to free the future, ideally in a finally block. This: //! - Releases any resources held by the future @@ -78,6 +82,7 @@ use std::{ future::Future, marker::PhantomData, + num::NonZeroU64, ops::Deref, panic, pin::Pin, @@ -85,8 +90,8 @@ use std::{ task::{Context, Poll, Wake}, }; -use super::{RustFutureContinuationCallback, RustFuturePoll, Scheduler}; -use crate::{rust_call_with_out_status, FfiDefault, LowerReturn, RustCallStatus}; +use super::{scheduler, RustFutureContinuationCallback, Scheduler}; +use crate::{rust_call_with_out_status, FfiDefault, LowerReturn, RustCallStatus, RustFuturePoll}; /// Wraps the actual future we're polling struct WrappedFuture @@ -223,17 +228,26 @@ where }) } - pub(super) fn poll(self: Arc, callback: RustFutureContinuationCallback, data: *const ()) { + pub(super) fn poll( + self: Arc, + callback: RustFutureContinuationCallback, + data: *const (), + blocking_task_queue_handle: Option, + ) { + scheduler::on_poll_start(blocking_task_queue_handle); + // Clear out the waked flag, since we're about to poll right now. + self.scheduler.lock().unwrap().clear_wake_flag(); let ready = self.is_cancelled() || { let mut locked = self.future.lock().unwrap(); let waker: std::task::Waker = Arc::clone(&self).into(); locked.poll(&mut Context::from_waker(&waker)) }; if ready { - callback(data, RustFuturePoll::Ready) + callback(data, RustFuturePoll::Ready, 0) } else { self.scheduler.lock().unwrap().store(callback, data); } + scheduler::on_poll_end(); } pub(super) fn is_cancelled(&self) -> bool { @@ -289,7 +303,12 @@ where /// only create those functions for each of the 13 possible FFI return types. #[doc(hidden)] pub trait RustFutureFfi { - fn ffi_poll(self: Arc, callback: RustFutureContinuationCallback, data: *const ()); + fn ffi_poll( + self: Arc, + callback: RustFutureContinuationCallback, + data: *const (), + blocking_task_queue_handle: Option, + ); fn ffi_cancel(&self); fn ffi_complete(&self, call_status: &mut RustCallStatus) -> ReturnType; fn ffi_free(self: Arc); @@ -302,8 +321,13 @@ where T: LowerReturn + Send + 'static, UT: Send + 'static, { - fn ffi_poll(self: Arc, callback: RustFutureContinuationCallback, data: *const ()) { - self.poll(callback, data) + fn ffi_poll( + self: Arc, + callback: RustFutureContinuationCallback, + data: *const (), + blocking_task_queue_handle: Option, + ) { + self.poll(callback, data, blocking_task_queue_handle) } fn ffi_cancel(&self) { diff --git a/uniffi_core/src/ffi/rustfuture/mod.rs b/uniffi_core/src/ffi/rustfuture/mod.rs index 4aaf013fd5..a7a1142c7a 100644 --- a/uniffi_core/src/ffi/rustfuture/mod.rs +++ b/uniffi_core/src/ffi/rustfuture/mod.rs @@ -4,8 +4,11 @@ use std::{future::Future, sync::Arc}; +mod blocking_task_queue; mod future; mod scheduler; + +pub use blocking_task_queue::*; use future::*; use scheduler::*; @@ -28,10 +31,23 @@ pub enum RustFuturePoll { /// /// The Rust side of things calls this when the foreign side should call [rust_future_poll] again /// to continue progress on the future. -pub type RustFutureContinuationCallback = extern "C" fn(callback_data: *const (), RustFuturePoll); +/// +/// WARNING: the call to [rust_future_poll] must be scheduled to happen soon after the callback is +/// called, but not inside the callback itself. If [rust_future_poll] is called inside the +/// callback, some futures will deadlock and our scheduler code might as well. +/// +/// * `callback_data` is the handle that the foreign code passed to `poll()` +/// * `poll_result` is the result of the poll +/// * If `blocking_task_task_queue` is non-zero, it's the BlockingTaskQueue handle that the next `poll()` should run on +pub type RustFutureContinuationCallback = extern "C" fn( + callback_data: *const (), + poll_result: RustFuturePoll, + blocking_task_queue_handle: u64, +); /// Opaque handle for a Rust future that's stored by the foreign language code #[repr(transparent)] +#[derive(Debug)] pub struct RustFutureHandle(*const ()); // === Public FFI API === @@ -69,16 +85,22 @@ where /// a [RustFuturePoll] value. For each [rust_future_poll] call the continuation will be called /// exactly once. /// +/// If this is running in a BlockingTaskQueue, then `blocking_task_queue_handle` must be the handle +/// for it. If not, `blocking_task_queue_handle` must be `0`. +/// /// # Safety /// /// The [RustFutureHandle] must not previously have been passed to [rust_future_free] pub unsafe fn rust_future_poll( - handle: RustFutureHandle, + future: RustFutureHandle, callback: RustFutureContinuationCallback, data: *const (), + blocking_task_queue_handle: u64, ) { - let future = &*(handle.0 as *mut Arc>); - future.clone().ffi_poll(callback, data) + let future = &*(future.0 as *mut Arc>); + future + .clone() + .ffi_poll(callback, data, blocking_task_queue_handle.try_into().ok()) } /// Cancel a Rust future diff --git a/uniffi_core/src/ffi/rustfuture/scheduler.rs b/uniffi_core/src/ffi/rustfuture/scheduler.rs index aae5a0c1cf..4440cd7ba7 100644 --- a/uniffi_core/src/ffi/rustfuture/scheduler.rs +++ b/uniffi_core/src/ffi/rustfuture/scheduler.rs @@ -2,10 +2,89 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ -use std::mem; +use std::{cell::RefCell, future::poll_fn, mem, num::NonZeroU64, task::Poll, thread_local}; use super::{RustFutureContinuationCallback, RustFuturePoll}; +/// Context of the current `RustFuture::poll` call +struct RustFutureContext { + /// Blocking task queue that the future is being polled on + current_blocking_task_queue_handle: Option, + /// Blocking task queue that we've been asked to schedule the next poll on + scheduled_blocking_task_queue_handle: Option, +} + +thread_local! { + static CONTEXT: RefCell = RefCell::new(RustFutureContext { + current_blocking_task_queue_handle: None, + scheduled_blocking_task_queue_handle: None, + }); +} + +fn with_context R, R>(operation: F) -> R { + CONTEXT.with(|context| operation(&mut context.borrow_mut())) +} + +pub fn on_poll_start(current_blocking_task_queue_handle: Option) { + with_context(|context| { + *context = RustFutureContext { + current_blocking_task_queue_handle, + scheduled_blocking_task_queue_handle: None, + } + }); +} + +pub fn on_poll_end() { + with_context(|context| { + *context = RustFutureContext { + current_blocking_task_queue_handle: None, + scheduled_blocking_task_queue_handle: None, + } + }); +} + +/// Schedule work in a blocking task queue +/// +/// The returned future will attempt to arrange for [RustFuture::poll] to be called in the +/// blocking task queue. Once [RustFuture::poll] is running in the blocking task queue, then the future +/// will be ready. +/// +/// There's one tricky issue here: how can we ensure that when the top-level task is run in the +/// blocking task queue, this future will be polled? What happens this future is a child of `join!`, +/// `FuturesUnordered` or some other Future that handles its own polling? +/// +/// We start with an assumption: if we notify the waker then this future will be polled when the +/// top-level task is polled next. If a future does not honor this then we consider it a broken +/// future. This seems fair, since that future would almost certainly break a lot of other future +/// code. +/// +/// Based on that, we can have a simple system. When we're polled: +/// * If we're running in the blocking task queue, then we return `Poll::Ready`. +/// * If not, we return `Poll::Pending` and notify the waker so that the future polls again on +/// the next top-level poll. +/// +/// Note that this can be inefficient if the code awaits multiple blocking task queues at once. We +/// can only run the next poll on one of them, but all futures will be woken up. This seems okay +/// for our intended use cases, it would be pretty odd for a library to use multiple blocking task +/// queues. The alternative would be to store the set of all pending blocking task queues, which +/// seems like complete overkill for our purposes. +pub(super) async fn schedule_in_blocking_task_queue(handle: NonZeroU64) { + poll_fn(|future_context| { + with_context(|poll_context| { + if poll_context.current_blocking_task_queue_handle == Some(handle) { + Poll::Ready(()) + } else { + poll_context + .scheduled_blocking_task_queue_handle + .get_or_insert(handle); + future_context.waker().wake_by_ref(); + Poll::Pending + } + }) + }) + .await +} + /// Schedules a [crate::RustFuture] by managing the continuation data /// /// This struct manages the continuation callback and data that comes from the foreign side. It @@ -41,21 +120,34 @@ impl Scheduler { /// Store new continuation data if we are in the `Empty` state. If we are in the `Waked` or /// `Cancelled` state, call the continuation immediately with the data. pub(super) fn store(&mut self, callback: RustFutureContinuationCallback, data: *const ()) { + if let Some(blocking_task_queue_handle) = + with_context(|context| context.scheduled_blocking_task_queue_handle) + { + // We were asked to schedule the future in a blocking task queue, call the callback + // rather than storing it + callback( + data, + RustFuturePoll::MaybeReady, + blocking_task_queue_handle.into(), + ); + return; + } + match self { Self::Empty => *self = Self::Set(callback, data), Self::Set(old_callback, old_data) => { log::error!( "store: observed `Self::Set` state. Is poll() being called from multiple threads at once?" ); - old_callback(*old_data, RustFuturePoll::Ready); + old_callback(*old_data, RustFuturePoll::Ready, 0); *self = Self::Set(callback, data); } Self::Waked => { *self = Self::Empty; - callback(data, RustFuturePoll::MaybeReady); + callback(data, RustFuturePoll::MaybeReady, 0); } Self::Cancelled => { - callback(data, RustFuturePoll::Ready); + callback(data, RustFuturePoll::Ready, 0); } } } @@ -67,7 +159,7 @@ impl Scheduler { let old_data = *old_data; let callback = *callback; *self = Self::Empty; - callback(old_data, RustFuturePoll::MaybeReady); + callback(old_data, RustFuturePoll::MaybeReady, 0); } // If we were in the `Empty` state, then transition to `Waked`. The next time `store` // is called, we will immediately call the continuation. @@ -79,7 +171,13 @@ impl Scheduler { pub(super) fn cancel(&mut self) { if let Self::Set(callback, old_data) = mem::replace(self, Self::Cancelled) { - callback(old_data, RustFuturePoll::Ready); + callback(old_data, RustFuturePoll::Ready, 0); + } + } + + pub(super) fn clear_wake_flag(&mut self) { + if let Self::Waked = self { + *self = Self::Empty } } diff --git a/uniffi_core/src/ffi/rustfuture/tests.rs b/uniffi_core/src/ffi/rustfuture/tests.rs index 1f68085562..c988678678 100644 --- a/uniffi_core/src/ffi/rustfuture/tests.rs +++ b/uniffi_core/src/ffi/rustfuture/tests.rs @@ -65,16 +65,38 @@ fn channel() -> (Sender, Arc>) { } /// Poll a Rust future and get an OnceCell that's set when the continuation is called -fn poll(rust_future: &Arc>) -> Arc> { +fn poll(rust_future: &Arc>) -> Arc> { let cell = Arc::new(OnceCell::new()); let cell_ptr = Arc::into_raw(cell.clone()) as *const (); - rust_future.clone().ffi_poll(poll_continuation, cell_ptr); + rust_future + .clone() + .ffi_poll(poll_continuation, cell_ptr, None); cell } -extern "C" fn poll_continuation(data: *const (), code: RustFuturePoll) { - let cell = unsafe { Arc::from_raw(data as *const OnceCell) }; - cell.set(code).expect("Error setting OnceCell"); +/// Like poll, but simulate `poll()` being called from a blocking task queue +fn poll_from_blocking_task_queue( + rust_future: &Arc>, + blocking_task_queue_handle: u64, +) -> Arc> { + let cell = Arc::new(OnceCell::new()); + let cell_ptr = Arc::into_raw(cell.clone()) as *const (); + rust_future.clone().ffi_poll( + poll_continuation, + cell_ptr, + Some(blocking_task_queue_handle.try_into().unwrap()), + ); + cell +} + +extern "C" fn poll_continuation( + data: *const (), + code: RustFuturePoll, + blocking_task_queue_handle: u64, +) { + let cell = unsafe { Arc::from_raw(data as *const OnceCell<(RustFuturePoll, u64)>) }; + cell.set((code, blocking_task_queue_handle)) + .expect("Error setting OnceCell"); } fn complete(rust_future: Arc>) -> (RustBuffer, RustCallStatus) { @@ -83,25 +105,47 @@ fn complete(rust_future: Arc>) -> (RustBuffer, Rus (return_value, out_status_code) } +fn check_continuation_not_called(once_cell: &OnceCell<(RustFuturePoll, u64)>) { + assert_eq!(once_cell.get(), None); +} + +fn check_continuation_called( + once_cell: &OnceCell<(RustFuturePoll, u64)>, + poll_result: RustFuturePoll, +) { + assert_eq!(once_cell.get(), Some(&(poll_result, 0))); +} + +fn check_continuation_called_with_blocking_task_queue_handle( + once_cell: &OnceCell<(RustFuturePoll, u64)>, + poll_result: RustFuturePoll, + blocking_task_queue_handle: u64, +) { + assert_eq!( + once_cell.get(), + Some(&(poll_result, blocking_task_queue_handle)) + ) +} + #[test] fn test_success() { let (sender, rust_future) = channel(); // Test polling the rust future before it's ready let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), None); + check_continuation_not_called(&continuation_result); sender.wake(); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::MaybeReady)); + check_continuation_called(&continuation_result, RustFuturePoll::MaybeReady); // Test polling the rust future when it's ready let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), None); + check_continuation_not_called(&continuation_result); sender.send(Ok("All done".into())); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::MaybeReady)); + check_continuation_called(&continuation_result, RustFuturePoll::MaybeReady); // Future polls should immediately return ready let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); // Complete the future let (return_buf, call_status) = complete(rust_future); @@ -117,12 +161,12 @@ fn test_error() { let (sender, rust_future) = channel(); let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), None); + check_continuation_not_called(&continuation_result); sender.send(Err("Something went wrong".into())); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::MaybeReady)); + check_continuation_called(&continuation_result, RustFuturePoll::MaybeReady); let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); let (_, call_status) = complete(rust_future); assert_eq!(call_status.code, RustCallStatusCode::Error); @@ -144,14 +188,14 @@ fn test_cancel() { let (_sender, rust_future) = channel(); let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), None); + check_continuation_not_called(&continuation_result); rust_future.ffi_cancel(); // Cancellation should immediately invoke the callback with RustFuturePoll::Ready - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); // Future polls should immediately invoke the callback with RustFuturePoll::Ready let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); let (_, call_status) = complete(rust_future); assert_eq!(call_status.code, RustCallStatusCode::Cancelled); @@ -187,7 +231,7 @@ fn test_complete_with_stored_continuation() { let continuation_result = poll(&rust_future); rust_future.ffi_free(); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); } // Test what happens if we see a `wake()` call while we're polling the future. This can @@ -210,10 +254,47 @@ fn test_wake_during_poll() { let rust_future: Arc> = RustFuture::new(future, crate::UniFfiTag); let continuation_result = poll(&rust_future); // The continuation function should called immediately - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::MaybeReady)); + check_continuation_called(&continuation_result, RustFuturePoll::MaybeReady); // A second poll should finish the future let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); + let (return_buf, call_status) = complete(rust_future); + assert_eq!(call_status.code, RustCallStatusCode::Success); + assert_eq!( + >::try_lift(return_buf).unwrap(), + "All done" + ); +} + +#[test] +fn test_blocking_task() { + let blocking_task_queue_handle = 1001; + let future = async move { + schedule_in_blocking_task_queue(blocking_task_queue_handle.try_into().unwrap()).await; + "All done".to_owned() + }; + let rust_future: Arc> = RustFuture::new(future, crate::UniFfiTag); + // On the first poll, the future should not be ready and it should ask to be scheduled in the + // blocking task queue + let continuation_result = poll(&rust_future); + check_continuation_called_with_blocking_task_queue_handle( + &continuation_result, + RustFuturePoll::MaybeReady, + blocking_task_queue_handle, + ); + // If we poll it again not in a blocking task queue, then we get the same result + let continuation_result = poll(&rust_future); + check_continuation_called_with_blocking_task_queue_handle( + &continuation_result, + RustFuturePoll::MaybeReady, + blocking_task_queue_handle, + ); + // When we poll it in the blocking task queue, then the future is ready + let continuation_result = + poll_from_blocking_task_queue(&rust_future, blocking_task_queue_handle); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); + + // Complete the future let (return_buf, call_status) = complete(rust_future); assert_eq!(call_status.code, RustCallStatusCode::Success); assert_eq!( diff --git a/uniffi_core/src/ffi_converter_impls.rs b/uniffi_core/src/ffi_converter_impls.rs index e2f9109a60..f452d2cb77 100644 --- a/uniffi_core/src/ffi_converter_impls.rs +++ b/uniffi_core/src/ffi_converter_impls.rs @@ -23,8 +23,8 @@ /// "UT" means an abitrary `UniFfiTag` type. use crate::{ check_remaining, derive_ffi_traits, ffi_converter_rust_buffer_lift_and_lower, metadata, - ConvertError, FfiConverter, Lift, LiftReturn, Lower, LowerReturn, MetadataBuffer, Result, - RustBuffer, UnexpectedUniFFICallbackError, + BlockingTaskQueue, BlockingTaskQueueVTable, ConvertError, FfiConverter, Lift, LiftReturn, + Lower, LowerReturn, MetadataBuffer, Result, RustBuffer, UnexpectedUniFFICallbackError, }; use anyhow::bail; use bytes::buf::{Buf, BufMut}; @@ -244,6 +244,35 @@ unsafe impl FfiConverter for Duration { const TYPE_ID_META: MetadataBuffer = MetadataBuffer::from_code(metadata::codes::TYPE_DURATION); } +/// Support for passing [BlockingTaskQueue] across the FFI +/// +/// Both fields of [BlockingTaskQueue] are serialized into a RustBuffer. The vtable pointer is +/// casted to a u64. +unsafe impl FfiConverter for BlockingTaskQueue { + ffi_converter_rust_buffer_lift_and_lower!(UT); + + fn write(obj: BlockingTaskQueue, buf: &mut Vec) { + let obj = obj.clone(); + buf.put_u64(obj.handle.into()); + buf.put_u64(obj.vtable as *const BlockingTaskQueueVTable as u64); + } + + fn try_read(buf: &mut &[u8]) -> Result { + check_remaining(buf, 16)?; + let handle = buf + .get_u64() + .try_into() + .expect("handle = 0 when reading BlockingTaskQueue"); + let vtable = unsafe { + &*(buf.get_u64() as *const BlockingTaskQueueVTable) as &'static BlockingTaskQueueVTable + }; + Ok(Self { handle, vtable }) + } + + const TYPE_ID_META: MetadataBuffer = + MetadataBuffer::from_code(metadata::codes::TYPE_BLOCKING_TASK_QUEUE); +} + // Support for passing optional values via the FFI. // // Optional values are currently always passed by serializing to a buffer. @@ -419,6 +448,7 @@ derive_ffi_traits!(blanket bool); derive_ffi_traits!(blanket String); derive_ffi_traits!(blanket Duration); derive_ffi_traits!(blanket SystemTime); +derive_ffi_traits!(blanket BlockingTaskQueue); // For composite types, derive LowerReturn, LiftReturn, etc, from Lift/Lower. // diff --git a/uniffi_core/src/metadata.rs b/uniffi_core/src/metadata.rs index f6a42e9876..9667bc627c 100644 --- a/uniffi_core/src/metadata.rs +++ b/uniffi_core/src/metadata.rs @@ -67,6 +67,7 @@ pub mod codes { pub const TYPE_CUSTOM: u8 = 22; pub const TYPE_RESULT: u8 = 23; pub const TYPE_FUTURE: u8 = 24; + pub const TYPE_BLOCKING_TASK_QUEUE: u8 = 25; pub const TYPE_UNIT: u8 = 255; // Literal codes for LiteralMetadata - note that we don't support diff --git a/uniffi_macros/src/setup_scaffolding.rs b/uniffi_macros/src/setup_scaffolding.rs index 1d0c368504..425aee2540 100644 --- a/uniffi_macros/src/setup_scaffolding.rs +++ b/uniffi_macros/src/setup_scaffolding.rs @@ -166,8 +166,13 @@ fn rust_future_scaffolding_fns(module_path: &str) -> TokenStream { #[allow(clippy::missing_safety_doc, missing_docs)] #[doc(hidden)] #[no_mangle] - pub unsafe extern "C" fn #ffi_rust_future_poll(handle: ::uniffi::RustFutureHandle, callback: ::uniffi::RustFutureContinuationCallback, data: *const ()) { - ::uniffi::ffi::rust_future_poll::<#return_type>(handle, callback, data); + pub unsafe extern "C" fn #ffi_rust_future_poll( + handle: ::uniffi::RustFutureHandle, + callback: ::uniffi::RustFutureContinuationCallback, + data: *const (), + blocking_task_queue_handle: u64, + ) { + ::uniffi::ffi::rust_future_poll::<#return_type>(handle, callback, data, blocking_task_queue_handle); } #[allow(clippy::missing_safety_doc, missing_docs)] diff --git a/uniffi_meta/src/metadata.rs b/uniffi_meta/src/metadata.rs index 7506b9d7ab..a5a04bfb16 100644 --- a/uniffi_meta/src/metadata.rs +++ b/uniffi_meta/src/metadata.rs @@ -50,6 +50,7 @@ pub mod codes { pub const TYPE_CUSTOM: u8 = 22; pub const TYPE_RESULT: u8 = 23; //pub const TYPE_FUTURE: u8 = 24; + pub const TYPE_BLOCKING_TASK_QUEUE: u8 = 25; pub const TYPE_UNIT: u8 = 255; // Literal codes diff --git a/uniffi_meta/src/reader.rs b/uniffi_meta/src/reader.rs index fa7f4447e9..e3875bb407 100644 --- a/uniffi_meta/src/reader.rs +++ b/uniffi_meta/src/reader.rs @@ -144,6 +144,7 @@ impl<'a> MetadataReader<'a> { codes::TYPE_STRING => Type::String, codes::TYPE_DURATION => Type::Duration, codes::TYPE_SYSTEM_TIME => Type::Timestamp, + codes::TYPE_BLOCKING_TASK_QUEUE => Type::BlockingTaskQueue, codes::TYPE_RECORD => Type::Record { module_path: self.read_string()?, name: self.read_string()?, diff --git a/uniffi_meta/src/types.rs b/uniffi_meta/src/types.rs index 647f4e9929..e0ec13991d 100644 --- a/uniffi_meta/src/types.rs +++ b/uniffi_meta/src/types.rs @@ -86,6 +86,7 @@ pub enum Type { // How the object is implemented. imp: ObjectImpl, }, + BlockingTaskQueue, // Types defined in the component API, each of which has a string name. Record { module_path: String, diff --git a/uniffi_udl/src/resolver.rs b/uniffi_udl/src/resolver.rs index ea98cd7a99..1409c3a6ff 100644 --- a/uniffi_udl/src/resolver.rs +++ b/uniffi_udl/src/resolver.rs @@ -209,6 +209,7 @@ pub(crate) fn resolve_builtin_type(name: &str) -> Option { "f64" => Some(Type::Float64), "timestamp" => Some(Type::Timestamp), "duration" => Some(Type::Duration), + "BlockingTaskQueue" => Some(Type::BlockingTaskQueue), _ => None, } }