diff --git a/Cargo.lock b/Cargo.lock index 6c73b85452..3e1c7c827f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -612,26 +612,53 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0845fa252299212f0389d64ba26f34fa32cfe41588355f21ed507c59a0f64541" +[[package]] +name = "futures" +version = "0.3.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da0290714b38af9b4a7b094b8a37086d1b4e61f2df9122c3cad2577669145335" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb" dependencies = [ "futures-core", + "futures-sink", ] [[package]] name = "futures-core" -version = "0.3.28" +version = "0.3.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" + +[[package]] +name = "futures-executor" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +checksum = "0f4fb8693db0cf099eadcca0efe2a5a22e4550f98ed16aba6c48700da29597bc" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] [[package]] name = "futures-io" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" +checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa" [[package]] name = "futures-lite" @@ -648,6 +675,47 @@ dependencies = [ "waker-fn", ] +[[package]] +name = "futures-macro" +version = "0.3.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e36d3378ee38c2a36ad710c5d30c2911d752cb941c00c72dbabfb786a7970817" + +[[package]] +name = "futures-task" +version = "0.3.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" + +[[package]] +name = "futures-util" +version = "0.3.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "gimli" version = "0.28.0" @@ -1632,6 +1700,7 @@ dependencies = [ name = "uniffi-fixture-futures" version = "0.21.0" dependencies = [ + "futures", "once_cell", "thiserror", "tokio", diff --git a/docs/manual/src/futures.md b/docs/manual/src/futures.md index dbe5e5a163..97f9aaf164 100644 --- a/docs/manual/src/futures.md +++ b/docs/manual/src/futures.md @@ -45,3 +45,63 @@ This code uses `asyncio` to drive the future to completion, while our exposed fu In Rust `Future` terminology this means the foreign bindings supply the "executor" - think event-loop, or async runtime. In this example it's `asyncio`. There's no requirement for a Rust event loop. There are [some great API docs](https://docs.rs/uniffi_core/latest/uniffi_core/ffi/rustfuture/index.html) on the implementation that are well worth a read. + +## Blocking tasks + +Rust executors are designed around an assumption that the `Future::poll` function will return quickly. +This assumption, combined with cooperative scheduling, allows for a large number of futures to be handled by a small number of threads. +Foreign executors make similar assumptions and sometimes more extreme ones. +For example, the Python eventloop is single threaded -- if any task spends a long time between `await` points, then it will block all other tasks from progressing. + +This raises the question of how async code can interact with blocking code that performs blocking IO, long-running computations without `await` breaks, etc. +UniFFI defines the `BlockingTaskQueue` type, which is a foreign object that schedules work on a thread where it's okay to block. + +On Rust, `BlockingTaskQueue` is a UniFFI type that can safely run blocking code. +It's `execute` method works like tokio's [block_in_place](https://docs.rs/tokio/latest/tokio/task/fn.block_in_place.html) function. +It inputs a closure and runs it in the `BlockingTaskQueue`. +This closure can reference the outside scope (i.e. it does not need to be `'static`). +For example: + +```rust +#[derive(uniffi::Object)] +struct DataStore { + // Used to run blocking tasks + queue: uniffi::BlockingTaskQueue, + // Low-level DB object with blocking methods + db: Mutex, +} + +#[uniffi::export] +impl DataStore { + #[uniffi::constructor] + fn new(queue: uniffi::BlockingTaskQueue) -> Self { + Self { + queue, + db: Mutex::new(Database::new()) + } + } + + fn fetch_all_items(&self) -> Vec { + self.queue.execute(|| self.db.lock().fetch_all_items()) + } +} +``` + +On the foreign side `BlockingTaskQueue` corresponds to a language-dependent class. + +### Kotlin +Kotlin uses `CoroutineContext` for its `BlockingTaskQueue`. +Any `CoroutineContext` will work, but `Dispatchers.IO` is usually a good choice. +A DataStore from the example above can be created with `DataStore(Dispatchers.IO)`. + +### Swift +Swift uses `DispatchQueue` for its `BlockingTaskQueue`. +The user-initiated global queue is normally a good choice. +A DataStore from the example above can be created with `DataStore(queue: DispatchQueue.global(qos: .userInitiated)`. +The `DispatchQueue` should be concurrent. + +### Python + +Python uses a `futures.Executor` for its `BlockingTaskQueue`. +`ThreadPoolExecutor` is typically a good choice. +A DataStore from the example above can be created with `DataStore(ThreadPoolExecutor())`. diff --git a/fixtures/futures/Cargo.toml b/fixtures/futures/Cargo.toml index 78b08cb689..67c850e9f0 100644 --- a/fixtures/futures/Cargo.toml +++ b/fixtures/futures/Cargo.toml @@ -16,6 +16,7 @@ path = "src/bin.rs" [dependencies] uniffi = { path = "../../uniffi", version = "0.25", features = ["tokio", "cli"] } +futures = "0.3.29" thiserror = "1.0" tokio = { version = "1.24.1", features = ["time", "sync"] } once_cell = "1.18.0" diff --git a/fixtures/futures/src/lib.rs b/fixtures/futures/src/lib.rs index 39a521495e..7466794740 100644 --- a/fixtures/futures/src/lib.rs +++ b/fixtures/futures/src/lib.rs @@ -11,6 +11,8 @@ use std::{ time::Duration, }; +use futures::stream::{FuturesUnordered, StreamExt}; + /// Non-blocking timer future. pub struct TimerFuture { shared_state: Arc>, @@ -326,4 +328,58 @@ pub async fn use_shared_resource(options: SharedResourceOptions) -> Result<(), A Ok(()) } +/// Async function that uses a blocking task queue to do its work +#[uniffi::export] +pub async fn calc_square(queue: uniffi::BlockingTaskQueue, value: i32) -> i32 { + queue.execute(|| value * value).await +} + +/// Same as before, but this one runs multiple tasks +#[uniffi::export] +pub async fn calc_squares(queue: uniffi::BlockingTaskQueue, items: Vec) -> Vec { + // Use `FuturesUnordered` to test our blocking task queue code which is known to be a tricky API to work with. + // In particular, if we don't notify the waker then FuturesUnordered will not poll again. + let mut futures: FuturesUnordered<_> = (0..items.len()) + .map(|i| { + // Test that we can use references from the surrounding scope + let items = &items; + queue.execute(move || items[i] * items[i]) + }) + .collect(); + let mut results = vec![]; + while let Some(result) = futures.next().await { + results.push(result); + } + results.sort(); + results +} + +/// ...and this one uses multiple BlockingTaskQueues +#[uniffi::export] +pub async fn calc_squares_multi_queue( + queues: Vec, + items: Vec, +) -> Vec { + let mut futures: FuturesUnordered<_> = (0..items.len()) + .map(|i| { + // Test that we can use references from the surrounding scope + let items = &items; + queues[i].execute(move || items[i] * items[i]) + }) + .collect(); + let mut results = vec![]; + while let Some(result) = futures.next().await { + results.push(result); + } + results.sort(); + results +} + +/// Like calc_square, but it clones the BlockingTaskQueue first then drops both copies. Used to +/// test that a) the clone works and b) we correctly drop the references. +#[uniffi::export] +pub async fn calc_square_with_clone(queue: uniffi::BlockingTaskQueue, value: i32) -> i32 { + queue.clone().execute(|| value * value).await +} + uniffi::include_scaffolding!("futures"); diff --git a/fixtures/futures/tests/bindings/test_futures.kts b/fixtures/futures/tests/bindings/test_futures.kts index 810bb40f41..0dece29410 100644 --- a/fixtures/futures/tests/bindings/test_futures.kts +++ b/fixtures/futures/tests/bindings/test_futures.kts @@ -1,9 +1,22 @@ import uniffi.fixture.futures.* +import java.util.concurrent.Executors import kotlinx.coroutines.* import kotlin.system.* +fun runAsyncTest(test: suspend CoroutineScope.() -> Unit) { + val initialBlockingTaskQueueHandleCount = uniffiBlockingTaskQueueHandleCount() + val initialPollHandleCount = uniffiPollHandleCount() + val time = runBlocking { + measureTimeMillis { + test() + } + } + assert(uniffiBlockingTaskQueueHandleCount() == initialBlockingTaskQueueHandleCount) + assert(uniffiPollHandleCount() == initialPollHandleCount) +} + // init UniFFI to get good measurements after that -runBlocking { +runAsyncTest { val time = measureTimeMillis { alwaysReady() } @@ -24,7 +37,7 @@ fun assertApproximateTime(actualTime: Long, expectedTime: Int, testName: String } // Test `always_ready`. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val result = alwaysReady() @@ -35,7 +48,7 @@ runBlocking { } // Test `void`. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val result = void() @@ -46,7 +59,7 @@ runBlocking { } // Test `sleep`. -runBlocking { +runAsyncTest { val time = measureTimeMillis { sleep(200U) } @@ -55,7 +68,7 @@ runBlocking { } // Test sequential futures. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val resultAlice = sayAfter(100U, "Alice") val resultBob = sayAfter(200U, "Bob") @@ -68,7 +81,7 @@ runBlocking { } // Test concurrent futures. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val resultAlice = async { sayAfter(100U, "Alice") } val resultBob = async { sayAfter(200U, "Bob") } @@ -81,7 +94,7 @@ runBlocking { } // Test async methods. -runBlocking { +runAsyncTest { val megaphone = newMegaphone() val time = measureTimeMillis { val resultAlice = megaphone.sayAfter(200U, "Alice") @@ -92,7 +105,7 @@ runBlocking { assertApproximateTime(time, 200, "async methods") } -runBlocking { +runAsyncTest { val megaphone = newMegaphone() val time = measureTimeMillis { val resultAlice = sayAfterWithMegaphone(megaphone, 200U, "Alice") @@ -104,7 +117,7 @@ runBlocking { } // Test async method returning optional object -runBlocking { +runAsyncTest { val megaphone = asyncMaybeNewMegaphone(true) assert(megaphone != null) @@ -113,7 +126,7 @@ runBlocking { } // Test with the Tokio runtime. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val resultAlice = sayAfterWithTokio(200U, "Alice") @@ -124,7 +137,7 @@ runBlocking { } // Test fallible function/method. -runBlocking { +runAsyncTest { val time1 = measureTimeMillis { try { fallibleMe(false) @@ -189,7 +202,7 @@ runBlocking { } // Test record. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val result = newMyRecord("foo", 42U) @@ -203,7 +216,7 @@ runBlocking { } // Test a broken sleep. -runBlocking { +runAsyncTest { val time = measureTimeMillis { brokenSleep(100U, 0U) // calls the waker twice immediately sleep(100U) // wait for possible failure @@ -217,7 +230,7 @@ runBlocking { // Test a future that uses a lock and that is cancelled. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val job = launch { useSharedResource(SharedResourceOptions(releaseAfterMs=5000U, timeoutMs=100U)) @@ -236,7 +249,7 @@ runBlocking { } // Test a future that uses a lock and that is not cancelled. -runBlocking { +runAsyncTest { val time = measureTimeMillis { useSharedResource(SharedResourceOptions(releaseAfterMs=100U, timeoutMs=1000U)) @@ -244,3 +257,33 @@ runBlocking { } println("useSharedResource (not canceled): ${time}ms") } + +// Test blocking task queues +runAsyncTest { + withTimeout(1000) { + assert(calcSquare(Dispatchers.IO, 20) == 400) + } + + withTimeout(1000) { + assert(calcSquares(Dispatchers.IO, listOf(1, -2, 3)) == listOf(1, 4, 9)) + } + + val executors = listOf( + Executors.newSingleThreadExecutor(), + Executors.newSingleThreadExecutor(), + Executors.newSingleThreadExecutor(), + ) + withTimeout(1000) { + assert(calcSquaresMultiQueue(executors.map { it.asCoroutineDispatcher() }, listOf(1, -2, 3)) == listOf(1, 4, 9)) + } + for (executor in executors) { + executor.shutdown() + } +} + +// Test blocking task queue cloning +runAsyncTest { + withTimeout(1000) { + assert(calcSquareWithClone(Dispatchers.IO, 20) == 400) + } +} diff --git a/fixtures/futures/tests/bindings/test_futures.py b/fixtures/futures/tests/bindings/test_futures.py index bfbeba86f8..37825db3a0 100644 --- a/fixtures/futures/tests/bindings/test_futures.py +++ b/fixtures/futures/tests/bindings/test_futures.py @@ -1,25 +1,31 @@ +import futures from futures import * +import contextlib import unittest from datetime import datetime import asyncio +from concurrent.futures import ThreadPoolExecutor def now(): return datetime.now() class TestFutures(unittest.TestCase): def test_always_ready(self): + @self.check_handle_counts() async def test(): self.assertEqual(await always_ready(), True) asyncio.run(test()) def test_void(self): + @self.check_handle_counts() async def test(): self.assertEqual(await void(), None) asyncio.run(test()) def test_sleep(self): + @self.check_handle_counts() async def test(): t0 = now() await sleep(2000) @@ -31,6 +37,7 @@ async def test(): asyncio.run(test()) def test_sequential_futures(self): + @self.check_handle_counts() async def test(): t0 = now() result_alice = await say_after(100, 'Alice') @@ -45,6 +52,7 @@ async def test(): asyncio.run(test()) def test_concurrent_tasks(self): + @self.check_handle_counts() async def test(): alice = asyncio.create_task(say_after(100, 'Alice')) bob = asyncio.create_task(say_after(200, 'Bob')) @@ -62,6 +70,7 @@ async def test(): asyncio.run(test()) def test_async_methods(self): + @self.check_handle_counts() async def test(): megaphone = new_megaphone() t0 = now() @@ -75,6 +84,7 @@ async def test(): asyncio.run(test()) def test_async_object_param(self): + @self.check_handle_counts() async def test(): megaphone = new_megaphone() t0 = now() @@ -88,6 +98,7 @@ async def test(): asyncio.run(test()) def test_with_tokio_runtime(self): + @self.check_handle_counts() async def test(): t0 = now() result_alice = await say_after_with_tokio(200, 'Alice') @@ -100,6 +111,7 @@ async def test(): asyncio.run(test()) def test_fallible(self): + @self.check_handle_counts() async def test(): result = await fallible_me(False) self.assertEqual(result, 42) @@ -124,6 +136,7 @@ async def test(): asyncio.run(test()) def test_fallible_struct(self): + @self.check_handle_counts() async def test(): megaphone = await fallible_struct(False) self.assertEqual(await megaphone.fallible_me(False), 42) @@ -137,6 +150,7 @@ async def test(): asyncio.run(test()) def test_record(self): + @self.check_handle_counts() async def test(): result = await new_my_record("foo", 42) self.assertEqual(result.__class__, MyRecord) @@ -146,6 +160,7 @@ async def test(): asyncio.run(test()) def test_cancel(self): + @self.check_handle_counts() async def test(): # Create a task task = asyncio.create_task(say_after(200, 'Alice')) @@ -163,6 +178,7 @@ async def test(): # Test a future that uses a lock and that is cancelled. def test_shared_resource_cancellation(self): + @self.check_handle_counts() async def test(): task = asyncio.create_task(use_shared_resource( SharedResourceOptions(release_after_ms=5000, timeout_ms=100))) @@ -173,10 +189,57 @@ async def test(): asyncio.run(test()) def test_shared_resource_no_cancellation(self): + @self.check_handle_counts() async def test(): await use_shared_resource(SharedResourceOptions(release_after_ms=100, timeout_ms=1000)) await use_shared_resource(SharedResourceOptions(release_after_ms=0, timeout_ms=1000)) asyncio.run(test()) + # blocking task queue tests + + def test_calc_square(self): + @self.check_handle_counts() + async def test(): + async with asyncio.timeout(1): + executor = ThreadPoolExecutor() + self.assertEqual(await calc_square(executor, 20), 400) + asyncio.run(test()) + + def test_calc_square_with_clone(self): + @self.check_handle_counts() + async def test(): + async with asyncio.timeout(1): + executor = ThreadPoolExecutor() + self.assertEqual(await calc_square_with_clone(executor, 20), 400) + asyncio.run(test()) + + def test_calc_squares(self): + @self.check_handle_counts() + async def test(): + async with asyncio.timeout(1): + executor = ThreadPoolExecutor() + self.assertEqual(await calc_squares(executor, [1, -2, 3]), [1, 4, 9]) + asyncio.run(test()) + + def test_calc_squares_multi_queue(self): + @self.check_handle_counts() + async def test(): + async with asyncio.timeout(1): + executors = [ + ThreadPoolExecutor(), + ThreadPoolExecutor(), + ThreadPoolExecutor(), + ] + self.assertEqual(await calc_squares_multi_queue(executors, [1, -2, 3]), [1, 4, 9]) + asyncio.run(test()) + + @contextlib.asynccontextmanager + async def check_handle_counts(self): + initial_poll_handle_count = len(futures.UNIFFI_POLL_DATA_POINTER_MANAGER) + initial_blocking_task_queue_handle_count = len(futures.UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP) + yield + self.assertEqual(len(futures.UNIFFI_POLL_DATA_POINTER_MANAGER), initial_poll_handle_count) + self.assertEqual(len(futures.UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP), initial_blocking_task_queue_handle_count) + if __name__ == '__main__': unittest.main() diff --git a/fixtures/futures/tests/bindings/test_futures.swift b/fixtures/futures/tests/bindings/test_futures.swift index 20e24c40ff..836a8f5a7b 100644 --- a/fixtures/futures/tests/bindings/test_futures.swift +++ b/fixtures/futures/tests/bindings/test_futures.swift @@ -3,10 +3,21 @@ import Foundation // To get `DispatchGroup` and `Date` types. var counter = DispatchGroup() -// Test `alwaysReady` -counter.enter() +func asyncTest(test: @escaping () async throws -> ()) { + let initialBlockingTaskQueueCount = uniffiBlockingTaskQueueHandleCount() + let initialPollDataHandleCount = uniffiPollDataHandleCount() + counter.enter() + Task { + try! await test() + counter.leave() + } + counter.wait() + assert(uniffiBlockingTaskQueueHandleCount() == initialBlockingTaskQueueCount) + assert(uniffiPollDataHandleCount() == initialPollDataHandleCount) +} -Task { +// Test `alwaysReady` +asyncTest { let t0 = Date() let result = await alwaysReady() let t1 = Date() @@ -14,40 +25,28 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration < 0.1) assert(result == true) - - counter.leave() } // Test record. -counter.enter() - -Task { +asyncTest { let result = await newMyRecord(a: "foo", b: 42) assert(result.a == "foo") assert(result.b == 42) - - counter.leave() } // Test `void` -counter.enter() - -Task { +asyncTest { let t0 = Date() await void() let t1 = Date() let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration < 0.1) - - counter.leave() } // Test `Sleep` -counter.enter() - -Task { +asyncTest { let t0 = Date() let result = await sleep(ms: 2000) let t1 = Date() @@ -55,14 +54,10 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 2 && tDelta.duration < 2.1) assert(result == true) - - counter.leave() } // Test sequential futures. -counter.enter() - -Task { +asyncTest { let t0 = Date() let result_alice = await sayAfter(ms: 1000, who: "Alice") let result_bob = await sayAfter(ms: 2000, who: "Bob") @@ -72,14 +67,10 @@ Task { assert(tDelta.duration > 3 && tDelta.duration < 3.1) assert(result_alice == "Hello, Alice!") assert(result_bob == "Hello, Bob!") - - counter.leave() } // Test concurrent futures. -counter.enter() - -Task { +asyncTest { async let alice = sayAfter(ms: 1000, who: "Alice") async let bob = sayAfter(ms: 2000, who: "Bob") @@ -91,14 +82,10 @@ Task { assert(tDelta.duration > 2 && tDelta.duration < 2.1) assert(result_alice == "Hello, Alice!") assert(result_bob == "Hello, Bob!") - - counter.leave() } // Test async methods -counter.enter() - -Task { +asyncTest { let megaphone = newMegaphone() let t0 = Date() @@ -108,26 +95,18 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 2 && tDelta.duration < 2.1) assert(result_alice == "HELLO, ALICE!") - - counter.leave() } // Test async function returning an object -counter.enter() - -Task { +asyncTest { let megaphone = await asyncNewMegaphone() let result = try await megaphone.fallibleMe(doFail: false) assert(result == 42) - - counter.leave() } // Test with the Tokio runtime. -counter.enter() - -Task { +asyncTest { let t0 = Date() let result_alice = await sayAfterWithTokio(ms: 2000, who: "Alice") let t1 = Date() @@ -135,15 +114,11 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 2 && tDelta.duration < 2.1) assert(result_alice == "Hello, Alice (with Tokio)!") - - counter.leave() } // Test fallible function/method… // … which doesn't throw. -counter.enter() - -Task { +asyncTest { let t0 = Date() let result = try await fallibleMe(doFail: false) let t1 = Date() @@ -151,19 +126,15 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 0 && tDelta.duration < 0.1) assert(result == 42) - - counter.leave() } -Task { +asyncTest { let m = try await fallibleStruct(doFail: false) let result = try await m.fallibleMe(doFail: false) assert(result == 42) } -counter.enter() - -Task { +asyncTest { let megaphone = newMegaphone() let t0 = Date() @@ -173,14 +144,10 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 0 && tDelta.duration < 0.1) assert(result == 42) - - counter.leave() } // … which does throw. -counter.enter() - -Task { +asyncTest { let t0 = Date() do { @@ -195,11 +162,9 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 0 && tDelta.duration < 0.1) - - counter.leave() } -Task { +asyncTest { do { let _ = try await fallibleStruct(doFail: true) } catch MyError.Foo { @@ -209,9 +174,7 @@ Task { } } -counter.enter() - -Task { +asyncTest { let megaphone = newMegaphone() let t0 = Date() @@ -228,13 +191,10 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 0 && tDelta.duration < 0.1) - - counter.leave() } // Test a future that uses a lock and that is cancelled. -counter.enter() -Task { +asyncTest { let task = Task { try! await useSharedResource(options: SharedResourceOptions(releaseAfterMs: 100, timeoutMs: 1000)) } @@ -250,15 +210,36 @@ Task { // Try accessing the shared resource again. The initial task should release the shared resource // before the timeout expires. try! await useSharedResource(options: SharedResourceOptions(releaseAfterMs: 0, timeoutMs: 1000)) - counter.leave() } // Test a future that uses a lock and that is not cancelled. -counter.enter() -Task { +asyncTest { try! await useSharedResource(options: SharedResourceOptions(releaseAfterMs: 100, timeoutMs: 1000)) try! await useSharedResource(options: SharedResourceOptions(releaseAfterMs: 0, timeoutMs: 1000)) - counter.leave() } -counter.wait() +// Test blocking task queues +asyncTest { + let calcSquareResult = await calcSquare(queue: DispatchQueue.global(qos: .userInitiated), value: 20) + assert(calcSquareResult == 400) + + let calcSquaresResult = await calcSquares(queue: DispatchQueue.global(qos: .userInitiated), items: [1, -2, 3]) + assert(calcSquaresResult == [1, 4, 9]) + + let calcSquaresMultiQueueResult = await calcSquaresMultiQueue( + queues: [ + DispatchQueue(label: "test-queue1", attributes: DispatchQueue.Attributes.concurrent), + DispatchQueue(label: "test-queue2", attributes: DispatchQueue.Attributes.concurrent), + DispatchQueue(label: "test-queue3", attributes: DispatchQueue.Attributes.concurrent) + ], + items: [1, -2, 3] + ) + assert(calcSquaresMultiQueueResult == [1, 4, 9]) +} + +// Test blocking task queue cloning +asyncTest { + let calcSquareResult = await calcSquareWithClone(queue: DispatchQueue.global(qos: .userInitiated), value: 20) + assert(calcSquareResult == 400) +} + diff --git a/fixtures/metadata/src/tests.rs b/fixtures/metadata/src/tests.rs index f4d8ae244f..cc21f14a62 100644 --- a/fixtures/metadata/src/tests.rs +++ b/fixtures/metadata/src/tests.rs @@ -123,6 +123,7 @@ mod test_type_ids { check_type_id::(Type::Float64); check_type_id::(Type::Boolean); check_type_id::(Type::String); + check_type_id::(Type::BlockingTaskQueue); } #[test] diff --git a/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/blocking_task_queue.rs b/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/blocking_task_queue.rs new file mode 100644 index 0000000000..a664f1fcd3 --- /dev/null +++ b/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/blocking_task_queue.rs @@ -0,0 +1,19 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use super::CodeType; + +#[derive(Debug)] +pub struct BlockingTaskQueueCodeType; + +impl CodeType for BlockingTaskQueueCodeType { + fn type_label(&self, _ci: &crate::ComponentInterface) -> String { + // Kotlin uses CoroutineContext for BlockingTaskQueue + "CoroutineContext".into() + } + + fn canonical_name(&self) -> String { + "BlockingTaskQueue".into() + } +} diff --git a/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs b/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs index 9b542adf85..cac4accbdf 100644 --- a/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs +++ b/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs @@ -16,6 +16,7 @@ use crate::backend::TemplateExpression; use crate::interface::*; use crate::BindingsConfig; +mod blocking_task_queue; mod callback_interface; mod compounds; mod custom; @@ -426,6 +427,8 @@ impl AsCodeType for T { Type::Timestamp => Box::new(miscellany::TimestampCodeType), Type::Duration => Box::new(miscellany::DurationCodeType), + Type::BlockingTaskQueue => Box::new(blocking_task_queue::BlockingTaskQueueCodeType), + Type::Enum { name, .. } => Box::new(enum_::EnumCodeType::new(name)), Type::Object { name, imp, .. } => Box::new(object::ObjectCodeType::new(name, imp)), Type::Record { name, .. } => Box::new(record::RecordCodeType::new(name)), @@ -561,7 +564,7 @@ mod filters { ) -> Result { let ffi_func = callable.ffi_rust_future_poll(ci); Ok(format!( - "{{ future, callback, continuation -> _UniFFILib.INSTANCE.{ffi_func}(future, callback, continuation) }}" + "{{ future, callback, continuation, blockingTaskQueueHandle -> _UniFFILib.INSTANCE.{ffi_func}(future, callback, continuation, blockingTaskQueueHandle) }}" )) } diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt b/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt index 4f4ece37f7..cf2c48f1fb 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt @@ -3,18 +3,37 @@ internal const val UNIFFI_RUST_FUTURE_POLL_READY = 0.toByte() internal const val UNIFFI_RUST_FUTURE_POLL_MAYBE_READY = 1.toByte() -internal val uniffiContinuationHandleMap = UniFfiHandleMap>() +/** + * Data for an in-progress poll of a RustFuture + */ +internal data class UniffiPollData( + val continuation: CancellableContinuation, + val rustFuture: Pointer, + val pollFunc: (Pointer, UniffiRustFutureContinuationCallback, USize, Long) -> Unit, +) + +internal val uniffiPollDataHandleMap = UniFfiHandleMap() // FFI type for Rust future continuations internal object uniffiRustFutureContinuationCallback: UniffiRustFutureContinuationCallback { - override fun callback(data: USize, pollResult: Byte) { - uniffiContinuationHandleMap.remove(data)?.resume(pollResult) + override fun callback(data: USize, pollResult: Byte, blockingTaskQueueHandle: Long) { + if (blockingTaskQueueHandle == 0L) { + // Complete the Kotlin continuation + uniffiPollDataHandleMap.remove(data)!!.continuation.resume(pollResult) + } else { + // Call the poll function again, but inside the BlockingTaskQueue coroutine context + val coroutineContext = uniffiBlockingTaskQueueHandleMap.get(blockingTaskQueueHandle) + val pollData = uniffiPollDataHandleMap.get(data)!! + CoroutineScope(coroutineContext).launch { + pollData.pollFunc(pollData.rustFuture, uniffiRustFutureContinuationCallback, data, blockingTaskQueueHandle) + } + } } } internal suspend fun uniffiRustCallAsync( rustFuture: Pointer, - pollFunc: (Pointer, UniffiRustFutureContinuationCallback, USize) -> Unit, + pollFunc: (Pointer, UniffiRustFutureContinuationCallback, USize, Long) -> Unit, completeFunc: (Pointer, RustCallStatus) -> F, freeFunc: (Pointer) -> Unit, liftFunc: (F) -> T, @@ -23,11 +42,9 @@ internal suspend fun uniffiRustCallAsync( try { do { val pollResult = suspendCancellableCoroutine { continuation -> - pollFunc( - rustFuture, - uniffiRustFutureContinuationCallback, - uniffiContinuationHandleMap.insert(continuation) - ) + val pollData = UniffiPollData(continuation, rustFuture, pollFunc) + val pollDataHandle = uniffiPollDataHandleMap.insert(pollData) + pollFunc(rustFuture, uniffiRustFutureContinuationCallback, pollDataHandle, 0L) } } while (pollResult != UNIFFI_RUST_FUTURE_POLL_READY); @@ -39,3 +56,6 @@ internal suspend fun uniffiRustCallAsync( } } +// For testing +public fun uniffiPollHandleCount() = uniffiPollDataHandleMap.size + diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/BlockingTaskQueueTemplate.kt b/uniffi_bindgen/src/bindings/kotlin/templates/BlockingTaskQueueTemplate.kt new file mode 100644 index 0000000000..717ff62df3 --- /dev/null +++ b/uniffi_bindgen/src/bindings/kotlin/templates/BlockingTaskQueueTemplate.kt @@ -0,0 +1,44 @@ +{{ self.add_import("kotlin.coroutines.CoroutineContext") }} + +object uniffiBlockingTaskQueueClone : UniffiBlockingTaskQueueClone { + override fun callback(handle: Long): Long { + val coroutineContext = uniffiBlockingTaskQueueHandleMap.get(handle) + return uniffiBlockingTaskQueueHandleMap.insert(coroutineContext) + } +} + +object uniffiBlockingTaskQueueFree : UniffiBlockingTaskQueueFree { + override fun callback(handle: Long) { + uniffiBlockingTaskQueueHandleMap.remove(handle) + } +} + +internal val uniffiBlockingTaskQueueVTable = UniffiBlockingTaskQueueVTable( + uniffiBlockingTaskQueueClone, + uniffiBlockingTaskQueueFree, +) +internal val uniffiBlockingTaskQueueHandleMap = ConcurrentHandleMap() + +public object {{ ffi_converter_name }}: FfiConverterRustBuffer { + override fun allocationSize(value: {{ type_name }}) = 16 + + override fun write(value: CoroutineContext, buf: ByteBuffer) { + // Call `write()` to make sure the data is written to the JNA backing data + uniffiBlockingTaskQueueVTable.write() + val handle = uniffiBlockingTaskQueueHandleMap.insert(value) + buf.putLong(handle) + buf.putLong(Pointer.nativeValue(uniffiBlockingTaskQueueVTable.getPointer())) + } + + override fun read(buf: ByteBuffer): CoroutineContext { + val handle = buf.getLong() + val coroutineContext = uniffiBlockingTaskQueueHandleMap.remove(handle)!! + // Read the VTable pointer and throw it out. The vtable is only used by Rust and always the + // same value. + buf.getLong() + return coroutineContext + } +} + +// For testing +public fun uniffiBlockingTaskQueueHandleCount() = uniffiBlockingTaskQueueHandleMap.size diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/CallbackInterfaceRuntime.kt b/uniffi_bindgen/src/bindings/kotlin/templates/CallbackInterfaceRuntime.kt index 5f4d12ddf2..b485779adc 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/CallbackInterfaceRuntime.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/CallbackInterfaceRuntime.kt @@ -1,37 +1,3 @@ -{{- self.add_import("java.util.concurrent.atomic.AtomicLong") }} -{{- self.add_import("java.util.concurrent.locks.ReentrantLock") }} -{{- self.add_import("kotlin.concurrent.withLock") }} - -internal typealias Handle = Long -internal class ConcurrentHandleMap( - private val leftMap: MutableMap = mutableMapOf(), -) { - private val lock = java.util.concurrent.locks.ReentrantLock() - private val currentHandle = AtomicLong(0L) - private val stride = 1L - - fun insert(obj: T): Handle = - lock.withLock { - currentHandle.getAndAdd(stride) - .also { handle -> - leftMap[handle] = obj - } - } - - fun get(handle: Handle) = lock.withLock { - leftMap[handle] ?: throw InternalException("No callback in handlemap; this is a Uniffi bug") - } - - fun delete(handle: Handle) { - this.remove(handle) - } - - fun remove(handle: Handle): T? = - lock.withLock { - leftMap.remove(handle) - } -} - // Magic number for the Rust proxy to call using the same mechanism as every other method, // to free the callback once it's dropped by Rust. internal const val IDX_CALLBACK_FREE = 0 diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/Helpers.kt b/uniffi_bindgen/src/bindings/kotlin/templates/Helpers.kt index c623c37734..b5f95f7779 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/Helpers.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/Helpers.kt @@ -189,3 +189,41 @@ internal class UniFfiHandleMap { return map.remove(handle) } } + +internal typealias Handle = Long +// Like UniFfiHandleMap, but allocates u64 values rather than usize ones. +// +// FIXME: consolidate this code with `UniFfiHandleMap` +// https://github.com/mozilla/uniffi-rs/pull/1823 +internal class ConcurrentHandleMap( + private val leftMap: MutableMap = mutableMapOf(), +) { + private val lock = java.util.concurrent.locks.ReentrantLock() + // Start with 1 so that 0 can be special-cased as the null value. + private val currentHandle = AtomicLong(1L) + private val stride = 1L + + val size: Int + get() = leftMap.size + + fun insert(obj: T): Handle = + lock.withLock { + currentHandle.getAndAdd(stride) + .also { handle -> + leftMap[handle] = obj + } + } + + fun get(handle: Handle) = lock.withLock { + leftMap[handle] ?: throw InternalException("No callback in handlemap; this is a Uniffi bug") + } + + fun delete(handle: Handle) { + this.remove(handle) + } + + fun remove(handle: Handle): T? = + lock.withLock { + leftMap.remove(handle) + } +} diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/ObjectRuntime.kt b/uniffi_bindgen/src/bindings/kotlin/templates/ObjectRuntime.kt index fa22b6ad2b..ed2d6c8f71 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/ObjectRuntime.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/ObjectRuntime.kt @@ -1,4 +1,3 @@ -{{- self.add_import("java.util.concurrent.atomic.AtomicLong") }} {{- self.add_import("java.util.concurrent.atomic.AtomicBoolean") }} // The base class for all UniFFI Object types. // diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/Types.kt b/uniffi_bindgen/src/bindings/kotlin/templates/Types.kt index fb7cb4afde..f720ff6e41 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/Types.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/Types.kt @@ -86,6 +86,9 @@ inline fun T.use(block: (T) -> R) = {%- when Type::Bytes %} {%- include "ByteArrayHelper.kt" %} +{%- when Type::BlockingTaskQueue %} +{%- include "BlockingTaskQueueTemplate.kt" %} + {%- when Type::Enum { name, module_path } %} {%- let e = ci.get_enum_definition(name).unwrap() %} {%- if !ci.is_name_used_as_error(name) %} @@ -131,6 +134,8 @@ inline fun T.use(block: (T) -> R) = {%- if ci.has_async_fns() %} {# Import types needed for async support #} {{ self.add_import("kotlin.coroutines.resume") }} +{{ self.add_import("kotlinx.coroutines.launch") }} {{ self.add_import("kotlinx.coroutines.suspendCancellableCoroutine") }} {{ self.add_import("kotlinx.coroutines.CancellableContinuation") }} +{{ self.add_import("kotlinx.coroutines.CoroutineScope") }} {%- endif %} diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/wrapper.kt b/uniffi_bindgen/src/bindings/kotlin/templates/wrapper.kt index 0f924e4d63..438254f4ae 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/wrapper.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/wrapper.kt @@ -30,7 +30,10 @@ import java.nio.ByteBuffer import java.nio.ByteOrder import java.nio.CharBuffer import java.nio.charset.CodingErrorAction +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.locks.ReentrantLock import java.util.concurrent.ConcurrentHashMap +import kotlin.concurrent.withLock {%- for req in self.imports() %} {{ req.render() }} diff --git a/uniffi_bindgen/src/bindings/python/gen_python/blocking_task_queue.rs b/uniffi_bindgen/src/bindings/python/gen_python/blocking_task_queue.rs new file mode 100644 index 0000000000..2f5a74fda0 --- /dev/null +++ b/uniffi_bindgen/src/bindings/python/gen_python/blocking_task_queue.rs @@ -0,0 +1,19 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use super::CodeType; + +#[derive(Debug)] +pub struct BlockingTaskQueueCodeType; + +impl CodeType for BlockingTaskQueueCodeType { + // On python we use an concurrent.futures.Executor for a BlockingTaskQueue + fn type_label(&self) -> String { + "concurrent.futures.Executor".into() + } + + fn canonical_name(&self) -> String { + "BlockingTaskQueue".into() + } +} diff --git a/uniffi_bindgen/src/bindings/python/gen_python/mod.rs b/uniffi_bindgen/src/bindings/python/gen_python/mod.rs index 2297627043..dffef2599a 100644 --- a/uniffi_bindgen/src/bindings/python/gen_python/mod.rs +++ b/uniffi_bindgen/src/bindings/python/gen_python/mod.rs @@ -16,6 +16,7 @@ use crate::backend::TemplateExpression; use crate::interface::*; use crate::BindingsConfig; +mod blocking_task_queue; mod callback_interface; mod compounds; mod custom; @@ -423,6 +424,8 @@ impl AsCodeType for T { Type::Timestamp => Box::new(miscellany::TimestampCodeType), Type::Duration => Box::new(miscellany::DurationCodeType), + Type::BlockingTaskQueue => Box::new(blocking_task_queue::BlockingTaskQueueCodeType), + Type::Enum { name, .. } => Box::new(enum_::EnumCodeType::new(name)), Type::Object { name, .. } => Box::new(object::ObjectCodeType::new(name)), Type::Record { name, .. } => Box::new(record::RecordCodeType::new(name)), diff --git a/uniffi_bindgen/src/bindings/python/templates/Async.py b/uniffi_bindgen/src/bindings/python/templates/Async.py index aae87ab98b..b31b421485 100644 --- a/uniffi_bindgen/src/bindings/python/templates/Async.py +++ b/uniffi_bindgen/src/bindings/python/templates/Async.py @@ -2,33 +2,63 @@ _UNIFFI_RUST_FUTURE_POLL_READY = 0 _UNIFFI_RUST_FUTURE_POLL_MAYBE_READY = 1 +""" +Data for an in-progress poll of a RustFuture +""" +class UniffiPoll(typing.NamedTuple): + eventloop: asyncio.AbstractEventLoop + py_future: asyncio.Future + rust_future: int + poll_fn: UNIFFI_RUST_FUTURE_CONTINUATION_CALLBACK + # Stores futures for _uniffi_continuation_callback -_UniffiContinuationPointerManager = _UniffiPointerManager() +UNIFFI_POLL_DATA_POINTER_MANAGER = _UniffiPointerManager() # Continuation callback for async functions # lift the return value or error and resolve the future, causing the async function to resume. @UNIFFI_RUST_FUTURE_CONTINUATION_CALLBACK -def _uniffi_continuation_callback(future_ptr, poll_code): - (eventloop, future) = _UniffiContinuationPointerManager.release_pointer(future_ptr) - eventloop.call_soon_threadsafe(_uniffi_set_future_result, future, poll_code) +def _uniffi_continuation_callback(poll_data_ptr, poll_code, blocking_task_queue_handle): + if blocking_task_queue_handle == 0: + # Complete the Python Future + poll_data = UNIFFI_POLL_DATA_POINTER_MANAGER.release_pointer(poll_data_ptr) + poll_data.eventloop.call_soon_threadsafe(_uniffi_set_future_result, poll_data.py_future, poll_code) + else: + # Call the poll function again, but inside the executor + poll_data = UNIFFI_POLL_DATA_POINTER_MANAGER.lookup(poll_data_ptr) + executor = UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.get(blocking_task_queue_handle) + executor.submit( + poll_data.poll_fn, + poll_data.rust_future, + _uniffi_continuation_callback, + poll_data_ptr, + blocking_task_queue_handle + ) def _uniffi_set_future_result(future, poll_code): if not future.cancelled(): future.set_result(poll_code) -async def _uniffi_rust_call_async(rust_future, ffi_poll, ffi_complete, ffi_free, lift_func, error_ffi_converter): +async def _uniffi_rust_call_async(rust_future, poll_fn, ffi_complete, ffi_free, lift_func, error_ffi_converter): try: eventloop = asyncio.get_running_loop() - # Loop and poll until we see a _UNIFFI_RUST_FUTURE_POLL_READY value + # Loop and poll until we see a UNIFFI_RUST_FUTURE_POLL_READY value while True: - future = eventloop.create_future() - ffi_poll( + py_future = eventloop.create_future() + poll_data = UniffiPoll( + eventloop=eventloop, + py_future=py_future, + rust_future=rust_future, + poll_fn=poll_fn, + ) + poll_handle = UNIFFI_POLL_DATA_POINTER_MANAGER.new_pointer(poll_data) + poll_fn( rust_future, _uniffi_continuation_callback, - _UniffiContinuationPointerManager.new_pointer((eventloop, future)), + poll_handle, + 0, ) - poll_code = await future + poll_code = await py_future if poll_code == _UNIFFI_RUST_FUTURE_POLL_READY: break diff --git a/uniffi_bindgen/src/bindings/python/templates/BlockingTaskQueueTemplate.py b/uniffi_bindgen/src/bindings/python/templates/BlockingTaskQueueTemplate.py new file mode 100644 index 0000000000..329ff37906 --- /dev/null +++ b/uniffi_bindgen/src/bindings/python/templates/BlockingTaskQueueTemplate.py @@ -0,0 +1,39 @@ +{{ self.add_import("concurrent.futures") }} + +@UNIFFI_BLOCKING_TASK_QUEUE_CLONE +def uniffi_blocking_task_queue_clone(handle): + executor = UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.get(handle) + return UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.insert(executor) + +@UNIFFI_BLOCKING_TASK_QUEUE_FREE +def uniffi_blocking_task_queue_free(handle): + UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.remove(handle) + +UNIFFI_BLOCKING_TASK_QUEUE_VTABLE = UniffiBlockingTaskQueueVTable( + uniffi_blocking_task_queue_clone, + uniffi_blocking_task_queue_free, +) + +UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP = ConcurrentHandleMap() + +class {{ ffi_converter_name }}(_UniffiConverterRustBuffer): + @staticmethod + def check_lower(value): + if not isinstance(value, concurrent.futures.Executor): + raise TypeError("Expected concurrent.futures.Executor instance, {} found".format(type(value).__name__)) + + @staticmethod + def write(value, buf): + handle = UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.insert(value) + buf.write_u64(handle) + buf.write_u64(ctypes.addressof(UNIFFI_BLOCKING_TASK_QUEUE_VTABLE)) + + @staticmethod + def read(buf): + handle = buf.read_u64() + executor = UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.remove(handle) + # Read the VTable pointer and throw it out. The vtable is only used by Rust and always the + # same value. + buf.read_u64() + return executor + diff --git a/uniffi_bindgen/src/bindings/python/templates/CallbackInterfaceRuntime.py b/uniffi_bindgen/src/bindings/python/templates/CallbackInterfaceRuntime.py index 1337d9685f..7862513db5 100644 --- a/uniffi_bindgen/src/bindings/python/templates/CallbackInterfaceRuntime.py +++ b/uniffi_bindgen/src/bindings/python/templates/CallbackInterfaceRuntime.py @@ -1,38 +1,3 @@ -import threading - -class ConcurrentHandleMap: - """ - A map where inserting, getting and removing data is synchronized with a lock. - """ - - def __init__(self): - # type Handle = int - self._left_map = {} # type: Dict[Handle, Any] - - self._lock = threading.Lock() - self._current_handle = 0 - self._stride = 1 - - def insert(self, obj): - with self._lock: - handle = self._current_handle - self._current_handle += self._stride - self._left_map[handle] = obj - return handle - - def get(self, handle): - with self._lock: - obj = self._left_map.get(handle) - if not obj: - raise InternalError("No callback in handlemap; this is a uniffi bug") - return obj - - def remove(self, handle): - with self._lock: - if handle in self._left_map: - obj = self._left_map.pop(handle) - return obj - # Magic number for the Rust proxy to call using the same mechanism as every other method, # to free the callback once it's dropped by Rust. IDX_CALLBACK_FREE = 0 diff --git a/uniffi_bindgen/src/bindings/python/templates/Helpers.py b/uniffi_bindgen/src/bindings/python/templates/Helpers.py index b4dad8da12..41f770b546 100644 --- a/uniffi_bindgen/src/bindings/python/templates/Helpers.py +++ b/uniffi_bindgen/src/bindings/python/templates/Helpers.py @@ -84,3 +84,43 @@ def _uniffi_trait_interface_call_with_error(call_status, make_call, write_return except Exception as e: call_status.code = _UniffiRustCallStatus.CALL_UNEXPECTED_ERROR call_status.error_buf = {{ Type::String.borrow()|lower_fn }}(repr(e)) + +class ConcurrentHandleMap: + """ + A map where inserting, getting and removing data is synchronized with a lock. + + TODO: consolidate this with `PointerManager` + https://github.com/mozilla/uniffi-rs/pull/1823 + """ + + def __init__(self): + # type Handle = int + self._left_map = {} # type: Dict[Handle, Any] + + self._lock = threading.Lock() + # Start with 1 so that 0 can be special-cased as the null value. + self._current_handle = 1 + self._stride = 1 + + def insert(self, obj): + with self._lock: + handle = self._current_handle + self._current_handle += self._stride + self._left_map[handle] = obj + return handle + + def get(self, handle): + with self._lock: + obj = self._left_map.get(handle) + if not obj: + raise InternalError("No callback in handlemap; this is a uniffi bug") + return obj + + def remove(self, handle): + with self._lock: + if handle in self._left_map: + obj = self._left_map.pop(handle) + return obj + + def __len__(self): + return len(self._left_map) diff --git a/uniffi_bindgen/src/bindings/python/templates/PointerManager.py b/uniffi_bindgen/src/bindings/python/templates/PointerManager.py index 23aa28eab4..526afc563a 100644 --- a/uniffi_bindgen/src/bindings/python/templates/PointerManager.py +++ b/uniffi_bindgen/src/bindings/python/templates/PointerManager.py @@ -5,6 +5,8 @@ class _UniffiPointerManagerCPython: This class is used to generate opaque pointers that reference Python objects to pass to Rust. It assumes a CPython platform. See _UniffiPointerManagerGeneral for the alternative. """ + def __init__(self): + self.count = 0 def new_pointer(self, obj): """ @@ -19,17 +21,22 @@ def new_pointer(self, obj): ctypes.pythonapi.Py_IncRef(ctypes.py_object(obj)) # id() is the object address on CPython # (https://docs.python.org/3/library/functions.html#id) + self.count += 1 return id(obj) def release_pointer(self, address): py_obj = ctypes.cast(address, ctypes.py_object) obj = py_obj.value ctypes.pythonapi.Py_DecRef(py_obj) + self.count -= 1 return obj def lookup(self, address): return ctypes.cast(address, ctypes.py_object).value + def __len__(self): + return self.count + class _UniffiPointerManagerGeneral: """ Manage giving out pointers to Python objects on non-CPython platforms @@ -61,6 +68,9 @@ def lookup(self, handle): with self._lock: return self._map[handle] + def __len__(self): + return len(self._map) + # Pick an pointer manager implementation based on the platform if platform.python_implementation() == 'CPython': _UniffiPointerManager = _UniffiPointerManagerCPython # type: ignore diff --git a/uniffi_bindgen/src/bindings/python/templates/Types.py b/uniffi_bindgen/src/bindings/python/templates/Types.py index 4aaed253e0..9f2840fefb 100644 --- a/uniffi_bindgen/src/bindings/python/templates/Types.py +++ b/uniffi_bindgen/src/bindings/python/templates/Types.py @@ -55,6 +55,9 @@ {%- when Type::Bytes %} {%- include "BytesHelper.py" %} +{%- when Type::BlockingTaskQueue %} +{%- include "BlockingTaskQueueTemplate.py" %} + {%- when Type::Enum { name, module_path } %} {%- let e = ci.get_enum_definition(name).unwrap() %} {# For enums, there are either an error *or* an enum, they can't be both. #} diff --git a/uniffi_bindgen/src/bindings/python/templates/wrapper.py b/uniffi_bindgen/src/bindings/python/templates/wrapper.py index 276ba868c3..7dd2f6a2ab 100644 --- a/uniffi_bindgen/src/bindings/python/templates/wrapper.py +++ b/uniffi_bindgen/src/bindings/python/templates/wrapper.py @@ -23,6 +23,7 @@ import contextlib import datetime import typing +import threading {%- if ci.has_async_fns() %} import asyncio {%- endif %} diff --git a/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs b/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs index c2c55e8018..78359b8d78 100644 --- a/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs +++ b/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs @@ -57,6 +57,7 @@ pub fn canonical_name(t: &Type) -> String { Type::CallbackInterface { name, .. } => format!("CallbackInterface{name}"), Type::Timestamp => "Timestamp".into(), Type::Duration => "Duration".into(), + Type::BlockingTaskQueue => "BlockingTaskQueue".into(), // Recursive types. // These add a prefix to the name of the underlying type. // The component API definition cannot give names to recursive types, so as long as the @@ -262,6 +263,7 @@ mod filters { } Type::External { .. } => panic!("No support for external types, yet"), Type::Custom { .. } => panic!("No support for custom types, yet"), + Type::BlockingTaskQueue => panic!("No support for async functions, yet"), }) } @@ -315,6 +317,7 @@ mod filters { ), Type::External { .. } => panic!("No support for lowering external types, yet"), Type::Custom { .. } => panic!("No support for lowering custom types, yet"), + Type::BlockingTaskQueue => panic!("No support for async functions, yet"), }) } @@ -355,6 +358,7 @@ mod filters { ), Type::External { .. } => panic!("No support for lifting external types, yet"), Type::Custom { .. } => panic!("No support for lifting custom types, yet"), + Type::BlockingTaskQueue => panic!("No support for async functions, yet"), }) } } diff --git a/uniffi_bindgen/src/bindings/swift/gen_swift/blocking_task_queue.rs b/uniffi_bindgen/src/bindings/swift/gen_swift/blocking_task_queue.rs new file mode 100644 index 0000000000..ab4df07f9b --- /dev/null +++ b/uniffi_bindgen/src/bindings/swift/gen_swift/blocking_task_queue.rs @@ -0,0 +1,19 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use super::CodeType; + +#[derive(Debug)] +pub struct BlockingTaskQueueCodeType; + +impl CodeType for BlockingTaskQueueCodeType { + fn type_label(&self) -> String { + // On Swift, we use a DispatchQueue for BlockingTaskQueue + "DispatchQueue".into() + } + + fn canonical_name(&self) -> String { + "BlockingTaskQueue".into() + } +} diff --git a/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs b/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs index bf55c6b9b8..4db5080eb1 100644 --- a/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs +++ b/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs @@ -18,6 +18,7 @@ use crate::backend::TemplateExpression; use crate::interface::*; use crate::BindingsConfig; +mod blocking_task_queue; mod callback_interface; mod compounds; mod custom; @@ -462,6 +463,8 @@ impl SwiftCodeOracle { Type::Timestamp => Box::new(miscellany::TimestampCodeType), Type::Duration => Box::new(miscellany::DurationCodeType), + Type::BlockingTaskQueue => Box::new(blocking_task_queue::BlockingTaskQueueCodeType), + Type::Enum { name, .. } => Box::new(enum_::EnumCodeType::new(name)), Type::Object { name, imp, .. } => Box::new(object::ObjectCodeType::new(name, imp)), Type::Record { name, .. } => Box::new(record::RecordCodeType::new(name)), diff --git a/uniffi_bindgen/src/bindings/swift/templates/Async.swift b/uniffi_bindgen/src/bindings/swift/templates/Async.swift index 578ff4ddf4..8fb37fc88f 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/Async.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/Async.swift @@ -1,9 +1,26 @@ private let UNIFFI_RUST_FUTURE_POLL_READY: Int8 = 0 private let UNIFFI_RUST_FUTURE_POLL_MAYBE_READY: Int8 = 1 +// Data for an in-progress poll of a RustFuture +fileprivate class UniffiPollData { + let continuation: UnsafeContinuation + let rustFuture: UnsafeMutableRawPointer + let pollFunc: (UnsafeMutableRawPointer, @escaping UniffiRustFutureContinuationCallback, UnsafeMutableRawPointer, UInt64) -> () + + init( + continuation: UnsafeContinuation, + rustFuture: UnsafeMutableRawPointer, + pollFunc: @escaping (UnsafeMutableRawPointer, @escaping UniffiRustFutureContinuationCallback, UnsafeMutableRawPointer, UInt64) -> () + ) { + self.continuation = continuation + self.rustFuture = rustFuture + self.pollFunc = pollFunc + } +} + fileprivate func uniffiRustCallAsync( rustFutureFunc: () -> UnsafeMutableRawPointer, - pollFunc: (UnsafeMutableRawPointer, @escaping UniffiRustFutureContinuationCallback, UnsafeMutableRawPointer) -> (), + pollFunc: @escaping (UnsafeMutableRawPointer, @escaping UniffiRustFutureContinuationCallback, UnsafeMutableRawPointer, UInt64) -> (), completeFunc: (UnsafeMutableRawPointer, UnsafeMutablePointer) -> F, freeFunc: (UnsafeMutableRawPointer) -> (), liftFunc: (F) throws -> T, @@ -19,7 +36,14 @@ fileprivate func uniffiRustCallAsync( var pollResult: Int8; repeat { pollResult = await withUnsafeContinuation { - pollFunc(rustFuture, uniffiFutureContinuationCallback, ContinuationHolder($0).toOpaque()) + let pollData = UniffiPollData( + continuation: $0, + rustFuture: rustFuture, + pollFunc: pollFunc + ) + let pollDataPtr = Unmanaged.passRetained(pollData).toOpaque() + UNIFFI_POLL_DATA_HANDLE_COUNT += 1 + pollFunc(rustFuture, uniffiFutureContinuationCallback, pollDataPtr, 0) } } while pollResult != UNIFFI_RUST_FUTURE_POLL_READY @@ -31,28 +55,28 @@ fileprivate func uniffiRustCallAsync( // Callback handlers for an async calls. These are invoked by Rust when the future is ready. They // lift the return value or error and resume the suspended function. -fileprivate func uniffiFutureContinuationCallback(ptr: UnsafeMutableRawPointer, pollResult: Int8) { - ContinuationHolder.fromOpaque(ptr).resume(pollResult) -} - -// Wraps UnsafeContinuation in a class so that we can use reference counting when passing it across -// the FFI -fileprivate class ContinuationHolder { - let continuation: UnsafeContinuation - - init(_ continuation: UnsafeContinuation) { - self.continuation = continuation - } - - func resume(_ pollResult: Int8) { - self.continuation.resume(returning: pollResult) - } - - func toOpaque() -> UnsafeMutableRawPointer { - return Unmanaged.passRetained(self).toOpaque() +fileprivate func uniffiFutureContinuationCallback( + pollDataPtr: UnsafeMutableRawPointer, + pollResult: Int8, + blockingTaskQueueHandle: UInt64 +) { + if (blockingTaskQueueHandle == 0) { + // Complete the Python Future + let pollData = Unmanaged.fromOpaque(pollDataPtr).takeRetainedValue() + UNIFFI_POLL_DATA_HANDLE_COUNT -= 1 + pollData.continuation.resume(returning: pollResult) + } else { + // Call the poll function again, but inside the DispatchQuee + let pollData = Unmanaged.fromOpaque(pollDataPtr).takeUnretainedValue() + let queue = UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.get(handle: blockingTaskQueueHandle)! + queue.async { + pollData.pollFunc(pollData.rustFuture, uniffiFutureContinuationCallback, pollDataPtr, blockingTaskQueueHandle) + } } +} - static func fromOpaque(_ ptr: UnsafeRawPointer) -> ContinuationHolder { - return Unmanaged.fromOpaque(ptr).takeRetainedValue() - } +// For testing +fileprivate var UNIFFI_POLL_DATA_HANDLE_COUNT: Int = 0 +public func uniffiPollDataHandleCount() -> Int { + return UNIFFI_POLL_DATA_HANDLE_COUNT } diff --git a/uniffi_bindgen/src/bindings/swift/templates/BlockingTaskQueueTemplate.swift b/uniffi_bindgen/src/bindings/swift/templates/BlockingTaskQueueTemplate.swift new file mode 100644 index 0000000000..1c81e72595 --- /dev/null +++ b/uniffi_bindgen/src/bindings/swift/templates/BlockingTaskQueueTemplate.swift @@ -0,0 +1,37 @@ +fileprivate var UNIFFI_BLOCKING_TASK_QUEUE_VTABLE = UniffiBlockingTaskQueueVTable( + clone: { (handle: UInt64) -> UInt64 in + let dispatchQueue = UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.get(handle: handle)! + return UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.insert(obj: dispatchQueue) + }, + free: { (handle: UInt64) in + UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.remove(handle: handle) + } +) +fileprivate var UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP = UniffiHandleMap() + +fileprivate struct {{ ffi_converter_name }}: FfiConverterRustBuffer { + typealias SwiftType = DispatchQueue + + public static func write(_ value: DispatchQueue, into buf: inout [UInt8]) { + let handle = UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.insert(obj: value) + writeInt(&buf, handle) + // From Apple: "You can safely use the address of a global variable as a persistent unique + // pointer value" (https://developer.apple.com/swift/blog/?id=6) + let vtablePointer = UnsafeMutablePointer(&UNIFFI_BLOCKING_TASK_QUEUE_VTABLE) + // Convert the pointer to a word-sized Int then to a 64-bit int then write it out. + writeInt(&buf, Int64(Int(bitPattern: vtablePointer))) + } + + public static func read(from buf: inout (data: Data, offset: Data.Index)) throws -> DispatchQueue { + let handle: UInt64 = try readInt(&buf) + // Read the VTable pointer and throw it out. The vtable is only used by Rust and always the + // same value. + let _: UInt64 = try readInt(&buf) + return UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.remove(handle: handle)! + } +} + +// For testing +public func uniffiBlockingTaskQueueHandleCount() -> Int { + UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.count +} diff --git a/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceRuntime.swift b/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceRuntime.swift index d03b7ccb3f..5863c2ad41 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceRuntime.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceRuntime.swift @@ -1,60 +1,3 @@ -fileprivate extension NSLock { - func withLock(f: () throws -> T) rethrows -> T { - self.lock() - defer { self.unlock() } - return try f() - } -} - -fileprivate typealias UniFFICallbackHandle = UInt64 -fileprivate class UniFFICallbackHandleMap { - private var leftMap: [UniFFICallbackHandle: T] = [:] - private var counter: [UniFFICallbackHandle: UInt64] = [:] - private var rightMap: [ObjectIdentifier: UniFFICallbackHandle] = [:] - - private let lock = NSLock() - private var currentHandle: UniFFICallbackHandle = 1 - private let stride: UniFFICallbackHandle = 1 - - func insert(obj: T) -> UniFFICallbackHandle { - lock.withLock { - let id = ObjectIdentifier(obj as AnyObject) - let handle = rightMap[id] ?? { - currentHandle += stride - let handle = currentHandle - leftMap[handle] = obj - rightMap[id] = handle - return handle - }() - counter[handle] = (counter[handle] ?? 0) + 1 - return handle - } - } - - func get(handle: UniFFICallbackHandle) -> T? { - lock.withLock { - leftMap[handle] - } - } - - func delete(handle: UniFFICallbackHandle) { - remove(handle: handle) - } - - @discardableResult - func remove(handle: UniFFICallbackHandle) -> T? { - lock.withLock { - defer { counter[handle] = (counter[handle] ?? 1) - 1 } - guard counter[handle] == 1 else { return leftMap[handle] } - let obj = leftMap.removeValue(forKey: handle) - if let obj = obj { - rightMap.removeValue(forKey: ObjectIdentifier(obj as AnyObject)) - } - return obj - } - } -} - // Magic number for the Rust proxy to call using the same mechanism as every other method, // to free the callback once it's dropped by Rust. private let IDX_CALLBACK_FREE: Int32 = 0 diff --git a/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceTemplate.swift b/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceTemplate.swift index 8cdd735b9a..f649b8b001 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceTemplate.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceTemplate.swift @@ -13,15 +13,15 @@ // FfiConverter protocol for callback interfaces fileprivate struct {{ ffi_converter_name }} { - fileprivate static var handleMap = UniFFICallbackHandleMap<{{ type_name }}>() + fileprivate static var handleMap = UniffiHandleMap<{{ type_name }}>() } extension {{ ffi_converter_name }} : FfiConverter { typealias SwiftType = {{ type_name }} // We can use Handle as the FfiType because it's a typealias to UInt64 - typealias FfiType = UniFFICallbackHandle + typealias FfiType = UInt64 - public static func lift(_ handle: UniFFICallbackHandle) throws -> SwiftType { + public static func lift(_ handle: UInt64) throws -> SwiftType { guard let callback = handleMap.get(handle: handle) else { throw UniffiInternalError.unexpectedStaleHandle } @@ -29,11 +29,11 @@ extension {{ ffi_converter_name }} : FfiConverter { } public static func read(from buf: inout (data: Data, offset: Data.Index)) throws -> SwiftType { - let handle: UniFFICallbackHandle = try readInt(&buf) + let handle: UInt64 = try readInt(&buf) return try lift(handle) } - public static func lower(_ v: SwiftType) -> UniFFICallbackHandle { + public static func lower(_ v: SwiftType) -> UInt64 { return handleMap.insert(obj: v) } diff --git a/uniffi_bindgen/src/bindings/swift/templates/Helpers.swift b/uniffi_bindgen/src/bindings/swift/templates/Helpers.swift index d233d4b762..3cc19914f3 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/Helpers.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/Helpers.swift @@ -44,6 +44,69 @@ fileprivate extension RustCallStatus { } } +fileprivate class UniffiHandleMap { + private var leftMap: [UInt64: T] = [:] + private var counter: [UInt64: UInt64] = [:] + private var rightMap: [ObjectIdentifier: UInt64] = [:] + + private let lock = NSLock() + // Start with 1 so that 0 can be special-cased as the null value. + private var currentHandle: UInt64 = 1 + private let stride: UInt64 = 1 + + func insert(obj: T) -> UInt64 { + lock.withLock { + let id = ObjectIdentifier(obj as AnyObject) + let handle = rightMap[id] ?? { + currentHandle += stride + let handle = currentHandle + leftMap[handle] = obj + rightMap[id] = handle + return handle + }() + counter[handle] = (counter[handle] ?? 0) + 1 + return handle + } + } + + func get(handle: UInt64) -> T? { + lock.withLock { + leftMap[handle] + } + } + + func delete(handle: UInt64) { + remove(handle: handle) + } + + @discardableResult + func remove(handle: UInt64) -> T? { + lock.withLock { + defer { counter[handle] = (counter[handle] ?? 1) - 1 } + guard counter[handle] == 1 else { return leftMap[handle] } + let obj = leftMap.removeValue(forKey: handle) + if let obj = obj { + rightMap.removeValue(forKey: ObjectIdentifier(obj as AnyObject)) + } + return obj + } + } + + var count: Int { + get { + leftMap.count + } + } +} + +fileprivate extension NSLock { + func withLock(f: () throws -> T) rethrows -> T { + self.lock() + defer { self.unlock() } + return try f() + } +} + private func rustCall(_ callback: (UnsafeMutablePointer) -> T) throws -> T { try makeRustCall(callback, errorHandler: nil) } diff --git a/uniffi_bindgen/src/bindings/swift/templates/ObjectTemplate.swift b/uniffi_bindgen/src/bindings/swift/templates/ObjectTemplate.swift index 553e60045f..b1eee0dfd0 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/ObjectTemplate.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/ObjectTemplate.swift @@ -154,7 +154,7 @@ public class {{ impl_class_name }}: public struct {{ ffi_converter_name }}: FfiConverter { {%- if obj.is_trait_interface() %} - fileprivate static var handleMap = UniFFICallbackHandleMap<{{ type_name }}>() + fileprivate static var handleMap = UniffiHandleMap<{{ type_name }}>() {%- endif %} typealias FfiType = UnsafeMutableRawPointer diff --git a/uniffi_bindgen/src/bindings/swift/templates/Types.swift b/uniffi_bindgen/src/bindings/swift/templates/Types.swift index 5e26758f3c..ba4d2059c8 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/Types.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/Types.swift @@ -64,6 +64,9 @@ {%- when Type::CallbackInterface { name, module_path } %} {%- include "CallbackInterfaceTemplate.swift" %} +{%- when Type::BlockingTaskQueue %} +{%- include "BlockingTaskQueueTemplate.swift" %} + {%- when Type::Custom { name, module_path, builtin } %} {%- include "CustomType.swift" %} diff --git a/uniffi_bindgen/src/interface/ffi.rs b/uniffi_bindgen/src/interface/ffi.rs index 5ecb582567..b7a5bd06d0 100644 --- a/uniffi_bindgen/src/interface/ffi.rs +++ b/uniffi_bindgen/src/interface/ffi.rs @@ -106,7 +106,8 @@ impl From<&Type> for FfiType { | Type::Sequence { .. } | Type::Map { .. } | Type::Timestamp - | Type::Duration => FfiType::RustBuffer(None), + | Type::Duration + | Type::BlockingTaskQueue => FfiType::RustBuffer(None), Type::External { name, kind: ExternalKind::Interface, diff --git a/uniffi_bindgen/src/interface/mod.rs b/uniffi_bindgen/src/interface/mod.rs index 81090ec37d..91febfeb50 100644 --- a/uniffi_bindgen/src/interface/mod.rs +++ b/uniffi_bindgen/src/interface/mod.rs @@ -67,7 +67,7 @@ mod record; pub use record::{Field, Record}; pub mod ffi; -pub use ffi::{FfiArgument, FfiCallbackFunction, FfiFunction, FfiStruct, FfiType}; +pub use ffi::{FfiArgument, FfiCallbackFunction, FfiField, FfiFunction, FfiStruct, FfiType}; pub use uniffi_meta::Radix; use uniffi_meta::{ ConstructorMetadata, LiteralMetadata, NamespaceMetadata, ObjectMetadata, TraitMethodMetadata, @@ -232,6 +232,7 @@ impl ComponentInterface { arguments: vec![ FfiArgument::new("data", FfiType::RustFutureContinuationData), FfiArgument::new("poll_result", FfiType::Int8), + FfiArgument::new("blocking_task_queue_handle", FfiType::UInt64), ], return_type: None, has_rust_call_status_arg: false, @@ -242,6 +243,18 @@ impl ComponentInterface { return_type: None, has_rust_call_status_arg: false, }, + FfiCallbackFunction { + name: "BlockingTaskQueueClone".to_owned(), + arguments: vec![FfiArgument::new("handle", FfiType::UInt64)], + return_type: Some(FfiType::UInt64), + has_rust_call_status_arg: false, + }, + FfiCallbackFunction { + name: "BlockingTaskQueueFree".to_owned(), + arguments: vec![FfiArgument::new("handle", FfiType::UInt64)], + return_type: None, + has_rust_call_status_arg: false, + }, ] } @@ -249,6 +262,30 @@ impl ComponentInterface { /// /// These are defined by the foreign code and invoked by Rust. pub fn ffi_struct_definitions(&self) -> impl IntoIterator + '_ { + self.builtin_vtable_definitions() + .into_iter() + .chain(self.callback_interface_vtable_definitions()) + } + + pub fn builtin_vtable_definitions(&self) -> impl IntoIterator + '_ { + [FfiStruct { + name: "BlockingTaskQueueVTable".to_owned(), + fields: vec![ + FfiField::new( + "clone", + FfiType::Callback("BlockingTaskQueueClone".to_owned()), + ), + FfiField::new( + "free", + FfiType::Callback("BlockingTaskQueueFree".to_owned()), + ), + ], + }] + } + + pub fn callback_interface_vtable_definitions( + &self, + ) -> impl IntoIterator + '_ { self.callback_interface_definitions() .iter() .map(|cbi| cbi.vtable_definition()) @@ -517,6 +554,10 @@ impl ComponentInterface { name: "callback_data".to_owned(), type_: FfiType::RustFutureContinuationData, }, + FfiArgument { + name: "blocking_task_queue_handle".to_owned(), + type_: FfiType::UInt64, + }, ], return_type: None, has_rust_call_status_arg: false, @@ -809,8 +850,12 @@ impl ComponentInterface { bail!("Conflicting type definition for \"{}\"", defn.name()); } self.types.add_known_types(defn.iter_types())?; - self.functions.push(defn); + if defn.is_async() { + self.types + .add_known_type(&uniffi_meta::Type::BlockingTaskQueue)?; + } + self.functions.push(defn); Ok(()) } diff --git a/uniffi_bindgen/src/interface/universe.rs b/uniffi_bindgen/src/interface/universe.rs index 70bc61f8a9..2faef72fd6 100644 --- a/uniffi_bindgen/src/interface/universe.rs +++ b/uniffi_bindgen/src/interface/universe.rs @@ -84,6 +84,7 @@ impl TypeUniverse { Type::Bytes => self.add_type_definition("bytes", type_)?, Type::Timestamp => self.add_type_definition("timestamp", type_)?, Type::Duration => self.add_type_definition("duration", type_)?, + Type::BlockingTaskQueue => self.add_type_definition("BlockingTaskQueue", type_)?, Type::Object { name, .. } | Type::Record { name, .. } | Type::Enum { name, .. } diff --git a/uniffi_bindgen/src/scaffolding/mod.rs b/uniffi_bindgen/src/scaffolding/mod.rs index 7fd81831aa..231e0495d3 100644 --- a/uniffi_bindgen/src/scaffolding/mod.rs +++ b/uniffi_bindgen/src/scaffolding/mod.rs @@ -45,6 +45,7 @@ mod filters { format!("std::sync::Arc<{}>", imp.rust_name_for(name)) } Type::CallbackInterface { name, .. } => format!("Box"), + Type::BlockingTaskQueue => "::uniffi::BlockingTaskQueue".to_owned(), Type::Optional { inner_type } => { format!("std::option::Option<{}>", type_rs(inner_type)?) } diff --git a/uniffi_core/src/ffi/rustfuture/blocking_task_queue.rs b/uniffi_core/src/ffi/rustfuture/blocking_task_queue.rs new file mode 100644 index 0000000000..50abe5a66f --- /dev/null +++ b/uniffi_core/src/ffi/rustfuture/blocking_task_queue.rs @@ -0,0 +1,69 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +//! Defines the BlockingTaskQueue struct +//! +//! This module is responsible for the general handling of BlockingTaskQueue instances (cloning, droping, etc). +//! See `scheduler.rs` and the foreign bindings code for how the async functionality is implemented. + +use super::scheduler::schedule_in_blocking_task_queue; +use std::num::NonZeroU64; + +/// Foreign-managed blocking task queue that we can use to schedule futures +/// +/// On the foreign side this is a Kotlin `CoroutineContext`, Python `Executor` or Swift +/// `DispatchQueue`. UniFFI converts those objects into this struct for the Rust code to use. +/// +/// Rust async code can call [BlockingTaskQueue::execute] to run a closure in that +/// blocking task queue. Use this for functions with blocking operations that should not be executed +/// in a normal async context. Some examples are non-async file/network operations, long-running +/// CPU-bound tasks, blocking database operations, etc. +#[repr(C)] +pub struct BlockingTaskQueue { + /// Opaque handle for the task queue + pub handle: NonZeroU64, + /// Method VTable + /// + /// This is simply a C struct where each field is a function pointer that inputs a + /// BlockingTaskQueue handle + pub vtable: &'static BlockingTaskQueueVTable, +} + +#[repr(C)] +#[derive(Debug)] +pub struct BlockingTaskQueueVTable { + clone: extern "C" fn(u64) -> u64, + drop: extern "C" fn(u64), +} + +// Note: see `scheduler.rs` for details on how BlockingTaskQueue is used. +impl BlockingTaskQueue { + /// Run a closure in a blocking task queue + pub async fn execute(&self, f: F) -> R + where + F: FnOnce() -> R, + { + schedule_in_blocking_task_queue(self.handle).await; + f() + } +} + +impl Clone for BlockingTaskQueue { + fn clone(&self) -> Self { + let raw_handle = (self.vtable.clone)(self.handle.into()); + let handle = raw_handle + .try_into() + .expect("BlockingTaskQueue.clone() returned 0"); + Self { + handle, + vtable: self.vtable, + } + } +} + +impl Drop for BlockingTaskQueue { + fn drop(&mut self) { + (self.vtable.drop)(self.handle.into()) + } +} diff --git a/uniffi_core/src/ffi/rustfuture/future.rs b/uniffi_core/src/ffi/rustfuture/future.rs index b104b20a32..b4eed8fe9d 100644 --- a/uniffi_core/src/ffi/rustfuture/future.rs +++ b/uniffi_core/src/ffi/rustfuture/future.rs @@ -21,6 +21,10 @@ //! 2b. If the async function is cancelled, then call [rust_future_cancel]. This causes the //! continuation function to be called with [RustFuturePoll::Ready] and the [RustFuture] to //! enter a cancelled state. +//! 2c. If the Rust code wants schedule work to be run in a `BlockingTaskQueue`, then the +//! continuation is called with [RustFuturePoll::MaybeReady] and the blocking task queue handle. +//! The foreign code is responsible for ensuring the next [rust_future_poll] call happens in +//! that blocking task queue and the handle is passed to [rust_future_poll]. //! 3. Call [rust_future_complete] to get the result of the future. //! 4. Call [rust_future_free] to free the future, ideally in a finally block. This: //! - Releases any resources held by the future @@ -78,6 +82,7 @@ use std::{ future::Future, marker::PhantomData, + num::NonZeroU64, ops::Deref, panic, pin::Pin, @@ -85,8 +90,8 @@ use std::{ task::{Context, Poll, Wake}, }; -use super::{RustFutureContinuationCallback, RustFuturePoll, Scheduler}; -use crate::{rust_call_with_out_status, FfiDefault, LowerReturn, RustCallStatus}; +use super::{scheduler, RustFutureContinuationCallback, Scheduler}; +use crate::{rust_call_with_out_status, FfiDefault, LowerReturn, RustCallStatus, RustFuturePoll}; /// Wraps the actual future we're polling struct WrappedFuture @@ -223,17 +228,26 @@ where }) } - pub(super) fn poll(self: Arc, callback: RustFutureContinuationCallback, data: *const ()) { + pub(super) fn poll( + self: Arc, + callback: RustFutureContinuationCallback, + data: *const (), + blocking_task_queue_handle: Option, + ) { + scheduler::on_poll_start(blocking_task_queue_handle); + // Clear out the waked flag, since we're about to poll right now. + self.scheduler.lock().unwrap().clear_wake_flag(); let ready = self.is_cancelled() || { let mut locked = self.future.lock().unwrap(); let waker: std::task::Waker = Arc::clone(&self).into(); locked.poll(&mut Context::from_waker(&waker)) }; if ready { - callback(data, RustFuturePoll::Ready) + callback(data, RustFuturePoll::Ready, 0) } else { self.scheduler.lock().unwrap().store(callback, data); } + scheduler::on_poll_end(); } pub(super) fn is_cancelled(&self) -> bool { @@ -289,7 +303,12 @@ where /// only create those functions for each of the 13 possible FFI return types. #[doc(hidden)] pub trait RustFutureFfi { - fn ffi_poll(self: Arc, callback: RustFutureContinuationCallback, data: *const ()); + fn ffi_poll( + self: Arc, + callback: RustFutureContinuationCallback, + data: *const (), + blocking_task_queue_handle: Option, + ); fn ffi_cancel(&self); fn ffi_complete(&self, call_status: &mut RustCallStatus) -> ReturnType; fn ffi_free(self: Arc); @@ -302,8 +321,13 @@ where T: LowerReturn + Send + 'static, UT: Send + 'static, { - fn ffi_poll(self: Arc, callback: RustFutureContinuationCallback, data: *const ()) { - self.poll(callback, data) + fn ffi_poll( + self: Arc, + callback: RustFutureContinuationCallback, + data: *const (), + blocking_task_queue_handle: Option, + ) { + self.poll(callback, data, blocking_task_queue_handle) } fn ffi_cancel(&self) { diff --git a/uniffi_core/src/ffi/rustfuture/mod.rs b/uniffi_core/src/ffi/rustfuture/mod.rs index 4aaf013fd5..a7a1142c7a 100644 --- a/uniffi_core/src/ffi/rustfuture/mod.rs +++ b/uniffi_core/src/ffi/rustfuture/mod.rs @@ -4,8 +4,11 @@ use std::{future::Future, sync::Arc}; +mod blocking_task_queue; mod future; mod scheduler; + +pub use blocking_task_queue::*; use future::*; use scheduler::*; @@ -28,10 +31,23 @@ pub enum RustFuturePoll { /// /// The Rust side of things calls this when the foreign side should call [rust_future_poll] again /// to continue progress on the future. -pub type RustFutureContinuationCallback = extern "C" fn(callback_data: *const (), RustFuturePoll); +/// +/// WARNING: the call to [rust_future_poll] must be scheduled to happen soon after the callback is +/// called, but not inside the callback itself. If [rust_future_poll] is called inside the +/// callback, some futures will deadlock and our scheduler code might as well. +/// +/// * `callback_data` is the handle that the foreign code passed to `poll()` +/// * `poll_result` is the result of the poll +/// * If `blocking_task_task_queue` is non-zero, it's the BlockingTaskQueue handle that the next `poll()` should run on +pub type RustFutureContinuationCallback = extern "C" fn( + callback_data: *const (), + poll_result: RustFuturePoll, + blocking_task_queue_handle: u64, +); /// Opaque handle for a Rust future that's stored by the foreign language code #[repr(transparent)] +#[derive(Debug)] pub struct RustFutureHandle(*const ()); // === Public FFI API === @@ -69,16 +85,22 @@ where /// a [RustFuturePoll] value. For each [rust_future_poll] call the continuation will be called /// exactly once. /// +/// If this is running in a BlockingTaskQueue, then `blocking_task_queue_handle` must be the handle +/// for it. If not, `blocking_task_queue_handle` must be `0`. +/// /// # Safety /// /// The [RustFutureHandle] must not previously have been passed to [rust_future_free] pub unsafe fn rust_future_poll( - handle: RustFutureHandle, + future: RustFutureHandle, callback: RustFutureContinuationCallback, data: *const (), + blocking_task_queue_handle: u64, ) { - let future = &*(handle.0 as *mut Arc>); - future.clone().ffi_poll(callback, data) + let future = &*(future.0 as *mut Arc>); + future + .clone() + .ffi_poll(callback, data, blocking_task_queue_handle.try_into().ok()) } /// Cancel a Rust future diff --git a/uniffi_core/src/ffi/rustfuture/scheduler.rs b/uniffi_core/src/ffi/rustfuture/scheduler.rs index aae5a0c1cf..4440cd7ba7 100644 --- a/uniffi_core/src/ffi/rustfuture/scheduler.rs +++ b/uniffi_core/src/ffi/rustfuture/scheduler.rs @@ -2,10 +2,89 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ -use std::mem; +use std::{cell::RefCell, future::poll_fn, mem, num::NonZeroU64, task::Poll, thread_local}; use super::{RustFutureContinuationCallback, RustFuturePoll}; +/// Context of the current `RustFuture::poll` call +struct RustFutureContext { + /// Blocking task queue that the future is being polled on + current_blocking_task_queue_handle: Option, + /// Blocking task queue that we've been asked to schedule the next poll on + scheduled_blocking_task_queue_handle: Option, +} + +thread_local! { + static CONTEXT: RefCell = RefCell::new(RustFutureContext { + current_blocking_task_queue_handle: None, + scheduled_blocking_task_queue_handle: None, + }); +} + +fn with_context R, R>(operation: F) -> R { + CONTEXT.with(|context| operation(&mut context.borrow_mut())) +} + +pub fn on_poll_start(current_blocking_task_queue_handle: Option) { + with_context(|context| { + *context = RustFutureContext { + current_blocking_task_queue_handle, + scheduled_blocking_task_queue_handle: None, + } + }); +} + +pub fn on_poll_end() { + with_context(|context| { + *context = RustFutureContext { + current_blocking_task_queue_handle: None, + scheduled_blocking_task_queue_handle: None, + } + }); +} + +/// Schedule work in a blocking task queue +/// +/// The returned future will attempt to arrange for [RustFuture::poll] to be called in the +/// blocking task queue. Once [RustFuture::poll] is running in the blocking task queue, then the future +/// will be ready. +/// +/// There's one tricky issue here: how can we ensure that when the top-level task is run in the +/// blocking task queue, this future will be polled? What happens this future is a child of `join!`, +/// `FuturesUnordered` or some other Future that handles its own polling? +/// +/// We start with an assumption: if we notify the waker then this future will be polled when the +/// top-level task is polled next. If a future does not honor this then we consider it a broken +/// future. This seems fair, since that future would almost certainly break a lot of other future +/// code. +/// +/// Based on that, we can have a simple system. When we're polled: +/// * If we're running in the blocking task queue, then we return `Poll::Ready`. +/// * If not, we return `Poll::Pending` and notify the waker so that the future polls again on +/// the next top-level poll. +/// +/// Note that this can be inefficient if the code awaits multiple blocking task queues at once. We +/// can only run the next poll on one of them, but all futures will be woken up. This seems okay +/// for our intended use cases, it would be pretty odd for a library to use multiple blocking task +/// queues. The alternative would be to store the set of all pending blocking task queues, which +/// seems like complete overkill for our purposes. +pub(super) async fn schedule_in_blocking_task_queue(handle: NonZeroU64) { + poll_fn(|future_context| { + with_context(|poll_context| { + if poll_context.current_blocking_task_queue_handle == Some(handle) { + Poll::Ready(()) + } else { + poll_context + .scheduled_blocking_task_queue_handle + .get_or_insert(handle); + future_context.waker().wake_by_ref(); + Poll::Pending + } + }) + }) + .await +} + /// Schedules a [crate::RustFuture] by managing the continuation data /// /// This struct manages the continuation callback and data that comes from the foreign side. It @@ -41,21 +120,34 @@ impl Scheduler { /// Store new continuation data if we are in the `Empty` state. If we are in the `Waked` or /// `Cancelled` state, call the continuation immediately with the data. pub(super) fn store(&mut self, callback: RustFutureContinuationCallback, data: *const ()) { + if let Some(blocking_task_queue_handle) = + with_context(|context| context.scheduled_blocking_task_queue_handle) + { + // We were asked to schedule the future in a blocking task queue, call the callback + // rather than storing it + callback( + data, + RustFuturePoll::MaybeReady, + blocking_task_queue_handle.into(), + ); + return; + } + match self { Self::Empty => *self = Self::Set(callback, data), Self::Set(old_callback, old_data) => { log::error!( "store: observed `Self::Set` state. Is poll() being called from multiple threads at once?" ); - old_callback(*old_data, RustFuturePoll::Ready); + old_callback(*old_data, RustFuturePoll::Ready, 0); *self = Self::Set(callback, data); } Self::Waked => { *self = Self::Empty; - callback(data, RustFuturePoll::MaybeReady); + callback(data, RustFuturePoll::MaybeReady, 0); } Self::Cancelled => { - callback(data, RustFuturePoll::Ready); + callback(data, RustFuturePoll::Ready, 0); } } } @@ -67,7 +159,7 @@ impl Scheduler { let old_data = *old_data; let callback = *callback; *self = Self::Empty; - callback(old_data, RustFuturePoll::MaybeReady); + callback(old_data, RustFuturePoll::MaybeReady, 0); } // If we were in the `Empty` state, then transition to `Waked`. The next time `store` // is called, we will immediately call the continuation. @@ -79,7 +171,13 @@ impl Scheduler { pub(super) fn cancel(&mut self) { if let Self::Set(callback, old_data) = mem::replace(self, Self::Cancelled) { - callback(old_data, RustFuturePoll::Ready); + callback(old_data, RustFuturePoll::Ready, 0); + } + } + + pub(super) fn clear_wake_flag(&mut self) { + if let Self::Waked = self { + *self = Self::Empty } } diff --git a/uniffi_core/src/ffi/rustfuture/tests.rs b/uniffi_core/src/ffi/rustfuture/tests.rs index 1f68085562..c988678678 100644 --- a/uniffi_core/src/ffi/rustfuture/tests.rs +++ b/uniffi_core/src/ffi/rustfuture/tests.rs @@ -65,16 +65,38 @@ fn channel() -> (Sender, Arc>) { } /// Poll a Rust future and get an OnceCell that's set when the continuation is called -fn poll(rust_future: &Arc>) -> Arc> { +fn poll(rust_future: &Arc>) -> Arc> { let cell = Arc::new(OnceCell::new()); let cell_ptr = Arc::into_raw(cell.clone()) as *const (); - rust_future.clone().ffi_poll(poll_continuation, cell_ptr); + rust_future + .clone() + .ffi_poll(poll_continuation, cell_ptr, None); cell } -extern "C" fn poll_continuation(data: *const (), code: RustFuturePoll) { - let cell = unsafe { Arc::from_raw(data as *const OnceCell) }; - cell.set(code).expect("Error setting OnceCell"); +/// Like poll, but simulate `poll()` being called from a blocking task queue +fn poll_from_blocking_task_queue( + rust_future: &Arc>, + blocking_task_queue_handle: u64, +) -> Arc> { + let cell = Arc::new(OnceCell::new()); + let cell_ptr = Arc::into_raw(cell.clone()) as *const (); + rust_future.clone().ffi_poll( + poll_continuation, + cell_ptr, + Some(blocking_task_queue_handle.try_into().unwrap()), + ); + cell +} + +extern "C" fn poll_continuation( + data: *const (), + code: RustFuturePoll, + blocking_task_queue_handle: u64, +) { + let cell = unsafe { Arc::from_raw(data as *const OnceCell<(RustFuturePoll, u64)>) }; + cell.set((code, blocking_task_queue_handle)) + .expect("Error setting OnceCell"); } fn complete(rust_future: Arc>) -> (RustBuffer, RustCallStatus) { @@ -83,25 +105,47 @@ fn complete(rust_future: Arc>) -> (RustBuffer, Rus (return_value, out_status_code) } +fn check_continuation_not_called(once_cell: &OnceCell<(RustFuturePoll, u64)>) { + assert_eq!(once_cell.get(), None); +} + +fn check_continuation_called( + once_cell: &OnceCell<(RustFuturePoll, u64)>, + poll_result: RustFuturePoll, +) { + assert_eq!(once_cell.get(), Some(&(poll_result, 0))); +} + +fn check_continuation_called_with_blocking_task_queue_handle( + once_cell: &OnceCell<(RustFuturePoll, u64)>, + poll_result: RustFuturePoll, + blocking_task_queue_handle: u64, +) { + assert_eq!( + once_cell.get(), + Some(&(poll_result, blocking_task_queue_handle)) + ) +} + #[test] fn test_success() { let (sender, rust_future) = channel(); // Test polling the rust future before it's ready let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), None); + check_continuation_not_called(&continuation_result); sender.wake(); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::MaybeReady)); + check_continuation_called(&continuation_result, RustFuturePoll::MaybeReady); // Test polling the rust future when it's ready let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), None); + check_continuation_not_called(&continuation_result); sender.send(Ok("All done".into())); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::MaybeReady)); + check_continuation_called(&continuation_result, RustFuturePoll::MaybeReady); // Future polls should immediately return ready let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); // Complete the future let (return_buf, call_status) = complete(rust_future); @@ -117,12 +161,12 @@ fn test_error() { let (sender, rust_future) = channel(); let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), None); + check_continuation_not_called(&continuation_result); sender.send(Err("Something went wrong".into())); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::MaybeReady)); + check_continuation_called(&continuation_result, RustFuturePoll::MaybeReady); let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); let (_, call_status) = complete(rust_future); assert_eq!(call_status.code, RustCallStatusCode::Error); @@ -144,14 +188,14 @@ fn test_cancel() { let (_sender, rust_future) = channel(); let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), None); + check_continuation_not_called(&continuation_result); rust_future.ffi_cancel(); // Cancellation should immediately invoke the callback with RustFuturePoll::Ready - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); // Future polls should immediately invoke the callback with RustFuturePoll::Ready let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); let (_, call_status) = complete(rust_future); assert_eq!(call_status.code, RustCallStatusCode::Cancelled); @@ -187,7 +231,7 @@ fn test_complete_with_stored_continuation() { let continuation_result = poll(&rust_future); rust_future.ffi_free(); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); } // Test what happens if we see a `wake()` call while we're polling the future. This can @@ -210,10 +254,47 @@ fn test_wake_during_poll() { let rust_future: Arc> = RustFuture::new(future, crate::UniFfiTag); let continuation_result = poll(&rust_future); // The continuation function should called immediately - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::MaybeReady)); + check_continuation_called(&continuation_result, RustFuturePoll::MaybeReady); // A second poll should finish the future let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); + let (return_buf, call_status) = complete(rust_future); + assert_eq!(call_status.code, RustCallStatusCode::Success); + assert_eq!( + >::try_lift(return_buf).unwrap(), + "All done" + ); +} + +#[test] +fn test_blocking_task() { + let blocking_task_queue_handle = 1001; + let future = async move { + schedule_in_blocking_task_queue(blocking_task_queue_handle.try_into().unwrap()).await; + "All done".to_owned() + }; + let rust_future: Arc> = RustFuture::new(future, crate::UniFfiTag); + // On the first poll, the future should not be ready and it should ask to be scheduled in the + // blocking task queue + let continuation_result = poll(&rust_future); + check_continuation_called_with_blocking_task_queue_handle( + &continuation_result, + RustFuturePoll::MaybeReady, + blocking_task_queue_handle, + ); + // If we poll it again not in a blocking task queue, then we get the same result + let continuation_result = poll(&rust_future); + check_continuation_called_with_blocking_task_queue_handle( + &continuation_result, + RustFuturePoll::MaybeReady, + blocking_task_queue_handle, + ); + // When we poll it in the blocking task queue, then the future is ready + let continuation_result = + poll_from_blocking_task_queue(&rust_future, blocking_task_queue_handle); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); + + // Complete the future let (return_buf, call_status) = complete(rust_future); assert_eq!(call_status.code, RustCallStatusCode::Success); assert_eq!( diff --git a/uniffi_core/src/ffi_converter_impls.rs b/uniffi_core/src/ffi_converter_impls.rs index e2f9109a60..f452d2cb77 100644 --- a/uniffi_core/src/ffi_converter_impls.rs +++ b/uniffi_core/src/ffi_converter_impls.rs @@ -23,8 +23,8 @@ /// "UT" means an abitrary `UniFfiTag` type. use crate::{ check_remaining, derive_ffi_traits, ffi_converter_rust_buffer_lift_and_lower, metadata, - ConvertError, FfiConverter, Lift, LiftReturn, Lower, LowerReturn, MetadataBuffer, Result, - RustBuffer, UnexpectedUniFFICallbackError, + BlockingTaskQueue, BlockingTaskQueueVTable, ConvertError, FfiConverter, Lift, LiftReturn, + Lower, LowerReturn, MetadataBuffer, Result, RustBuffer, UnexpectedUniFFICallbackError, }; use anyhow::bail; use bytes::buf::{Buf, BufMut}; @@ -244,6 +244,35 @@ unsafe impl FfiConverter for Duration { const TYPE_ID_META: MetadataBuffer = MetadataBuffer::from_code(metadata::codes::TYPE_DURATION); } +/// Support for passing [BlockingTaskQueue] across the FFI +/// +/// Both fields of [BlockingTaskQueue] are serialized into a RustBuffer. The vtable pointer is +/// casted to a u64. +unsafe impl FfiConverter for BlockingTaskQueue { + ffi_converter_rust_buffer_lift_and_lower!(UT); + + fn write(obj: BlockingTaskQueue, buf: &mut Vec) { + let obj = obj.clone(); + buf.put_u64(obj.handle.into()); + buf.put_u64(obj.vtable as *const BlockingTaskQueueVTable as u64); + } + + fn try_read(buf: &mut &[u8]) -> Result { + check_remaining(buf, 16)?; + let handle = buf + .get_u64() + .try_into() + .expect("handle = 0 when reading BlockingTaskQueue"); + let vtable = unsafe { + &*(buf.get_u64() as *const BlockingTaskQueueVTable) as &'static BlockingTaskQueueVTable + }; + Ok(Self { handle, vtable }) + } + + const TYPE_ID_META: MetadataBuffer = + MetadataBuffer::from_code(metadata::codes::TYPE_BLOCKING_TASK_QUEUE); +} + // Support for passing optional values via the FFI. // // Optional values are currently always passed by serializing to a buffer. @@ -419,6 +448,7 @@ derive_ffi_traits!(blanket bool); derive_ffi_traits!(blanket String); derive_ffi_traits!(blanket Duration); derive_ffi_traits!(blanket SystemTime); +derive_ffi_traits!(blanket BlockingTaskQueue); // For composite types, derive LowerReturn, LiftReturn, etc, from Lift/Lower. // diff --git a/uniffi_core/src/metadata.rs b/uniffi_core/src/metadata.rs index f6a42e9876..9667bc627c 100644 --- a/uniffi_core/src/metadata.rs +++ b/uniffi_core/src/metadata.rs @@ -67,6 +67,7 @@ pub mod codes { pub const TYPE_CUSTOM: u8 = 22; pub const TYPE_RESULT: u8 = 23; pub const TYPE_FUTURE: u8 = 24; + pub const TYPE_BLOCKING_TASK_QUEUE: u8 = 25; pub const TYPE_UNIT: u8 = 255; // Literal codes for LiteralMetadata - note that we don't support diff --git a/uniffi_macros/src/setup_scaffolding.rs b/uniffi_macros/src/setup_scaffolding.rs index 1d0c368504..425aee2540 100644 --- a/uniffi_macros/src/setup_scaffolding.rs +++ b/uniffi_macros/src/setup_scaffolding.rs @@ -166,8 +166,13 @@ fn rust_future_scaffolding_fns(module_path: &str) -> TokenStream { #[allow(clippy::missing_safety_doc, missing_docs)] #[doc(hidden)] #[no_mangle] - pub unsafe extern "C" fn #ffi_rust_future_poll(handle: ::uniffi::RustFutureHandle, callback: ::uniffi::RustFutureContinuationCallback, data: *const ()) { - ::uniffi::ffi::rust_future_poll::<#return_type>(handle, callback, data); + pub unsafe extern "C" fn #ffi_rust_future_poll( + handle: ::uniffi::RustFutureHandle, + callback: ::uniffi::RustFutureContinuationCallback, + data: *const (), + blocking_task_queue_handle: u64, + ) { + ::uniffi::ffi::rust_future_poll::<#return_type>(handle, callback, data, blocking_task_queue_handle); } #[allow(clippy::missing_safety_doc, missing_docs)] diff --git a/uniffi_meta/src/metadata.rs b/uniffi_meta/src/metadata.rs index 7506b9d7ab..a5a04bfb16 100644 --- a/uniffi_meta/src/metadata.rs +++ b/uniffi_meta/src/metadata.rs @@ -50,6 +50,7 @@ pub mod codes { pub const TYPE_CUSTOM: u8 = 22; pub const TYPE_RESULT: u8 = 23; //pub const TYPE_FUTURE: u8 = 24; + pub const TYPE_BLOCKING_TASK_QUEUE: u8 = 25; pub const TYPE_UNIT: u8 = 255; // Literal codes diff --git a/uniffi_meta/src/reader.rs b/uniffi_meta/src/reader.rs index fa7f4447e9..e3875bb407 100644 --- a/uniffi_meta/src/reader.rs +++ b/uniffi_meta/src/reader.rs @@ -144,6 +144,7 @@ impl<'a> MetadataReader<'a> { codes::TYPE_STRING => Type::String, codes::TYPE_DURATION => Type::Duration, codes::TYPE_SYSTEM_TIME => Type::Timestamp, + codes::TYPE_BLOCKING_TASK_QUEUE => Type::BlockingTaskQueue, codes::TYPE_RECORD => Type::Record { module_path: self.read_string()?, name: self.read_string()?, diff --git a/uniffi_meta/src/types.rs b/uniffi_meta/src/types.rs index 647f4e9929..e0ec13991d 100644 --- a/uniffi_meta/src/types.rs +++ b/uniffi_meta/src/types.rs @@ -86,6 +86,7 @@ pub enum Type { // How the object is implemented. imp: ObjectImpl, }, + BlockingTaskQueue, // Types defined in the component API, each of which has a string name. Record { module_path: String, diff --git a/uniffi_udl/src/resolver.rs b/uniffi_udl/src/resolver.rs index ea98cd7a99..1409c3a6ff 100644 --- a/uniffi_udl/src/resolver.rs +++ b/uniffi_udl/src/resolver.rs @@ -209,6 +209,7 @@ pub(crate) fn resolve_builtin_type(name: &str) -> Option { "f64" => Some(Type::Float64), "timestamp" => Some(Type::Timestamp), "duration" => Some(Type::Duration), + "BlockingTaskQueue" => Some(Type::BlockingTaskQueue), _ => None, } }