Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/dlpack
Submodule dlpack updated 1 files
+262 −3 include/dlpack/dlpack.h
89 changes: 66 additions & 23 deletions python/tvm_ffi/_optional_torch_c_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,18 +464,41 @@ def load_torch_c_dlpack_extension() -> Any:
{device});
}

void toDLPackNonOwningImpl(const Tensor& tensor, DLTensor& out) {
// Fill in the pre-allocated DLTensor struct with direct pointers
// This is a non-owning conversion - the caller owns the tensor
// and must keep it alive for the duration of DLTensor usage
out.data = tensor.data_ptr();
out.device = torchDeviceToDLDeviceForDLPackv1(tensor.device());
out.ndim = static_cast<int32_t>(tensor.dim());
out.dtype = getDLDataTypeForDLPackv1(tensor);
// sizes() and strides() return pointers to TensorImpl's stable storage
// which remains valid as long as the tensor is alive
out.shape = const_cast<int64_t*>(tensor.sizes().data());
out.strides = const_cast<int64_t*>(tensor.strides().data());
out.byte_offset = 0;
}

} // namespace
} // namespace at

int TorchDLPackFromPyObject(void* py_obj, DLManagedTensorVersioned** out, void** env_stream) {
int TorchDLPackDLTensorFromPyObjectNoSync(void* py_obj, DLTensor* out) {
try {
// Use handle (non-owning) to avoid unnecessary refcount operations
py::handle handle(static_cast<PyObject*>(py_obj));
at::Tensor tensor = handle.cast<at::Tensor>();
at::toDLPackNonOwningImpl(tensor, *out);
return 0;
} catch (const std::exception& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return -1;
}
}

int TorchDLPackManagedTensorFromPyObjectNoSync(void* py_obj, DLManagedTensorVersioned** out) {
try {
py::handle handle(static_cast<PyObject*>(py_obj));
at::Tensor tensor = handle.cast<at::Tensor>();
#ifdef BUILD_WITH_CUDA
if (env_stream != nullptr && tensor.is_cuda()) {
*env_stream = at::cuda::getCurrentCUDAStream(tensor.device().index()).stream();
}
#endif
*out = at::toDLPackImpl<DLManagedTensorVersioned>(tensor);
return 0;
} catch (const std::exception& e) {
Expand All @@ -484,7 +507,7 @@ def load_torch_c_dlpack_extension() -> Any:
}
}

int TorchDLPackToPyObject(DLManagedTensorVersioned* src, void** py_obj_out) {
int TorchDLPackManagedTensorToPyObjectNoSync(DLManagedTensorVersioned* src, void** py_obj_out) {
try {
at::Tensor tensor = at::fromDLPackImpl<DLManagedTensorVersioned>(src, nullptr);
*py_obj_out = THPVariable_Wrap(tensor);
Expand All @@ -495,7 +518,7 @@ def load_torch_c_dlpack_extension() -> Any:
}
}

int TorchDLPackTensorAllocator(
int TorchDLPackManagedTensorAllocator(
DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx,
void (*SetError)(void* error_ctx, const char* kind, const char* message)
) {
Expand All @@ -508,21 +531,45 @@ def load_torch_c_dlpack_extension() -> Any:
*out = at::toDLPackImpl<DLManagedTensorVersioned>(tensor);
return 0;
} catch (const std::exception& e) {
SetError(error_ctx, "TorchDLPackTensorAllocator", e.what());
SetError(error_ctx, "TorchDLPackManagedTensorAllocator", e.what());
return -1;
}
}

int64_t TorchDLPackFromPyObjectPtr() {
return reinterpret_cast<int64_t>(TorchDLPackFromPyObject);
int TorchDLPackCurrentWorkStream(DLDeviceType device_type, int32_t device_id, void** out_stream) {
try {
#ifdef BUILD_WITH_CUDA
if (device_type == kDLCUDA || device_type == kDLROCM) {
*out_stream = at::cuda::getCurrentCUDAStream(device_id).stream();
}
#endif
return 0;
} catch (const std::exception& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return -1;
}
}

int64_t TorchDLPackToPyObjectPtr() {
return reinterpret_cast<int64_t>(TorchDLPackToPyObject);
}
struct TorchDLPackExchangeAPI : public DLPackExchangeAPI {
TorchDLPackExchangeAPI() {
header.version.major = DLPACK_MAJOR_VERSION;
header.version.minor = DLPACK_MINOR_VERSION;
header.prev_api = nullptr;
managed_tensor_allocator = TorchDLPackManagedTensorAllocator;
managed_tensor_from_py_object_no_sync = TorchDLPackManagedTensorFromPyObjectNoSync;
managed_tensor_to_py_object_no_sync = TorchDLPackManagedTensorToPyObjectNoSync;
dltensor_from_py_object_no_sync = TorchDLPackDLTensorFromPyObjectNoSync;
current_work_stream = TorchDLPackCurrentWorkStream;
}

static const DLPackExchangeAPI* Global() {
static TorchDLPackExchangeAPI inst;
return &inst;
}
};

int64_t TorchDLPackTensorAllocatorPtr() {
return reinterpret_cast<int64_t>(TorchDLPackTensorAllocator);
int64_t TorchDLPackExchangeAPIPtr() {
return reinterpret_cast<int64_t>(TorchDLPackExchangeAPI::Global());
}
"""
try:
Expand All @@ -541,17 +588,13 @@ def load_torch_c_dlpack_extension() -> Any:
name="c_dlpack",
cpp_sources=cpp_source,
functions=[
"TorchDLPackFromPyObjectPtr",
"TorchDLPackToPyObjectPtr",
"TorchDLPackTensorAllocatorPtr",
"TorchDLPackExchangeAPIPtr",
],
extra_cflags=extra_cflags,
extra_include_paths=include_paths,
)
# set the dlpack related flags
setattr(torch.Tensor, "__c_dlpack_from_pyobject__", mod.TorchDLPackFromPyObjectPtr())
setattr(torch.Tensor, "__c_dlpack_to_pyobject__", mod.TorchDLPackToPyObjectPtr())
setattr(torch.Tensor, "__c_dlpack_tensor_allocator__", mod.TorchDLPackTensorAllocatorPtr())
# Set the DLPackExchangeAPI pointer on the class
setattr(torch.Tensor, "__c_dlpack_exchange_api__", mod.TorchDLPackExchangeAPIPtr())
return mod
except ImportError:
pass
Expand Down
2 changes: 1 addition & 1 deletion python/tvm_ffi/core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def from_dlpack(
class DLTensorTestWrapper:
"""Wrapper of a Tensor that exposes DLPack protocol, only for testing purpose."""

__c_dlpack_from_pyobject__: int
__c_dlpack_exchange_api__: int
def __init__(self, tensor: Tensor) -> None: ...
def __tvm_ffi_env_stream__(self) -> int: ...
def __dlpack_device__(self) -> tuple[int, int]: ...
Expand Down
53 changes: 47 additions & 6 deletions python/tvm_ffi/cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ from cpython cimport pycapsule, PyCapsule_Destructor
from cpython cimport PyErr_SetNone

cdef extern from "dlpack/dlpack.h":
int DLPACK_MAJOR_VERSION
int DLPACK_MINOR_VERSION

cdef enum:
kDLCPU = 1,
kDLCUDA = 2,
Expand Down Expand Up @@ -77,6 +80,47 @@ cdef extern from "dlpack/dlpack.h":
void (*deleter)(DLManagedTensorVersioned* self)
uint64_t flags

# DLPack Exchange API function pointer types
ctypedef int (*DLPackManagedTensorAllocator)(
DLTensor* prototype,
DLManagedTensorVersioned** out,
void* error_ctx,
void (*SetError)(void* error_ctx, const char* kind, const char* message)
) noexcept

ctypedef int (*DLPackManagedTensorFromPyObjectNoSync)(
void* py_object,
DLManagedTensorVersioned** out
) noexcept

ctypedef int (*DLPackManagedTensorToPyObjectNoSync)(
DLManagedTensorVersioned* tensor,
void** out_py_object
) noexcept

ctypedef int (*DLPackCurrentWorkStream)(
int device_type,
int32_t device_id,
void** out_current_stream
) noexcept

ctypedef int (*DLPackDLTensorFromPyObjectNoSync)(
void* py_object,
DLTensor* out
) noexcept

ctypedef struct DLPackExchangeAPIHeader:
DLPackVersion version
DLPackExchangeAPIHeader* prev_api

ctypedef struct DLPackExchangeAPI:
DLPackExchangeAPIHeader header
DLPackManagedTensorAllocator managed_tensor_allocator
DLPackManagedTensorFromPyObjectNoSync managed_tensor_from_py_object_no_sync
DLPackManagedTensorToPyObjectNoSync managed_tensor_to_py_object_no_sync
DLPackDLTensorFromPyObjectNoSync dltensor_from_py_object_no_sync
DLPackCurrentWorkStream current_work_stream


# Cython binding for TVM FFI C API
cdef extern from "tvm/ffi/c_api.h":
Expand Down Expand Up @@ -285,14 +329,11 @@ cdef extern from "tvm_ffi_python_helpers.h":
int device_type
int device_id
TVMFFIStreamHandle stream
DLPackToPyObject c_dlpack_to_pyobject
DLPackTensorAllocator c_dlpack_tensor_allocator
const DLPackExchangeAPI* c_dlpack_exchange_api

ctypedef struct TVMFFIPyArgSetter:
int (*func)(TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, PyObject* py_arg, TVMFFIAny* out) except -1
DLPackFromPyObject c_dlpack_from_pyobject
DLPackToPyObject c_dlpack_to_pyobject
DLPackTensorAllocator c_dlpack_tensor_allocator
const DLPackExchangeAPI* c_dlpack_exchange_api

ctypedef int (*TVMFFIPyArgSetterFactory)(PyObject* value, TVMFFIPyArgSetter* out) except -1
# The main call function
Expand All @@ -303,7 +344,7 @@ cdef extern from "tvm_ffi_python_helpers.h":
TVMFFIAny* result,
int* c_api_ret_code,
int release_gil,
DLPackToPyObject* out_dlpack_importer
const DLPackExchangeAPI** out_ctx_dlpack_api
) except -1

int TVMFFIPyConstructorCall(
Expand Down
Loading