diff --git a/guppylang-internals/src/guppylang_internals/tracing/function.py b/guppylang-internals/src/guppylang_internals/tracing/function.py index 51ffff7c7..880c9e50b 100644 --- a/guppylang-internals/src/guppylang_internals/tracing/function.py +++ b/guppylang-internals/src/guppylang_internals/tracing/function.py @@ -11,6 +11,7 @@ from guppylang_internals.checker.errors.type_errors import TypeMismatchError from guppylang_internals.compiler.core import CompilerContext, DFContainer from guppylang_internals.compiler.expr_compiler import ExprCompiler +from guppylang_internals.definition.custom import CustomFunctionDef from guppylang_internals.definition.value import CallableDef from guppylang_internals.diagnostic import Error from guppylang_internals.error import GuppyComptimeError, GuppyError, exception_hook @@ -44,6 +45,13 @@ class TracingReturnError(Error): msg: str +@dataclass(frozen=True) +class VarArgsError(Error): + title: ClassVar[str] = "Varargs functions not supported in comptime" + message: ClassVar[str] = "{msg}" + msg: str + + def trace_function( python_func: Callable[..., Any], ty: FunctionType, @@ -173,6 +181,12 @@ def trace_call(func: CallableDef, *args: Any) -> Any: # Update inouts # If the input types of the function aren't known, we can't check this. # This is the case for functions with a custom checker and no type annotations. + if isinstance(func, CustomFunctionDef) and not func.has_signature: + msg = f"Try wrapping `{func.name}` in a non-comptime @guppy function" + err = VarArgsError(call_node, msg) + + raise GuppyError(err) from None + if len(func.ty.inputs) != 0: for inp, arg, var in zip(func.ty.inputs, args, arg_vars, strict=True): if InputFlags.Inout in inp.flags: diff --git a/tests/emulator/test_builder.py b/tests/emulator/test_builder.py index af82d6b33..3140c6eba 100644 --- a/tests/emulator/test_builder.py +++ b/tests/emulator/test_builder.py @@ -152,7 +152,6 @@ def test_emulator_builder_immutability(): assert builder.verbose is False - def test_emulator_builder_reuse(): """Test that the same builder can be used multiple times.""" builder = EmulatorBuilder().with_name("reusable").with_verbose(True) diff --git a/tests/error/tracing_errors/varargs.err b/tests/error/tracing_errors/varargs.err new file mode 100644 index 000000000..a2a8af9e2 --- /dev/null +++ b/tests/error/tracing_errors/varargs.err @@ -0,0 +1,7 @@ +Traceback (most recent call last): + File "$FILE", line 12, in + main.compile() + File "$FILE", line 8, in main + barrier(qs) +guppylang_internals.error.GuppyComptimeError: Varargs functions not supported in comptime +Try wrapping `barrier` in a non-comptime @guppy function diff --git a/tests/error/tracing_errors/varargs.py b/tests/error/tracing_errors/varargs.py new file mode 100644 index 000000000..a8daaf725 --- /dev/null +++ b/tests/error/tracing_errors/varargs.py @@ -0,0 +1,12 @@ +from guppylang import guppy +from guppylang.std.quantum import qubit, h, discard_array +from guppylang.std.builtins import array, barrier + +@guppy.comptime +def main() -> None: + qs = array(qubit() for _ in range(5)) + barrier(qs) + h(qs[0]) + discard_array(qs) + +main.compile() diff --git a/tests/integration/tracing/test_quantum.py b/tests/integration/tracing/test_quantum.py index e8538d40f..7f0ddb849 100644 --- a/tests/integration/tracing/test_quantum.py +++ b/tests/integration/tracing/test_quantum.py @@ -1,8 +1,8 @@ from guppylang import qubit from guppylang.decorator import guppy from guppylang.std.angles import angle -from guppylang.std.builtins import array -from guppylang.std.quantum import h, measure, cx, rz +from guppylang.std.builtins import array, barrier +from guppylang.std.quantum import h, measure, cx, rz, discard_array import itertools @@ -40,3 +40,20 @@ def test(qs: array[qubit, 10], theta: angle) -> None: theta /= 2 validate(test.compile()) + + +def test_barrier_wrapper(validate): + """Test workaround for https://github.com/CQCL/guppylang/issues/1189""" + + @guppy + def barrier_wrapper(qs: array[qubit, 5]) -> None: + barrier(qs) + + @guppy.comptime + def main() -> None: + qs = array(qubit() for _ in range(5)) + barrier_wrapper(qs) + h(qs[0]) + discard_array(qs) + + validate(main.compile())