-
Notifications
You must be signed in to change notification settings - Fork 156
[RFC] Support DLPACK C Functions for Speed Exchange and Stream Handling #174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
145b3d9
0330ad9
ffb153d
a947bef
4b1de24
bddb25b
8e628e8
e120200
df77508
180dfcd
eda587a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| /*! | ||
| * Copyright (c) 2017 by Contributors | ||
| * Copyright (c) 2017 - by Contributors | ||
| * \file dlpack.h | ||
| * \brief The common header of DLPack. | ||
| */ | ||
|
|
@@ -326,7 +326,7 @@ typedef struct DLManagedTensor { | |
| * | ||
| * \note This is the current standard DLPack exchange data structure. | ||
| */ | ||
| struct DLManagedTensorVersioned { | ||
| typedef struct DLManagedTensorVersioned { | ||
| /*! | ||
| * \brief The API and ABI version of the current managed Tensor | ||
| */ | ||
|
|
@@ -360,6 +360,258 @@ struct DLManagedTensorVersioned { | |
| uint64_t flags; | ||
| /*! \brief DLTensor which is being memory managed */ | ||
| DLTensor dl_tensor; | ||
| } DLManagedTensorVersioned; | ||
|
|
||
| //-------------------------------------------------------------------- | ||
| // DLPack C functions for speed exchange | ||
| //-------------------------------------------------------------------- | ||
| /*! | ||
| * \brief A generic C-style allocator that exposes allocation of a Tensor/Array. | ||
| * | ||
| * This information can then be used to set allocators of a callee to run allocations. | ||
| * This information can then be used to set the callee's allocator to perform allocations. | ||
| * This function can be exposed by the framework through the DLPackExchangeAPI. | ||
| * | ||
| * This particular function does not assume a Python environment; as a result, | ||
| * the error handling mechanism is different from Python-related functions. | ||
| * | ||
| * \param prototype The prototype DLTensor to offer details about the device and shape. | ||
| * Other field information will be ignored during allocation. | ||
| * \param out The output DLManagedTensorVersioned. | ||
| * \param error_ctx The context to set the error. | ||
| * \param SetError The function to set the error. | ||
| * \return 0 on success, -1 on failure. | ||
| * The callee should call SetError(error_ctx, kind, message) to set the error kind and message. | ||
| * \note Error propagation via SetError. | ||
| * | ||
| * \sa DLPackExchangeAPI | ||
| */ | ||
| typedef int (*DLPackManagedTensorAllocator)( // | ||
| DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, // | ||
| void (*SetError)(void* error_ctx, const char* kind, const char* message) // | ||
| ); | ||
|
|
||
| /*! | ||
| * \brief Exports a PyObject* Tensor/NDArray to a DLManagedTensorVersioned. | ||
| * | ||
| * This function is a C-style function pointer to quickly convert a PyObject* Tensor/NDArray | ||
| * to a DLManagedTensorVersioned without going through the Python interpreter. | ||
| * | ||
| * This function does not perform any stream synchronization. The consumer should query | ||
| * DLPackCurrentWorkStream to get the current work stream and launch kernels on it. | ||
| * | ||
| * This function is exposed by the framework through the DLPackExchangeAPI. | ||
| * | ||
| * This information can then be picked up by importers and libraries to perform a fast conversion. | ||
| * This function should not throw any exceptions; if it fails, it should return -1 and | ||
| * set the error message via PyErr_SetXXX. | ||
| * | ||
| * \param py_object The Python object to convert; this should be PyObject*. | ||
| * We use void* to avoid dependency on Python.h. | ||
| * | ||
| * \param out The output DLManagedTensorVersioned. | ||
| * \return 0 on success, -1 on failure. PyError should be set if -1 is returned. | ||
| * \note We use void* to avoid dependency on Python.h, so this specific type is | ||
| * not dependent on Python.h and can be copied to dlpack.h. | ||
| * | ||
| * \sa DLPackExchangeAPI, DLPackCurrentWorkStream | ||
| */ | ||
| typedef int (*DLPackManagedTensorFromPyObjectNoSync)( // | ||
| void* py_object, // | ||
| DLManagedTensorVersioned** out // | ||
| ); | ||
|
|
||
| /*! | ||
| * \brief Exports a PyObject* Tensor/NDArray to a DLTensor whose space is pre-allocated on stack. | ||
| * | ||
| * This function is a C-style function pointer to quickly convert a PyObject* Tensor/NDArray | ||
| * to a DLTensor whose space is pre-allocated on stack without going through the Python interpreter. | ||
| * | ||
| * This is an non-owning conversion, the producer still owns the memory of data, strides, shape. | ||
| * The liveness of DLTensor is only guaranteed until the consumer returns control to the caller. | ||
| * | ||
| * In the context of this function, we expect the producer to allocated space for data, strides and shape. | ||
| * | ||
| * This function does not perform any stream synchronization. The consumer should query | ||
| * DLPackCurrentWorkStream to get the current work stream and launch kernels on it. | ||
| * | ||
| * This function is useful when the consumer do not need to retain the tensor memory. | ||
| * It generally can provide about 2x faster conversion than DLPackManagedTensorFromPyObjectNoSync. | ||
| * | ||
| * For cases where consumer may needs to reorganize the tensor memory via temporary managed copy, | ||
| * DLPackManagedTensorFromPyObjectNoSync should be used. | ||
| * | ||
| * This function is exposed by the framework through the DLPackExchangeAPI. | ||
| * | ||
| * This information can then be picked up by importers and libraries to perform a fast conversion. | ||
| * This function should not throw any exceptions; if it fails, it should return -1 and | ||
| * set the error message via PyErr_SetXXX. | ||
| * | ||
| * \param py_object The Python object to convert; this should be PyObject*. | ||
| * We use void* to avoid dependency on Python.h. | ||
| * | ||
| * \param out The output DLTensor, whose space is pre-allocated on stack. | ||
| * \return 0 on success, -1 on failure. PyError should be set if -1 is returned. | ||
| * \note We use void* to avoid dependency on Python.h, so this specific type is | ||
| * not dependent on Python.h and can be copied to dlpack.h. | ||
| * | ||
| * \sa DLPackExchangeAPI, DLPackCurrentWorkStream | ||
| */ | ||
| typedef int (*DLPackDLTensorFromPyObjectNoSync)( // | ||
| void* py_object, // | ||
| DLTensor* out // | ||
| ); | ||
|
|
||
| /*! | ||
| * \brief Obtain the current work stream of a device. | ||
| * | ||
| * This function is a C-style function pointer to obtain the current work stream | ||
| * of a device for frameworks that rely on a context manager to manage the stream. | ||
| * For example, it should map to torch.cuda.current_stream in PyTorch. | ||
| * | ||
| * This function can be set to NULL if the framework does not rely on a context manager | ||
| * to manage the stream. However, we encourage frameworks to provide this function | ||
| * if possible. | ||
| * | ||
| * As if this field is not set, likely consumer cannot safely do stream based | ||
| * exchange based on the | ||
| * | ||
| * \param device_type The device type. | ||
| * \param device_id The device id. | ||
| * \param out_current_stream The output current work stream. | ||
| * \return 0 on success, -1 on failure. | ||
| * | ||
| * \sa DLPackExchangeAPI | ||
| */ | ||
| typedef int (*DLPackCurrentWorkStream)( // | ||
| DLDeviceType device_type, // | ||
| int32_t device_id, // | ||
| void** out_current_stream // | ||
| ); | ||
|
|
||
| /*! | ||
| * \brief Imports a DLManagedTensorVersioned to a PyObject* Tensor/NDArray. | ||
| * | ||
| * This function is a C-style function pointer to quickly convert a DLManagedTensorVersioned | ||
| * to a PyObject* without going through the Python Interpreter. | ||
| * | ||
| * This function does not perform any stream synchronization. | ||
| * | ||
| * This function is exposed by the framework through the DLPackExchangeAPI. | ||
| * | ||
| * \param tensor The DLManagedTensorVersioned to convert. | ||
| * \param out_py_object The output Python object. | ||
| * \return 0 on success, -1 on failure. PyError should be set if -1 is returned. | ||
| * \note We use void* to avoid dependency on Python.h, so this specific type is | ||
| * not dependent on Python.h and can be copied to dlpack.h. | ||
| * | ||
| * \sa DLPackExchangeAPI | ||
| */ | ||
| typedef int (*DLPackManagedTensorToPyObjectNoSync)( // | ||
| DLManagedTensorVersioned* tensor, void** out_py_object // | ||
| ); | ||
|
|
||
| /*! | ||
| * \brief Framework-specific function pointers table for DLPack exchange. | ||
| * | ||
| * Guidelines for leveraging DLPackExchangeAPI: | ||
| * | ||
| * There are generally two kinds of consumer needs for DLPack exchange: | ||
| * - N0: library support, where consumer.kernel(x, y, z) would like to run a kernel | ||
| * with the data from x, y, z. The consumer is also expected to run the kernel with the same | ||
| * stream context as the producer. For example, when x, y, z is torch.Tensor, | ||
| * consumer should query exchange_api->current_work_stream to get the | ||
| * current stream and launch the kernel with the same stream. | ||
| * This setup is necessary for no synchronization in kernel launch and maximum compatibility | ||
| * with CUDA graph capture in the producer. | ||
| * This is the desirable behavior for library extension support for frameworks like PyTorch. | ||
| * - N1: data ingestion and retention | ||
| * | ||
| * Note that obj.__dlpack__() API should provide useful ways for N1. | ||
| * The primary focus of the current DLPackExchangeAPI is to enable faster exchange N0 | ||
| * with the support of the function pointer current_work_stream. | ||
| * | ||
| * Array/Tensor libraries should statically create and initialize this structure | ||
| * then return a pointer to DLPackExchangeAPI as an int value in Tensor/Array. | ||
| * The DLPackExchangeAPI* should stay alive throughout the lifetime of the process. | ||
| * | ||
| * One simple way to do so is to create a static instance of DLPackExchangeAPI | ||
| * within the framework and return a pointer to it. The following code | ||
| * shows an example to do so in C++. It should also be reasonably easy | ||
| * to do so in other languages. | ||
| * | ||
| * \code | ||
| * struct MyDLPackExchangeAPI : public DLPackExchangeAPI { | ||
| * MyDLPackExchangeAPI() { | ||
| * version.major = DLPACK_MAJOR_VERSION; | ||
| * version.minor = DLPACK_MINOR_VERSION; | ||
| * managed_tensor_allocator = MyDLPackManagedTensorAllocator; | ||
| * managed_tensor_from_py_object_no_sync = MyDLPackManagedTensorFromPyObjectNoSync; | ||
| * managed_tensor_to_py_object_no_sync = MyDLPackManagedTensorToPyObjectNoSync; | ||
| * dltensor_from_py_object_no_sync = MyDLPackDLTensorFromPyObjectNoSync; | ||
| * current_work_stream = MyDLPackCurrentWorkStream; | ||
| * prev_version_api = nullptr; | ||
| * } | ||
| * | ||
| * static const DLPackExchangeAPI* Global() { | ||
| * static MyDLPackExchangeAPI inst; | ||
| * return &inst; | ||
| * } | ||
| * }; | ||
| * \endcode | ||
| * | ||
| * Each framework should attach a dunder `__c_dlpack_exchange_api__` integer | ||
| * to point to the DLPackExchangeAPI* pointer. | ||
| * | ||
| * Importantly, the attribute should be attached to the class of the Tensor, not the instance. | ||
| * | ||
| * mypackage.Tensor.__c_dlpack_exchange_api__ = MyPackageDLPackExchangeAPI | ||
| * | ||
| * or equivalently: | ||
| * | ||
| * type(tensor_obj).__c_dlpack_exchange_api__ = MyPackageDLPackExchangeAPI | ||
| */ | ||
| struct DLPackExchangeAPI { | ||
| /*! | ||
| * \brief The current DLPack version. | ||
| */ | ||
| DLPackVersion version; | ||
| /*! | ||
| * \brief Optional pointer to an older DLPackExchangeAPI in the chain. | ||
| * | ||
| * It should be set to NULL if the framework does not support older versions. | ||
| * | ||
| * \sa DLPackExchangeAPI | ||
| */ | ||
| struct DLPackExchangeAPI* prev_version_api; | ||
|
||
| /*! | ||
| * \brief Framework-specific function pointer for DLPackManagedTensorAllocator | ||
| * \sa DLPackManagedTensorAllocator | ||
| */ | ||
| DLPackManagedTensorAllocator managed_tensor_allocator; | ||
| /*! | ||
| * \brief Framework-specific function pointer for DLPackManagedTensorFromPyObject | ||
| * \sa DLPackManagedTensorFromPyObject | ||
| */ | ||
| DLPackManagedTensorFromPyObjectNoSync managed_tensor_from_py_object_no_sync; | ||
| /*! | ||
| * \brief Framework-specific function pointer for DLPackManagedTensorToPyObject | ||
| * \sa DLPackManagedTensorToPyObject | ||
| */ | ||
| DLPackManagedTensorToPyObjectNoSync managed_tensor_to_py_object_no_sync; | ||
| /*! | ||
| * \brief Framework-specific function pointer for DLPackDLTensorFromPyObject | ||
| * \sa DLPackDLTensorFromPyObjectNoSync | ||
| */ | ||
| DLPackDLTensorFromPyObjectNoSync dltensor_from_py_object_no_sync; | ||
| /*! | ||
| * \brief Framework-specific function pointer for DLPackCurrentWorkStream | ||
| * | ||
| * This function can be set to NULL if the framework does not rely on context manager to manage the stream. | ||
| * | ||
| * \sa DLPackCurrentWorkStream | ||
| */ | ||
| DLPackCurrentWorkStream current_work_stream; | ||
| }; | ||
|
|
||
| #ifdef __cplusplus | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit unsure if this is specified as well as it needs to be? I.e. I think NULL would be the default stream.
The question is: Is there any need (or not) an "undefined or no synchronization" return value (such as -1)?
If not, we are all good, but if some producer might need this (for whatever reason), then we need to specify this here.
The alternative is that the producer just has to return the default stream (otherwise the consumer has to guess a stream anyway probably in the kernel use-case!).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don;t think there is a need in this particular context. returning the default stream is likely more well defined