Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions cpp/include/cugraph/mtmg/instance_manager.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
* Copyright (c) 2023-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -51,7 +51,11 @@ class instance_manager_t {
{
for (size_t i = 0; i < nccl_comms_.size(); ++i) {
rmm::cuda_set_device_raii local_set_device(device_ids_[i]);
RAFT_NCCL_TRY(ncclCommDestroy(*nccl_comms_[i]));
try {
RAFT_NCCL_TRY(ncclCommDestroy(*nccl_comms_[i]));
} catch (const std::exception& e) {
std::cerr << "Error destroying NCCL communication: " << e.what() << std::endl;
}
}
}

Expand Down
16 changes: 16 additions & 0 deletions cpp/include/cugraph/sampling_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,16 @@ biased_neighbor_sample(
bool dedupe_sources = false,
bool do_expensive_check = false);

enum class temporal_sampling_comparison_t {
STRICTLY_INCREASING = 0, /** Time strictly increasing (each time is after the previous one) */
MONOTONICALLY_INCREASING, /** Time monotonically increasing (could have multiple edges with same
time) */
STRICTLY_DECREASING, /** Time strictly decreasing (each time is before the previous one) */
MONOTONICALLY_DECREASING, /** Time monotonically decreasing (could have multiple edges with same
time) */
LAST /** Support last n behavior */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LAST /** Support last n behavior */

This sounds vague to me.

Isn't this something like NUM_COMPARISON_TYPES? (basically the number of supported temporal sampling comparison methods/modes/types... and so o n)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the end of the enum. It's actually to support a feature that will be implemented later called "last n".

If LAST is specified, the software will (eventually, not part of this PR) use the fanout value (call it n) and instead of randomly selecting n edges that leave this vertex it will select the n edges with the largest time stamps. You can think of this as sorting the outgoing edges by time and selecting the "last n" of them.

This is a feature that we'll add later (perhaps later in 25.12, perhaps not until 26.01, not sure of the priorities and how far we'll progress).

};

struct sampling_flags_t {
/**
* Specifies how to handle prior sources. Default is DEFAULT.
Expand All @@ -277,6 +287,12 @@ struct sampling_flags_t {
* (true) or without replacement (false). Default is true.
*/
bool with_replacement{true};

/**
* Specifies how to handle temporal sampling. Default is STRICTLY_INCREASING.
*/
temporal_sampling_comparison_t temporal_sampling_comparison{
temporal_sampling_comparison_t::STRICTLY_INCREASING};
};

/**
Expand Down
33 changes: 29 additions & 4 deletions cpp/src/c_api/temporal_neighbor_sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,27 @@ struct temporal_neighbor_sampling_functor : public cugraph::c_api::abstract_func
std::optional<rmm::device_uvector<label_t>> edge_label{std::nullopt};
std::optional<rmm::device_uvector<size_t>> offsets{std::nullopt};

cugraph::temporal_sampling_comparison_t temporal_sampling_comparison{};
switch (options_.temporal_sampling_comparison_) {
case cugraph_temporal_sampling_comparison_t::STRICTLY_INCREASING:
temporal_sampling_comparison =
cugraph::temporal_sampling_comparison_t::STRICTLY_INCREASING;
break;
case cugraph_temporal_sampling_comparison_t::MONOTONICALLY_INCREASING:
temporal_sampling_comparison =
cugraph::temporal_sampling_comparison_t::MONOTONICALLY_INCREASING;
break;
case cugraph_temporal_sampling_comparison_t::STRICTLY_DECREASING:
temporal_sampling_comparison =
cugraph::temporal_sampling_comparison_t::STRICTLY_DECREASING;
break;
case cugraph_temporal_sampling_comparison_t::MONOTONICALLY_DECREASING:
temporal_sampling_comparison =
cugraph::temporal_sampling_comparison_t::MONOTONICALLY_DECREASING;
break;
default: CUGRAPH_FAIL("Invalid temporal sampling comparison type");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto temporal_sampling_comparison = options_.temporal_sampling_comparison_;

Won't this work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're converting between types here. I'm hesitant to assume that translating from the C enum definition in the .h file to the C++ enum class definition in the .hpp file is guaranteed to work by assignment. This was my simple solution.

};

// FIXME: For biased sampling, the user should pass either biases or edge weights,
// otherwised throw an error and suggest the user to call uniform neighbor sample instead

Expand Down Expand Up @@ -321,7 +342,8 @@ struct temporal_neighbor_sampling_functor : public cugraph::c_api::abstract_func
cugraph::sampling_flags_t{options_.prior_sources_behavior_,
options_.return_hops_,
options_.dedupe_sources_,
options_.with_replacement_},
options_.with_replacement_,
temporal_sampling_comparison},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or can we just call this with options_.temporal_sampling_comparison_? Why are we creating a temporary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type conversion as mentioned above

do_expensive_check_);
} else {
std::tie(sampled_edge_srcs,
Expand Down Expand Up @@ -356,7 +378,8 @@ struct temporal_neighbor_sampling_functor : public cugraph::c_api::abstract_func
cugraph::sampling_flags_t{options_.prior_sources_behavior_,
options_.return_hops_,
options_.dedupe_sources_,
options_.with_replacement_},
options_.with_replacement_,
temporal_sampling_comparison},
do_expensive_check_);
}
} else {
Expand Down Expand Up @@ -394,7 +417,8 @@ struct temporal_neighbor_sampling_functor : public cugraph::c_api::abstract_func
cugraph::sampling_flags_t{options_.prior_sources_behavior_,
options_.return_hops_,
options_.dedupe_sources_,
options_.with_replacement_},
options_.with_replacement_,
temporal_sampling_comparison},
do_expensive_check_);
} else {
std::tie(sampled_edge_srcs,
Expand Down Expand Up @@ -428,7 +452,8 @@ struct temporal_neighbor_sampling_functor : public cugraph::c_api::abstract_func
cugraph::sampling_flags_t{options_.prior_sources_behavior_,
options_.return_hops_,
options_.dedupe_sources_,
options_.with_replacement_},
options_.with_replacement_,
temporal_sampling_comparison},
do_expensive_check_);
}
}
Expand Down
60 changes: 44 additions & 16 deletions cpp/src/sampling/detail/gather_one_hop_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <cugraph/arithmetic_variant_types.hpp>
#include <cugraph/edge_property.hpp>
#include <cugraph/sampling_functions.hpp>
#include <cugraph/utilities/mask_utils.cuh>

#include <raft/util/cudart_utils.hpp>
Expand Down Expand Up @@ -355,6 +356,7 @@ temporal_gather_one_hop_edgelist(
raft::device_span<edge_time_t const> active_major_times,
std::optional<raft::device_span<int32_t const>> active_major_labels,
std::optional<raft::device_span<uint8_t const>> gather_flags,
temporal_sampling_comparison_t temporal_sampling_comparison,
bool do_expensive_check)
{
constexpr bool store_transposed = false;
Expand Down Expand Up @@ -532,24 +534,48 @@ temporal_gather_one_hop_edgelist(
tmp_positions
? detail::mark_entries(handle,
edge_times.size(),
[d_tmp = edge_times.data(),
[temporal_sampling_comparison,
d_tmp = edge_times.data(),
d_tmp_positions = tmp_positions->data(),
kv_store_view =
kv_binary_search_store_device_view_t<decltype(kv_store.view())>{
kv_store.view()}] __device__(auto index) {
auto edge_time = d_tmp[index];
auto key_time =
cuda::std::get<0>(kv_store_view.find(d_tmp_positions[index]));
return (edge_time > key_time);

switch (temporal_sampling_comparison) {
case temporal_sampling_comparison_t::STRICTLY_INCREASING:
return (edge_time > key_time);
case temporal_sampling_comparison_t::MONOTONICALLY_INCREASING:
return (edge_time >= key_time);
case temporal_sampling_comparison_t::STRICTLY_DECREASING:
return (edge_time < key_time);
case temporal_sampling_comparison_t::MONOTONICALLY_DECREASING:
return (edge_time <= key_time);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert(false); // never be reached I assume this part should never be reached...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make that change.

assert(false);
})
: detail::mark_entries(
handle,
edge_times.size(),
[d_tmp = edge_times.data(), d_tmp_time = tmp_times->data()] __device__(auto index) {
auto edge_time = d_tmp[index];
auto key_time = d_tmp_time[index];
return (edge_time > key_time);
});
: detail::mark_entries(handle,
edge_times.size(),
[temporal_sampling_comparison,
d_tmp = edge_times.data(),
d_tmp_time = tmp_times->data()] __device__(auto index) {
auto edge_time = d_tmp[index];
auto key_time = d_tmp_time[index];

switch (temporal_sampling_comparison) {
case temporal_sampling_comparison_t::STRICTLY_INCREASING:
return (edge_time > key_time);
case temporal_sampling_comparison_t::MONOTONICALLY_INCREASING:
return (edge_time >= key_time);
case temporal_sampling_comparison_t::STRICTLY_DECREASING:
return (edge_time < key_time);
case temporal_sampling_comparison_t::MONOTONICALLY_DECREASING:
return (edge_time <= key_time);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

assert(false); // never be reached

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make that change

assert(false);
});

raft::device_span<uint32_t const> marked_entry_span{marked_entries.data(),
marked_entries.size()};
Expand All @@ -574,12 +600,14 @@ temporal_gather_one_hop_edgelist(
handle, std::move(*tmp_positions), marked_entry_span, keep_count);
}

result_labels = rmm::device_uvector<label_t>(keep_count, handle.get_stream());
kv_store.view().find(
tmp_positions->begin(),
tmp_positions->end(),
thrust::make_zip_iterator(thrust::make_discard_iterator(), result_labels->begin()),
handle.get_stream());
if (active_major_labels) {
result_labels = rmm::device_uvector<label_t>(keep_count, handle.get_stream());
kv_store.view().find(
tmp_positions->begin(),
tmp_positions->end(),
thrust::make_zip_iterator(thrust::make_discard_iterator(), result_labels->begin()),
handle.get_stream());
}
}

std::tie(result_srcs, result_dsts, result_properties) =
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/sampling/detail/gather_one_hop_mg_v32_e32.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ temporal_gather_one_hop_edgelist(
raft::device_span<int32_t const> active_major_times,
std::optional<raft::device_span<int32_t const>> active_major_labels,
std::optional<raft::device_span<uint8_t const>> gather_flags,
temporal_sampling_comparison_t temporal_sampling_comparison,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<vertex_t>,
Expand All @@ -67,6 +68,7 @@ temporal_gather_one_hop_edgelist(
raft::device_span<int64_t const> active_major_times,
std::optional<raft::device_span<int32_t const>> active_major_labels,
std::optional<raft::device_span<uint8_t const>> gather_flags,
temporal_sampling_comparison_t temporal_sampling_comparison,
bool do_expensive_check);

} // namespace detail
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/sampling/detail/gather_one_hop_mg_v64_e64.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ temporal_gather_one_hop_edgelist(
raft::device_span<int32_t const> active_major_times,
std::optional<raft::device_span<int32_t const>> active_major_labels,
std::optional<raft::device_span<uint8_t const>> gather_flags,
temporal_sampling_comparison_t temporal_sampling_comparison,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<vertex_t>,
Expand All @@ -67,6 +68,7 @@ temporal_gather_one_hop_edgelist(
raft::device_span<int64_t const> active_major_times,
std::optional<raft::device_span<int32_t const>> active_major_labels,
std::optional<raft::device_span<uint8_t const>> gather_flags,
temporal_sampling_comparison_t temporal_sampling_comparison,
bool do_expensive_check);

} // namespace detail
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/sampling/detail/gather_one_hop_sg_v32_e32.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include "cugraph/sampling_functions.hpp"
#include "gather_one_hop_impl.cuh"

namespace cugraph {
Expand Down Expand Up @@ -51,6 +52,7 @@ temporal_gather_one_hop_edgelist(
raft::device_span<int32_t const> active_major_times,
std::optional<raft::device_span<int32_t const>> active_major_labels,
std::optional<raft::device_span<uint8_t const>> gather_flags,
temporal_sampling_comparison_t temporal_sampling_comparison,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<vertex_t>,
Expand All @@ -67,6 +69,7 @@ temporal_gather_one_hop_edgelist(
raft::device_span<int64_t const> active_major_times,
std::optional<raft::device_span<int32_t const>> active_major_labels,
std::optional<raft::device_span<uint8_t const>> gather_flags,
temporal_sampling_comparison_t temporal_sampling_comparison,
bool do_expensive_check);

} // namespace detail
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/sampling/detail/gather_one_hop_sg_v64_e64.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ temporal_gather_one_hop_edgelist(
raft::device_span<int32_t const> active_major_times,
std::optional<raft::device_span<int32_t const>> active_major_labels,
std::optional<raft::device_span<uint8_t const>> gather_flags,
temporal_sampling_comparison_t temporal_sampling_comparison,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<vertex_t>,
Expand All @@ -67,6 +68,7 @@ temporal_gather_one_hop_edgelist(
raft::device_span<int64_t const> active_major_times,
std::optional<raft::device_span<int32_t const>> active_major_labels,
std::optional<raft::device_span<uint8_t const>> gather_flags,
temporal_sampling_comparison_t temporal_sampling_comparison,
bool do_expensive_check);

} // namespace detail
Expand Down
Loading