From 35a1e223bfab02b488494db23d532565eec10cf2 Mon Sep 17 00:00:00 2001 From: beizha Date: Fri, 29 Mar 2024 14:11:46 +0800 Subject: [PATCH] feat: qfasm support parameter --- quafu/elements/utils.py | 4 +- quafu/qfasm/qfasm_lexer.py | 5 + quafu/qfasm/qfasm_parser.py | 121 +++++++++++++-------- tests/quafu/qasm/parameter_test.py | 168 +++++++++++++++++++++++++++++ tests/quafu/qasm/parser_test.py | 2 +- 5 files changed, 250 insertions(+), 50 deletions(-) create mode 100644 tests/quafu/qasm/parameter_test.py diff --git a/quafu/elements/utils.py b/quafu/elements/utils.py index be0954f..19c4b79 100644 --- a/quafu/elements/utils.py +++ b/quafu/elements/utils.py @@ -55,14 +55,14 @@ def handle_expression(param: ParameterType): retstr = f"({retstr} - {handle_expression(param.operands[i])})" elif param.funcs[i] == _operator.truediv: retstr = f"{retstr} / {handle_expression(param.operands[i])}" + elif param.funcs[i] == _operator.pow: + retstr = f"({retstr}) ^ {handle_expression(param.operands[i])}" elif param.funcs[i] == anp.sin: retstr = f"sin({retstr})" elif param.funcs[i] == anp.cos: retstr = f"cos({retstr})" elif param.funcs[i] == anp.tan: retstr = f"tan({retstr})" - elif param.funcs[i] == _operator.pow: - retstr = f"pow({retstr}, {handle_expression(param.operands[i])})" elif param.funcs[i] == anp.arcsin: retstr = f"asin({retstr})" elif param.funcs[i] == anp.arccos: diff --git a/quafu/qfasm/qfasm_lexer.py b/quafu/qfasm/qfasm_lexer.py index c73cc6c..6b49ef1 100644 --- a/quafu/qfasm/qfasm_lexer.py +++ b/quafu/qfasm/qfasm_lexer.py @@ -78,6 +78,7 @@ def token(self): "STRING", "ASSIGN", "MATCHES", + "EQUAL", "ID", "UNIT", "CHANNEL", @@ -147,6 +148,10 @@ def t_MATCHES(self, t): r"==" return t + def t_EQUAL(self, t): + r"=" + return t + def t_UNIT(self, t): r"ns|us" return t diff --git a/quafu/qfasm/qfasm_parser.py b/quafu/qfasm/qfasm_parser.py index f3dc44c..0be90e3 100644 --- a/quafu/qfasm/qfasm_parser.py +++ b/quafu/qfasm/qfasm_parser.py @@ -23,11 +23,12 @@ from quafu.qfasm.exceptions import ParserError from quafu import QuantumCircuit - +from quafu.elements import Parameter, ParameterExpression from .qfasm_lexer import QfasmLexer from .qfasm_utils import * -unaryop = ["sin", "cos", "tan", "exp", "ln", "sqrt", "acos", "atan", "asin"] +unaryop = {"sin": "sin", "cos": "cos", "tan": "tan", "exp": "exp", + "ln": "log", "sqrt": "sqrt", "acos": "arccos", "atan": "arctan", "asin": "arcsin"} unarynp = { "sin": np.sin, "cos": np.cos, @@ -84,6 +85,8 @@ def __init__(self, filepath: str = None, debug=False): self.qnum = 0 # cbit num used self.cnum = 0 + # param + self.params = {} def add_U_CX(self): # Add U and CX in global_symtab @@ -172,7 +175,9 @@ def handle_gateins(self, gateins: GateInstruction): for i in range(symnode.num): tempargs.append(symnode.start + i) args.append(tempargs) - + # change carg to parameter + for i in range(len(gateins.cargs)): + gateins.cargs[i] = self.compute_exp(gateins.cargs[i]) # call many times for i in range(len(args[0])): oneargs = [] @@ -252,52 +257,49 @@ def handle_gateins(self, gateins: GateInstruction): # change newins's qarg to real q for i in range(len(newins.qargs)): newins.qargs[i] = qargdict[newins.qargs[i].name] - # change newins's carg to real carg (consider exp) + # change newins's carg to real carg (consider exp and parameter) for i in range(len(newins.cargs)): - if not ( - isinstance(newins.cargs[i], int) - or isinstance(newins.cargs[i], float) - ): - # for expression - newins.cargs[i] = self.compute_exp(newins.cargs[i], cargdict) + # for expression and parameter, it will return parameter or int/float + newins.cargs[i] = self.compute_exp(newins.cargs[i], cargdict) # now, recurse gate_list.extend(self.handle_gateins(newins)) return gate_list - def compute_exp(self, carg, cargdict: dict): + def compute_exp(self, carg, cargdict: dict={}): # recurse - if isinstance(carg, int) or isinstance(carg, float): + if isinstance(carg, int) or isinstance(carg, float) or isinstance(carg, ParameterExpression): return carg # if it's id, should get real number from gateins elif isinstance(carg, Id): - return cargdict[carg.name] + if carg.name in cargdict: + return cargdict[carg.name] + # if it's parameter, just return + else: + return self.params[carg.name] elif isinstance(carg, UnaryExpr): if carg.type == "-": return -self.compute_exp(carg.children[0], cargdict) elif carg.type in unaryop: - return unarynp[carg.type](self.compute_exp(carg.children[0], cargdict)) + nowcarg = self.compute_exp(carg.children[0], cargdict) + if isinstance(nowcarg, ParameterExpression): + func = getattr(nowcarg, unaryop[carg.type]) + return func() + else: + return unarynp[carg.type](nowcarg) elif isinstance(carg, BinaryExpr): + cargl = self.compute_exp(carg.children[0], cargdict) + cargr = self.compute_exp(carg.children[1], cargdict) if carg.type == "+": - return self.compute_exp(carg.children[0], cargdict) + self.compute_exp( - carg.children[1], cargdict - ) + return cargl + cargr elif carg.type == "-": - return self.compute_exp(carg.children[0], cargdict) - self.compute_exp( - carg.children[1], cargdict - ) + return cargl - cargr elif carg.type == "*": - return self.compute_exp(carg.children[0], cargdict) * self.compute_exp( - carg.children[1], cargdict - ) + return cargl * cargr elif carg.type == "/": - return self.compute_exp(carg.children[0], cargdict) / self.compute_exp( - carg.children[1], cargdict - ) + return cargl / cargr elif carg.type == "^": - return self.compute_exp(carg.children[0], cargdict) ** self.compute_exp( - carg.children[1], cargdict - ) + return cargl ** cargr def addInstruction(self, qc: QuantumCircuit, ins): if ins is None: @@ -416,8 +418,19 @@ def check_qargs(self, gateins: GateInstruction): f"Qubit used as different argument when call gate {gateins.name} at line {gateins.lineno} file {gateins.filename}" ) + def check_param(self, carg): + if isinstance(carg, int) or isinstance(carg, float): + return + elif isinstance(carg, Id) and carg.name not in self.params: + raise ParserError(f"The parameter {carg.name} is undefined at line {carg.lineno} file {carg.filename}") + elif isinstance(carg, UnaryExpr): + self.check_param(carg.children[0]) + elif isinstance(carg, BinaryExpr): + self.check_param(carg.children[0]) + self.check_param(carg.children[1]) + def check_cargs(self, gateins: GateInstruction): - # check that cargs belongs to unary (they must be int or float) + # check that cargs belongs to unary (they must be int or float or parameter) # cargs is different from CREG if gateins.name not in self.nuop and gateins.name not in self.mulctrl: if gateins.name not in self.global_symtab: @@ -429,12 +442,9 @@ def check_cargs(self, gateins: GateInstruction): raise ParserError( f"The {gateins.name} is not declared as a gate at line {gateins.lineno} file {gateins.filename}" ) - # check every carg in [int, float] + # check every carg in [int, float, parameter] for carg in gateins.cargs: - if not (isinstance(carg, int) or isinstance(carg, float)): - raise ParserError( - f"Classical argument must be of type int or float at line {gateins.lineno} file {gateins.filename}" - ) + self.check_param(carg) # check cargs's num matches gate's delcared cargs if len(gateins.cargs) != len(gatenote.cargs): raise ParserError( @@ -482,11 +492,11 @@ def check_gate_qargs(self, gateins: GateInstruction): def check_gate_cargs(self, gateins: GateInstruction): # check gate_op's classcal args, must matches num declared by gate - if gateins.name == "barrier" and len(gateins.cargs) > 0: + if gateins.name in ["barrier", "reset", "measure"] and len(gateins.cargs) > 0: raise ParserError( f"Barrier can not receive classical argument at line {gateins.lineno} file {gateins.filename}" ) - if gateins.name != "barrier": + if gateins.name not in ["barrier", "reset", "measure"]: if gateins.name not in self.global_symtab: raise ParserError( f"The gate {gateins.name} is undefined at line {gateins.lineno} file {gateins.filename}" @@ -500,7 +510,7 @@ def check_gate_cargs(self, gateins: GateInstruction): raise ParserError( f"The number of classical argument declared in gate {gateins.name} is inconsistent with instruction at line {gateins.lineno} file {gateins.filename}" ) - # check carg must from gate declared argument or int/float + # check carg must from gate declared argument or int/float or parameter for carg in gateins.cargs: # recurse check expression self.check_carg_declartion(carg) @@ -510,16 +520,17 @@ def check_carg_declartion(self, node): return if isinstance(node, Id): # check declaration - if node.name not in self.symtab: + if node.name in self.symtab: + symnode = self.symtab[node.name] + if symnode.type != "CARG": + raise ParserError( + f"The {node.name} is not declared as a classical bit at line {node.lineno} file {node.filename}" + ) + return + elif node.name not in self.params: raise ParserError( f"The classical argument {node.name} is undefined at line {node.lineno} file {node.filename}" ) - symnode = self.symtab[node.name] - if symnode.type != "CARG": - raise ParserError( - f"The {node.name} is not declared as a classical bit at line {node.lineno} file {node.filename}" - ) - return if isinstance(node, UnaryExpr): self.check_carg_declartion(node.children[0]) elif isinstance(node, BinaryExpr): @@ -579,7 +590,7 @@ def p_statement_qop(self, p): | qif error """ if p[2] != ";": - raise ParserError(f"Expecting ';' behind statement") + raise ParserError(f"Expecting ';' behind statement at line {p[1].lineno} file {p[1].filename}") p[0] = p[1] def p_statement_empty(self, p): @@ -979,6 +990,8 @@ def p_statement_bitdecl(self, p): """ statement : qdecl ';' | cdecl ';' + | defparam ';' + | defparam error | qdecl error | cdecl error | error @@ -991,6 +1004,20 @@ def p_statement_bitdecl(self, p): ) p[0] = p[1] + def p_statement_defparam(self, p): + """ + defparam : id EQUAL FLOAT + | id EQUAL INT + | id EQUAL error + """ + if not isinstance(p[3], int) and not isinstance(p[3], float): + raise ParserError(f"Expecting 'INT' or 'FLOAT behind '=' at line {p[1].lineno} file {p[1].filename}") + param_name = p[1].name + if param_name in self.params: + raise ParserError(f"Duplicate declaration for parameter {p[1].name} at line {p[1].lineno} file {p[1].filename}") + self.params[param_name] = Parameter(param_name, p[3]) + p[0] = None + def p_qdecl(self, p): """ qdecl : QREG indexed_id @@ -1169,7 +1196,7 @@ def p_expr_mathfunc(self, p): """ if p[1].name not in unaryop: raise ParserError( - f"Math function {p[1].name} not supported, only support {unaryop} line {p[1].lineno} file {p[1].filename}" + f"Math function {p[1].name} not supported, only support {unaryop.keys()} line {p[1].lineno} file {p[1].filename}" ) if not isinstance(p[3], Node): p[0] = unarynp[p[1].name](p[3]) diff --git a/tests/quafu/qasm/parameter_test.py b/tests/quafu/qasm/parameter_test.py new file mode 100644 index 0000000..db6d28f --- /dev/null +++ b/tests/quafu/qasm/parameter_test.py @@ -0,0 +1,168 @@ +import math + +from quafu import QuantumCircuit +from quafu.qfasm.qfasm_convertor import qasm_to_quafu +from quafu.elements import ParameterExpression, Parameter + +class TestParser: + """ + Test for PLY parser + """ + + def compare_cir(self, qc1: QuantumCircuit, qc2: QuantumCircuit): + # compare reg and compare gates + assert len(qc1.qregs) == len(qc2.qregs) + for i in range(len(qc1.qregs)): + reg1 = qc1.qregs[i] + reg2 = qc2.qregs[i] + assert len(reg1.qubits) == len(reg2.qubits) + assert len(qc1.gates) == len(qc2.gates) + assert len(qc1.variables) == len(qc2.variables) + for i in range(len(qc1.variables)): + self.compare_parameter(qc1.variables[i], qc2.variables[i]) + for i in range(len(qc1.gates)): + gate1 = qc1.gates[i] + gate2 = qc2.gates[i] + assert gate1.name == gate2.name + if hasattr(gate1, "pos"): + assert gate1.pos == gate2.pos + if hasattr(gate1, "paras"): + assert len(gate1.paras) == len(gate2.paras) + for j in range(len(gate1.paras)): + self.compare_parameter(gate1.paras[j], gate2.paras[j]) + + def compare_parameter(self, param1, param2): + if isinstance(param1, ParameterExpression) or isinstance(param2, ParameterExpression): + assert isinstance(param2, ParameterExpression) + assert isinstance(param1, ParameterExpression) + assert param1.latex == param2.latex + assert param1.value == param2.value + assert len(param1.funcs) == len(param2.funcs) + assert len(param1.operands) == len(param2.operands) + if not param1.latex: + self.compare_parameter(param1.pivot, param2.pivot) + for i in range(len(param1.funcs)): + assert param1.funcs[i] == param2.funcs[i] + for i in range(len(param1.operands)): + self.compare_parameter(param1.operands[i], param2.operands[i]) + else: + assert param1 == param2 + # ---------------------------------------- + # test for parameter + # ---------------------------------------- + def test_parameter_plain(self): + qasm = """ + theta1 = 1.0; theta2 = 2.0; qreg q[2]; rx(theta1) q[0]; rx(theta2) q[1]; + """ + cir = qasm_to_quafu(openqasm=qasm) + assert cir.gates[0].name == "RX" + assert cir.gates[1].name == "RX" + self.compare_parameter(cir.variables[0], Parameter("theta1", 1.0)) + self.compare_parameter(cir.variables[1], Parameter("theta2", 2.0)) + self.compare_parameter(cir.gates[0].paras[0], Parameter("theta1", 1.0)) + self.compare_parameter(cir.gates[1].paras[0], Parameter("theta2", 2.0)) + + def test_parameter_func(self): + qasm = """ + theta1 = 1.0; theta2 = 2.0; + gate test(rz, ry) a { + rz(rz) a; + ry(ry) a; + } + qreg q[1]; + test(theta1, theta2) q[0]; + """ + cir = qasm_to_quafu(openqasm=qasm) + assert cir.gates[0].name == "RZ" + assert cir.gates[1].name == "RY" + self.compare_parameter(cir.variables[0], Parameter("theta1", 1.0)) + self.compare_parameter(cir.variables[1], Parameter("theta2", 2.0)) + self.compare_parameter(cir.gates[0].paras[0], Parameter("theta1", 1.0)) + self.compare_parameter(cir.gates[1].paras[0], Parameter("theta2", 2.0)) + + def test_parameter_func_mix(self): + qasm = """ + theta = 1.0; + theta1 = 3.0; + gate test(rz, theta1) a { + rz(rz) a; + ry(theta1) a; + } + qreg q[1]; + test(theta, 2.0) q[0]; + """ + cir = qasm_to_quafu(openqasm=qasm) + assert cir.gates[0].name == "RZ" + assert cir.gates[1].name == "RY" + assert len(cir.variables) == 1 + assert cir.gates[1].paras[0] == 2.0 + self.compare_parameter(cir.variables[0], Parameter("theta", 1.0)) + self.compare_parameter(cir.gates[0].paras[0], Parameter("theta", 1.0)) + + def test_parameter_func_mix2(self): + qasm = """ + theta = 1.0; + theta1 = 3.0; + gate test(rz, theta1) a { + rz(rz) a; + ry(theta) a; + } + qreg q[1]; + test(theta, 2.0) q[0]; + """ + cir = qasm_to_quafu(openqasm=qasm) + assert cir.gates[0].name == "RZ" + assert cir.gates[1].name == "RY" + assert len(cir.variables) == 1 + self.compare_parameter(cir.variables[0], Parameter("theta", 1.0)) + self.compare_parameter(cir.gates[0].paras[0], Parameter("theta", 1.0)) + self.compare_parameter(cir.gates[1].paras[0], Parameter("theta", 1.0)) + + def test_parameter_expression(self): + qasm = """ + theta1 = 1.0; theta2 = 2.0; + qreg q[2]; rx(theta1+theta2) q[0]; rx(theta1+theta1*theta2) q[1]; + """ + cir = qasm_to_quafu(openqasm=qasm) + assert cir.gates[0].name == "RX" + assert cir.gates[1].name == "RX" + theta1 = Parameter("theta1", 1.0) + theta2 = Parameter("theta2", 2.0) + self.compare_parameter(cir.variables[0], theta1) + self.compare_parameter(cir.variables[1], theta2) + self.compare_parameter(cir.gates[0].paras[0], theta1+theta2) + self.compare_parameter(cir.gates[1].paras[0], theta1+theta1*theta2) + qasm = """ + theta1 = 1.0; theta2 = 2.0; + theta3 = 3.0; theta4 = 4.0; + qreg q[2]; + rx(theta1+theta2-theta3*theta4^2) q[0]; + rx(sin(theta1*2+theta2)) q[1]; + """ + cir = qasm_to_quafu(openqasm=qasm) + assert cir.gates[0].name == "RX" + assert cir.gates[1].name == "RX" + theta1 = Parameter("theta1", 1.0) + theta2 = Parameter("theta2", 2.0) + theta3 = Parameter("theta3", 3.0) + theta4 = Parameter("theta4", 4.0) + assert len(cir.variables) == 4 + theta = theta1 + theta2 - theta3 * theta4 ** 2 + self.compare_parameter(cir.variables[0], theta1) + self.compare_parameter(cir.variables[1], theta2) + self.compare_parameter(cir.variables[2], theta3) + self.compare_parameter(cir.variables[3], theta4) + self.compare_parameter(cir.gates[0].paras[0], theta) + theta = (theta1*2+theta2).sin() + self.compare_parameter(cir.gates[1].paras[0], theta) + + + def test_parameter_to_from(self): + qc = QuantumCircuit(4) + theta1 = Parameter("theta1", 1.0) + theta2 = Parameter("theta2", 2.0) + qc.rx(0, theta1) + qc.rx(0, theta2) + qc2 = QuantumCircuit(4) + qc2.from_openqasm(qc.to_openqasm(with_para=True)) + self.compare_cir(qc,qc2) \ No newline at end of file diff --git a/tests/quafu/qasm/parser_test.py b/tests/quafu/qasm/parser_test.py index 609a768..07f886d 100644 --- a/tests/quafu/qasm/parser_test.py +++ b/tests/quafu/qasm/parser_test.py @@ -118,7 +118,7 @@ def test_include_cannot_find_file(self): qasm_to_quafu(token) def test_single_equals_error(self): - with pytest.raises(LexerError, match=r"Illegal character =.*") as e: + with pytest.raises(ParserError, match=r"Illegal IF statement, .*") as e: qasm = f"if (a=2) U(0,0,0)q[0];" qasm_to_quafu(qasm)