diff --git a/docs/manual/src/futures.md b/docs/manual/src/futures.md index 26a7e420be..f85074dcbc 100644 --- a/docs/manual/src/futures.md +++ b/docs/manual/src/futures.md @@ -93,3 +93,62 @@ In this case, we need an event loop to run the Python async function, but there' Use `uniffi_set_event_loop()` to handle this case. It should be called before the Rust code makes the async call and passed an eventloop to use. +## Blocking tasks + +Rust executors are designed around an assumption that the `Future::poll` function will return quickly. +This assumption, combined with cooperative scheduling, allows for a large number of futures to be handled by a small number of threads. +Foreign executors make similar assumptions and sometimes more extreme ones. +For example, the Python eventloop is single threaded -- if any task spends a long time between `await` points, then it will block all other tasks from progressing. + +This raises the question of how async code can interact with blocking code that performs blocking IO, long-running computations without `await` breaks, etc. +UniFFI defines the `BlockingTaskQueue` type, which is a foreign object that schedules work on a thread where it's okay to block. + +On Rust, `BlockingTaskQueue` is a UniFFI type that can safely run blocking code. +It's `execute` method works like tokio's [block_in_place](https://docs.rs/tokio/latest/tokio/task/fn.block_in_place.html) function. +It inputs a closure and runs it in the `BlockingTaskQueue`. +This closure can reference the outside scope (i.e. it does not need to be `'static`). +For example: + +```rust +#[derive(uniffi::Object)] +struct DataStore { + // Used to run blocking tasks + queue: uniffi::BlockingTaskQueue, + // Low-level DB object with blocking methods + db: Mutex, +} + +#[uniffi::export] +impl DataStore { + #[uniffi::constructor] + fn new(queue: uniffi::BlockingTaskQueue) -> Self { + Self { + queue, + db: Mutex::new(Database::new()) + } + } + + async fn fetch_all_items(&self) -> Vec { + self.queue.execute(|| self.db.lock().fetch_all_items()).await + } +} +``` + +On the foreign side `BlockingTaskQueue` corresponds to a language-dependent class. + +### Kotlin +Kotlin uses `CoroutineContext` for its `BlockingTaskQueue`. +Any `CoroutineContext` will work, but `Dispatchers.IO` is usually a good choice. +A DataStore from the example above can be created with `DataStore(Dispatchers.IO)`. + +### Swift +Swift uses `DispatchQueue` for its `BlockingTaskQueue`. +The user-initiated global queue is normally a good choice. +A DataStore from the example above can be created with `DataStore(queue: DispatchQueue.global(qos: .userInitiated)`. +The `DispatchQueue` should be concurrent. + +### Python + +Python uses a `futures.Executor` for its `BlockingTaskQueue`. +`ThreadPoolExecutor` is typically a good choice. +A DataStore from the example above can be created with `DataStore(ThreadPoolExecutor())`. diff --git a/fixtures/futures/src/lib.rs b/fixtures/futures/src/lib.rs index 15bc32b9cf..6869c70a8c 100644 --- a/fixtures/futures/src/lib.rs +++ b/fixtures/futures/src/lib.rs @@ -11,7 +11,10 @@ use std::{ time::Duration, }; -use futures::future::{AbortHandle, Abortable, Aborted}; +use futures::{ + future::{AbortHandle, Abortable, Aborted}, + stream::{FuturesUnordered, StreamExt}, +}; /// Non-blocking timer future. pub struct TimerFuture { @@ -456,4 +459,58 @@ async fn cancel_delay_using_trait(obj: Arc, delay_ms: i32) { assert_eq!(future.await, Err(Aborted)); } +/// Async function that uses a blocking task queue to do its work +#[uniffi::export] +pub async fn calc_square(queue: uniffi::BlockingTaskQueue, value: i32) -> i32 { + queue.execute(|| value * value).await +} + +/// Same as before, but this one runs multiple tasks +#[uniffi::export] +pub async fn calc_squares(queue: uniffi::BlockingTaskQueue, items: Vec) -> Vec { + // Use `FuturesUnordered` to test our blocking task queue code which is known to be a tricky API to work with. + // In particular, if we don't notify the waker then FuturesUnordered will not poll again. + let mut futures: FuturesUnordered<_> = (0..items.len()) + .map(|i| { + // Test that we can use references from the surrounding scope + let items = &items; + queue.execute(move || items[i] * items[i]) + }) + .collect(); + let mut results = vec![]; + while let Some(result) = futures.next().await { + results.push(result); + } + results.sort(); + results +} + +/// ...and this one uses multiple BlockingTaskQueues +#[uniffi::export] +pub async fn calc_squares_multi_queue( + queues: Vec, + items: Vec, +) -> Vec { + let mut futures: FuturesUnordered<_> = (0..items.len()) + .map(|i| { + // Test that we can use references from the surrounding scope + let items = &items; + queues[i].execute(move || items[i] * items[i]) + }) + .collect(); + let mut results = vec![]; + while let Some(result) = futures.next().await { + results.push(result); + } + results.sort(); + results +} + +/// Like calc_square, but it clones the BlockingTaskQueue first then drops both copies. Used to +/// test that a) the clone works and b) we correctly drop the references. +#[uniffi::export] +pub async fn calc_square_with_clone(queue: uniffi::BlockingTaskQueue, value: i32) -> i32 { + queue.clone().execute(|| value * value).await +} + uniffi::include_scaffolding!("futures"); diff --git a/fixtures/futures/tests/bindings/test_futures.kts b/fixtures/futures/tests/bindings/test_futures.kts index f853ddb4ea..37c988e2cf 100644 --- a/fixtures/futures/tests/bindings/test_futures.kts +++ b/fixtures/futures/tests/bindings/test_futures.kts @@ -1,9 +1,22 @@ import uniffi.fixture.futures.* +import java.util.concurrent.Executors import kotlinx.coroutines.* import kotlin.system.* +fun runAsyncTest(test: suspend CoroutineScope.() -> Unit) { + val initialBlockingTaskQueueHandleCount = uniffiBlockingTaskQueueHandleCount() + val initialPollHandleCount = uniffiPollHandleCount() + val time = runBlocking { + measureTimeMillis { + test() + } + } + assert(uniffiBlockingTaskQueueHandleCount() == initialBlockingTaskQueueHandleCount) + assert(uniffiPollHandleCount() == initialPollHandleCount) +} + // init UniFFI to get good measurements after that -runBlocking { +runAsyncTest { val time = measureTimeMillis { alwaysReady() } @@ -24,7 +37,7 @@ fun assertApproximateTime(actualTime: Long, expectedTime: Int, testName: String } // Test `always_ready`. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val result = alwaysReady() @@ -35,7 +48,7 @@ runBlocking { } // Test `void`. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val result = void() @@ -46,7 +59,7 @@ runBlocking { } // Test `sleep`. -runBlocking { +runAsyncTest { val time = measureTimeMillis { sleep(200U) } @@ -55,7 +68,7 @@ runBlocking { } // Test sequential futures. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val resultAlice = sayAfter(100U, "Alice") val resultBob = sayAfter(200U, "Bob") @@ -68,7 +81,7 @@ runBlocking { } // Test concurrent futures. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val resultAlice = async { sayAfter(100U, "Alice") } val resultBob = async { sayAfter(200U, "Bob") } @@ -81,7 +94,7 @@ runBlocking { } // Test async methods. -runBlocking { +runAsyncTest { val megaphone = newMegaphone() val time = measureTimeMillis { val resultAlice = megaphone.sayAfter(200U, "Alice") @@ -92,7 +105,7 @@ runBlocking { assertApproximateTime(time, 200, "async methods") } -runBlocking { +runAsyncTest { val megaphone = newMegaphone() val time = measureTimeMillis { val resultAlice = sayAfterWithMegaphone(megaphone, 200U, "Alice") @@ -104,7 +117,7 @@ runBlocking { } // Test async method returning optional object -runBlocking { +runAsyncTest { val megaphone = asyncMaybeNewMegaphone(true) assert(megaphone != null) @@ -213,7 +226,7 @@ runBlocking { // Test with the Tokio runtime. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val resultAlice = sayAfterWithTokio(200U, "Alice") @@ -224,7 +237,7 @@ runBlocking { } // Test fallible function/method. -runBlocking { +runAsyncTest { val time1 = measureTimeMillis { try { fallibleMe(false) @@ -289,7 +302,7 @@ runBlocking { } // Test record. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val result = newMyRecord("foo", 42U) @@ -303,7 +316,7 @@ runBlocking { } // Test a broken sleep. -runBlocking { +runAsyncTest { val time = measureTimeMillis { brokenSleep(100U, 0U) // calls the waker twice immediately sleep(100U) // wait for possible failure @@ -317,7 +330,7 @@ runBlocking { // Test a future that uses a lock and that is cancelled. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val job = launch { useSharedResource(SharedResourceOptions(releaseAfterMs=5000U, timeoutMs=100U)) @@ -336,7 +349,7 @@ runBlocking { } // Test a future that uses a lock and that is not cancelled. -runBlocking { +runAsyncTest { val time = measureTimeMillis { useSharedResource(SharedResourceOptions(releaseAfterMs=100U, timeoutMs=1000U)) @@ -344,3 +357,33 @@ runBlocking { } println("useSharedResource (not canceled): ${time}ms") } + +// Test blocking task queues +runAsyncTest { + withTimeout(1000) { + assert(calcSquare(Dispatchers.IO, 20) == 400) + } + + withTimeout(1000) { + assert(calcSquares(Dispatchers.IO, listOf(1, -2, 3)) == listOf(1, 4, 9)) + } + + val executors = listOf( + Executors.newSingleThreadExecutor(), + Executors.newSingleThreadExecutor(), + Executors.newSingleThreadExecutor(), + ) + withTimeout(1000) { + assert(calcSquaresMultiQueue(executors.map { it.asCoroutineDispatcher() }, listOf(1, -2, 3)) == listOf(1, 4, 9)) + } + for (executor in executors) { + executor.shutdown() + } +} + +// Test blocking task queue cloning +runAsyncTest { + withTimeout(1000) { + assert(calcSquareWithClone(Dispatchers.IO, 20) == 400) + } +} diff --git a/fixtures/futures/tests/bindings/test_futures.py b/fixtures/futures/tests/bindings/test_futures.py index 1b84451b5d..8c1a24184f 100644 --- a/fixtures/futures/tests/bindings/test_futures.py +++ b/fixtures/futures/tests/bindings/test_futures.py @@ -1,27 +1,33 @@ +import futures from futures import * +import contextlib import unittest from datetime import datetime import asyncio import typing import futures +from concurrent.futures import ThreadPoolExecutor def now(): return datetime.now() class TestFutures(unittest.TestCase): def test_always_ready(self): + @self.check_handle_counts() async def test(): self.assertEqual(await always_ready(), True) asyncio.run(test()) def test_void(self): + @self.check_handle_counts() async def test(): self.assertEqual(await void(), None) asyncio.run(test()) def test_sleep(self): + @self.check_handle_counts() async def test(): t0 = now() await sleep(2000) @@ -33,6 +39,7 @@ async def test(): asyncio.run(test()) def test_sequential_futures(self): + @self.check_handle_counts() async def test(): t0 = now() result_alice = await say_after(100, 'Alice') @@ -47,6 +54,7 @@ async def test(): asyncio.run(test()) def test_concurrent_tasks(self): + @self.check_handle_counts() async def test(): alice = asyncio.create_task(say_after(100, 'Alice')) bob = asyncio.create_task(say_after(200, 'Bob')) @@ -64,6 +72,7 @@ async def test(): asyncio.run(test()) def test_async_methods(self): + @self.check_handle_counts() async def test(): megaphone = new_megaphone() t0 = now() @@ -162,6 +171,7 @@ async def test(): self.assertEqual(len(futures.UNIFFI_FOREIGN_FUTURE_HANDLE_MAP), 0) def test_async_object_param(self): + @self.check_handle_counts() async def test(): megaphone = new_megaphone() t0 = now() @@ -175,6 +185,7 @@ async def test(): asyncio.run(test()) def test_with_tokio_runtime(self): + @self.check_handle_counts() async def test(): t0 = now() result_alice = await say_after_with_tokio(200, 'Alice') @@ -187,6 +198,7 @@ async def test(): asyncio.run(test()) def test_fallible(self): + @self.check_handle_counts() async def test(): result = await fallible_me(False) self.assertEqual(result, 42) @@ -211,6 +223,7 @@ async def test(): asyncio.run(test()) def test_fallible_struct(self): + @self.check_handle_counts() async def test(): megaphone = await fallible_struct(False) self.assertEqual(await megaphone.fallible_me(False), 42) @@ -224,6 +237,7 @@ async def test(): asyncio.run(test()) def test_record(self): + @self.check_handle_counts() async def test(): result = await new_my_record("foo", 42) self.assertEqual(result.__class__, MyRecord) @@ -233,6 +247,7 @@ async def test(): asyncio.run(test()) def test_cancel(self): + @self.check_handle_counts() async def test(): # Create a task task = asyncio.create_task(say_after(200, 'Alice')) @@ -250,6 +265,7 @@ async def test(): # Test a future that uses a lock and that is cancelled. def test_shared_resource_cancellation(self): + @self.check_handle_counts() async def test(): task = asyncio.create_task(use_shared_resource( SharedResourceOptions(release_after_ms=5000, timeout_ms=100))) @@ -260,6 +276,7 @@ async def test(): asyncio.run(test()) def test_shared_resource_no_cancellation(self): + @self.check_handle_counts() async def test(): await use_shared_resource(SharedResourceOptions(release_after_ms=100, timeout_ms=1000)) await use_shared_resource(SharedResourceOptions(release_after_ms=0, timeout_ms=1000)) @@ -271,5 +288,47 @@ async def test(): self.assertEqual(typing.get_type_hints(sleep_no_return), {"ms": int, "return": type(None)}) asyncio.run(test()) + # blocking task queue tests + + def test_calc_square(self): + @self.check_handle_counts() + async def test(): + executor = ThreadPoolExecutor() + self.assertEqual(await calc_square(executor, 20), 400) + asyncio.run(asyncio.wait_for(test(), timeout=1)) + + def test_calc_square_with_clone(self): + @self.check_handle_counts() + async def test(): + executor = ThreadPoolExecutor() + self.assertEqual(await calc_square_with_clone(executor, 20), 400) + asyncio.run(asyncio.wait_for(test(), timeout=1)) + + def test_calc_squares(self): + @self.check_handle_counts() + async def test(): + executor = ThreadPoolExecutor() + self.assertEqual(await calc_squares(executor, [1, -2, 3]), [1, 4, 9]) + asyncio.run(asyncio.wait_for(test(), timeout=1)) + + def test_calc_squares_multi_queue(self): + @self.check_handle_counts() + async def test(): + executors = [ + ThreadPoolExecutor(), + ThreadPoolExecutor(), + ThreadPoolExecutor(), + ] + self.assertEqual(await calc_squares_multi_queue(executors, [1, -2, 3]), [1, 4, 9]) + asyncio.run(asyncio.wait_for(test(), timeout=1)) + + @contextlib.asynccontextmanager + async def check_handle_counts(self): + initial_poll_handle_count = len(futures._UniffiPollDataHandleMap) + initial_blocking_task_queue_handle_count = len(futures._UniffiBlockingTaskQueueHandleMap) + yield + self.assertEqual(len(futures._UniffiPollDataHandleMap), initial_poll_handle_count) + self.assertEqual(len(futures._UniffiBlockingTaskQueueHandleMap), initial_blocking_task_queue_handle_count) + if __name__ == '__main__': unittest.main() diff --git a/fixtures/futures/tests/bindings/test_futures.swift b/fixtures/futures/tests/bindings/test_futures.swift index 11dacd870e..22b1aa5f64 100644 --- a/fixtures/futures/tests/bindings/test_futures.swift +++ b/fixtures/futures/tests/bindings/test_futures.swift @@ -3,10 +3,21 @@ import Foundation // To get `DispatchGroup` and `Date` types. var counter = DispatchGroup() -// Test `alwaysReady` -counter.enter() +func asyncTest(test: @escaping () async throws -> ()) { + let initialBlockingTaskQueueCount = uniffiBlockingTaskQueueHandleCountFutures() + let initialPollDataHandleCount = uniffiPollDataHandleCountFutures() + counter.enter() + Task { + try! await test() + counter.leave() + } + counter.wait() + assert(uniffiBlockingTaskQueueHandleCountFutures() == initialBlockingTaskQueueCount) + assert(uniffiPollDataHandleCountFutures() == initialPollDataHandleCount) +} -Task { +// Test `alwaysReady` +asyncTest { let t0 = Date() let result = await alwaysReady() let t1 = Date() @@ -14,40 +25,28 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration < 0.1) assert(result == true) - - counter.leave() } // Test record. -counter.enter() - -Task { +asyncTest { let result = await newMyRecord(a: "foo", b: 42) assert(result.a == "foo") assert(result.b == 42) - - counter.leave() } // Test `void` -counter.enter() - -Task { +asyncTest { let t0 = Date() await void() let t1 = Date() let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration < 0.1) - - counter.leave() } // Test `Sleep` -counter.enter() - -Task { +asyncTest { let t0 = Date() let result = await sleep(ms: 2000) let t1 = Date() @@ -55,14 +54,10 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 2 && tDelta.duration < 2.1) assert(result == true) - - counter.leave() } // Test sequential futures. -counter.enter() - -Task { +asyncTest { let t0 = Date() let result_alice = await sayAfter(ms: 1000, who: "Alice") let result_bob = await sayAfter(ms: 2000, who: "Bob") @@ -72,14 +67,10 @@ Task { assert(tDelta.duration > 3 && tDelta.duration < 3.1) assert(result_alice == "Hello, Alice!") assert(result_bob == "Hello, Bob!") - - counter.leave() } // Test concurrent futures. -counter.enter() - -Task { +asyncTest { async let alice = sayAfter(ms: 1000, who: "Alice") async let bob = sayAfter(ms: 2000, who: "Bob") @@ -91,14 +82,10 @@ Task { assert(tDelta.duration > 2 && tDelta.duration < 2.1) assert(result_alice == "Hello, Alice!") assert(result_bob == "Hello, Bob!") - - counter.leave() } // Test async methods -counter.enter() - -Task { +asyncTest { let megaphone = newMegaphone() let t0 = Date() @@ -108,8 +95,6 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 2 && tDelta.duration < 2.1) assert(result_alice == "HELLO, ALICE!") - - counter.leave() } // Test async trait interface methods @@ -242,21 +227,15 @@ Task { } // Test async function returning an object -counter.enter() - -Task { +asyncTest { let megaphone = await asyncNewMegaphone() let result = try await megaphone.fallibleMe(doFail: false) assert(result == 42) - - counter.leave() } // Test with the Tokio runtime. -counter.enter() - -Task { +asyncTest { let t0 = Date() let result_alice = await sayAfterWithTokio(ms: 2000, who: "Alice") let t1 = Date() @@ -264,15 +243,11 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 2 && tDelta.duration < 2.1) assert(result_alice == "Hello, Alice (with Tokio)!") - - counter.leave() } // Test fallible function/method… // … which doesn't throw. -counter.enter() - -Task { +asyncTest { let t0 = Date() let result = try await fallibleMe(doFail: false) let t1 = Date() @@ -280,19 +255,15 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 0 && tDelta.duration < 0.1) assert(result == 42) - - counter.leave() } -Task { +asyncTest { let m = try await fallibleStruct(doFail: false) let result = try await m.fallibleMe(doFail: false) assert(result == 42) } -counter.enter() - -Task { +asyncTest { let megaphone = newMegaphone() let t0 = Date() @@ -302,14 +273,10 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 0 && tDelta.duration < 0.1) assert(result == 42) - - counter.leave() } // … which does throw. -counter.enter() - -Task { +asyncTest { let t0 = Date() do { @@ -324,11 +291,9 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 0 && tDelta.duration < 0.1) - - counter.leave() } -Task { +asyncTest { do { let _ = try await fallibleStruct(doFail: true) } catch MyError.Foo { @@ -338,9 +303,7 @@ Task { } } -counter.enter() - -Task { +asyncTest { let megaphone = newMegaphone() let t0 = Date() @@ -357,13 +320,10 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 0 && tDelta.duration < 0.1) - - counter.leave() } // Test a future that uses a lock and that is cancelled. -counter.enter() -Task { +asyncTest { let task = Task { try! await useSharedResource(options: SharedResourceOptions(releaseAfterMs: 100, timeoutMs: 1000)) } @@ -379,15 +339,36 @@ Task { // Try accessing the shared resource again. The initial task should release the shared resource // before the timeout expires. try! await useSharedResource(options: SharedResourceOptions(releaseAfterMs: 0, timeoutMs: 1000)) - counter.leave() } // Test a future that uses a lock and that is not cancelled. -counter.enter() -Task { +asyncTest { try! await useSharedResource(options: SharedResourceOptions(releaseAfterMs: 100, timeoutMs: 1000)) try! await useSharedResource(options: SharedResourceOptions(releaseAfterMs: 0, timeoutMs: 1000)) - counter.leave() } -counter.wait() +// Test blocking task queues +asyncTest { + let calcSquareResult = await calcSquare(queue: DispatchQueue.global(qos: .userInitiated), value: 20) + assert(calcSquareResult == 400) + + let calcSquaresResult = await calcSquares(queue: DispatchQueue.global(qos: .userInitiated), items: [1, -2, 3]) + assert(calcSquaresResult == [1, 4, 9]) + + let calcSquaresMultiQueueResult = await calcSquaresMultiQueue( + queues: [ + DispatchQueue(label: "test-queue1", attributes: DispatchQueue.Attributes.concurrent), + DispatchQueue(label: "test-queue2", attributes: DispatchQueue.Attributes.concurrent), + DispatchQueue(label: "test-queue3", attributes: DispatchQueue.Attributes.concurrent) + ], + items: [1, -2, 3] + ) + assert(calcSquaresMultiQueueResult == [1, 4, 9]) +} + +// Test blocking task queue cloning +asyncTest { + let calcSquareResult = await calcSquareWithClone(queue: DispatchQueue.global(qos: .userInitiated), value: 20) + assert(calcSquareResult == 400) +} + diff --git a/fixtures/metadata/src/tests.rs b/fixtures/metadata/src/tests.rs index 2a280597ea..1a54cbe067 100644 --- a/fixtures/metadata/src/tests.rs +++ b/fixtures/metadata/src/tests.rs @@ -129,6 +129,7 @@ mod test_type_ids { check_type_id::(Type::Float64); check_type_id::(Type::Boolean); check_type_id::(Type::String); + check_type_id::(Type::BlockingTaskQueue); } #[test] diff --git a/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/blocking_task_queue.rs b/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/blocking_task_queue.rs new file mode 100644 index 0000000000..a664f1fcd3 --- /dev/null +++ b/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/blocking_task_queue.rs @@ -0,0 +1,19 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use super::CodeType; + +#[derive(Debug)] +pub struct BlockingTaskQueueCodeType; + +impl CodeType for BlockingTaskQueueCodeType { + fn type_label(&self, _ci: &crate::ComponentInterface) -> String { + // Kotlin uses CoroutineContext for BlockingTaskQueue + "CoroutineContext".into() + } + + fn canonical_name(&self) -> String { + "BlockingTaskQueue".into() + } +} diff --git a/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs b/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs index c4fc8e0ed6..bdae1a416a 100644 --- a/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs +++ b/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs @@ -18,6 +18,7 @@ use crate::bindings::kotlin; use crate::interface::*; use crate::{BindingGenerator, BindingsConfig}; +mod blocking_task_queue; mod callback_interface; mod compounds; mod custom; @@ -493,6 +494,8 @@ impl AsCodeType for T { Type::Timestamp => Box::new(miscellany::TimestampCodeType), Type::Duration => Box::new(miscellany::DurationCodeType), + Type::BlockingTaskQueue => Box::new(blocking_task_queue::BlockingTaskQueueCodeType), + Type::Enum { name, .. } => Box::new(enum_::EnumCodeType::new(name)), Type::Object { name, imp, .. } => Box::new(object::ObjectCodeType::new(name, imp)), Type::Record { name, .. } => Box::new(record::RecordCodeType::new(name)), @@ -668,7 +671,7 @@ mod filters { ) -> Result { let ffi_func = callable.ffi_rust_future_poll(ci); Ok(format!( - "{{ future, callback, continuation -> UniffiLib.INSTANCE.{ffi_func}(future, callback, continuation) }}" + "{{ future, callback, continuation, blockingTaskQueueHandle -> UniffiLib.INSTANCE.{ffi_func}(future, callback, continuation, blockingTaskQueueHandle) }}" )) } diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt b/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt index b28fbd2c80..de764444b7 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt @@ -3,18 +3,41 @@ internal const val UNIFFI_RUST_FUTURE_POLL_READY = 0.toByte() internal const val UNIFFI_RUST_FUTURE_POLL_MAYBE_READY = 1.toByte() -internal val uniffiContinuationHandleMap = UniffiHandleMap>() +/** + * Data for an in-progress poll of a RustFuture + */ +internal data class UniffiPollData( + val continuation: CancellableContinuation, + val rustFuture: Long, + val pollFunc: (Long, UniffiRustFutureContinuationCallback, Long, Long) -> Unit, +) + +// Stores the UniffiPollData instances that correspond to RustFuture callback data +internal val uniffiPollDataHandleMap = UniffiHandleMap() + +// Stores the CoroutineContext instances that correspond to blocking task queue handles +internal val uniffiBlockingTaskQueueHandleMap = UniffiHandleMap() // FFI type for Rust future continuations internal object uniffiRustFutureContinuationCallbackImpl: UniffiRustFutureContinuationCallback { - override fun callback(data: Long, pollResult: Byte) { - uniffiContinuationHandleMap.remove(data).resume(pollResult) + override fun callback(data: Long, pollResult: Byte, blockingTaskQueueHandle: Long) { + if (blockingTaskQueueHandle == 0L) { + // Complete the Kotlin continuation + uniffiPollDataHandleMap.remove(data).continuation.resume(pollResult) + } else { + // Call the poll function again, but inside the BlockingTaskQueue coroutine context + val coroutineContext = uniffiBlockingTaskQueueHandleMap.get(blockingTaskQueueHandle) + val pollData = uniffiPollDataHandleMap.get(data) + CoroutineScope(coroutineContext).launch { + pollData.pollFunc(pollData.rustFuture, uniffiRustFutureContinuationCallbackImpl, data, blockingTaskQueueHandle) + } + } } } internal suspend fun uniffiRustCallAsync( rustFuture: Long, - pollFunc: (Long, UniffiRustFutureContinuationCallback, Long) -> Unit, + pollFunc: (Long, UniffiRustFutureContinuationCallback, Long, Long) -> Unit, completeFunc: (Long, UniffiRustCallStatus) -> F, freeFunc: (Long) -> Unit, liftFunc: (F) -> T, @@ -23,10 +46,12 @@ internal suspend fun uniffiRustCallAsync( try { do { val pollResult = suspendCancellableCoroutine { continuation -> + val pollData = UniffiPollData(continuation, rustFuture, pollFunc) pollFunc( rustFuture, uniffiRustFutureContinuationCallbackImpl, - uniffiContinuationHandleMap.insert(continuation) + uniffiPollDataHandleMap.insert(pollData), + 0L ) } } while (pollResult != UNIFFI_RUST_FUTURE_POLL_READY); @@ -113,5 +138,7 @@ internal object uniffiForeignFutureFreeImpl: UniffiForeignFutureFree { // For testing public fun uniffiForeignFutureHandleCount() = uniffiForeignFutureHandleMap.size - {%- endif %} + +// For testing +public fun uniffiPollHandleCount() = uniffiPollDataHandleMap.size diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/BlockingTaskQueueTemplate.kt b/uniffi_bindgen/src/bindings/kotlin/templates/BlockingTaskQueueTemplate.kt new file mode 100644 index 0000000000..6ec55b4a7a --- /dev/null +++ b/uniffi_bindgen/src/bindings/kotlin/templates/BlockingTaskQueueTemplate.kt @@ -0,0 +1,41 @@ +object uniffiBlockingTaskQueueClone : UniffiBlockingTaskQueueClone { + override fun callback(handle: Long): Long { + val coroutineContext = uniffiBlockingTaskQueueHandleMap.get(handle) + return uniffiBlockingTaskQueueHandleMap.insert(coroutineContext) + } +} + +object uniffiBlockingTaskQueueFree : UniffiBlockingTaskQueueFree { + override fun callback(handle: Long) { + uniffiBlockingTaskQueueHandleMap.remove(handle) + } +} + +internal val uniffiBlockingTaskQueueVTable = UniffiBlockingTaskQueueVTable( + uniffiBlockingTaskQueueClone, + uniffiBlockingTaskQueueFree, +) + +public object {{ ffi_converter_name }}: FfiConverterRustBuffer { + override fun allocationSize(value: {{ type_name }}) = 16UL + + override fun write(value: CoroutineContext, buf: ByteBuffer) { + // Call `write()` to make sure the data is written to the JNA backing data + uniffiBlockingTaskQueueVTable.write() + val handle = uniffiBlockingTaskQueueHandleMap.insert(value) + buf.putLong(handle) + buf.putLong(Pointer.nativeValue(uniffiBlockingTaskQueueVTable.getPointer())) + } + + override fun read(buf: ByteBuffer): CoroutineContext { + val handle = buf.getLong() + val coroutineContext = uniffiBlockingTaskQueueHandleMap.remove(handle) + // Read the VTable pointer and throw it out. The vtable is only used by Rust and always the + // same value. + buf.getLong() + return coroutineContext + } +} + +// For testing +public fun uniffiBlockingTaskQueueHandleCount() = uniffiBlockingTaskQueueHandleMap.size diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/HandleMap.kt b/uniffi_bindgen/src/bindings/kotlin/templates/HandleMap.kt index 3a56648190..7215339958 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/HandleMap.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/HandleMap.kt @@ -3,7 +3,8 @@ // This is used pass an opaque 64-bit handle representing a foreign object to the Rust code. internal class UniffiHandleMap { private val map = ConcurrentHashMap() - private val counter = java.util.concurrent.atomic.AtomicLong(0) + // Start at 1, since `0` represents a NULL handle. + private val counter = java.util.concurrent.atomic.AtomicLong(1) val size: Int get() = map.size diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/Types.kt b/uniffi_bindgen/src/bindings/kotlin/templates/Types.kt index c27121b701..e8cedc5650 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/Types.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/Types.kt @@ -89,6 +89,9 @@ object NoPointer {%- when Type::Bytes %} {%- include "ByteArrayHelper.kt" %} +{%- when Type::BlockingTaskQueue %} +{%- include "BlockingTaskQueueTemplate.kt" %} + {%- when Type::Enum { name, module_path } %} {%- let e = ci.get_enum_definition(name).unwrap() %} {%- if !ci.is_name_used_as_error(name) %} @@ -134,10 +137,12 @@ object NoPointer {%- if ci.has_async_fns() %} {# Import types needed for async support #} {{ self.add_import("kotlin.coroutines.resume") }} -{{ self.add_import("kotlinx.coroutines.launch") }} -{{ self.add_import("kotlinx.coroutines.suspendCancellableCoroutine") }} +{{ self.add_import("kotlin.coroutines.CoroutineContext") }} {{ self.add_import("kotlinx.coroutines.CancellableContinuation") }} +{{ self.add_import("kotlinx.coroutines.CoroutineScope") }} {{ self.add_import("kotlinx.coroutines.DelicateCoroutinesApi") }} -{{ self.add_import("kotlinx.coroutines.Job") }} {{ self.add_import("kotlinx.coroutines.GlobalScope") }} +{{ self.add_import("kotlinx.coroutines.Job") }} +{{ self.add_import("kotlinx.coroutines.launch") }} +{{ self.add_import("kotlinx.coroutines.suspendCancellableCoroutine") }} {%- endif %} diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/wrapper.kt b/uniffi_bindgen/src/bindings/kotlin/templates/wrapper.kt index 2cdc72a5e2..ea326bb242 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/wrapper.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/wrapper.kt @@ -32,6 +32,7 @@ import java.nio.CharBuffer import java.nio.charset.CodingErrorAction import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.ConcurrentHashMap +import kotlin.concurrent.withLock {%- for req in self.imports() %} {{ req.render() }} diff --git a/uniffi_bindgen/src/bindings/python/gen_python/blocking_task_queue.rs b/uniffi_bindgen/src/bindings/python/gen_python/blocking_task_queue.rs new file mode 100644 index 0000000000..2f5a74fda0 --- /dev/null +++ b/uniffi_bindgen/src/bindings/python/gen_python/blocking_task_queue.rs @@ -0,0 +1,19 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use super::CodeType; + +#[derive(Debug)] +pub struct BlockingTaskQueueCodeType; + +impl CodeType for BlockingTaskQueueCodeType { + // On python we use an concurrent.futures.Executor for a BlockingTaskQueue + fn type_label(&self) -> String { + "concurrent.futures.Executor".into() + } + + fn canonical_name(&self) -> String { + "BlockingTaskQueue".into() + } +} diff --git a/uniffi_bindgen/src/bindings/python/gen_python/mod.rs b/uniffi_bindgen/src/bindings/python/gen_python/mod.rs index 6a10a38e7f..350f1f73a7 100644 --- a/uniffi_bindgen/src/bindings/python/gen_python/mod.rs +++ b/uniffi_bindgen/src/bindings/python/gen_python/mod.rs @@ -18,6 +18,7 @@ use crate::bindings::python; use crate::interface::*; use crate::{BindingGenerator, BindingsConfig}; +mod blocking_task_queue; mod callback_interface; mod compounds; mod custom; @@ -479,6 +480,8 @@ impl AsCodeType for T { Type::Timestamp => Box::new(miscellany::TimestampCodeType), Type::Duration => Box::new(miscellany::DurationCodeType), + Type::BlockingTaskQueue => Box::new(blocking_task_queue::BlockingTaskQueueCodeType), + Type::Enum { name, .. } => Box::new(enum_::EnumCodeType::new(name)), Type::Object { name, .. } => Box::new(object::ObjectCodeType::new(name)), Type::Record { name, .. } => Box::new(record::RecordCodeType::new(name)), diff --git a/uniffi_bindgen/src/bindings/python/templates/Async.py b/uniffi_bindgen/src/bindings/python/templates/Async.py index 26daa9ba5c..57ca4e03e3 100644 --- a/uniffi_bindgen/src/bindings/python/templates/Async.py +++ b/uniffi_bindgen/src/bindings/python/templates/Async.py @@ -2,8 +2,22 @@ _UNIFFI_RUST_FUTURE_POLL_READY = 0 _UNIFFI_RUST_FUTURE_POLL_MAYBE_READY = 1 -# Stores futures for _uniffi_continuation_callback -_UniffiContinuationHandleMap = _UniffiHandleMap() +""" +Data for an in-progress poll of a RustFuture +""" +class UniffiPoll(typing.NamedTuple): + eventloop: asyncio.AbstractEventLoop + future: asyncio.Future + rust_future: int + # Must be UNIFFI_RUST_FUTURE_CONTINUATION_CALLBACK, but it's not clear how to specify as valid + # type for mypy and our current Python version + ffi_poll: object + +# Stores the UniffiPoll instances that correspond to RustFuture callback data +_UniffiPollDataHandleMap = _UniffiHandleMap() + +# Stores the concurrent.futures.Executor instances that correspond to blocking task queue handles +_UniffiBlockingTaskQueueHandleMap = _UniffiHandleMap() UNIFFI_GLOBAL_EVENT_LOOP = None @@ -32,9 +46,22 @@ def _uniffi_get_event_loop(): # Continuation callback for async functions # lift the return value or error and resolve the future, causing the async function to resume. @UNIFFI_RUST_FUTURE_CONTINUATION_CALLBACK -def _uniffi_continuation_callback(future_ptr, poll_code): - (eventloop, future) = _UniffiContinuationHandleMap.remove(future_ptr) - eventloop.call_soon_threadsafe(_uniffi_set_future_result, future, poll_code) +def _uniffi_continuation_callback(poll_data_handle, poll_code, blocking_task_queue_handle): + if blocking_task_queue_handle == 0: + # Complete the Python Future + poll_data = _UniffiPollDataHandleMap.remove(poll_data_handle) + poll_data.eventloop.call_soon_threadsafe(_uniffi_set_future_result, poll_data.future, poll_code) + else: + # Call the poll function again, but inside the executor + poll_data = _UniffiPollDataHandleMap.get(poll_data_handle) + executor = _UniffiBlockingTaskQueueHandleMap.get(blocking_task_queue_handle) + executor.submit( + poll_data.ffi_poll, + poll_data.rust_future, + _uniffi_continuation_callback, + poll_data_handle, + blocking_task_queue_handle + ) def _uniffi_set_future_result(future, poll_code): if not future.cancelled(): @@ -47,10 +74,17 @@ async def _uniffi_rust_call_async(rust_future, ffi_poll, ffi_complete, ffi_free, # Loop and poll until we see a _UNIFFI_RUST_FUTURE_POLL_READY value while True: future = eventloop.create_future() + poll_data = UniffiPoll( + eventloop=eventloop, + future=future, + rust_future=rust_future, + ffi_poll=ffi_poll, + ) ffi_poll( rust_future, _uniffi_continuation_callback, - _UniffiContinuationHandleMap.insert((eventloop, future)), + _UniffiPollDataHandleMap.insert(poll_data), + 0, ) poll_code = await future if poll_code == _UNIFFI_RUST_FUTURE_POLL_READY: diff --git a/uniffi_bindgen/src/bindings/python/templates/BlockingTaskQueueTemplate.py b/uniffi_bindgen/src/bindings/python/templates/BlockingTaskQueueTemplate.py new file mode 100644 index 0000000000..3d1a96a2c5 --- /dev/null +++ b/uniffi_bindgen/src/bindings/python/templates/BlockingTaskQueueTemplate.py @@ -0,0 +1,37 @@ +{{ self.add_import("concurrent.futures") }} + +@UNIFFI_BLOCKING_TASK_QUEUE_CLONE +def uniffi_blocking_task_queue_clone(handle): + executor = _UniffiBlockingTaskQueueHandleMap.get(handle) + return _UniffiBlockingTaskQueueHandleMap.insert(executor) + +@UNIFFI_BLOCKING_TASK_QUEUE_FREE +def uniffi_blocking_task_queue_free(handle): + _UniffiBlockingTaskQueueHandleMap.remove(handle) + +UNIFFI_BLOCKING_TASK_QUEUE_VTABLE = UniffiBlockingTaskQueueVTable( + uniffi_blocking_task_queue_clone, + uniffi_blocking_task_queue_free, +) + +class {{ ffi_converter_name }}(_UniffiConverterRustBuffer): + @staticmethod + def check_lower(value): + if not isinstance(value, concurrent.futures.Executor): + raise TypeError("Expected concurrent.futures.Executor instance, {} found".format(type(value).__name__)) + + @staticmethod + def write(value, buf): + handle = _UniffiBlockingTaskQueueHandleMap.insert(value) + buf.write_u64(handle) + buf.write_u64(ctypes.addressof(UNIFFI_BLOCKING_TASK_QUEUE_VTABLE)) + + @staticmethod + def read(buf): + handle = buf.read_u64() + executor = _UniffiBlockingTaskQueueHandleMap.remove(handle) + # Read the VTable pointer and throw it out. The vtable is only used by Rust and always the + # same value. + buf.read_u64() + return executor + diff --git a/uniffi_bindgen/src/bindings/python/templates/HandleMap.py b/uniffi_bindgen/src/bindings/python/templates/HandleMap.py index f7c13cf745..c18d17048a 100644 --- a/uniffi_bindgen/src/bindings/python/templates/HandleMap.py +++ b/uniffi_bindgen/src/bindings/python/templates/HandleMap.py @@ -7,7 +7,8 @@ def __init__(self): # type Handle = int self._map = {} # type: Dict[Handle, Any] self._lock = threading.Lock() - self._counter = itertools.count() + # Start at 1, since `0` represents a NULL handle. + self._counter = itertools.count(1) def insert(self, obj): with self._lock: diff --git a/uniffi_bindgen/src/bindings/python/templates/Types.py b/uniffi_bindgen/src/bindings/python/templates/Types.py index 4aaed253e0..9f2840fefb 100644 --- a/uniffi_bindgen/src/bindings/python/templates/Types.py +++ b/uniffi_bindgen/src/bindings/python/templates/Types.py @@ -55,6 +55,9 @@ {%- when Type::Bytes %} {%- include "BytesHelper.py" %} +{%- when Type::BlockingTaskQueue %} +{%- include "BlockingTaskQueueTemplate.py" %} + {%- when Type::Enum { name, module_path } %} {%- let e = ci.get_enum_definition(name).unwrap() %} {# For enums, there are either an error *or* an enum, they can't be both. #} diff --git a/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs b/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs index 86234597fe..b26090e6a1 100644 --- a/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs +++ b/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs @@ -81,6 +81,7 @@ pub fn canonical_name(t: &Type) -> String { Type::CallbackInterface { name, .. } => format!("CallbackInterface{name}"), Type::Timestamp => "Timestamp".into(), Type::Duration => "Duration".into(), + Type::BlockingTaskQueue => "BlockingTaskQueue".into(), // Recursive types. // These add a prefix to the name of the underlying type. // The component API definition cannot give names to recursive types, so as long as the @@ -287,6 +288,7 @@ mod filters { } Type::External { .. } => panic!("No support for external types, yet"), Type::Custom { .. } => panic!("No support for custom types, yet"), + Type::BlockingTaskQueue => panic!("No support for async functions, yet"), }) } @@ -340,6 +342,7 @@ mod filters { ), Type::External { .. } => panic!("No support for lowering external types, yet"), Type::Custom { .. } => panic!("No support for lowering custom types, yet"), + Type::BlockingTaskQueue => panic!("No support for async functions, yet"), }) } @@ -380,6 +383,7 @@ mod filters { ), Type::External { .. } => panic!("No support for lifting external types, yet"), Type::Custom { .. } => panic!("No support for lifting custom types, yet"), + Type::BlockingTaskQueue => panic!("No support for async functions, yet"), }) } } diff --git a/uniffi_bindgen/src/bindings/swift/gen_swift/blocking_task_queue.rs b/uniffi_bindgen/src/bindings/swift/gen_swift/blocking_task_queue.rs new file mode 100644 index 0000000000..ab4df07f9b --- /dev/null +++ b/uniffi_bindgen/src/bindings/swift/gen_swift/blocking_task_queue.rs @@ -0,0 +1,19 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use super::CodeType; + +#[derive(Debug)] +pub struct BlockingTaskQueueCodeType; + +impl CodeType for BlockingTaskQueueCodeType { + fn type_label(&self) -> String { + // On Swift, we use a DispatchQueue for BlockingTaskQueue + "DispatchQueue".into() + } + + fn canonical_name(&self) -> String { + "BlockingTaskQueue".into() + } +} diff --git a/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs b/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs index e784e0268e..6f02f77479 100644 --- a/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs +++ b/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs @@ -20,6 +20,7 @@ use crate::bindings::swift; use crate::interface::*; use crate::{BindingGenerator, BindingsConfig}; +mod blocking_task_queue; mod callback_interface; mod compounds; mod custom; @@ -487,6 +488,8 @@ impl SwiftCodeOracle { Type::Timestamp => Box::new(miscellany::TimestampCodeType), Type::Duration => Box::new(miscellany::DurationCodeType), + Type::BlockingTaskQueue => Box::new(blocking_task_queue::BlockingTaskQueueCodeType), + Type::Enum { name, .. } => Box::new(enum_::EnumCodeType::new(name)), Type::Object { name, imp, .. } => Box::new(object::ObjectCodeType::new(name, imp)), Type::Record { name, .. } => Box::new(record::RecordCodeType::new(name)), diff --git a/uniffi_bindgen/src/bindings/swift/templates/Async.swift b/uniffi_bindgen/src/bindings/swift/templates/Async.swift index e16f3108e1..555bba41ab 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/Async.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/Async.swift @@ -1,11 +1,32 @@ private let UNIFFI_RUST_FUTURE_POLL_READY: Int8 = 0 private let UNIFFI_RUST_FUTURE_POLL_MAYBE_READY: Int8 = 1 -fileprivate let uniffiContinuationHandleMap = UniffiHandleMap>() +// Data for an in-progress poll of a RustFuture +fileprivate class UniffiPollData { + let continuation: UnsafeContinuation + let rustFuture: UInt64 + let pollFunc: (UInt64, @escaping UniffiRustFutureContinuationCallback, UInt64, UInt64) -> () + + init( + continuation: UnsafeContinuation, + rustFuture: UInt64, + pollFunc: @escaping (UInt64, @escaping UniffiRustFutureContinuationCallback, UInt64, UInt64) -> () + ) { + self.continuation = continuation + self.rustFuture = rustFuture + self.pollFunc = pollFunc + } +} + +// Stores the UniffiPollData instances that correspond to RustFuture callback data +fileprivate let uniffiPollDataHandleMap = UniffiHandleMap() + +// Stores the DispatchQueue instances that correspond to blocking task queue handles +fileprivate var uniffiBlockingTaskQueueHandleMap = UniffiHandleMap() fileprivate func uniffiRustCallAsync( rustFutureFunc: () -> UInt64, - pollFunc: (UInt64, @escaping UniffiRustFutureContinuationCallback, UInt64) -> (), + pollFunc: @escaping (UInt64, @escaping UniffiRustFutureContinuationCallback, UInt64, UInt64) -> (), completeFunc: (UInt64, UnsafeMutablePointer) -> F, freeFunc: (UInt64) -> (), liftFunc: (F) throws -> T, @@ -21,11 +42,18 @@ fileprivate func uniffiRustCallAsync( var pollResult: Int8; repeat { pollResult = await withUnsafeContinuation { + let pollData = UniffiPollData( + continuation: $0, + rustFuture: rustFuture, + pollFunc: pollFunc + ) pollFunc( rustFuture, uniffiFutureContinuationCallback, - uniffiContinuationHandleMap.insert(obj: $0) + uniffiPollDataHandleMap.insert(obj: pollData), + 0 ) + } } while pollResult != UNIFFI_RUST_FUTURE_POLL_READY @@ -37,11 +65,22 @@ fileprivate func uniffiRustCallAsync( // Callback handlers for an async calls. These are invoked by Rust when the future is ready. They // lift the return value or error and resume the suspended function. -fileprivate func uniffiFutureContinuationCallback(handle: UInt64, pollResult: Int8) { - if let continuation = try? uniffiContinuationHandleMap.remove(handle: handle) { - continuation.resume(returning: pollResult) +fileprivate func uniffiFutureContinuationCallback( + pollDataHandle: UInt64, + pollResult: Int8, + blockingTaskQueueHandle: UInt64 +) { + if (blockingTaskQueueHandle == 0) { + // Try to complete the Swift continutation + let pollData = try! uniffiPollDataHandleMap.remove(handle: pollDataHandle) + pollData.continuation.resume(returning: pollResult) } else { - print("uniffiFutureContinuationCallback invalid handle") + // Call the poll function again, but inside the DispatchQuee + let pollData = try! uniffiPollDataHandleMap.get(handle: pollDataHandle) + let queue = try! uniffiBlockingTaskQueueHandleMap.get(handle: blockingTaskQueueHandle) + queue.async { + pollData.pollFunc(pollData.rustFuture, uniffiFutureContinuationCallback, pollDataHandle, blockingTaskQueueHandle) + } } } @@ -114,3 +153,8 @@ public func uniffiForeignFutureHandleCount{{ ci.namespace()|class_name }}() -> I } {%- endif %} + +// For testing +public func uniffiPollDataHandleCount{{ ci.namespace()|class_name }}() -> Int { + return uniffiPollDataHandleMap.count +} diff --git a/uniffi_bindgen/src/bindings/swift/templates/BlockingTaskQueueTemplate.swift b/uniffi_bindgen/src/bindings/swift/templates/BlockingTaskQueueTemplate.swift new file mode 100644 index 0000000000..36d4cb6890 --- /dev/null +++ b/uniffi_bindgen/src/bindings/swift/templates/BlockingTaskQueueTemplate.swift @@ -0,0 +1,45 @@ +fileprivate var UNIFFI_BLOCKING_TASK_QUEUE_VTABLE = UniffiBlockingTaskQueueVTable( + clone: { (handle: UInt64) -> UInt64 in + do { + let dispatchQueue = try uniffiBlockingTaskQueueHandleMap.get(handle: handle) + return uniffiBlockingTaskQueueHandleMap.insert(obj: dispatchQueue) + } catch { + print("UniffiBlockingTaskQueueVTable.clone: invalid task queue handle") + return 0 + } + }, + free: { (handle: UInt64) in + do { + try uniffiBlockingTaskQueueHandleMap.remove(handle: handle) + } catch { + print("UniffiBlockingTaskQueueVTable.free: invalid task queue handle") + } + } +) + +fileprivate struct {{ ffi_converter_name }}: FfiConverterRustBuffer { + typealias SwiftType = DispatchQueue + + public static func write(_ value: DispatchQueue, into buf: inout [UInt8]) { + let handle = uniffiBlockingTaskQueueHandleMap.insert(obj: value) + writeInt(&buf, handle) + // From Apple: "You can safely use the address of a global variable as a persistent unique + // pointer value" (https://developer.apple.com/swift/blog/?id=6) + let vtablePointer = UnsafeMutablePointer(&UNIFFI_BLOCKING_TASK_QUEUE_VTABLE) + // Convert the pointer to a word-sized Int then to a 64-bit int then write it out. + writeInt(&buf, Int64(Int(bitPattern: vtablePointer))) + } + + public static func read(from buf: inout (data: Data, offset: Data.Index)) throws -> DispatchQueue { + let handle: UInt64 = try readInt(&buf) + // Read the VTable pointer and throw it out. The vtable is only used by Rust and always the + // same value. + let _: UInt64 = try readInt(&buf) + return try uniffiBlockingTaskQueueHandleMap.remove(handle: handle) + } +} + +// For testing +public func uniffiBlockingTaskQueueHandleCount{{ ci.namespace()|class_name }}() -> Int { + uniffiBlockingTaskQueueHandleMap.count +} diff --git a/uniffi_bindgen/src/bindings/swift/templates/HandleMap.swift b/uniffi_bindgen/src/bindings/swift/templates/HandleMap.swift index 6de9f085d6..033a76ce8a 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/HandleMap.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/HandleMap.swift @@ -1,6 +1,7 @@ fileprivate class UniffiHandleMap { private var map: [UInt64: T] = [:] private let lock = NSLock() + // Start at 1, since `0` represents a NULL handle. private var currentHandle: UInt64 = 1 func insert(obj: T) -> UInt64 { diff --git a/uniffi_bindgen/src/bindings/swift/templates/Types.swift b/uniffi_bindgen/src/bindings/swift/templates/Types.swift index 5e26758f3c..ba4d2059c8 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/Types.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/Types.swift @@ -64,6 +64,9 @@ {%- when Type::CallbackInterface { name, module_path } %} {%- include "CallbackInterfaceTemplate.swift" %} +{%- when Type::BlockingTaskQueue %} +{%- include "BlockingTaskQueueTemplate.swift" %} + {%- when Type::Custom { name, module_path, builtin } %} {%- include "CustomType.swift" %} diff --git a/uniffi_bindgen/src/interface/ffi.rs b/uniffi_bindgen/src/interface/ffi.rs index b27cb78477..eb7a269844 100644 --- a/uniffi_bindgen/src/interface/ffi.rs +++ b/uniffi_bindgen/src/interface/ffi.rs @@ -130,7 +130,8 @@ impl From<&Type> for FfiType { | Type::Sequence { .. } | Type::Map { .. } | Type::Timestamp - | Type::Duration => FfiType::RustBuffer(None), + | Type::Duration + | Type::BlockingTaskQueue => FfiType::RustBuffer(None), Type::External { name, kind: ExternalKind::Interface, diff --git a/uniffi_bindgen/src/interface/mod.rs b/uniffi_bindgen/src/interface/mod.rs index 90a941637a..e81aaada21 100644 --- a/uniffi_bindgen/src/interface/mod.rs +++ b/uniffi_bindgen/src/interface/mod.rs @@ -514,6 +514,10 @@ impl ComponentInterface { name: "callback_data".to_owned(), type_: FfiType::Handle, }, + FfiArgument { + name: "blocking_task_queue_handle".to_owned(), + type_: FfiType::UInt64, + }, ], return_type: None, has_rust_call_status_arg: false, @@ -628,6 +632,7 @@ impl ComponentInterface { arguments: vec![ FfiArgument::new("data", FfiType::UInt64), FfiArgument::new("poll_result", FfiType::Int8), + FfiArgument::new("blocking_task_queue_handle", FfiType::UInt64), ], return_type: None, has_rust_call_status_arg: false, @@ -647,6 +652,20 @@ impl ComponentInterface { has_rust_call_status_arg: false, } .into(), + FfiCallbackFunction { + name: "BlockingTaskQueueClone".to_owned(), + arguments: vec![FfiArgument::new("handle", FfiType::UInt64)], + return_type: Some(FfiType::UInt64), + has_rust_call_status_arg: false, + } + .into(), + FfiCallbackFunction { + name: "BlockingTaskQueueFree".to_owned(), + arguments: vec![FfiArgument::new("handle", FfiType::UInt64)], + return_type: None, + has_rust_call_status_arg: false, + } + .into(), FfiStruct { name: "ForeignFuture".to_owned(), fields: vec![ @@ -655,6 +674,20 @@ impl ComponentInterface { ], } .into(), + FfiStruct { + name: "BlockingTaskQueueVTable".to_owned(), + fields: vec![ + FfiField::new( + "clone", + FfiType::Callback("BlockingTaskQueueClone".to_owned()), + ), + FfiField::new( + "free", + FfiType::Callback("BlockingTaskQueueFree".to_owned()), + ), + ], + } + .into(), ] .into_iter() .chain( @@ -865,8 +898,11 @@ impl ComponentInterface { self.types.add_known_types(defn.iter_types())?; defn.throws_name() .map(|n| self.errors.insert(n.to_string())); + if defn.is_async() { + self.types + .add_known_type(&uniffi_meta::Type::BlockingTaskQueue)?; + } self.functions.push(defn); - Ok(()) } diff --git a/uniffi_bindgen/src/interface/universe.rs b/uniffi_bindgen/src/interface/universe.rs index 70bc61f8a9..2faef72fd6 100644 --- a/uniffi_bindgen/src/interface/universe.rs +++ b/uniffi_bindgen/src/interface/universe.rs @@ -84,6 +84,7 @@ impl TypeUniverse { Type::Bytes => self.add_type_definition("bytes", type_)?, Type::Timestamp => self.add_type_definition("timestamp", type_)?, Type::Duration => self.add_type_definition("duration", type_)?, + Type::BlockingTaskQueue => self.add_type_definition("BlockingTaskQueue", type_)?, Type::Object { name, .. } | Type::Record { name, .. } | Type::Enum { name, .. } diff --git a/uniffi_bindgen/src/scaffolding/mod.rs b/uniffi_bindgen/src/scaffolding/mod.rs index 7fd81831aa..231e0495d3 100644 --- a/uniffi_bindgen/src/scaffolding/mod.rs +++ b/uniffi_bindgen/src/scaffolding/mod.rs @@ -45,6 +45,7 @@ mod filters { format!("std::sync::Arc<{}>", imp.rust_name_for(name)) } Type::CallbackInterface { name, .. } => format!("Box"), + Type::BlockingTaskQueue => "::uniffi::BlockingTaskQueue".to_owned(), Type::Optional { inner_type } => { format!("std::option::Option<{}>", type_rs(inner_type)?) } diff --git a/uniffi_core/src/ffi/rustfuture/blocking_task_queue.rs b/uniffi_core/src/ffi/rustfuture/blocking_task_queue.rs new file mode 100644 index 0000000000..50abe5a66f --- /dev/null +++ b/uniffi_core/src/ffi/rustfuture/blocking_task_queue.rs @@ -0,0 +1,69 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +//! Defines the BlockingTaskQueue struct +//! +//! This module is responsible for the general handling of BlockingTaskQueue instances (cloning, droping, etc). +//! See `scheduler.rs` and the foreign bindings code for how the async functionality is implemented. + +use super::scheduler::schedule_in_blocking_task_queue; +use std::num::NonZeroU64; + +/// Foreign-managed blocking task queue that we can use to schedule futures +/// +/// On the foreign side this is a Kotlin `CoroutineContext`, Python `Executor` or Swift +/// `DispatchQueue`. UniFFI converts those objects into this struct for the Rust code to use. +/// +/// Rust async code can call [BlockingTaskQueue::execute] to run a closure in that +/// blocking task queue. Use this for functions with blocking operations that should not be executed +/// in a normal async context. Some examples are non-async file/network operations, long-running +/// CPU-bound tasks, blocking database operations, etc. +#[repr(C)] +pub struct BlockingTaskQueue { + /// Opaque handle for the task queue + pub handle: NonZeroU64, + /// Method VTable + /// + /// This is simply a C struct where each field is a function pointer that inputs a + /// BlockingTaskQueue handle + pub vtable: &'static BlockingTaskQueueVTable, +} + +#[repr(C)] +#[derive(Debug)] +pub struct BlockingTaskQueueVTable { + clone: extern "C" fn(u64) -> u64, + drop: extern "C" fn(u64), +} + +// Note: see `scheduler.rs` for details on how BlockingTaskQueue is used. +impl BlockingTaskQueue { + /// Run a closure in a blocking task queue + pub async fn execute(&self, f: F) -> R + where + F: FnOnce() -> R, + { + schedule_in_blocking_task_queue(self.handle).await; + f() + } +} + +impl Clone for BlockingTaskQueue { + fn clone(&self) -> Self { + let raw_handle = (self.vtable.clone)(self.handle.into()); + let handle = raw_handle + .try_into() + .expect("BlockingTaskQueue.clone() returned 0"); + Self { + handle, + vtable: self.vtable, + } + } +} + +impl Drop for BlockingTaskQueue { + fn drop(&mut self) { + (self.vtable.drop)(self.handle.into()) + } +} diff --git a/uniffi_core/src/ffi/rustfuture/future.rs b/uniffi_core/src/ffi/rustfuture/future.rs index 93c34e7543..26781206d2 100644 --- a/uniffi_core/src/ffi/rustfuture/future.rs +++ b/uniffi_core/src/ffi/rustfuture/future.rs @@ -21,6 +21,10 @@ //! 2b. If the async function is cancelled, then call [rust_future_cancel]. This causes the //! continuation function to be called with [RustFuturePoll::Ready] and the [RustFuture] to //! enter a cancelled state. +//! 2c. If the Rust code wants schedule work to be run in a `BlockingTaskQueue`, then the +//! continuation is called with [RustFuturePoll::MaybeReady] and the blocking task queue handle. +//! The foreign code is responsible for ensuring the next [rust_future_poll] call happens in +//! that blocking task queue and the handle is passed to [rust_future_poll]. //! 3. Call [rust_future_complete] to get the result of the future. //! 4. Call [rust_future_free] to free the future, ideally in a finally block. This: //! - Releases any resources held by the future @@ -78,6 +82,7 @@ use std::{ future::Future, marker::PhantomData, + num::NonZeroU64, ops::Deref, panic, pin::Pin, @@ -85,8 +90,8 @@ use std::{ task::{Context, Poll, Wake}, }; -use super::{RustFutureContinuationCallback, RustFuturePoll, Scheduler}; -use crate::{rust_call_with_out_status, FfiDefault, LowerReturn, RustCallStatus}; +use super::{scheduler, RustFutureContinuationCallback, Scheduler}; +use crate::{rust_call_with_out_status, FfiDefault, LowerReturn, RustCallStatus, RustFuturePoll}; /// Wraps the actual future we're polling struct WrappedFuture @@ -223,17 +228,26 @@ where }) } - pub(super) fn poll(self: Arc, callback: RustFutureContinuationCallback, data: u64) { + pub(super) fn poll( + self: Arc, + callback: RustFutureContinuationCallback, + data: u64, + blocking_task_queue_handle: Option, + ) { + scheduler::on_poll_start(blocking_task_queue_handle); + // Clear out the waked flag, since we're about to poll right now. + self.scheduler.lock().unwrap().clear_wake_flag(); let ready = self.is_cancelled() || { let mut locked = self.future.lock().unwrap(); let waker: std::task::Waker = Arc::clone(&self).into(); locked.poll(&mut Context::from_waker(&waker)) }; if ready { - callback(data, RustFuturePoll::Ready) + callback(data, RustFuturePoll::Ready, 0) } else { self.scheduler.lock().unwrap().store(callback, data); } + scheduler::on_poll_end(); } pub(super) fn is_cancelled(&self) -> bool { @@ -289,7 +303,12 @@ where /// only create those functions for each of the 13 possible FFI return types. #[doc(hidden)] pub trait RustFutureFfi: Send + Sync { - fn ffi_poll(self: Arc, callback: RustFutureContinuationCallback, data: u64); + fn ffi_poll( + self: Arc, + callback: RustFutureContinuationCallback, + data: u64, + blocking_task_queue_handle: Option, + ); fn ffi_cancel(&self); fn ffi_complete(&self, call_status: &mut RustCallStatus) -> ReturnType; fn ffi_free(self: Arc); @@ -302,8 +321,13 @@ where T: LowerReturn + Send + 'static, UT: Send + 'static, { - fn ffi_poll(self: Arc, callback: RustFutureContinuationCallback, data: u64) { - self.poll(callback, data) + fn ffi_poll( + self: Arc, + callback: RustFutureContinuationCallback, + data: u64, + blocking_task_queue_handle: Option, + ) { + self.poll(callback, data, blocking_task_queue_handle) } fn ffi_cancel(&self) { diff --git a/uniffi_core/src/ffi/rustfuture/mod.rs b/uniffi_core/src/ffi/rustfuture/mod.rs index 39529f2db1..3d236a04b6 100644 --- a/uniffi_core/src/ffi/rustfuture/mod.rs +++ b/uniffi_core/src/ffi/rustfuture/mod.rs @@ -4,8 +4,11 @@ use std::{future::Future, sync::Arc}; +mod blocking_task_queue; mod future; mod scheduler; + +pub use blocking_task_queue::*; use future::*; use scheduler::*; @@ -28,7 +31,16 @@ pub enum RustFuturePoll { /// /// The Rust side of things calls this when the foreign side should call [rust_future_poll] again /// to continue progress on the future. -pub type RustFutureContinuationCallback = extern "C" fn(callback_data: u64, RustFuturePoll); +/// +/// WARNING: the call to [rust_future_poll] must be scheduled to happen soon after the callback is +/// called, but not inside the callback itself. If [rust_future_poll] is called inside the +/// callback, some futures will deadlock and our scheduler code might as well. +/// +/// * `callback_data` is the handle that the foreign code passed to `poll()` +/// * `poll_result` is the result of the poll +/// * If `blocking_task_task_queue` is non-zero, it's the BlockingTaskQueue handle that the next `poll()` should run on +pub type RustFutureContinuationCallback = + extern "C" fn(callback_data: u64, poll_result: RustFuturePoll, blocking_task_queue_handle: u64); // === Public FFI API === @@ -62,6 +74,9 @@ where /// a [RustFuturePoll] value. For each [rust_future_poll] call the continuation will be called /// exactly once. /// +/// If this is running in a BlockingTaskQueue, then `blocking_task_queue_handle` must be the handle +/// for it. If not, `blocking_task_queue_handle` must be `0`. +/// /// # Safety /// /// The [Handle] must not previously have been passed to [rust_future_free] @@ -69,10 +84,15 @@ pub unsafe fn rust_future_poll( handle: Handle, callback: RustFutureContinuationCallback, data: u64, + blocking_task_queue_handle: u64, ) where dyn RustFutureFfi: HandleAlloc, { - as HandleAlloc>::get_arc(handle).ffi_poll(callback, data) + as HandleAlloc>::get_arc(handle).ffi_poll( + callback, + data, + blocking_task_queue_handle.try_into().ok(), + ) } /// Cancel a Rust future diff --git a/uniffi_core/src/ffi/rustfuture/scheduler.rs b/uniffi_core/src/ffi/rustfuture/scheduler.rs index 629ee0c109..26526701d3 100644 --- a/uniffi_core/src/ffi/rustfuture/scheduler.rs +++ b/uniffi_core/src/ffi/rustfuture/scheduler.rs @@ -2,10 +2,89 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ -use std::mem; +use std::{cell::RefCell, future::poll_fn, mem, num::NonZeroU64, task::Poll, thread_local}; use super::{RustFutureContinuationCallback, RustFuturePoll}; +/// Context of the current `RustFuture::poll` call +struct RustFutureContext { + /// Blocking task queue that the future is being polled on + current_blocking_task_queue_handle: Option, + /// Blocking task queue that we've been asked to schedule the next poll on + scheduled_blocking_task_queue_handle: Option, +} + +thread_local! { + static CONTEXT: RefCell = RefCell::new(RustFutureContext { + current_blocking_task_queue_handle: None, + scheduled_blocking_task_queue_handle: None, + }); +} + +fn with_context R, R>(operation: F) -> R { + CONTEXT.with(|context| operation(&mut context.borrow_mut())) +} + +pub fn on_poll_start(current_blocking_task_queue_handle: Option) { + with_context(|context| { + *context = RustFutureContext { + current_blocking_task_queue_handle, + scheduled_blocking_task_queue_handle: None, + } + }); +} + +pub fn on_poll_end() { + with_context(|context| { + *context = RustFutureContext { + current_blocking_task_queue_handle: None, + scheduled_blocking_task_queue_handle: None, + } + }); +} + +/// Schedule work in a blocking task queue +/// +/// The returned future will attempt to arrange for [RustFuture::poll] to be called in the +/// blocking task queue. Once [RustFuture::poll] is running in the blocking task queue, then the future +/// will be ready. +/// +/// There's one tricky issue here: how can we ensure that when the top-level task is run in the +/// blocking task queue, this future will be polled? What happens this future is a child of `join!`, +/// `FuturesUnordered` or some other Future that handles its own polling? +/// +/// We start with an assumption: if we notify the waker then this future will be polled when the +/// top-level task is polled next. If a future does not honor this then we consider it a broken +/// future. This seems fair, since that future would almost certainly break a lot of other future +/// code. +/// +/// Based on that, we can have a simple system. When we're polled: +/// * If we're running in the blocking task queue, then we return `Poll::Ready`. +/// * If not, we return `Poll::Pending` and notify the waker so that the future polls again on +/// the next top-level poll. +/// +/// Note that this can be inefficient if the code awaits multiple blocking task queues at once. We +/// can only run the next poll on one of them, but all futures will be woken up. This seems okay +/// for our intended use cases, it would be pretty odd for a library to use multiple blocking task +/// queues. The alternative would be to store the set of all pending blocking task queues, which +/// seems like complete overkill for our purposes. +pub(super) async fn schedule_in_blocking_task_queue(handle: NonZeroU64) { + poll_fn(|future_context| { + with_context(|poll_context| { + if poll_context.current_blocking_task_queue_handle == Some(handle) { + Poll::Ready(()) + } else { + poll_context + .scheduled_blocking_task_queue_handle + .get_or_insert(handle); + future_context.waker().wake_by_ref(); + Poll::Pending + } + }) + }) + .await +} + /// Schedules a [crate::RustFuture] by managing the continuation data /// /// This struct manages the continuation callback and data that comes from the foreign side. It @@ -41,21 +120,34 @@ impl Scheduler { /// Store new continuation data if we are in the `Empty` state. If we are in the `Waked` or /// `Cancelled` state, call the continuation immediately with the data. pub(super) fn store(&mut self, callback: RustFutureContinuationCallback, data: u64) { + if let Some(blocking_task_queue_handle) = + with_context(|context| context.scheduled_blocking_task_queue_handle) + { + // We were asked to schedule the future in a blocking task queue, call the callback + // rather than storing it + callback( + data, + RustFuturePoll::MaybeReady, + blocking_task_queue_handle.into(), + ); + return; + } + match self { Self::Empty => *self = Self::Set(callback, data), Self::Set(old_callback, old_data) => { log::error!( "store: observed `Self::Set` state. Is poll() being called from multiple threads at once?" ); - old_callback(*old_data, RustFuturePoll::Ready); + old_callback(*old_data, RustFuturePoll::Ready, 0); *self = Self::Set(callback, data); } Self::Waked => { *self = Self::Empty; - callback(data, RustFuturePoll::MaybeReady); + callback(data, RustFuturePoll::MaybeReady, 0); } Self::Cancelled => { - callback(data, RustFuturePoll::Ready); + callback(data, RustFuturePoll::Ready, 0); } } } @@ -67,7 +159,7 @@ impl Scheduler { let old_data = *old_data; let callback = *callback; *self = Self::Empty; - callback(old_data, RustFuturePoll::MaybeReady); + callback(old_data, RustFuturePoll::MaybeReady, 0); } // If we were in the `Empty` state, then transition to `Waked`. The next time `store` // is called, we will immediately call the continuation. @@ -79,7 +171,13 @@ impl Scheduler { pub(super) fn cancel(&mut self) { if let Self::Set(callback, old_data) = mem::replace(self, Self::Cancelled) { - callback(old_data, RustFuturePoll::Ready); + callback(old_data, RustFuturePoll::Ready, 0); + } + } + + pub(super) fn clear_wake_flag(&mut self) { + if let Self::Waked = self { + *self = Self::Empty } } diff --git a/uniffi_core/src/ffi/rustfuture/tests.rs b/uniffi_core/src/ffi/rustfuture/tests.rs index 886ee27c71..d4d111e44a 100644 --- a/uniffi_core/src/ffi/rustfuture/tests.rs +++ b/uniffi_core/src/ffi/rustfuture/tests.rs @@ -65,16 +65,34 @@ fn channel() -> (Sender, Arc>) { } /// Poll a Rust future and get an OnceCell that's set when the continuation is called -fn poll(rust_future: &Arc>) -> Arc> { +fn poll(rust_future: &Arc>) -> Arc> { let cell = Arc::new(OnceCell::new()); let handle = Arc::into_raw(cell.clone()) as u64; - rust_future.clone().ffi_poll(poll_continuation, handle); + rust_future + .clone() + .ffi_poll(poll_continuation, handle, None); cell } -extern "C" fn poll_continuation(data: u64, code: RustFuturePoll) { - let cell = unsafe { Arc::from_raw(data as *const OnceCell) }; - cell.set(code).expect("Error setting OnceCell"); +/// Like poll, but simulate `poll()` being called from a blocking task queue +fn poll_from_blocking_task_queue( + rust_future: &Arc>, + blocking_task_queue_handle: u64, +) -> Arc> { + let cell = Arc::new(OnceCell::new()); + let handle = Arc::into_raw(cell.clone()) as u64; + rust_future.clone().ffi_poll( + poll_continuation, + handle, + Some(blocking_task_queue_handle.try_into().unwrap()), + ); + cell +} + +extern "C" fn poll_continuation(data: u64, code: RustFuturePoll, blocking_task_queue_handle: u64) { + let cell = unsafe { Arc::from_raw(data as *const OnceCell<(RustFuturePoll, u64)>) }; + cell.set((code, blocking_task_queue_handle)) + .expect("Error setting OnceCell"); } fn complete(rust_future: Arc>) -> (RustBuffer, RustCallStatus) { @@ -83,25 +101,47 @@ fn complete(rust_future: Arc>) -> (RustBuffer, Rus (return_value, out_status_code) } +fn check_continuation_not_called(once_cell: &OnceCell<(RustFuturePoll, u64)>) { + assert_eq!(once_cell.get(), None); +} + +fn check_continuation_called( + once_cell: &OnceCell<(RustFuturePoll, u64)>, + poll_result: RustFuturePoll, +) { + assert_eq!(once_cell.get(), Some(&(poll_result, 0))); +} + +fn check_continuation_called_with_blocking_task_queue_handle( + once_cell: &OnceCell<(RustFuturePoll, u64)>, + poll_result: RustFuturePoll, + blocking_task_queue_handle: u64, +) { + assert_eq!( + once_cell.get(), + Some(&(poll_result, blocking_task_queue_handle)) + ) +} + #[test] fn test_success() { let (sender, rust_future) = channel(); // Test polling the rust future before it's ready let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), None); + check_continuation_not_called(&continuation_result); sender.wake(); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::MaybeReady)); + check_continuation_called(&continuation_result, RustFuturePoll::MaybeReady); // Test polling the rust future when it's ready let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), None); + check_continuation_not_called(&continuation_result); sender.send(Ok("All done".into())); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::MaybeReady)); + check_continuation_called(&continuation_result, RustFuturePoll::MaybeReady); // Future polls should immediately return ready let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); // Complete the future let (return_buf, call_status) = complete(rust_future); @@ -117,12 +157,12 @@ fn test_error() { let (sender, rust_future) = channel(); let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), None); + check_continuation_not_called(&continuation_result); sender.send(Err("Something went wrong".into())); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::MaybeReady)); + check_continuation_called(&continuation_result, RustFuturePoll::MaybeReady); let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); let (_, call_status) = complete(rust_future); assert_eq!(call_status.code, RustCallStatusCode::Error); @@ -144,14 +184,14 @@ fn test_cancel() { let (_sender, rust_future) = channel(); let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), None); + check_continuation_not_called(&continuation_result); rust_future.ffi_cancel(); // Cancellation should immediately invoke the callback with RustFuturePoll::Ready - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); // Future polls should immediately invoke the callback with RustFuturePoll::Ready let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); let (_, call_status) = complete(rust_future); assert_eq!(call_status.code, RustCallStatusCode::Cancelled); @@ -187,7 +227,7 @@ fn test_complete_with_stored_continuation() { let continuation_result = poll(&rust_future); rust_future.ffi_free(); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); } // Test what happens if we see a `wake()` call while we're polling the future. This can @@ -210,10 +250,47 @@ fn test_wake_during_poll() { let rust_future: Arc> = RustFuture::new(future, crate::UniFfiTag); let continuation_result = poll(&rust_future); // The continuation function should called immediately - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::MaybeReady)); + check_continuation_called(&continuation_result, RustFuturePoll::MaybeReady); // A second poll should finish the future let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); + let (return_buf, call_status) = complete(rust_future); + assert_eq!(call_status.code, RustCallStatusCode::Success); + assert_eq!( + >::try_lift(return_buf).unwrap(), + "All done" + ); +} + +#[test] +fn test_blocking_task() { + let blocking_task_queue_handle = 1001; + let future = async move { + schedule_in_blocking_task_queue(blocking_task_queue_handle.try_into().unwrap()).await; + "All done".to_owned() + }; + let rust_future: Arc> = RustFuture::new(future, crate::UniFfiTag); + // On the first poll, the future should not be ready and it should ask to be scheduled in the + // blocking task queue + let continuation_result = poll(&rust_future); + check_continuation_called_with_blocking_task_queue_handle( + &continuation_result, + RustFuturePoll::MaybeReady, + blocking_task_queue_handle, + ); + // If we poll it again not in a blocking task queue, then we get the same result + let continuation_result = poll(&rust_future); + check_continuation_called_with_blocking_task_queue_handle( + &continuation_result, + RustFuturePoll::MaybeReady, + blocking_task_queue_handle, + ); + // When we poll it in the blocking task queue, then the future is ready + let continuation_result = + poll_from_blocking_task_queue(&rust_future, blocking_task_queue_handle); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); + + // Complete the future let (return_buf, call_status) = complete(rust_future); assert_eq!(call_status.code, RustCallStatusCode::Success); assert_eq!( diff --git a/uniffi_core/src/ffi_converter_impls.rs b/uniffi_core/src/ffi_converter_impls.rs index aec093154a..3866f84eda 100644 --- a/uniffi_core/src/ffi_converter_impls.rs +++ b/uniffi_core/src/ffi_converter_impls.rs @@ -23,8 +23,9 @@ /// "UT" means an arbitrary `UniFfiTag` type. use crate::{ check_remaining, derive_ffi_traits, ffi_converter_rust_buffer_lift_and_lower, metadata, - ConvertError, FfiConverter, Lift, LiftRef, LiftReturn, Lower, LowerReturn, MetadataBuffer, - Result, RustBuffer, UnexpectedUniFFICallbackError, + BlockingTaskQueue, BlockingTaskQueueVTable, ConvertError, FfiConverter, Lift, LiftRef, + LiftReturn, Lower, LowerReturn, MetadataBuffer, Result, RustBuffer, + UnexpectedUniFFICallbackError, }; use anyhow::bail; use bytes::buf::{Buf, BufMut}; @@ -244,6 +245,35 @@ unsafe impl FfiConverter for Duration { const TYPE_ID_META: MetadataBuffer = MetadataBuffer::from_code(metadata::codes::TYPE_DURATION); } +/// Support for passing [BlockingTaskQueue] across the FFI +/// +/// Both fields of [BlockingTaskQueue] are serialized into a RustBuffer. The vtable pointer is +/// casted to a u64. +unsafe impl FfiConverter for BlockingTaskQueue { + ffi_converter_rust_buffer_lift_and_lower!(UT); + + fn write(obj: BlockingTaskQueue, buf: &mut Vec) { + let obj = obj.clone(); + buf.put_u64(obj.handle.into()); + buf.put_u64(obj.vtable as *const BlockingTaskQueueVTable as u64); + } + + fn try_read(buf: &mut &[u8]) -> Result { + check_remaining(buf, 16)?; + let handle = buf + .get_u64() + .try_into() + .expect("handle = 0 when reading BlockingTaskQueue"); + let vtable = unsafe { + &*(buf.get_u64() as *const BlockingTaskQueueVTable) as &'static BlockingTaskQueueVTable + }; + Ok(Self { handle, vtable }) + } + + const TYPE_ID_META: MetadataBuffer = + MetadataBuffer::from_code(metadata::codes::TYPE_BLOCKING_TASK_QUEUE); +} + // Support for passing optional values via the FFI. // // Optional values are currently always passed by serializing to a buffer. @@ -419,6 +449,7 @@ derive_ffi_traits!(blanket bool); derive_ffi_traits!(blanket String); derive_ffi_traits!(blanket Duration); derive_ffi_traits!(blanket SystemTime); +derive_ffi_traits!(blanket BlockingTaskQueue); // For composite types, derive LowerReturn, LiftReturn, etc, from Lift/Lower. // diff --git a/uniffi_core/src/metadata.rs b/uniffi_core/src/metadata.rs index b18036d533..0d2852a4ad 100644 --- a/uniffi_core/src/metadata.rs +++ b/uniffi_core/src/metadata.rs @@ -69,6 +69,7 @@ pub mod codes { pub const TYPE_RESULT: u8 = 23; pub const TYPE_TRAIT_INTERFACE: u8 = 24; pub const TYPE_CALLBACK_TRAIT_INTERFACE: u8 = 25; + pub const TYPE_BLOCKING_TASK_QUEUE: u8 = 26; pub const TYPE_UNIT: u8 = 255; // Literal codes for LiteralMetadata - note that we don't support diff --git a/uniffi_macros/src/setup_scaffolding.rs b/uniffi_macros/src/setup_scaffolding.rs index c82e9389bb..4deac6792e 100644 --- a/uniffi_macros/src/setup_scaffolding.rs +++ b/uniffi_macros/src/setup_scaffolding.rs @@ -166,8 +166,13 @@ fn rust_future_scaffolding_fns(module_path: &str) -> TokenStream { #[allow(clippy::missing_safety_doc, missing_docs)] #[doc(hidden)] #[no_mangle] - pub unsafe extern "C" fn #ffi_rust_future_poll(handle: ::uniffi::Handle, callback: ::uniffi::RustFutureContinuationCallback, data: u64) { - ::uniffi::ffi::rust_future_poll::<#return_type, crate::UniFfiTag>(handle, callback, data); + pub unsafe extern "C" fn #ffi_rust_future_poll( + handle: ::uniffi::Handle, + callback: ::uniffi::RustFutureContinuationCallback, + data: u64, + blocking_task_queue_handle: u64, + ) { + ::uniffi::ffi::rust_future_poll::<#return_type, crate::UniFfiTag>(handle, callback, data, blocking_task_queue_handle); } #[allow(clippy::missing_safety_doc, missing_docs)] diff --git a/uniffi_meta/src/metadata.rs b/uniffi_meta/src/metadata.rs index 66c2c63952..b553e0060e 100644 --- a/uniffi_meta/src/metadata.rs +++ b/uniffi_meta/src/metadata.rs @@ -52,6 +52,7 @@ pub mod codes { pub const TYPE_RESULT: u8 = 23; pub const TYPE_TRAIT_INTERFACE: u8 = 24; pub const TYPE_CALLBACK_TRAIT_INTERFACE: u8 = 25; + pub const TYPE_BLOCKING_TASK_QUEUE: u8 = 26; pub const TYPE_UNIT: u8 = 255; // Literal codes diff --git a/uniffi_meta/src/reader.rs b/uniffi_meta/src/reader.rs index f51bb786f8..f219871f8e 100644 --- a/uniffi_meta/src/reader.rs +++ b/uniffi_meta/src/reader.rs @@ -145,6 +145,7 @@ impl<'a> MetadataReader<'a> { codes::TYPE_STRING => Type::String, codes::TYPE_DURATION => Type::Duration, codes::TYPE_SYSTEM_TIME => Type::Timestamp, + codes::TYPE_BLOCKING_TASK_QUEUE => Type::BlockingTaskQueue, codes::TYPE_RECORD => Type::Record { module_path: self.read_string()?, name: self.read_string()?, diff --git a/uniffi_meta/src/types.rs b/uniffi_meta/src/types.rs index 51bf156b50..5003dd9f77 100644 --- a/uniffi_meta/src/types.rs +++ b/uniffi_meta/src/types.rs @@ -88,6 +88,7 @@ pub enum Type { // How the object is implemented. imp: ObjectImpl, }, + BlockingTaskQueue, // Types defined in the component API, each of which has a string name. Record { module_path: String, diff --git a/uniffi_udl/src/resolver.rs b/uniffi_udl/src/resolver.rs index ea98cd7a99..1409c3a6ff 100644 --- a/uniffi_udl/src/resolver.rs +++ b/uniffi_udl/src/resolver.rs @@ -209,6 +209,7 @@ pub(crate) fn resolve_builtin_type(name: &str) -> Option { "f64" => Some(Type::Float64), "timestamp" => Some(Type::Timestamp), "duration" => Some(Type::Duration), + "BlockingTaskQueue" => Some(Type::BlockingTaskQueue), _ => None, } }