Skip to content

Commit

Permalink
Add static_multiset and its insert APIs (#437)
Browse files Browse the repository at this point in the history
This PR adds `static_multiset` and its insert-related APIs. This is a
proof of concept of sharing implementation details between set and
multiset containers.
  • Loading branch information
PointKernel authored Feb 15, 2024
1 parent 4fdc73b commit f54ff98
Show file tree
Hide file tree
Showing 9 changed files with 1,293 additions and 14 deletions.
44 changes: 35 additions & 9 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,14 @@ struct window_probing_results {
* @tparam KeyEqual Binary callable type used to compare two keys for equality
* @tparam ProbingScheme Probing scheme (see `include/cuco/probing_scheme.cuh` for options)
* @tparam StorageRef Storage ref type
* @tparam AllowsDuplicates Flag indicating whether duplicate keys are allowed or not
*/
template <typename Key,
cuda::thread_scope Scope,
typename KeyEqual,
typename ProbingScheme,
typename StorageRef>
typename StorageRef,
bool AllowsDuplicates>
class open_addressing_ref_impl {
static_assert(sizeof(Key) <= 8, "Container does not support key types larger than 8 bytes.");

Expand All @@ -94,6 +96,9 @@ class open_addressing_ref_impl {
/// Determines if the container is a key/value or key-only store
static constexpr auto has_payload = not std::is_same_v<Key, typename StorageRef::value_type>;

/// Flag indicating whether duplicate keys are allowed or not
static constexpr auto allows_duplicates = AllowsDuplicates;

// TODO: how to re-enable this check?
// static_assert(is_window_extent_v<typename StorageRef::extent_type>,
// "Extent is not a valid cuco::window_extent");
Expand Down Expand Up @@ -360,18 +365,26 @@ class open_addressing_ref_impl {
for (auto& slot_content : window_slots) {
auto const eq_res = this->predicate_(this->extract_key(slot_content), key);

// If the key is already in the container, return false
if (eq_res == detail::equal_result::EQUAL) { return false; }
if constexpr (not allows_duplicates) {
// If the key is already in the container, return false
if (eq_res == detail::equal_result::EQUAL) { return false; }
}
if (eq_res == detail::equal_result::EMPTY or
cuco::detail::bitwise_compare(this->extract_key(slot_content),
this->erased_key_sentinel())) {
auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content);
switch (attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index,
slot_content,
val)) {
case insert_result::DUPLICATE: {
if constexpr (allows_duplicates) {
[[fallthrough]];
} else {
return false;
}
}
case insert_result::CONTINUE: continue;
case insert_result::SUCCESS: return true;
case insert_result::DUPLICATE: return false;
}
}
}
Expand Down Expand Up @@ -405,8 +418,13 @@ class open_addressing_ref_impl {
switch (this->predicate_(this->extract_key(window_slots[i]), key)) {
case detail::equal_result::EMPTY:
return window_probing_results{detail::equal_result::EMPTY, i};
case detail::equal_result::EQUAL:
return window_probing_results{detail::equal_result::EQUAL, i};
case detail::equal_result::EQUAL: {
if constexpr (allows_duplicates) {
continue;
} else {
return window_probing_results{detail::equal_result::EQUAL, i};
}
}
default: {
if (cuco::detail::bitwise_compare(this->extract_key(window_slots[i]),
this->erased_key_sentinel())) {
Expand All @@ -421,8 +439,10 @@ class open_addressing_ref_impl {
return window_probing_results{detail::equal_result::UNEQUAL, -1};
}();

// If the key is already in the container, return false
if (group.any(state == detail::equal_result::EQUAL)) { return false; }
if constexpr (not allows_duplicates) {
// If the key is already in the container, return false
if (group.any(state == detail::equal_result::EQUAL)) { return false; }
}

auto const group_contains_available =
group.ballot(state == detail::equal_result::EMPTY or state == detail::equal_result::ERASED);
Expand All @@ -437,7 +457,13 @@ class open_addressing_ref_impl {

switch (group.shfl(status, src_lane)) {
case insert_result::SUCCESS: return true;
case insert_result::DUPLICATE: return false;
case insert_result::DUPLICATE: {
if constexpr (allows_duplicates) {
[[fallthrough]];
} else {
return false;
}
}
default: continue;
}
} else {
Expand Down
258 changes: 258 additions & 0 deletions include/cuco/detail/static_multiset/static_multiset.inl
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cuco/detail/bitwise_compare.cuh>
#include <cuco/detail/utility/cuda.hpp>
#include <cuco/detail/utils.hpp>
#include <cuco/operator.hpp>
#include <cuco/static_multiset_ref.cuh>

#include <cstddef>

namespace cuco {

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
constexpr static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
static_multiset(Extent capacity,
empty_key<Key> empty_key_sentinel,
KeyEqual const& pred,
ProbingScheme const& probing_scheme,
cuda_thread_scope<Scope>,
Storage,
Allocator const& alloc,
cuda_stream_ref stream)
: impl_{std::make_unique<impl_type>(
capacity, empty_key_sentinel, pred, probing_scheme, alloc, stream)}
{
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
constexpr static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
static_multiset(Extent n,
double desired_load_factor,
empty_key<Key> empty_key_sentinel,
KeyEqual const& pred,
ProbingScheme const& probing_scheme,
cuda_thread_scope<Scope>,
Storage,
Allocator const& alloc,
cuda_stream_ref stream)
: impl_{std::make_unique<impl_type>(
n, desired_load_factor, empty_key_sentinel, pred, probing_scheme, alloc, stream)}
{
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
constexpr static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
static_multiset(Extent capacity,
empty_key<Key> empty_key_sentinel,
erased_key<Key> erased_key_sentinel,
KeyEqual const& pred,
ProbingScheme const& probing_scheme,
cuda_thread_scope<Scope>,
Storage,
Allocator const& alloc,
cuda_stream_ref stream)
: impl_{std::make_unique<impl_type>(
capacity, empty_key_sentinel, erased_key_sentinel, pred, probing_scheme, alloc, stream)}
{
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
void static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::clear(
cuda_stream_ref stream) noexcept
{
impl_->clear(stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
void static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::clear_async(
cuda_stream_ref stream) noexcept
{
impl_->clear_async(stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt>
void static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::insert(
InputIt first, InputIt last, cuda_stream_ref stream)
{
this->insert_async(first, last, stream);
stream.synchronize();
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt>
void static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::insert_async(
InputIt first, InputIt last, cuda_stream_ref stream) noexcept
{
impl_->insert_async(first, last, ref(op::insert), stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename StencilIt, typename Predicate>
void static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::insert_if(
InputIt first, InputIt last, StencilIt stencil, Predicate pred, cuda_stream_ref stream)
{
this->insert_if_async(first, last, stencil, pred, stream);
stream.synchronize();
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename StencilIt, typename Predicate>
void static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
insert_if_async(
InputIt first, InputIt last, StencilIt stencil, Predicate pred, cuda_stream_ref stream) noexcept
{
impl_->insert_if_async(first, last, stencil, pred, ref(op::insert), stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::size_type
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::size(
cuda_stream_ref stream) const noexcept
{
return impl_->size(stream);
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
constexpr auto
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::capacity()
const noexcept
{
return impl_->capacity();
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
constexpr static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::key_type
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
empty_key_sentinel() const noexcept
{
return impl_->empty_key_sentinel();
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
constexpr static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::key_type
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
erased_key_sentinel() const noexcept
{
return impl_->erased_key_sentinel();
}

template <class Key,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename... Operators>
auto static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::ref(
Operators...) const noexcept
{
static_assert(sizeof...(Operators), "No operators specified");
return cuco::detail::bitwise_compare(this->empty_key_sentinel(), this->erased_key_sentinel())
? ref_type<Operators...>{cuco::empty_key<key_type>(this->empty_key_sentinel()),
impl_->key_eq(),
impl_->probing_scheme(),
cuda_thread_scope<Scope>{},
impl_->storage_ref()}
: ref_type<Operators...>{cuco::empty_key<key_type>(this->empty_key_sentinel()),
cuco::erased_key<key_type>(this->erased_key_sentinel()),
impl_->key_eq(),
impl_->probing_scheme(),
cuda_thread_scope<Scope>{},
impl_->storage_ref()};
}
} // namespace cuco
Loading

0 comments on commit f54ff98

Please sign in to comment.