@@ -467,15 +467,10 @@ def load_torch_c_dlpack_extension() -> Any:
467467} // namespace
468468} // namespace at
469469
470- int TorchDLPackFromPyObject(void* py_obj, DLManagedTensorVersioned** out, void** env_stream ) {
470+ int TorchDLPackFromPyObject(void* py_obj, DLManagedTensorVersioned** out) {
471471 try {
472472 py::handle handle(static_cast<PyObject*>(py_obj));
473473 at::Tensor tensor = handle.cast<at::Tensor>();
474- #ifdef BUILD_WITH_CUDA
475- if (env_stream != nullptr && tensor.is_cuda()) {
476- *env_stream = at::cuda::getCurrentCUDAStream(tensor.device().index()).stream();
477- }
478- #endif
479474 *out = at::toDLPackImpl<DLManagedTensorVersioned>(tensor);
480475 return 0;
481476 } catch (const std::exception& e) {
@@ -513,16 +508,66 @@ def load_torch_c_dlpack_extension() -> Any:
513508 }
514509}
515510
516- int64_t TorchDLPackFromPyObjectPtr() {
517- return reinterpret_cast<int64_t>(TorchDLPackFromPyObject);
511+ int TorchDLTensorFromPyObject(void* py_obj, DLTensor* out) {
512+ try {
513+ // Use handle (non-owning) to avoid unnecessary refcount operations
514+ py::handle handle(static_cast<PyObject*>(py_obj));
515+ const at::Tensor& tensor = handle.cast<const at::Tensor&>();
516+
517+ // Fill in the pre-allocated DLTensor struct with direct pointers
518+ // This is a non-owning conversion - the original PyObject owns the data
519+ // and is kept alive by the caller for the duration of this call
520+ out->data = tensor.data_ptr();
521+ out->device = torchDeviceToDLDeviceForDLPackv1(tensor.device());
522+ out->ndim = static_cast<int32_t>(tensor.dim());
523+ out->dtype = getDLDataTypeForDLPackv1(tensor);
524+ // sizes() and strides() return pointers to TensorImpl's stable storage
525+ // which remains valid as long as the original PyObject is alive
526+ out->shape = const_cast<int64_t*>(tensor.sizes().data());
527+ out->strides = const_cast<int64_t*>(tensor.strides().data());
528+ out->byte_offset = 0;
529+
530+ return 0;
531+ } catch (const std::exception& e) {
532+ PyErr_SetString(PyExc_RuntimeError, e.what());
533+ return -1;
534+ }
518535}
519536
520- int64_t TorchDLPackToPyObjectPtr() {
521- return reinterpret_cast<int64_t>(TorchDLPackToPyObject);
537+ int TorchCurrentWorkStream(DLDeviceType device_type, int32_t device_id, void** out_stream) {
538+ try {
539+ #ifdef BUILD_WITH_CUDA
540+ if (device_type != kDLCPU) {
541+ *out_stream = at::cuda::getCurrentCUDAStream(device_id).stream();
542+ }
543+ #endif
544+ return 0;
545+ } catch (const std::exception& e) {
546+ PyErr_SetString(PyExc_RuntimeError, e.what());
547+ return -1;
548+ }
522549}
523550
524- int64_t TorchDLPackTensorAllocatorPtr() {
525- return reinterpret_cast<int64_t>(TorchDLPackTensorAllocator);
551+ struct TorchDLPackExchangeAPI : public DLPackExchangeAPI {
552+ TorchDLPackExchangeAPI() {
553+ version.major = DLPACK_MAJOR_VERSION;
554+ version.minor = DLPACK_MINOR_VERSION;
555+ prev_version_api = nullptr;
556+ managed_tensor_allocator = TorchDLPackTensorAllocator;
557+ managed_tensor_from_py_object_no_sync = TorchDLPackFromPyObject;
558+ managed_tensor_to_py_object_no_sync = TorchDLPackToPyObject;
559+ dltensor_from_py_object_no_sync = TorchDLTensorFromPyObject;
560+ current_work_stream = TorchCurrentWorkStream;
561+ }
562+
563+ static const DLPackExchangeAPI* Global() {
564+ static TorchDLPackExchangeAPI inst;
565+ return &inst;
566+ }
567+ };
568+
569+ int64_t TorchDLPackExchangeAPIPtr() {
570+ return reinterpret_cast<int64_t>(TorchDLPackExchangeAPI::Global());
526571}
527572 """
528573 try :
@@ -541,17 +586,13 @@ def load_torch_c_dlpack_extension() -> Any:
541586 name = "c_dlpack" ,
542587 cpp_sources = cpp_source ,
543588 functions = [
544- "TorchDLPackFromPyObjectPtr" ,
545- "TorchDLPackToPyObjectPtr" ,
546- "TorchDLPackTensorAllocatorPtr" ,
589+ "TorchDLPackExchangeAPIPtr" ,
547590 ],
548591 extra_cflags = extra_cflags ,
549592 extra_include_paths = include_paths ,
550593 )
551- # set the dlpack related flags
552- setattr (torch .Tensor , "__c_dlpack_from_pyobject__" , mod .TorchDLPackFromPyObjectPtr ())
553- setattr (torch .Tensor , "__c_dlpack_to_pyobject__" , mod .TorchDLPackToPyObjectPtr ())
554- setattr (torch .Tensor , "__c_dlpack_tensor_allocator__" , mod .TorchDLPackTensorAllocatorPtr ())
594+ # Set the DLPackExchangeAPI pointer on the class
595+ setattr (torch .Tensor , "__c_dlpack_exchange_api__" , mod .TorchDLPackExchangeAPIPtr ())
555596 return mod
556597 except ImportError :
557598 pass
0 commit comments