-
Notifications
You must be signed in to change notification settings - Fork 17
feat!: Use borrow_array instead of value_array for array lowering
#1166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
40d1255
60dccb3
4054f8f
a1b19af
169aa76
8272f8c
eeecf53
b9d961b
3df7013
db52f55
c9fe2d4
eedecd8
001da68
1926639
c70670f
71b9fc6
31d92cc
1fa119e
8c87895
b75047d
ab08667
7575884
7342398
b8a33c3
9a63367
5525ea5
e81872d
dda2742
025880d
60665ba
d5b4ca6
9332015
f1fc799
3635907
982d8a9
5f3e366
a6d3d0b
f38682c
c1c0808
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -63,13 +63,17 @@ | |||||||||
| from guppylang_internals.std._internal.compiler.arithmetic import ( | ||||||||||
| UnsignedIntVal, | ||||||||||
| convert_ifromusize, | ||||||||||
| convert_itousize, | ||||||||||
| ) | ||||||||||
| from guppylang_internals.std._internal.compiler.array import ( | ||||||||||
| array_clone, | ||||||||||
| array_convert_from_std_array, | ||||||||||
| array_convert_to_std_array, | ||||||||||
| array_map, | ||||||||||
| array_new, | ||||||||||
| array_repeat, | ||||||||||
| array_new_all_borrowed, | ||||||||||
| array_return, | ||||||||||
| array_type, | ||||||||||
| standard_array_type, | ||||||||||
| unpack_array, | ||||||||||
| ) | ||||||||||
|
|
@@ -79,7 +83,6 @@ | |||||||||
| from guppylang_internals.std._internal.compiler.prelude import ( | ||||||||||
| build_error, | ||||||||||
| build_panic, | ||||||||||
| build_unwrap, | ||||||||||
| panic, | ||||||||||
| ) | ||||||||||
| from guppylang_internals.std._internal.compiler.tket_bool import ( | ||||||||||
|
|
@@ -555,23 +558,30 @@ def visit_ResultExpr(self, node: ResultExpr) -> Wire: | |||||||||
| op_name = f"result_array_{base_name}" | ||||||||||
| size_arg = node.array_len.to_arg().to_hugr(self.ctx) | ||||||||||
| extra_args = [size_arg, *extra_args] | ||||||||||
| # Remove the option wrapping in the array | ||||||||||
| unwrap = array_unwrap_elem(self.ctx) | ||||||||||
| unwrap = self.builder.load_function( | ||||||||||
| unwrap, | ||||||||||
| instantiation=ht.FunctionType([ht.Option(base_ty)], [base_ty]), | ||||||||||
| type_args=[ht.TypeTypeArg(base_ty)], | ||||||||||
| # As `borrow_array`s used by Guppy are linear, we need to clone it (knowing | ||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
what are the chances of defining return as a pass-through operation, i.e. such that it returns the array? (much like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would require the signature of the result which is a breaking change in the extension which was a subject of much discussion and though it was implemented it was then reverted again so I don't think we are doing that for now |
||||||||||
| # that all elements in it are copyable) to avoid linearity violations when | ||||||||||
| # both passing it to the result operation and returning it (as an inout | ||||||||||
| # argument). | ||||||||||
| value_wire, inout_wire = self.builder.add_op( | ||||||||||
| array_clone(base_ty, size_arg), value_wire | ||||||||||
| ) | ||||||||||
| map_op = array_map(ht.Option(base_ty), size_arg, base_ty) | ||||||||||
| value_wire = self.builder.add_op(map_op, value_wire, unwrap) | ||||||||||
| func_ty = FunctionType( | ||||||||||
| [ | ||||||||||
| FuncInput( | ||||||||||
| array_type(node.base_ty, node.array_len), InputFlags.Inout | ||||||||||
| ), | ||||||||||
| ], | ||||||||||
| NoneType(), | ||||||||||
| ) | ||||||||||
| self._update_inout_ports(node.args, iter([inout_wire]), func_ty) | ||||||||||
| if is_bool_type(node.base_ty): | ||||||||||
| # We need to coerce a read on all the array elements if they are bools. | ||||||||||
| array_read = array_read_bool(self.ctx) | ||||||||||
| array_read = self.builder.load_function(array_read) | ||||||||||
| map_op = array_map(OpaqueBool, size_arg, ht.Bool) | ||||||||||
| value_wire = self.builder.add_op(map_op, value_wire, array_read) | ||||||||||
| base_ty = ht.Bool | ||||||||||
| # Turn `value_array` into regular linear `array` | ||||||||||
| # Turn `borrow_array` into regular `array` | ||||||||||
| value_wire = self.builder.add_op( | ||||||||||
| array_convert_to_std_array(base_ty, size_arg), value_wire | ||||||||||
| ) | ||||||||||
|
|
@@ -647,8 +657,8 @@ def visit_StateResultExpr(self, node: StateResultExpr) -> Wire: | |||||||||
| ) | ||||||||||
| qubits_out = unpack_array(self.builder, qubit_arr_out) | ||||||||||
| else: | ||||||||||
| # If the input is an array of qubits, we need to unwrap the elements first, | ||||||||||
| # and then convert to a value array and back. | ||||||||||
| # If the input is an array of qubits, we need to convert to a standard | ||||||||||
| # array. | ||||||||||
| qubits_in = [self.visit(node.args[1])] | ||||||||||
| qubits_out = [ | ||||||||||
| apply_array_op_with_conversions( | ||||||||||
|
|
@@ -680,26 +690,23 @@ def visit_DesugaredArrayComp(self, node: DesugaredArrayComp) -> Wire: | |||||||||
| assert isinstance(array_ty, OpaqueType) | ||||||||||
| array_var = Variable(next(tmp_vars), array_ty, node) | ||||||||||
| count_var = Variable(next(tmp_vars), int_type(), node) | ||||||||||
| # See https://github.com/CQCL/guppylang/issues/629 | ||||||||||
| hugr_elt_ty = ht.Option(node.elt_ty.to_hugr(self.ctx)) | ||||||||||
| # Initialise array with `None`s | ||||||||||
| make_none = array_comprehension_init_func(self.ctx) | ||||||||||
| make_none = self.builder.load_function( | ||||||||||
| make_none, | ||||||||||
| instantiation=ht.FunctionType([], [hugr_elt_ty]), | ||||||||||
| type_args=[ht.TypeTypeArg(node.elt_ty.to_hugr(self.ctx))], | ||||||||||
| ) | ||||||||||
| hugr_elt_ty = node.elt_ty.to_hugr(self.ctx) | ||||||||||
| # Initialise empty array. | ||||||||||
| self.dfg[array_var] = self.builder.add_op( | ||||||||||
| array_repeat(hugr_elt_ty, node.length.to_arg().to_hugr(self.ctx)), make_none | ||||||||||
| array_new_all_borrowed(hugr_elt_ty, node.length.to_arg().to_hugr(self.ctx)) | ||||||||||
| ) | ||||||||||
| self.dfg[count_var] = self.builder.load( | ||||||||||
| hugr.std.int.IntVal(0, width=NumericType.INT_WIDTH) | ||||||||||
| ) | ||||||||||
| with self._build_generators([node.generator], [array_var, count_var]): | ||||||||||
| elt = self.visit(node.elt) | ||||||||||
| array, count = self.dfg[array_var], self.dfg[count_var] | ||||||||||
| [], [self.dfg[array_var]] = self._build_method_call( | ||||||||||
| array_ty, "__setitem__", node, [array, count, elt], array_ty.args | ||||||||||
| idx = self.builder.add_op(convert_itousize(), count) | ||||||||||
| self.dfg[array_var] = self.builder.add_op( | ||||||||||
| array_return(hugr_elt_ty, node.length.to_arg().to_hugr(self.ctx)), | ||||||||||
| array, | ||||||||||
| idx, | ||||||||||
| elt, | ||||||||||
| ) | ||||||||||
| # Update `count += 1` | ||||||||||
| one = self.builder.load(hugr.std.int.IntVal(1, width=NumericType.INT_WIDTH)) | ||||||||||
|
|
@@ -835,10 +842,6 @@ def python_value_to_hugr(v: Any, exp_ty: Type, ctx: CompilerContext) -> hv.Value | |||||||||
| return None | ||||||||||
|
|
||||||||||
|
|
||||||||||
| ARRAY_COMPREHENSION_INIT: Final[GlobalConstId] = GlobalConstId.fresh( | ||||||||||
| "array.__comprehension.init" | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| ARRAY_UNWRAP_ELEM: Final[GlobalConstId] = GlobalConstId.fresh("array.__unwrap_elem") | ||||||||||
| ARRAY_WRAP_ELEM: Final[GlobalConstId] = GlobalConstId.fresh("array.__wrap_elem") | ||||||||||
|
|
||||||||||
|
|
@@ -848,54 +851,6 @@ def python_value_to_hugr(v: Any, exp_ty: Type, ctx: CompilerContext) -> hv.Value | |||||||||
| ) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def array_comprehension_init_func(ctx: CompilerContext) -> hf.Function: | ||||||||||
| """Returns the Hugr function that is used to initialise arrays elements before a | ||||||||||
| comprehension. | ||||||||||
|
|
||||||||||
| Just returns the `None` variant of the optional element type. | ||||||||||
|
|
||||||||||
| See https://github.com/CQCL/guppylang/issues/629 | ||||||||||
| """ | ||||||||||
| v = ht.Variable(0, ht.TypeBound(ht.TypeBound.Linear)) | ||||||||||
| sig = ht.PolyFuncType( | ||||||||||
| params=[ht.TypeTypeParam(ht.TypeBound.Linear)], | ||||||||||
| body=ht.FunctionType([], [ht.Option(v)]), | ||||||||||
| ) | ||||||||||
| func, already_defined = ctx.declare_global_func(ARRAY_COMPREHENSION_INIT, sig) | ||||||||||
| if not already_defined: | ||||||||||
| func.set_outputs(func.add_op(ops.Tag(0, ht.Option(v)))) | ||||||||||
| return func | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def array_unwrap_elem(ctx: CompilerContext) -> hf.Function: | ||||||||||
| """Returns the Hugr function that is used to unwrap the elements in an option array | ||||||||||
| to turn it into a regular array.""" | ||||||||||
| v = ht.Variable(0, ht.TypeBound(ht.TypeBound.Linear)) | ||||||||||
| sig = ht.PolyFuncType( | ||||||||||
| params=[ht.TypeTypeParam(ht.TypeBound.Linear)], | ||||||||||
| body=ht.FunctionType([ht.Option(v)], [v]), | ||||||||||
| ) | ||||||||||
| func, already_defined = ctx.declare_global_func(ARRAY_UNWRAP_ELEM, sig) | ||||||||||
| if not already_defined: | ||||||||||
| msg = "Linear array element has already been used" | ||||||||||
| func.set_outputs(build_unwrap(func, func.inputs()[0], msg)) | ||||||||||
| return func | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def array_wrap_elem(ctx: CompilerContext) -> hf.Function: | ||||||||||
| """Returns the Hugr function that is used to wrap the elements in an regular array | ||||||||||
| to turn it into a option array.""" | ||||||||||
| v = ht.Variable(0, ht.TypeBound(ht.TypeBound.Linear)) | ||||||||||
| sig = ht.PolyFuncType( | ||||||||||
| params=[ht.TypeTypeParam(ht.TypeBound.Linear)], | ||||||||||
| body=ht.FunctionType([v], [ht.Option(v)]), | ||||||||||
| ) | ||||||||||
| func, already_defined = ctx.declare_global_func(ARRAY_WRAP_ELEM, sig) | ||||||||||
| if not already_defined: | ||||||||||
| func.set_outputs(func.add_op(ops.Tag(1, ht.Option(v)), func.inputs()[0])) | ||||||||||
| return func | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def array_read_bool(ctx: CompilerContext) -> hf.Function: | ||||||||||
| """Returns the Hugr function that is used to unwrap the elements in an option array | ||||||||||
| to turn it into a regular array.""" | ||||||||||
|
|
@@ -944,31 +899,21 @@ def apply_array_op_with_conversions( | |||||||||
| output array. | ||||||||||
|
|
||||||||||
| Transformations: | ||||||||||
| 1. Unwraps / wraps elements in options. | ||||||||||
| 3. (Optional) Converts from / to opaque bool to / from Hugr bool. | ||||||||||
| 1. (Optional) Converts from / to opaque bool to / from Hugr bool. | ||||||||||
acl-cqc marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
| 2. Converts from / to value array to / from standard Hugr array. | ||||||||||
| """ | ||||||||||
| unwrap = array_unwrap_elem(ctx) | ||||||||||
| unwrap = builder.load_function( | ||||||||||
| unwrap, | ||||||||||
| instantiation=ht.FunctionType([ht.Option(elem_ty)], [elem_ty]), | ||||||||||
| type_args=[ht.TypeTypeArg(elem_ty)], | ||||||||||
| ) | ||||||||||
| map_op = array_map(ht.Option(elem_ty), size_arg, elem_ty) | ||||||||||
| unwrapped_array = builder.add_op(map_op, input_array, unwrap) | ||||||||||
|
|
||||||||||
| if convert_bool: | ||||||||||
| array_read = array_read_bool(ctx) | ||||||||||
| array_read = builder.load_function(array_read) | ||||||||||
| map_op = array_map(OpaqueBool, size_arg, ht.Bool) | ||||||||||
| unwrapped_array = builder.add_op(map_op, unwrapped_array, array_read) | ||||||||||
| input_array = builder.add_op(map_op, input_array, array_read) | ||||||||||
| elem_ty = ht.Bool | ||||||||||
|
|
||||||||||
| unwrapped_array = builder.add_op( | ||||||||||
| array_convert_to_std_array(elem_ty, size_arg), unwrapped_array | ||||||||||
| input_array = builder.add_op( | ||||||||||
| array_convert_to_std_array(elem_ty, size_arg), input_array | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| result_array = builder.add_op(op, unwrapped_array) | ||||||||||
| result_array = builder.add_op(op, input_array) | ||||||||||
|
|
||||||||||
| result_array = builder.add_op( | ||||||||||
| array_convert_from_std_array(elem_ty, size_arg), result_array | ||||||||||
|
|
@@ -981,11 +926,4 @@ def apply_array_op_with_conversions( | |||||||||
| result_array = builder.add_op(map_op, result_array, array_make_opaque) | ||||||||||
| elem_ty = OpaqueBool | ||||||||||
|
|
||||||||||
| wrap = array_wrap_elem(ctx) | ||||||||||
| wrap = builder.load_function( | ||||||||||
| wrap, | ||||||||||
| instantiation=ht.FunctionType([elem_ty], [ht.Option(elem_ty)]), | ||||||||||
| type_args=[ht.TypeTypeArg(elem_ty)], | ||||||||||
| ) | ||||||||||
| map_op = array_map(elem_ty, size_arg, ht.Option(elem_ty)) | ||||||||||
| return builder.add_op(map_op, result_array, wrap) | ||||||||||
| return result_array | ||||||||||
Uh oh!
There was an error while loading. Please reload this page.