From 9f0820a33a938b4061d16a9fe7686b21794297fb Mon Sep 17 00:00:00 2001 From: Adam Chidlow Date: Fri, 13 Sep 2024 16:20:43 +0800 Subject: [PATCH] improve type union error messaging, since some framework functions are typed as having union parameters --- src/puyapy/awst_build/eb/_expect.py | 32 ++++++++++--------- src/puyapy/awst_build/eb/arc4/abi_call.py | 4 +-- src/puyapy/awst_build/eb/arc4/tuple.py | 8 ++--- src/puyapy/awst_build/eb/bytes.py | 7 ++-- src/puyapy/awst_build/eb/compiled.py | 18 +++++------ src/puyapy/awst_build/eb/log.py | 5 +-- src/puyapy/awst_build/eb/string.py | 9 +++--- src/puyapy/awst_build/eb/subroutine.py | 2 +- src/puyapy/awst_build/eb/transaction/inner.py | 5 +-- src/puyapy/awst_build/pytypes.py | 8 +++++ tests/test_execution.py | 4 ++- 11 files changed, 58 insertions(+), 44 deletions(-) diff --git a/src/puyapy/awst_build/eb/_expect.py b/src/puyapy/awst_build/eb/_expect.py index ead6e1644c..d480e03526 100644 --- a/src/puyapy/awst_build/eb/_expect.py +++ b/src/puyapy/awst_build/eb/_expect.py @@ -37,8 +37,7 @@ def at_most_one_arg_of_type( logger.error(f"expected at most 1 argument, got {len(args)}", location=location) if isinstance(first, InstanceBuilder) and is_type_or_subtype(first.pytype, of_any=valid_types): return first - logger.error("unexpected argument type", location=first.source_location) - return None + return not_the_type_of(first, default=default_none) def default_raise(msg: str, location: SourceLocation) -> typing.Never: @@ -68,6 +67,15 @@ def defaulter(msg: str, location: SourceLocation) -> InstanceBuilder: # noqa: A return defaulter +def not_the_type_of(node: NodeBuilder, default: Callable[[str, SourceLocation], _T]) -> _T: + if isinstance(node.pytype, pytypes.UnionType): + msg = "type unions are unsupported at this location" + else: + msg = "unexpected argument type" + logger.error(msg, location=node.source_location) + return default(msg, node.source_location) + + def at_least_one_arg( args: Sequence[_TBuilder], location: SourceLocation, @@ -120,10 +128,7 @@ def exactly_one_arg_of_type( first = maybe_resolve_literal(first, pytype) if isinstance(first, InstanceBuilder) and is_type_or_subtype(first.pytype, of=pytype): return first - msg = "unexpected argument type" - result = default(msg, first.source_location) - logger.error(msg, location=first.source_location) - return result + return not_the_type_of(first, default=default) def exactly_one_arg_of_type_else_dummy( @@ -185,10 +190,7 @@ def argument_of_type( builder.pytype, of_any=(target_type, *additional_types) ): return builder - msg = "unexpected argument type" - result = default(msg, builder.source_location) - logger.error(msg, location=builder.source_location) - return result + return not_the_type_of(builder, default=default) def argument_of_type_else_dummy( @@ -218,11 +220,11 @@ def simple_string_literal( return value case InstanceBuilder(pytype=pytypes.StrLiteralType): msg = "argument must be a simple str literal" - case _: - msg = "unexpected argument type" - result = default(msg, builder.source_location) - logger.error(msg, location=builder.source_location) - return result + result = default(msg, builder.source_location) + logger.error(msg, location=builder.source_location) + return result + case other: + return not_the_type_of(other, default=default) def instance_builder( diff --git a/src/puyapy/awst_build/eb/arc4/abi_call.py b/src/puyapy/awst_build/eb/arc4/abi_call.py index 4cc064ce4a..699095d36e 100644 --- a/src/puyapy/awst_build/eb/arc4/abi_call.py +++ b/src/puyapy/awst_build/eb/arc4/abi_call.py @@ -240,8 +240,8 @@ def call( kind="update" if is_update else "create", location=method_or_type.source_location, ) - case _: - raise CodeError("unexpected argument type", method_or_type.source_location) + case other: + expect.not_the_type_of(other, default=expect.default_raise) if compiled is None: compiled = CompiledContractExpressionBuilder( CompiledContract( diff --git a/src/puyapy/awst_build/eb/arc4/tuple.py b/src/puyapy/awst_build/eb/arc4/tuple.py index 2c236a2984..e73ed09699 100644 --- a/src/puyapy/awst_build/eb/arc4/tuple.py +++ b/src/puyapy/awst_build/eb/arc4/tuple.py @@ -48,8 +48,8 @@ def call( ARC4Encode(value=arg.resolve(), wtype=wtype, source_location=location), typ ) case _: - # don't know expected type - raise CodeError("unexpected argument type", arg.source_location) + # don't know expected type, so raise + expect.not_the_type_of(arg, default=expect.default_raise) class ARC4TupleTypeBuilder(ARC4TypeBuilder[pytypes.TupleType]): @@ -91,8 +91,8 @@ def index(self, index: InstanceBuilder, location: SourceLocation) -> InstanceBui pass case InstanceBuilder(pytype=pytypes.IntLiteralType): raise CodeError("tuple index must be a simple int literal", index.source_location) - case _: - raise CodeError("unexpected argument type", index.source_location) + case other: + expect.not_the_type_of(other, default=expect.default_raise) try: item_typ = self.pytype.items[index_value] except IndexError: diff --git a/src/puyapy/awst_build/eb/bytes.py b/src/puyapy/awst_build/eb/bytes.py index ef5285b2af..590762200b 100644 --- a/src/puyapy/awst_build/eb/bytes.py +++ b/src/puyapy/awst_build/eb/bytes.py @@ -85,9 +85,10 @@ def call( value=b"", encoding=BytesEncoding.unknown, source_location=location ) return BytesExpressionBuilder(value) - case _: - logger.error("unexpected argument type", location=arg.source_location) - return dummy_value(self.produces(), location) + case other: + return expect.not_the_type_of( + other, default=expect.default_dummy_value(self.produces()) + ) @typing.override def member_access(self, name: str, location: SourceLocation) -> NodeBuilder: diff --git a/src/puyapy/awst_build/eb/compiled.py b/src/puyapy/awst_build/eb/compiled.py index ef6d8d0b5d..4adb748cd6 100644 --- a/src/puyapy/awst_build/eb/compiled.py +++ b/src/puyapy/awst_build/eb/compiled.py @@ -163,12 +163,12 @@ def call( case NodeBuilder(pytype=pytypes.TypeType(typ=pytypes.ContractType() as contract_typ)): contract = contract_typ.name case invalid_or_none: - # if None (=missing), then error message already logged by get_arg_mapping - if invalid_or_none is not None: - logger.error( - "unexpected argument type", location=invalid_or_none.source_location - ) - return dummy_value(result_type, location) + if invalid_or_none is None: + # if None (=missing), then error message already logged by get_arg_mapping + return dummy_value(result_type, location) + return expect.not_the_type_of( + invalid_or_none, default=expect.default_dummy_value(result_type) + ) return CompiledContractExpressionBuilder( CompiledContract( @@ -207,9 +207,7 @@ def call( logic_sig = LogicSigReference("") # dummy reference # if None (=missing), then error message already logged by get_arg_mapping if missing_or_invalid is not None: - logger.error( - "unexpected argument type", location=missing_or_invalid.source_location - ) + expect.not_the_type_of(missing_or_invalid, default=expect.default_none) prefix, template_vars = _extract_prefix_template_args(arg_map) return CompiledLogicSigExpressionBuilder( CompiledLogicSig( @@ -232,7 +230,7 @@ def _extract_prefix_template_args( if isinstance(template_vars_node, DictLiteralBuilder): template_vars = {k: v.resolve() for k, v in template_vars_node.mapping.items()} else: - logger.error("unexpected argument type", location=template_vars_node.source_location) + expect.not_the_type_of(template_vars_node, default=expect.default_none) if prefix_node := name_args.get("template_vars_prefix"): prefix = expect.simple_string_literal(prefix_node, default=expect.default_none) return prefix, template_vars diff --git a/src/puyapy/awst_build/eb/log.py b/src/puyapy/awst_build/eb/log.py index 166f7c288a..33e6038913 100644 --- a/src/puyapy/awst_build/eb/log.py +++ b/src/puyapy/awst_build/eb/log.py @@ -7,6 +7,7 @@ from puya.parse import SourceLocation from puyapy.awst_build import intrinsic_factory, pytypes +from puyapy.awst_build.eb import _expect as expect from puyapy.awst_build.eb._base import FunctionBuilder from puyapy.awst_build.eb.interface import InstanceBuilder, NodeBuilder from puyapy.awst_build.eb.none import NoneExpressionBuilder @@ -42,13 +43,13 @@ def call( ): sep = sep_arg.to_bytes(sep_arg.source_location) else: - logger.error("unexpected argument type", location=sep_arg.source_location) + expect.not_the_type_of(sep_arg, default=expect.default_none) sep = empty_utf8 bytes_args = [] for arg in args: if not isinstance(arg, InstanceBuilder): - logger.error("unexpected argument type", location=arg.source_location) + expect.not_the_type_of(arg, default=expect.default_none) else: if arg.pytype == pytypes.IntLiteralType: arg = arg.resolve_literal(UInt64TypeBuilder(arg.source_location)) diff --git a/src/puyapy/awst_build/eb/string.py b/src/puyapy/awst_build/eb/string.py index 05cb361ddd..d16e8472ad 100644 --- a/src/puyapy/awst_build/eb/string.py +++ b/src/puyapy/awst_build/eb/string.py @@ -86,9 +86,10 @@ def call( case None: str_const = StringConstant(value="", source_location=location) return StringExpressionBuilder(str_const) - case _: - logger.error("unexpected argument type", location=arg.source_location) - return dummy_value(self.produces(), location) + case other: + return expect.not_the_type_of( + other, default=expect.default_dummy_value(self.produces()) + ) class StringExpressionBuilder(BytesBackedInstanceExpressionBuilder): @@ -277,7 +278,7 @@ def call( arg = expect.exactly_one_arg(args, location, default=expect.default_none) if not isinstance(arg, StaticSizedCollectionBuilder): if arg is not None: # if None, already have an error logged - logger.error("unexpected argument type", location=arg.source_location) + expect.not_the_type_of(arg, default=expect.default_none) return dummy_value(pytypes.StringType, location) items = [ expect.argument_of_type_else_dummy( diff --git a/src/puyapy/awst_build/eb/subroutine.py b/src/puyapy/awst_build/eb/subroutine.py index bfa862e843..b461b3037c 100644 --- a/src/puyapy/awst_build/eb/subroutine.py +++ b/src/puyapy/awst_build/eb/subroutine.py @@ -109,7 +109,7 @@ def call( arg = arg_map[arg_map_name] if pytypes.ContractBaseType in arg_typ.mro: if not is_type_or_subtype(arg.pytype, of=arg_typ): - logger.error("unexpected argument type", location=arg.source_location) + expect.not_the_type_of(arg, default=expect.default_none) else: arg = expect.argument_of_type_else_dummy(arg, arg_typ) passed_name = arg_map_name if arg_map_name in arg_names else None diff --git a/src/puyapy/awst_build/eb/transaction/inner.py b/src/puyapy/awst_build/eb/transaction/inner.py index cf25bec72b..7d5c168ce6 100644 --- a/src/puyapy/awst_build/eb/transaction/inner.py +++ b/src/puyapy/awst_build/eb/transaction/inner.py @@ -8,6 +8,7 @@ from puya.parse import SourceLocation from puyapy.awst_build import pytypes +from puyapy.awst_build.eb import _expect as expect from puyapy.awst_build.eb._base import FunctionBuilder from puyapy.awst_build.eb.factories import builder_for_instance from puyapy.awst_build.eb.interface import InstanceBuilder, NodeBuilder, TypeBuilder @@ -85,8 +86,8 @@ def call( pytype=pytypes.TransactionRelatedType() as arg_pytype ) if arg_pytype in pytypes.InnerTransactionFieldsetTypes.values(): pass - case _: - raise CodeError("unexpected argument type", arg.source_location) + case other: + expect.not_the_type_of(other, default=expect.default_raise) arg_exprs.append(arg.resolve()) arg_result_type = pytypes.InnerTransactionResultTypes[arg_pytype.transaction_type] diff --git a/src/puyapy/awst_build/pytypes.py b/src/puyapy/awst_build/pytypes.py index 3939442001..fab7b4404d 100644 --- a/src/puyapy/awst_build/pytypes.py +++ b/src/puyapy/awst_build/pytypes.py @@ -226,6 +226,14 @@ class UnionType(PyType): name: str = attrs.field(init=False) source_location: SourceLocation + def __attrs_post_init__(self) -> None: + if len({isinstance(t, LiteralOnlyType) for t in self.types}) > 1: + # we can't support these for semantic compatibility reasons with testing + raise CodeError( + "type unions between literal and non-literals are not supported", + self.source_location, + ) + @name.default def _name(self) -> str: return " | ".join(t.name for t in self.types) diff --git a/tests/test_execution.py b/tests/test_execution.py index ebe8319036..09373f0479 100644 --- a/tests/test_execution.py +++ b/tests/test_execution.py @@ -1369,7 +1369,9 @@ def approval_program(self) -> bool: def clear_state_program(self) -> bool: return True - with pytest.raises(puya.errors.CodeError, match="unexpected argument type"): + with pytest.raises( + puya.errors.CodeError, match="type unions are unsupported at this location" + ): harness.deploy_from_closure(test)