Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Temporal] Pick function optimization #7417

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 26 additions & 44 deletions graphbolt/src/fused_csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1327,6 +1327,7 @@
torch::Tensor probs, int64_t fanout, bool replace) {
auto positive_probs_indices = probs.nonzero().squeeze(1);
auto num_positive_probs = positive_probs_indices.size(0);
int64_t num_neighbors = num_positive_probs;
if (num_positive_probs == 0) return torch::empty({0}, torch::kLong);
if ((fanout == -1) || (num_positive_probs <= fanout && !replace)) {
return positive_probs_indices;
Expand All @@ -1342,56 +1343,37 @@
positive_probs_indices.data_ptr<int64_t>();

if (!replace) {
// The algorithm is from gumbel softmax.
// s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1).
// Here we can apply exp to the formula which will not affect result
// of argmax or topk. Then we have
// s = argmax( p / (-log(eps)) ) where eps ~ U(0, 1).
// We can also simplify the formula above by
// s = argmax( p / q ) where q ~ Exp(1).
if (fanout == 1) {
// Return argmax(p / q).
scalar_t max_prob = 0;
int64_t max_prob_index = -1;
// We only care about the neighbors with non-zero probability.
for (auto i = 0; i < num_positive_probs; ++i) {
// Calculate (p / q) for the current neighbor.
scalar_t current_prob =
probs_data_ptr[positive_probs_indices_ptr[i]] /
RandomEngine::ThreadLocal()->Exponential(1.);
if (current_prob > max_prob) {
max_prob = current_prob;
max_prob_index = positive_probs_indices_ptr[i];
}
if (fanout >= num_neighbors / 10) {
for (int64_t i = 0; i < fanout; ++i) {
auto j = RandomEngine::ThreadLocal()->RandInt(i, num_neighbors);

Check warning on line 1348 in graphbolt/src/fused_csc_sampling_graph.cc

View workflow job for this annotation

GitHub Actions / lintrunner

CLANGFORMAT format

See https://clang.llvm.org/docs/ClangFormat.html. Run `lintrunner -a` to apply this patch.
std::swap(positive_probs_indices_ptr[i], positive_probs_indices_ptr[j]);
}
ret_ptr[0] = max_prob_index;
} else {
// Return topk(p / q).
std::vector<std::pair<scalar_t, int64_t>> q(num_positive_probs);
for (auto i = 0; i < num_positive_probs; ++i) {
q[i].first = probs_data_ptr[positive_probs_indices_ptr[i]] /
RandomEngine::ThreadLocal()->Exponential(1.);
q[i].second = positive_probs_indices_ptr[i];
std::memcpy(ret_ptr, positive_probs_indices_ptr, fanout * sizeof(int64_t));
} else if (fanout < 64) {
auto begin = ret_ptr;
auto end = ret_ptr + fanout;

while (begin != end) {
// Put the new random number in the last position.
int64_t tmp = RandomEngine::ThreadLocal()->RandInt( static_cast<int64_t>(0), num_neighbors);
*begin = positive_probs_indices_ptr[tmp];
// Check if a new value doesn't exist in current
// range(picked_data_ptr, begin). Otherwise get a new
// value until we haven't unique range of elements.
auto it = std::find(ret_ptr, begin, *begin);
if (it == begin) ++begin;
}
if (fanout < num_positive_probs / 64) {
// Use partial_sort.
std::partial_sort(
q.begin(), q.begin() + fanout, q.end(), std::greater{});
for (auto i = 0; i < fanout; ++i) {
ret_ptr[i] = q[i].second;
}
} else {
// Use nth_element.
std::nth_element(
q.begin(), q.begin() + fanout - 1, q.end(), std::greater{});
for (auto i = 0; i < fanout; ++i) {
ret_ptr[i] = q[i].second;
}
} else {
std::unordered_set<int64_t> picked_set;
while (static_cast<int64_t>(picked_set.size()) < fanout) {
int64_t tmp = RandomEngine::ThreadLocal()->RandInt( static_cast<int64_t>(0), num_neighbors);
picked_set.insert(positive_probs_indices_ptr[tmp]);
}
std::copy(picked_set.begin(), picked_set.end(), ret_ptr);
}
} else {
// Calculate cumulative sum of probabilities.
std::vector<scalar_t> prefix_sum_probs(num_positive_probs);
std::vector<scalar_t> prefix_sum_probs(num_neighbors);
scalar_t sum_probs = 0;
for (auto i = 0; i < num_positive_probs; ++i) {
sum_probs += probs_data_ptr[positive_probs_indices_ptr[i]];
Expand Down
Loading