diff --git a/crates/wasmparser/src/readers/component/types.rs b/crates/wasmparser/src/readers/component/types.rs index c33feaa954..a1c2372ad7 100644 --- a/crates/wasmparser/src/readers/component/types.rs +++ b/crates/wasmparser/src/readers/component/types.rs @@ -186,7 +186,7 @@ impl PrimitiveValType { }) } - pub(crate) fn requires_realloc(&self) -> bool { + pub(crate) fn contains_ptr(&self) -> bool { matches!(self, Self::String) } diff --git a/crates/wasmparser/src/validator/types.rs b/crates/wasmparser/src/validator/types.rs index 222172f08d..84bf5de34b 100644 --- a/crates/wasmparser/src/validator/types.rs +++ b/crates/wasmparser/src/validator/types.rs @@ -322,10 +322,10 @@ pub enum ComponentValType { } impl ComponentValType { - pub(crate) fn requires_realloc(&self, types: &TypeList) -> bool { + pub(crate) fn contains_ptr(&self, types: &TypeList) -> bool { match self { - ComponentValType::Primitive(ty) => ty.requires_realloc(), - ComponentValType::Type(ty) => types[*ty].unwrap_defined().requires_realloc(types), + ComponentValType::Primitive(ty) => ty.contains_ptr(), + ComponentValType::Type(ty) => types[*ty].unwrap_defined().contains_ptr(types), } } @@ -660,14 +660,24 @@ pub struct ComponentFuncType { impl ComponentFuncType { /// Lowers the component function type to core parameter and result types for the /// canonical ABI. - pub(crate) fn lower(&self, types: &TypeList, import: bool) -> LoweringInfo { + pub(crate) fn lower(&self, types: &TypeList, is_lower: bool) -> LoweringInfo { let mut info = LoweringInfo::default(); for (_, ty) in self.params.iter() { - // When `import` is false, it means we're lifting a core function, - // check if the parameters needs realloc - if !import && !info.requires_realloc { - info.requires_realloc = ty.requires_realloc(types); + // Check to see if `ty` has a pointer somewhere in it, needed for + // any type that transitively contains either a string or a list. + // In this situation lowered functions must specify `memory`, and + // lifted functions must specify `realloc` as well. Lifted functions + // gain their memory requirement through the final clause of this + // function. + if is_lower { + if !info.requires_memory { + info.requires_memory = ty.contains_ptr(types); + } + } else { + if !info.requires_realloc { + info.requires_realloc = ty.contains_ptr(types); + } } if !ty.push_wasm_types(types, &mut info.params) { @@ -679,7 +689,7 @@ impl ComponentFuncType { info.requires_memory = true; // We need realloc as well when lifting a function - if !import { + if !is_lower { info.requires_realloc = true; } break; @@ -687,17 +697,19 @@ impl ComponentFuncType { } for (_, ty) in self.results.iter() { - // When `import` is true, it means we're lowering a component function, - // check if the result needs realloc - if import && !info.requires_realloc { - info.requires_realloc = ty.requires_realloc(types); + // Results of lowered functions that contains pointers must be + // allocated by the callee meaning that realloc is required. + // Results of lifted function are allocated by the guest which + // means that no realloc option is necessary. + if is_lower && !info.requires_realloc { + info.requires_realloc = ty.contains_ptr(types); } if !ty.push_wasm_types(types, &mut info.results) { // Too many results to return directly, either a retptr parameter will be used (import) // or a single pointer will be returned (export) info.results.clear(); - if import { + if is_lower { info.params.max = MAX_LOWERED_TYPES; assert!(info.params.push(ValType::I32)); } else { @@ -795,23 +807,22 @@ pub enum ComponentDefinedType { } impl ComponentDefinedType { - pub(crate) fn requires_realloc(&self, types: &TypeList) -> bool { + pub(crate) fn contains_ptr(&self, types: &TypeList) -> bool { match self { - Self::Primitive(ty) => ty.requires_realloc(), - Self::Record(r) => r.fields.values().any(|ty| ty.requires_realloc(types)), - Self::Variant(v) => v.cases.values().any(|case| { - case.ty - .map(|ty| ty.requires_realloc(types)) - .unwrap_or(false) - }), + Self::Primitive(ty) => ty.contains_ptr(), + Self::Record(r) => r.fields.values().any(|ty| ty.contains_ptr(types)), + Self::Variant(v) => v + .cases + .values() + .any(|case| case.ty.map(|ty| ty.contains_ptr(types)).unwrap_or(false)), Self::List(_) => true, - Self::Tuple(t) => t.types.iter().any(|ty| ty.requires_realloc(types)), - Self::Union(u) => u.types.iter().any(|ty| ty.requires_realloc(types)), + Self::Tuple(t) => t.types.iter().any(|ty| ty.contains_ptr(types)), + Self::Union(u) => u.types.iter().any(|ty| ty.contains_ptr(types)), Self::Flags(_) | Self::Enum(_) | Self::Own(_) | Self::Borrow(_) => false, - Self::Option(ty) => ty.requires_realloc(types), + Self::Option(ty) => ty.contains_ptr(types), Self::Result { ok, err } => { - ok.map(|ty| ty.requires_realloc(types)).unwrap_or(false) - || err.map(|ty| ty.requires_realloc(types)).unwrap_or(false) + ok.map(|ty| ty.contains_ptr(types)).unwrap_or(false) + || err.map(|ty| ty.contains_ptr(types)).unwrap_or(false) } } } diff --git a/tests/local/component-model/lower.wast b/tests/local/component-model/lower.wast new file mode 100644 index 0000000000..827645d3e0 --- /dev/null +++ b/tests/local/component-model/lower.wast @@ -0,0 +1,15 @@ +(assert_invalid + (component + (import "f" (func $f (param "x" (list u8)))) + (core func $f (canon lower (func $f) + )) + ) + "canonical option `memory` is required") + +(assert_invalid + (component + (import "f" (func $f (result (list u8)))) + (core func $f (canon lower (func $f) + )) + ) + "canonical option `memory` is required")