diff --git a/cpp/include/kvikio/utils.hpp b/cpp/include/kvikio/utils.hpp index e48088ec09..3bef106ff8 100644 --- a/cpp/include/kvikio/utils.hpp +++ b/cpp/include/kvikio/utils.hpp @@ -115,33 +115,6 @@ constexpr bool is_host_memory(const void* ptr) { return true; } return ret; } -/** - * @brief RAII wrapper for a CUDA primary context - */ -class CudaPrimaryContext { - public: - CUdevice dev{}; - CUcontext ctx{}; - - CudaPrimaryContext(int device_ordinal) - { - CUDA_DRIVER_TRY(cudaAPI::instance().DeviceGet(&dev, device_ordinal)); - CUDA_DRIVER_TRY(cudaAPI::instance().DevicePrimaryCtxRetain(&ctx, dev)); - } - CudaPrimaryContext(const CudaPrimaryContext&) = delete; - CudaPrimaryContext& operator=(CudaPrimaryContext const&) = delete; - CudaPrimaryContext(CudaPrimaryContext&&) = delete; - CudaPrimaryContext&& operator=(CudaPrimaryContext&&) = delete; - ~CudaPrimaryContext() - { - try { - CUDA_DRIVER_TRY(cudaAPI::instance().DevicePrimaryCtxRelease(dev), CUfileException); - } catch (const CUfileException& e) { - std::cerr << e.what() << std::endl; - } - } -}; - /** * @brief Given a device ordinal, return the primary context of the device. * @@ -152,12 +125,22 @@ class CudaPrimaryContext { */ [[nodiscard]] KVIKIO_EXPORT inline CUcontext get_primary_cuda_context(int ordinal) { - static std::map _primary_contexts; + static std::map _cache; static std::mutex _mutex; std::lock_guard const lock(_mutex); - _primary_contexts.try_emplace(ordinal, ordinal); - return _primary_contexts.at(ordinal).ctx; + if (_cache.find(ordinal) == _cache.end()) { + CUdevice dev{}; + CUcontext ctx{}; + CUDA_DRIVER_TRY(cudaAPI::instance().DeviceGet(&dev, ordinal)); + + // Notice, we let the primary context leak at program exit. We do this because `_cache` + // is static and we are not allowed to call `cuDevicePrimaryCtxRelease()` after main: + // + CUDA_DRIVER_TRY(cudaAPI::instance().DevicePrimaryCtxRetain(&ctx, dev)); + _cache.emplace(ordinal, ctx); + } + return _cache.at(ordinal); } /**