Skip to content

Commit a0121da

Browse files
committed
update
1 parent 0be336e commit a0121da

File tree

1 file changed

+60
-19
lines changed

1 file changed

+60
-19
lines changed

python/tvm_ffi/_optional_torch_c_dlpack.py

Lines changed: 60 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -467,15 +467,10 @@ def load_torch_c_dlpack_extension() -> Any:
467467
} // namespace
468468
} // namespace at
469469
470-
int TorchDLPackFromPyObject(void* py_obj, DLManagedTensorVersioned** out, void** env_stream) {
470+
int TorchDLPackFromPyObject(void* py_obj, DLManagedTensorVersioned** out) {
471471
try {
472472
py::handle handle(static_cast<PyObject*>(py_obj));
473473
at::Tensor tensor = handle.cast<at::Tensor>();
474-
#ifdef BUILD_WITH_CUDA
475-
if (env_stream != nullptr && tensor.is_cuda()) {
476-
*env_stream = at::cuda::getCurrentCUDAStream(tensor.device().index()).stream();
477-
}
478-
#endif
479474
*out = at::toDLPackImpl<DLManagedTensorVersioned>(tensor);
480475
return 0;
481476
} catch (const std::exception& e) {
@@ -513,16 +508,66 @@ def load_torch_c_dlpack_extension() -> Any:
513508
}
514509
}
515510
516-
int64_t TorchDLPackFromPyObjectPtr() {
517-
return reinterpret_cast<int64_t>(TorchDLPackFromPyObject);
511+
int TorchDLTensorFromPyObject(void* py_obj, DLTensor* out) {
512+
try {
513+
// Use handle (non-owning) to avoid unnecessary refcount operations
514+
py::handle handle(static_cast<PyObject*>(py_obj));
515+
const at::Tensor& tensor = handle.cast<const at::Tensor&>();
516+
517+
// Fill in the pre-allocated DLTensor struct with direct pointers
518+
// This is a non-owning conversion - the original PyObject owns the data
519+
// and is kept alive by the caller for the duration of this call
520+
out->data = tensor.data_ptr();
521+
out->device = torchDeviceToDLDeviceForDLPackv1(tensor.device());
522+
out->ndim = static_cast<int32_t>(tensor.dim());
523+
out->dtype = getDLDataTypeForDLPackv1(tensor);
524+
// sizes() and strides() return pointers to TensorImpl's stable storage
525+
// which remains valid as long as the original PyObject is alive
526+
out->shape = const_cast<int64_t*>(tensor.sizes().data());
527+
out->strides = const_cast<int64_t*>(tensor.strides().data());
528+
out->byte_offset = 0;
529+
530+
return 0;
531+
} catch (const std::exception& e) {
532+
PyErr_SetString(PyExc_RuntimeError, e.what());
533+
return -1;
534+
}
518535
}
519536
520-
int64_t TorchDLPackToPyObjectPtr() {
521-
return reinterpret_cast<int64_t>(TorchDLPackToPyObject);
537+
int TorchCurrentWorkStream(DLDeviceType device_type, int32_t device_id, void** out_stream) {
538+
try {
539+
#ifdef BUILD_WITH_CUDA
540+
if (device_type != kDLCPU) {
541+
*out_stream = at::cuda::getCurrentCUDAStream(device_id).stream();
542+
}
543+
#endif
544+
return 0;
545+
} catch (const std::exception& e) {
546+
PyErr_SetString(PyExc_RuntimeError, e.what());
547+
return -1;
548+
}
522549
}
523550
524-
int64_t TorchDLPackTensorAllocatorPtr() {
525-
return reinterpret_cast<int64_t>(TorchDLPackTensorAllocator);
551+
struct TorchDLPackExchangeAPI : public DLPackExchangeAPI {
552+
TorchDLPackExchangeAPI() {
553+
version.major = DLPACK_MAJOR_VERSION;
554+
version.minor = DLPACK_MINOR_VERSION;
555+
prev_version_api = nullptr;
556+
managed_tensor_allocator = TorchDLPackTensorAllocator;
557+
managed_tensor_from_py_object_no_sync = TorchDLPackFromPyObject;
558+
managed_tensor_to_py_object_no_sync = TorchDLPackToPyObject;
559+
dltensor_from_py_object_no_sync = TorchDLTensorFromPyObject;
560+
current_work_stream = TorchCurrentWorkStream;
561+
}
562+
563+
static const DLPackExchangeAPI* Global() {
564+
static TorchDLPackExchangeAPI inst;
565+
return &inst;
566+
}
567+
};
568+
569+
int64_t TorchDLPackExchangeAPIPtr() {
570+
return reinterpret_cast<int64_t>(TorchDLPackExchangeAPI::Global());
526571
}
527572
"""
528573
try:
@@ -541,17 +586,13 @@ def load_torch_c_dlpack_extension() -> Any:
541586
name="c_dlpack",
542587
cpp_sources=cpp_source,
543588
functions=[
544-
"TorchDLPackFromPyObjectPtr",
545-
"TorchDLPackToPyObjectPtr",
546-
"TorchDLPackTensorAllocatorPtr",
589+
"TorchDLPackExchangeAPIPtr",
547590
],
548591
extra_cflags=extra_cflags,
549592
extra_include_paths=include_paths,
550593
)
551-
# set the dlpack related flags
552-
setattr(torch.Tensor, "__c_dlpack_from_pyobject__", mod.TorchDLPackFromPyObjectPtr())
553-
setattr(torch.Tensor, "__c_dlpack_to_pyobject__", mod.TorchDLPackToPyObjectPtr())
554-
setattr(torch.Tensor, "__c_dlpack_tensor_allocator__", mod.TorchDLPackTensorAllocatorPtr())
594+
# Set the DLPackExchangeAPI pointer on the class
595+
setattr(torch.Tensor, "__c_dlpack_exchange_api__", mod.TorchDLPackExchangeAPIPtr())
555596
return mod
556597
except ImportError:
557598
pass

0 commit comments

Comments
 (0)