-
Notifications
You must be signed in to change notification settings - Fork 335
Add support for other temporal comparisons #5283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b28199b
1ff24ae
c5fd389
5c126ef
c145629
ddcfe18
1677266
72ffd62
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -279,6 +279,27 @@ struct temporal_neighbor_sampling_functor : public cugraph::c_api::abstract_func | |
std::optional<rmm::device_uvector<label_t>> edge_label{std::nullopt}; | ||
std::optional<rmm::device_uvector<size_t>> offsets{std::nullopt}; | ||
|
||
cugraph::temporal_sampling_comparison_t temporal_sampling_comparison{}; | ||
switch (options_.temporal_sampling_comparison_) { | ||
case cugraph_temporal_sampling_comparison_t::STRICTLY_INCREASING: | ||
temporal_sampling_comparison = | ||
cugraph::temporal_sampling_comparison_t::STRICTLY_INCREASING; | ||
break; | ||
case cugraph_temporal_sampling_comparison_t::MONOTONICALLY_INCREASING: | ||
temporal_sampling_comparison = | ||
cugraph::temporal_sampling_comparison_t::MONOTONICALLY_INCREASING; | ||
break; | ||
case cugraph_temporal_sampling_comparison_t::STRICTLY_DECREASING: | ||
temporal_sampling_comparison = | ||
cugraph::temporal_sampling_comparison_t::STRICTLY_DECREASING; | ||
break; | ||
case cugraph_temporal_sampling_comparison_t::MONOTONICALLY_DECREASING: | ||
temporal_sampling_comparison = | ||
cugraph::temporal_sampling_comparison_t::MONOTONICALLY_DECREASING; | ||
break; | ||
default: CUGRAPH_FAIL("Invalid temporal sampling comparison type"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Won't this work? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're converting between types here. I'm hesitant to assume that translating from the C enum definition in the |
||
}; | ||
|
||
// FIXME: For biased sampling, the user should pass either biases or edge weights, | ||
// otherwised throw an error and suggest the user to call uniform neighbor sample instead | ||
|
||
|
@@ -321,7 +342,8 @@ struct temporal_neighbor_sampling_functor : public cugraph::c_api::abstract_func | |
cugraph::sampling_flags_t{options_.prior_sources_behavior_, | ||
options_.return_hops_, | ||
options_.dedupe_sources_, | ||
options_.with_replacement_}, | ||
options_.with_replacement_, | ||
temporal_sampling_comparison}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or can we just call this with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Type conversion as mentioned above |
||
do_expensive_check_); | ||
} else { | ||
std::tie(sampled_edge_srcs, | ||
|
@@ -356,7 +378,8 @@ struct temporal_neighbor_sampling_functor : public cugraph::c_api::abstract_func | |
cugraph::sampling_flags_t{options_.prior_sources_behavior_, | ||
options_.return_hops_, | ||
options_.dedupe_sources_, | ||
options_.with_replacement_}, | ||
options_.with_replacement_, | ||
temporal_sampling_comparison}, | ||
do_expensive_check_); | ||
} | ||
} else { | ||
|
@@ -394,7 +417,8 @@ struct temporal_neighbor_sampling_functor : public cugraph::c_api::abstract_func | |
cugraph::sampling_flags_t{options_.prior_sources_behavior_, | ||
options_.return_hops_, | ||
options_.dedupe_sources_, | ||
options_.with_replacement_}, | ||
options_.with_replacement_, | ||
temporal_sampling_comparison}, | ||
do_expensive_check_); | ||
} else { | ||
std::tie(sampled_edge_srcs, | ||
|
@@ -428,7 +452,8 @@ struct temporal_neighbor_sampling_functor : public cugraph::c_api::abstract_func | |
cugraph::sampling_flags_t{options_.prior_sources_behavior_, | ||
options_.return_hops_, | ||
options_.dedupe_sources_, | ||
options_.with_replacement_}, | ||
options_.with_replacement_, | ||
temporal_sampling_comparison}, | ||
do_expensive_check_); | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
|
||
#include <cugraph/arithmetic_variant_types.hpp> | ||
#include <cugraph/edge_property.hpp> | ||
#include <cugraph/sampling_functions.hpp> | ||
#include <cugraph/utilities/mask_utils.cuh> | ||
|
||
#include <raft/util/cudart_utils.hpp> | ||
|
@@ -355,6 +356,7 @@ temporal_gather_one_hop_edgelist( | |
raft::device_span<edge_time_t const> active_major_times, | ||
std::optional<raft::device_span<int32_t const>> active_major_labels, | ||
std::optional<raft::device_span<uint8_t const>> gather_flags, | ||
temporal_sampling_comparison_t temporal_sampling_comparison, | ||
bool do_expensive_check) | ||
{ | ||
constexpr bool store_transposed = false; | ||
|
@@ -532,24 +534,48 @@ temporal_gather_one_hop_edgelist( | |
tmp_positions | ||
? detail::mark_entries(handle, | ||
edge_times.size(), | ||
[d_tmp = edge_times.data(), | ||
[temporal_sampling_comparison, | ||
d_tmp = edge_times.data(), | ||
d_tmp_positions = tmp_positions->data(), | ||
kv_store_view = | ||
kv_binary_search_store_device_view_t<decltype(kv_store.view())>{ | ||
kv_store.view()}] __device__(auto index) { | ||
auto edge_time = d_tmp[index]; | ||
auto key_time = | ||
cuda::std::get<0>(kv_store_view.find(d_tmp_positions[index])); | ||
return (edge_time > key_time); | ||
|
||
switch (temporal_sampling_comparison) { | ||
case temporal_sampling_comparison_t::STRICTLY_INCREASING: | ||
return (edge_time > key_time); | ||
case temporal_sampling_comparison_t::MONOTONICALLY_INCREASING: | ||
return (edge_time >= key_time); | ||
case temporal_sampling_comparison_t::STRICTLY_DECREASING: | ||
return (edge_time < key_time); | ||
case temporal_sampling_comparison_t::MONOTONICALLY_DECREASING: | ||
return (edge_time <= key_time); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll make that change. |
||
assert(false); | ||
}) | ||
: detail::mark_entries( | ||
handle, | ||
edge_times.size(), | ||
[d_tmp = edge_times.data(), d_tmp_time = tmp_times->data()] __device__(auto index) { | ||
auto edge_time = d_tmp[index]; | ||
auto key_time = d_tmp_time[index]; | ||
return (edge_time > key_time); | ||
}); | ||
: detail::mark_entries(handle, | ||
edge_times.size(), | ||
[temporal_sampling_comparison, | ||
d_tmp = edge_times.data(), | ||
d_tmp_time = tmp_times->data()] __device__(auto index) { | ||
auto edge_time = d_tmp[index]; | ||
auto key_time = d_tmp_time[index]; | ||
|
||
switch (temporal_sampling_comparison) { | ||
case temporal_sampling_comparison_t::STRICTLY_INCREASING: | ||
return (edge_time > key_time); | ||
case temporal_sampling_comparison_t::MONOTONICALLY_INCREASING: | ||
return (edge_time >= key_time); | ||
case temporal_sampling_comparison_t::STRICTLY_DECREASING: | ||
return (edge_time < key_time); | ||
case temporal_sampling_comparison_t::MONOTONICALLY_DECREASING: | ||
return (edge_time <= key_time); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll make that change |
||
assert(false); | ||
}); | ||
|
||
raft::device_span<uint32_t const> marked_entry_span{marked_entries.data(), | ||
marked_entries.size()}; | ||
|
@@ -574,12 +600,14 @@ temporal_gather_one_hop_edgelist( | |
handle, std::move(*tmp_positions), marked_entry_span, keep_count); | ||
} | ||
|
||
result_labels = rmm::device_uvector<label_t>(keep_count, handle.get_stream()); | ||
kv_store.view().find( | ||
tmp_positions->begin(), | ||
tmp_positions->end(), | ||
thrust::make_zip_iterator(thrust::make_discard_iterator(), result_labels->begin()), | ||
handle.get_stream()); | ||
if (active_major_labels) { | ||
result_labels = rmm::device_uvector<label_t>(keep_count, handle.get_stream()); | ||
kv_store.view().find( | ||
tmp_positions->begin(), | ||
tmp_positions->end(), | ||
thrust::make_zip_iterator(thrust::make_discard_iterator(), result_labels->begin()), | ||
handle.get_stream()); | ||
} | ||
} | ||
|
||
std::tie(result_srcs, result_dsts, result_properties) = | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LAST /** Support last n behavior */
This sounds vague to me.
Isn't this something like
NUM_COMPARISON_TYPES
? (basically the number of supported temporal sampling comparison methods/modes/types... and so o n)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not the end of the enum. It's actually to support a feature that will be implemented later called "last n".
If
LAST
is specified, the software will (eventually, not part of this PR) use the fanout value (call itn
) and instead of randomly selectingn
edges that leave this vertex it will select then
edges with the largest time stamps. You can think of this as sorting the outgoing edges by time and selecting the "last n" of them.This is a feature that we'll add later (perhaps later in 25.12, perhaps not until 26.01, not sure of the priorities and how far we'll progress).