diff --git a/ci/dcu_test.sh b/ci/dcu_test.sh index be2d0e96369c75..cc303f5466ea50 100644 --- a/ci/dcu_test.sh +++ b/ci/dcu_test.sh @@ -75,6 +75,7 @@ function hybrid_paddlex() { function main(){ cd ${PADDLE_ROOT}/build pip install hypothesis + /opt/py310/bin/pip install -r ${PADDLE_ROOT}/python/unittest_py/requirements.txt /opt/py310/bin/pip install safetensors if ls ${PADDLE_ROOT}/build/python/dist/*whl >/dev/null 2>&1; then pip install ${PADDLE_ROOT}/build/python/dist/*whl diff --git a/paddle/fluid/framework/dlpack_tensor.cc b/paddle/fluid/framework/dlpack_tensor.cc index 793d0bbdf6e695..02b27cbe0ef9ad 100644 --- a/paddle/fluid/framework/dlpack_tensor.cc +++ b/paddle/fluid/framework/dlpack_tensor.cc @@ -265,7 +265,7 @@ ::DLDataType PhiDataTypeToDLDataType(phi::DataType dtype) { framework::TransToProtoVarType(dtype)); } -phi::Place DLDeviceToPlace(const DLDevice &dl_device) { +phi::Place DLDeviceToPlace(const ::DLDevice &dl_device) { phi::Place place; if (dl_device.device_type == kDLCPU) { place = phi::CPUPlace(); @@ -279,7 +279,7 @@ phi::Place DLDeviceToPlace(const DLDevice &dl_device) { return place; } -DLDevice PlaceToDLDevice(const phi::Place &place) { +::DLDevice PlaceToDLDevice(const phi::Place &place) { return phi::VisitPlace(place, internal::DLDeviceVisitor()); } @@ -358,6 +358,22 @@ DLManagedTensorVersioned *ToDLPackVersioned(const phi::DenseTensor &src, return ToDLPackImpl(src, flags); } +void ToDLPackNonOwningImpl(const phi::DenseTensor &tensor, + ::DLTensor &out) { // NOLINT + // Fill in the pre-allocated DLTensor struct with direct pointers + // This is a non-owning conversion - the caller owns the tensor + // and must keep it alive for the duration of DLTensor usage + out.data = const_cast(tensor.data()); + out.device = PlaceToDLDevice(tensor.place()); + out.ndim = static_cast(tensor.dims().size()); + out.dtype = PhiDataTypeToDLDataType(tensor.dtype()); + // sizes() and strides() return pointers to TensorImpl's stable storage + // which remains valid as long as the tensor is alive + out.shape = const_cast(tensor.dims().Get()); + out.strides = const_cast(tensor.strides().Get()); + out.byte_offset = 0; +} + template phi::DenseTensor FromDLPackImpl(T *src, Deleter deleter) { std::vector shape_vec; diff --git a/paddle/fluid/framework/dlpack_tensor.h b/paddle/fluid/framework/dlpack_tensor.h index e287ce342fa78c..1aa8e79f93e7de 100644 --- a/paddle/fluid/framework/dlpack_tensor.h +++ b/paddle/fluid/framework/dlpack_tensor.h @@ -29,15 +29,19 @@ and paddle/phi/api/lib/tensor_utils.cc */ using Deleter = std::function; -phi::Place DLDeviceToPlace(const DLDevice& device); -DLDevice PlaceToDLDevice(const phi::Place& place); - -TEST_API DLManagedTensor* ToDLPack(const phi::DenseTensor& src, - uint64_t flags = 0); -DLManagedTensorVersioned* ToDLPackVersioned(const phi::DenseTensor& src, - uint64_t flags = 0); -TEST_API phi::DenseTensor FromDLPack(DLManagedTensor* src); -phi::DenseTensor FromDLPackVersioned(DLManagedTensorVersioned* src); +::DLDataType PhiDataTypeToDLDataType(phi::DataType dtype); +phi::DataType DLDataTypeToPhiDataType(::DLDataType type); +phi::Place DLDeviceToPlace(const ::DLDevice& device); +::DLDevice PlaceToDLDevice(const phi::Place& place); + +TEST_API ::DLManagedTensor* ToDLPack(const phi::DenseTensor& src, + uint64_t flags = 0); +::DLManagedTensorVersioned* ToDLPackVersioned(const phi::DenseTensor& src, + uint64_t flags = 0); +void ToDLPackNonOwningImpl(const phi::DenseTensor& tensor, + ::DLTensor& out); // NOLINT +TEST_API phi::DenseTensor FromDLPack(::DLManagedTensor* src); +phi::DenseTensor FromDLPackVersioned(::DLManagedTensorVersioned* src); // A traits to support both DLManagedTensor and DLManagedTensorVersioned template diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index d3b17ad377b7cf..3119464f9cb974 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -763,6 +763,108 @@ class PyLayerBlockContextManager { PyLayerBlockContextManager() = default; }; +int DLPackDLTensorFromPyObjectNoSync(void *py_obj, DLTensor *out) { + try { + // Use handle (non-owning) to avoid unnecessary refcount operations + py::handle handle(static_cast(py_obj)); + paddle::Tensor tensor = handle.cast(); + std::shared_ptr dense_tensor = + std::static_pointer_cast(tensor.impl()); + paddle::framework::ToDLPackNonOwningImpl(*dense_tensor, *out); + return 0; + } catch (const std::exception &e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return -1; + } +} + +int DLPackManagedTensorFromPyObjectNoSync(void *py_obj, + DLManagedTensorVersioned **out) { + try { + py::handle handle(static_cast(py_obj)); + paddle::Tensor tensor = handle.cast(); + std::shared_ptr dense_tensor = + std::static_pointer_cast(tensor.impl()); + *out = paddle::framework::ToDLPackVersioned(*dense_tensor); + return 0; + } catch (const std::exception &e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return -1; + } +} + +int DLPackManagedTensorToPyObjectNoSync(DLManagedTensorVersioned *src, + void **py_obj_out) { + try { + phi::DenseTensor dense_tensor = paddle::framework::FromDLPackVersioned(src); + paddle::Tensor tensor(std::make_shared(dense_tensor)); + egr::EagerUtils::autograd_meta(&tensor)->SetPersistable(false); + *py_obj_out = ToPyObject(tensor); + return 0; + } catch (const std::exception &e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return -1; + } +} + +int DLPackManagedTensorAllocator(::DLTensor *prototype, + ::DLManagedTensorVersioned **out, + void *error_ctx, + void (*SetError)(void *error_ctx, + const char *kind, + const char *message)) { + try { + phi::IntArray shape(prototype->shape, prototype->ndim); + phi::Place place(paddle::framework::DLDeviceToPlace(prototype->device)); + phi::DataType dtype = + paddle::framework::DLDataTypeToPhiDataType(prototype->dtype); + paddle::Tensor tensor = paddle::empty(shape, dtype, place); + std::shared_ptr dense_tensor = + std::static_pointer_cast(tensor.impl()); + *out = paddle::framework::ToDLPackVersioned(*dense_tensor); + return 0; + } catch (const std::exception &e) { + SetError(error_ctx, "DLPackManagedTensorAllocator", e.what()); + return -1; + } +} + +int DLPackCurrentWorkStream(DLDeviceType device_type, + int32_t device_id, + void **out_stream) { + try { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ + defined(PADDLE_WITH_CUSTOM_DEVICE) + if (device_type == kDLCUDA || device_type == kDLROCM) { + *out_stream = platform::get_current_stream(device_id)->raw_stream(); + } +#endif + return 0; + } catch (const std::exception &e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return -1; + } +} + +struct PaddleDLPackExchangeAPI : public ::DLPackExchangeAPI { + PaddleDLPackExchangeAPI() { + header.version.major = DLPACK_MAJOR_VERSION; + header.version.minor = DLPACK_MINOR_VERSION; + header.prev_api = nullptr; + managed_tensor_allocator = DLPackManagedTensorAllocator; + managed_tensor_from_py_object_no_sync = + DLPackManagedTensorFromPyObjectNoSync; + managed_tensor_to_py_object_no_sync = DLPackManagedTensorToPyObjectNoSync; + dltensor_from_py_object_no_sync = DLPackDLTensorFromPyObjectNoSync; + current_work_stream = DLPackCurrentWorkStream; + } + + static const DLPackExchangeAPI *Instance() { + static PaddleDLPackExchangeAPI inst; + return &inst; + } +}; + // NOTE: use to load file by Mmap enum MMapLoadModes { ALLOCATOR_MAPPED_SHARED = 1, @@ -1773,6 +1875,10 @@ PYBIND11_MODULE(libpaddle, m) { dl_device.device_id); }); + m.def("dlpack_exchange_api_ptr", []() -> int64_t { + return reinterpret_cast(PaddleDLPackExchangeAPI::Instance()); + }); + m.def("from_dlpack", [](py::object data) { if (PyCapsule_IsValid(data.ptr(), DLPackTraits::capsule)) { diff --git a/python/paddle/base/dygraph/tensor_patch_methods.py b/python/paddle/base/dygraph/tensor_patch_methods.py index e19d5e7f8405d1..f9545777153f21 100644 --- a/python/paddle/base/dygraph/tensor_patch_methods.py +++ b/python/paddle/base/dygraph/tensor_patch_methods.py @@ -1586,6 +1586,7 @@ def __tvm_ffi_env_stream__(self) -> int: ("__dlpack_device__", __dlpack_device__), ("get_device", get_device), ("__tvm_ffi_env_stream__", __tvm_ffi_env_stream__), + ("__c_dlpack_exchange_api__", core.dlpack_exchange_api_ptr()), ): setattr(core.eager.Tensor, method_name, method) diff --git a/python/paddle/utils/dlpack.py b/python/paddle/utils/dlpack.py index c1b3c21afaea86..68b44cc27f89ce 100644 --- a/python/paddle/utils/dlpack.py +++ b/python/paddle/utils/dlpack.py @@ -75,6 +75,7 @@ class DLDeviceType(enum.IntEnum): kDLWebGPU = (15,) kDLHexagon = (16,) kDLMAIA = (17,) + kDLTrn = (18,) def to_dlpack(x: Tensor) -> CapsuleType: @@ -215,7 +216,7 @@ def from_dlpack( if hasattr(dlpack, "__dlpack__"): kwargs = {} - kwargs["max_version"] = (1, 1) + kwargs["max_version"] = (1, 2) if copy is not None: kwargs["copy"] = copy diff --git a/python/unittest_py/requirements.txt b/python/unittest_py/requirements.txt index ddfccc8090f240..0ccf6d98680f22 100644 --- a/python/unittest_py/requirements.txt +++ b/python/unittest_py/requirements.txt @@ -20,3 +20,4 @@ xdoctest==1.3.0 ubelt==1.3.3 # just for xdoctest mypy==1.17.1 soundfile +apache-tvm-ffi==0.1.0b16 diff --git a/test/dygraph_to_static/test_tensor_attr_consistency.py b/test/dygraph_to_static/test_tensor_attr_consistency.py index 86a4437a7c69ce..b68c2db87fe609 100644 --- a/test/dygraph_to_static/test_tensor_attr_consistency.py +++ b/test/dygraph_to_static/test_tensor_attr_consistency.py @@ -81,6 +81,7 @@ '__dlpack__', "__dlpack_device__", "__tvm_ffi_env_stream__", + "__c_dlpack_exchange_api__", ] ) STATIC_ONLY_TENSOR_ATTRS_ALLOW_LIST = OrderedSet( diff --git a/test/legacy_test/test_tvm_ffi.py b/test/legacy_test/test_tvm_ffi.py index aa6a91b4aa24de..ce1a955932ebe4 100644 --- a/test/legacy_test/test_tvm_ffi.py +++ b/test/legacy_test/test_tvm_ffi.py @@ -12,12 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import platform import unittest +from typing import TYPE_CHECKING + +import numpy as np +import tvm_ffi.cpp import paddle +if TYPE_CHECKING: + from tvm_ffi import Module + -class TestTVMFFI(unittest.TestCase): +class TestTVMFFIEnvStream(unittest.TestCase): def test_tvm_ffi_env_stream_for_gpu_tensor(self): if not paddle.is_compiled_with_cuda(): return @@ -34,5 +44,113 @@ def test_tvm_ffi_env_stream_for_cpu_tensor(self): tensor.__tvm_ffi_env_stream__() +class TestCDLPackExchangeAPI(unittest.TestCase): + def test_c_dlpack_exchange_api_cpu(self): + cpp_source = r""" + void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) { + // implementation of a library function + TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; + DLDataType f32_dtype{kDLFloat, 32, 1}; + TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; + TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor"; + TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor"; + TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape"; + for (int i = 0; i < x->shape[0]; ++i) { + static_cast(y->data)[i] = static_cast(x->data)[i] + 1; + } + } + """ + + mod: Module = tvm_ffi.cpp.load_inline( + name='mod', cpp_sources=cpp_source, functions='add_one_cpu' + ) + + x = paddle.full((3,), 1.0, dtype='float32').cpu() + y = paddle.zeros((3,), dtype='float32').cpu() + mod.add_one_cpu(x, y) + np.testing.assert_allclose(y.numpy(), [2.0, 2.0, 2.0]) + + def test_c_dlpack_exchange_api_gpu(self): + if not paddle.is_compiled_with_cuda(): + return + if paddle.is_compiled_with_rocm(): + # Skip on DCU because CUDA_HOME is not available + return + if platform.system() == "Windows": + # Temporary skip this test case on windows because compile bug on TVM FFI + return + cpp_sources = r""" + void add_one_cuda(tvm::ffi::TensorView x, tvm::ffi::TensorView y); + """ + cuda_sources = r""" + __global__ void AddOneKernel(float* x, float* y, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + y[idx] = x[idx] + 1; + } + } + + void add_one_cuda(tvm::ffi::TensorView x, tvm::ffi::TensorView y) { + // implementation of a library function + TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; + DLDataType f32_dtype{kDLFloat, 32, 1}; + TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; + TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor"; + TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor"; + TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape"; + + int64_t n = x->shape[0]; + int64_t nthread_per_block = 256; + int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block; + // Obtain the current stream from the environment by calling TVMFFIEnvGetStream + cudaStream_t stream = static_cast( + TVMFFIEnvGetStream(x->device.device_type, x->device.device_id)); + // launch the kernel + AddOneKernel<<>>(static_cast(x->data), + static_cast(y->data), n); + } + """ + mod: Module = tvm_ffi.cpp.load_inline( + name='mod', + cpp_sources=cpp_sources, + cuda_sources=cuda_sources, + functions=['add_one_cuda'], + ) + + x = paddle.full((3,), 1.0, dtype='float32').cuda() + y = paddle.zeros((3,), dtype='float32').cuda() + mod.add_one_cuda(x, y) + np.testing.assert_allclose(y.numpy(), [2.0, 2.0, 2.0]) + + def test_c_dlpack_exchange_api_alloc_tensor(self): + if platform.system() == "Windows": + # Temporary skip this test case on windows because return owned tensor created by + # TVMFFIEnvGetTensorAllocator will cause double free error + return + cpp_source = r""" + inline tvm::ffi::Tensor alloc_tensor(tvm::ffi::Shape shape, DLDataType dtype, DLDevice device) { + return tvm::ffi::Tensor::FromDLPackAlloc(TVMFFIEnvGetTensorAllocator(), shape, dtype, device); + } + + tvm::ffi::Tensor add_one_cpu(tvm::ffi::TensorView x) { + TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; + DLDataType f32_dtype{kDLFloat, 32, 1}; + TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; + tvm::ffi::Shape x_shape(x->shape, x->shape + x->ndim); + tvm::ffi::Tensor y = alloc_tensor(x_shape, f32_dtype, x->device); + for (int i = 0; i < x->shape[0]; ++i) { + static_cast(y->data)[i] = static_cast(x->data)[i] + 1; + } + return y; + } + """ + mod: Module = tvm_ffi.cpp.load_inline( + name='mod', cpp_sources=cpp_source, functions=['add_one_cpu'] + ) + x = paddle.full((3,), 1.0, dtype='float32').cpu() + y = mod.add_one_cpu(x) + np.testing.assert_allclose(y.numpy(), [2.0, 2.0, 2.0]) + + if __name__ == '__main__': unittest.main() diff --git a/third_party/dlpack b/third_party/dlpack index 3ea601bb413074..93c8f2a3c774b8 160000 --- a/third_party/dlpack +++ b/third_party/dlpack @@ -1 +1 @@ -Subproject commit 3ea601bb413074c49a77c4ce3218bc08f8c4703c +Subproject commit 93c8f2a3c774b84af6f652b1992c48164fae60fc