diff --git a/releasenotes/notes/general-transitive-closure-2b6bee50af59cf8a.yaml b/releasenotes/notes/general-transitive-closure-2b6bee50af59cf8a.yaml new file mode 100644 index 0000000000..312721c349 --- /dev/null +++ b/releasenotes/notes/general-transitive-closure-2b6bee50af59cf8a.yaml @@ -0,0 +1,7 @@ +--- +features: + - | + Added a new function ``transitive_closure`` to rustworkx which returns the + transitive closure of a graph. The transitive closure of G = (V,E) is a graph + G+ = (V,E+) such that for all v, w in V there is an edge (v, w) in E+ if and + only if there is a path from v to w in G. \ No newline at end of file diff --git a/rustworkx/__init__.pyi b/rustworkx/__init__.pyi index 11edc5922e..82ce213803 100644 --- a/rustworkx/__init__.pyi +++ b/rustworkx/__init__.pyi @@ -214,6 +214,7 @@ from .rustworkx import graph_tensor_product as graph_tensor_product from .rustworkx import graph_token_swapper as graph_token_swapper from .rustworkx import digraph_transitivity as digraph_transitivity from .rustworkx import graph_transitivity as graph_transitivity +from .rustworkx import transitive_closure as transitive_closure from .rustworkx import digraph_bfs_search as digraph_bfs_search from .rustworkx import graph_bfs_search as graph_bfs_search from .rustworkx import digraph_dfs_search as digraph_dfs_search diff --git a/src/lib.rs b/src/lib.rs index cce7c91755..2cd35c9c44 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -542,6 +542,7 @@ fn rustworkx(py: Python<'_>, m: &Bound) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(minimum_spanning_edges))?; m.add_wrapped(wrap_pyfunction!(minimum_spanning_tree))?; m.add_wrapped(wrap_pyfunction!(graph_transitivity))?; + m.add_wrapped(wrap_pyfunction!(transitive_closure))?; m.add_wrapped(wrap_pyfunction!(digraph_transitivity))?; m.add_wrapped(wrap_pyfunction!(graph_token_swapper))?; m.add_wrapped(wrap_pyfunction!(graph_core_number))?; diff --git a/src/transitivity.rs b/src/transitivity.rs index 6bcb25f94f..f8dc164e28 100644 --- a/src/transitivity.rs +++ b/src/transitivity.rs @@ -13,11 +13,20 @@ use super::{digraph, graph}; use hashbrown::HashSet; +use crate::digraph::PyDiGraph; +use petgraph::algo::kosaraju_scc; +use petgraph::algo::DfsSpace; +use petgraph::graph::DiGraph; use pyo3::prelude::*; +use petgraph::visit::EdgeRef; +use petgraph::visit::IntoEdgeReferences; +use petgraph::visit::NodeCount; use petgraph::graph::NodeIndex; use rayon::prelude::*; +use rustworkx_core::traversal::build_transitive_closure_dag; + fn _graph_triangles(graph: &graph::PyGraph, node: usize) -> (usize, usize) { let mut triangles: usize = 0; @@ -186,3 +195,59 @@ pub fn digraph_transitivity(graph: &digraph::PyDiGraph) -> f64 { _ => triangles as f64 / triples as f64, } } + +/// Returns the transitive closure of a graph +#[pyfunction] +#[pyo3(text_signature = "(graph, /")] +pub fn transitive_closure(py: Python, graph: &PyDiGraph) -> PyResult { + let sccs = kosaraju_scc(&graph.graph); + + let mut condensed_graph = DiGraph::new(); + let mut scc_nodes = Vec::new(); + let mut scc_map: Vec = vec![NodeIndex::end(); graph.node_count()]; + + for scc in &sccs { + let scc_node = condensed_graph.add_node(()); + scc_nodes.push(scc_node); + for node in scc { + scc_map[node.index()] = scc_node; + } + } + for edge in graph.graph.edge_references() { + let (source, target) = (edge.source(), edge.target()); + + if scc_map[source.index()] != scc_map[target.index()] { + condensed_graph.add_edge(scc_map[source.index()], scc_map[target.index()], ()); + } + } + + let closure_graph_result = build_transitive_closure_dag(condensed_graph, None, || {}); + let out_graph = closure_graph_result.unwrap(); + + let mut new_graph = graph.graph.clone(); + new_graph.clear(); + + let mut result_map: Vec = vec![NodeIndex::end(); out_graph.node_count()]; + for (_index, node) in out_graph.node_indices().enumerate() { + let result_node = new_graph.add_node(py.None()); + result_map[node.index()] = result_node; + } + for edge in out_graph.edge_references() { + let (source, target) = (edge.source(), edge.target()); + new_graph.add_edge( + result_map[source.index()], + result_map[target.index()], + py.None(), + ); + } + let out = PyDiGraph { + graph: new_graph, + cycle_state: DfsSpace::default(), + check_cycle: false, + node_removed: false, + multigraph: true, + attrs: py.None(), + }; + + Ok(out) +} diff --git a/tests/graph/test_transitive_closure.py b/tests/graph/test_transitive_closure.py new file mode 100644 index 0000000000..cc0d9c6674 --- /dev/null +++ b/tests/graph/test_transitive_closure.py @@ -0,0 +1,70 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +import rustworkx + + +class TestTransitive(unittest.TestCase): + def test_transitive_closure(self): + + graph = rustworkx.PyDiGraph() + graph.add_nodes_from(list(range(4))) + graph.add_edge(0, 1, ()) + graph.add_edge(1, 2, ()) + graph.add_edge(2, 0, ()) + graph.add_edge(2, 3, ()) + + closure_graph = rustworkx.transitive_closure(graph) + self.expected_edges = [ + (0, 1), + (0, 2), + (0, 3), + (1, 0), + (1, 2), + (1, 3), + (2, 0), + (2, 1), + (2, 3), + ] + + self.assertEqualEdgeList(self.expected_edges, closure_graph.edge_list()) + + def test_transitive_closure_single_node(self): + graph = rustworkx.PyDiGraph() + graph.add_node(()) + closure_graph = rustworkx.transitive_closure(graph) + expected_edges = [] + self.assertEqualEdgeList(expected_edges, closure_graph.edge_list()) + + def test_transitive_closure_no_edges(self): + graph = rustworkx.PyDiGraph() + graph.add_nodes_from(list(range(4))) + closure_graph = rustworkx.transitive_closure(graph) + expected_edges = [] + self.assertEqualEdgeList(expected_edges, closure_graph.edge_list()) + + def test_transitive_closure_complete_graph(self): + graph = rustworkx.PyDiGraph() + graph.add_nodes_from(list(range(4))) + for i in range(4): + for j in range(4): + if i != j: + graph.add_edge(i, j, ()) + closure_graph = rustworkx.transitive_closure(graph) + expected_edges = [(i, j) for i in range(4) for j in range(4) if i != j] + self.assertEqualEdgeList(expected_edges, closure_graph.edge_list()) + + def assertEqualEdgeList(self, expected, actual): + for edge in actual: + self.assertTrue(edge in expected) \ No newline at end of file