Skip to content

Commit

Permalink
Cleanup CUB temporary storage layout test (#1848)
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber authored Jun 12, 2024
1 parent b543fe7 commit fd001a4
Showing 1 changed file with 27 additions and 27 deletions.
54 changes: 27 additions & 27 deletions cub/test/catch2_test_temporary_storage_layout.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*******************************************************************************
* Copyright (c) 2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2011-2024, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
Expand Down Expand Up @@ -34,45 +34,45 @@
#include "catch2_test_helper.h"
#include "cub/detail/temporary_storage.cuh"

using values = c2h::enum_type_list<int, 1, 4, 42>;
using num_storage_slots = c2h::enum_type_list<int, 1, 4, 42>;

template <int Items>
std::size_t GetTemporaryStorageSize(std::size_t (&sizes)[Items])
std::size_t get_temporary_storage_size(std::size_t (&sizes)[Items])
{
void* pointers[Items]{};
std::size_t temp_storage_bytes{};
CubDebugExit(cub::AliasTemporaries(nullptr, temp_storage_bytes, pointers, sizes));
return temp_storage_bytes;
}

std::size_t GetActualZero()
std::size_t get_actual_zero()
{
std::size_t sizes[1]{};

return GetTemporaryStorageSize(sizes);
return get_temporary_storage_size(sizes);
}

CUB_TEST("Test empty storage", "[temporary_storage_layout]", values)
CUB_TEST("Test empty storage", "[temporary_storage_layout]", num_storage_slots)
{
constexpr auto StorageSlots = c2h::get<0, TestType>::value;
cub::detail::temporary_storage::layout<StorageSlots> temporary_storage;
CHECK(temporary_storage.get_size() == GetActualZero());
constexpr auto storage_slots = c2h::get<0, TestType>::value;
cub::detail::temporary_storage::layout<storage_slots> temporary_storage;
CHECK(temporary_storage.get_size() == get_actual_zero());
}

CUB_TEST("Test partially filled storage", "[temporary_storage_layout]", values)
CUB_TEST("Test partially filled storage", "[temporary_storage_layout]", num_storage_slots)
{
constexpr auto StorageSlots = c2h::get<0, TestType>::value;
constexpr auto storage_slots = c2h::get<0, TestType>::value;
using target_type = std::uint64_t;
constexpr std::size_t target_elements = 42;
constexpr std::size_t full_slot_elements = target_elements * sizeof(target_type);
constexpr std::size_t empty_slot_elements{};

cub::detail::temporary_storage::layout<StorageSlots> temporary_storage;
cub::detail::temporary_storage::layout<storage_slots> temporary_storage;

std::unique_ptr<cub::detail::temporary_storage::alias<target_type>> arrays[StorageSlots];
std::size_t sizes[StorageSlots]{};
std::unique_ptr<cub::detail::temporary_storage::alias<target_type>> arrays[storage_slots];
std::size_t sizes[storage_slots]{};

for (int slot_id = 0; slot_id < StorageSlots; slot_id++)
for (int slot_id = 0; slot_id < storage_slots; slot_id++)
{
auto slot = temporary_storage.get_slot(slot_id);

Expand All @@ -89,9 +89,9 @@ CUB_TEST("Test partially filled storage", "[temporary_storage_layout]", values)

temporary_storage.map_to_buffer(temp_storage.get(), temp_storage_bytes);

CHECK(temp_storage_bytes == GetTemporaryStorageSize(sizes));
CHECK(temp_storage_bytes == get_temporary_storage_size(sizes));

for (int slot_id = 0; slot_id < StorageSlots; slot_id++)
for (int slot_id = 0; slot_id < storage_slots; slot_id++)
{
if (slot_id % 2 == 0)
{
Expand All @@ -104,7 +104,7 @@ CUB_TEST("Test partially filled storage", "[temporary_storage_layout]", values)
}
}

CUB_TEST("Test grow", "[temporary_storage_layout]", values)
CUB_TEST("Test grow", "[temporary_storage_layout]", num_storage_slots)
{
constexpr auto StorageSlots = c2h::get<0, TestType>::value;
using target_type = std::uint64_t;
Expand Down Expand Up @@ -143,25 +143,25 @@ CUB_TEST("Test grow", "[temporary_storage_layout]", values)
}
}

CUB_TEST("Test double grow", "[temporary_storage_layout]", values)
CUB_TEST("Test double grow", "[temporary_storage_layout]", num_storage_slots)
{
constexpr auto StorageSlots = c2h::get<0, TestType>::value;
constexpr auto storage_slots = c2h::get<0, TestType>::value;
using target_type = std::uint64_t;
constexpr std::size_t target_elements_number = 42;

cub::detail::temporary_storage::layout<StorageSlots> preset_layout;
std::unique_ptr<cub::detail::temporary_storage::alias<target_type>> preset_arrays[StorageSlots];
cub::detail::temporary_storage::layout<storage_slots> preset_layout;
std::unique_ptr<cub::detail::temporary_storage::alias<target_type>> preset_arrays[storage_slots];

for (int slot_id = 0; slot_id < StorageSlots; slot_id++)
for (int slot_id = 0; slot_id < storage_slots; slot_id++)
{
preset_arrays[slot_id].reset(new cub::detail::temporary_storage::alias<target_type>(
preset_layout.get_slot(slot_id)->template create_alias<target_type>(2 * target_elements_number)));
}

cub::detail::temporary_storage::layout<StorageSlots> postset_layout;
std::unique_ptr<cub::detail::temporary_storage::alias<target_type>> postset_arrays[StorageSlots];
cub::detail::temporary_storage::layout<storage_slots> postset_layout;
std::unique_ptr<cub::detail::temporary_storage::alias<target_type>> postset_arrays[storage_slots];

for (int slot_id = 0; slot_id < StorageSlots; slot_id++)
for (int slot_id = 0; slot_id < storage_slots; slot_id++)
{
postset_arrays[slot_id].reset(new cub::detail::temporary_storage::alias<target_type>(
postset_layout.get_slot(slot_id)->template create_alias<target_type>(target_elements_number)));
Expand All @@ -176,7 +176,7 @@ CUB_TEST("Test double grow", "[temporary_storage_layout]", values)
preset_layout.map_to_buffer(temp_storage.get(), tmp_storage_bytes);
postset_layout.map_to_buffer(temp_storage.get(), tmp_storage_bytes);

for (int slot_id = 0; slot_id < StorageSlots; slot_id++)
for (int slot_id = 0; slot_id < storage_slots; slot_id++)
{
CHECK(postset_arrays[slot_id]->get() == preset_arrays[slot_id]->get());
}
Expand Down

0 comments on commit fd001a4

Please sign in to comment.