-
Notifications
You must be signed in to change notification settings - Fork 160
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Ensure that
cuda_memory_resource
allocates memory on the proper dev…
…ice (#2073) * Ensure that `cuda_memory_resource` allocates memory on the proper device * Move `__ensure_current_device` to own header
- Loading branch information
Showing
4 changed files
with
99 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
65 changes: 65 additions & 0 deletions
65
libcudacxx/include/cuda/std/__cuda/ensure_current_device.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Part of libcu++, the C++ Standard Library for your entire system, | ||
// under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef _CUDA__STD__CUDA_ENSURE_CURRENT_DEVICE_H | ||
#define _CUDA__STD__CUDA_ENSURE_CURRENT_DEVICE_H | ||
|
||
#include <cuda/std/detail/__config> | ||
|
||
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) | ||
# pragma GCC system_header | ||
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) | ||
# pragma clang system_header | ||
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) | ||
# pragma system_header | ||
#endif // no system header | ||
|
||
#if !defined(_CCCL_CUDA_COMPILER_NVCC) && !defined(_CCCL_CUDA_COMPILER_NVHPC) | ||
# include <cuda_runtime_api.h> | ||
#endif // !_CCCL_CUDA_COMPILER_NVCC && !_CCCL_CUDA_COMPILER_NVHPC | ||
|
||
#include <cuda/std/__cuda/api_wrapper.h> | ||
#include <cuda/std/__exception/cuda_error.h> | ||
|
||
_LIBCUDACXX_BEGIN_NAMESPACE_CUDA | ||
|
||
//! @brief `__ensure_current_device` is a simple helper that the current device is set to the right one. | ||
//! Only changes the current device if the target device is not the current one | ||
struct __ensure_current_device | ||
{ | ||
int __target_device_ = 0; | ||
int __original_device_ = 0; | ||
|
||
//! @brief Querries the current device and if that is different than \p __target_device sets the current device to | ||
//! \p __target_device | ||
__ensure_current_device(const int __target_device) | ||
: __target_device_(__target_device) | ||
{ | ||
_CCCL_TRY_CUDA_API(::cudaGetDevice, "Failed to query current device", &__original_device_); | ||
if (__original_device_ != __target_device_) | ||
{ | ||
_CCCL_TRY_CUDA_API(::cudaSetDevice, "Failed to set device", __target_device_); | ||
} | ||
} | ||
|
||
//! @brief If the \p __original_device was not equal to \p __target_device sets the current device back to | ||
//! \p __original_device | ||
~__ensure_current_device() | ||
{ | ||
if (__original_device_ != __target_device_) | ||
{ | ||
_CCCL_TRY_CUDA_API(::cudaSetDevice, "Failed to set device", __original_device_); | ||
} | ||
} | ||
}; | ||
|
||
_LIBCUDACXX_END_NAMESPACE_CUDA | ||
|
||
#endif //_CUDA__STD__CUDA_ENSURE_CURRENT_DEVICE_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters