Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve handling of and/or where the result is a type union #312

Merged
merged 2 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/sizes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
arc4_types/Arc4StringTypes 472 35 437 35 0
arc4_types/Arc4StructsFromAnotherModule 73 12 61 12 0
arc4_types/Arc4StructsType 318 239 79 239 0
arc4_types/Arc4TuplesType 882 138 744 138 0
arc4_types/Arc4TuplesType 865 138 727 138 0
arc_28/EventEmitter 191 133 58 133 0
asset/Reference 269 261 8 261 0
auction/Auction 601 522 79 522 0
augmented_assignment/Augmented 159 156 3 156 0
avm_types_in_abi/Test 423 351 72 351 0
biguint_binary_ops/BiguintBinaryOps 189 77 112 77 0
boolean_binary_ops/BooleanBinaryOps 345 280 65 280 0
boolean_binary_ops/BooleanBinaryOps 1154 471 683 471 0
box_storage/Box 1860 1435 425 1435 0
bytes_ops/BiguintBinaryOps 139 139 0 139 0
calculator 349 317 32 315 2
Expand Down
14 changes: 2 additions & 12 deletions src/puyapy/awst_build/arc4_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,19 +382,9 @@ def require_arg_name(arg: pytypes.FuncArg) -> str:
)
return arg.name

def require_single_type(arg: pytypes.FuncArg) -> pytypes.PyType:
try:
(typ,) = arg.types
except ValueError:
raise CodeError(
"union types are not supported as method arguments", location
) from None
else:
return typ

if not (
func_type.args
and set(require_single_type(func_type.args[0]).mro).intersection(
and set(func_type.args[0].type.mro).intersection(
(pytypes.ARC4ContractBaseType, pytypes.ARC4ClientBaseType)
)
):
Expand All @@ -403,7 +393,7 @@ def require_single_type(arg: pytypes.FuncArg) -> pytypes.PyType:
f" instance methods of classes derived from {pytypes.ARC4ContractBaseType}",
location,
)
result = {require_arg_name(arg): require_single_type(arg) for arg in func_type.args[1:]}
result = {require_arg_name(arg): arg.type for arg in func_type.args[1:]}
if "output" in result:
# https://github.com/algorandfoundation/ARCs/blob/main/assets/arc-0032/application.schema.json
raise CodeError(
Expand Down
38 changes: 16 additions & 22 deletions src/puyapy/awst_build/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@
from puya.errors import CodeError, InternalError, log_exceptions
from puya.models import ContractReference
from puya.parse import SourceLocation
from puya.utils import attrs_extend
from puya.utils import attrs_extend, unique

from puyapy.awst_build import pytypes
from puyapy.awst_build.contract_data import AppStorageDeclaration
from puyapy.awst_build.exceptions import TypeUnionError
from puyapy.parse import ParseResult, source_location_from_mypy

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -197,7 +196,7 @@ def type_to_pytype(
registry: Mapping[str, pytypes.PyType],
mypy_type: mypy.types.Type,
*,
source_location: SourceLocation | None,
source_location: SourceLocation,
in_type_args: bool = False,
in_func_sig: bool = False,
) -> pytypes.PyType:
Expand Down Expand Up @@ -269,13 +268,13 @@ def type_to_pytype(
our_literal_value = literal_value
return pytypes.TypingLiteralType(value=our_literal_value, source_location=loc)
case mypy.types.UnionType(items=items):
types = [recurse(it) for it in items]
types = unique(recurse(it) for it in items)
if not types:
raise CodeError("Cannot resolve empty type", loc)
if len(types) == 1:
return pytypes.NeverType
elif len(types) == 1:
return types[0]
else:
raise TypeUnionError(types, loc)
return pytypes.UnionType(types, loc)
case mypy.types.NoneType() | mypy.types.PartialType(type=None):
return pytypes.NoneType
case mypy.types.UninhabitedType():
Expand All @@ -301,19 +300,14 @@ def type_to_pytype(
for at, name, kind in zip(
func_like.arg_types, func_like.arg_names, func_like.arg_kinds, strict=True
):
try:
pt = type_to_pytype(
registry,
at,
source_location=loc,
in_type_args=in_type_args,
in_func_sig=True,
)
except TypeUnionError as union:
pts = union.types
else:
pts = [pt]
func_args.append(pytypes.FuncArg(types=pts, kind=kind, name=name))
arg_pytype = type_to_pytype(
registry,
at,
source_location=loc,
in_type_args=in_type_args,
in_func_sig=True,
)
func_args.append(pytypes.FuncArg(type=arg_pytype, kind=kind, name=name))
if None in func_like.bound_args:
logger.debug(
"None contained in bound args for function reference", location=loc
Expand Down Expand Up @@ -347,7 +341,7 @@ def _maybe_parameterise_pytype(
registry: Mapping[str, pytypes.PyType],
maybe_generic: pytypes.PyType,
mypy_type_args: Sequence[mypy.types.Type],
loc: SourceLocation | None,
loc: SourceLocation,
) -> pytypes.PyType:
if not mypy_type_args:
return maybe_generic
Expand All @@ -361,7 +355,7 @@ def _maybe_parameterise_pytype(
return result


def _type_of_any_to_error_message(type_of_any: int, source_location: SourceLocation | None) -> str:
def _type_of_any_to_error_message(type_of_any: int, source_location: SourceLocation) -> str:
from mypy.types import TypeOfAny

match type_of_any:
Expand Down
7 changes: 7 additions & 0 deletions src/puyapy/awst_build/eb/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing_extensions
from puya import log
from puya.awst.nodes import (
BinaryBooleanOperator,
CompileTimeConstantExpression,
Expression,
FieldExpression,
Expand Down Expand Up @@ -152,6 +153,12 @@ def binary_op(
) -> InstanceBuilder:
return NotImplemented

@typing.override
def bool_binary_op(
self, other: InstanceBuilder, op: BinaryBooleanOperator, location: SourceLocation
) -> InstanceBuilder:
return super().bool_binary_op(other, op, location)

@typing.override
def augmented_assignment(
self, op: BuilderBinaryOp, rhs: InstanceBuilder, location: SourceLocation
Expand Down
39 changes: 19 additions & 20 deletions src/puyapy/awst_build/eb/_expect.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def at_most_one_arg_of_type(
logger.error(f"expected at most 1 argument, got {len(args)}", location=location)
if isinstance(first, InstanceBuilder) and is_type_or_subtype(first.pytype, of_any=valid_types):
return first
logger.error("unexpected argument type", location=first.source_location)
return None
return not_this_type(first, default=default_none)


def default_raise(msg: str, location: SourceLocation) -> typing.Never:
Expand All @@ -60,14 +59,24 @@ def defaulter(msg: str, location: SourceLocation) -> _T: # noqa: ARG001
def default_dummy_value(
pytype: pytypes.PyType,
) -> Callable[[str, SourceLocation], InstanceBuilder]:
assert not isinstance(pytype, pytypes.LiteralOnlyType)

def defaulter(msg: str, location: SourceLocation) -> InstanceBuilder: # noqa: ARG001
return dummy_value(pytype, location)

return defaulter


def not_this_type(node: NodeBuilder, default: Callable[[str, SourceLocation], _T]) -> _T:
"""Provide consistent error messages for unexpected types."""
if isinstance(node.pytype, pytypes.UnionType):
msg = "type unions are unsupported at this location"
else:
msg = "unexpected argument type"
result = default(msg, node.source_location)
logger.error(msg, location=node.source_location)
return result


def at_least_one_arg(
args: Sequence[_TBuilder],
location: SourceLocation,
Expand Down Expand Up @@ -120,10 +129,7 @@ def exactly_one_arg_of_type(
first = maybe_resolve_literal(first, pytype)
if isinstance(first, InstanceBuilder) and is_type_or_subtype(first.pytype, of=pytype):
return first
msg = "unexpected argument type"
result = default(msg, first.source_location)
logger.error(msg, location=first.source_location)
return result
return not_this_type(first, default=default)


def exactly_one_arg_of_type_else_dummy(
Expand All @@ -133,8 +139,6 @@ def exactly_one_arg_of_type_else_dummy(
*,
resolve_literal: bool = False,
) -> InstanceBuilder:
assert not isinstance(pytype, pytypes.LiteralOnlyType)

return exactly_one_arg_of_type(
args,
pytype,
Expand All @@ -152,8 +156,6 @@ def no_args(args: Sequence[NodeBuilder], location: SourceLocation) -> None:
def exactly_n_args_of_type_else_dummy(
args: Sequence[NodeBuilder], pytype: pytypes.PyType, location: SourceLocation, num_args: int
) -> Sequence[InstanceBuilder]:
assert not isinstance(pytype, pytypes.LiteralOnlyType)

if not exactly_n_args(args, location, num_args):
dummy_args = [dummy_value(pytype, location)] * num_args
args = [arg or default for arg, default in zip_longest(args, dummy_args)]
Expand Down Expand Up @@ -185,10 +187,7 @@ def argument_of_type(
builder.pytype, of_any=(target_type, *additional_types)
):
return builder
msg = "unexpected argument type"
result = default(msg, builder.source_location)
logger.error(msg, location=builder.source_location)
return result
return not_this_type(builder, default=default)


def argument_of_type_else_dummy(
Expand Down Expand Up @@ -218,11 +217,11 @@ def simple_string_literal(
return value
case InstanceBuilder(pytype=pytypes.StrLiteralType):
msg = "argument must be a simple str literal"
case _:
msg = "unexpected argument type"
result = default(msg, builder.source_location)
logger.error(msg, location=builder.source_location)
return result
result = default(msg, builder.source_location)
logger.error(msg, location=builder.source_location)
return result
case other:
return not_this_type(other, default=default)


def instance_builder(
Expand Down
10 changes: 10 additions & 0 deletions src/puyapy/awst_build/eb/_literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from puya import log
from puya.awst.nodes import (
BinaryBooleanOperator,
BoolConstant,
BytesConstant,
BytesEncoding,
Expand Down Expand Up @@ -121,6 +122,15 @@ def binary_op(
folded = fold_binary_expr(location, op.value, lhs, rhs)
return LiteralBuilderImpl(value=folded, source_location=location)

@typing.override
def bool_binary_op(
self, other: InstanceBuilder, op: BinaryBooleanOperator, location: SourceLocation
) -> InstanceBuilder:
if not isinstance(other, LiteralBuilder):
return super().bool_binary_op(other, op, location)
folded = fold_binary_expr(location, op.value, self.value, other.value)
return LiteralBuilderImpl(value=folded, source_location=location)

@typing.override
def augmented_assignment(
self, op: BuilderBinaryOp, rhs: InstanceBuilder, location: SourceLocation
Expand Down
10 changes: 7 additions & 3 deletions src/puyapy/awst_build/eb/_type_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Callable

from puya.awst.nodes import Expression
from puya.errors import InternalError
from puya.errors import CodeError, InternalError
from puya.parse import SourceLocation

from puyapy.awst_build import constants, intrinsic_data, pytypes
Expand Down Expand Up @@ -247,7 +247,9 @@ def builder_for_instance(pytyp: pytypes.PyType, expr: Expression) -> InstanceBui
for base in pytyp.mro:
if eb_base := PYTYPE_BASE_TO_BUILDER.get(base):
return eb_base(expr, pytyp)
raise InternalError(f"No builder for instance: {pytyp}", expr.source_location)
if isinstance(pytyp, pytypes.UnionType):
raise CodeError("type unions are unsupported at this location", expr.source_location)
raise InternalError(f"no builder for instance: {pytyp}", expr.source_location)


def builder_for_type(pytyp: pytypes.PyType, expr_loc: SourceLocation) -> CallableBuilder:
Expand All @@ -258,4 +260,6 @@ def builder_for_type(pytyp: pytypes.PyType, expr_loc: SourceLocation) -> Callabl
for base in pytyp.mro:
if tb_base := PYTYPE_BASE_TO_TYPE_BUILDER.get(base):
return tb_base(pytyp, expr_loc)
raise InternalError(f"No builder for type: {pytyp}", expr_loc)
if isinstance(pytyp, pytypes.UnionType):
raise CodeError("type unions are unsupported at this location", expr_loc)
raise InternalError(f"no builder for type: {pytyp}", expr_loc)
4 changes: 4 additions & 0 deletions src/puyapy/awst_build/eb/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@


def dummy_value(pytype: pytypes.PyType, location: SourceLocation) -> InstanceBuilder:
if isinstance(pytype, pytypes.LiteralOnlyType):
from puyapy.awst_build.eb._literals import LiteralBuilderImpl

return LiteralBuilderImpl(pytype.python_type(), location)
expr = VarExpression(name="", wtype=pytype.wtype, source_location=location)
return builder_for_instance(pytype, expr)

Expand Down
4 changes: 2 additions & 2 deletions src/puyapy/awst_build/eb/arc4/abi_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,8 @@ def call(
kind="update" if is_update else "create",
location=method_or_type.source_location,
)
case _:
raise CodeError("unexpected argument type", method_or_type.source_location)
case other:
expect.not_this_type(other, default=expect.default_raise)
if compiled is None:
compiled = CompiledContractExpressionBuilder(
CompiledContract(
Expand Down
8 changes: 4 additions & 4 deletions src/puyapy/awst_build/eb/arc4/tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def call(
ARC4Encode(value=arg.resolve(), wtype=wtype, source_location=location), typ
)
case _:
# don't know expected type
raise CodeError("unexpected argument type", arg.source_location)
# don't know expected type, so raise
expect.not_this_type(arg, default=expect.default_raise)


class ARC4TupleTypeBuilder(ARC4TypeBuilder[pytypes.TupleType]):
Expand Down Expand Up @@ -91,8 +91,8 @@ def index(self, index: InstanceBuilder, location: SourceLocation) -> InstanceBui
pass
case InstanceBuilder(pytype=pytypes.IntLiteralType):
raise CodeError("tuple index must be a simple int literal", index.source_location)
case _:
raise CodeError("unexpected argument type", index.source_location)
case other:
expect.not_this_type(other, default=expect.default_raise)
try:
item_typ = self.pytype.items[index_value]
except IndexError:
Expand Down
Loading
Loading