Skip to content

Commit

Permalink
fix: when there is exactly 15 arguments to an ABI function, the final…
Browse files Browse the repository at this point in the history
… argument should not be expected to be automatically tuple-packed
  • Loading branch information
achidlow committed Sep 2, 2024
1 parent eb3409f commit 860404e
Show file tree
Hide file tree
Showing 20 changed files with 3,762 additions and 2,132 deletions.
2 changes: 1 addition & 1 deletion examples/sizes.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Name O0 size O1 size O1 ⏷ O2 size O2 ⏷
abi_routing/CustomApproval 284 256 28 256 0
abi_routing/MinimumARC4 77 55 22 55 0
abi_routing/Reference 1284 1082 202 1082 0
abi_routing/Reference 1462 1223 239 1223 0
amm/ConstantProductAMM 1260 1109 151 1109 0
application/Reference 180 168 12 168 0
arc4_dynamic_arrays/DynamicArray 3008 1970 1038 1970 0
Expand Down
104 changes: 50 additions & 54 deletions src/puya/ir/arc4_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def call(
location: SourceLocation, method: awst_nodes.ContractMethod, *args: awst_nodes.Expression
) -> awst_nodes.SubroutineCallExpression:
return awst_nodes.SubroutineCallExpression(
source_location=location,
wtype=method.return_type,
target=awst_nodes.ContractMethodTarget(cref=method.cref, member_name=method.member_name),
args=[awst_nodes.CallArg(name=None, value=arg) for arg in args],
wtype=method.return_type,
source_location=location,
)


Expand Down Expand Up @@ -290,77 +290,72 @@ def check_allowed_oca(


def _map_abi_args(
args: Sequence[awst_nodes.SubroutineArgument], location: SourceLocation
arg_types: Sequence[wtypes.WType], location: SourceLocation
) -> Iterable[awst_nodes.Expression]:
abi_arg_index = 1 # 0th arg is for method selector
transaction_arg_offset = sum(1 for a in args if isinstance(a.wtype, wtypes.WGroupTransaction))

non_transaction_args = [a for a in args if not isinstance(a.wtype, wtypes.WGroupTransaction)]
last_arg: awst_nodes.Expression | None = None
if len(non_transaction_args) > 15:
args_overflow_wtype = wtypes.ARC4Tuple(
types=[
_map_param_wtype_to_arc4_tuple_type(a.wtype) for a in non_transaction_args[14:]
],
source_location=location,
)
last_arg = app_arg(15, args_overflow_wtype, location)

def get_arg(index: int, arg_wtype: wtypes.WType) -> awst_nodes.Expression:
if index < 15:
return app_arg(index, arg_wtype, location)
transaction_arg_offset = 0
incoming_types = []
for a in arg_types:
if isinstance(a, wtypes.WGroupTransaction):
transaction_arg_offset += 1
else:
if last_arg is None:
raise InternalError("last_arg should not be None if there are more than 15 args")
assert isinstance(last_arg.wtype, wtypes.ARC4Tuple)
return awst_nodes.TupleItemExpression(
base=last_arg, index=index - 15, source_location=location
if isinstance(a, wtypes.ARC4Type):
arc4_type = a
else:
converted = _maybe_avm_to_arc4_equivalent_type(a)
if converted is not None:
arc4_type = converted
elif _reference_type_array(a) is not None:
arc4_type = wtypes.arc4_byte_alias
else:
raise CodeError(f"not an ARC4 type or native equivalent: {a}", location)
incoming_types.append(arc4_type)

if len(incoming_types) > 15:
unpacked_types, packed_types = incoming_types[:14], incoming_types[14:]
else:
unpacked_types, packed_types = incoming_types, []
abi_args = [
app_arg(array_index, arg_wtype, location)
for array_index, arg_wtype in enumerate(unpacked_types, start=1)
]
if packed_types:
packed_arg = app_arg(
15, wtypes.ARC4Tuple(types=packed_types, source_location=location), location
)
abi_args.extend(
awst_nodes.TupleItemExpression(
base=packed_arg, index=tuple_index, source_location=location
)
for tuple_index, _ in enumerate(packed_types)
)
abi_args.reverse() # reverse so we can pop off end

for arg in args:
if isinstance(arg.wtype, wtypes.WGroupTransaction):
for arg in arg_types:
if isinstance(arg, wtypes.WGroupTransaction):
transaction_index = uint64_sub(
intrinsic_factory.txn("GroupIndex", wtypes.uint64_wtype, location),
constant(
transaction_arg_offset,
location,
),
constant(transaction_arg_offset, location),
location,
)
yield awst_nodes.GroupTransactionReference(
index=transaction_index, wtype=arg.wtype, source_location=location
index=transaction_index, wtype=arg, source_location=location
)
transaction_arg_offset -= 1
else:
if (ref_array := _reference_type_array(arg.wtype)) is not None:
bytes_arg = get_arg(abi_arg_index, wtypes.bytes_wtype)
uint64_index = intrinsic_factory.btoi(bytes_arg, location)
abi_arg = abi_args.pop()
if (ref_array := _reference_type_array(arg)) is not None:
uint64_index = intrinsic_factory.btoi(abi_arg, location)
yield awst_nodes.IntrinsicCall(
op_code="txnas",
immediates=[ref_array],
stack_args=[uint64_index],
wtype=arg.wtype,
wtype=arg,
source_location=location,
)
else:
converted = _maybe_avm_to_arc4_equivalent_type(arg.wtype)
abi_arg = get_arg(abi_arg_index, converted or arg.wtype)
if converted is not None:
abi_arg = _arc4_decode(
bytes_arg=abi_arg, target_wtype=arg.wtype, location=location
)
if abi_arg.wtype != arg:
abi_arg = _arc4_decode(bytes_arg=abi_arg, target_wtype=arg, location=location)
yield abi_arg
abi_arg_index += 1


def _map_param_wtype_to_arc4_tuple_type(wtype: wtypes.WType) -> wtypes.WType:
converted = _maybe_avm_to_arc4_equivalent_type(wtype)
if converted is not None:
return converted
elif _reference_type_array(wtype) is not None:
return wtypes.arc4_byte_alias
else:
return wtype


def route_abi_methods(
Expand All @@ -371,7 +366,8 @@ def route_abi_methods(
seen_signatures = set[str]()
for method, config in methods.items():
abi_loc = config.source_location or location
method_result = call(abi_loc, method, *_map_abi_args(method.args, location))
abi_args = list(_map_abi_args([a.wtype for a in method.args], location))
method_result = call(abi_loc, method, *abi_args)
match method.return_type:
case wtypes.void_wtype:
call_and_maybe_log = awst_nodes.ExpressionStatement(method_result)
Expand Down
38 changes: 38 additions & 0 deletions test_cases/abi_routing/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,44 @@ def method_with_default_args(
assert int_from_storage.native == 2, "wrong int from storage"
assert int_from_function.native == 3, "wrong int from function"

@arc4.abimethod
def method_with_15_args(
self,
one: UInt64,
two: UInt64,
three: UInt64,
four: UInt64,
five: UInt64,
six: UInt64,
seven: UInt64,
eight: UInt64,
nine: UInt64,
ten: UInt64,
eleven: UInt64,
twelve: UInt64,
thirteen: UInt64,
fourteen: UInt64,
fifteen: Bytes,
) -> Bytes:
"""Fifteen args should not encode the last argument as a tuple"""
assert (
one
+ two
+ three
+ four
+ five
+ six
+ seven
+ eight
+ nine
+ ten
+ eleven
+ twelve
+ thirteen
+ fourteen
)
return fifteen

@arc4.abimethod
def method_with_more_than_15_args(
self,
Expand Down
Loading

0 comments on commit 860404e

Please sign in to comment.