From 0054957397c3f9faf5dc54ac68350615adf4eab1 Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Mon, 4 Dec 2023 13:39:20 -0600 Subject: [PATCH] Fix Arena MR to support simultaneous access by PTDS and other streams (#1395) (#1396) This PR backports #1395 from 24.02 to 23.12. It contains an arena MR fix for simultaneous access by PTDS and other streams. Backport requested by @sameerz @GregoryKimball. Authors: - Thomas Graves (https://github.com/tgravescs) Approvers: - Lawrence Mitchell (https://github.com/wence-) - Mark Harris (https://github.com/harrism) --- .../rmm/mr/device/arena_memory_resource.hpp | 21 ++++++++++++- tests/mr/device/arena_mr_tests.cpp | 31 +++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/include/rmm/mr/device/arena_memory_resource.hpp b/include/rmm/mr/device/arena_memory_resource.hpp index 0dbd9c90e..929b8454f 100644 --- a/include/rmm/mr/device/arena_memory_resource.hpp +++ b/include/rmm/mr/device/arena_memory_resource.hpp @@ -235,7 +235,26 @@ class arena_memory_resource final : public device_memory_resource { } } - if (!global_arena_.deallocate(ptr, bytes)) { RMM_FAIL("allocation not found"); } + if (!global_arena_.deallocate(ptr, bytes)) { + // It's possible to use per thread default streams along with another pool of streams. + // This means that it's possible for an allocation to move from a thread or stream arena + // back into the global arena during a defragmentation and then move down into another arena + // type. For instance, thread arena -> global arena -> stream arena. If this happens and + // there was an allocation from it while it was a thread arena, we now have to check to + // see if the allocation is part of a stream arena, and vice versa. + // Only do this in exceptional cases to not affect performance and have to check all + // arenas all the time. + if (use_per_thread_arena(stream)) { + for (auto& stream_arena : stream_arenas_) { + if (stream_arena.second.deallocate(ptr, bytes)) { return; } + } + } else { + for (auto const& thread_arena : thread_arenas_) { + if (thread_arena.second->deallocate(ptr, bytes)) { return; } + } + } + RMM_FAIL("allocation not found"); + } } /** diff --git a/tests/mr/device/arena_mr_tests.cpp b/tests/mr/device/arena_mr_tests.cpp index 48967d06a..7525cac9f 100644 --- a/tests/mr/device/arena_mr_tests.cpp +++ b/tests/mr/device/arena_mr_tests.cpp @@ -533,6 +533,37 @@ TEST_F(ArenaTest, Defragment) // NOLINT }()); } +TEST_F(ArenaTest, PerThreadToStreamDealloc) // NOLINT +{ + // This is testing that deallocation of a ptr still works when + // it was originally allocated in a superblock that was in a thread + // arena that then moved to global arena during a defragmentation + // and then moved to a stream arena. + auto const arena_size = superblock::minimum_size * 2; + arena_mr mr(rmm::mr::get_current_device_resource(), arena_size); + // Create an allocation from a per thread arena + void* thread_ptr = mr.allocate(256, rmm::cuda_stream_per_thread); + // Create an allocation in a stream arena to force global arena + // to be empty + cuda_stream stream{}; + void* ptr = mr.allocate(32_KiB, stream); + mr.deallocate(ptr, 32_KiB, stream); + // at this point the global arena doesn't have any superblocks so + // the next allocation causes defrag. Defrag causes all superblocks + // from the thread and stream arena allocated above to go back to + // global arena and it allocates one superblock to the stream arena. + auto* ptr1 = mr.allocate(superblock::minimum_size, rmm::cuda_stream_view{}); + // Allocate again to make sure all superblocks from + // global arena are owned by a stream arena instead of a thread arena + // or the global arena. + auto* ptr2 = mr.allocate(32_KiB, rmm::cuda_stream_view{}); + // The original thread ptr is now owned by a stream arena so make + // sure deallocation works. + mr.deallocate(thread_ptr, 256, rmm::cuda_stream_per_thread); + mr.deallocate(ptr1, superblock::minimum_size, rmm::cuda_stream_view{}); + mr.deallocate(ptr2, 32_KiB, rmm::cuda_stream_view{}); +} + TEST_F(ArenaTest, DumpLogOnFailure) // NOLINT { arena_mr mr{rmm::mr::get_current_device_resource(), 1_MiB, true};