diff --git a/pytket/phir/phirgen_parallel.py b/pytket/phir/phirgen_parallel.py index cc428ad..02bb387 100644 --- a/pytket/phir/phirgen_parallel.py +++ b/pytket/phir/phirgen_parallel.py @@ -280,33 +280,31 @@ def format_and_add_primary_commands( groups2qops(fmt_g2q, ops) +def get_transport_time_for_gate(gate: str, machine: "Machine") -> float: + """Return the transport time on the machine for the given gate type.""" + match gate: + case "RZ" | "R1XY": + return machine.sq_time + case "RZZ": + return machine.tq_time + case "Measure": + return machine.meas_prep_time + case "Init": + return 0 + case _: + logger.warning("Gate type %s not assigned a transport duration", gate) + return 0 + + def adjust_phir_transport_time(ops: list["JsonDict"], machine: "Machine") -> None: """Analyze the generated phir and adjust the transport time.""" adjustment = 0.0 for op in ops: if "qop" in op: - match op["qop"]: - case "RZ" | "R1XY": - adjustment += machine.sq_time - case "RZZ": - adjustment += machine.tq_time - case "Measure": - adjustment += machine.meas_prep_time - case _: - logger.warning( - "Gate type %s not assigned a transport duration", op["qop"] - ) + adjustment += get_transport_time_for_gate(op["qop"], machine) if "block" in op and op["block"] == "qparallel": first_op = op["ops"][0]["qop"] - match first_op: - case "RZ" | "R1XY": - adjustment += machine.sq_time - case "RZZ": - adjustment += machine.tq_time - case _: - logger.warning( - "Gate type %s not assigned a transport duration", first_op - ) + adjustment += get_transport_time_for_gate(first_op, machine) if "mop" in op and op["mop"] == "Transport": cost, units = op["duration"] op["duration"] = cost + adjustment, units