Skip to content

Commit 7ece07d

Browse files
feat: implement missing async arg result handling (bytecodealliance#1342)
This commit implements some (but not all) of the async argument and parameter handling that is emitted for function calls.
1 parent 454d688 commit 7ece07d

File tree

1 file changed

+170
-68
lines changed

1 file changed

+170
-68
lines changed

crates/core/src/abi.rs

Lines changed: 170 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
use std::fmt;
2+
use std::iter;
3+
24
pub use wit_parser::abi::{AbiVariant, FlatTypes, WasmSignature, WasmType};
35
use wit_parser::{
46
align_to_arch, Alignment, ArchitectureSize, ElementInfo, Enum, Flags, FlagsRepr, Function,
@@ -920,6 +922,7 @@ struct Generator<'a, B: Bindgen> {
920922
}
921923

922924
const MAX_FLAT_PARAMS: usize = 16;
925+
const MAX_FLAT_ASYNC_PARAMS: usize = 4;
923926

924927
impl<'a, B: Bindgen> Generator<'a, B> {
925928
fn new(resolve: &'a Resolve, bindgen: &'a mut B) -> Generator<'a, B> {
@@ -1075,57 +1078,93 @@ impl<'a, B: Bindgen> Generator<'a, B> {
10751078
amt: usize::from(func.result.is_some()),
10761079
});
10771080
}
1081+
10781082
LiftLower::LiftArgsLowerResults => {
1079-
if let (AbiVariant::GuestImport, true) = (variant, async_) {
1080-
todo!("implement host-side support for async lift/lower");
1081-
}
1083+
let max_flat_params = match (variant, async_) {
1084+
(AbiVariant::GuestImport | AbiVariant::GuestImportAsync, _is_async @ true) => {
1085+
MAX_FLAT_ASYNC_PARAMS
1086+
}
1087+
_ => MAX_FLAT_PARAMS,
1088+
};
10821089

1090+
// Read parameters from memory
10831091
let read_from_memory = |self_: &mut Self| {
10841092
let mut offset = ArchitectureSize::default();
1085-
let ptr = self_.stack.pop().unwrap();
1093+
let ptr = self_
1094+
.stack
1095+
.pop()
1096+
.expect("empty stack during read param from memory");
10861097
for (_, ty) in func.params.iter() {
10871098
offset = align_to_arch(offset, self_.bindgen.sizes().align(ty));
10881099
self_.read_from_memory(ty, ptr.clone(), offset);
10891100
offset += self_.bindgen.sizes().size(ty);
10901101
}
10911102
};
10921103

1093-
if !sig.indirect_params {
1094-
// If parameters are not passed indirectly then we lift each
1104+
// Resolve parameters
1105+
if sig.indirect_params {
1106+
// If parameters were passed indirectly, arguments must be
1107+
// read in succession from memory, with the pointer to the arguments
1108+
// being the first argument to the function.
1109+
self.emit(&Instruction::GetArg { nth: 0 });
1110+
read_from_memory(self);
1111+
} else {
1112+
// ... otherwise, if parameters were passed directly then we lift each
10951113
// argument in succession from the component wasm types that
10961114
// make-up the type.
10971115
let mut offset = 0;
10981116
for (_, ty) in func.params.iter() {
1099-
let types = flat_types(self.resolve, ty).unwrap();
1117+
let types = flat_types(self.resolve, ty, Some(max_flat_params))
1118+
.expect("direct parameter load failed to produce types during generation of fn call");
11001119
for _ in 0..types.len() {
11011120
self.emit(&Instruction::GetArg { nth: offset });
11021121
offset += 1;
11031122
}
11041123
self.lift(ty);
11051124
}
1106-
} else {
1107-
// ... otherwise argument is read in succession from memory
1108-
// where the pointer to the arguments is the first argument
1109-
// to the function.
1110-
self.emit(&Instruction::GetArg { nth: 0 });
1111-
read_from_memory(self);
11121125
}
11131126

11141127
// ... and that allows us to call the interface types function
11151128
self.emit(&Instruction::CallInterface { func, async_ });
11161129

1130+
// The return value of an async function is *not* the result of the function
1131+
// itself or a pointer but rather a status code.
1132+
//
11171133
// Asynchronous functions will call `task.return` after the
11181134
// interface function completes, so lowering is conditional
11191135
// based on slightly different logic for the `task.return`
11201136
// intrinsic.
1121-
let (lower_to_memory, async_flat_results) = if async_ {
1122-
let results = match &func.result {
1123-
Some(ty) => flat_types(self.resolve, ty),
1124-
None => Some(Vec::new()),
1125-
};
1126-
(results.is_none(), Some(results))
1127-
} else {
1128-
(sig.retptr, None)
1137+
let (lower_to_memory, async_flat_results) = match (variant, async_, &func.result) {
1138+
// Async guest imports return a i32 status code
1139+
(
1140+
AbiVariant::GuestImport | AbiVariant::GuestImportAsync,
1141+
_is_async @ true,
1142+
None,
1143+
) => {
1144+
unreachable!("async guest imports always return a result")
1145+
}
1146+
// Async guest imports return a i32 status code
1147+
(
1148+
AbiVariant::GuestImport | AbiVariant::GuestImportAsync,
1149+
_is_async @ true,
1150+
Some(ty),
1151+
) => {
1152+
// For async guest imports, we know whether we must lower results
1153+
// if there are no params (i.e. the usual out pointer wasn't even required)
1154+
// and we always know the return value will be a i32 status code
1155+
assert!(matches!(ty, Type::U32 | Type::S32));
1156+
(sig.params.is_empty(), Some(Some(vec![WasmType::I32])))
1157+
}
1158+
// All other async cases
1159+
(_, _is_async @ true, func_result) => {
1160+
let results = match &func_result {
1161+
Some(ty) => flat_types(self.resolve, ty, Some(max_flat_params)),
1162+
None => Some(Vec::new()),
1163+
};
1164+
(results.is_none(), Some(results))
1165+
}
1166+
// All other non-async cases
1167+
(_, _is_async @ false, _) => (sig.retptr, None),
11291168
};
11301169

11311170
// This was dynamically allocated by the caller (or async start
@@ -1147,33 +1186,59 @@ impl<'a, B: Bindgen> Generator<'a, B> {
11471186

11481187
self.realloc = Some(realloc);
11491188

1150-
if !lower_to_memory {
1151-
// With no return pointer in use we simply lower the
1152-
// result(s) and return that directly from the function.
1153-
if let Some(ty) = &func.result {
1154-
self.lower(ty);
1189+
// Perform memory lowing of relevant results, including out pointers as well as traditional results
1190+
match (lower_to_memory, sig.retptr, variant) {
1191+
// Async guest imports with do no lowering cannot have ret pointers
1192+
// not having to do lowering implies that there was no return pointer provided
1193+
(_lower_to_memory @ false, _has_ret_ptr @ true, AbiVariant::GuestImport)
1194+
if async_ =>
1195+
{
1196+
unreachable!(
1197+
"async guest import cannot avoid lowering when a ret ptr is present ({async_note} func [{func_name}], variant {variant:#?})",
1198+
async_note = async_.then_some("async").unwrap_or("sync"),
1199+
func_name = func.name,
1200+
)
11551201
}
1156-
} else {
1157-
match variant {
1158-
// When a function is imported to a guest this means
1159-
// it's a host providing the implementation of the
1160-
// import. The result is stored in the pointer
1161-
// specified in the last argument, so we get the
1162-
// pointer here and then write the return value into
1163-
// it.
1164-
AbiVariant::GuestImport => {
1165-
self.emit(&Instruction::GetArg {
1166-
nth: sig.params.len() - 1,
1167-
});
1168-
let ptr = self.stack.pop().unwrap();
1169-
self.write_params_to_memory(&func.result, ptr, Default::default());
1202+
1203+
// For sync calls, if no lowering to memory is required and there *is* a return pointer in use
1204+
// then we need to lower then simply lower the result(s) and return that directly from the function.
1205+
(_lower_to_memory @ false, _, _) => {
1206+
if let Some(ty) = &func.result {
1207+
self.lower(ty);
11701208
}
1209+
}
1210+
1211+
// Lowering to memory for a guest import
1212+
//
1213+
// When a function is imported to a guest this means
1214+
// it's a host providing the implementation of the
1215+
// import. The result is stored in the pointer
1216+
// specified in the last argument, so we get the
1217+
// pointer here and then write the return value into
1218+
// it.
1219+
(
1220+
_lower_to_memory @ true,
1221+
_has_ret_ptr @ true,
1222+
AbiVariant::GuestImport | AbiVariant::GuestImportAsync,
1223+
) => {
1224+
self.emit(&Instruction::GetArg {
1225+
nth: sig.params.len() - 1,
1226+
});
1227+
let ptr = self
1228+
.stack
1229+
.pop()
1230+
.expect("empty stack during result lower to memory");
1231+
self.write_params_to_memory(&func.result, ptr, Default::default());
1232+
}
11711233

1172-
// For a guest import this is a function defined in
1173-
// wasm, so we're returning a pointer where the
1174-
// value was stored at. Allocate some space here
1175-
// (statically) and then write the result into that
1176-
// memory, returning the pointer at the end.
1234+
// Lowering to memory for a guest export
1235+
//
1236+
// For a guest import this is a function defined in
1237+
// wasm, so we're returning a pointer where the
1238+
// value was stored at. Allocate some space here
1239+
// (statically) and then write the result into that
1240+
// memory, returning the pointer at the end.
1241+
(_lower_to_memory @ true, _, variant) => match variant {
11771242
AbiVariant::GuestExport | AbiVariant::GuestExportAsync => {
11781243
let ElementInfo { size, align } =
11791244
self.bindgen.sizes().params(&func.result);
@@ -1185,24 +1250,56 @@ impl<'a, B: Bindgen> Generator<'a, B> {
11851250
);
11861251
self.stack.push(ptr);
11871252
}
1188-
1189-
AbiVariant::GuestImportAsync | AbiVariant::GuestExportAsyncStackful => {
1190-
unreachable!()
1253+
AbiVariant::GuestImport | AbiVariant::GuestImportAsync => {
1254+
unreachable!(
1255+
"lowering to memory cannot be performed without a return pointer ({async_note} func [{func_name}], variant {variant:#?})",
1256+
async_note = async_.then_some("async").unwrap_or("sync"),
1257+
func_name = func.name,
1258+
)
11911259
}
1192-
}
1260+
AbiVariant::GuestExportAsyncStackful => {
1261+
todo!("stackful exports are not yet supported")
1262+
}
1263+
},
11931264
}
11941265

1195-
if let Some(results) = async_flat_results {
1196-
let name = &format!("[task-return]{}", func.name);
1197-
let params = results.as_deref().unwrap_or(&[WasmType::Pointer]);
1266+
// Build and emit the appropriate return
1267+
match (variant, async_flat_results) {
1268+
// Async guest imports always return a i32 status code
1269+
(AbiVariant::GuestImport | AbiVariant::GuestImportAsync, None) if async_ => {
1270+
unreachable!("async guest imports must have a return")
1271+
}
11981272

1199-
self.emit(&Instruction::AsyncTaskReturn { name, params });
1200-
} else {
1201-
self.emit(&Instruction::Return {
1202-
func,
1203-
amt: sig.results.len(),
1204-
});
1273+
// Async guest imports with results return the status code, not a pointer to any results
1274+
(AbiVariant::GuestImport | AbiVariant::GuestImportAsync, Some(results))
1275+
if async_ =>
1276+
{
1277+
let name = &format!("[task-return]{}", func.name);
1278+
let params = results.as_deref().unwrap_or(&[WasmType::I32]);
1279+
self.emit(&Instruction::AsyncTaskReturn { name, params });
1280+
}
1281+
1282+
// All async/non-async cases with results that need to be returned are present here
1283+
//
1284+
// In practice, async imports should not end up here, as the returned result of an
1285+
// async import is *not* a pointer but instead a status code.
1286+
(_, Some(results)) => {
1287+
let name = &format!("[task-return]{}", func.name);
1288+
let params = results.as_deref().unwrap_or(&[WasmType::Pointer]);
1289+
self.emit(&Instruction::AsyncTaskReturn { name, params });
1290+
}
1291+
1292+
// All async/non-async cases with no results simply return
1293+
//
1294+
// In practice, an async import will never get here (it always has a result, the error code)
1295+
(_, None) => {
1296+
self.emit(&Instruction::Return {
1297+
func,
1298+
amt: sig.results.len(),
1299+
});
1300+
}
12051301
}
1302+
12061303
self.realloc = None;
12071304
}
12081305
}
@@ -1257,7 +1354,7 @@ impl<'a, B: Bindgen> Generator<'a, B> {
12571354
let mut operands = operands;
12581355
let mut operands_for_ty;
12591356
for ty in types {
1260-
let types = flat_types(self.resolve, ty).unwrap();
1357+
let types = flat_types(self.resolve, ty, None).unwrap();
12611358
(operands_for_ty, operands) = operands.split_at(types.len());
12621359
self.stack.extend_from_slice(operands_for_ty);
12631360
self.deallocate(ty, what);
@@ -1455,7 +1552,7 @@ impl<'a, B: Bindgen> Generator<'a, B> {
14551552
cases: impl IntoIterator<Item = Option<&'b Type>>,
14561553
) -> Vec<WasmType> {
14571554
use Instruction::*;
1458-
let results = flat_types(self.resolve, ty).unwrap();
1555+
let results = flat_types(self.resolve, ty, None).unwrap();
14591556
let mut casts = Vec::new();
14601557
for (i, ty) in cases.into_iter().enumerate() {
14611558
self.push_block();
@@ -1472,7 +1569,7 @@ impl<'a, B: Bindgen> Generator<'a, B> {
14721569
// Determine the types of all the wasm values we just
14731570
// pushed, and record how many. If we pushed too few
14741571
// then we'll need to push some zeros after this.
1475-
let temp = flat_types(self.resolve, ty).unwrap();
1572+
let temp = flat_types(self.resolve, ty, None).unwrap();
14761573
pushed += temp.len();
14771574

14781575
// For all the types pushed we may need to insert some
@@ -1638,13 +1735,13 @@ impl<'a, B: Bindgen> Generator<'a, B> {
16381735
types: impl Iterator<Item = &'b Type>,
16391736
mut iter: impl FnMut(&mut Self, &Type),
16401737
) {
1641-
let temp = flat_types(self.resolve, container).unwrap();
1738+
let temp = flat_types(self.resolve, container, None).unwrap();
16421739
let mut args = self
16431740
.stack
16441741
.drain(self.stack.len() - temp.len()..)
16451742
.collect::<Vec<_>>();
16461743
for ty in types {
1647-
let temp = flat_types(self.resolve, ty).unwrap();
1744+
let temp = flat_types(self.resolve, ty, None).unwrap();
16481745
self.stack.extend(args.drain(..temp.len()));
16491746
iter(self, ty);
16501747
}
@@ -1657,7 +1754,7 @@ impl<'a, B: Bindgen> Generator<'a, B> {
16571754
cases: impl IntoIterator<Item = Option<&'b Type>>,
16581755
mut iter: impl FnMut(&mut Self, &Type),
16591756
) {
1660-
let params = flat_types(self.resolve, ty).unwrap();
1757+
let params = flat_types(self.resolve, ty, None).unwrap();
16611758
let mut casts = Vec::new();
16621759
let block_inputs = self
16631760
.stack
@@ -1668,7 +1765,7 @@ impl<'a, B: Bindgen> Generator<'a, B> {
16681765
if let Some(ty) = ty {
16691766
// Push only the values we need for this variant onto
16701767
// the stack.
1671-
let temp = flat_types(self.resolve, ty).unwrap();
1768+
let temp = flat_types(self.resolve, ty, None).unwrap();
16721769
self.stack
16731770
.extend(block_inputs[..temp.len()].iter().cloned());
16741771

@@ -2399,9 +2496,14 @@ fn cast(from: WasmType, to: WasmType) -> Bitcast {
23992496
}
24002497
}
24012498

2402-
fn flat_types(resolve: &Resolve, ty: &Type) -> Option<Vec<WasmType>> {
2403-
let mut storage = [WasmType::I32; MAX_FLAT_PARAMS];
2404-
let mut flat = FlatTypes::new(&mut storage);
2499+
/// Flatten types in a given type
2500+
///
2501+
/// It is sometimes necessary to restrict the number of max parameters dynamically,
2502+
/// for example during an async guest import call (flat params are limited to 4)
2503+
fn flat_types(resolve: &Resolve, ty: &Type, max_params: Option<usize>) -> Option<Vec<WasmType>> {
2504+
let mut storage =
2505+
iter::repeat_n(WasmType::I32, max_params.unwrap_or(MAX_FLAT_PARAMS)).collect::<Vec<_>>();
2506+
let mut flat = FlatTypes::new(storage.as_mut_slice());
24052507
if resolve.push_flat(ty, &mut flat) {
24062508
Some(flat.to_vec())
24072509
} else {

0 commit comments

Comments
 (0)