Skip to content

Commit

Permalink
improve type union error messaging, since some framework functions ar…
Browse files Browse the repository at this point in the history
…e typed as having union parameters
  • Loading branch information
achidlow committed Sep 13, 2024
1 parent a9792fc commit 9f0820a
Show file tree
Hide file tree
Showing 11 changed files with 58 additions and 44 deletions.
32 changes: 17 additions & 15 deletions src/puyapy/awst_build/eb/_expect.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def at_most_one_arg_of_type(
logger.error(f"expected at most 1 argument, got {len(args)}", location=location)
if isinstance(first, InstanceBuilder) and is_type_or_subtype(first.pytype, of_any=valid_types):
return first
logger.error("unexpected argument type", location=first.source_location)
return None
return not_the_type_of(first, default=default_none)


def default_raise(msg: str, location: SourceLocation) -> typing.Never:
Expand Down Expand Up @@ -68,6 +67,15 @@ def defaulter(msg: str, location: SourceLocation) -> InstanceBuilder: # noqa: A
return defaulter


def not_the_type_of(node: NodeBuilder, default: Callable[[str, SourceLocation], _T]) -> _T:
if isinstance(node.pytype, pytypes.UnionType):
msg = "type unions are unsupported at this location"
else:
msg = "unexpected argument type"
logger.error(msg, location=node.source_location)
return default(msg, node.source_location)


def at_least_one_arg(
args: Sequence[_TBuilder],
location: SourceLocation,
Expand Down Expand Up @@ -120,10 +128,7 @@ def exactly_one_arg_of_type(
first = maybe_resolve_literal(first, pytype)
if isinstance(first, InstanceBuilder) and is_type_or_subtype(first.pytype, of=pytype):
return first
msg = "unexpected argument type"
result = default(msg, first.source_location)
logger.error(msg, location=first.source_location)
return result
return not_the_type_of(first, default=default)


def exactly_one_arg_of_type_else_dummy(
Expand Down Expand Up @@ -185,10 +190,7 @@ def argument_of_type(
builder.pytype, of_any=(target_type, *additional_types)
):
return builder
msg = "unexpected argument type"
result = default(msg, builder.source_location)
logger.error(msg, location=builder.source_location)
return result
return not_the_type_of(builder, default=default)


def argument_of_type_else_dummy(
Expand Down Expand Up @@ -218,11 +220,11 @@ def simple_string_literal(
return value
case InstanceBuilder(pytype=pytypes.StrLiteralType):
msg = "argument must be a simple str literal"
case _:
msg = "unexpected argument type"
result = default(msg, builder.source_location)
logger.error(msg, location=builder.source_location)
return result
result = default(msg, builder.source_location)
logger.error(msg, location=builder.source_location)
return result
case other:
return not_the_type_of(other, default=default)


def instance_builder(
Expand Down
4 changes: 2 additions & 2 deletions src/puyapy/awst_build/eb/arc4/abi_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,8 @@ def call(
kind="update" if is_update else "create",
location=method_or_type.source_location,
)
case _:
raise CodeError("unexpected argument type", method_or_type.source_location)
case other:
expect.not_the_type_of(other, default=expect.default_raise)
if compiled is None:
compiled = CompiledContractExpressionBuilder(
CompiledContract(
Expand Down
8 changes: 4 additions & 4 deletions src/puyapy/awst_build/eb/arc4/tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def call(
ARC4Encode(value=arg.resolve(), wtype=wtype, source_location=location), typ
)
case _:
# don't know expected type
raise CodeError("unexpected argument type", arg.source_location)
# don't know expected type, so raise
expect.not_the_type_of(arg, default=expect.default_raise)


class ARC4TupleTypeBuilder(ARC4TypeBuilder[pytypes.TupleType]):
Expand Down Expand Up @@ -91,8 +91,8 @@ def index(self, index: InstanceBuilder, location: SourceLocation) -> InstanceBui
pass
case InstanceBuilder(pytype=pytypes.IntLiteralType):
raise CodeError("tuple index must be a simple int literal", index.source_location)
case _:
raise CodeError("unexpected argument type", index.source_location)
case other:
expect.not_the_type_of(other, default=expect.default_raise)
try:
item_typ = self.pytype.items[index_value]
except IndexError:
Expand Down
7 changes: 4 additions & 3 deletions src/puyapy/awst_build/eb/bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,10 @@ def call(
value=b"", encoding=BytesEncoding.unknown, source_location=location
)
return BytesExpressionBuilder(value)
case _:
logger.error("unexpected argument type", location=arg.source_location)
return dummy_value(self.produces(), location)
case other:
return expect.not_the_type_of(
other, default=expect.default_dummy_value(self.produces())
)

@typing.override
def member_access(self, name: str, location: SourceLocation) -> NodeBuilder:
Expand Down
18 changes: 8 additions & 10 deletions src/puyapy/awst_build/eb/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,12 @@ def call(
case NodeBuilder(pytype=pytypes.TypeType(typ=pytypes.ContractType() as contract_typ)):
contract = contract_typ.name
case invalid_or_none:
# if None (=missing), then error message already logged by get_arg_mapping
if invalid_or_none is not None:
logger.error(
"unexpected argument type", location=invalid_or_none.source_location
)
return dummy_value(result_type, location)
if invalid_or_none is None:
# if None (=missing), then error message already logged by get_arg_mapping
return dummy_value(result_type, location)
return expect.not_the_type_of(
invalid_or_none, default=expect.default_dummy_value(result_type)
)

return CompiledContractExpressionBuilder(
CompiledContract(
Expand Down Expand Up @@ -207,9 +207,7 @@ def call(
logic_sig = LogicSigReference("") # dummy reference
# if None (=missing), then error message already logged by get_arg_mapping
if missing_or_invalid is not None:
logger.error(
"unexpected argument type", location=missing_or_invalid.source_location
)
expect.not_the_type_of(missing_or_invalid, default=expect.default_none)
prefix, template_vars = _extract_prefix_template_args(arg_map)
return CompiledLogicSigExpressionBuilder(
CompiledLogicSig(
Expand All @@ -232,7 +230,7 @@ def _extract_prefix_template_args(
if isinstance(template_vars_node, DictLiteralBuilder):
template_vars = {k: v.resolve() for k, v in template_vars_node.mapping.items()}
else:
logger.error("unexpected argument type", location=template_vars_node.source_location)
expect.not_the_type_of(template_vars_node, default=expect.default_none)
if prefix_node := name_args.get("template_vars_prefix"):
prefix = expect.simple_string_literal(prefix_node, default=expect.default_none)
return prefix, template_vars
5 changes: 3 additions & 2 deletions src/puyapy/awst_build/eb/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from puya.parse import SourceLocation

from puyapy.awst_build import intrinsic_factory, pytypes
from puyapy.awst_build.eb import _expect as expect
from puyapy.awst_build.eb._base import FunctionBuilder
from puyapy.awst_build.eb.interface import InstanceBuilder, NodeBuilder
from puyapy.awst_build.eb.none import NoneExpressionBuilder
Expand Down Expand Up @@ -42,13 +43,13 @@ def call(
):
sep = sep_arg.to_bytes(sep_arg.source_location)
else:
logger.error("unexpected argument type", location=sep_arg.source_location)
expect.not_the_type_of(sep_arg, default=expect.default_none)
sep = empty_utf8

bytes_args = []
for arg in args:
if not isinstance(arg, InstanceBuilder):
logger.error("unexpected argument type", location=arg.source_location)
expect.not_the_type_of(arg, default=expect.default_none)
else:
if arg.pytype == pytypes.IntLiteralType:
arg = arg.resolve_literal(UInt64TypeBuilder(arg.source_location))
Expand Down
9 changes: 5 additions & 4 deletions src/puyapy/awst_build/eb/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,10 @@ def call(
case None:
str_const = StringConstant(value="", source_location=location)
return StringExpressionBuilder(str_const)
case _:
logger.error("unexpected argument type", location=arg.source_location)
return dummy_value(self.produces(), location)
case other:
return expect.not_the_type_of(
other, default=expect.default_dummy_value(self.produces())
)


class StringExpressionBuilder(BytesBackedInstanceExpressionBuilder):
Expand Down Expand Up @@ -277,7 +278,7 @@ def call(
arg = expect.exactly_one_arg(args, location, default=expect.default_none)
if not isinstance(arg, StaticSizedCollectionBuilder):
if arg is not None: # if None, already have an error logged
logger.error("unexpected argument type", location=arg.source_location)
expect.not_the_type_of(arg, default=expect.default_none)
return dummy_value(pytypes.StringType, location)
items = [
expect.argument_of_type_else_dummy(
Expand Down
2 changes: 1 addition & 1 deletion src/puyapy/awst_build/eb/subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def call(
arg = arg_map[arg_map_name]
if pytypes.ContractBaseType in arg_typ.mro:
if not is_type_or_subtype(arg.pytype, of=arg_typ):
logger.error("unexpected argument type", location=arg.source_location)
expect.not_the_type_of(arg, default=expect.default_none)
else:
arg = expect.argument_of_type_else_dummy(arg, arg_typ)
passed_name = arg_map_name if arg_map_name in arg_names else None
Expand Down
5 changes: 3 additions & 2 deletions src/puyapy/awst_build/eb/transaction/inner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from puya.parse import SourceLocation

from puyapy.awst_build import pytypes
from puyapy.awst_build.eb import _expect as expect
from puyapy.awst_build.eb._base import FunctionBuilder
from puyapy.awst_build.eb.factories import builder_for_instance
from puyapy.awst_build.eb.interface import InstanceBuilder, NodeBuilder, TypeBuilder
Expand Down Expand Up @@ -85,8 +86,8 @@ def call(
pytype=pytypes.TransactionRelatedType() as arg_pytype
) if arg_pytype in pytypes.InnerTransactionFieldsetTypes.values():
pass
case _:
raise CodeError("unexpected argument type", arg.source_location)
case other:
expect.not_the_type_of(other, default=expect.default_raise)

arg_exprs.append(arg.resolve())
arg_result_type = pytypes.InnerTransactionResultTypes[arg_pytype.transaction_type]
Expand Down
8 changes: 8 additions & 0 deletions src/puyapy/awst_build/pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,14 @@ class UnionType(PyType):
name: str = attrs.field(init=False)
source_location: SourceLocation

def __attrs_post_init__(self) -> None:
if len({isinstance(t, LiteralOnlyType) for t in self.types}) > 1:
# we can't support these for semantic compatibility reasons with testing
raise CodeError(
"type unions between literal and non-literals are not supported",
self.source_location,
)

@name.default
def _name(self) -> str:
return " | ".join(t.name for t in self.types)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,7 +1369,9 @@ def approval_program(self) -> bool:
def clear_state_program(self) -> bool:
return True

with pytest.raises(puya.errors.CodeError, match="unexpected argument type"):
with pytest.raises(
puya.errors.CodeError, match="type unions are unsupported at this location"
):
harness.deploy_from_closure(test)


Expand Down

0 comments on commit 9f0820a

Please sign in to comment.