Skip to content

Commit d660a6d

Browse files
authored
add C implementation of MST (#5044)
This PR adds the C, PLC and python API for Minimum Spanning Tree closes #4882 Authors: - Joseph Nke (https://github.com/jnke2016) Approvers: - Rick Ratzel (https://github.com/rlratzel) - Chuck Hastings (https://github.com/ChuckHastings) URL: #5044
1 parent d692694 commit d660a6d

File tree

11 files changed

+688
-3
lines changed

11 files changed

+688
-3
lines changed

cpp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,7 @@ add_library(cugraph_c
554554
src/c_api/edgelist.cpp
555555
src/c_api/renumber_arbitrary_edgelist.cu
556556
src/c_api/legacy_fa2.cpp
557+
src/c_api/legacy_mst.cpp
557558
)
558559
add_library(cugraph::cugraph_c ALIAS cugraph_c)
559560

cpp/include/cugraph_c/algorithms.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -32,6 +32,7 @@
3232
#include <cugraph_c/sampling_algorithms.h>
3333
#include <cugraph_c/similarity_algorithms.h>
3434
#include <cugraph_c/traversal_algorithms.h>
35+
#include <cugraph_c/tree_algorithms.h>
3536
/**
3637
* @}
3738
*/
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
#include <cugraph_c/error.h>
20+
#include <cugraph_c/graph.h>
21+
#include <cugraph_c/graph_functions.h>
22+
#include <cugraph_c/random.h>
23+
#include <cugraph_c/resource_handle.h>
24+
25+
/** @defgroup layout Layout algorithms
26+
*/
27+
28+
#ifdef __cplusplus
29+
extern "C" {
30+
#endif
31+
32+
/**
33+
* @brief Opaque layout output
34+
*/
35+
typedef struct {
36+
int32_t align_;
37+
} cugraph_layout_result_t;
38+
39+
/**
40+
* @brief Minimum Spanning Tree
41+
*
42+
* NOTE: This currently wraps the legacy minimum implementation and is only
43+
* available in Single GPU implementation.
44+
*
45+
* @param [in] handle Handle for accessing resources
46+
* @param [in] graph Pointer to graph. NOTE: Graph might be modified if the storage
47+
* needs to be transposed
48+
* @param [in] do_expensive_check
49+
* A flag to run expensive checks for input arguments (if set to true)
50+
* @param [out] result Opaque object containing the extracted subgraph
51+
* @param [out] error Pointer to an error object storing details of any error. Will
52+
* be populated if error code is not CUGRAPH_SUCCESS
53+
* @return error code
54+
*/
55+
cugraph_error_code_t cugraph_minimum_spanning_tree(const cugraph_resource_handle_t* handle,
56+
cugraph_graph_t* graph,
57+
bool_t do_expensive_check,
58+
cugraph_induced_subgraph_result_t** result,
59+
cugraph_error_t** error);
60+
61+
#ifdef __cplusplus
62+
}
63+
#endif

cpp/src/c_api/legacy_mst.cpp

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "c_api/abstract_functor.hpp"
18+
#include "c_api/capi_helper.hpp"
19+
#include "c_api/graph.hpp"
20+
#include "c_api/induced_subgraph_result.hpp"
21+
#include "c_api/random.hpp"
22+
#include "c_api/resource_handle.hpp"
23+
#include "c_api/utils.hpp"
24+
25+
#include <cugraph_c/algorithms.h>
26+
27+
#include <cugraph/algorithms.hpp>
28+
#include <cugraph/detail/utility_wrappers.hpp>
29+
#include <cugraph/graph_functions.hpp>
30+
31+
#include <optional>
32+
33+
namespace {
34+
35+
struct minimum_spanning_tree_functor : public cugraph::c_api::abstract_functor {
36+
raft::handle_t const& handle_;
37+
cugraph::c_api::cugraph_graph_t* graph_{nullptr};
38+
bool do_expensive_check_{};
39+
cugraph::c_api::cugraph_induced_subgraph_result_t* result_{};
40+
;
41+
42+
minimum_spanning_tree_functor(::cugraph_resource_handle_t const* handle,
43+
::cugraph_graph_t* graph,
44+
bool do_expensive_check)
45+
: abstract_functor(),
46+
handle_(*reinterpret_cast<cugraph::c_api::cugraph_resource_handle_t const*>(handle)->handle_),
47+
graph_(reinterpret_cast<cugraph::c_api::cugraph_graph_t*>(graph)),
48+
do_expensive_check_(do_expensive_check)
49+
{
50+
}
51+
52+
template <typename vertex_t,
53+
typename edge_t,
54+
typename weight_t,
55+
typename edge_type_type_t,
56+
bool store_transposed,
57+
bool multi_gpu>
58+
void operator()()
59+
{
60+
if constexpr (!cugraph::is_candidate<vertex_t, edge_t, weight_t>::value) {
61+
unsupported();
62+
} else if constexpr (multi_gpu) {
63+
unsupported();
64+
} else if constexpr (!std::is_same_v<edge_t, int32_t>) {
65+
unsupported();
66+
} else {
67+
auto graph =
68+
reinterpret_cast<cugraph::graph_t<vertex_t, edge_t, false, false>*>(graph_->graph_);
69+
70+
auto edge_weights =
71+
reinterpret_cast<cugraph::edge_property_t<edge_t, weight_t>*>(graph_->edge_weights_);
72+
73+
auto number_map = reinterpret_cast<rmm::device_uvector<vertex_t>*>(graph_->number_map_);
74+
75+
auto graph_view = graph->view();
76+
auto edge_partition_view = graph_view.local_edge_partition_view();
77+
78+
rmm::device_uvector<weight_t> tmp_weights(0, handle_.get_stream());
79+
if (edge_weights == nullptr) {
80+
tmp_weights.resize(edge_partition_view.indices().size(), handle_.get_stream());
81+
cugraph::detail::scalar_fill(handle_, tmp_weights.data(), tmp_weights.size(), weight_t{1});
82+
}
83+
84+
cugraph::legacy::GraphCSRView<vertex_t, edge_t, weight_t> legacy_csr_graph_view(
85+
const_cast<edge_t*>(edge_partition_view.offsets().data()),
86+
const_cast<vertex_t*>(edge_partition_view.indices().data()),
87+
(edge_weights == nullptr)
88+
? tmp_weights.data()
89+
: const_cast<weight_t*>(edge_weights->view().value_firsts().front()),
90+
edge_partition_view.offsets().size() - 1,
91+
edge_partition_view.indices().size());
92+
93+
auto result_legacy_coo_graph =
94+
cugraph::minimum_spanning_tree<vertex_t, edge_t, weight_t>(handle_, legacy_csr_graph_view);
95+
96+
const size_t num_edges = result_legacy_coo_graph->view().number_of_edges;
97+
98+
// FIXME: Add new constructor for cugraph_type_erased_host_array_t that takes an
99+
// rmm::device_buffer with the goa of skipping copies
100+
101+
rmm::device_uvector<vertex_t> result_src(num_edges, handle_.get_stream());
102+
raft::copy(result_src.data(),
103+
result_legacy_coo_graph->view().src_indices,
104+
result_src.size(),
105+
handle_.get_stream());
106+
107+
rmm::device_uvector<vertex_t> result_dst(num_edges, handle_.get_stream());
108+
raft::copy(result_dst.data(),
109+
result_legacy_coo_graph->view().dst_indices,
110+
result_dst.size(),
111+
handle_.get_stream());
112+
113+
std::optional<rmm::device_uvector<weight_t>> result_wgt{std::nullopt};
114+
115+
result_wgt = rmm::device_uvector<weight_t>{num_edges, handle_.get_stream()};
116+
raft::copy(result_wgt->data(),
117+
result_legacy_coo_graph->view().edge_data,
118+
result_wgt->size(),
119+
handle_.get_stream());
120+
121+
cugraph::unrenumber_int_vertices<vertex_t, multi_gpu>(
122+
handle_,
123+
result_src.data(),
124+
result_src.size(),
125+
number_map->data(),
126+
graph_view.vertex_partition_range_lasts(),
127+
do_expensive_check_);
128+
129+
cugraph::unrenumber_int_vertices<vertex_t, multi_gpu>(
130+
handle_,
131+
result_dst.data(),
132+
result_dst.size(),
133+
number_map->data(),
134+
graph_view.vertex_partition_range_lasts(),
135+
do_expensive_check_);
136+
137+
rmm::device_uvector<size_t> edge_offsets(2, handle_.get_stream());
138+
std::vector<size_t> h_edge_offsets{{0, num_edges}};
139+
raft::update_device(
140+
edge_offsets.data(), h_edge_offsets.data(), h_edge_offsets.size(), handle_.get_stream());
141+
142+
// FIXME: Add support for edge_id and edge_type_id.
143+
result_ = new cugraph::c_api::cugraph_induced_subgraph_result_t{
144+
new cugraph::c_api::cugraph_type_erased_device_array_t(result_src, graph_->vertex_type_),
145+
new cugraph::c_api::cugraph_type_erased_device_array_t(result_dst, graph_->vertex_type_),
146+
result_wgt ? new cugraph::c_api::cugraph_type_erased_device_array_t(*result_wgt,
147+
graph_->weight_type_)
148+
: NULL,
149+
NULL,
150+
NULL,
151+
new cugraph::c_api::cugraph_type_erased_device_array_t(edge_offsets,
152+
cugraph_data_type_id_t::SIZE_T)};
153+
}
154+
}
155+
};
156+
157+
} // namespace
158+
159+
extern "C" cugraph_error_code_t cugraph_minimum_spanning_tree(
160+
const cugraph_resource_handle_t* handle,
161+
cugraph_graph_t* graph,
162+
bool_t do_expensive_check,
163+
cugraph_induced_subgraph_result_t** result,
164+
cugraph_error_t** error)
165+
{
166+
minimum_spanning_tree_functor functor(handle, graph, do_expensive_check);
167+
168+
return cugraph::c_api::run_algorithm(graph, functor, result, error);
169+
}

cpp/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,7 @@ ConfigureCTest(CAPI_COUNT_MULTI_EDGES c_api/count_multi_edges_test.c)
935935
ConfigureCTest(CAPI_EGONET_TEST c_api/egonet_test.c)
936936
ConfigureCTest(CAPI_TWO_HOP_NEIGHBORS_TEST c_api/two_hop_neighbors_test.c)
937937
ConfigureCTest(CAPI_K_TRUSS_TEST c_api/k_truss_test.c)
938+
ConfigureCTest(CAPI_MST_TEST c_api/legacy_mst_test.c)
938939

939940
if (BUILD_CUGRAPH_MTMG_TESTS)
940941
###################################################################################################

0 commit comments

Comments
 (0)