Skip to content

Commit 6e6e2c2

Browse files
committed
refactor!: Store function argument names in FuncInput
1 parent e3ef76a commit 6e6e2c2

File tree

6 files changed

+37
-30
lines changed

6 files changed

+37
-30
lines changed

guppylang-internals/src/guppylang_internals/checker/expr_checker.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -496,12 +496,7 @@ def visit_Attribute(self, node: ast.Attribute) -> tuple[ast.expr, Type]:
496496
)
497497
# Make a closure by partially applying the `self` argument
498498
# TODO: Try to infer some type args based on `self`
499-
result_ty = FunctionType(
500-
func.ty.inputs[1:],
501-
func.ty.output,
502-
func.ty.input_names[1:] if func.ty.input_names else None,
503-
func.ty.params,
504-
)
499+
result_ty = FunctionType(func.ty.inputs[1:], func.ty.output, func.ty.params)
505500
return with_loc(node, PartialApply(func=name, args=[node.value])), result_ty
506501
raise GuppyTypeError(AttributeNotFoundError(attr_span, ty, node.attr))
507502

guppylang-internals/src/guppylang_internals/checker/func_checker.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,12 @@ def check_global_func_def(
134134
"""Type checks a top-level function definition."""
135135
args = func_def.args.args
136136
returns_none = isinstance(ty.output, NoneType)
137-
assert ty.input_names is not None
137+
assert all(inp.name is not None for inp in ty.inputs)
138138

139139
cfg = CFGBuilder().build(func_def.body, returns_none, globals)
140140
inputs = [
141-
Variable(x, inp.ty, loc, inp.flags, is_func_input=True)
142-
for x, inp, loc in zip(ty.input_names, ty.inputs, args, strict=True)
141+
Variable(cast(str, inp.name), inp.ty, loc, inp.flags, is_func_input=True)
142+
for inp, loc in zip(ty.inputs, args, strict=True)
143143
# Comptime inputs are turned into generic args, so are not included here
144144
if InputFlags.Comptime not in inp.flags
145145
]
@@ -194,10 +194,14 @@ def check_nested_func_def(
194194

195195
# Construct inputs for checking the body CFG
196196
inputs = [v for v, _ in captured.values()] + [
197-
Variable(x, inp.ty, func_def.args.args[i], inp.flags, is_func_input=True)
198-
for i, (x, inp) in enumerate(
199-
zip(func_ty.input_names, func_ty.inputs, strict=True)
197+
Variable(
198+
cast(str, inp.name),
199+
inp.ty,
200+
func_def.args.args[i],
201+
inp.flags,
202+
is_func_input=True,
200203
)
204+
for i, inp in enumerate(func_ty.inputs)
201205
# Comptime inputs are turned into generic args, so are not included here
202206
if InputFlags.Comptime not in inp.flags
203207
]
@@ -305,7 +309,6 @@ def check_signature(
305309
return FunctionType(
306310
inputs,
307311
output,
308-
input_names,
309312
sorted(param_var_mapping.values(), key=lambda v: v.idx),
310313
)
311314

guppylang-internals/src/guppylang_internals/definition/struct.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,13 +273,16 @@ def compile(self, args: list[Wire]) -> list[Wire]:
273273

274274
constructor_sig = FunctionType(
275275
inputs=[
276-
FuncInput(f.ty, InputFlags.Owned if f.ty.linear else InputFlags.NoFlags)
276+
FuncInput(
277+
f.ty,
278+
InputFlags.Owned if f.ty.linear else InputFlags.NoFlags,
279+
f.name,
280+
)
277281
for f in self.fields
278282
],
279283
output=StructType(
280284
defn=self, args=[p.to_bound(i) for i, p in enumerate(self.params)]
281285
),
282-
input_names=[f.name for f in self.fields],
283286
params=self.params,
284287
)
285288
constructor_def = CustomFunctionDef(

guppylang-internals/src/guppylang_internals/tys/parsing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def check_function_arg(
289289
ctx.param_var_mapping[name] = ConstParam(
290290
len(ctx.param_var_mapping), name, ty, from_comptime_arg=True
291291
)
292-
return FuncInput(ty, flags)
292+
return FuncInput(ty, flags, name)
293293

294294

295295
if sys.version_info >= (3, 12):

guppylang-internals/src/guppylang_internals/tys/printing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,10 @@ def _wrap(s: str, inside_row: bool) -> str:
165165

166166
def signature_to_str(name: str, sig: FunctionType) -> str:
167167
"""Displays a function signature in Python syntax including the function name."""
168-
assert sig.input_names is not None
168+
assert all(inp.name is not None for inp in sig.inputs)
169169
s = f"def {name}("
170170
s += ", ".join(
171-
f"{name}: {inp.ty}{TypePrinter._print_flags(inp.flags)}"
172-
for name, inp in zip(sig.input_names, sig.inputs, strict=True)
171+
f"{inp.name}: {inp.ty}{TypePrinter._print_flags(inp.flags)}"
172+
for inp in sig.inputs
173173
)
174174
return s + ") -> " + str(sig.output)

guppylang-internals/src/guppylang_internals/tys/ty.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22
from collections.abc import Sequence
3-
from dataclasses import dataclass, field
3+
from dataclasses import dataclass, field, replace
44
from enum import Enum, Flag, auto
55
from functools import cached_property, total_ordering
66
from typing import TYPE_CHECKING, ClassVar, TypeAlias, cast
@@ -389,6 +389,10 @@ class FuncInput:
389389
ty: "Type"
390390
flags: InputFlags
391391

392+
#: Name of this input, or `None` if it is an unnamed argument (e.g. inside a
393+
#: higher-order `Callable` type)
394+
name: str | None = field(default=None, compare=False)
395+
392396

393397
@dataclass(frozen=True, init=False)
394398
class FunctionType(ParametrizedTypeBase):
@@ -397,7 +401,6 @@ class FunctionType(ParametrizedTypeBase):
397401
inputs: Sequence[FuncInput]
398402
output: "Type"
399403
params: Sequence[Parameter]
400-
input_names: Sequence[str] | None
401404
comptime_args: Sequence[ConstArg]
402405

403406
args: Sequence[Argument] = field(init=False)
@@ -411,7 +414,6 @@ def __init__(
411414
self,
412415
inputs: Sequence[FuncInput],
413416
output: "Type",
414-
input_names: Sequence[str] | None = None,
415417
params: Sequence[Parameter] | None = None,
416418
comptime_args: Sequence[ConstArg] | None = None,
417419
) -> None:
@@ -433,7 +435,6 @@ def __init__(
433435
object.__setattr__(self, "comptime_args", comptime_args)
434436
object.__setattr__(self, "inputs", inputs)
435437
object.__setattr__(self, "output", output)
436-
object.__setattr__(self, "input_names", input_names or [])
437438
object.__setattr__(self, "params", params)
438439

439440
@property
@@ -449,6 +450,16 @@ def bound_vars(self) -> set[BoundVar]:
449450
return set()
450451
return super().bound_vars
451452

453+
@cached_property
454+
def input_names(self) -> Sequence[str] | None:
455+
"""Names of all inputs or `None` if there are unnamed inputs."""
456+
names: list[str] = []
457+
for inp in self.inputs:
458+
if inp.name is None:
459+
return None
460+
names.append(inp.name)
461+
return names
462+
452463
def cast(self) -> "Type":
453464
"""Casts an implementor of `TypeBase` into a `Type`."""
454465
return self
@@ -507,12 +518,8 @@ def visit(self, visitor: Visitor) -> None:
507518
def transform(self, transformer: Transformer) -> "Type":
508519
"""Accepts a transformer on this type."""
509520
return transformer.transform(self) or FunctionType(
510-
[
511-
FuncInput(inp.ty.transform(transformer), inp.flags)
512-
for inp in self.inputs
513-
],
521+
[replace(inp, ty=inp.ty.transform(transformer)) for inp in self.inputs],
514522
self.output.transform(transformer),
515-
self.input_names,
516523
self.params,
517524
)
518525

@@ -542,9 +549,8 @@ def instantiate_partial(self, args: "PartialInst") -> "FunctionType":
542549

543550
inst = Instantiator(full_inst)
544551
return FunctionType(
545-
[FuncInput(inp.ty.transform(inst), inp.flags) for inp in self.inputs],
552+
[replace(inp, ty=inp.ty.transform(inst)) for inp in self.inputs],
546553
self.output.transform(inst),
547-
self.input_names,
548554
remaining_params,
549555
# Comptime type arguments also need to be instantiated
550556
comptime_args=[

0 commit comments

Comments
 (0)