Skip to content

Commit

Permalink
get_primary_cuda_context(): leak
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Sep 25, 2024
1 parent e1f952d commit ebc3706
Showing 1 changed file with 13 additions and 30 deletions.
43 changes: 13 additions & 30 deletions cpp/include/kvikio/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,33 +115,6 @@ constexpr bool is_host_memory(const void* ptr) { return true; }
return ret;
}

/**
* @brief RAII wrapper for a CUDA primary context
*/
class CudaPrimaryContext {
public:
CUdevice dev{};
CUcontext ctx{};

CudaPrimaryContext(int device_ordinal)
{
CUDA_DRIVER_TRY(cudaAPI::instance().DeviceGet(&dev, device_ordinal));
CUDA_DRIVER_TRY(cudaAPI::instance().DevicePrimaryCtxRetain(&ctx, dev));
}
CudaPrimaryContext(const CudaPrimaryContext&) = delete;
CudaPrimaryContext& operator=(CudaPrimaryContext const&) = delete;
CudaPrimaryContext(CudaPrimaryContext&&) = delete;
CudaPrimaryContext&& operator=(CudaPrimaryContext&&) = delete;
~CudaPrimaryContext()
{
try {
CUDA_DRIVER_TRY(cudaAPI::instance().DevicePrimaryCtxRelease(dev), CUfileException);
} catch (const CUfileException& e) {
std::cerr << e.what() << std::endl;
}
}
};

/**
* @brief Given a device ordinal, return the primary context of the device.
*
Expand All @@ -152,12 +125,22 @@ class CudaPrimaryContext {
*/
[[nodiscard]] KVIKIO_EXPORT inline CUcontext get_primary_cuda_context(int ordinal)
{
static std::map<int, CudaPrimaryContext> _primary_contexts;
static std::map<int, CUcontext> _cache;
static std::mutex _mutex;
std::lock_guard const lock(_mutex);

_primary_contexts.try_emplace(ordinal, ordinal);
return _primary_contexts.at(ordinal).ctx;
if (_cache.find(ordinal) == _cache.end()) {
CUdevice dev{};
CUcontext ctx{};
CUDA_DRIVER_TRY(cudaAPI::instance().DeviceGet(&dev, ordinal));

// Notice, we let the primary context leak at program exit. We do this because `_cache`
// is static and we are not allowed to call `cuDevicePrimaryCtxRelease()` after main:
// <https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#initialization>
CUDA_DRIVER_TRY(cudaAPI::instance().DevicePrimaryCtxRetain(&ctx, dev));
_cache.emplace(ordinal, ctx);
}
return _cache.at(ordinal);
}

/**
Expand Down

0 comments on commit ebc3706

Please sign in to comment.