diff --git a/tket-py/src/rewrite.rs b/tket-py/src/rewrite.rs index f9825ff05..753749c27 100644 --- a/tket-py/src/rewrite.rs +++ b/tket-py/src/rewrite.rs @@ -83,20 +83,8 @@ pub enum PyRewriter { Vec(Vec), } -// impl> Rewriter for PyRewriter { -// fn get_rewrites(&self, circ: &H) -> Vec { -// match self { -// Self::ECC(ecc) => ecc.0.get_rewrites(circ), -// Self::Vec(rewriters) => rewriters -// .iter() -// .flat_map(|r| r.get_rewrites(circ)) -// .collect(), -// } -// } -// } - impl> Rewriter> for PyRewriter { - fn get_rewrites(&self, circ: &ResourceScope) -> Vec::Node>> { + fn get_rewrites(&self, circ: &ResourceScope) -> Vec> { match self { Self::ECC(ecc) => ecc.0.get_rewrites(circ), Self::Vec(rewriters) => rewriters @@ -157,7 +145,8 @@ impl PyECCRewriter { )?)) } - /// Returns a list of circuit rewrites that can be applied to the given Tk2Circuit. + /// Returns a list of circuit rewrites that can be applied to the given + /// Tk2Circuit. pub fn get_rewrites(&self, circ: &Tk2Circuit) -> Vec { self.0 .get_rewrites(&circ.circ) diff --git a/tket/examples/badger-hadamard-opt.rs b/tket/examples/badger-hadamard-opt.rs new file mode 100644 index 000000000..0b0b70765 --- /dev/null +++ b/tket/examples/badger-hadamard-opt.rs @@ -0,0 +1,140 @@ +//! Using Badger to perform Hadamard cancellation. + +use hugr::{ + builder::{endo_sig, DFGBuilder, Dataflow, DataflowHugr}, + extension::prelude::qb_t, + HugrView, +}; +use tket::{ + op_matches, + optimiser::BadgerOptimiser, + resource::{CircuitUnit, ResourceScope}, + rewrite::{ + matcher::{CircuitMatcher, MatchContext, MatchOutcome}, + replacer::CircuitReplacer, + strategy::LexicographicCostFunction, + MatchReplaceRewriter, + }, + Circuit, Subcircuit, TketOp, +}; + +/// A matcher that matches two Hadamard gates in a row. +#[derive(Clone, Copy, Debug)] +struct TwoHMatcher; + +/// A replacement that replaces two Hadamard gates in a row with the identity. +#[derive(Clone, Copy, Debug)] +struct HadamardCancellation; + +/// State to keep track of how much has been matched so far. +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)] +enum PartialMatchState { + /// No hadamard matched so far. + #[default] + NoMatch, + /// One hadamard matched so far. + MatchedOne, +} + +impl CircuitMatcher for TwoHMatcher { + type PartialMatchInfo = PartialMatchState; + type MatchInfo = (); + + fn match_tket_op( + &self, + op: tket::TketOp, + _op_args: &[CircuitUnit], + match_context: MatchContext, + ) -> MatchOutcome { + if op != TketOp::H { + // We are not intersted in matching this op + return MatchOutcome::stop(); + } + + match match_context.match_info { + PartialMatchState::NoMatch => { + // Making progress! Proceed to matching the second Hadamard + MatchOutcome::default().proceed(PartialMatchState::MatchedOne) + } + PartialMatchState::MatchedOne => { + // We have matched two Hadamards, so we can report the match + MatchOutcome::default() + .complete(()) // report a full match + .skip(PartialMatchState::MatchedOne) // consider skipping + // this op (and match + // another one) + } + } + } +} + +impl CircuitReplacer<()> for HadamardCancellation { + fn replace_match( + &self, + subcircuit: &Subcircuit, + circuit: &ResourceScope, + _match_info: (), + ) -> Vec { + let hugr = circuit.hugr(); + // subgraph should be a pair of Hadamards + assert_eq!(subcircuit.node_count(circuit), 2); + assert!(subcircuit + .nodes(circuit) + .all(|n| op_matches(hugr.get_optype(n), TketOp::H))); + + // The right hand side of the rewrite is just an empty one-qubit circuit + let h = DFGBuilder::new(endo_sig(qb_t())).unwrap(); + let inps = h.input_wires(); + let empty_circ = h.finish_hugr_with_outputs(inps).unwrap(); + + vec![Circuit::new(empty_circ)] + } +} + +/// A 4-qubit circuit made of three layers +/// - layer of 4x Hadamards on each qubit, +/// - layer of CX gates between pairs of qubits, +/// - layer of 4x Hadamards on each qubit, +fn h_cx_h() -> Circuit { + let mut h = DFGBuilder::new(endo_sig(vec![qb_t(); 4])).unwrap(); + let qbs = h.input_wires(); + let mut circ = h.as_circuit(qbs); + + for _ in 0..4 { + for i in 0..4 { + circ.append(TketOp::H, [i]).unwrap(); + } + } + + for i in (0..4).step_by(2) { + circ.append(TketOp::CX, [i, i + 1]).unwrap(); + } + + for _ in 0..4 { + for i in 0..4 { + circ.append(TketOp::H, [i]).unwrap(); + } + } + + let qbs = circ.finish(); + Circuit::new(h.finish_hugr_with_outputs(qbs).unwrap()) +} + +fn main() { + let rewriter = MatchReplaceRewriter::new(TwoHMatcher, HadamardCancellation); + + let optimiser = + BadgerOptimiser::new(rewriter, LexicographicCostFunction::default_cx_strategy()); + + let circuit = h_cx_h(); + + let optimised = optimiser.optimise(&circuit, Default::default()); + + // Only CX gates are left + assert_eq!(optimised.num_operations(), 2); + assert!(optimised + .operations() + .all(|cmd| op_matches(cmd.optype(), TketOp::CX))); + + println!("Success!"); +} diff --git a/tket/examples/circuit-matcher.rs b/tket/examples/circuit-matcher.rs index 79d906228..8f08a6c70 100644 --- a/tket/examples/circuit-matcher.rs +++ b/tket/examples/circuit-matcher.rs @@ -103,7 +103,7 @@ impl CircuitMatcher for CliffordMatcher { fn main() { const CIRCUIT: &str = r#"{"bits": [], "commands": [{"args": [["q", [0]], ["q", [1]]], "op": {"type": "CX"}}, {"args": [["q", [0]]], "op": {"params": ["0.5"], "type": "Rz"}}, {"args": [["q", [1]]], "op": {"type": "V"}}, {"args": [["q", [0]], ["q", [2]]], "op": {"type": "CX"}}, {"args": [["q", [0]], ["q", [1]]], "op": {"type": "CX"}}, {"args": [["q", [2]]], "op": {"type": "S"}}, {"args": [["q", [0]]], "op": {"params": ["0.111"], "type": "Rz"}}, {"args": [["q", [2]]], "op": {"type": "T"}}, {"args": [["q", [1]], ["q", [2]]], "op": {"type": "CX"}}, {"args": [["q", [1]]], "op": {"type": "T"}}, {"args": [["q", [2]]], "op": {"type": "S"}}, {"args": [["q", [0]], ["q", [1]]], "op": {"type": "CX"}}], "created_qubits": [], "discarded_qubits": [], "implicit_permutation": [[["q", [0]], ["q", [0]]], [["q", [1]], ["q", [1]]], [["q", [2]], ["q", [2]]]], "phase": "0.0", "qubits": [["q", [0]], ["q", [1]], ["q", [2]]]}"#; let ser_circ: SerialCircuit = serde_json::from_str(CIRCUIT).unwrap(); - let circuit = ResourceScope::from_circuit(ser_circ.decode().unwrap()); + let circuit = ResourceScope::from_circuit(ser_circ.decode(Default::default()).unwrap()); let matcher = CliffordMatcher { allowed_num_cx: 2..4, diff --git a/tket/src/optimiser/badger.rs b/tket/src/optimiser/badger.rs index f7cbca1c8..db237d1a7 100644 --- a/tket/src/optimiser/badger.rs +++ b/tket/src/optimiser/badger.rs @@ -1,15 +1,15 @@ //! Badger circuit optimiser. //! //! This module implements the Badger circuit optimiser. It relies on a rewriter -//! and a RewriteStrategy instance to repeatedly rewrite a circuit and optimising -//! it according to some cost metric (typically gate count). +//! and a RewriteStrategy instance to repeatedly rewrite a circuit and +//! optimising it according to some cost metric (typically gate count). //! -//! The optimiser is implemented as a priority queue of circuits to be processed. -//! On top of the queue are the circuits with the lowest cost. They are popped -//! from the queue and replaced by the new circuits obtained from the rewriter -//! and the rewrite strategy. A hash of every circuit computed is stored to -//! detect and ignore duplicates. The priority queue is truncated whenever -//! it gets too large. +//! The optimiser is implemented as a priority queue of circuits to be +//! processed. On top of the queue are the circuits with the lowest cost. They +//! are popped from the queue and replaced by the new circuits obtained from the +//! rewriter and the rewrite strategy. A hash of every circuit computed is +//! stored to detect and ignore duplicates. The priority queue is truncated +//! whenever it gets too large. mod eq_circ_class; pub mod log; @@ -49,7 +49,8 @@ pub struct BadgerOptions { /// /// Defaults to `None`, which means no timeout. pub progress_timeout: Option, - /// The maximum number of circuits to process before stopping the optimisation. + /// The maximum number of circuits to process before stopping the + /// optimisation. /// /// For data parallel multi-threading, (split_circuit=true), applies on a /// per-thread basis, otherwise applies globally. @@ -60,13 +61,14 @@ pub struct BadgerOptions { /// /// Defaults to `1`. pub n_threads: NonZeroUsize, - /// Whether to split the circuit into chunks and process each in a separate thread. + /// Whether to split the circuit into chunks and process each in a separate + /// thread. /// - /// If this option is set to `true`, the optimiser will split the circuit into `n_threads` - /// chunks. + /// If this option is set to `true`, the optimiser will split the circuit + /// into `n_threads` chunks. /// - /// If this option is set to `false`, the optimiser will run parallel searches on the whole - /// circuit. + /// If this option is set to `false`, the optimiser will run parallel + /// searches on the whole circuit. /// /// Defaults to `false`. pub split_circuit: bool, @@ -99,10 +101,11 @@ impl Default for BadgerOptions { /// /// Optimisation is done by maintaining a priority queue of circuits and /// always processing the circuit with the lowest cost first. Rewrites are -/// computed for that circuit and all new circuit obtained are added to the queue. +/// computed for that circuit and all new circuit obtained are added to the +/// queue. /// -/// There are a single-threaded and two multi-threaded versions of the optimiser, -/// controlled by setting the [`BadgerOptions::n_threads`] and +/// There are a single-threaded and two multi-threaded versions of the +/// optimiser, controlled by setting the [`BadgerOptions::n_threads`] and /// [`BadgerOptions::split_circuit`] fields. /// /// [Quartz]: https://arxiv.org/abs/2204.09033 @@ -366,7 +369,8 @@ where best_circ } - /// Run the Badger optimiser on a circuit, with data parallel multithreading. + /// Run the Badger optimiser on a circuit, with data parallel + /// multithreading. /// /// Split the circuit into chunks and process each in a separate thread. #[tracing::instrument(target = "badger::metrics", skip(self, circ, logger))] @@ -509,7 +513,8 @@ mod badger_default { Ok(BadgerOptimiser::new(rewriter, strategy)) } - /// An optimiser minimising Rz gate count using a precompiled binary rewriter. + /// An optimiser minimising Rz gate count using a precompiled binary + /// rewriter. #[cfg(feature = "binary-eccs")] pub fn rz_opt_with_rewriter_binary( rewriter_path: impl AsRef, @@ -578,7 +583,8 @@ mod tests { /// ``` const NON_COMPOSABLE: &str = r#"{"phase":"0.0","commands":[{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[4]],["q",[1]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[1]],["q",[2]]]},{"op":{"type":"U3","params":["0.5","0","0.5"],"signature":["Q"]},"args":[["q",[1]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[3]],["q",[4]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[4]],["q",[0]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[0]],["q",[2]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[0]],["q",[2]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[3]],["q",[1]]]}],"qubits":[["q",[0]],["q",[1]],["q",[2]],["q",[3]],["q",[4]]],"bits":[],"implicit_permutation":[[["q",[0]],["q",[0]]],[["q",[1]],["q",[1]]],[["q",[2]],["q",[2]]],[["q",[3]],["q",[3]]],[["q",[4]],["q",[4]]]]}"#; - /// A circuit that would trigger non-composable rewrites, if we applied them blindly from nam_6_3 matches. + /// A circuit that would trigger non-composable rewrites, if we applied them + /// blindly from nam_6_3 matches. #[fixture] fn non_composable_rw_hugr() -> Circuit { load_tk1_json_str(NON_COMPOSABLE, DecodeOptions::new()).unwrap() diff --git a/tket/src/optimiser/badger/worker.rs b/tket/src/optimiser/badger/worker.rs index 4852f3c45..64f0f76f8 100644 --- a/tket/src/optimiser/badger/worker.rs +++ b/tket/src/optimiser/badger/worker.rs @@ -2,11 +2,10 @@ use std::thread::{self, JoinHandle}; -use crate::circuit::cost::CircuitCost; use crate::circuit::CircuitHash; -use crate::resource::ResourceScope; use crate::rewrite::strategy::RewriteStrategy; use crate::rewrite::Rewriter; +use crate::{circuit::cost::CircuitCost, resource::ResourceScope}; use super::pqueue_worker::{StatePQueueChannels, Work}; diff --git a/tket/src/resource.rs b/tket/src/resource.rs index 834a156a2..300b403b8 100644 --- a/tket/src/resource.rs +++ b/tket/src/resource.rs @@ -54,7 +54,7 @@ use hugr::{ }; pub use interval::{Interval, InvalidInterval}; use itertools::Itertools; -pub use scope::{ResourceScope, ResourceScopeConfig}; +pub use scope::{CircuitRewriteError, ResourceScope, ResourceScopeConfig}; pub use types::{CircuitUnit, Position, ResourceAllocator, ResourceId}; use crate::{ @@ -82,7 +82,8 @@ impl ResourceScope { /// Register a rewrite applied to the circuit. /// - /// Returns `true` if the rewrite was successfully registered, or `false` if it was ignored. + /// Returns `true` if the rewrite was successfully registered, or `false` if + /// it was ignored. #[inline] pub fn add_rewrite_trace(&mut self, rewrite: impl Into) -> bool { self.as_circuit_mut().add_rewrite_trace(rewrite) diff --git a/tket/src/resource/scope.rs b/tket/src/resource/scope.rs index bbbb26f50..2583f51f6 100644 --- a/tket/src/resource/scope.rs +++ b/tket/src/resource/scope.rs @@ -4,8 +4,10 @@ //! tracking within a specific region of a HUGR, computing resource paths and //! providing efficient lookup of circuit units associated with ports. -use std::collections::BTreeSet; -use std::{cmp, iter}; +mod patch; +pub use patch::CircuitRewriteError; + +use std::{cmp, collections::BTreeSet, iter}; use crate::resource::flow::{DefaultResourceFlow, ResourceFlow}; use crate::resource::types::{CircuitUnit, PortMap}; @@ -185,17 +187,15 @@ impl ResourceScope { self.hugr } - pub(crate) fn hugr_mut(&mut self) -> &mut H { - &mut self.hugr - } - /// Wrap the underlying HUGR in a Circuit as reference. pub fn as_circuit(&self) -> Circuit<&H> { Circuit::new(self.hugr()) } - pub(crate) fn as_circuit_mut(&mut self) -> Circuit<&mut H> { - Circuit::new(self.hugr_mut()) + /// Careful: this will not update the circuit units, so do not modify + /// the HUGR using this. + pub(super) fn as_circuit_mut(&mut self) -> Circuit<&mut H> { + Circuit::new(&mut self.hugr) } /// Get the underlying subgraph, or `None` if the circuit is empty. diff --git a/tket/src/resource/scope/patch.rs b/tket/src/resource/scope/patch.rs new file mode 100644 index 000000000..8b45a2adf --- /dev/null +++ b/tket/src/resource/scope/patch.rs @@ -0,0 +1,369 @@ +//! Applying [`CircuitRewrite`]s to a [`ResourceScope`]. + +use derive_more::derive::{Display, Error, From}; +use hugr::{ + hugr::{ + hugrmut::HugrMut, patch::simple_replace, views::SiblingSubgraph, Patch, + SimpleReplacementError, + }, + Direction, HugrView, +}; +use indexmap::IndexMap; +use itertools::Itertools; + +use crate::{ + resource::{ + scope::node_circuit_units_mut, CircuitUnit, Position, ResourceScope, ResourceScopeConfig, + }, + rewrite::{CircuitRewrite, NewCircuitRewrite, OldCircuitRewrite}, + Circuit, +}; + +impl> ResourceScope { + /// Apply a rewrite to the circuit. + /// + /// This should ideally live within [`hugr_core::hugr::Patch`], but would + /// need ResourceScope: HugrView. + pub fn apply_rewrite( + &mut self, + rewrite: CircuitRewrite, + ) -> Result, CircuitRewriteError> { + match rewrite { + CircuitRewrite::New(rewrite) => self.apply_rewrite_new(rewrite), + CircuitRewrite::Old(OldCircuitRewrite(repl)) => { + repl.apply(&mut self.hugr).map_err(Into::into) + } + } + } + + fn apply_rewrite_new( + &mut self, + rewrite: NewCircuitRewrite, + ) -> Result, CircuitRewriteError> { + let simple_replacement = rewrite.to_simple_replacement(self); + let NewCircuitRewrite { + subcircuit, + replacement, + } = rewrite; + + let inputs = subcircuit.input_ports(self); + let outputs = subcircuit.output_ports(self); + + let units_at_inputs = inputs.iter().map(|inp| { + debug_assert!(inp + .iter() + .map(|&(node, port)| self.get_circuit_unit(node, port)) + .all_equal()); + debug_assert!(!inp.is_empty()); + + let &(node, port) = inp.first().expect("just checked"); + self.get_circuit_unit(node, port).expect("just checked") + }); + + let mut repl_scope = None; + if replacement.num_operations() > 0 { + let repl_scope = repl_scope.insert(ResourceScope::from_circuit_with_input_units( + replacement, + units_at_inputs, + )); + + let max_input_pos = self.get_nearest_position( + inputs.iter().flatten().map(|&(n, _)| n), + Direction::Incoming, + ); + let min_output_pos = + self.get_nearest_position(outputs.iter().map(|&(n, _)| n), Direction::Outgoing); + + let n_nodes = repl_scope.nodes().len() as i64; + let ideal_interval_size = Position::new_integer(n_nodes + 2); + let (start, end) = match (max_input_pos, min_output_pos) { + (None, None) => (Position::new_integer(0), ideal_interval_size), + (None, Some(end)) => (Position(end.0 - ideal_interval_size.0), end), + (Some(start), None) => (start, Position(start.0 + ideal_interval_size.0)), + (Some(start), Some(end)) => (start, end), + }; + + if start >= end { + return Err(CircuitRewriteError::EmptyPositionRange); + } + + repl_scope.rescale_positions(start, end); + } + + let simple_replace::Outcome { + node_map, + removed_nodes, + } = simple_replacement.apply(&mut self.hugr)?; + + for (&node, _) in &removed_nodes { + self.circuit_units.swap_remove(&node); + } + + for (&repl_node, &new_node) in &node_map { + let repl_scope = repl_scope.as_ref().expect("non-empty replacement"); + let repl_node_units = repl_scope + .circuit_units + .get(&repl_node) + .expect("valid node"); + self.circuit_units.insert(new_node, repl_node_units.clone()); + } + + let new_node_set = self + .subgraph() + .map(|subgraph| subgraph.nodes()) + .unwrap_or_default() + .iter() + .chain(node_map.values()) + .copied() + .filter(|n| !removed_nodes.contains_key(&n)) + .collect_vec(); + + let [inp_node, out_node] = self.as_circuit().io_nodes(); + let incoming_ports = self + .hugr() + .node_outputs(inp_node) + .map(|p| (inp_node, p)) + .map(|(n, p)| self.hugr().linked_inputs(n, p).collect_vec()) + .take_while(|ports| !ports.is_empty()) + .collect_vec(); + let outgoing_ports = self + .hugr() + .node_inputs(out_node) + .map(|p| (out_node, p)) + .map(|(n, p)| self.hugr().single_linked_output(n, p)) + .while_some() + .collect_vec(); + self.subgraph = Some(SiblingSubgraph::new_unchecked( + incoming_ports, + outgoing_ports, + vec![], + new_node_set, + )); + + Ok(simple_replace::Outcome { + node_map, + removed_nodes, + }) + } + + fn get_nearest_position( + &self, + nodes: impl IntoIterator, + dir: Direction, + ) -> Option { + let all_pos = nodes + .into_iter() + .flat_map(|n| self.hugr().neighbours(n, dir)) + .filter_map(|n| self.get_position(n)); + match dir { + Direction::Incoming => all_pos.max(), + Direction::Outgoing => all_pos.min(), + } + } +} + +impl> ResourceScope { + /// Create a new resource scope from a circuit, using the given input units + /// instead of allocating new ones. + fn from_circuit_with_input_units( + circuit: Circuit, + units: impl IntoIterator>, + ) -> Self { + let subgraph = circuit.subgraph().unwrap(); + let inputs = subgraph.incoming_ports().to_owned(); + + let mut this = Self { + hugr: circuit.into_hugr(), + subgraph: Some(subgraph), + circuit_units: IndexMap::new(), + }; + + for (inp, unit) in inputs.into_iter().zip_eq(units) { + for (node, port) in inp { + let Some(node_units) = + node_circuit_units_mut(&mut this.circuit_units, node, &this.hugr) + else { + continue; + }; + node_units.port_map.set(port, unit) + } + } + + let config = ResourceScopeConfig::default(); + this.compute_circuit_units(&config.flows); + + this + } + + /// Rescale the positions of all nodes to be within the given range. + fn rescale_positions(&mut self, start: Position, end: Position) { + let (curr_start, curr_end) = self + .nodes() + .iter() + .map(|&n| self.get_position(n).expect("valid node")) + .minmax() + .into_option() + .expect("non empty subgraph"); + + debug_assert!(curr_start < curr_end || (curr_start == curr_end && self.nodes().len() == 1)); + + for &node in self + .subgraph + .as_ref() + .map(|subgraph| subgraph.nodes()) + .unwrap_or_default() + { + let Some(node_units) = + node_circuit_units_mut(&mut self.circuit_units, node, &self.hugr) + else { + continue; + }; + node_units.position = node_units + .position + .rescale(curr_start..=curr_end, start..=end); + } + } +} + +/// Errors that can occur when applying a rewrite to a resource scope. +#[derive(Debug, Display, Error, From)] +pub enum CircuitRewriteError { + /// An error occurred while applying the rewrite. + SimpleReplacementError(#[from] SimpleReplacementError), + /// The replacement could not be inserted in topological order. Is the + /// subcircuit non-convex or disconnected? + #[display("replacement could not be inserted in topological order. Is the subcircuit non-convex or disconnected?")] + EmptyPositionRange, +} + +#[cfg(test)] +mod tests { + use crate::{ + resource::{CircuitUnit, ResourceScope}, + rewrite::CircuitRewrite, + utils::build_simple_circuit, + Subcircuit, TketOp, + }; + use hugr::{hugr::Patch, ops::OpType, Direction}; + use itertools::Itertools; + use rstest::rstest; + + fn simple_circuit() -> ResourceScope { + ResourceScope::from_circuit( + build_simple_circuit(2, |circ| { + circ.append(TketOp::H, [0])?; + circ.append(TketOp::CX, [0, 1])?; + circ.append(TketOp::CX, [0, 1])?; + circ.append(TketOp::H, [1])?; + Ok(()) + }) + .unwrap(), + ) + } + + fn rewrite_to_n_cx(n: usize) -> (ResourceScope, CircuitRewrite) { + let circ = simple_circuit(); + + let cx_nodes = circ + .as_circuit() + .commands() + .filter_map(|cmd| { + if &OpType::from(TketOp::CX) == cmd.optype() { + Some(cmd.node()) + } else { + None + } + }) + .collect_array::<2>() + .unwrap(); + + let repl = build_simple_circuit(2, |circ| { + for _ in 0..n { + circ.append(TketOp::CX, [0, 1])?; + } + Ok(()) + }) + .unwrap(); + + let rewrite = CircuitRewrite::try_new( + Subcircuit::try_from_nodes(cx_nodes, &circ).unwrap(), + &circ, + repl, + ) + .unwrap(); + + (circ, rewrite) + } + + #[rstest] + #[case(rewrite_to_n_cx(0), "rewrite_to_0cx")] + #[case(rewrite_to_n_cx(1), "rewrite_to_1cx")] + #[case(rewrite_to_n_cx(4), "rewrite_to_4cx")] + fn test_circuit_rewrite_preserves_circuit_units( + #[case] (circ, rewrite): (ResourceScope, CircuitRewrite), + #[case] name: &str, + ) { + // Apply rewrite to resource scope + + use std::collections::BTreeSet; + let mut rewritten_scope = circ.clone(); + rewritten_scope.apply_rewrite(rewrite.clone()).unwrap(); + + // Apply rewrite to circuit directly + let mut direct_circuit = circ.as_circuit().extract_dfg().unwrap(); + rewrite + .to_simple_replacement(&circ) + .apply(direct_circuit.hugr_mut()) + .unwrap(); + let direct_scope = ResourceScope::from_circuit(direct_circuit); + + assert_eq!( + BTreeSet::from_iter(rewritten_scope.nodes()), + BTreeSet::from_iter(direct_scope.nodes()) + ); + + // Check that circuit units are identical + for &node in rewritten_scope.nodes() { + for dir in Direction::BOTH { + let rewritten_units = rewritten_scope.get_circuit_units_slice(node, dir); + let direct_units = direct_scope.get_circuit_units_slice(node, dir); + assert_eq!( + rewritten_units, direct_units, + "Circuit units differ for node {:?}", + node + ); + } + } + + // Check that positions are strictly increasing along each resource path + for &(node, port) in rewritten_scope + .subgraph() + .unwrap() + .incoming_ports() + .iter() + .flatten() + { + let res = match rewritten_scope.get_circuit_unit(node, port).unwrap() { + CircuitUnit::Resource(resource_id) => resource_id, + CircuitUnit::Copyable(..) => { + continue; + } + }; + let all_pos = rewritten_scope + .resource_path_iter(res, node, Direction::Outgoing) + .map(|n| rewritten_scope.get_position(n).unwrap()) + .collect_vec(); + assert!(all_pos.is_sorted()); + } + + // Snapshot test of all circuit units + let mut circuit_units = rewritten_scope + .circuit_units + .clone() + .into_iter() + .collect_vec(); + circuit_units.sort_unstable_by_key(|&(k, _)| k); // For deterministic output + + insta::assert_debug_snapshot!(name, circuit_units); + } +} diff --git a/tket/src/resource/scope/snapshots/tket__resource__scope__patch__tests__rewrite_to_0cx.snap b/tket/src/resource/scope/snapshots/tket__resource__scope__patch__tests__rewrite_to_0cx.snap new file mode 100644 index 000000000..cf49dc89e --- /dev/null +++ b/tket/src/resource/scope/snapshots/tket__resource__scope__patch__tests__rewrite_to_0cx.snap @@ -0,0 +1,52 @@ +--- +source: tket/src/resource/scope/patch.rs +expression: circuit_units +--- +[ + ( + Node( + 4, + ), + NodeCircuitUnits { + port_map: PortMap { + vec: [ + Resource( + ResourceId( + 0, + ), + ), + Resource( + ResourceId( + 0, + ), + ), + ], + num_inputs: 1, + }, + position: Position(0), + }, + ), + ( + Node( + 7, + ), + NodeCircuitUnits { + port_map: PortMap { + vec: [ + Resource( + ResourceId( + 1, + ), + ), + Resource( + ResourceId( + 1, + ), + ), + ], + num_inputs: 1, + }, + position: Position(3), + }, + ), +] diff --git a/tket/src/resource/scope/snapshots/tket__resource__scope__patch__tests__rewrite_to_1cx.snap b/tket/src/resource/scope/snapshots/tket__resource__scope__patch__tests__rewrite_to_1cx.snap new file mode 100644 index 000000000..c5e272cce --- /dev/null +++ b/tket/src/resource/scope/snapshots/tket__resource__scope__patch__tests__rewrite_to_1cx.snap @@ -0,0 +1,85 @@ +--- +source: tket/src/resource/scope/patch.rs +expression: circuit_units +--- +[ + ( + Node( + 4, + ), + NodeCircuitUnits { + port_map: PortMap { + vec: [ + Resource( + ResourceId( + 0, + ), + ), + Resource( + ResourceId( + 0, + ), + ), + ], + num_inputs: 1, + }, + position: Position(0), + }, + ), + ( + Node( + 7, + ), + NodeCircuitUnits { + port_map: PortMap { + vec: [ + Resource( + ResourceId( + 1, + ), + ), + Resource( + ResourceId( + 1, + ), + ), + ], + num_inputs: 1, + }, + position: Position(3), + }, + ), + ( + Node( + 11, + ), + NodeCircuitUnits { + port_map: PortMap { + vec: [ + Resource( + ResourceId( + 0, + ), + ), + Resource( + ResourceId( + 1, + ), + ), + Resource( + ResourceId( + 0, + ), + ), + Resource( + ResourceId( + 1, + ), + ), + ], + num_inputs: 2, + }, + position: Position(3/2), + }, + ), +] diff --git a/tket/src/resource/scope/snapshots/tket__resource__scope__patch__tests__rewrite_to_4cx.snap b/tket/src/resource/scope/snapshots/tket__resource__scope__patch__tests__rewrite_to_4cx.snap new file mode 100644 index 000000000..ad52dee6d --- /dev/null +++ b/tket/src/resource/scope/snapshots/tket__resource__scope__patch__tests__rewrite_to_4cx.snap @@ -0,0 +1,184 @@ +--- +source: tket/src/resource/scope/patch.rs +expression: circuit_units +--- +[ + ( + Node( + 4, + ), + NodeCircuitUnits { + port_map: PortMap { + vec: [ + Resource( + ResourceId( + 0, + ), + ), + Resource( + ResourceId( + 0, + ), + ), + ], + num_inputs: 1, + }, + position: Position(0), + }, + ), + ( + Node( + 7, + ), + NodeCircuitUnits { + port_map: PortMap { + vec: [ + Resource( + ResourceId( + 1, + ), + ), + Resource( + ResourceId( + 1, + ), + ), + ], + num_inputs: 1, + }, + position: Position(3), + }, + ), + ( + Node( + 11, + ), + NodeCircuitUnits { + port_map: PortMap { + vec: [ + Resource( + ResourceId( + 0, + ), + ), + Resource( + ResourceId( + 1, + ), + ), + Resource( + ResourceId( + 0, + ), + ), + Resource( + ResourceId( + 1, + ), + ), + ], + num_inputs: 2, + }, + position: Position(3/5), + }, + ), + ( + Node( + 12, + ), + NodeCircuitUnits { + port_map: PortMap { + vec: [ + Resource( + ResourceId( + 0, + ), + ), + Resource( + ResourceId( + 1, + ), + ), + Resource( + ResourceId( + 0, + ), + ), + Resource( + ResourceId( + 1, + ), + ), + ], + num_inputs: 2, + }, + position: Position(6/5), + }, + ), + ( + Node( + 13, + ), + NodeCircuitUnits { + port_map: PortMap { + vec: [ + Resource( + ResourceId( + 0, + ), + ), + Resource( + ResourceId( + 1, + ), + ), + Resource( + ResourceId( + 0, + ), + ), + Resource( + ResourceId( + 1, + ), + ), + ], + num_inputs: 2, + }, + position: Position(9/5), + }, + ), + ( + Node( + 14, + ), + NodeCircuitUnits { + port_map: PortMap { + vec: [ + Resource( + ResourceId( + 0, + ), + ), + Resource( + ResourceId( + 1, + ), + ), + Resource( + ResourceId( + 0, + ), + ), + Resource( + ResourceId( + 1, + ), + ), + ], + num_inputs: 2, + }, + position: Position(12/5), + }, + ), +] diff --git a/tket/src/resource/types.rs b/tket/src/resource/types.rs index f1ffc4c51..abb602d18 100644 --- a/tket/src/resource/types.rs +++ b/tket/src/resource/types.rs @@ -4,6 +4,9 @@ //! copyable values throughout a HUGR circuit, including resource identifiers, //! positions, and the mapping structures that associate them with operations. +use std::ops::RangeInclusive; + +use cgmath::Zero; use derive_more::derive::From; use hugr::{ core::HugrNode, types::Signature, Direction, IncomingPort, OutgoingPort, Port, PortIndex, Wire, @@ -40,7 +43,7 @@ impl ResourceId { /// Initially assigned as contiguous integers, they may become non-integer /// when operations are inserted or removed. #[derive(Clone, Copy, Default, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct Position(Rational64); +pub struct Position(pub(crate) Rational64); impl std::fmt::Debug for Position { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -67,6 +70,27 @@ impl Position { pub fn increment(&self) -> Self { Self(self.0 + 1) } + + /// Rescale the position such that any number in the old range (including + /// ends) is within the new range (excluding ends). + pub(crate) fn rescale( + &self, + old_range: RangeInclusive, + new_range: RangeInclusive, + ) -> Self { + let Self(pos) = self; + let old_range_size = Rational64::from_integer(2) + old_range.end().0 - old_range.start().0; + let new_range_size = new_range.end().0 - new_range.start().0; + + if old_range_size == Rational64::zero() { + return *new_range.start(); + } + + let new_pos = new_range.start().0 + + (Rational64::from_integer(1) + pos - old_range.start().0) / old_range_size + * new_range_size; + Self(new_pos) + } } /// A value associated with a dataflow port, identified either by a resource ID diff --git a/tket/src/rewrite.rs b/tket/src/rewrite.rs index 247732168..3913dbdc8 100644 --- a/tket/src/rewrite.rs +++ b/tket/src/rewrite.rs @@ -3,6 +3,7 @@ #[cfg(feature = "portmatching")] pub mod ecc_rewriter; pub mod matcher; +pub mod replacer; pub mod strategy; pub mod trace; @@ -15,17 +16,15 @@ use hugr::core::HugrNode; use hugr::hugr::hugrmut::HugrMut; use hugr::hugr::patch::simple_replace; use hugr::hugr::views::sibling_subgraph::InvalidSubgraph; -use hugr::hugr::Patch; use hugr::types::Signature; -use hugr::{ - hugr::{views::SiblingSubgraph, SimpleReplacementError}, - SimpleReplacement, -}; +use hugr::{hugr::views::SiblingSubgraph, SimpleReplacement}; use hugr::{Hugr, HugrView}; use itertools::Either; +use matcher::{CircuitMatcher, MatchingOptions}; +use replacer::CircuitReplacer; use crate::circuit::Circuit; -use crate::resource::ResourceScope; +use crate::resource::{CircuitRewriteError, ResourceScope}; pub use crate::Subcircuit; /// A rewrite rule for circuits. @@ -45,17 +44,39 @@ pub enum CircuitRewrite { } /// A rewrite rule for circuits. +/// +/// The following invariants hold: +/// - the subcircuit is not empty +/// - the subcircuit is convex #[derive(Debug, Clone)] pub struct NewCircuitRewrite { - subcircuit: Subcircuit, - replacement: Circuit, + pub(crate) subcircuit: Subcircuit, + pub(crate) replacement: Circuit, +} + +impl NewCircuitRewrite { + /// Construct a [`SimpleReplacement`] that executes the rewrite as a HUGR + /// operation. + pub fn to_simple_replacement( + &self, + circuit: &ResourceScope>, + ) -> SimpleReplacement { + let subgraph = self + .subcircuit + .try_to_subgraph(circuit) + .expect("subcircuit is valid subgraph"); + subgraph + .create_simple_replacement(circuit.hugr(), self.replacement.clone().into_hugr()) + .expect("rewrite is valid simple replacement") + } } /// A rewrite rule for circuits, wrapping a HUGR [`SimpleReplacement`]. /// -/// You should migrate to using [`NewCircuitRewrite`] instead. It is much faster. +/// You should migrate to using [`NewCircuitRewrite`] instead. It is much +/// faster. #[derive(Debug, Clone, From, Into)] -pub struct OldCircuitRewrite(SimpleReplacement); +pub struct OldCircuitRewrite(pub(crate) SimpleReplacement); impl CircuitRewrite { /// Create a new rewrite that can be applied to `hugr`. @@ -64,9 +85,9 @@ impl CircuitRewrite { circuit: &ResourceScope>, replacement: Circuit, ) -> Result { - subcircuit - .validate_subgraph(circuit) - .map_err(|err| InvalidRewrite::try_from(err).unwrap_or_else(|err| panic!("{err}")))?; + subcircuit.validate_subgraph(circuit).map_err(|err| { + InvalidRewrite::try_from(err).unwrap_or_else(|err| panic!("unknown error: {err}")) + })?; let subcircuit_sig = subcircuit.dataflow_signature(circuit); let replacement_sig = replacement.circuit_signature(); @@ -150,24 +171,27 @@ impl CircuitRewrite { Self::Old(rewrite) => Either::Right(rewrite.0.subgraph().nodes().iter().copied()), } } +} +impl CircuitRewrite { /// Apply the rewrite rule to a circuit. #[inline] pub fn apply( self, - circ: &mut ResourceScope>, - ) -> Result, SimpleReplacementError> { - circ.as_circuit_mut().add_rewrite_trace(&self); - self.to_simple_replacement(circ).apply(circ.hugr_mut()) + circ: &mut ResourceScope>, + ) -> Result { + circ.add_rewrite_trace(&self); + circ.apply_rewrite(self) } - /// Apply the rewrite rule to a circuit, without registering it in the rewrite trace. + /// Apply the rewrite rule to a circuit, without registering it in the + /// rewrite trace. #[inline] pub fn apply_notrace( self, - circ: &mut ResourceScope>, - ) -> Result, SimpleReplacementError> { - self.to_simple_replacement(circ).apply(circ.hugr_mut()) + circ: &mut ResourceScope>, + ) -> Result { + circ.apply_rewrite(self) } } @@ -243,6 +267,25 @@ impl TryFrom> for InvalidRewrite { } } } +/// A rewriter made of a [`CircuitMatcher`] and a [`CircuitReplacer`]. +/// +/// The [`CircuitMatcher`] is used to find matches in the circuit, and the +/// [`CircuitReplacer`] is used to create [`CircuitRewrite`]s for each match. +#[derive(Clone, Debug)] +pub struct MatchReplaceRewriter { + matcher: C, + replacer: R, +} + +impl MatchReplaceRewriter { + /// Create a new [`MatchReplaceRewriter`]. + pub fn new(matcher: C, replacement: R) -> Self { + Self { + matcher, + replacer: replacement, + } + } +} fn compute_node_count_delta( subcircuit: &Subcircuit, @@ -253,3 +296,27 @@ fn compute_node_count_delta( let old_count = subcircuit.nodes(circuit).count() as isize; new_count - old_count } + +impl> Rewriter> for MatchReplaceRewriter +where + C: CircuitMatcher, + R: CircuitReplacer, +{ + fn get_rewrites(&self, circ: &ResourceScope) -> Vec> { + let matches = self + .matcher + .as_hugr_matcher() + .get_all_matches(circ, &MatchingOptions::default()); + matches + .into_iter() + .flat_map(|(subcirc, match_info)| { + self.replacer + .replace_match(&subcirc, circ, match_info) + .into_iter() + .filter_map(move |repl| { + CircuitRewrite::try_new(subcirc.clone(), circ, repl).ok() + }) + }) + .collect() + } +} diff --git a/tket/src/rewrite/ecc_rewriter.rs b/tket/src/rewrite/ecc_rewriter.rs index ea29a781e..f7cd03d32 100644 --- a/tket/src/rewrite/ecc_rewriter.rs +++ b/tket/src/rewrite/ecc_rewriter.rs @@ -69,9 +69,9 @@ pub struct ECCRewriter { impl ECCRewriter { /// Create a new rewriter from equivalent circuit classes in JSON file. /// - /// This uses the Quartz JSON file format to store equivalent circuit classes. - /// Generate such a file using the `gen_ecc_set.sh` script at the root of - /// the Quartz repository. + /// This uses the Quartz JSON file format to store equivalent circuit + /// classes. Generate such a file using the `gen_ecc_set.sh` script at + /// the root of the Quartz repository. /// /// Quartz: . pub fn try_from_eccs_json_file(path: impl AsRef) -> io::Result { @@ -184,8 +184,9 @@ impl ECCRewriter { Self::load_binary_io(&mut file) } - /// When the ECC gets loaded, all custom operations are an instance of `OpaqueOp`. - /// We need to resolve them into `ExtensionOp`s by validating the definitions. + /// When the ECC gets loaded, all custom operations are an instance of + /// `OpaqueOp`. We need to resolve them into `ExtensionOp`s by + /// validating the definitions. fn resolve_extension_ops(&mut self) -> Result<(), ExtensionResolutionError> { self.targets .iter_mut() @@ -194,10 +195,7 @@ impl ECCRewriter { } impl> Rewriter> for ECCRewriter { - fn get_rewrites( - &self, - circ: &ResourceScope, - ) -> Vec as super::hidden::CircuitLike>::Node>> { + fn get_rewrites(&self, circ: &ResourceScope) -> Vec> { self.get_rewrites(&circ.as_circuit()) } } diff --git a/tket/src/rewrite/matcher/adapter.rs b/tket/src/rewrite/matcher/adapter.rs index 30bceac8f..9350eb858 100644 --- a/tket/src/rewrite/matcher/adapter.rs +++ b/tket/src/rewrite/matcher/adapter.rs @@ -575,7 +575,7 @@ mod tests { // A circuit with two constant angle Rz gates, one of them is 0.123. const CIRC: &str = r#"{"bits": [], "commands": [{"args": [["q", [0]]], "op": {"params": ["0.123"], "type": "Rz"}}, {"args": [["q", [0]]], "op": {"params": ["0.5"], "type": "Rz"}}], "created_qubits": [], "discarded_qubits": [], "implicit_permutation": [[["q", [0]], ["q", [0]]]], "phase": "0.0", "qubits": [["q", [0]]]}"#; let ser_circ: tket_json_rs::SerialCircuit = serde_json::from_str(CIRC).unwrap(); - let circuit = ResourceScope::from_circuit(ser_circ.decode().unwrap()); + let circuit = ResourceScope::from_circuit(ser_circ.decode(Default::default()).unwrap()); let matcher = TestRzMatcher.as_hugr_matcher(); let (match_subcirc, ()) = matcher diff --git a/tket/src/rewrite/replacer.rs b/tket/src/rewrite/replacer.rs new file mode 100644 index 000000000..b8a2e693a --- /dev/null +++ b/tket/src/rewrite/replacer.rs @@ -0,0 +1,19 @@ +//! Providing replacements for pattern matches. + +use hugr::HugrView; + +use crate::{resource::ResourceScope, Circuit, Subcircuit}; + +/// Provide possible replacements for a pattern match. +pub trait CircuitReplacer { + /// Get the possible replacements for a pattern match. + /// + /// The order (and signature) of the inputs and outputs on the returned + /// circuits must match the order of the boundary ports in `subgraph`. + fn replace_match( + &self, + subcircuit: &Subcircuit, + circuit: &ResourceScope, + match_info: MatchInfo, + ) -> Vec; +} diff --git a/tket/src/rewrite/strategy.rs b/tket/src/rewrite/strategy.rs index 10191bd54..d8579b316 100644 --- a/tket/src/rewrite/strategy.rs +++ b/tket/src/rewrite/strategy.rs @@ -12,13 +12,13 @@ //! threshold function. //! //! The exhaustive strategies are parametrised by a strategy cost function: -//! - [`LexicographicCostFunction`] allows rewrites that do -//! not increase some coarse cost function (e.g. CX count), whilst -//! ordering them according to a lexicographic ordering of finer cost -//! functions (e.g. total gate count). See -//! [`LexicographicCostFunction::default_cx_strategy`]) for a default implementation. -//! - [`GammaStrategyCost`] ignores rewrites that increase the cost -//! function beyond a percentage given by a f64 parameter gamma. +//! - [`LexicographicCostFunction`] allows rewrites that do not increase some +//! coarse cost function (e.g. CX count), whilst ordering them according to +//! a lexicographic ordering of finer cost functions (e.g. total gate +//! count). See [`LexicographicCostFunction::default_cx_strategy`]) for a +//! default implementation. +//! - [`GammaStrategyCost`] ignores rewrites that increase the cost function +//! beyond a percentage given by a f64 parameter gamma. use std::iter; use std::{collections::HashSet, fmt::Debug}; @@ -73,7 +73,8 @@ pub trait RewriteStrategy { }) } - /// Returns the expected cost of a rewrite's matched subcircuit after replacing it. + /// Returns the expected cost of a rewrite's matched subcircuit after + /// replacing it. fn post_rewrite_cost(&self, rw: &CircuitRewrite) -> Self::Cost { Circuit::new(rw.replacement()).circuit_cost(|op| self.op_cost(op)) } @@ -142,15 +143,11 @@ impl RewriteStrategy for GreedyRewriteStrategy { let mut cost_delta = 0; let mut circ = circ.clone(); for rewrite in rewrites { - if rewrite - .to_subgraph(&circ) - .nodes() - .iter() - .any(|n| changed_nodes.contains(n)) - { + let subgraph = rewrite.to_subgraph(&circ); + if subgraph.nodes().iter().any(|n| changed_nodes.contains(n)) { continue; } - changed_nodes.extend(rewrite.to_subgraph(&circ).nodes().iter().copied()); + changed_nodes.extend(subgraph.nodes().iter().copied()); cost_delta += rewrite.node_count_delta(&circ); rewrite .apply(&mut circ) @@ -308,7 +305,8 @@ pub trait StrategyCost { /// The cost of a single operation. type OpCost: CircuitCost; - /// Returns true if the rewrite is allowed, based on the cost of the pattern and target. + /// Returns true if the rewrite is allowed, based on the cost of the pattern + /// and target. #[inline] fn under_threshold(&self, pattern_cost: &Self::OpCost, target_cost: &Self::OpCost) -> bool { target_cost.sub_cost(pattern_cost).as_isize() <= 0 @@ -478,7 +476,8 @@ impl GammaStrategyCost usize> { GammaStrategyCost::with_cost(|op| is_cx(op) as usize) } - /// Exhaustive rewrite strategy with CX count cost function and provided gamma. + /// Exhaustive rewrite strategy with CX count cost function and provided + /// gamma. #[inline] pub fn exhaustive_cx_with_gamma(gamma: f64) -> ExhaustiveThresholdStrategy { GammaStrategyCost::new(gamma, |op| is_cx(op) as usize)