Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/dlpack
Submodule dlpack updated 1 files
+262 −3 include/dlpack/dlpack.h
89 changes: 66 additions & 23 deletions python/tvm_ffi/_optional_torch_c_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,18 +464,41 @@ def load_torch_c_dlpack_extension() -> Any:
{device});
}

void toDLPackNonOwningImpl(const Tensor& tensor, DLTensor& out) {
// Fill in the pre-allocated DLTensor struct with direct pointers
// This is a non-owning conversion - the caller owns the tensor
// and must keep it alive for the duration of DLTensor usage
out.data = tensor.data_ptr();
out.device = torchDeviceToDLDeviceForDLPackv1(tensor.device());
out.ndim = static_cast<int32_t>(tensor.dim());
out.dtype = getDLDataTypeForDLPackv1(tensor);
// sizes() and strides() return pointers to TensorImpl's stable storage
// which remains valid as long as the tensor is alive
out.shape = const_cast<int64_t*>(tensor.sizes().data());
out.strides = const_cast<int64_t*>(tensor.strides().data());
out.byte_offset = 0;
}

} // namespace
} // namespace at

int TorchDLPackFromPyObject(void* py_obj, DLManagedTensorVersioned** out, void** env_stream) {
int TorchDLPackDLTensorFromPyObjectNoSync(void* py_obj, DLTensor* out) {
try {
// Use handle (non-owning) to avoid unnecessary refcount operations
py::handle handle(static_cast<PyObject*>(py_obj));
at::Tensor tensor = handle.cast<at::Tensor>();
at::toDLPackNonOwningImpl(tensor, *out);
return 0;
} catch (const std::exception& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return -1;
}
}

int TorchDLPackManagedTensorFromPyObjectNoSync(void* py_obj, DLManagedTensorVersioned** out) {
try {
py::handle handle(static_cast<PyObject*>(py_obj));
at::Tensor tensor = handle.cast<at::Tensor>();
#ifdef BUILD_WITH_CUDA
if (env_stream != nullptr && tensor.is_cuda()) {
*env_stream = at::cuda::getCurrentCUDAStream(tensor.device().index()).stream();
}
#endif
*out = at::toDLPackImpl<DLManagedTensorVersioned>(tensor);
return 0;
} catch (const std::exception& e) {
Expand All @@ -484,7 +507,7 @@ def load_torch_c_dlpack_extension() -> Any:
}
}

int TorchDLPackToPyObject(DLManagedTensorVersioned* src, void** py_obj_out) {
int TorchDLPackManagedTensorToPyObjectNoSync(DLManagedTensorVersioned* src, void** py_obj_out) {
try {
at::Tensor tensor = at::fromDLPackImpl<DLManagedTensorVersioned>(src, nullptr);
*py_obj_out = THPVariable_Wrap(tensor);
Expand All @@ -495,7 +518,7 @@ def load_torch_c_dlpack_extension() -> Any:
}
}

int TorchDLPackTensorAllocator(
int TorchDLPackManagedTensorAllocator(
DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx,
void (*SetError)(void* error_ctx, const char* kind, const char* message)
) {
Expand All @@ -508,21 +531,45 @@ def load_torch_c_dlpack_extension() -> Any:
*out = at::toDLPackImpl<DLManagedTensorVersioned>(tensor);
return 0;
} catch (const std::exception& e) {
SetError(error_ctx, "TorchDLPackTensorAllocator", e.what());
SetError(error_ctx, "TorchDLPackManagedTensorAllocator", e.what());
return -1;
}
}

int64_t TorchDLPackFromPyObjectPtr() {
return reinterpret_cast<int64_t>(TorchDLPackFromPyObject);
int TorchDLPackCurrentWorkStream(DLDeviceType device_type, int32_t device_id, void** out_stream) {
try {
#ifdef BUILD_WITH_CUDA
if (device_type != kDLCPU) {
*out_stream = at::cuda::getCurrentCUDAStream(device_id).stream();
}
#endif
return 0;
} catch (const std::exception& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return -1;
}
}

int64_t TorchDLPackToPyObjectPtr() {
return reinterpret_cast<int64_t>(TorchDLPackToPyObject);
}
struct TorchDLPackExchangeAPI : public DLPackExchangeAPI {
TorchDLPackExchangeAPI() {
header.version.major = DLPACK_MAJOR_VERSION;
header.version.minor = DLPACK_MINOR_VERSION;
header.prev_api = nullptr;
managed_tensor_allocator = TorchDLPackManagedTensorAllocator;
managed_tensor_from_py_object_no_sync = TorchDLPackManagedTensorFromPyObjectNoSync;
managed_tensor_to_py_object_no_sync = TorchDLPackManagedTensorToPyObjectNoSync;
dltensor_from_py_object_no_sync = TorchDLPackDLTensorFromPyObjectNoSync;
current_work_stream = TorchDLPackCurrentWorkStream;
}

static const DLPackExchangeAPI* Global() {
static TorchDLPackExchangeAPI inst;
return &inst;
}
};

int64_t TorchDLPackTensorAllocatorPtr() {
return reinterpret_cast<int64_t>(TorchDLPackTensorAllocator);
int64_t TorchDLPackExchangeAPIPtr() {
return reinterpret_cast<int64_t>(TorchDLPackExchangeAPI::Global());
}
"""
try:
Expand All @@ -541,17 +588,13 @@ def load_torch_c_dlpack_extension() -> Any:
name="c_dlpack",
cpp_sources=cpp_source,
functions=[
"TorchDLPackFromPyObjectPtr",
"TorchDLPackToPyObjectPtr",
"TorchDLPackTensorAllocatorPtr",
"TorchDLPackExchangeAPIPtr",
],
extra_cflags=extra_cflags,
extra_include_paths=include_paths,
)
# set the dlpack related flags
setattr(torch.Tensor, "__c_dlpack_from_pyobject__", mod.TorchDLPackFromPyObjectPtr())
setattr(torch.Tensor, "__c_dlpack_to_pyobject__", mod.TorchDLPackToPyObjectPtr())
setattr(torch.Tensor, "__c_dlpack_tensor_allocator__", mod.TorchDLPackTensorAllocatorPtr())
# Set the DLPackExchangeAPI pointer on the class
setattr(torch.Tensor, "__c_dlpack_exchange_api__", mod.TorchDLPackExchangeAPIPtr())
return mod
except ImportError:
pass
Expand Down
2 changes: 1 addition & 1 deletion python/tvm_ffi/core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def from_dlpack(
class DLTensorTestWrapper:
"""Wrapper of a Tensor that exposes DLPack protocol, only for testing purpose."""

__c_dlpack_from_pyobject__: int
__c_dlpack_exchange_api__: int
def __init__(self, tensor: Tensor) -> None: ...
def __tvm_ffi_env_stream__(self) -> int: ...
def __dlpack_device__(self) -> tuple[int, int]: ...
Expand Down
51 changes: 46 additions & 5 deletions python/tvm_ffi/cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ from cpython cimport pycapsule, PyCapsule_Destructor
from cpython cimport PyErr_SetNone

cdef extern from "dlpack/dlpack.h":
int DLPACK_MAJOR_VERSION
int DLPACK_MINOR_VERSION

cdef enum:
kDLCPU = 1,
kDLCUDA = 2,
Expand Down Expand Up @@ -77,6 +80,47 @@ cdef extern from "dlpack/dlpack.h":
void (*deleter)(DLManagedTensorVersioned* self)
uint64_t flags

# DLPack Exchange API function pointer types
ctypedef int (*DLPackManagedTensorAllocator)(
DLTensor* prototype,
DLManagedTensorVersioned** out,
void* error_ctx,
void (*SetError)(void* error_ctx, const char* kind, const char* message)
) noexcept

ctypedef int (*DLPackManagedTensorFromPyObjectNoSync)(
void* py_object,
DLManagedTensorVersioned** out
) noexcept

ctypedef int (*DLPackManagedTensorToPyObjectNoSync)(
DLManagedTensorVersioned* tensor,
void** out_py_object
) noexcept

ctypedef int (*DLPackCurrentWorkStream)(
int device_type,
int32_t device_id,
void** out_current_stream
) noexcept

ctypedef int (*DLPackDLTensorFromPyObjectNoSync)(
void* py_object,
DLTensor* out
) noexcept

ctypedef struct DLPackExchangeAPIHeader:
DLPackVersion version
DLPackExchangeAPIHeader* prev_api

ctypedef struct DLPackExchangeAPI:
DLPackExchangeAPIHeader header
DLPackManagedTensorAllocator managed_tensor_allocator
DLPackManagedTensorFromPyObjectNoSync managed_tensor_from_py_object_no_sync
DLPackManagedTensorToPyObjectNoSync managed_tensor_to_py_object_no_sync
DLPackDLTensorFromPyObjectNoSync dltensor_from_py_object_no_sync
DLPackCurrentWorkStream current_work_stream


# Cython binding for TVM FFI C API
cdef extern from "tvm/ffi/c_api.h":
Expand Down Expand Up @@ -285,14 +329,11 @@ cdef extern from "tvm_ffi_python_helpers.h":
int device_type
int device_id
TVMFFIStreamHandle stream
DLPackToPyObject c_dlpack_to_pyobject
DLPackTensorAllocator c_dlpack_tensor_allocator
const DLPackExchangeAPI* c_dlpack_exchange_api

ctypedef struct TVMFFIPyArgSetter:
int (*func)(TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, PyObject* py_arg, TVMFFIAny* out) except -1
DLPackFromPyObject c_dlpack_from_pyobject
DLPackToPyObject c_dlpack_to_pyobject
DLPackTensorAllocator c_dlpack_tensor_allocator
const DLPackExchangeAPI* c_dlpack_exchange_api

ctypedef int (*TVMFFIPyArgSetterFactory)(PyObject* value, TVMFFIPyArgSetter* out) except -1
# The main call function
Expand Down
99 changes: 60 additions & 39 deletions python/tvm_ffi/cython/function.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -142,35 +142,39 @@ cdef int TVMFFIPyArgSetterObject_(
return 0


cdef int TVMFFIPyArgSetterDLPackCExporter_(
cdef int TVMFFIPyArgSetterDLPackExchangeAPI_(
TVMFFIPyArgSetter* this, TVMFFIPyCallContext* ctx,
PyObject* arg, TVMFFIAny* out
) except -1:
cdef DLManagedTensorVersioned* temp_managed_tensor
cdef TVMFFIObjectHandle temp_chandle
cdef TVMFFIStreamHandle env_stream = NULL
cdef void* current_stream = NULL
cdef const DLPackExchangeAPI* api = this.c_dlpack_exchange_api

if this.c_dlpack_to_pyobject != NULL:
ctx.c_dlpack_to_pyobject = this.c_dlpack_to_pyobject
if this.c_dlpack_tensor_allocator != NULL:
ctx.c_dlpack_tensor_allocator = this.c_dlpack_tensor_allocator
# Set the exchange API in context
ctx.c_dlpack_exchange_api = api

if ctx.device_type != -1:
# already queried device, do not do it again, pass NULL to stream
if (this.c_dlpack_from_pyobject)(arg, &temp_managed_tensor, NULL) != 0:
return -1
else:
# query string on the envrionment stream
if (this.c_dlpack_from_pyobject)(arg, &temp_managed_tensor, &env_stream) != 0:
return -1
# If device is not CPU, we should set the device type and id
if temp_managed_tensor.dl_tensor.device.device_type != kDLCPU:
ctx.stream = env_stream
ctx.device_type = temp_managed_tensor.dl_tensor.device.device_type
ctx.device_id = temp_managed_tensor.dl_tensor.device.device_id
# run conversion
# Convert PyObject to DLPack using the struct's function pointer
if api.managed_tensor_from_py_object_no_sync(arg, &temp_managed_tensor) != 0:
return -1

# Query current stream from producer if device is not CPU
if temp_managed_tensor.dl_tensor.device.device_type != kDLCPU:
if ctx.device_type == -1 and api.current_work_stream != NULL:
# First time seeing a device, query the stream
if api.current_work_stream(
temp_managed_tensor.dl_tensor.device.device_type,
temp_managed_tensor.dl_tensor.device.device_id,
&current_stream
) == 0:
ctx.stream = <TVMFFIStreamHandle>current_stream
ctx.device_type = temp_managed_tensor.dl_tensor.device.device_type
ctx.device_id = temp_managed_tensor.dl_tensor.device.device_id

# Convert to TVM Tensor
if TVMFFITensorFromDLPackVersioned(temp_managed_tensor, 0, 0, &temp_chandle) != 0:
raise BufferError("Failed to convert DLManagedTensorVersioned to ffi.Tensor")

out.type_index = kTVMFFITensor
out.v_ptr = temp_chandle
TVMFFIPyPushTempFFIObject(ctx, temp_chandle)
Expand All @@ -179,15 +183,36 @@ cdef int TVMFFIPyArgSetterDLPackCExporter_(

cdef int TorchDLPackToPyObjectFallback_(
DLManagedTensorVersioned* dltensor, void** py_obj_out
) except -1:
) noexcept:
# a bit convoluted but ok as a fallback
cdef TVMFFIObjectHandle temp_chandle
TVMFFITensorFromDLPackVersioned(dltensor, 0, 0, &temp_chandle)
tensor = make_tensor_from_chandle(temp_chandle)
torch_tensor = torch.from_dlpack(tensor)
Py_INCREF(torch_tensor)
py_obj_out[0] = <void*>(<PyObject*>torch_tensor)
return 0
if TVMFFITensorFromDLPackVersioned(dltensor, 0, 0, &temp_chandle) != 0:
return -1
try:
tensor = make_tensor_from_chandle(temp_chandle)
torch_tensor = torch.from_dlpack(tensor)
Py_INCREF(torch_tensor)
py_obj_out[0] = <void*>(<PyObject*>torch_tensor)
return 0
except Exception:
return -1

cdef inline const DLPackExchangeAPI* GetTorchFallbackExchangeAPI() noexcept:
global _torch_fallback_exchange_api

_torch_fallback_exchange_api.header.version.major = DLPACK_MAJOR_VERSION
_torch_fallback_exchange_api.header.version.minor = DLPACK_MINOR_VERSION
_torch_fallback_exchange_api.header.prev_api = NULL
_torch_fallback_exchange_api.managed_tensor_allocator = NULL
_torch_fallback_exchange_api.managed_tensor_from_py_object_no_sync = NULL
_torch_fallback_exchange_api.managed_tensor_to_py_object_no_sync = TorchDLPackToPyObjectFallback_
_torch_fallback_exchange_api.dltensor_from_py_object_no_sync = NULL
_torch_fallback_exchange_api.current_work_stream = NULL

return &_torch_fallback_exchange_api

# Static storage for the fallback exchange API
cdef DLPackExchangeAPI _torch_fallback_exchange_api


cdef int TVMFFIPyArgSetterTorchFallback_(
Expand All @@ -202,7 +227,7 @@ cdef int TVMFFIPyArgSetterTorchFallback_(
out.type_index = kTVMFFITensor
out.v_ptr = (<Tensor>arg).chandle
temp_dltensor = TVMFFITensorGetDLTensorPtr((<Tensor>arg).chandle)
ctx.c_dlpack_to_pyobject = TorchDLPackToPyObjectFallback_
ctx.c_dlpack_exchange_api = GetTorchFallbackExchangeAPI()
# record the stream and device for torch context
if is_cuda and ctx.device_type != -1:
ctx.device_type = temp_dltensor.device.device_type
Expand Down Expand Up @@ -546,17 +571,13 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) exce
out.func = TVMFFIPyArgSetterObjectRValueRef_
return 0
if os.environ.get("TVM_FFI_SKIP_c_dlpack_from_pyobject", "0") != "1":
# external tensors
if hasattr(arg, "__c_dlpack_from_pyobject__"):
out.func = TVMFFIPyArgSetterDLPackCExporter_
temp_ptr = arg.__c_dlpack_from_pyobject__
out.c_dlpack_from_pyobject = <DLPackFromPyObject>temp_ptr
if hasattr(arg, "__c_dlpack_to_pyobject__"):
temp_ptr = arg.__c_dlpack_to_pyobject__
out.c_dlpack_to_pyobject = <DLPackToPyObject>temp_ptr
if hasattr(arg, "__c_dlpack_tensor_allocator__"):
temp_ptr = arg.__c_dlpack_tensor_allocator__
out.c_dlpack_tensor_allocator = <DLPackTensorAllocator>temp_ptr
# Check for DLPackExchangeAPI struct (new approach)
# This is checked on the CLASS, not the instance
arg_class = type(arg)
if hasattr(arg_class, "__c_dlpack_exchange_api__"):
out.func = TVMFFIPyArgSetterDLPackExchangeAPI_
temp_ptr = arg_class.__c_dlpack_exchange_api__
out.c_dlpack_exchange_api = <const DLPackExchangeAPI*>(<long long>temp_ptr)
return 0
if torch is not None and isinstance(arg, torch.Tensor):
out.func = TVMFFIPyArgSetterTorchFallback_
Expand Down
Loading