-
Notifications
You must be signed in to change notification settings - Fork 320
DLPack to mdspan
#7047
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
fbusato
wants to merge
60
commits into
NVIDIA:main
Choose a base branch
from
fbusato:dlpack-to-mdspan
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
DLPack to mdspan
#7047
Changes from all commits
Commits
Show all changes
60 commits
Select commit
Hold shift + click to select a range
750ca5a
first version
fbusato f040c10
add unit test
fbusato 464ccc2
documentation
fbusato 6f32ae9
Update libcudacxx/include/cuda/__mdspan/mdspan_to_dlpack.h
fbusato 3457d3a
Merge branch 'mdspan-to-dlpack' of github.com:fbusato/cccl into mdspa…
fbusato ee05eda
add many types
fbusato 4d2e0da
remove operator->
fbusato f290320
formatting
fbusato 7a22848
fix MSVC warning
fbusato f78db30
improve documentation
fbusato 1467ab2
fix MSVC warning
fbusato d844f65
first version
fbusato 3843556
complete the implementation
fbusato 977909f
add unit test
fbusato b0e1fbc
cuda.coop: Use cuda.core.experimental.Linker instead of internal numb…
shwina 50da3d4
Make c2h vector comparisons `constexpr` (#7009)
davebayer f8a4d06
improves comments on decoupled lookback example (#7015)
elstehle e9f0a13
Extract reduce_op_sync into a free function (#7004)
bernhardmgruber 362d316
Remove experimental namespace from cuda.core import (#7022)
NaderAlAwar 28d22c9
reexpress completion signature transform alias to make clangd happy (…
ericniebler 1e28e8c
Qualify call to `__launch_impl` in launch.h to avoid ambiguity errors…
ericniebler f21a158
Rework hierarchy levels (#6957)
davebayer 1ef85d4
Use vectorized tuning for triad benchmark for dtypes of size 2 (#7019)
NaderAlAwar 00a1b95
[libcu++] Fix synchronous resource adapter property passing (#6976)
pciolkosz adc23f5
[libcu++] Remove _view from the shared memory getter name (#6997)
pciolkosz 33aa542
[thrust] Ignore CUDA free errors in thrust memory resource (#7002)
pciolkosz 262b718
[libcu++] Correctly handle extended lambda in cuda::launch (#6987)
pciolkosz 6402bc6
the `<stdexcept>` header must be included when using `_CCCL_THROW`, …
fbusato 5546b87
Error out when nvrtcc cannot parse cuda_thread_count (#7035)
bernhardmgruber 58aba1d
Allow all public headers to be included with host compilers only (#7012)
davebayer e80cee2
[cuda.compute]: Fixes and updates to benchmarks (#6999)
shwina d91b711
Support operations with side-effects (state) in `cuda.compute` (#7008)
shwina c40c68d
Fix `cuda::memcpy async` edge cases and add more tests (#6608)
bernhardmgruber 16bdfbf
Explicitly set `CCCL_TOPLEVEL_PROJECT` to `OFF` when needed (#7016)
KyleFromNVIDIA 11d32ec
[libcu++] Add explicit alignment specification in buffer (#7005)
pciolkosz d1dcaa5
Use the sccache-dist build cluster for RAPIDS CI jobs (#7014)
trxcllnt 52834f8
first version
fbusato dec5dca
complete the implementation
fbusato b38a6a7
add unit test
fbusato fe72fcc
Merge branch 'dlpack-to-mdspan' of github.com:fbusato/cccl into dlpac…
fbusato 5be1893
fix unit test
fbusato f0909df
formatting
fbusato d149dff
minor fixes
fbusato e96ebea
fix compiler warnings
fbusato 136ab59
refactor vector type traits by removing conditional compilation for v…
fbusato 501f48c
reenable vector types for CTK 13
fbusato bd6094c
Merge branch 'main' into mdspan-to-dlpack
fbusato 604257d
fix msvc warning
fbusato f7c5eb4
Merge branch 'mdspan-to-dlpack' into dlpack-to-mdspan
fbusato 14cf251
documentation and copyright
fbusato eb2635a
fix index_operator.pass
fbusato ea7e4e4
fix formatting
fbusato 8e813f1
Merge branch 'main' into mdspan-to-dlpack
fbusato c20e897
Merge branch 'main' into dlpack-to-mdspan
fbusato b6a52cd
use internal type
fbusato 0f8d8b7
Merge branch 'main' into mdspan-to-dlpack
fbusato 1c7f5d4
Merge branch 'mdspan-to-dlpack' into dlpack-to-mdspan
fbusato 9bbf73b
address comments
fbusato 83b6eb2
Merge branch 'main' into dlpack-to-mdspan
fbusato ba38968
use _CCCL_HAS_DLPACK
fbusato File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
129 changes: 129 additions & 0 deletions
129
docs/libcudacxx/extended_api/mdspan/dlpack_to_mdspan.rst
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| .. _libcudacxx-extended-api-mdspan-dlpack-to-mdspan: | ||
|
|
||
| DLPack to ``mdspan`` | ||
| ==================== | ||
|
|
||
| This functionality provides a conversion from `DLPack <https://dmlc.github.io/dlpack/latest/>`__ ``DLTensor`` to ``cuda::host_mdspan``, ``cuda::device_mdspan``, and ``cuda::managed_mdspan``. | ||
|
|
||
| Defined in the ``<cuda/mdspan>`` header. | ||
|
|
||
| Conversion functions | ||
| -------------------- | ||
|
|
||
| .. code:: cuda | ||
|
|
||
| namespace cuda { | ||
|
|
||
| template <typename ElementType, size_t Rank, typename LayoutPolicy = cuda::std::layout_stride> | ||
| [[nodiscard]] cuda::host_mdspan<ElementType, cuda::std::dims<Rank, int64_t>, LayoutPolicy> | ||
| to_host_mdspan(const DLTensor& tensor); | ||
|
|
||
| template <typename ElementType, size_t Rank, typename LayoutPolicy = cuda::std::layout_stride> | ||
| [[nodiscard]] cuda::device_mdspan<ElementType, cuda::std::dims<Rank, int64_t>, LayoutPolicy> | ||
| to_device_mdspan(const DLTensor& tensor); | ||
|
|
||
| template <typename ElementType, size_t Rank, typename LayoutPolicy = cuda::std::layout_stride> | ||
| [[nodiscard]] cuda::managed_mdspan<ElementType, cuda::std::dims<Rank, int64_t>, LayoutPolicy> | ||
| to_managed_mdspan(const DLTensor& tensor); | ||
|
|
||
| } // namespace cuda | ||
|
|
||
| Template parameters | ||
| ------------------- | ||
|
|
||
| - ``ElementType``: The element type of the resulting ``mdspan``. Must match the ``DLTensor::dtype``. | ||
| - ``Rank``: The number of dimensions. Must match ``DLTensor::ndim``. | ||
| - ``LayoutPolicy``: The layout policy for the resulting ``mdspan``. Defaults to ``cuda::std::layout_stride``. Supported layouts are: | ||
|
|
||
| - ``cuda::std::layout_right`` (C-contiguous, row-major) | ||
| - ``cuda::std::layout_left`` (Fortran-contiguous, column-major) | ||
| - ``cuda::std::layout_stride`` (general strided layout) | ||
|
|
||
| Semantics | ||
| --------- | ||
|
|
||
| The conversion produces a non-owning ``mdspan`` view of the ``DLTensor`` data: | ||
|
|
||
| - The ``mdspan`` data pointer is computed as ``static_cast<char*>(tensor.data) + tensor.byte_offset``. | ||
| - For ``rank > 0``, ``mdspan.extent(i)`` is ``tensor.shape[i]``. | ||
| - For ``layout_stride``, ``mdspan.stride(i)`` is ``tensor.strides[i]`` (or computed as row-major if ``strides`` is ``nullptr`` for DLPack < v1.2). | ||
| - The device type is validated: | ||
|
|
||
| - ``kDLCPU`` for ``to_host_mdspan`` | ||
| - ``kDLCUDA`` for ``to_device_mdspan`` | ||
| - ``kDLCUDAManaged`` for ``to_managed_mdspan`` | ||
|
|
||
| Constraints | ||
| ----------- | ||
|
|
||
| - ``LayoutPolicy`` must be one of ``cuda::std::layout_right``, ``cuda::std::layout_left``, or ``cuda::std::layout_stride``. | ||
| - For ``layout_right`` and ``layout_left``, the ``DLTensor`` strides must be compatible with the layout. | ||
|
|
||
| Runtime errors | ||
| -------------- | ||
|
|
||
| The conversion throws ``std::invalid_argument`` in the following cases: | ||
|
|
||
| - ``DLTensor::ndim`` does not match the specified ``Rank``. | ||
| - ``DLTensor::dtype`` does not match ``ElementType``. | ||
| - ``DLTensor::data`` is ``nullptr``. | ||
| - ``DLTensor::shape`` is ``nullptr`` (for rank > 0). | ||
| - Any ``DLTensor::shape[i]`` is negative. | ||
| - ``DLTensor::strides`` is ``nullptr`` for DLPack v1.2 or later. | ||
| - ``DLTensor::strides`` is ``nullptr`` for ``layout_left`` with rank > 1 (DLPack < v1.2). | ||
| - ``DLTensor::strides[i]`` is not positive for ``layout_stride``. | ||
| - ``DLTensor::strides`` are not compatible with the requested ``layout_right`` or ``layout_left``. | ||
| - ``DLTensor::device.device_type`` does not match the target mdspan type. | ||
| - Data pointer is not properly aligned for the element type. | ||
|
|
||
| Availability notes | ||
| ------------------ | ||
|
|
||
| - This API is available only when DLPack header is present, namely ``<dlpack/dlpack.h>`` is found in the include path. | ||
| - Requires DLPack major version 1. | ||
|
|
||
| References | ||
| ---------- | ||
|
|
||
| - `DLPack C API <https://dmlc.github.io/dlpack/latest/c_api.html>`__ documentation. | ||
|
|
||
| Example | ||
| ------- | ||
|
|
||
| .. code:: cuda | ||
|
|
||
| #include <dlpack/dlpack.h> | ||
| #include <cuda/mdspan> | ||
| #include <cuda/std/cassert> | ||
| #include <cuda/std/cstdint> | ||
|
|
||
| int main() { | ||
| int data[6] = {0, 1, 2, 3, 4, 5}; | ||
|
|
||
| // Create a DLTensor manually for demonstration | ||
| int64_t shape[2] = {2, 3}; | ||
| int64_t strides[2] = {3, 1}; // row-major strides | ||
|
|
||
| DLTensor tensor{}; | ||
| tensor.data = data; | ||
| tensor.device = {kDLCPU, 0}; | ||
| tensor.ndim = 2; | ||
| tensor.dtype = {kDLInt, 32, 1}; | ||
| tensor.shape = shape; | ||
| tensor.strides = strides; | ||
| tensor.byte_offset = 0; | ||
|
|
||
| // Convert to host_mdspan | ||
| auto md = cuda::to_host_mdspan<int, 2>(tensor); | ||
|
|
||
| assert(md.rank() == 2); | ||
| assert(md.extent(0) == 2 && md.extent(1) == 3); | ||
| assert(md.stride(0) == 3 && md.stride(1) == 1); | ||
| assert(md.data_handle() == data); | ||
| assert(md(0, 0) == 0 && md(1, 2) == 5); | ||
| } | ||
|
|
||
| See also | ||
| -------- | ||
|
|
||
| - :ref:`libcudacxx-extended-api-mdspan-mdspan-to-dlpack` for the reverse conversion. |
137 changes: 137 additions & 0 deletions
137
docs/libcudacxx/extended_api/mdspan/mdspan_to_dlpack.rst
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,137 @@ | ||
| .. _libcudacxx-extended-api-mdspan-mdspan-to-dlpack: | ||
|
|
||
| ``mdspan`` to DLPack | ||
| ==================== | ||
|
|
||
| This functionality provides a conversion from ``cuda::host_mdspan``, ``cuda::device_mdspan``, and ``cuda::managed_mdspan`` to `DLPack <https://dmlc.github.io/dlpack/latest/>`__ ``DLTensor`` view. | ||
|
|
||
| Defined in the ``<cuda/mdspan>`` header. | ||
|
|
||
| Conversion functions | ||
| -------------------- | ||
|
|
||
| .. code:: cuda | ||
|
|
||
| namespace cuda { | ||
|
|
||
| template <typename T, typename Extents, typename Layout, typename Accessor> | ||
| [[nodiscard]] __dlpack_tensor<Extents::rank()> | ||
| to_dlpack(const cuda::host_mdspan<T, Extents, Layout, Accessor>& mdspan); | ||
|
|
||
| template <typename T, typename Extents, typename Layout, typename Accessor> | ||
| [[nodiscard]] __dlpack_tensor<Extents::rank()> | ||
| to_dlpack(const cuda::device_mdspan<T, Extents, Layout, Accessor>& mdspan, | ||
| cuda::device_ref device = cuda::device_ref{0}); | ||
fbusato marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| template <typename T, typename Extents, typename Layout, typename Accessor> | ||
| [[nodiscard]] __dlpack_tensor<Extents::rank()> | ||
| to_dlpack(const cuda::managed_mdspan<T, Extents, Layout, Accessor>& mdspan); | ||
|
|
||
| } // namespace cuda | ||
|
|
||
| Types | ||
| ----- | ||
|
|
||
| ``__dlpack_tensor`` is an internal class that stores a ``DLTensor`` and owns the backing storage for its ``shape`` and ``strides`` pointers. The class does not use any heap allocation. | ||
|
|
||
| .. code:: cuda | ||
|
|
||
| namespace cuda { | ||
|
|
||
| template <size_t Rank> | ||
| class __dlpack_tensor { | ||
| public: | ||
| __dlpack_tensor(); | ||
| __dlpack_tensor(const __dlpack_tensor&) noexcept; | ||
| __dlpack_tensor(__dlpack_tensor&&) noexcept; | ||
| __dlpack_tensor& operator=(const __dlpack_tensor&) noexcept; | ||
| __dlpack_tensor& operator=(__dlpack_tensor&&) noexcept; | ||
| ~__dlpack_tensor() noexcept = default; | ||
|
|
||
| DLTensor& get() noexcept; | ||
| const DLTensor& get() const noexcept; | ||
| }; | ||
|
|
||
| } // namespace cuda | ||
|
|
||
| ``cuda::__dlpack_tensor`` stores a ``DLTensor`` and owns the backing storage for its ``shape`` and ``strides`` pointers. The class does not use any heap allocation. | ||
|
|
||
| .. note:: **Lifetime** | ||
|
|
||
| The ``DLTensor`` associated with ``cuda::__dlpack_tensor`` must not outlive the wrapper. If the wrapper is destroyed, the returned ``DLTensor::shape`` and ``DLTensor::strides`` pointers will dangle. | ||
|
|
||
| .. note:: **Const-correctness** | ||
|
|
||
| ``DLTensor::data`` points at ``mdspan.data_handle()`` (or is ``nullptr`` if ``mdspan.size() == 0``). If ``T`` is ``const``, the pointer is ``const_cast``'d because ``DLTensor::data`` is unqualified. | ||
|
|
||
| Semantics | ||
| --------- | ||
|
|
||
| The conversion produces a non-owning DLPack view of the ``mdspan`` data and metadata: | ||
|
|
||
| - ``DLTensor::ndim`` is ``mdspan.rank()``. | ||
| - For rank > 0, ``DLTensor::shape[i]`` is ``mdspan.extent(i)``. | ||
| - For rank > 0, ``DLTensor::strides[i]`` is ``mdspan.stride(i)``. | ||
| - ``DLTensor::byte_offset`` is always ``0``. | ||
| - ``DLTensor::device`` is: | ||
|
|
||
| - ``{kDLCPU, 0}`` for ``cuda::host_mdspan`` | ||
| - ``{kDLCUDA, device.get()}`` for ``cuda::device_mdspan`` | ||
| - ``{kDLCUDAManaged, 0}`` for ``cuda::managed_mdspan`` | ||
|
|
||
| Element types are mapped to ``DLDataType`` according to the DLPack conventions, including: | ||
|
|
||
| - ``bool``. | ||
| - Signed and unsigned integers. | ||
| - IEEE-754 Floating-point and extended precision floating-point, including ``__half``, ``__nv_bfloat16``, ``__float128``, FP8, FP6, FP4 when available. | ||
| - Complex: ``cuda::std::complex<__half>``, ``cuda::std::complex<float>``, and ``cuda::std::complex<double>``. | ||
| - `CUDA built-in vector types <https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/cpp-language-extensions.html#built-in-types>`__, such as ``int2``, ``float4``, etc. | ||
| - Vector types for extended floating-point, such as ``__half2``, ``__nv_fp8x4_e4m3``, etc. | ||
|
|
||
| Constraints | ||
| ----------- | ||
|
|
||
| - The accessor ``data_handle_type`` must be a pointer type. | ||
|
|
||
| Runtime errors | ||
| -------------- | ||
|
|
||
| - If any ``extent(i)`` or ``stride(i)`` cannot be represented in ``int64_t``, the conversion raises an exception. | ||
fbusato marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| Availability notes | ||
| ------------------ | ||
|
|
||
| - This API is available only when DLPack header is present, namely ``<dlpack/dlpack.h>`` is found in the include path. | ||
|
|
||
| References | ||
| ---------- | ||
|
|
||
| - `DLPack C API <https://dmlc.github.io/dlpack/latest/c_api.html>`__ documentation. | ||
|
|
||
| Example | ||
| ------- | ||
|
|
||
| .. code:: cuda | ||
|
|
||
| #include <dlpack/dlpack.h> | ||
| #include <cuda/mdspan> | ||
| #include <cuda/std/cassert> | ||
| #include <cuda/std/cstdint> | ||
|
|
||
| int main() { | ||
| using extents_t = cuda::std::extents<size_t, 2, 3>; | ||
|
|
||
| int data[6] = {0, 1, 2, 3, 4, 5}; | ||
| cuda::host_mdspan<int, extents_t> md{data, extents_t{}}; | ||
|
|
||
| auto dl = cuda::to_dlpack(md); | ||
| const auto& dltensor = dl.get(); | ||
| // auto dltensor = dl.get(); is incorrect; it returns a reference to a temporary object that will be destroyed at the end of the statement. | ||
|
|
||
| // `dl` owns the shape/stride storage; `dltensor.data` is a non-owning pointer to `data`. | ||
| assert(dltensor.device.device_type == kDLCPU); | ||
| assert(dltensor.ndim == 2); | ||
| assert(dltensor.shape[0] == 2 && dltensor.shape[1] == 3); | ||
| assert(dltensor.strides[0] == 3 && dltensor.strides[1] == 1); | ||
| assert(dltensor.data == data); | ||
| } | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.