Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport] Fix equality and pickling of DAGCircuit with stretches #14114

Merged
merged 2 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,14 @@ impl DAGCircuit {
})
.into_py_dict(py)?,
)?;
out_dict.set_item(
"captured_stretches",
self.captured_stretches.iter().into_py_dict(py)?,
)?;
out_dict.set_item(
"declared_stretches",
self.declared_stretches.iter().into_py_dict(py)?,
)?;
out_dict.set_item("vars_by_type", self.vars_by_type.clone())?;
out_dict.set_item("qubits", self.qubits.bits())?;
out_dict.set_item("clbits", self.clbits.bits())?;
Expand Down Expand Up @@ -743,7 +751,14 @@ impl DAGCircuit {
};
self.vars_info.insert(key.extract()?, info);
}

self.captured_stretches = dict_state
.get_item("captured_stretches")?
.unwrap()
.extract()?;
self.declared_stretches = dict_state
.get_item("declared_stretches")?
.unwrap()
.extract()?;
let binding = dict_state.get_item("qubits")?.unwrap();
let qubits_raw = binding.extract::<Vec<ShareableQubit>>()?;
for bit in qubits_raw.into_iter() {
Expand Down Expand Up @@ -2324,12 +2339,11 @@ impl DAGCircuit {
return Ok(false);
}

for (our_stretch, their_stretch) in self
.captured_stretches
.values()
.zip(other.captured_stretches.values())
{
if !our_stretch.bind(py).eq(their_stretch)? {
// Note that `captured_stretches` is a set and thus order of captured stretches
// does not influence equality.
let our_captured_stretches = PySet::new(py, self.captured_stretches.values())?;
for their_stretch in other.captured_stretches.values() {
if !our_captured_stretches.contains(their_stretch)? {
return Ok(false);
}
}
Expand Down Expand Up @@ -4380,6 +4394,7 @@ impl DAGCircuit {
if !self.vars_by_type[DAGVarType::Capture as usize]
.bind(py)
.is_empty()
|| !self.captured_stretches.is_empty()
{
return Err(DAGCircuitError::new_err(
"cannot add inputs to a circuit with captures",
Expand Down Expand Up @@ -4517,6 +4532,23 @@ impl DAGCircuit {
.len()
}

// These three stretch getter methods are for testing purposes, they'll be public
// APIs in Qiskit 2.1
#[getter]
fn _num_stretches(&self) -> usize {
self._num_captured_stretches() + self._num_declared_stretches()
}

#[getter]
fn _num_captured_stretches(&self) -> usize {
self.captured_stretches.len()
}

#[getter]
fn _num_declared_stretches(&self) -> usize {
self.declared_stretches.len()
}

/// Is this realtime variable in the DAG?
///
/// Args:
Expand Down
17 changes: 16 additions & 1 deletion test/python/circuit/test_circuit_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,25 @@ def test_stretch_circuit_equality(self):
b = expr.Stretch.new("b")
c = expr.Stretch.new("c")

# Capture order doesn't matter in circuit equality!
qc1 = QuantumCircuit(captures=[a, b, c])
self.assertEqual(qc1, QuantumCircuit(captures=[a, b, c]))
self.assertNotEqual(qc1, QuantumCircuit(captures=[c, b, a]))
self.assertEqual(qc1, QuantumCircuit(captures=[c, b, a]))

qc1 = QuantumCircuit()
qc1.add_stretch(a)
qc1.add_stretch(b)
qc1.add_stretch(c)

qc2 = QuantumCircuit()
qc2.add_stretch(c)
qc2.add_stretch(b)
qc2.add_stretch(a)

# But declaration order does!
self.assertNotEqual(qc1, qc2)

qc1 = QuantumCircuit(captures=[a, b, c])
qc2 = QuantumCircuit(captures=[a])
qc2.add_stretch(b)
qc2.add_stretch(c)
Expand Down
111 changes: 109 additions & 2 deletions test/python/dagcircuit/test_dagcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

# pylint: disable=invalid-name

"""Test for the DAGCircuit object"""

from __future__ import annotations
Expand Down Expand Up @@ -508,13 +510,16 @@ def test_copy_empty_like_vars(self):
dag.add_input_var(expr.Var.new("b", types.Uint(8)))
dag.add_declared_var(expr.Var.new("c", types.Bool()))
dag.add_declared_var(expr.Var.new("d", types.Uint(8)))
dag.add_declared_stretch(expr.Stretch.new("e"))
self.assertEqual(dag, dag.copy_empty_like())

dag = DAGCircuit()
dag.add_captured_var(expr.Var.new("a", types.Bool()))
dag.add_captured_var(expr.Var.new("b", types.Uint(8)))
dag.add_declared_var(expr.Var.new("c", types.Bool()))
dag.add_declared_var(expr.Var.new("d", types.Uint(8)))
dag.add_declared_stretch(expr.Stretch.new("c"))
dag.add_declared_var(expr.Var.new("d", types.Bool()))
dag.add_declared_var(expr.Var.new("e", types.Uint(8)))
dag.add_declared_stretch(expr.Stretch.new("f"))
self.assertEqual(dag, dag.copy_empty_like())

def test_copy_empty_like_vars_captures(self):
Expand All @@ -523,22 +528,26 @@ def test_copy_empty_like_vars_captures(self):
b = expr.Var.new("b", types.Uint(8))
c = expr.Var.new("c", types.Bool())
d = expr.Var.new("d", types.Uint(8))
e = expr.Stretch.new("e")
all_captures = DAGCircuit()
for var in [a, b, c, d]:
all_captures.add_captured_var(var)
all_captures.add_captured_stretch(e)

dag = DAGCircuit()
dag.add_input_var(a)
dag.add_input_var(b)
dag.add_declared_var(c)
dag.add_declared_var(d)
dag.add_declared_stretch(e)
self.assertEqual(all_captures, dag.copy_empty_like(vars_mode="captures"))

dag = DAGCircuit()
dag.add_captured_var(a)
dag.add_captured_var(b)
dag.add_declared_var(c)
dag.add_declared_var(d)
dag.add_declared_stretch(e)
self.assertEqual(all_captures, dag.copy_empty_like(vars_mode="captures"))

def test_copy_empty_like_vars_drop(self):
Expand All @@ -547,19 +556,22 @@ def test_copy_empty_like_vars_drop(self):
b = expr.Var.new("b", types.Uint(8))
c = expr.Var.new("c", types.Bool())
d = expr.Var.new("d", types.Uint(8))
e = expr.Stretch.new("e")

dag = DAGCircuit()
dag.add_input_var(a)
dag.add_input_var(b)
dag.add_declared_var(c)
dag.add_declared_var(d)
dag.add_declared_stretch(e)
self.assertEqual(DAGCircuit(), dag.copy_empty_like(vars_mode="drop"))

dag = DAGCircuit()
dag.add_captured_var(a)
dag.add_captured_var(b)
dag.add_declared_var(c)
dag.add_declared_var(d)
dag.add_captured_stretch(e)
self.assertEqual(DAGCircuit(), dag.copy_empty_like(vars_mode="drop"))

def test_remove_busy_clbit(self):
Expand Down Expand Up @@ -1850,8 +1862,10 @@ def test_present_vars(self):
"""The vars should be compared whether or not they're used."""
a_bool = expr.Var.new("a", types.Bool())
a_u8 = expr.Var.new("a", types.Uint(8))
a_stretch = expr.Stretch.new("a")
a_u8_other = expr.Var.new("a", types.Uint(8))
b_bool = expr.Var.new("b", types.Bool())
b_stretch = expr.Stretch.new("b")

left = DAGCircuit()
left.add_input_var(a_bool)
Expand Down Expand Up @@ -1895,6 +1909,17 @@ def test_present_vars(self):
self.assertEqual(right.num_vars, 2)
self.assertNotEqual(left, right)

right = DAGCircuit()
right.add_captured_stretch(a_stretch)
right.add_captured_stretch(b_stretch)
self.assertEqual(right.num_input_vars, 0)
self.assertEqual(right.num_captured_vars, 0)
self.assertEqual(right.num_declared_vars, 0)
self.assertEqual(right._num_captured_stretches, 2)
self.assertEqual(right._num_declared_stretches, 0)
self.assertEqual(right._num_stretches, 2)
self.assertNotEqual(left, right)

right = DAGCircuit()
right.add_declared_var(a_bool)
right.add_declared_var(b_bool)
Expand All @@ -1904,11 +1929,24 @@ def test_present_vars(self):
self.assertEqual(right.num_vars, 2)
self.assertNotEqual(left, right)

right = DAGCircuit()
right.add_declared_stretch(a_stretch)
right.add_declared_stretch(b_stretch)
self.assertEqual(right.num_input_vars, 0)
self.assertEqual(right.num_captured_vars, 0)
self.assertEqual(right.num_declared_vars, 0)
self.assertEqual(right._num_captured_stretches, 0)
self.assertEqual(right._num_declared_stretches, 2)
self.assertEqual(right._num_stretches, 2)
self.assertNotEqual(left, right)

left = DAGCircuit()
left.add_captured_var(a_u8)
left.add_captured_stretch(b_stretch)

right = DAGCircuit()
right.add_captured_var(a_u8)
right.add_captured_stretch(b_stretch)
self.assertEqual(left, right)

right = DAGCircuit()
Expand Down Expand Up @@ -2111,13 +2149,18 @@ def test_forbid_mixing_captures_inputs(self):
"""Test that a DAG can't have both captures and inputs."""
a = expr.Var.new("a", types.Bool())
b = expr.Var.new("b", types.Bool())
c = expr.Stretch.new("c")

dag = DAGCircuit()
dag.add_input_var(a)
with self.assertRaisesRegex(
DAGCircuitError, "cannot add captures to a circuit with inputs"
):
dag.add_captured_var(b)
with self.assertRaisesRegex(
DAGCircuitError, "cannot add captures to a circuit with inputs"
):
dag.add_captured_stretch(c)

dag = DAGCircuit()
dag.add_captured_var(a)
Expand All @@ -2126,6 +2169,13 @@ def test_forbid_mixing_captures_inputs(self):
):
dag.add_input_var(b)

dag = DAGCircuit()
dag.add_captured_stretch(c)
with self.assertRaisesRegex(
DAGCircuitError, "cannot add inputs to a circuit with captures"
):
dag.add_input_var(a)

def test_forbid_adding_nonstandalone_var(self):
"""Temporary "wrapping" vars aren't standalone and can't be tracked separately."""
dag = DAGCircuit()
Expand All @@ -2138,12 +2188,69 @@ def test_forbid_adding_conflicting_vars(self):
"""Can't re-add a variable that exists, nor a shadowing variable in the same scope."""
a1 = expr.Var.new("a", types.Bool())
a2 = expr.Var.new("a", types.Bool())
a3 = expr.Stretch.new("a")
dag = DAGCircuit()
dag.add_declared_var(a1)
with self.assertRaisesRegex(DAGCircuitError, "already present in the circuit"):
dag.add_declared_var(a1)
with self.assertRaisesRegex(DAGCircuitError, "cannot add .* as its name shadows"):
dag.add_declared_var(a2)
with self.assertRaisesRegex(DAGCircuitError, "cannot add .* as its name shadows"):
dag.add_declared_stretch(a3)

dag = DAGCircuit()
dag.add_declared_stretch(a3)
with self.assertRaisesRegex(DAGCircuitError, "already present in the circuit"):
dag.add_captured_stretch(a3)
with self.assertRaisesRegex(DAGCircuitError, "cannot add .* as its name shadows"):
dag.add_declared_var(a1)
with self.assertRaisesRegex(DAGCircuitError, "cannot add .* as its name shadows"):
dag.add_declared_var(a2)

def test_pickle_stretches(self):
"""Test stretches preserved through pickle."""
a = expr.Stretch.new("a")
b = expr.Stretch.new("b")

# Check captures and declarations.
dag = DAGCircuit()
dag.add_declared_stretch(a)
dag.add_captured_stretch(b)

self.assertEqual(dag._num_stretches, 2)
self.assertEqual(dag._num_captured_stretches, 1)
self.assertEqual(dag._num_declared_stretches, 1)

with io.BytesIO() as buf:
pickle.dump(dag, buf)
buf.seek(0)
output = pickle.load(buf)

self.assertEqual(output._num_stretches, 2)
self.assertEqual(output._num_captured_stretches, 1)
self.assertEqual(output._num_declared_stretches, 1)
self.assertEqual(output, dag)

def test_deepcopy_stretches(self):
"""Test stretches preserved through deepcopy."""
a = expr.Stretch.new("a")
b = expr.Stretch.new("b")

# Check captures and declarations.
dag = DAGCircuit()
dag.add_declared_stretch(a)
dag.add_captured_stretch(b)

self.assertEqual(dag._num_stretches, 2)
self.assertEqual(dag._num_captured_stretches, 1)
self.assertEqual(dag._num_declared_stretches, 1)

output = copy.deepcopy(dag)

self.assertEqual(output._num_stretches, 2)
self.assertEqual(output._num_captured_stretches, 1)
self.assertEqual(output._num_declared_stretches, 1)
self.assertEqual(output, dag)


class TestDagSubstitute(QiskitTestCase):
Expand Down