diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 56790b93b..00bc909d2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -221,7 +221,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Add requirements - run: python -m pip install --upgrade cmake>=3.12 ninja pytest flake8 pytest-cov setuptools + run: python -m pip install --upgrade cmake>=3.12 ninja pytest ruff pytest-cov setuptools - name: Build and install run: python -m pip install --verbose -e . @@ -235,8 +235,11 @@ jobs: - name: Test with stim and rustworkx using coverage run: python -m pytest tests --cov=./src/pymatching --cov-report term - - name: flake8 - run: flake8 ./src ./tests + - name: ruff lint + run: ruff check src tests + + - name: ruff format + run: ruff format --check src tests build_docs: runs-on: ubuntu-latest @@ -272,7 +275,7 @@ jobs: with: python-version: '3.x' - name: Add requirements - run: python -m pip install --upgrade cmake>=3.12 ninja pytest flake8 pytest-cov stim rustworkx + run: python -m pip install --upgrade cmake>=3.12 ninja pytest ruff pytest-cov stim rustworkx - name: Build and install run: pip install --verbose -e . - name: Run tests and collect coverage diff --git a/src/pymatching/__init__.py b/src/pymatching/__init__.py index f0c9d753a..764fea0f8 100644 --- a/src/pymatching/__init__.py +++ b/src/pymatching/__init__.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pymatching._cpp_pymatching import (randomize, set_seed, rand_float) # noqa +from pymatching._cpp_pymatching import randomize, set_seed, rand_float # noqa from pymatching._cpp_pymatching import main as cli # noqa from pymatching.matching import Matching # noqa -from pymatching._version import __version__ +from pymatching._version import __version__ as __version__ # noqa randomize() # Set random seed using std::random_device diff --git a/src/pymatching/_version.py b/src/pymatching/_version.py index 8743aa06d..95f70461f 100644 --- a/src/pymatching/_version.py +++ b/src/pymatching/_version.py @@ -1 +1 @@ -__version__ = "2.3.1" # pragma no cover +__version__ = "2.3.1" # pragma no cover diff --git a/src/pymatching/matching.py b/src/pymatching/matching.py index 3ac8fed0e..cfd002930 100644 --- a/src/pymatching/matching.py +++ b/src/pymatching/matching.py @@ -38,18 +38,26 @@ class Matching: a `stim.DetectorErrorModel`. """ - def __init__(self, - graph: Union[csc_matrix, np.ndarray, "rx.PyGraph", nx.Graph, List[ - List[int]], 'stim.DetectorErrorModel', spmatrix] = None, - weights: Union[float, np.ndarray, List[float]] = None, - error_probabilities: Union[float, np.ndarray, List[float]] = None, - repetitions: int = None, - timelike_weights: Union[float, np.ndarray, List[float]] = None, - measurement_error_probabilities: Union[float, np.ndarray, List[float]] = None, - *, - enable_correlations: bool = False, - **kwargs - ): + def __init__( + self, + graph: Union[ + csc_matrix, + np.ndarray, + "rx.PyGraph", + nx.Graph, + List[List[int]], + "stim.DetectorErrorModel", + spmatrix, + ] = None, + weights: Union[float, np.ndarray, List[float]] = None, + error_probabilities: Union[float, np.ndarray, List[float]] = None, + repetitions: int = None, + timelike_weights: Union[float, np.ndarray, List[float]] = None, + measurement_error_probabilities: Union[float, np.ndarray, List[float]] = None, + *, + enable_correlations: bool = False, + **kwargs, + ): r"""Constructor for the Matching class Parameters @@ -151,6 +159,7 @@ def __init__(self, # Rustworkx PyGraph try: import rustworkx as rx + if isinstance(graph, rx.PyGraph): self.load_from_rustworkx(graph) return @@ -159,12 +168,18 @@ def __init__(self, # stim.DetectorErrorModel try: import stim + if isinstance(graph, stim.DetectorErrorModel): - self._load_from_detector_error_model(graph, enable_correlations=enable_correlations) + self._load_from_detector_error_model( + graph, enable_correlations=enable_correlations + ) return elif isinstance(graph, stim.Circuit): self.from_stim_circuit - self._load_from_detector_error_model(graph.detector_error_model(decompose_errors=True), enable_correlations=enable_correlations) + self._load_from_detector_error_model( + graph.detector_error_model(decompose_errors=True), + enable_correlations=enable_correlations, + ) return except ImportError: # pragma no cover pass @@ -172,12 +187,20 @@ def __init__(self, try: graph = csc_matrix(graph) except TypeError: - raise TypeError("The type of the input graph is not recognised. `graph` must be " - "a scipy.sparse or numpy matrix, networkx or rustworkx graph, or " - "stim.DetectorErrorModel.") - self.load_from_check_matrix(graph, weights, error_probabilities, - repetitions, timelike_weights, measurement_error_probabilities, - **kwargs) + raise TypeError( + "The type of the input graph is not recognised. `graph` must be " + "a scipy.sparse or numpy matrix, networkx or rustworkx graph, or " + "stim.DetectorErrorModel." + ) + self.load_from_check_matrix( + graph, + weights, + error_probabilities, + repetitions, + timelike_weights, + measurement_error_probabilities, + **kwargs, + ) def add_noise(self) -> Union[Tuple[np.ndarray, np.ndarray], None]: """Add noise by flipping edges in the matching graph with @@ -199,29 +222,38 @@ def add_noise(self) -> Union[Tuple[np.ndarray, np.ndarray], None]: return None return self._matching_graph.add_noise() - def _syndrome_array_to_detection_events(self, z: Union[np.ndarray, List[int]]) -> np.ndarray: + def _syndrome_array_to_detection_events( + self, z: Union[np.ndarray, List[int]] + ) -> np.ndarray: try: z = np.array(z, dtype=np.uint8) except ValueError: - raise ValueError("Syndrome must be of type numpy.ndarray or " - "convertible to numpy.ndarray, not {}".format(z)) - if len(z.shape) == 1 and (self.num_detectors <= z.shape[0] - <= self.num_detectors + len(self.boundary)): + raise ValueError( + "Syndrome must be of type numpy.ndarray or " + "convertible to numpy.ndarray, not {}".format(z) + ) + if len(z.shape) == 1 and ( + self.num_detectors <= z.shape[0] <= self.num_detectors + len(self.boundary) + ): detection_events = z.nonzero()[0] elif len(z.shape) == 2 and z.shape[0] * z.shape[1] == self.num_detectors: times, checks = z.T.nonzero() detection_events = times * z.shape[0] + checks else: - raise ValueError("The shape ({}) of the syndrome vector z is not valid.".format(z.shape)) + raise ValueError( + "The shape ({}) of the syndrome vector z is not valid.".format(z.shape) + ) return detection_events - def decode(self, - z: Union[np.ndarray, List[bool], List[int]], - *, - return_weight: bool = False, - enable_correlations: bool = False, - **kwargs - ) -> Union[np.ndarray, Tuple[np.ndarray, int]]: + def decode( + self, + z: Union[np.ndarray, List[bool], List[int]], + *, + return_weight: bool = False, + enable_correlations: bool = False, + edge_reweights: Optional[np.ndarray] = None, + **kwargs, + ) -> Union[np.ndarray, Tuple[np.ndarray, int]]: r""" Decode the syndrome `z` using minimum-weight perfect matching @@ -254,6 +286,15 @@ def decode(self, `stim.DetectorErrorModel` with `enable_correlations=True`. For a description of the correlated matching algorithm, see https://arxiv.org/abs/1310.0863. By default, False + edge_reweights: np.ndarray, optional + A 2D numpy array of edge reweights, of shape (N, 3). Each row of `edge_reweights` + specifies an edge to be temporarily reweighted for the duration of the decoding + of the shot `z`. The first two columns of `edge_reweights` specify the node endpoints + of the edge to reweight, and the third column is the new edge weight (as a float) to reweight to. + For example, `edge_reweights[i, :] == np.array([4, 5, 2.4], dtype=np.float64)` means + "give edge (4, 5) a new weight of 2.4". + For a boundary edge, the second node index is -1. + By default None. Returns ------- @@ -338,7 +379,9 @@ def decode(self, """ detection_events = self._syndrome_array_to_detection_events(z) correction, weight = self._matching_graph.decode( - detection_events, enable_correlations=enable_correlations + detection_events, + enable_correlations=enable_correlations, + edge_reweights=edge_reweights, ) if return_weight: return correction, weight @@ -346,13 +389,15 @@ def decode(self, return correction def decode_batch( - self, - shots: np.ndarray, - *, - return_weights: bool = False, - bit_packed_shots: bool = False, - bit_packed_predictions: bool = False, - enable_correlations: bool = False) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + self, + shots: np.ndarray, + *, + return_weights: bool = False, + bit_packed_shots: bool = False, + bit_packed_predictions: bool = False, + enable_correlations: bool = False, + edge_reweights: Optional[List[Optional[np.ndarray]]] = None, + ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """ Decode from a 2D `shots` array containing a batch of syndrome measurements. A faster alternative to using `pymatching.Matching.decode` and iterating over the shots in Python. @@ -396,6 +441,12 @@ def decode_batch( `stim.DetectorErrorModel` with `enable_correlations=True`. For a description of the correlated matching algorithm, see https://arxiv.org/abs/1310.0863. By default, False + edge_reweights : list[numpy.ndarray], optional + A list of reweighting rules for each shot. `edge_reweights[i]` corresponds to the edge reweights + for shot `i`. `edge_reweights[i]` can be `None` (no reweighting) or a 2D numpy array of shape (M, 3), + where each row specifies an edge to reweight (see `Matching.decode` documentation for more details). + The length of `edge_reweights` must equal the number of shots (the first dimension of `shots`). + By default None. Returns ------- @@ -452,18 +503,20 @@ def decode_batch( shots, bit_packed_predictions=bit_packed_predictions, bit_packed_shots=bit_packed_shots, - enable_correlations=enable_correlations + enable_correlations=enable_correlations, + edge_reweights=edge_reweights, ) if return_weights: return predictions, weights else: return predictions - def decode_to_edges_array(self, - syndrome: Union[np.ndarray, List[bool], List[int]], - *, - enable_correlations: bool = False - ) -> np.ndarray: + def decode_to_edges_array( + self, + syndrome: Union[np.ndarray, List[bool], List[int]], + *, + enable_correlations: bool = False, + ) -> np.ndarray: """ Decode the syndrome `syndrome` using minimum-weight perfect matching, returning the edges in the solution, given as pairs of detector node indices in a numpy array. @@ -523,9 +576,9 @@ def decode_to_edges_array(self, detection_events, enable_correlations=enable_correlations ) - def decode_to_matched_dets_array(self, - syndrome: Union[np.ndarray, List[bool], List[int]] - ) -> np.ndarray: + def decode_to_matched_dets_array( + self, syndrome: Union[np.ndarray, List[bool], List[int]] + ) -> np.ndarray: """ Decode the syndrome `syndrome` using minimum-weight perfect matching, returning the pairs of matched detection events (or detection events matched to the boundary) as a 2D numpy array. @@ -572,11 +625,13 @@ def decode_to_matched_dets_array(self, [ 4 6]] """ detection_events = self._syndrome_array_to_detection_events(syndrome) - return self._matching_graph.decode_to_matched_detection_events_array(detection_events) + return self._matching_graph.decode_to_matched_detection_events_array( + detection_events + ) - def decode_to_matched_dets_dict(self, - syndrome: Union[np.ndarray, List[bool], List[int]] - ) -> Union[np.ndarray, Tuple[np.ndarray, int]]: + def decode_to_matched_dets_dict( + self, syndrome: Union[np.ndarray, List[bool], List[int]] + ) -> Union[np.ndarray, Tuple[np.ndarray, int]]: """ Decode the syndrome `syndrome` using minimum-weight perfect matching, returning a dictionary giving the detection event that each detection event was matched to (or None if it was matched @@ -618,7 +673,9 @@ def decode_to_matched_dets_dict(self, {0: None, 3: 4, 4: 3} """ detection_events = self._syndrome_array_to_detection_events(syndrome) - return self._matching_graph.decode_to_matched_detection_events_dict(detection_events) + return self._matching_graph.decode_to_matched_detection_events_dict( + detection_events + ) def draw(self) -> None: """Draw the matching graph using matplotlib @@ -631,15 +688,17 @@ def draw(self) -> None: this function. """ # Ignore matplotlib deprecation warnings from networkx.draw_networkx - warnings.filterwarnings("ignore", category=matplotlib.MatplotlibDeprecationWarning) + warnings.filterwarnings( + "ignore", category=matplotlib.MatplotlibDeprecationWarning + ) warnings.filterwarnings("ignore", category=DeprecationWarning) G = self.to_networkx() pos = nx.spectral_layout(G, weight=None) c = "#bfbfbf" - ncolors = ['w' if n[1]['is_boundary'] else c for n in G.nodes(data=True)] + ncolors = ["w" if n[1]["is_boundary"] else c for n in G.nodes(data=True)] nx.draw_networkx_nodes(G, pos=pos, node_color=ncolors, edgecolors=c) nx.draw_networkx_labels(G, pos=pos) - weights = np.array([e[2]['weight'] for e in G.edges(data=True)]) + weights = np.array([e[2]["weight"] for e in G.edges(data=True)]) normalised_weights = 0.2 + 2 * weights / np.max(weights) nx.draw_networkx_edges(G, pos=pos, width=normalised_weights) @@ -651,30 +710,39 @@ def qid_to_str(qid): else: return str(qid) - edge_labels = {(s, t): qid_to_str(d['fault_ids']) for (s, t, d) in G.edges(data=True)} + edge_labels = { + (s, t): qid_to_str(d["fault_ids"]) for (s, t, d) in G.edges(data=True) + } nx.draw_networkx_edge_labels(G, pos=pos, edge_labels=edge_labels) def __repr__(self) -> str: m = self.num_detectors b = len(self.boundary) e = self._matching_graph.get_num_edges() - return "".format( - m, 's' if m != 1 else '', b, 's' if b != 1 else '', - e, 's' if e != 1 else '') + return ( + "".format( + m, + "s" if m != 1 else "", + b, + "s" if b != 1 else "", + e, + "s" if e != 1 else "", + ) + ) def add_edge( - self, - node1: int, - node2: int, - fault_ids: Union[int, Set[int]] = None, - weight: float = 1.0, - error_probability: float = None, - *, - merge_strategy: str = "disallow", - **kwargs + self, + node1: int, + node2: int, + fault_ids: Union[int, Set[int]] = None, + weight: float = 1.0, + error_probability: float = None, + *, + merge_strategy: str = "disallow", + **kwargs, ) -> None: """ Add an edge to the matching graph @@ -743,27 +811,30 @@ def add_edge( [(0, 1, {'fault_ids': {1}, 'weight': 1.0, 'error_probability': -1.0})] """ if fault_ids is not None and "qubit_id" in kwargs: - raise ValueError("Both `fault_ids` and `qubit_id` were provided as arguments. Please " - "provide `fault_ids` instead of `qubit_id` as an argument, as use of `qubit_id` has " - "been deprecated.") + raise ValueError( + "Both `fault_ids` and `qubit_id` were provided as arguments. Please " + "provide `fault_ids` instead of `qubit_id` as an argument, as use of `qubit_id` has " + "been deprecated." + ) if fault_ids is None and "qubit_id" in kwargs: fault_ids = kwargs["qubit_id"] if isinstance(fault_ids, (int, np.integer)): fault_ids = set() if fault_ids == -1 else {int(fault_ids)} fault_ids = set() if fault_ids is None else fault_ids error_probability = error_probability if error_probability is not None else -1 - self._matching_graph.add_edge(node1, node2, fault_ids, weight, - error_probability, merge_strategy) + self._matching_graph.add_edge( + node1, node2, fault_ids, weight, error_probability, merge_strategy + ) def add_boundary_edge( - self, - node: int, - fault_ids: Union[int, Set[int]] = None, - weight: float = 1.0, - error_probability: float = None, - *, - merge_strategy: str = "disallow", - **kwargs + self, + node: int, + fault_ids: Union[int, Set[int]] = None, + weight: float = 1.0, + error_probability: float = None, + *, + merge_strategy: str = "disallow", + **kwargs, ) -> None: """ Add an edge connecting `node` to the boundary @@ -831,8 +902,9 @@ def add_boundary_edge( fault_ids = set() if fault_ids == -1 else {int(fault_ids)} fault_ids = set() if fault_ids is None else fault_ids error_probability = error_probability if error_probability is not None else -1 - self._matching_graph.add_boundary_edge(node, fault_ids, weight, - error_probability, merge_strategy) + self._matching_graph.add_boundary_edge( + node, fault_ids, weight, error_probability, merge_strategy + ) def has_edge(self, node1: int, node2: int) -> bool: """ @@ -872,7 +944,9 @@ def has_boundary_edge(self, node: int) -> bool: """ return self._matching_graph.has_boundary_edge(node) - def get_edge_data(self, node1: int, node2: int) -> Dict[str, Union[Set[int], float]]: + def get_edge_data( + self, node1: int, node2: int + ) -> Dict[str, Union[Set[int], float]]: """ Returns the edge data associated with the edge `(node1, node2)`. @@ -927,18 +1001,18 @@ def edges(self) -> List[Tuple[int, Optional[int], Dict]]: @staticmethod def from_check_matrix( - check_matrix: Union[csc_matrix, spmatrix, np.ndarray, List[List[int]]], - weights: Union[float, np.ndarray, List[float]] = None, - error_probabilities: Union[float, np.ndarray, List[float]] = None, - repetitions: int = None, - timelike_weights: Union[float, np.ndarray, List[float]] = None, - measurement_error_probabilities: Union[float, np.ndarray, List[float]] = None, - *, - faults_matrix: Union[csc_matrix, spmatrix, np.ndarray, List[List[int]]] = None, - merge_strategy: str = "smallest-weight", - use_virtual_boundary_node: bool = False, - **kwargs - ) -> 'pymatching.Matching': + check_matrix: Union[csc_matrix, spmatrix, np.ndarray, List[List[int]]], + weights: Union[float, np.ndarray, List[float]] = None, + error_probabilities: Union[float, np.ndarray, List[float]] = None, + repetitions: int = None, + timelike_weights: Union[float, np.ndarray, List[float]] = None, + measurement_error_probabilities: Union[float, np.ndarray, List[float]] = None, + *, + faults_matrix: Union[csc_matrix, spmatrix, np.ndarray, List[List[int]]] = None, + merge_strategy: str = "smallest-weight", + use_virtual_boundary_node: bool = False, + **kwargs, + ) -> "pymatching.Matching": r""" Load a matching graph from a check matrix @@ -1047,23 +1121,24 @@ def from_check_matrix( faults_matrix=faults_matrix, merge_strategy=merge_strategy, use_virtual_boundary_node=use_virtual_boundary_node, - **kwargs + **kwargs, ) return m - def load_from_check_matrix(self, - check_matrix: Union[csc_matrix, spmatrix, np.ndarray, List[List[int]]] = None, - weights: Union[float, np.ndarray, List[float]] = None, - error_probabilities: Union[float, np.ndarray, List[float]] = None, - repetitions: int = None, - timelike_weights: Union[float, np.ndarray, List[float]] = None, - measurement_error_probabilities: Union[float, np.ndarray, List[float]] = None, - *, - faults_matrix: Union[csc_matrix, spmatrix, np.ndarray, List[List[int]]] = None, - merge_strategy: str = "smallest-weight", - use_virtual_boundary_node: bool = False, - **kwargs - ) -> None: + def load_from_check_matrix( + self, + check_matrix: Union[csc_matrix, spmatrix, np.ndarray, List[List[int]]] = None, + weights: Union[float, np.ndarray, List[float]] = None, + error_probabilities: Union[float, np.ndarray, List[float]] = None, + repetitions: int = None, + timelike_weights: Union[float, np.ndarray, List[float]] = None, + measurement_error_probabilities: Union[float, np.ndarray, List[float]] = None, + *, + faults_matrix: Union[csc_matrix, spmatrix, np.ndarray, List[List[int]]] = None, + merge_strategy: str = "smallest-weight", + use_virtual_boundary_node: bool = False, + **kwargs, + ) -> None: """ Load a matching graph from a check matrix @@ -1169,13 +1244,17 @@ def load_from_check_matrix(self, try: check_matrix = csc_matrix(check_matrix) except TypeError: - raise TypeError("`check_matrix` must be convertible to a `scipy.sparse.csc_matrix`") + raise TypeError( + "`check_matrix` must be convertible to a `scipy.sparse.csc_matrix`" + ) if faults_matrix is not None: try: faults_matrix = csc_matrix(faults_matrix) except TypeError: - raise TypeError("`faults` must be convertible to `scipy.sparse.csc_matrix`") + raise TypeError( + "`faults` must be convertible to `scipy.sparse.csc_matrix`" + ) num_edges = check_matrix.shape[1] @@ -1183,9 +1262,11 @@ def load_from_check_matrix(self, if weights is None and slw is not None: weights = slw elif weights is not None and slw is not None: - raise ValueError("Both `weights` and `spacelike_weights` were provided as arguments, but these " - "two arguments are equivalent. Please provide only `weights` as an argument, as " - "the `spacelike_weights` argument has been deprecated.") + raise ValueError( + "Both `weights` and `spacelike_weights` were provided as arguments, but these " + "two arguments are equivalent. Please provide only `weights` as an argument, as " + "the `spacelike_weights` argument has been deprecated." + ) weights = 1.0 if weights is None else weights if isinstance(weights, (int, float, np.integer, np.floating)): @@ -1204,43 +1285,59 @@ def load_from_check_matrix(self, if repetitions > 1: timelike_weights = 1.0 if timelike_weights is None else timelike_weights if isinstance(timelike_weights, (int, float, np.integer, np.floating)): - timelike_weights = np.ones(check_matrix.shape[0], dtype=float) * timelike_weights + timelike_weights = ( + np.ones(check_matrix.shape[0], dtype=float) * timelike_weights + ) elif isinstance(timelike_weights, (np.ndarray, list)): timelike_weights = np.array(timelike_weights, dtype=float) else: - raise ValueError("timelike_weights should be a float or a 1d numpy array") + raise ValueError( + "timelike_weights should be a float or a 1d numpy array" + ) mep = kwargs.get("measurement_error_probability") if measurement_error_probabilities is not None and mep is not None: - raise ValueError("Both `measurement_error_probabilities` and `measurement_error_probability` " - "were provided as arguments. Please " - "provide `measurement_error_probabilities` instead of `measurement_error_probability` " - "as an argument, as use of `measurement_error_probability` has been deprecated.") + raise ValueError( + "Both `measurement_error_probabilities` and `measurement_error_probability` " + "were provided as arguments. Please " + "provide `measurement_error_probabilities` instead of `measurement_error_probability` " + "as an argument, as use of `measurement_error_probability` has been deprecated." + ) if measurement_error_probabilities is None and mep is not None: measurement_error_probabilities = mep - p_meas = measurement_error_probabilities if measurement_error_probabilities is not None else -1 + p_meas = ( + measurement_error_probabilities + if measurement_error_probabilities is not None + else -1 + ) if isinstance(p_meas, (int, float, np.integer, np.floating)): p_meas = np.ones(check_matrix.shape[0], dtype=float) * p_meas elif isinstance(p_meas, (np.ndarray, list)): p_meas = np.array(p_meas, dtype=float) else: - raise ValueError("measurement_error_probabilities should be a float or 1d numpy array") + raise ValueError( + "measurement_error_probabilities should be a float or 1d numpy array" + ) else: timelike_weights = None p_meas = None - self._matching_graph = _cpp_pm.sparse_column_check_matrix_to_matching_graph(check_matrix, weights, - error_probabilities, - merge_strategy, - use_virtual_boundary_node, - repetitions, - timelike_weights, p_meas, - faults_matrix) + self._matching_graph = _cpp_pm.sparse_column_check_matrix_to_matching_graph( + check_matrix, + weights, + error_probabilities, + merge_strategy, + use_virtual_boundary_node, + repetitions, + timelike_weights, + p_meas, + faults_matrix, + ) @staticmethod def from_detector_error_model( - model: 'stim.DetectorErrorModel', *, enable_correlations: bool = False - ) -> 'pymatching.Matching': + model: "stim.DetectorErrorModel", *, enable_correlations: bool = False + ) -> "pymatching.Matching": """ Constructs a `pymatching.Matching` object by loading from a `stim.DetectorErrorModel`. @@ -1292,15 +1389,15 @@ def from_detector_error_model( """ m = Matching() - m._load_from_detector_error_model(model, enable_correlations=enable_correlations) + m._load_from_detector_error_model( + model, enable_correlations=enable_correlations + ) return m @staticmethod def from_detector_error_model_file( - dem_path: Union[str, Path], - *, - enable_correlations: bool = False - ) -> 'pymatching.Matching': + dem_path: Union[str, Path], *, enable_correlations: bool = False + ) -> "pymatching.Matching": """ Construct a `pymatching.Matching` by loading from a stim DetectorErrorModel file path. @@ -1324,13 +1421,14 @@ def from_detector_error_model_file( dem_path = str(dem_path) m = Matching() m._matching_graph = _cpp_pm.detector_error_model_file_to_matching_graph( - dem_path, - enable_correlations=enable_correlations + dem_path, enable_correlations=enable_correlations ) return m @staticmethod - def from_stim_circuit(circuit: 'stim.Circuit', *, enable_correlations=False) -> 'pymatching.Matching': + def from_stim_circuit( + circuit: "stim.Circuit", *, enable_correlations=False + ) -> "pymatching.Matching": """ Constructs a `pymatching.Matching` object by loading from a `stim.Circuit` @@ -1374,20 +1472,20 @@ def from_stim_circuit(circuit: 'stim.Circuit', *, enable_correlations=False) -> "To install stim using pip, run `pip install stim`." ) if not isinstance(circuit, stim.Circuit): - raise TypeError(f"`circuit` must be a `stim.Circuit`. Instead, got {type(circuit)}") + raise TypeError( + f"`circuit` must be a `stim.Circuit`. Instead, got {type(circuit)}" + ) m = Matching() m._matching_graph = _cpp_pm.detector_error_model_to_matching_graph( str(circuit.detector_error_model(decompose_errors=True)), - enable_correlations=enable_correlations + enable_correlations=enable_correlations, ) return m @staticmethod def from_stim_circuit_file( - stim_circuit_path: Union[str, Path], - *, - enable_correlations: bool = False - ) -> 'pymatching.Matching': + stim_circuit_path: Union[str, Path], *, enable_correlations: bool = False + ) -> "pymatching.Matching": """ Construct a `pymatching.Matching` by loading from a stim circuit file path. @@ -1412,12 +1510,13 @@ def from_stim_circuit_file( stim_circuit_path = str(stim_circuit_path) m = Matching() m._matching_graph = _cpp_pm.stim_circuit_file_to_matching_graph( - stim_circuit_path, - enable_correlations=enable_correlations + stim_circuit_path, enable_correlations=enable_correlations ) return m - def _load_from_detector_error_model(self, model: 'stim.DetectorErrorModel', *, enable_correlations: bool = False) -> None: + def _load_from_detector_error_model( + self, model: "stim.DetectorErrorModel", *, enable_correlations: bool = False + ) -> None: try: import stim except ImportError: # pragma no cover @@ -1427,13 +1526,17 @@ def _load_from_detector_error_model(self, model: 'stim.DetectorErrorModel', *, e "To install stim using pip, run `pip install stim`." ) if not isinstance(model, stim.DetectorErrorModel): - raise TypeError(f"'model' must be `stim.DetectorErrorModel`. Instead, got: {type(model)}") + raise TypeError( + f"'model' must be `stim.DetectorErrorModel`. Instead, got: {type(model)}" + ) self._matching_graph = _cpp_pm.detector_error_model_to_matching_graph( str(model), enable_correlations=enable_correlations ) @staticmethod - def from_networkx(graph: nx.Graph, *, min_num_fault_ids: int = None) -> 'pymatching.Matching': + def from_networkx( + graph: nx.Graph, *, min_num_fault_ids: int = None + ) -> "pymatching.Matching": r""" Returns a new `pymatching.Matching` object from a NetworkX graph @@ -1480,12 +1583,12 @@ def from_networkx(graph: nx.Graph, *, min_num_fault_ids: int = None) -> 'pymatch """ m = Matching() - m.load_from_networkx( - graph=graph, min_num_fault_ids=min_num_fault_ids - ) + m.load_from_networkx(graph=graph, min_num_fault_ids=min_num_fault_ids) return m - def load_from_networkx(self, graph: nx.Graph, *, min_num_fault_ids: int = None) -> None: + def load_from_networkx( + self, graph: nx.Graph, *, min_num_fault_ids: int = None + ) -> None: r""" Load a matching graph from a NetworkX graph into a `pymatching.Matching` object @@ -1534,19 +1637,22 @@ def load_from_networkx(self, graph: nx.Graph, *, min_num_fault_ids: int = None) if not isinstance(graph, nx.Graph): raise TypeError("G must be a NetworkX graph") - boundary = {i for i, attr in graph.nodes(data=True) - if attr.get("is_boundary", False)} + boundary = { + i for i, attr in graph.nodes(data=True) if attr.get("is_boundary", False) + } num_nodes = graph.number_of_nodes() all_fault_ids = set() num_fault_ids = 0 if min_num_fault_ids is None else min_num_fault_ids g = _cpp_pm.MatchingGraph(num_nodes, num_fault_ids) g.set_boundary(boundary) - for (u, v, attr) in graph.edges(data=True): + for u, v, attr in graph.edges(data=True): u, v = int(u), int(v) if "fault_ids" in attr and "qubit_id" in attr: - raise ValueError("Both `fault_ids` and `qubit_id` were provided as edge attributes, however use " - "of `qubit_id` has been deprecated in favour of `fault_ids`. Please only supply " - "`fault_ids` as an edge attribute.") + raise ValueError( + "Both `fault_ids` and `qubit_id` were provided as edge attributes, however use " + "of `qubit_id` has been deprecated in favour of `fault_ids`. Please only supply " + "`fault_ids` as an edge attribute." + ) if "fault_ids" not in attr and "qubit_id" in attr: fault_ids = attr["qubit_id"] # Still accept qubit_id as well for now else: @@ -1557,28 +1663,41 @@ def load_from_networkx(self, graph: nx.Graph, *, min_num_fault_ids: int = None) try: fault_ids = set(fault_ids) if not all(isinstance(q, (int, np.integer)) for q in fault_ids): - raise TypeError("fault_ids must be a set of ints, not {}".format(fault_ids)) + raise TypeError( + "fault_ids must be a set of ints, not {}".format(fault_ids) + ) except TypeError: raise TypeError( "fault_ids property must be an int or a set of int" - " (or convertible to a set), not {}".format(fault_ids)) + " (or convertible to a set), not {}".format(fault_ids) + ) all_fault_ids = all_fault_ids | fault_ids weight = attr.get("weight", 1) # Default weight is 1 if not provided e_prob = attr.get("error_probability", -1) # Note: NetworkX graphs do not support parallel edges (merge strategy is redundant) - g.add_edge(u, v, fault_ids, weight, e_prob, merge_strategy="smallest-weight") + g.add_edge( + u, v, fault_ids, weight, e_prob, merge_strategy="smallest-weight" + ) self._matching_graph = g - def load_from_retworkx(self, graph: "rx.PyGraph", *, min_num_fault_ids: int = None) -> None: + def load_from_retworkx( + self, graph: "rx.PyGraph", *, min_num_fault_ids: int = None + ) -> None: r""" Load a matching graph from a retworkX graph. This method is deprecated since the retworkx package has been renamed to rustworkx. Please use ``pymatching.Matching.load_from_rustworkx`` instead. """ - warnings.warn("`pymatching.Matching.load_from_retworkx` is now deprecated since the `retworkx` library has been " - "renamed to `rustworkx`. Please use `pymatching.Matching.load_from_rustworkx` instead.", DeprecationWarning, stacklevel=2) + warnings.warn( + "`pymatching.Matching.load_from_retworkx` is now deprecated since the `retworkx` library has been " + "renamed to `rustworkx`. Please use `pymatching.Matching.load_from_rustworkx` instead.", + DeprecationWarning, + stacklevel=2, + ) self.load_from_rustworkx(graph=graph, min_num_fault_ids=min_num_fault_ids) - def load_from_rustworkx(self, graph: "rx.PyGraph", *, min_num_fault_ids: int = None) -> None: + def load_from_rustworkx( + self, graph: "rx.PyGraph", *, min_num_fault_ids: int = None + ) -> None: r""" Load a matching graph from a rustworkX graph @@ -1623,20 +1742,26 @@ def load_from_rustworkx(self, graph: "rx.PyGraph", *, min_num_fault_ids: int = N try: import rustworkx as rx except ImportError: # pragma no cover - raise ImportError("rustworkx must be installed to use Matching.load_from_rustworkx") + raise ImportError( + "rustworkx must be installed to use Matching.load_from_rustworkx" + ) if not isinstance(graph, rx.PyGraph): raise TypeError("G must be a rustworkx graph") - boundary = {i for i in graph.node_indices() if graph[i].get("is_boundary", False)} + boundary = { + i for i in graph.node_indices() if graph[i].get("is_boundary", False) + } num_nodes = len(graph) num_fault_ids = 0 if min_num_fault_ids is None else min_num_fault_ids g = _cpp_pm.MatchingGraph(num_nodes, num_fault_ids) g.set_boundary(boundary) - for (u, v, attr) in graph.weighted_edge_list(): + for u, v, attr in graph.weighted_edge_list(): u, v = int(u), int(v) if "fault_ids" in attr and "qubit_id" in attr: - raise ValueError("Both `fault_ids` and `qubit_id` were provided as edge attributes, however use " - "of `qubit_id` has been deprecated in favour of `fault_ids`. Please only supply " - "`fault_ids` as an edge attribute.") + raise ValueError( + "Both `fault_ids` and `qubit_id` were provided as edge attributes, however use " + "of `qubit_id` has been deprecated in favour of `fault_ids`. Please only supply " + "`fault_ids` as an edge attribute." + ) if "fault_ids" not in attr and "qubit_id" in attr: fault_ids = attr["qubit_id"] # Still accept qubit_id as well for now else: @@ -1647,15 +1772,20 @@ def load_from_rustworkx(self, graph: "rx.PyGraph", *, min_num_fault_ids: int = N try: fault_ids = set(fault_ids) if not all(isinstance(q, (int, np.integer)) for q in fault_ids): - raise TypeError("fault_ids must be a set of ints, not {}".format(fault_ids)) + raise TypeError( + "fault_ids must be a set of ints, not {}".format(fault_ids) + ) except TypeError: raise TypeError( "fault_ids property must be an int or a set of int" - " (or convertible to a set), not {}".format(fault_ids)) + " (or convertible to a set), not {}".format(fault_ids) + ) weight = attr.get("weight", 1) # Default weight is 1 if not provided e_prob = attr.get("error_probability", -1) # Note: rustworkx graphs do not support parallel edges (merge strategy is redundant) - g.add_edge(u, v, fault_ids, weight, e_prob, merge_strategy="smallest-weight") + g.add_edge( + u, v, fault_ids, weight, e_prob, merge_strategy="smallest-weight" + ) self._matching_graph = g def to_networkx(self) -> nx.Graph: @@ -1681,9 +1811,9 @@ def to_networkx(self) -> nx.Graph: boundary = self.boundary for i in graph.nodes: is_boundary = i in boundary - graph.nodes[i]['is_boundary'] = is_boundary + graph.nodes[i]["is_boundary"] = is_boundary if has_virtual_boundary: - graph.nodes[num_nodes]['is_boundary'] = True + graph.nodes[num_nodes]["is_boundary"] = True return graph def to_retworkx(self) -> "rx.PyGraph": @@ -1692,8 +1822,12 @@ def to_retworkx(self) -> "rx.PyGraph": ``retworkx.PyGraph``. Note that in the future, only the `rustworkx` package name will be supported, see: https://pypi.org/project/retworkx/. """ - warnings.warn("`pymatching.Matching.to_retworkx` is now deprecated since the `retworkx` library has been " - "renamed to `rustworkx`. Please use `pymatching.Matching.to_rustworkx` instead.", DeprecationWarning, stacklevel=2) + warnings.warn( + "`pymatching.Matching.to_retworkx` is now deprecated since the `retworkx` library has been " + "renamed to `rustworkx`. Please use `pymatching.Matching.to_rustworkx` instead.", + DeprecationWarning, + stacklevel=2, + ) return self.to_rustworkx() def to_rustworkx(self) -> "rx.PyGraph": @@ -1711,7 +1845,9 @@ def to_rustworkx(self) -> "rx.PyGraph": try: import rustworkx as rx except ImportError: # pragma no cover - raise ImportError("rustworkx must be installed to use Matching.to_rustworkx.") + raise ImportError( + "rustworkx must be installed to use Matching.to_rustworkx." + ) graph = rx.PyGraph(multigraph=False) num_nodes = self.num_nodes @@ -1728,7 +1864,7 @@ def to_rustworkx(self) -> "rx.PyGraph": boundary = self.boundary for i in graph.node_indices(): is_boundary = i in boundary - graph[i]['is_boundary'] = is_boundary + graph[i]["is_boundary"] = is_boundary if has_virtual_boundary: graph[num_nodes]["is_boundary"] = True return graph diff --git a/src/pymatching/sparse_blossom/driver/mwpm_decoding.cc b/src/pymatching/sparse_blossom/driver/mwpm_decoding.cc index 3f08c378b..fbd709db6 100644 --- a/src/pymatching/sparse_blossom/driver/mwpm_decoding.cc +++ b/src/pymatching/sparse_blossom/driver/mwpm_decoding.cc @@ -183,7 +183,7 @@ pm::MatchingResult pm::decode_detection_events_for_up_to_64_observables( res += shatter_blossoms_for_all_detection_events_and_extract_obs_mask_and_weight( mwpm, mwpm.flooder.negative_weight_detection_events); res.obs_mask ^= mwpm.flooder.negative_weight_obs_mask; - res.weight += mwpm.flooder.negative_weight_sum; + res.weight += mwpm.flooder.graph.negative_weight_sum; if (edge_correlations) { mwpm.flooder.graph.undo_reweights(); @@ -225,7 +225,7 @@ void pm::decode_detection_events( for (auto& obs : mwpm.flooder.negative_weight_observables) *(obs_begin_ptr + obs) ^= 1; // Add negative weight sum to blossom solution weight - weight += mwpm.flooder.negative_weight_sum; + weight += mwpm.flooder.graph.negative_weight_sum; } else { pm::MatchingResult bit_packed_res = @@ -238,7 +238,7 @@ void pm::decode_detection_events( // Translate observable mask into bit vector fill_bit_vector_from_obs_mask(bit_packed_res.obs_mask, obs_begin_ptr, num_observables); // Add negative weight sum to blossom solution weight - weight = bit_packed_res.weight + mwpm.flooder.negative_weight_sum; + weight = bit_packed_res.weight + mwpm.flooder.graph.negative_weight_sum; } if (edge_correlations) { @@ -248,7 +248,7 @@ void pm::decode_detection_events( } void pm::decode_detection_events_to_match_edges(pm::Mwpm& mwpm, const std::vector& detection_events) { - if (mwpm.flooder.negative_weight_sum != 0) + if (mwpm.flooder.graph.negative_weight_sum != 0) throw std::invalid_argument( "Decoding to matched detection events not supported for graphs containing edges with negative weights."); process_timeline_until_completion(mwpm, detection_events); diff --git a/src/pymatching/sparse_blossom/driver/mwpm_decoding.perf.cc b/src/pymatching/sparse_blossom/driver/mwpm_decoding.perf.cc index 0daa298a9..26f3bea4e 100644 --- a/src/pymatching/sparse_blossom/driver/mwpm_decoding.perf.cc +++ b/src/pymatching/sparse_blossom/driver/mwpm_decoding.perf.cc @@ -72,7 +72,7 @@ BENCHMARK(Decode_surface_r5_d5_p1000) { } } }) - .goal_micros(290) + .goal_micros(350) .show_rate("dets", (double)num_dets) .show_rate("layers", (double)rounds * (double)shots.size()) .show_rate("shots", (double)shots.size()); @@ -105,7 +105,7 @@ BENCHMARK(Decode_surface_r11_d11_p100) { } } }) - .goal_millis(10) + .goal_millis(17) .show_rate("dets", (double)num_dets) .show_rate("layers", (double)rounds * (double)shots.size()) .show_rate("shots", (double)shots.size()); @@ -114,6 +114,52 @@ BENCHMARK(Decode_surface_r11_d11_p100) { } } +BENCHMARK(Decode_surface_r11_d11_p100_reweight) { + size_t rounds = 11; + auto data = generate_data(11, rounds, 0.01, 128); + const auto &dem = data.first; + const auto &shots = data.second; + + size_t num_buckets = pm::NUM_DISTINCT_WEIGHTS; + auto mwpm = pm::detector_error_model_to_mwpm(dem, num_buckets); + + size_t num_dets = 0; + for (const auto &shot : shots) { + num_dets += shot.hits.size(); + } + + // Create a reweight vector (reweight edge 0-1 to weight 5.0) + // We need to find valid node indices. Just pick 0 and 1 if they exist. + std::vector> reweights_r11_local; + if (mwpm.flooder.graph.nodes.size() > 1) { + // Find an edge + size_t u = 0; + int64_t v = -1; + if (mwpm.flooder.graph.nodes[0].neighbors[0]) { + v = mwpm.flooder.graph.nodes[0].neighbors[0] - &mwpm.flooder.graph.nodes[0]; + } + reweights_r11_local.emplace_back(u, v, 5.0); + } + + size_t num_mistakes = 0; + benchmark_go([&]() { + for (const auto &shot : shots) { + mwpm.flooder.graph.apply_temp_reweights(reweights_r11_local); + auto res = + pm::decode_detection_events_for_up_to_64_observables(mwpm, shot.hits, /*enable_correlations=*/false); + mwpm.flooder.graph.undo_reweights(); + if (shot.obs_mask_as_u64() != res.obs_mask) { + num_mistakes++; + } + } + }) + .goal_millis(17) + .show_rate("dets", (double)num_dets) + .show_rate("layers", (double)rounds * (double)shots.size()) + .show_rate("shots", (double)shots.size()); + // Mistakes are expected since we are messing up the weights +} + BENCHMARK(Decode_surface_r11_d11_p1000) { size_t rounds = 11; auto data = generate_data(11, rounds, 0.001, 512); @@ -138,7 +184,7 @@ BENCHMARK(Decode_surface_r11_d11_p1000) { } } }) - .goal_millis(1.5) + .goal_millis(2.4) .show_rate("dets", (double)num_dets) .show_rate("layers", (double)rounds * (double)shots.size()) .show_rate("shots", (double)shots.size()); @@ -171,7 +217,7 @@ BENCHMARK(Decode_surface_r11_d11_p10000) { } } }) - .goal_micros(83) + .goal_micros(78) .show_rate("dets", (double)num_dets) .show_rate("layers", (double)rounds * (double)shots.size()) .show_rate("shots", (double)shots.size()); @@ -204,7 +250,7 @@ BENCHMARK(Decode_surface_r11_d11_p100000) { } } }) - .goal_micros(33) + .goal_micros(35) .show_rate("dets", (double)num_dets) .show_rate("layers", (double)rounds * (double)shots.size()) .show_rate("shots", (double)shots.size()); @@ -237,7 +283,7 @@ BENCHMARK(Decode_surface_r21_d21_p100) { } } }) - .goal_millis(7.5) + .goal_millis(13) .show_rate("dets", (double)num_dets) .show_rate("layers", (double)rounds * (double)shots.size()) .show_rate("shots", (double)shots.size()); @@ -273,7 +319,7 @@ BENCHMARK(Decode_surface_r21_d21_p100_with_dijkstra) { res.reset(); } }) - .goal_millis(7.8) + .goal_millis(16) .show_rate("dets", (double)num_dets) .show_rate("layers", (double)rounds * (double)shots.size()) .show_rate("shots", (double)shots.size()); @@ -303,7 +349,7 @@ BENCHMARK(Decode_surface_r21_d21_p100_to_edges) { pm::decode_detection_events_to_edges(mwpm, shot.hits, edges); } }) - .goal_millis(8.2) + .goal_millis(16) .show_rate("dets", (double)num_dets) .show_rate("layers", (double)rounds * (double)shots.size()) .show_rate("shots", (double)shots.size()); @@ -330,7 +376,7 @@ BENCHMARK(Decode_surface_r21_d21_p100_to_edges_with_correlations) { pm::decode_detection_events_to_edges_with_edge_correlations(mwpm, shot.hits, edges); } }) - .goal_millis(8.2) + .goal_millis(34) .show_rate("dets", (double)num_dets) .show_rate("layers", (double)rounds * (double)shots.size()) .show_rate("shots", (double)shots.size()); @@ -360,7 +406,7 @@ BENCHMARK(Decode_surface_r21_d21_p1000) { } } }) - .goal_millis(6.3) + .goal_millis(9.6) .show_rate("dets", (double)num_dets) .show_rate("layers", (double)rounds * (double)shots.size()) .show_rate("shots", (double)shots.size()); @@ -369,6 +415,48 @@ BENCHMARK(Decode_surface_r21_d21_p1000) { } } +BENCHMARK(Decode_surface_r21_d21_p1000_reweight) { + size_t rounds = 21; + auto data = generate_data(21, rounds, 0.001, 256); + const auto &dem = data.first; + const auto &shots = data.second; + + size_t num_buckets = pm::NUM_DISTINCT_WEIGHTS; + auto mwpm = pm::detector_error_model_to_mwpm(dem, num_buckets); + + size_t num_dets = 0; + for (const auto &shot : shots) { + num_dets += shot.hits.size(); + } + + std::vector> reweights_r21_local; + if (mwpm.flooder.graph.nodes.size() > 1) { + size_t u = 0; + int64_t v = -1; + if (mwpm.flooder.graph.nodes[0].neighbors[0]) { + v = mwpm.flooder.graph.nodes[0].neighbors[0] - &mwpm.flooder.graph.nodes[0]; + } + reweights_r21_local.emplace_back(u, v, 5.0); + } + + size_t num_mistakes = 0; + benchmark_go([&]() { + for (const auto &shot : shots) { + mwpm.flooder.graph.apply_temp_reweights(reweights_r21_local); + auto res = + pm::decode_detection_events_for_up_to_64_observables(mwpm, shot.hits, /*enable_correlations=*/false); + mwpm.flooder.graph.undo_reweights(); + if (shot.obs_mask_as_u64() != res.obs_mask) { + num_mistakes++; + } + } + }) + .goal_millis(9.6) + .show_rate("dets", (double)num_dets) + .show_rate("layers", (double)rounds * (double)shots.size()) + .show_rate("shots", (double)shots.size()); +} + BENCHMARK(Decode_surface_r21_d21_p1000_with_dijkstra) { size_t rounds = 21; auto data = generate_data(21, rounds, 0.001, 256); @@ -397,7 +485,7 @@ BENCHMARK(Decode_surface_r21_d21_p1000_with_dijkstra) { res.reset(); } }) - .goal_millis(7.7) + .goal_millis(16) .show_rate("dets", (double)num_dets) .show_rate("layers", (double)rounds * (double)shots.size()) .show_rate("shots", (double)shots.size()); @@ -427,7 +515,7 @@ BENCHMARK(Decode_surface_r21_d21_p1000_to_edges) { pm::decode_detection_events_to_edges(mwpm, shot.hits, edges); } }) - .goal_millis(8.4) + .goal_millis(17) .show_rate("dets", (double)num_dets) .show_rate("layers", (double)rounds * (double)shots.size()) .show_rate("shots", (double)shots.size()); @@ -454,7 +542,55 @@ BENCHMARK(Decode_surface_r21_d21_p1000_to_edges_with_correlations) { pm::decode_detection_events_to_edges_with_edge_correlations(mwpm, shot.hits, edges); } }) - .goal_millis(8.4) + .goal_millis(48) + .show_rate("dets", (double)num_dets) + .show_rate("layers", (double)rounds * (double)shots.size()) + .show_rate("shots", (double)shots.size()); +} + +BENCHMARK(Decode_surface_r21_d21_p1000_reweight_with_correlations) { + size_t rounds = 21; + auto data = generate_data(21, rounds, 0.001, 256, true); + auto &dem = data.first; + const auto &shots = data.second; + + size_t num_buckets = pm::NUM_DISTINCT_WEIGHTS; + auto mwpm = pm::detector_error_model_to_mwpm(dem, num_buckets, /*ensure_search_flooder_included=*/true, /*enable_correlations=*/true); + + size_t num_dets = 0; + for (const auto &shot : shots) { + num_dets += shot.hits.size(); + } + + std::vector> reweights_corr_local; + if (mwpm.flooder.graph.nodes.size() > 1) { + size_t u = 0; + int64_t v = -1; + if (mwpm.flooder.graph.nodes[0].neighbors[0]) { + v = mwpm.flooder.graph.nodes[0].neighbors[0] - &mwpm.flooder.graph.nodes[0]; + } + reweights_corr_local.emplace_back(u, v, 5.0); + } + + size_t num_mistakes = 0; + pm::ExtendedMatchingResult res(mwpm.flooder.graph.num_observables); + benchmark_go([&]() { + for (const auto &shot : shots) { + mwpm.flooder.graph.apply_temp_reweights(reweights_corr_local); + mwpm.search_flooder.graph.apply_temp_reweights(reweights_corr_local, mwpm.flooder.graph.normalising_constant); + + pm::decode_detection_events(mwpm, shot.hits, res.obs_crossed.data(), res.weight, /*enable_correlations=*/true); + + mwpm.flooder.graph.undo_reweights(); + mwpm.search_flooder.graph.undo_reweights(); + + if (shot.obs_mask_as_u64() != res.obs_crossed[0]) { + num_mistakes++; + } + res.reset(); + } + }) + .goal_millis(44) .show_rate("dets", (double)num_dets) .show_rate("layers", (double)rounds * (double)shots.size()) .show_rate("shots", (double)shots.size()); @@ -484,7 +620,7 @@ BENCHMARK(Decode_surface_r21_d21_p10000) { } } }) - .goal_millis(0.980) + .goal_millis(1.7) .show_rate("dets", (double)num_dets) .show_rate("layers", (double)rounds * (double)shots.size()) .show_rate("shots", (double)shots.size()); @@ -521,7 +657,7 @@ BENCHMARK(Decode_surface_r21_d21_p10000_with_dijkstra) { res.reset(); } }) - .goal_millis(1.3) + .goal_millis(2.9) .show_rate("dets", (double)num_dets) .show_rate("layers", (double)rounds * (double)shots.size()) .show_rate("shots", (double)shots.size()); @@ -551,7 +687,7 @@ BENCHMARK(Decode_surface_r21_d21_p10000_to_edges) { pm::decode_detection_events_to_edges(mwpm, shot.hits, edges); } }) - .goal_millis(1.4) + .goal_millis(3.1) .show_rate("dets", (double)num_dets) .show_rate("layers", (double)rounds * (double)shots.size()) .show_rate("shots", (double)shots.size()); @@ -578,7 +714,7 @@ BENCHMARK(Decode_surface_r21_d21_p10000_to_edges_with_correlations) { pm::decode_detection_events_to_edges_with_edge_correlations(mwpm, shot.hits, edges); } }) - .goal_millis(1.4) + .goal_millis(8.0) .show_rate("dets", (double)num_dets) .show_rate("layers", (double)rounds * (double)shots.size()) .show_rate("shots", (double)shots.size()); @@ -608,7 +744,7 @@ BENCHMARK(Decode_surface_r21_d21_p100000) { } } }) - .goal_micros(94) + .goal_micros(110) .show_rate("dets", (double)num_dets) .show_rate("layers", (double)rounds * (double)shots.size()) .show_rate("shots", (double)shots.size()); @@ -645,7 +781,7 @@ BENCHMARK(Decode_surface_r21_d21_p100000_with_dijkstra) { res.reset(); } }) - .goal_micros(130) + .goal_micros(230) .show_rate("dets", (double)num_dets) .show_rate("layers", (double)rounds * (double)shots.size()) .show_rate("shots", (double)shots.size()); @@ -675,7 +811,7 @@ BENCHMARK(Decode_surface_r21_d21_p100000_to_edges) { pm::decode_detection_events_to_edges(mwpm, shot.hits, edges); } }) - .goal_micros(130) + .goal_micros(250) .show_rate("dets", (double)num_dets) .show_rate("layers", (double)rounds * (double)shots.size()) .show_rate("shots", (double)shots.size()); @@ -702,7 +838,7 @@ BENCHMARK(Decode_surface_r21_d21_p100000_to_edges_with_correlations) { pm::decode_detection_events_to_edges_with_edge_correlations(mwpm, shot.hits, edges); } }) - .goal_micros(130) + .goal_micros(590) .show_rate("dets", (double)num_dets) .show_rate("layers", (double)rounds * (double)shots.size()) .show_rate("shots", (double)shots.size()); diff --git a/src/pymatching/sparse_blossom/driver/user_graph.cc b/src/pymatching/sparse_blossom/driver/user_graph.cc index 313343978..aba8c7a5c 100644 --- a/src/pymatching/sparse_blossom/driver/user_graph.cc +++ b/src/pymatching/sparse_blossom/driver/user_graph.cc @@ -337,7 +337,7 @@ pm::Mwpm pm::UserGraph::to_mwpm(pm::weight_int num_distinct_weights, bool ensure } pm::Mwpm& pm::UserGraph::get_mwpm_with_search_graph() { - if (!_mwpm_needs_updating && _mwpm.flooder.graph.nodes.size() == _mwpm.search_flooder.graph.nodes.size()) { + if (!_mwpm_needs_updating && _mwpm.search_flooder_available()) { return _mwpm; } else { _mwpm = to_mwpm(pm::NUM_DISTINCT_WEIGHTS, true); diff --git a/src/pymatching/sparse_blossom/driver/user_graph.perf.cc b/src/pymatching/sparse_blossom/driver/user_graph.perf.cc index 8678afc66..48ad39e64 100644 --- a/src/pymatching/sparse_blossom/driver/user_graph.perf.cc +++ b/src/pymatching/sparse_blossom/driver/user_graph.perf.cc @@ -35,7 +35,7 @@ BENCHMARK(Load_dem_r11_d11_p100) { auto mwpm = pm::detector_error_model_to_mwpm(dem, num_buckets); } }) - .goal_millis(35) + .goal_millis(54) .show_rate("loads", (double)num_loads); } @@ -48,10 +48,9 @@ BENCHMARK(Load_dem_r21_d21_p100) { auto mwpm = pm::detector_error_model_to_mwpm(dem, num_buckets); } }) - .goal_millis(280) + .goal_millis(460) .show_rate("loads", (double)num_loads); } - BENCHMARK(Load_dem_r11_d11_p100_correlations) { auto dem = generate_dem(11, 11, 0.01, true); size_t num_buckets = 1024; @@ -62,7 +61,7 @@ BENCHMARK(Load_dem_r11_d11_p100_correlations) { dem, num_buckets, /*ensure_search_flooder_included=*/false, /*enable_correlations=*/true); } }) - .goal_millis(35) + .goal_millis(300) .show_rate("loads", (double)num_loads); } @@ -76,6 +75,6 @@ BENCHMARK(Load_dem_r21_d21_p100_correlations) { dem, num_buckets, /*ensure_search_flooder_included=*/false, /*enable_correlations=*/true); } }) - .goal_millis(280) + .goal_millis(3100) .show_rate("loads", (double)num_loads); } diff --git a/src/pymatching/sparse_blossom/driver/user_graph.pybind.cc b/src/pymatching/sparse_blossom/driver/user_graph.pybind.cc index 7b227e5cc..b613afce7 100644 --- a/src/pymatching/sparse_blossom/driver/user_graph.pybind.cc +++ b/src/pymatching/sparse_blossom/driver/user_graph.pybind.cc @@ -178,15 +178,41 @@ void pm_pybind::pybind_user_graph_methods(py::module &m, py::class_ &detection_events, bool enable_correlations) { + [](pm::UserGraph &self, + const py::array_t &detection_events, + bool enable_correlations, + py::object edge_reweights) { std::vector detection_events_vec( detection_events.data(), detection_events.data() + detection_events.size()); auto &mwpm = enable_correlations ? self.get_mwpm_with_search_graph() : self.get_mwpm(); + + if (!edge_reweights.is_none()) { + auto rw_array = edge_reweights.cast>(); + if (rw_array.ndim() != 2 || rw_array.shape(1) != 3) + throw std::invalid_argument("edge_reweights must be (N, 3)"); + auto r = rw_array.unchecked<2>(); + auto& reweights = mwpm.flooder.graph.reweight_buffer; + reweights.clear(); + reweights.reserve(r.shape(0)); + for (py::ssize_t i = 0; i < r.shape(0); i++) { + reweights.emplace_back((size_t)r(i, 0), (int64_t)r(i, 1), r(i, 2)); + } + mwpm.flooder.graph.apply_temp_reweights(reweights); + if (mwpm.search_flooder_available()) + mwpm.search_flooder.graph.apply_temp_reweights(reweights, mwpm.flooder.graph.normalising_constant); + } + auto obs_crossed = new std::vector(self.get_num_observables(), 0); pm::total_weight_int weight = 0; pm::decode_detection_events(mwpm, detection_events_vec, obs_crossed->data(), weight, enable_correlations); double rescaled_weight = (double)weight / mwpm.flooder.graph.normalising_constant; + if (!edge_reweights.is_none()) { + mwpm.flooder.graph.undo_reweights(); + if (mwpm.search_flooder_available()) + mwpm.search_flooder.graph.undo_reweights(); + } + auto err_capsule = py::capsule(obs_crossed, [](void *x) { delete reinterpret_cast *>(x); }); @@ -196,7 +222,8 @@ void pm_pybind::pybind_user_graph_methods(py::module &m, py::class_ &detection_events, bool enable_correlations) { @@ -250,7 +277,8 @@ void pm_pybind::pybind_user_graph_methods(py::module &m, py::class_ &shots, bool bit_packed_shots, bool bit_packed_predictions, - bool enable_correlations) { + bool enable_correlations, + py::object edge_reweights) { if (shots.ndim() != 2) throw std::invalid_argument( "`shots` array should have two dimensions, not " + std::to_string(shots.ndim())); @@ -288,6 +316,14 @@ void pm_pybind::pybind_user_graph_methods(py::module &m, py::class_ detection_events; + py::list reweights_list; + bool has_reweights = !edge_reweights.is_none(); + if (has_reweights) { + reweights_list = edge_reweights.cast(); + if (reweights_list.size() != (size_t)shots.shape(0)) + throw std::invalid_argument("edge_reweights list must have same length as shots"); + } + // Vector used to extract predicted observables when decoding if bit_packed_predictions is true std::vector temp_predictions; if (bit_packed_predictions) @@ -310,6 +346,26 @@ void pm_pybind::pybind_user_graph_methods(py::module &m, py::class_>(); + if (rw_array.ndim() != 2 || rw_array.shape(1) != 3) + throw std::invalid_argument("edge_reweights element must be (N, 3)"); + auto r = rw_array.unchecked<2>(); + auto& reweights = mwpm.flooder.graph.reweight_buffer; + reweights.clear(); + reweights.reserve(r.shape(0)); + for (py::ssize_t k = 0; k < r.shape(0); k++) { + reweights.emplace_back((size_t)r(k, 0), (int64_t)r(k, 1), r(k, 2)); + } + mwpm.flooder.graph.apply_temp_reweights(reweights); + if (mwpm.search_flooder_available()) + mwpm.search_flooder.graph.apply_temp_reweights(reweights, mwpm.flooder.graph.normalising_constant); + } + } + pm::total_weight_int solution_weight = 0; if (bit_packed_predictions) { std::fill(temp_predictions.begin(), temp_predictions.end(), 0); @@ -328,6 +384,15 @@ void pm_pybind::pybind_user_graph_methods(py::module &m, py::class_ &detection_events) { diff --git a/src/pymatching/sparse_blossom/flooder/detector_node.h b/src/pymatching/sparse_blossom/flooder/detector_node.h index 446de586d..b8fc4b565 100644 --- a/src/pymatching/sparse_blossom/flooder/detector_node.h +++ b/src/pymatching/sparse_blossom/flooder/detector_node.h @@ -25,6 +25,8 @@ namespace pm { +enum NodeFlags : uint8_t { WEIGHT_SIGN = 2 }; + /// A detector node is a location where a detection event might occur. /// /// It corresponds to a potential symptom that could be seen, and can @@ -66,6 +68,7 @@ class DetectorNode { std::vector neighbors; /// The node's neighbors. std::vector neighbor_weights; /// Distance crossed by the edge to each neighbor. std::vector neighbor_observables; /// Observables crossed by the edge to each neighbor. + std::vector neighbor_markers; /// Markers for each edge (e.g. WEIGHT_SIGN). /// After it reached this node, how much further did the owning search region grow? Also is it currently growing? inline VaryingCT local_radius() const { diff --git a/src/pymatching/sparse_blossom/flooder/graph.cc b/src/pymatching/sparse_blossom/flooder/graph.cc index 5c1b8850d..708a6b094 100644 --- a/src/pymatching/sparse_blossom/flooder/graph.cc +++ b/src/pymatching/sparse_blossom/flooder/graph.cc @@ -61,15 +61,22 @@ void MatchingGraph::add_edge( // all_edges_to_implied_weights_unconverted[u][v] for a node u corresponds to the edge weights conditioned by (u, v) // where v is the v'th neighbour of u in nodes[u].neighbors. + uint8_t weight_sign = 0; + if (weight < 0) { + weight_sign = pm::WEIGHT_SIGN; + } + nodes[u].neighbors.push_back(&(nodes[v])); nodes[u].neighbor_weights.push_back(std::abs(weight)); nodes[u].neighbor_observables.push_back(obs_mask); + nodes[u].neighbor_markers.push_back(weight_sign); nodes[u].neighbor_implied_weights.push_back({}); edges_to_implied_weights_unconverted[u].emplace_back(implied_weights_for_other_edges); nodes[v].neighbors.push_back(&(nodes[u])); nodes[v].neighbor_weights.push_back(std::abs(weight)); nodes[v].neighbor_observables.push_back(obs_mask); + nodes[v].neighbor_markers.push_back(weight_sign); nodes[v].neighbor_implied_weights.push_back({}); edges_to_implied_weights_unconverted[v].emplace_back(implied_weights_for_other_edges); } @@ -99,6 +106,11 @@ void MatchingGraph::add_boundary_edge( negative_weight_sum += weight; } + uint8_t weight_sign = 0; + if (weight < 0) { + weight_sign = pm::WEIGHT_SIGN; + } + auto& n = nodes[u]; if (!n.neighbors.empty() && n.neighbors[0] == nullptr) { throw std::invalid_argument("Max one boundary edge."); @@ -106,6 +118,7 @@ void MatchingGraph::add_boundary_edge( n.neighbors.insert(n.neighbors.begin(), 1, nullptr); n.neighbor_weights.insert(n.neighbor_weights.begin(), 1, std::abs(weight)); n.neighbor_observables.insert(n.neighbor_observables.begin(), 1, obs_mask); + n.neighbor_markers.insert(n.neighbor_markers.begin(), 1, weight_sign); n.neighbor_implied_weights.insert(n.neighbor_implied_weights.begin(), 1, {}); edges_to_implied_weights_unconverted[u].insert( edges_to_implied_weights_unconverted[u].begin(), 1, implied_weights_for_other_edges); @@ -228,6 +241,16 @@ void MatchingGraph::undo_reweights() { *prev.ptr = prev.val; } previous_weights.clear(); + negative_weight_sum -= negative_weight_sum_delta; + negative_weight_sum_delta = 0; +} + +void MatchingGraph::apply_temp_reweights(const std::vector>& reweights) { + apply_temp_reweights_generic( + *this, reweights, normalising_constant, [this](pm::signed_weight_int delta) { + negative_weight_sum += delta; + negative_weight_sum_delta += delta; + }); } } // namespace pm diff --git a/src/pymatching/sparse_blossom/flooder/graph.h b/src/pymatching/sparse_blossom/flooder/graph.h index abfddfc25..22bad6051 100644 --- a/src/pymatching/sparse_blossom/flooder/graph.h +++ b/src/pymatching/sparse_blossom/flooder/graph.h @@ -15,8 +15,10 @@ #ifndef PYMATCHING2_GRAPH_H #define PYMATCHING2_GRAPH_H +#include #include #include +#include #include #include @@ -62,6 +64,8 @@ class MatchingGraph { // alert a user if they try to decode with enable_correlations=true, but forgot to load from the // dem with enable_correlations=true. bool loaded_from_dem_without_correlations = false; + pm::total_weight_int negative_weight_sum_delta = 0; + std::vector> reweight_buffer; MatchingGraph(); MatchingGraph(size_t num_nodes, size_t num_observables); @@ -86,6 +90,7 @@ class MatchingGraph { void reweight(std::vector& implied_weights); void reweight_for_edge(const int64_t& u, const int64_t& v); void reweight_for_edges(const std::vector& edges); + void apply_temp_reweights(const std::vector>& reweights); }; void apply_reweights( @@ -113,6 +118,67 @@ inline void MatchingGraph::reweight(std::vector& implied_weights) apply_reweights(implied_weights, previous_weights); } +template +void apply_temp_reweights_generic( + GraphType& graph, + const std::vector>& reweights, + double normalising_constant, + OnNegativeEdge on_negative_edge) { + for (const auto& rw : reweights) { + size_t u = std::get<0>(rw); + int64_t v = std::get<1>(rw); + double weight = std::get<2>(rw); + + double rescaled_normalising_constant = normalising_constant / 2; + pm::signed_weight_int w = (pm::signed_weight_int)round(weight * rescaled_normalising_constant); + w *= 2; + pm::weight_int new_w = std::abs(w); + + if (u >= graph.nodes.size()) + throw std::invalid_argument("Node index " + std::to_string(u) + " out of range"); + auto* u_node_ptr = &graph.nodes[u]; + auto* v_node_ptr = (decltype(u_node_ptr)) nullptr; + if (v != -1) { + if (v < 0 || (size_t)v >= graph.nodes.size()) + throw std::invalid_argument("Node index " + std::to_string(v) + " out of range"); + v_node_ptr = &graph.nodes[(size_t)v]; + } + + size_t idx = u_node_ptr->index_of_neighbor(v_node_ptr); + if (idx == SIZE_MAX) + throw std::invalid_argument("Edge (" + std::to_string(u) + ", " + std::to_string(v) + ") not found"); + + // Check sign consistency + bool new_is_negative = w < 0; + bool old_is_negative = u_node_ptr->neighbor_markers[idx] & pm::WEIGHT_SIGN; + if (new_is_negative != old_is_negative) { + throw std::invalid_argument( + "Reweighting edge (" + std::to_string(u) + ", " + std::to_string(v) + + ") failed: sign flip not allowed. " + "Original sign: " + + (old_is_negative ? "negative" : "positive") + + ", New sign: " + (new_is_negative ? "negative" : "positive")); + } + + if (new_is_negative) { + pm::signed_weight_int old_w = -(pm::signed_weight_int)u_node_ptr->neighbor_weights[idx]; + pm::signed_weight_int delta = w - old_w; + on_negative_edge(delta); + } + + weight_int* w_ptr = &u_node_ptr->neighbor_weights[idx]; + graph.previous_weights.emplace_back(w_ptr, *w_ptr); + *w_ptr = new_w; + + if (v_node_ptr) { + size_t idx_v = v_node_ptr->index_of_neighbor(u_node_ptr); + weight_int* w_ptr_v = &v_node_ptr->neighbor_weights[idx_v]; + graph.previous_weights.emplace_back(w_ptr_v, *w_ptr_v); + *w_ptr_v = new_w; + } + } +} + } // namespace pm #endif // PYMATCHING2_GRAPH_H diff --git a/src/pymatching/sparse_blossom/flooder/graph.test.cc b/src/pymatching/sparse_blossom/flooder/graph.test.cc index 8569c6130..6cc740f60 100644 --- a/src/pymatching/sparse_blossom/flooder/graph.test.cc +++ b/src/pymatching/sparse_blossom/flooder/graph.test.cc @@ -89,3 +89,51 @@ TEST(Graph, AddBoundaryEdgeWithImpliedWeights) { ASSERT_EQ(g.edges_to_implied_weights_unconverted[0][0][0].node2, 2); ASSERT_EQ(g.edges_to_implied_weights_unconverted[0][0][0].implied_weight, 7); } + +TEST(Graph, ApplyTempReweights) { + pm::MatchingGraph g(4, 64); + g.add_edge(0, 1, 10, {0}, {}); + g.add_edge(1, 2, -20, {1}, {}); + + ASSERT_EQ(g.nodes[0].neighbor_weights[0], 10); + ASSERT_EQ(g.nodes[1].neighbor_weights[0], 10); + ASSERT_EQ(g.nodes[1].neighbor_weights[1], 20); + ASSERT_EQ(g.negative_weight_sum, -20); + ASSERT_EQ(g.negative_weight_sum_delta, 0); + + // Reweight positive to positive + std::vector> reweights; + g.normalising_constant = 1.0; + // 30.0 * (1.0/2) = 15.0. 15 * 2 = 30. + reweights.emplace_back(0, 1, 30.0); + g.apply_temp_reweights(reweights); + ASSERT_EQ(g.nodes[0].neighbor_weights[0], 30); + ASSERT_EQ(g.negative_weight_sum, -20); // No change + ASSERT_EQ(g.negative_weight_sum_delta, 0); + + g.undo_reweights(); + ASSERT_EQ(g.nodes[0].neighbor_weights[0], 10); + ASSERT_EQ(g.negative_weight_sum_delta, 0); + + // Reweight negative to negative + // Orig -20. Reweight to -40.0. + // -40 * 0.5 = -20. -20 * 2 = -40. + reweights.clear(); + reweights.emplace_back(1, 2, -40.0); + g.apply_temp_reweights(reweights); + ASSERT_EQ(g.nodes[1].neighbor_weights[1], 40); + // Delta: -40 - (-20) = -20. + // Sum: -20 + (-20) = -40. + ASSERT_EQ(g.negative_weight_sum, -40); + ASSERT_EQ(g.negative_weight_sum_delta, -20); + + g.undo_reweights(); + ASSERT_EQ(g.nodes[1].neighbor_weights[1], 20); + ASSERT_EQ(g.negative_weight_sum, -20); + ASSERT_EQ(g.negative_weight_sum_delta, 0); + + // Sign flip (should throw) + reweights.clear(); + reweights.emplace_back(0, 1, -5.0); + ASSERT_THROW(g.apply_temp_reweights(reweights), std::invalid_argument); +} diff --git a/src/pymatching/sparse_blossom/flooder_matcher_interop/varying.perf.cc b/src/pymatching/sparse_blossom/flooder_matcher_interop/varying.perf.cc index b02f7d68e..d7e3ad883 100644 --- a/src/pymatching/sparse_blossom/flooder_matcher_interop/varying.perf.cc +++ b/src/pymatching/sparse_blossom/flooder_matcher_interop/varying.perf.cc @@ -37,7 +37,7 @@ BENCHMARK(Varying32_get_distance_at_time) { total += varyings[k].get_distance_at_time(times[k]); } }) - .goal_millis(3.5) + .goal_millis(6.7) .show_rate("calls", NUM_ITEMS); if (total == 0) { std::cerr << "data dependence"; @@ -62,7 +62,7 @@ BENCHMARK(Varying64_get_distance_at_time) { total += varyings[k].get_distance_at_time(times[k]); } }) - .goal_millis(5.9) + .goal_millis(16) .show_rate("calls", NUM_ITEMS); if (total == 0) { std::cerr << "data dependence"; diff --git a/src/pymatching/sparse_blossom/matcher/mwpm.h b/src/pymatching/sparse_blossom/matcher/mwpm.h index a2067101b..996611b49 100644 --- a/src/pymatching/sparse_blossom/matcher/mwpm.h +++ b/src/pymatching/sparse_blossom/matcher/mwpm.h @@ -83,6 +83,10 @@ struct Mwpm { void create_detection_event(DetectorNode* node); void reset(); + + bool search_flooder_available() const { + return flooder.graph.nodes.size() == search_flooder.graph.nodes.size(); + } }; } // namespace pm diff --git a/src/pymatching/sparse_blossom/search/search_detector_node.h b/src/pymatching/sparse_blossom/search/search_detector_node.h index 7cfc6d14c..50cb28241 100644 --- a/src/pymatching/sparse_blossom/search/search_detector_node.h +++ b/src/pymatching/sparse_blossom/search/search_detector_node.h @@ -15,18 +15,18 @@ #ifndef PYMATCHING2_SEARCH_DETECTOR_NODE_H #define PYMATCHING2_SEARCH_DETECTOR_NODE_H -#include "pymatching/sparse_blossom/driver/implied_weights.h" +#include +#include + +#include "pymatching/sparse_blossom/flooder/detector_node.h" #include "pymatching/sparse_blossom/tracker/queued_event_tracker.h" namespace pm { -const uint8_t FLIPPED = 1; -const uint8_t WEIGHT_SIGN = 2; +enum SearchNodeFlags : uint8_t { FLIPPED = 1 }; -class SearchDetectorNode { - public: - SearchDetectorNode() : reached_from_source(nullptr), index_of_predecessor(SIZE_MAX), distance_from_source(0) { - } +struct SearchDetectorNode { + SearchDetectorNode() : reached_from_source(nullptr), index_of_predecessor(SIZE_MAX), distance_from_source(0), node_event_tracker() {} /// The SearchDetectorNode that this node was reached from in the Dijkstra search SearchDetectorNode *reached_from_source; diff --git a/src/pymatching/sparse_blossom/search/search_graph.cc b/src/pymatching/sparse_blossom/search/search_graph.cc index d4be3ae5d..ac77f4a7f 100644 --- a/src/pymatching/sparse_blossom/search/search_graph.cc +++ b/src/pymatching/sparse_blossom/search/search_graph.cc @@ -159,3 +159,8 @@ void pm::SearchGraph::convert_implied_weights(const double normalising_constant) } } } + +void pm::SearchGraph::apply_temp_reweights( + const std::vector>& reweights, double normalising_constant) { + apply_temp_reweights_generic(*this, reweights, normalising_constant, [](pm::signed_weight_int delta) {}); +} diff --git a/src/pymatching/sparse_blossom/search/search_graph.h b/src/pymatching/sparse_blossom/search/search_graph.h index e57006c66..02bca70ea 100644 --- a/src/pymatching/sparse_blossom/search/search_graph.h +++ b/src/pymatching/sparse_blossom/search/search_graph.h @@ -57,6 +57,7 @@ class SearchGraph { void reweight_for_edge(const int64_t& u, const int64_t& v); void reweight_for_edges(const std::vector& edges); void undo_reweights(); + void apply_temp_reweights(const std::vector>& reweights, double normalising_constant); }; inline void SearchGraph::reweight(std::vector& implied_weights) { diff --git a/src/pymatching/sparse_blossom/search/search_graph.test.cc b/src/pymatching/sparse_blossom/search/search_graph.test.cc index 03ffafe0a..49d09bf52 100644 --- a/src/pymatching/sparse_blossom/search/search_graph.test.cc +++ b/src/pymatching/sparse_blossom/search/search_graph.test.cc @@ -46,3 +46,32 @@ TEST(SearchGraph, AddBoundaryEdge) { ASSERT_EQ(g.nodes[0].neighbor_observable_indices[0], v1); ASSERT_EQ(g.nodes[0].neighbor_weights[0], 7); } + +TEST(SearchGraph, ApplyTempReweights) { + pm::SearchGraph g(3); + g.add_edge(0, 1, 10, {0}); + g.add_edge(1, 2, -20, {1}); + + std::vector> reweights; + reweights.emplace_back(0, 1, 30.0); + // normalising_constant = 1.0 + g.apply_temp_reweights(reweights, 1.0); + + // 30.0 * 0.5 = 15. 15*2 = 30. + ASSERT_EQ(g.nodes[0].neighbor_weights[0], 30); + + g.undo_reweights(); + ASSERT_EQ(g.nodes[0].neighbor_weights[0], 10); + + reweights.clear(); + reweights.emplace_back(1, 2, -40.0); + g.apply_temp_reweights(reweights, 1.0); + ASSERT_EQ(g.nodes[1].neighbor_weights[1], 40); + + g.undo_reweights(); + ASSERT_EQ(g.nodes[1].neighbor_weights[1], 20); + + reweights.clear(); + reweights.emplace_back(0, 1, -5.0); + ASSERT_THROW(g.apply_temp_reweights(reweights, 1.0), std::invalid_argument); +} diff --git a/src/pymatching/sparse_blossom/tracker/radix_heap_queue.perf.cc b/src/pymatching/sparse_blossom/tracker/radix_heap_queue.perf.cc index b1b151f14..c3dcb7c0f 100644 --- a/src/pymatching/sparse_blossom/tracker/radix_heap_queue.perf.cc +++ b/src/pymatching/sparse_blossom/tracker/radix_heap_queue.perf.cc @@ -45,7 +45,7 @@ BENCHMARK(bucket_queue_sort) { } } }) - .goal_micros(4.9) + .goal_micros(3.4) .show_rate("EnqueueDequeues", (double)v.size()); if (dependence) { std::cerr << "data dependence"; @@ -67,7 +67,7 @@ BENCHMARK(bucket_queue_stream) { q.enqueue(FloodCheckEvent(q.dequeue().time + cyclic_time_int{100})); } }) - .goal_micros(99) + .goal_micros(120) .show_rate("EnqueueDequeues", (double)n); if (dependence) { std::cerr << "data dependence"; diff --git a/tests/cli_test.py b/tests/cli_test.py index b476bcc37..b8c6f8a80 100644 --- a/tests/cli_test.py +++ b/tests/cli_test.py @@ -13,11 +13,16 @@ def predict_args(dem_file: Path, input_file: Path, output_file: Path) -> List[str]: return [ "predict", - "--dem", str(dem_file), - "--in", str(input_file), - "--in_format", input_file.suffix[1:], - "--out", str(output_file), - "--out_format", "dets", + "--dem", + str(dem_file), + "--in", + str(input_file), + "--in_format", + input_file.suffix[1:], + "--out", + str(output_file), + "--out_format", + "dets", ] @@ -27,47 +32,69 @@ def test_calling_cli_creates_expected_file( tmp_path: Path, data_dir: Path, cli_function: Callable[[List[str]], None], - input_format: str + input_format: str, ): output_file = tmp_path / "three_errors_predictions.dets" - cli_function(command_line_args=predict_args( - data_dir / "three_errors.dem", - data_dir / f"three_errors.{input_format}", - output_file)) + cli_function( + command_line_args=predict_args( + data_dir / "three_errors.dem", + data_dir / f"three_errors.{input_format}", + output_file, + ) + ) assert output_file.is_file() with open(output_file, encoding="utf-8") as prediction_file: - assert prediction_file.readlines() == ["shot\n", "shot L0\n", "shot L2\n", "shot L1\n"] + assert prediction_file.readlines() == [ + "shot\n", + "shot L0\n", + "shot L2\n", + "shot L1\n", + ] @pytest.mark.parametrize("input_format", ["dets", "b8"]) def test_patching_cli_argv_creates_expected_file( - tmp_path: Path, - data_dir: Path, - input_format: str + tmp_path: Path, data_dir: Path, input_format: str ): output_file = tmp_path / "three_errors_predictions.dets" args = predict_args( data_dir / "three_errors.dem", data_dir / f"three_errors.{input_format}", - output_file) + output_file, + ) with patch.object(sys, "argv", ["cli"] + args): cli_argv() assert output_file.is_file() with open(output_file) as prediction_file: - assert prediction_file.readlines() == ["shot\n", "shot L0\n", "shot L2\n", "shot L1\n"] + assert prediction_file.readlines() == [ + "shot\n", + "shot L0\n", + "shot L2\n", + "shot L1\n", + ] def test_load_surface_code_b8_cli(tmp_path: Path, data_dir: Path): dem_path = data_dir / "surface_code_rotated_memory_x_13_0.01.dem" dets_b8_in_path = data_dir / "surface_code_rotated_memory_x_13_0.01_1000_shots.b8" - out_fn = tmp_path / "surface_code_rotated_memory_x_13_0.01_1000_shots_temp_predictions.b8" + out_fn = ( + tmp_path + / "surface_code_rotated_memory_x_13_0.01_1000_shots_temp_predictions.b8" + ) - pymatching._cpp_pymatching.main(command_line_args=[ - "predict", - "--dem", str(dem_path), - "--in", str(dets_b8_in_path), - "--in_format", "b8", - "--out", str(out_fn), - "--out_format", "b8", - "--in_includes_appended_observables" - ]) + pymatching._cpp_pymatching.main( + command_line_args=[ + "predict", + "--dem", + str(dem_path), + "--in", + str(dets_b8_in_path), + "--in_format", + "b8", + "--out", + str(out_fn), + "--out_format", + "b8", + "--in_includes_appended_observables", + ] + ) diff --git a/tests/matching/add_noise_test.py b/tests/matching/add_noise_test.py index 510b0921a..6ea0b02de 100644 --- a/tests/matching/add_noise_test.py +++ b/tests/matching/add_noise_test.py @@ -38,13 +38,10 @@ def test_add_noise_with_boundary(): for i in range(11): g.add_edge(i, i + 1, fault_ids=i, error_probability=(i + 1) % 2) for i in range(5, 12): - g.nodes()[i]['is_boundary'] = True + g.nodes()[i]["is_boundary"] = True m = Matching(g) noise, syndrome = m.add_noise() assert sum(syndrome) == 5 assert np.array_equal(noise, (np.arange(11) + 1) % 2) assert m.boundary == set(range(5, 12)) - assert np.array_equal( - syndrome, - np.array([1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]) - ) + assert np.array_equal(syndrome, np.array([1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0])) diff --git a/tests/matching/decode_test.py b/tests/matching/decode_test.py index 3e2bcd513..19cf08759 100644 --- a/tests/matching/decode_test.py +++ b/tests/matching/decode_test.py @@ -166,13 +166,13 @@ def test_surface_code_solution_weights(data_dir: Path): predicted_observables = [] for i in range(min(shots.shape[0], 1000)): prediction, weight = m.decode( - shots[i, 0: -m.num_fault_ids], return_weight=True + shots[i, 0 : -m.num_fault_ids], return_weight=True ) weights.append(weight) predicted_observables.append(prediction) for weight, expected_weight in zip(weights, expected_weights): assert weight == pytest.approx(expected_weight, rel=1e-8) - assert predicted_observables == expected_observables[0: len(predicted_observables)] + assert predicted_observables == expected_observables[0 : len(predicted_observables)] expected_observables_arr = np.zeros((shots.shape[0], 1), dtype=np.uint8) expected_observables_arr[:, 0] = np.array(expected_observables) @@ -181,17 +181,17 @@ def test_surface_code_solution_weights(data_dir: Path): temp_shots, _, _ = sampler.sample(shots=10, bit_packed=True) assert temp_shots.shape[1] == np.ceil(dem.num_detectors // 8) - batch_predictions = m.decode_batch(shots[:, 0: -m.num_fault_ids]) + batch_predictions = m.decode_batch(shots[:, 0 : -m.num_fault_ids]) assert np.array_equal(batch_predictions, expected_observables_arr) batch_predictions, batch_weights = m.decode_batch( - shots[:, 0: -m.num_fault_ids], return_weights=True + shots[:, 0 : -m.num_fault_ids], return_weights=True ) assert np.array_equal(batch_predictions, expected_observables_arr) assert np.allclose(batch_weights, expected_weights, rtol=1e-8) bitpacked_shots = np.packbits( - shots[:, 0: dem.num_detectors], bitorder="little", axis=1 + shots[:, 0 : dem.num_detectors], bitorder="little", axis=1 ) batch_predictions_from_bitpacked, bitpacked_batch_weights = m.decode_batch( bitpacked_shots, return_weights=True, bit_packed_shots=True @@ -226,8 +226,14 @@ def test_surface_code_solution_weights_with_correlations(data_dir: Path): num_observables=m.num_fault_ids, ) # Test correlated decoding - corr_weights_path = data_dir / "surface_code_rotated_memory_x_13_0.01_1000_shots_no_buckets_weights_pymatching_correlated.txt" - corr_predictions_path = data_dir / "surface_code_rotated_memory_x_13_0.01_1000_shots_no_buckets_predictions_pymatching_correlated.txt" + corr_weights_path = ( + data_dir + / "surface_code_rotated_memory_x_13_0.01_1000_shots_no_buckets_weights_pymatching_correlated.txt" + ) + corr_predictions_path = ( + data_dir + / "surface_code_rotated_memory_x_13_0.01_1000_shots_no_buckets_predictions_pymatching_correlated.txt" + ) with open( corr_weights_path, "r", @@ -246,7 +252,7 @@ def test_surface_code_solution_weights_with_correlations(data_dir: Path): m_corr = Matching.from_detector_error_model(dem, enable_correlations=True) corr_predictions, corr_weights = m_corr.decode_batch( - shots[:, 0: -m.num_fault_ids], + shots[:, 0 : -m.num_fault_ids], return_weights=True, enable_correlations=True, ) @@ -409,9 +415,7 @@ def test_correlated_matching_handles_single_detector_components(): rounds=5, before_round_data_depolarization=p, ) - circ_str = str(circuit).replace( - f"DEPOLARIZE1({p})", f"PAULI_CHANNEL_1(0, {p}, 0)" - ) + circ_str = str(circuit).replace(f"DEPOLARIZE1({p})", f"PAULI_CHANNEL_1(0, {p}, 0)") noisy_circuit = stim.Circuit(circ_str) dem = noisy_circuit.detector_error_model( decompose_errors=True, approximate_disjoint_errors=True @@ -426,13 +430,17 @@ def test_load_from_circuit_with_correlations(): code_task="surface_code:rotated_memory_x", distance=3, rounds=3, - after_clifford_depolarization=0.001 + after_clifford_depolarization=0.001, ) shots = circuit.compile_detector_sampler().sample(shots=10) matching_1 = pymatching.Matching(circuit, enable_correlations=True) - matching_2 = pymatching.Matching.from_stim_circuit(circuit=circuit, enable_correlations=True) + matching_2 = pymatching.Matching.from_stim_circuit( + circuit=circuit, enable_correlations=True + ) for m in (matching_1, matching_2): - predictions, weights = m.decode_batch(shots=shots, return_weights=True, enable_correlations=True) + predictions, weights = m.decode_batch( + shots=shots, return_weights=True, enable_correlations=True + ) def test_use_correlations_with_uncorrelated_dem_load_raises_value_error(tmp_path): @@ -443,23 +451,28 @@ def test_use_correlations_with_uncorrelated_dem_load_raises_value_error(tmp_path code_task="surface_code:rotated_memory_x", distance=d, rounds=d, - after_clifford_depolarization=p + after_clifford_depolarization=p, ) dem = circuit.detector_error_model(decompose_errors=True) shots = circuit.compile_detector_sampler().sample(shots=10) matching_1 = pymatching.Matching(circuit, enable_correlations=False) - matching_2 = pymatching.Matching.from_stim_circuit(circuit=circuit, enable_correlations=False) + matching_2 = pymatching.Matching.from_stim_circuit( + circuit=circuit, enable_correlations=False + ) matching_3 = pymatching.Matching.from_detector_error_model( - model=dem, - enable_correlations=False + model=dem, enable_correlations=False ) fn = f"surface_code_x_d{d}_r{d}_p{p}" stim_file = tmp_path / f"{fn}.stim" circuit.to_file(stim_file) - matching_4 = pymatching.Matching.from_stim_circuit_file(stim_file, enable_correlations=False) + matching_4 = pymatching.Matching.from_stim_circuit_file( + stim_file, enable_correlations=False + ) dem_file = tmp_path / f"{fn}.dem" dem.to_file(dem_file) - matching_5 = pymatching.Matching.from_detector_error_model_file(dem_file, enable_correlations=False) + matching_5 = pymatching.Matching.from_detector_error_model_file( + dem_file, enable_correlations=False + ) for m in (matching_1, matching_2, matching_3, matching_4, matching_5): with pytest.raises(ValueError): m.decode_batch(shots=shots, return_weights=True, enable_correlations=True) @@ -479,7 +492,7 @@ def test_use_correlations_without_decompose_errors_raises_value_error(tmp_path): code_task="surface_code:rotated_memory_x", distance=d, rounds=d, - after_clifford_depolarization=p + after_clifford_depolarization=p, ) dem = circuit.detector_error_model(decompose_errors=False) dem_file = tmp_path / "surface_code.dem" @@ -489,4 +502,6 @@ def test_use_correlations_without_decompose_errors_raises_value_error(tmp_path): with pytest.raises(ValueError): pymatching.Matching(dem, enable_correlations=True) with pytest.raises(ValueError): - pymatching.Matching.from_detector_error_model_file(dem_file, enable_correlations=True) + pymatching.Matching.from_detector_error_model_file( + dem_file, enable_correlations=True + ) diff --git a/tests/matching/draw_test.py b/tests/matching/draw_test.py index 0667f1fc7..28a4b49cf 100644 --- a/tests/matching/draw_test.py +++ b/tests/matching/draw_test.py @@ -23,8 +23,8 @@ def test_draw_matching(): g.add_edge(0, 1, fault_ids={0}, weight=1.1, error_probability=0.1) g.add_edge(1, 2, fault_ids={1}, weight=2.1, error_probability=0.2) g.add_edge(2, 3, fault_ids={2, 3}, weight=0.9, error_probability=0.3) - g.nodes[0]['is_boundary'] = True - g.nodes[3]['is_boundary'] = True + g.nodes[0]["is_boundary"] = True + g.nodes[3]["is_boundary"] = True g.add_edge(0, 3, weight=0.0) m = Matching(g) plt.figure() diff --git a/tests/matching/edges_test.py b/tests/matching/edges_test.py index d01c62a5d..974d17771 100644 --- a/tests/matching/edges_test.py +++ b/tests/matching/edges_test.py @@ -23,8 +23,8 @@ def test_qubit_id_accepted_using_add_edge(): m.add_edge(1, 2, qubit_id={1, 2}) es = list(m.edges()) expected_edges = [ - (0, 1, {'fault_ids': {0}, 'weight': 1.0, 'error_probability': -1.0}), - (1, 2, {'fault_ids': {1, 2}, 'weight': 1.0, 'error_probability': -1.0}) + (0, 1, {"fault_ids": {0}, "weight": 1.0, "error_probability": -1.0}), + (1, 2, {"fault_ids": {1, 2}, "weight": 1.0, "error_probability": -1.0}), ] assert es == expected_edges @@ -47,50 +47,128 @@ def test_add_edge(): m.add_edge(0, 1, weight=0.123, error_probability=0.6) m.add_edge(1, 2, weight=0.6, error_probability=0.3, fault_ids=0) m.add_edge(2, 3, weight=0.01, error_probability=0.5, fault_ids={1, 2}) - expected = [(0, 1, {'fault_ids': set(), 'weight': 0.123, 'error_probability': 0.6}), - (1, 2, {'fault_ids': {0}, 'weight': 0.6, 'error_probability': 0.3}), - (2, 3, {'fault_ids': {1, 2}, 'weight': 0.01, 'error_probability': 0.5})] + expected = [ + (0, 1, {"fault_ids": set(), "weight": 0.123, "error_probability": 0.6}), + (1, 2, {"fault_ids": {0}, "weight": 0.6, "error_probability": 0.3}), + (2, 3, {"fault_ids": {1, 2}, "weight": 0.01, "error_probability": 0.5}), + ] assert m.edges() == expected def test_add_edge_merge_strategy(): m = Matching() m.add_edge(0, 10, fault_ids={0}, weight=1.2, error_probability=0.3) - assert m.edges() == [(0, 10, {'fault_ids': {0}, 'weight': 1.2, 'error_probability': 0.3})] - m.add_edge(0, 10, fault_ids={1}, weight=1.0, error_probability=0.35, merge_strategy="smallest-weight") - assert m.edges() == [(0, 10, {'fault_ids': {1}, 'weight': 1.0, 'error_probability': 0.35})] + assert m.edges() == [ + (0, 10, {"fault_ids": {0}, "weight": 1.2, "error_probability": 0.3}) + ] + m.add_edge( + 0, + 10, + fault_ids={1}, + weight=1.0, + error_probability=0.35, + merge_strategy="smallest-weight", + ) + assert m.edges() == [ + (0, 10, {"fault_ids": {1}, "weight": 1.0, "error_probability": 0.35}) + ] with pytest.raises(ValueError): - m.add_edge(0, 10, fault_ids={1}, weight=1.5, error_probability=0.6, merge_strategy="disallow") - m.add_edge(0, 10, fault_ids={2}, weight=4.0, error_probability=0.2, merge_strategy="independent") + m.add_edge( + 0, + 10, + fault_ids={1}, + weight=1.5, + error_probability=0.6, + merge_strategy="disallow", + ) + m.add_edge( + 0, + 10, + fault_ids={2}, + weight=4.0, + error_probability=0.2, + merge_strategy="independent", + ) es = m.edges() es[0][2]["weight"] = round(es[0][2]["weight"], 6) - assert es == [(0, 10, {'fault_ids': {1}, 'weight': 0.958128, 'error_probability': 0.41})] + assert es == [ + (0, 10, {"fault_ids": {1}, "weight": 0.958128, "error_probability": 0.41}) + ] m = Matching() m.add_edge(1, 10, fault_ids={0}, weight=2, error_probability=0.3) - m.add_edge(1, 10, fault_ids={1}, weight=5, error_probability=0.1, merge_strategy="keep-original") - assert m.edges() == [(1, 10, {'fault_ids': {0}, 'weight': 2, 'error_probability': 0.3})] - m.add_edge(1, 10, fault_ids={2}, weight=5, error_probability=0.1, merge_strategy="replace") - assert m.edges() == [(1, 10, {'fault_ids': {2}, 'weight': 5, 'error_probability': 0.1})] + m.add_edge( + 1, + 10, + fault_ids={1}, + weight=5, + error_probability=0.1, + merge_strategy="keep-original", + ) + assert m.edges() == [ + (1, 10, {"fault_ids": {0}, "weight": 2, "error_probability": 0.3}) + ] + m.add_edge( + 1, 10, fault_ids={2}, weight=5, error_probability=0.1, merge_strategy="replace" + ) + assert m.edges() == [ + (1, 10, {"fault_ids": {2}, "weight": 5, "error_probability": 0.1}) + ] def test_add_boundary_edge(): m = Matching() m.add_boundary_edge(0, fault_ids={0}, weight=1.2, error_probability=0.3) - assert m.edges() == [(0, None, {'fault_ids': {0}, 'weight': 1.2, 'error_probability': 0.3})] - m.add_boundary_edge(0, fault_ids={1}, weight=1.0, error_probability=0.35, merge_strategy="smallest-weight") - assert m.edges() == [(0, None, {'fault_ids': {1}, 'weight': 1.0, 'error_probability': 0.35})] + assert m.edges() == [ + (0, None, {"fault_ids": {0}, "weight": 1.2, "error_probability": 0.3}) + ] + m.add_boundary_edge( + 0, + fault_ids={1}, + weight=1.0, + error_probability=0.35, + merge_strategy="smallest-weight", + ) + assert m.edges() == [ + (0, None, {"fault_ids": {1}, "weight": 1.0, "error_probability": 0.35}) + ] with pytest.raises(ValueError): - m.add_boundary_edge(0, fault_ids={1}, weight=1.5, error_probability=0.6, merge_strategy="disallow") - m.add_boundary_edge(0, fault_ids={2}, weight=4.0, error_probability=0.2, merge_strategy="independent") + m.add_boundary_edge( + 0, + fault_ids={1}, + weight=1.5, + error_probability=0.6, + merge_strategy="disallow", + ) + m.add_boundary_edge( + 0, + fault_ids={2}, + weight=4.0, + error_probability=0.2, + merge_strategy="independent", + ) es = m.edges() es[0][2]["weight"] = round(es[0][2]["weight"], 6) - assert es == [(0, None, {'fault_ids': {1}, 'weight': 0.958128, 'error_probability': 0.41})] + assert es == [ + (0, None, {"fault_ids": {1}, "weight": 0.958128, "error_probability": 0.41}) + ] m = Matching() m.add_boundary_edge(1, fault_ids={0}, weight=2, error_probability=0.3) - m.add_boundary_edge(1, fault_ids={1}, weight=5, error_probability=0.1, merge_strategy="keep-original") - assert m.edges() == [(1, None, {'fault_ids': {0}, 'weight': 2, 'error_probability': 0.3})] - m.add_boundary_edge(1, fault_ids={2}, weight=5, error_probability=0.1, merge_strategy="replace") - assert m.edges() == [(1, None, {'fault_ids': {2}, 'weight': 5, 'error_probability': 0.1})] + m.add_boundary_edge( + 1, + fault_ids={1}, + weight=5, + error_probability=0.1, + merge_strategy="keep-original", + ) + assert m.edges() == [ + (1, None, {"fault_ids": {0}, "weight": 2, "error_probability": 0.3}) + ] + m.add_boundary_edge( + 1, fault_ids={2}, weight=5, error_probability=0.1, merge_strategy="replace" + ) + assert m.edges() == [ + (1, None, {"fault_ids": {2}, "weight": 5, "error_probability": 0.1}) + ] def test_has_edge(): @@ -111,9 +189,17 @@ def test_has_edge(): def test_get_edge_data(): m = Matching() m.add_edge(0, 1, {0, 1, 2}, 2.5, 0.1) - assert m.get_edge_data(0, 1) == {"fault_ids": {0, 1, 2}, "weight": 2.5, "error_probability": 0.1} + assert m.get_edge_data(0, 1) == { + "fault_ids": {0, 1, 2}, + "weight": 2.5, + "error_probability": 0.1, + } m.add_boundary_edge(5, {5, 6, 7}, 5.0, 0.34) - assert m.get_boundary_edge_data(5) == {"fault_ids": {5, 6, 7}, "weight": 5.0, "error_probability": 0.34} + assert m.get_boundary_edge_data(5) == { + "fault_ids": {5, 6, 7}, + "weight": 5.0, + "error_probability": 0.34, + } def test_large_edge_weight_not_added_to_graph(): @@ -128,7 +214,9 @@ def test_large_edge_weight_not_added_to_graph(): with pytest.warns(UserWarning): m.add_boundary_edge(5, weight=-9999999999) assert m.num_edges == 1 - assert m.edges() == [(0, 1, {"fault_ids": set(), "weight": 1.0, "error_probability": -1.0})] + assert m.edges() == [ + (0, 1, {"fault_ids": set(), "weight": 1.0, "error_probability": -1.0}) + ] def test_add_self_loop(): @@ -141,7 +229,7 @@ def test_add_self_loop(): (0, 1, {"fault_ids": set(), "weight": 2.0, "error_probability": -1.0}), (2, 2, {"fault_ids": set(), "weight": 5.0, "error_probability": -1.0}), (3, 4, {"fault_ids": set(), "weight": 10.0, "error_probability": -1.0}), - (4, 4, {"fault_ids": set(), "weight": 11.0, "error_probability": -1.0}) + (4, 4, {"fault_ids": set(), "weight": 11.0, "error_probability": -1.0}), ] with pytest.raises(ValueError): m.add_edge(4, 4, weight=12) @@ -151,5 +239,5 @@ def test_add_self_loop(): (0, 1, {"fault_ids": set(), "weight": 2.0, "error_probability": -1.0}), (2, 2, {"fault_ids": set(), "weight": 5.0, "error_probability": -1.0}), (3, 4, {"fault_ids": set(), "weight": 14.0, "error_probability": -1.0}), - (4, 4, {"fault_ids": set(), "weight": 10.0, "error_probability": -1.0}) + (4, 4, {"fault_ids": set(), "weight": 10.0, "error_probability": -1.0}), ] diff --git a/tests/matching/load_from_networkx_test.py b/tests/matching/load_from_networkx_test.py index 4e69d5534..0f8b4e5f1 100644 --- a/tests/matching/load_from_networkx_test.py +++ b/tests/matching/load_from_networkx_test.py @@ -22,7 +22,7 @@ def test_bad_fault_ids_raises_value_error(): g = nx.Graph() - g.add_edge(0, 1, fault_ids='test') + g.add_edge(0, 1, fault_ids="test") with pytest.raises(TypeError): Matching(g) g = nx.Graph() @@ -38,7 +38,7 @@ def test_boundary_from_networkx(): g.add_edge(1, 2, fault_ids=2) g.add_edge(2, 3, fault_ids=3) g.add_edge(3, 4, fault_ids=4) - g.nodes()[4]['is_boundary'] = True + g.nodes()[4]["is_boundary"] = True m = Matching.from_networkx(g) assert m.boundary == {4} assert np.array_equal(m.decode(np.array([1, 0, 0, 0])), np.array([1, 0, 0, 0, 0])) @@ -55,14 +55,22 @@ def test_boundaries_from_networkx(): g.add_edge(3, 4, fault_ids=3) g.add_edge(4, 5, fault_ids=4) g.add_edge(0, 5, fault_ids=-1, weight=0.0) - g.nodes()[0]['is_boundary'] = True - g.nodes()[5]['is_boundary'] = True + g.nodes()[0]["is_boundary"] = True + g.nodes()[5]["is_boundary"] = True m = Matching.from_networkx(g) assert m.boundary == {0, 5} - assert np.array_equal(m.decode(np.array([0, 1, 0, 0, 0, 0])), np.array([1, 0, 0, 0, 0])) - assert np.array_equal(m.decode(np.array([0, 0, 1, 0, 0])), np.array([1, 1, 0, 0, 0])) - assert np.array_equal(m.decode(np.array([0, 0, 1, 1, 0])), np.array([0, 0, 1, 0, 0])) - assert np.array_equal(m.decode(np.array([0, 0, 0, 1, 0])), np.array([0, 0, 0, 1, 1])) + assert np.array_equal( + m.decode(np.array([0, 1, 0, 0, 0, 0])), np.array([1, 0, 0, 0, 0]) + ) + assert np.array_equal( + m.decode(np.array([0, 0, 1, 0, 0])), np.array([1, 1, 0, 0, 0]) + ) + assert np.array_equal( + m.decode(np.array([0, 0, 1, 1, 0])), np.array([0, 0, 1, 0, 0]) + ) + assert np.array_equal( + m.decode(np.array([0, 0, 0, 1, 0])), np.array([0, 0, 0, 1, 1]) + ) def test_wrong_networkx_graph_type_raises_type_error(): @@ -83,23 +91,20 @@ def test_unweighted_stabiliser_graph_from_networkx(): w.add_edge(3, 4, fault_ids=5, weight=6.0) w.add_edge(4, 5, fault_ids=6, weight=9.0) m = Matching(w) - assert (m.num_fault_ids == 7) - assert (m.num_detectors == 6) - assert (np.array_equal( - m.decode(np.array([1, 0, 1, 0, 0, 0])), - np.array([0, 0, 1, 0, 0, 0, 0])) + assert m.num_fault_ids == 7 + assert m.num_detectors == 6 + assert np.array_equal( + m.decode(np.array([1, 0, 1, 0, 0, 0])), np.array([0, 0, 1, 0, 0, 0, 0]) ) with pytest.raises(ValueError): m.decode(np.array([1, 1, 0])) with pytest.raises(ValueError): m.decode(np.array([1, 1, 1, 0, 0, 0])) - assert (np.array_equal( - m.decode(np.array([1, 0, 0, 0, 0, 1])), - np.array([0, 0, 1, 0, 1, 0, 0])) + assert np.array_equal( + m.decode(np.array([1, 0, 0, 0, 0, 1])), np.array([0, 0, 1, 0, 1, 0, 0]) ) - assert (np.array_equal( - m.decode(np.array([0, 1, 0, 0, 0, 1])), - np.array([0, 0, 0, 0, 1, 0, 0])) + assert np.array_equal( + m.decode(np.array([0, 1, 0, 0, 0, 1])), np.array([0, 0, 0, 0, 1, 0, 0]) ) @@ -109,27 +114,27 @@ def test_mwpm_from_networkx(): g.add_edge(0, 2, fault_ids=1) g.add_edge(1, 2, fault_ids=2) m = Matching(g) - assert (isinstance(m._matching_graph, MatchingGraph)) - assert (m.num_detectors == 3) - assert (m.num_fault_ids == 3) + assert isinstance(m._matching_graph, MatchingGraph) + assert m.num_detectors == 3 + assert m.num_fault_ids == 3 g = nx.Graph() g.add_edge(0, 1) g.add_edge(0, 2) g.add_edge(1, 2) m = Matching(g) - assert (isinstance(m._matching_graph, MatchingGraph)) - assert (m.num_detectors == 3) - assert (m.num_fault_ids == 0) + assert isinstance(m._matching_graph, MatchingGraph) + assert m.num_detectors == 3 + assert m.num_fault_ids == 0 g = nx.Graph() g.add_edge(0, 1, weight=1.5) g.add_edge(0, 2, weight=1.7) g.add_edge(1, 2, weight=1.2) m = Matching(g) - assert (isinstance(m._matching_graph, MatchingGraph)) - assert (m.num_detectors == 3) - assert (m.num_fault_ids == 0) + assert isinstance(m._matching_graph, MatchingGraph) + assert m.num_detectors == 3 + assert m.num_fault_ids == 0 def test_matching_edges_from_networkx(): @@ -137,17 +142,16 @@ def test_matching_edges_from_networkx(): g.add_edge(0, 1, fault_ids=0, weight=1.1, error_probability=0.1) g.add_edge(1, 2, fault_ids=1, weight=2.1, error_probability=0.2) g.add_edge(2, 3, fault_ids={2, 3}, weight=0.9, error_probability=0.3) - g.nodes[0]['is_boundary'] = True - g.nodes[3]['is_boundary'] = True + g.nodes[0]["is_boundary"] = True + g.nodes[3]["is_boundary"] = True g.add_edge(0, 3, weight=0.0) m = Matching(g) es = list(m.edges()) expected_edges = [ - (0, 1, {'fault_ids': {0}, 'weight': 1.1, 'error_probability': 0.1}), - (0, 3, {'fault_ids': set(), 'weight': 0.0, 'error_probability': -1.0}), - (1, 2, {'fault_ids': {1}, 'weight': 2.1, 'error_probability': 0.2}), - (2, 3, {'fault_ids': {2, 3}, 'weight': 0.9, 'error_probability': 0.3}) - + (0, 1, {"fault_ids": {0}, "weight": 1.1, "error_probability": 0.1}), + (0, 3, {"fault_ids": set(), "weight": 0.0, "error_probability": -1.0}), + (1, 2, {"fault_ids": {1}, "weight": 2.1, "error_probability": 0.2}), + (2, 3, {"fault_ids": {2, 3}, "weight": 0.9, "error_probability": 0.3}), ] assert es == expected_edges @@ -157,17 +161,16 @@ def test_qubit_id_accepted_via_networkx(): g.add_edge(0, 1, qubit_id=0, weight=1.1, error_probability=0.1) g.add_edge(1, 2, qubit_id=1, weight=2.1, error_probability=0.2) g.add_edge(2, 3, qubit_id={2, 3}, weight=0.9, error_probability=0.3) - g.nodes[0]['is_boundary'] = True - g.nodes[3]['is_boundary'] = True + g.nodes[0]["is_boundary"] = True + g.nodes[3]["is_boundary"] = True g.add_edge(0, 3, weight=0.0) m = Matching(g) es = list(m.edges()) expected_edges = [ - (0, 1, {'fault_ids': {0}, 'weight': 1.1, 'error_probability': 0.1}), - (0, 3, {'fault_ids': set(), 'weight': 0.0, 'error_probability': -1.0}), - (1, 2, {'fault_ids': {1}, 'weight': 2.1, 'error_probability': 0.2}), - (2, 3, {'fault_ids': {2, 3}, 'weight': 0.9, 'error_probability': 0.3}) - + (0, 1, {"fault_ids": {0}, "weight": 1.1, "error_probability": 0.1}), + (0, 3, {"fault_ids": set(), "weight": 0.0, "error_probability": -1.0}), + (1, 2, {"fault_ids": {1}, "weight": 2.1, "error_probability": 0.2}), + (2, 3, {"fault_ids": {2, 3}, "weight": 0.9, "error_probability": 0.3}), ] assert es == expected_edges diff --git a/tests/matching/load_from_rustworkx_test.py b/tests/matching/load_from_rustworkx_test.py index 3bd486d40..e34808d0b 100644 --- a/tests/matching/load_from_rustworkx_test.py +++ b/tests/matching/load_from_rustworkx_test.py @@ -28,7 +28,7 @@ def test_boundary_from_rustworkx(): g.add_edge(1, 2, dict(fault_ids=2)) g.add_edge(2, 3, dict(fault_ids=3)) g.add_edge(3, 4, dict(fault_ids=4)) - g[4]['is_boundary'] = True + g[4]["is_boundary"] = True m = Matching(g) assert m.boundary == {4} assert np.array_equal(m.decode(np.array([1, 0, 0, 0])), np.array([1, 0, 0, 0, 0])) @@ -47,14 +47,22 @@ def test_boundaries_from_rustworkx(): g.add_edge(3, 4, dict(fault_ids=3)) g.add_edge(4, 5, dict(fault_ids=4)) g.add_edge(0, 5, dict(fault_ids=-1, weight=0.0)) - g.nodes()[0]['is_boundary'] = True - g.nodes()[5]['is_boundary'] = True + g.nodes()[0]["is_boundary"] = True + g.nodes()[5]["is_boundary"] = True m = Matching(g) assert m.boundary == {0, 5} - assert np.array_equal(m.decode(np.array([0, 1, 0, 0, 0, 0])), np.array([1, 0, 0, 0, 0])) - assert np.array_equal(m.decode(np.array([0, 0, 1, 0, 0])), np.array([1, 1, 0, 0, 0])) - assert np.array_equal(m.decode(np.array([0, 0, 1, 1, 0])), np.array([0, 0, 1, 0, 0])) - assert np.array_equal(m.decode(np.array([0, 0, 0, 1, 0])), np.array([0, 0, 0, 1, 1])) + assert np.array_equal( + m.decode(np.array([0, 1, 0, 0, 0, 0])), np.array([1, 0, 0, 0, 0]) + ) + assert np.array_equal( + m.decode(np.array([0, 0, 1, 0, 0])), np.array([1, 1, 0, 0, 0]) + ) + assert np.array_equal( + m.decode(np.array([0, 0, 1, 1, 0])), np.array([0, 0, 1, 0, 0]) + ) + assert np.array_equal( + m.decode(np.array([0, 0, 0, 1, 0])), np.array([0, 0, 0, 1, 1]) + ) def test_unweighted_stabiliser_graph_from_rustworkx(): @@ -71,23 +79,20 @@ def test_unweighted_stabiliser_graph_from_rustworkx(): w.add_edge(3, 4, dict(fault_ids=5, weight=6.0)) w.add_edge(4, 5, dict(fault_ids=6, weight=9.0)) m = Matching(w) - assert (m.num_fault_ids == 7) - assert (m.num_detectors == 6) - assert (np.array_equal( - m.decode(np.array([1, 0, 1, 0, 0, 0])), - np.array([0, 0, 1, 0, 0, 0, 0])) + assert m.num_fault_ids == 7 + assert m.num_detectors == 6 + assert np.array_equal( + m.decode(np.array([1, 0, 1, 0, 0, 0])), np.array([0, 0, 1, 0, 0, 0, 0]) ) with pytest.raises(ValueError): m.decode(np.array([1, 1, 0])) with pytest.raises(ValueError): m.decode(np.array([1, 1, 1, 0, 0, 0])) - assert (np.array_equal( - m.decode(np.array([1, 0, 0, 0, 0, 1])), - np.array([0, 0, 1, 0, 1, 0, 0])) + assert np.array_equal( + m.decode(np.array([1, 0, 0, 0, 0, 1])), np.array([0, 0, 1, 0, 1, 0, 0]) ) - assert (np.array_equal( - m.decode(np.array([0, 1, 0, 0, 0, 1])), - np.array([0, 0, 0, 0, 1, 0, 0])) + assert np.array_equal( + m.decode(np.array([0, 1, 0, 0, 0, 1])), np.array([0, 0, 0, 0, 1, 0, 0]) ) @@ -99,9 +104,9 @@ def test_mwpm_from_rustworkx(): g.add_edge(0, 2, dict(fault_ids=1)) g.add_edge(1, 2, dict(fault_ids=2)) m = Matching(g) - assert (isinstance(m._matching_graph, MatchingGraph)) - assert (m.num_detectors == 3) - assert (m.num_fault_ids == 3) + assert isinstance(m._matching_graph, MatchingGraph) + assert m.num_detectors == 3 + assert m.num_fault_ids == 3 g = rx.PyGraph() g.add_nodes_from([{} for _ in range(3)]) @@ -109,9 +114,9 @@ def test_mwpm_from_rustworkx(): g.add_edge(0, 2, {}) g.add_edge(1, 2, {}) m = Matching(g) - assert (isinstance(m._matching_graph, MatchingGraph)) - assert (m.num_detectors == 3) - assert (m.num_fault_ids == 0) + assert isinstance(m._matching_graph, MatchingGraph) + assert m.num_detectors == 3 + assert m.num_fault_ids == 0 g = rx.PyGraph() g.add_nodes_from([{} for _ in range(3)]) @@ -119,9 +124,9 @@ def test_mwpm_from_rustworkx(): g.add_edge(0, 2, dict(weight=1.7)) g.add_edge(1, 2, dict(weight=1.2)) m = Matching(g) - assert (isinstance(m._matching_graph, MatchingGraph)) - assert (m.num_detectors == 3) - assert (m.num_fault_ids == 0) + assert isinstance(m._matching_graph, MatchingGraph) + assert m.num_detectors == 3 + assert m.num_fault_ids == 0 def test_matching_edges_from_rustworkx(): @@ -131,16 +136,16 @@ def test_matching_edges_from_rustworkx(): g.add_edge(0, 1, dict(fault_ids=0, weight=1.1, error_probability=0.1)) g.add_edge(1, 2, dict(fault_ids=1, weight=2.1, error_probability=0.2)) g.add_edge(2, 3, dict(fault_ids={2, 3}, weight=0.9, error_probability=0.3)) - g[0]['is_boundary'] = True - g[3]['is_boundary'] = True + g[0]["is_boundary"] = True + g[3]["is_boundary"] = True g.add_edge(0, 3, dict(weight=0.0)) m = Matching(g) es = list(m.edges()) expected_edges = [ - (0, 1, {'fault_ids': {0}, 'weight': 1.1, 'error_probability': 0.1}), - (1, 2, {'fault_ids': {1}, 'weight': 2.1, 'error_probability': 0.2}), - (2, 3, {'fault_ids': {2, 3}, 'weight': 0.9, 'error_probability': 0.3}), - (0, 3, {'fault_ids': set(), 'weight': 0.0, 'error_probability': -1.0}), + (0, 1, {"fault_ids": {0}, "weight": 1.1, "error_probability": 0.1}), + (1, 2, {"fault_ids": {1}, "weight": 2.1, "error_probability": 0.2}), + (2, 3, {"fault_ids": {2, 3}, "weight": 0.9, "error_probability": 0.3}), + (0, 3, {"fault_ids": set(), "weight": 0.0, "error_probability": -1.0}), ] print(es) assert es == expected_edges @@ -153,16 +158,16 @@ def test_qubit_id_accepted_via_rustworkx(): g.add_edge(0, 1, dict(qubit_id=0, weight=1.1, error_probability=0.1)) g.add_edge(1, 2, dict(qubit_id=1, weight=2.1, error_probability=0.2)) g.add_edge(2, 3, dict(qubit_id={2, 3}, weight=0.9, error_probability=0.3)) - g[0]['is_boundary'] = True - g[3]['is_boundary'] = True + g[0]["is_boundary"] = True + g[3]["is_boundary"] = True g.add_edge(0, 3, dict(weight=0.0)) m = Matching(g) es = list(m.edges()) expected_edges = [ - (0, 1, {'fault_ids': {0}, 'weight': 1.1, 'error_probability': 0.1}), - (1, 2, {'fault_ids': {1}, 'weight': 2.1, 'error_probability': 0.2}), - (2, 3, {'fault_ids': {2, 3}, 'weight': 0.9, 'error_probability': 0.3}), - (0, 3, {'fault_ids': set(), 'weight': 0.0, 'error_probability': -1.0}) + (0, 1, {"fault_ids": {0}, "weight": 1.1, "error_probability": 0.1}), + (1, 2, {"fault_ids": {1}, "weight": 2.1, "error_probability": 0.2}), + (2, 3, {"fault_ids": {2, 3}, "weight": 0.9, "error_probability": 0.3}), + (0, 3, {"fault_ids": set(), "weight": 0.0, "error_probability": -1.0}), ] assert es == expected_edges diff --git a/tests/matching/load_from_stim_test.py b/tests/matching/load_from_stim_test.py index b4e15a8a2..b857ba437 100644 --- a/tests/matching/load_from_stim_test.py +++ b/tests/matching/load_from_stim_test.py @@ -21,11 +21,15 @@ def test_load_from_stim_objects(): stim = pytest.importorskip("stim") - c = stim.Circuit.generated("surface_code:rotated_memory_x", distance=5, rounds=5, - after_clifford_depolarization=0.01, - before_measure_flip_probability=0.01, - after_reset_flip_probability=0.01, - before_round_data_depolarization=0.01) + c = stim.Circuit.generated( + "surface_code:rotated_memory_x", + distance=5, + rounds=5, + after_clifford_depolarization=0.01, + before_measure_flip_probability=0.01, + after_reset_flip_probability=0.01, + before_round_data_depolarization=0.01, + ) dem = c.detector_error_model(decompose_errors=True) m = Matching.from_detector_error_model(dem) assert m.num_detectors == dem.num_detectors @@ -62,8 +66,12 @@ def test_load_from_stim_files(data_dir: Path): def test_load_from_stim_wrong_type_raises_type_error(): stim = pytest.importorskip("stim") - c = stim.Circuit.generated("surface_code:rotated_memory_x", distance=3, rounds=1, - after_clifford_depolarization=0.01) + c = stim.Circuit.generated( + "surface_code:rotated_memory_x", + distance=3, + rounds=1, + after_clifford_depolarization=0.01, + ) with pytest.raises(TypeError): Matching.from_detector_error_model(c) with pytest.raises(TypeError): diff --git a/tests/matching/output_graph_test.py b/tests/matching/output_graph_test.py index 98a60badf..7950f792b 100644 --- a/tests/matching/output_graph_test.py +++ b/tests/matching/output_graph_test.py @@ -23,15 +23,15 @@ def test_matching_to_networkx(): g.add_edge(0, 1, fault_ids={0}, weight=1.1, error_probability=0.1) g.add_edge(1, 2, fault_ids={1}, weight=2.1, error_probability=0.2) g.add_edge(2, 3, fault_ids={2, 3}, weight=0.9, error_probability=0.3) - g.nodes[0]['is_boundary'] = True - g.nodes[3]['is_boundary'] = True + g.nodes[0]["is_boundary"] = True + g.nodes[3]["is_boundary"] = True g.add_edge(0, 3, weight=0.0) m = Matching(g) - g.edges[(0, 3)]['fault_ids'] = set() - g.edges[(0, 3)]['error_probability'] = -1.0 - g.nodes[1]['is_boundary'] = False - g.nodes[2]['is_boundary'] = False + g.edges[(0, 3)]["fault_ids"] = set() + g.edges[(0, 3)]["error_probability"] = -1.0 + g.nodes[1]["is_boundary"] = False + g.nodes[2]["is_boundary"] = False g2 = m.to_networkx() @@ -46,17 +46,28 @@ def test_matching_to_networkx(): m.add_edge(1, 2, weight=4) g = m.to_networkx() es = list(g.edges(data=True)) - assert es == [(0, 3, {"weight": 2.0, "error_probability": -1, "fault_ids": set()}), - (0, 1, {"weight": 3.0, "error_probability": -1, "fault_ids": set()}), - (1, 2, {"weight": 4.0, "error_probability": -1, "fault_ids": set()})] - assert sorted(list(g.nodes(data=True))) == [(0, {"is_boundary": False}), (1, {"is_boundary": False}), - (2, {"is_boundary": False}), (3, {"is_boundary": True})] + assert es == [ + (0, 3, {"weight": 2.0, "error_probability": -1, "fault_ids": set()}), + (0, 1, {"weight": 3.0, "error_probability": -1, "fault_ids": set()}), + (1, 2, {"weight": 4.0, "error_probability": -1, "fault_ids": set()}), + ] + assert sorted(list(g.nodes(data=True))) == [ + (0, {"is_boundary": False}), + (1, {"is_boundary": False}), + (2, {"is_boundary": False}), + (3, {"is_boundary": True}), + ] m = Matching() m.add_edge(0, 1) g = m.to_networkx() - assert list(g.edges(data=True)) == [(0, 1, {"weight": 1.0, "error_probability": -1, "fault_ids": set()})] - assert list(g.nodes(data=True)) == [(0, {"is_boundary": False}), (1, {"is_boundary": False})] + assert list(g.edges(data=True)) == [ + (0, 1, {"weight": 1.0, "error_probability": -1, "fault_ids": set()}) + ] + assert list(g.nodes(data=True)) == [ + (0, {"is_boundary": False}), + (1, {"is_boundary": False}), + ] def test_matching_to_rustworkx(): @@ -66,16 +77,16 @@ def test_matching_to_rustworkx(): g.add_edge(0, 1, dict(fault_ids={0}, weight=1.1, error_probability=0.1)) g.add_edge(1, 2, dict(fault_ids={1}, weight=2.1, error_probability=0.2)) g.add_edge(2, 3, dict(fault_ids={2, 3}, weight=0.9, error_probability=0.3)) - g[0]['is_boundary'] = True - g[3]['is_boundary'] = True + g[0]["is_boundary"] = True + g[3]["is_boundary"] = True g.add_edge(0, 3, dict(weight=0.0)) m = Matching(g) edge_0_3 = g.get_edge_data(0, 3) - edge_0_3['fault_ids'] = set() - edge_0_3['error_probability'] = -1.0 - g[1]['is_boundary'] = False - g[2]['is_boundary'] = False + edge_0_3["fault_ids"] = set() + edge_0_3["error_probability"] = -1.0 + g[1]["is_boundary"] = False + g[2]["is_boundary"] = False g2 = m.to_rustworkx() @@ -90,16 +101,24 @@ def test_matching_to_rustworkx(): m.add_edge(1, 2, weight=4) g = m.to_rustworkx() es = list(g.weighted_edge_list()) - assert es == [(0, 3, {"weight": 2.0, "error_probability": -1, "fault_ids": set()}), - (0, 1, {"weight": 3.0, "error_probability": -1, "fault_ids": set()}), - (1, 2, {"weight": 4.0, "error_probability": -1, "fault_ids": set()})] - assert list(g.nodes()) == [{"is_boundary": False}, {"is_boundary": False}, - {"is_boundary": False}, {"is_boundary": True}] + assert es == [ + (0, 3, {"weight": 2.0, "error_probability": -1, "fault_ids": set()}), + (0, 1, {"weight": 3.0, "error_probability": -1, "fault_ids": set()}), + (1, 2, {"weight": 4.0, "error_probability": -1, "fault_ids": set()}), + ] + assert list(g.nodes()) == [ + {"is_boundary": False}, + {"is_boundary": False}, + {"is_boundary": False}, + {"is_boundary": True}, + ] m = Matching() m.add_edge(0, 1) g = m.to_rustworkx() - assert list(g.weighted_edge_list()) == [(0, 1, {"weight": 1.0, "error_probability": -1, "fault_ids": set()})] + assert list(g.weighted_edge_list()) == [ + (0, 1, {"weight": 1.0, "error_probability": -1, "fault_ids": set()}) + ] assert list(g.nodes()) == [{"is_boundary": False}, {"is_boundary": False}] @@ -108,9 +127,11 @@ def test_negative_weight_edge_returned(): m.add_edge(0, 1, weight=0.5, error_probability=0.3) m.add_edge(1, 2, weight=0.5, error_probability=0.3, fault_ids=0) m.add_edge(2, 3, weight=-0.5, error_probability=0.7, fault_ids={1, 2}) - expected = [(0, 1, {'fault_ids': set(), 'weight': 0.5, 'error_probability': 0.3}), - (1, 2, {'fault_ids': {0}, 'weight': 0.5, 'error_probability': 0.3}), - (2, 3, {'fault_ids': {1, 2}, 'weight': -0.5, 'error_probability': 0.7})] + expected = [ + (0, 1, {"fault_ids": set(), "weight": 0.5, "error_probability": 0.3}), + (1, 2, {"fault_ids": {0}, "weight": 0.5, "error_probability": 0.3}), + (2, 3, {"fault_ids": {1, 2}, "weight": -0.5, "error_probability": 0.7}), + ] assert m.edges() == expected @@ -120,6 +141,8 @@ def test_self_loop_to_networkx(): m.add_edge(1, 1, weight=2) m.add_edge(1, 2, weight=1) g = m.to_networkx() - assert list(g.edges(data=True)) == [(0, 0, {'fault_ids': set(), 'weight': 3.0, 'error_probability': -1.0}), - (1, 1, {'fault_ids': set(), 'weight': 2.0, 'error_probability': -1.0}), - (1, 2, {'fault_ids': set(), 'weight': 1.0, 'error_probability': -1.0})] + assert list(g.edges(data=True)) == [ + (0, 0, {"fault_ids": set(), "weight": 3.0, "error_probability": -1.0}), + (1, 1, {"fault_ids": set(), "weight": 2.0, "error_probability": -1.0}), + (1, 2, {"fault_ids": set(), "weight": 1.0, "error_probability": -1.0}), + ] diff --git a/tests/matching/repr_test.py b/tests/matching/repr_test.py index af0c12980..8e5267829 100644 --- a/tests/matching/repr_test.py +++ b/tests/matching/repr_test.py @@ -22,9 +22,10 @@ def test_repr(): g.add_edge(0, 1, fault_ids=0) g.add_edge(1, 2, fault_ids=1) g.add_edge(2, 3, fault_ids=2) - g.nodes[0]['is_boundary'] = True - g.nodes[3]['is_boundary'] = True + g.nodes[0]["is_boundary"] = True + g.nodes[3]["is_boundary"] = True g.add_edge(0, 3, weight=0.0) m = Matching(g) - assert m.__repr__() == ("") + assert m.__repr__() == ( + "" + ) diff --git a/tests/matching/test_reweight.py b/tests/matching/test_reweight.py new file mode 100644 index 000000000..9951f0b4c --- /dev/null +++ b/tests/matching/test_reweight.py @@ -0,0 +1,173 @@ +import numpy as np +import pymatching +import pytest + + +def test_decode_reweight(): + # Simple graph: 0 -- 1 -- 2 + # Edge (0, 1) weight 2 + # Edge (1, 2) weight 2 + m = pymatching.Matching() + m.add_edge(0, 1, fault_ids={0}, weight=2) + m.add_edge(1, 2, fault_ids={1}, weight=2) + + # decode([1, 0, 1]) -> detection events 0, 2 + # Should match 0 to 2 via 1. Weight 4. + res, weight = m.decode(np.array([1, 0, 1]), return_weight=True) + assert weight == 4.0 + + # Reweight (0, 1) to 5. New weight 5+2 = 7. + reweights = np.array([[0, 1, 5.0]]) + res, weight = m.decode( + np.array([1, 0, 1]), return_weight=True, edge_reweights=reweights + ) + assert weight == 7.0 + + # Check weights restored + res, weight = m.decode(np.array([1, 0, 1]), return_weight=True) + assert weight == 4.0 + + +def test_decode_reweight_boundary(): + m = pymatching.Matching() + m.add_boundary_edge(0, fault_ids={0}, weight=2) + m.add_edge(0, 1, fault_ids={1}, weight=3) + + # decode([1, 0]) -> event at 0. Matches to boundary (weight 2). + res, weight = m.decode(np.array([1, 0]), return_weight=True) + assert weight == 2.0 + + # Reweight boundary edge to 5. + reweights = np.array([[0, -1, 5.0]]) + res, weight = m.decode( + np.array([1, 0]), return_weight=True, edge_reweights=reweights + ) + assert weight == 5.0 + + # Restored + res, weight = m.decode(np.array([1, 0]), return_weight=True) + assert weight == 2.0 + + +def test_decode_batch_reweight(): + m = pymatching.Matching() + m.add_edge(0, 1, fault_ids={0}, weight=2) + m.add_edge(1, 2, fault_ids={1}, weight=2) + + shots = np.array([[1, 0, 1], [1, 0, 1]], dtype=np.uint8) + + # Shot 0: reweight (0, 1) to 5. Expected weight 7. + # Shot 1: no reweight. Expected weight 4. + + reweights = [np.array([[0, 1, 5.0]]), None] + + preds, weights = m.decode_batch( + shots, return_weights=True, edge_reweights=reweights + ) + assert weights[0] == 7.0 + assert weights[1] == 4.0 + + # Check restored + preds, weights = m.decode_batch(shots, return_weights=True) + assert weights[0] == 4.0 + assert weights[1] == 4.0 + + +def test_decode_batch_reweight_all_same(): + m = pymatching.Matching() + m.add_edge(0, 1, fault_ids={0}, weight=2) + + shots = np.array([[1, 1], [1, 1]], dtype=np.uint8) + # Reweight to 5 + rw = np.array([[0, 1, 5.0]]) + reweights = [rw, rw] + + preds, weights = m.decode_batch( + shots, return_weights=True, edge_reweights=reweights + ) + assert weights[0] == 5.0 + assert weights[1] == 5.0 + + preds, weights = m.decode_batch(shots, return_weights=True) + assert weights[0] == 2.0 + + +def test_decode_reweight_large_observables(): + # If num_observables > 64, the search graph should be present even if enable_correlations=False. + m = pymatching.Matching() + # Add enough edges with unique fault_ids to exceed 64 observables + for i in range(70): + m.add_edge(i, i + 1, fault_ids={i}, weight=1) + + assert m.num_fault_ids >= 70 + + # Decode a simple case: error on edge (0, 1) + # Expected weight 1. + syndrome = np.zeros(m.num_nodes, dtype=np.uint8) + syndrome[0] = 1 + syndrome[1] = 1 + + res, weight = m.decode(syndrome, return_weight=True) + assert weight == 1.0 + + # Reweight edge (0, 1) to 10. + reweights = np.array([[0, 1, 10.0]]) + res, weight = m.decode(syndrome, return_weight=True, edge_reweights=reweights) + assert weight == 10.0 + + # Verify the weight is restored + res, weight = m.decode(syndrome, return_weight=True) + assert weight == 1.0 + + +def test_reweight_sign_flip_raises_error(): + m = pymatching.Matching() + m.add_edge(0, 1, weight=2) + m.add_edge(1, 2, weight=-3) + + # Positive to negative (flip) + with pytest.raises(ValueError, match="sign flip not allowed"): + m.decode(np.array([1, 0, 1]), edge_reweights=np.array([[0, 1, -5.0]])) + + # Negative to positive (flip) + with pytest.raises(ValueError, match="sign flip not allowed"): + m.decode(np.array([1, 0, 1]), edge_reweights=np.array([[1, 2, 3.0]])) + + +def test_reweight_negative_to_negative(): + # Graph: 0 -- 1 -- 2 + # (0, 1) weight 5 + # (1, 2) weight -3. + # Solution for detection events at 0, 2. + # Standard matching: 0 matches to 2 via 1. Path: (0,1), (1,2). + # Cost: 5 + (-3) = 2. + # Note: Negative weight -3 means edge (1,2) is pre-flipped. + # Events at 0, 2 means syndrome is 1 at 0, 1 at 2. + # If (1,2) is pre-flipped, it causes events at 1, 2. + # Observed syndrome: 0:1, 1:0, 2:1. + # Adjusted syndrome (xor with negative weight syndrome): + # 0:1, 1:1, 2:0. + # Now we match 0 and 1. Path (0, 1) cost 5. + # Total cost = 5 + (-3) = 2. + + m = pymatching.Matching() + m.add_edge(0, 1, weight=5) + m.add_edge(1, 2, weight=-3) + + # Check baseline + res, weight = m.decode(np.array([1, 0, 1]), return_weight=True) + assert weight == 2.0 + + # Reweight (1, 2) to -10. + # New cost calculation: + # Path (0, 1) cost 5. + # Total cost = 5 + (-10) = -5. + reweights = np.array([[1, 2, -10.0]]) + res, weight = m.decode( + np.array([1, 0, 1]), return_weight=True, edge_reweights=reweights + ) + assert weight == -5.0 + + # Verify restoration + res, weight = m.decode(np.array([1, 0, 1]), return_weight=True) + assert weight == 2.0 diff --git a/tests/rand/rand_gen_test.py b/tests/rand/rand_gen_test.py index b13fc2bf4..bc6a0ef49 100644 --- a/tests/rand/rand_gen_test.py +++ b/tests/rand/rand_gen_test.py @@ -6,12 +6,14 @@ def test_add_noise_without_error_probabilities_returns_none(): m = Matching(csr_matrix(np.array([[1, 1, 0], [0, 1, 1]]))) assert m.add_noise() is None - m = Matching(csr_matrix(np.array([[1, 1, 0], [0, 1, 1]])), - error_probabilities=np.array([0.5, 0.7, -0.1])) + m = Matching( + csr_matrix(np.array([[1, 1, 0], [0, 1, 1]])), + error_probabilities=np.array([0.5, 0.7, -0.1]), + ) assert m.add_noise() is None def test_rand_float(): N = 1000 - s = sum(rand_float(0., 1.) for i in range(N)) + s = sum(rand_float(0.0, 1.0) for i in range(N)) assert 430 < s < 570