diff --git a/3rdparty/dlpack b/3rdparty/dlpack index addbc8b3..11173661 160000 --- a/3rdparty/dlpack +++ b/3rdparty/dlpack @@ -1 +1 @@ -Subproject commit addbc8b3d9449691d01827ac4a0e0d035cf8ea40 +Subproject commit 111736618e8d1028b23605f76dcaa6a38cfea809 diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py b/python/tvm_ffi/_optional_torch_c_dlpack.py index c1bb1ef0..94bb3d7a 100644 --- a/python/tvm_ffi/_optional_torch_c_dlpack.py +++ b/python/tvm_ffi/_optional_torch_c_dlpack.py @@ -464,18 +464,41 @@ def load_torch_c_dlpack_extension() -> Any: {device}); } +void toDLPackNonOwningImpl(const Tensor& tensor, DLTensor& out) { + // Fill in the pre-allocated DLTensor struct with direct pointers + // This is a non-owning conversion - the caller owns the tensor + // and must keep it alive for the duration of DLTensor usage + out.data = tensor.data_ptr(); + out.device = torchDeviceToDLDeviceForDLPackv1(tensor.device()); + out.ndim = static_cast(tensor.dim()); + out.dtype = getDLDataTypeForDLPackv1(tensor); + // sizes() and strides() return pointers to TensorImpl's stable storage + // which remains valid as long as the tensor is alive + out.shape = const_cast(tensor.sizes().data()); + out.strides = const_cast(tensor.strides().data()); + out.byte_offset = 0; +} + } // namespace } // namespace at -int TorchDLPackFromPyObject(void* py_obj, DLManagedTensorVersioned** out, void** env_stream) { +int TorchDLPackDLTensorFromPyObjectNoSync(void* py_obj, DLTensor* out) { + try { + // Use handle (non-owning) to avoid unnecessary refcount operations + py::handle handle(static_cast(py_obj)); + at::Tensor tensor = handle.cast(); + at::toDLPackNonOwningImpl(tensor, *out); + return 0; + } catch (const std::exception& e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return -1; + } +} + +int TorchDLPackManagedTensorFromPyObjectNoSync(void* py_obj, DLManagedTensorVersioned** out) { try { py::handle handle(static_cast(py_obj)); at::Tensor tensor = handle.cast(); -#ifdef BUILD_WITH_CUDA - if (env_stream != nullptr && tensor.is_cuda()) { - *env_stream = at::cuda::getCurrentCUDAStream(tensor.device().index()).stream(); - } -#endif *out = at::toDLPackImpl(tensor); return 0; } catch (const std::exception& e) { @@ -484,7 +507,7 @@ def load_torch_c_dlpack_extension() -> Any: } } -int TorchDLPackToPyObject(DLManagedTensorVersioned* src, void** py_obj_out) { +int TorchDLPackManagedTensorToPyObjectNoSync(DLManagedTensorVersioned* src, void** py_obj_out) { try { at::Tensor tensor = at::fromDLPackImpl(src, nullptr); *py_obj_out = THPVariable_Wrap(tensor); @@ -495,7 +518,7 @@ def load_torch_c_dlpack_extension() -> Any: } } -int TorchDLPackTensorAllocator( +int TorchDLPackManagedTensorAllocator( DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, void (*SetError)(void* error_ctx, const char* kind, const char* message) ) { @@ -508,21 +531,45 @@ def load_torch_c_dlpack_extension() -> Any: *out = at::toDLPackImpl(tensor); return 0; } catch (const std::exception& e) { - SetError(error_ctx, "TorchDLPackTensorAllocator", e.what()); + SetError(error_ctx, "TorchDLPackManagedTensorAllocator", e.what()); return -1; } } -int64_t TorchDLPackFromPyObjectPtr() { - return reinterpret_cast(TorchDLPackFromPyObject); +int TorchDLPackCurrentWorkStream(DLDeviceType device_type, int32_t device_id, void** out_stream) { + try { +#ifdef BUILD_WITH_CUDA + if (device_type == kDLCUDA || device_type == kDLROCM) { + *out_stream = at::cuda::getCurrentCUDAStream(device_id).stream(); + } +#endif + return 0; + } catch (const std::exception& e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return -1; + } } -int64_t TorchDLPackToPyObjectPtr() { - return reinterpret_cast(TorchDLPackToPyObject); -} +struct TorchDLPackExchangeAPI : public DLPackExchangeAPI { + TorchDLPackExchangeAPI() { + header.version.major = DLPACK_MAJOR_VERSION; + header.version.minor = DLPACK_MINOR_VERSION; + header.prev_api = nullptr; + managed_tensor_allocator = TorchDLPackManagedTensorAllocator; + managed_tensor_from_py_object_no_sync = TorchDLPackManagedTensorFromPyObjectNoSync; + managed_tensor_to_py_object_no_sync = TorchDLPackManagedTensorToPyObjectNoSync; + dltensor_from_py_object_no_sync = TorchDLPackDLTensorFromPyObjectNoSync; + current_work_stream = TorchDLPackCurrentWorkStream; + } + + static const DLPackExchangeAPI* Global() { + static TorchDLPackExchangeAPI inst; + return &inst; + } +}; -int64_t TorchDLPackTensorAllocatorPtr() { - return reinterpret_cast(TorchDLPackTensorAllocator); +int64_t TorchDLPackExchangeAPIPtr() { + return reinterpret_cast(TorchDLPackExchangeAPI::Global()); } """ try: @@ -541,17 +588,13 @@ def load_torch_c_dlpack_extension() -> Any: name="c_dlpack", cpp_sources=cpp_source, functions=[ - "TorchDLPackFromPyObjectPtr", - "TorchDLPackToPyObjectPtr", - "TorchDLPackTensorAllocatorPtr", + "TorchDLPackExchangeAPIPtr", ], extra_cflags=extra_cflags, extra_include_paths=include_paths, ) - # set the dlpack related flags - setattr(torch.Tensor, "__c_dlpack_from_pyobject__", mod.TorchDLPackFromPyObjectPtr()) - setattr(torch.Tensor, "__c_dlpack_to_pyobject__", mod.TorchDLPackToPyObjectPtr()) - setattr(torch.Tensor, "__c_dlpack_tensor_allocator__", mod.TorchDLPackTensorAllocatorPtr()) + # Set the DLPackExchangeAPI pointer on the class + setattr(torch.Tensor, "__c_dlpack_exchange_api__", mod.TorchDLPackExchangeAPIPtr()) return mod except ImportError: pass diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi index 45a7d28c..787608cd 100644 --- a/python/tvm_ffi/core.pyi +++ b/python/tvm_ffi/core.pyi @@ -485,7 +485,7 @@ def from_dlpack( class DLTensorTestWrapper: """Wrapper of a Tensor that exposes DLPack protocol, only for testing purpose.""" - __c_dlpack_from_pyobject__: int + __c_dlpack_exchange_api__: int def __init__(self, tensor: Tensor) -> None: ... def __tvm_ffi_env_stream__(self) -> int: ... def __dlpack_device__(self) -> tuple[int, int]: ... diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi index a8b4212e..6cad7947 100644 --- a/python/tvm_ffi/cython/base.pxi +++ b/python/tvm_ffi/cython/base.pxi @@ -25,6 +25,9 @@ from cpython cimport pycapsule, PyCapsule_Destructor from cpython cimport PyErr_SetNone cdef extern from "dlpack/dlpack.h": + int DLPACK_MAJOR_VERSION + int DLPACK_MINOR_VERSION + cdef enum: kDLCPU = 1, kDLCUDA = 2, @@ -77,6 +80,47 @@ cdef extern from "dlpack/dlpack.h": void (*deleter)(DLManagedTensorVersioned* self) uint64_t flags + # DLPack Exchange API function pointer types + ctypedef int (*DLPackManagedTensorAllocator)( + DLTensor* prototype, + DLManagedTensorVersioned** out, + void* error_ctx, + void (*SetError)(void* error_ctx, const char* kind, const char* message) + ) noexcept + + ctypedef int (*DLPackManagedTensorFromPyObjectNoSync)( + void* py_object, + DLManagedTensorVersioned** out + ) noexcept + + ctypedef int (*DLPackManagedTensorToPyObjectNoSync)( + DLManagedTensorVersioned* tensor, + void** out_py_object + ) noexcept + + ctypedef int (*DLPackCurrentWorkStream)( + int device_type, + int32_t device_id, + void** out_current_stream + ) noexcept + + ctypedef int (*DLPackDLTensorFromPyObjectNoSync)( + void* py_object, + DLTensor* out + ) noexcept + + ctypedef struct DLPackExchangeAPIHeader: + DLPackVersion version + DLPackExchangeAPIHeader* prev_api + + ctypedef struct DLPackExchangeAPI: + DLPackExchangeAPIHeader header + DLPackManagedTensorAllocator managed_tensor_allocator + DLPackManagedTensorFromPyObjectNoSync managed_tensor_from_py_object_no_sync + DLPackManagedTensorToPyObjectNoSync managed_tensor_to_py_object_no_sync + DLPackDLTensorFromPyObjectNoSync dltensor_from_py_object_no_sync + DLPackCurrentWorkStream current_work_stream + # Cython binding for TVM FFI C API cdef extern from "tvm/ffi/c_api.h": @@ -285,14 +329,11 @@ cdef extern from "tvm_ffi_python_helpers.h": int device_type int device_id TVMFFIStreamHandle stream - DLPackToPyObject c_dlpack_to_pyobject - DLPackTensorAllocator c_dlpack_tensor_allocator + const DLPackExchangeAPI* c_dlpack_exchange_api ctypedef struct TVMFFIPyArgSetter: int (*func)(TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, PyObject* py_arg, TVMFFIAny* out) except -1 - DLPackFromPyObject c_dlpack_from_pyobject - DLPackToPyObject c_dlpack_to_pyobject - DLPackTensorAllocator c_dlpack_tensor_allocator + const DLPackExchangeAPI* c_dlpack_exchange_api ctypedef int (*TVMFFIPyArgSetterFactory)(PyObject* value, TVMFFIPyArgSetter* out) except -1 # The main call function @@ -303,7 +344,7 @@ cdef extern from "tvm_ffi_python_helpers.h": TVMFFIAny* result, int* c_api_ret_code, int release_gil, - DLPackToPyObject* out_dlpack_importer + const DLPackExchangeAPI** out_ctx_dlpack_api ) except -1 int TVMFFIPyConstructorCall( diff --git a/python/tvm_ffi/cython/function.pxi b/python/tvm_ffi/cython/function.pxi index 2fa75fbe..33ed69bf 100644 --- a/python/tvm_ffi/cython/function.pxi +++ b/python/tvm_ffi/cython/function.pxi @@ -55,13 +55,13 @@ cdef inline object make_ret_small_bytes(TVMFFIAny result): return bytearray_to_bytes(&bytes) -cdef inline object make_ret(TVMFFIAny result, DLPackToPyObject c_dlpack_to_pyobject = NULL): +cdef inline object make_ret(TVMFFIAny result, const DLPackExchangeAPI* c_ctx_dlpack_api = NULL): """convert result to return value.""" cdef int32_t type_index type_index = result.type_index if type_index == kTVMFFITensor: # specially handle Tensor as it needs a special dltensor field - return make_tensor_from_any(result, c_dlpack_to_pyobject) + return make_tensor_from_any(result, c_ctx_dlpack_api) elif type_index == kTVMFFIOpaquePyObject: return make_ret_opaque_object(result) elif type_index >= kTVMFFIStaticObjectBegin: @@ -142,35 +142,39 @@ cdef int TVMFFIPyArgSetterObject_( return 0 -cdef int TVMFFIPyArgSetterDLPackCExporter_( +cdef int TVMFFIPyArgSetterDLPackExchangeAPI_( TVMFFIPyArgSetter* this, TVMFFIPyCallContext* ctx, PyObject* arg, TVMFFIAny* out ) except -1: cdef DLManagedTensorVersioned* temp_managed_tensor cdef TVMFFIObjectHandle temp_chandle - cdef TVMFFIStreamHandle env_stream = NULL + cdef void* current_stream = NULL + cdef const DLPackExchangeAPI* api = this.c_dlpack_exchange_api - if this.c_dlpack_to_pyobject != NULL: - ctx.c_dlpack_to_pyobject = this.c_dlpack_to_pyobject - if this.c_dlpack_tensor_allocator != NULL: - ctx.c_dlpack_tensor_allocator = this.c_dlpack_tensor_allocator + # Set the exchange API in context + ctx.c_dlpack_exchange_api = api - if ctx.device_type != -1: - # already queried device, do not do it again, pass NULL to stream - if (this.c_dlpack_from_pyobject)(arg, &temp_managed_tensor, NULL) != 0: - return -1 - else: - # query string on the envrionment stream - if (this.c_dlpack_from_pyobject)(arg, &temp_managed_tensor, &env_stream) != 0: - return -1 - # If device is not CPU, we should set the device type and id - if temp_managed_tensor.dl_tensor.device.device_type != kDLCPU: - ctx.stream = env_stream - ctx.device_type = temp_managed_tensor.dl_tensor.device.device_type - ctx.device_id = temp_managed_tensor.dl_tensor.device.device_id - # run conversion + # Convert PyObject to DLPack using the struct's function pointer + if api.managed_tensor_from_py_object_no_sync(arg, &temp_managed_tensor) != 0: + return -1 + + # Query current stream from producer if device is not CPU + if temp_managed_tensor.dl_tensor.device.device_type != kDLCPU: + if ctx.device_type == -1 and api.current_work_stream != NULL: + # First time seeing a device, query the stream + if api.current_work_stream( + temp_managed_tensor.dl_tensor.device.device_type, + temp_managed_tensor.dl_tensor.device.device_id, + ¤t_stream + ) == 0: + ctx.stream = current_stream + ctx.device_type = temp_managed_tensor.dl_tensor.device.device_type + ctx.device_id = temp_managed_tensor.dl_tensor.device.device_id + + # Convert to TVM Tensor if TVMFFITensorFromDLPackVersioned(temp_managed_tensor, 0, 0, &temp_chandle) != 0: raise BufferError("Failed to convert DLManagedTensorVersioned to ffi.Tensor") + out.type_index = kTVMFFITensor out.v_ptr = temp_chandle TVMFFIPyPushTempFFIObject(ctx, temp_chandle) @@ -179,15 +183,36 @@ cdef int TVMFFIPyArgSetterDLPackCExporter_( cdef int TorchDLPackToPyObjectFallback_( DLManagedTensorVersioned* dltensor, void** py_obj_out -) except -1: +) noexcept: # a bit convoluted but ok as a fallback cdef TVMFFIObjectHandle temp_chandle - TVMFFITensorFromDLPackVersioned(dltensor, 0, 0, &temp_chandle) - tensor = make_tensor_from_chandle(temp_chandle) - torch_tensor = torch.from_dlpack(tensor) - Py_INCREF(torch_tensor) - py_obj_out[0] = (torch_tensor) - return 0 + if TVMFFITensorFromDLPackVersioned(dltensor, 0, 0, &temp_chandle) != 0: + return -1 + try: + tensor = make_tensor_from_chandle(temp_chandle) + torch_tensor = torch.from_dlpack(tensor) + Py_INCREF(torch_tensor) + py_obj_out[0] = (torch_tensor) + return 0 + except Exception: + return -1 + +cdef inline const DLPackExchangeAPI* GetTorchFallbackExchangeAPI() noexcept: + global _torch_fallback_exchange_api + + _torch_fallback_exchange_api.header.version.major = DLPACK_MAJOR_VERSION + _torch_fallback_exchange_api.header.version.minor = DLPACK_MINOR_VERSION + _torch_fallback_exchange_api.header.prev_api = NULL + _torch_fallback_exchange_api.managed_tensor_allocator = NULL + _torch_fallback_exchange_api.managed_tensor_from_py_object_no_sync = NULL + _torch_fallback_exchange_api.managed_tensor_to_py_object_no_sync = TorchDLPackToPyObjectFallback_ + _torch_fallback_exchange_api.dltensor_from_py_object_no_sync = NULL + _torch_fallback_exchange_api.current_work_stream = NULL + + return &_torch_fallback_exchange_api + +# Static storage for the fallback exchange API +cdef DLPackExchangeAPI _torch_fallback_exchange_api cdef int TVMFFIPyArgSetterTorchFallback_( @@ -202,7 +227,7 @@ cdef int TVMFFIPyArgSetterTorchFallback_( out.type_index = kTVMFFITensor out.v_ptr = (arg).chandle temp_dltensor = TVMFFITensorGetDLTensorPtr((arg).chandle) - ctx.c_dlpack_to_pyobject = TorchDLPackToPyObjectFallback_ + ctx.c_dlpack_exchange_api = GetTorchFallbackExchangeAPI() # record the stream and device for torch context if is_cuda and ctx.device_type != -1: ctx.device_type = temp_dltensor.device.device_type @@ -546,17 +571,13 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) exce out.func = TVMFFIPyArgSetterObjectRValueRef_ return 0 if os.environ.get("TVM_FFI_SKIP_c_dlpack_from_pyobject", "0") != "1": - # external tensors - if hasattr(arg, "__c_dlpack_from_pyobject__"): - out.func = TVMFFIPyArgSetterDLPackCExporter_ - temp_ptr = arg.__c_dlpack_from_pyobject__ - out.c_dlpack_from_pyobject = temp_ptr - if hasattr(arg, "__c_dlpack_to_pyobject__"): - temp_ptr = arg.__c_dlpack_to_pyobject__ - out.c_dlpack_to_pyobject = temp_ptr - if hasattr(arg, "__c_dlpack_tensor_allocator__"): - temp_ptr = arg.__c_dlpack_tensor_allocator__ - out.c_dlpack_tensor_allocator = temp_ptr + # Check for DLPackExchangeAPI struct (new approach) + # This is checked on the CLASS, not the instance + arg_class = type(arg) + if hasattr(arg_class, "__c_dlpack_exchange_api__"): + out.func = TVMFFIPyArgSetterDLPackExchangeAPI_ + temp_ptr = arg_class.__c_dlpack_exchange_api__ + out.c_dlpack_exchange_api = (temp_ptr) return 0 if torch is not None and isinstance(arg, torch.Tensor): out.func = TVMFFIPyArgSetterTorchFallback_ @@ -657,7 +678,7 @@ cdef class Function(Object): def __call__(self, *args): cdef TVMFFIAny result cdef int c_api_ret_code - cdef DLPackToPyObject c_dlpack_to_pyobject = NULL + cdef const DLPackExchangeAPI* c_ctx_dlpack_api = NULL # IMPORTANT: caller need to initialize result->type_index to kTVMFFINone result.type_index = kTVMFFINone result.v_int64 = 0 @@ -667,12 +688,12 @@ cdef class Function(Object): &result, &c_api_ret_code, self.release_gil, - &c_dlpack_to_pyobject + &c_ctx_dlpack_api ) # NOTE: logic is same as check_call # directly inline here to simplify the resulting trace if c_api_ret_code == 0: - return make_ret(result, c_dlpack_to_pyobject) + return make_ret(result, c_ctx_dlpack_api) elif c_api_ret_code == -2: raise_existing_error() raise move_from_last_error().py_error() diff --git a/python/tvm_ffi/cython/tensor.pxi b/python/tvm_ffi/cython/tensor.pxi index 74b065b5..4ebc5156 100644 --- a/python/tvm_ffi/cython/tensor.pxi +++ b/python/tvm_ffi/cython/tensor.pxi @@ -275,33 +275,74 @@ _set_class_tensor(Tensor) _register_object_by_index(kTVMFFITensor, Tensor) -cdef int _dltensor_test_wrapper_c_dlpack_from_pyobject( - void* obj, DLManagedTensorVersioned** out, TVMFFIStreamHandle* env_stream +cdef int _dltensor_test_wrapper_from_pyobject( + void* obj, DLManagedTensorVersioned** out ) except -1: + """DLPackExchangeAPI: managed_tensor_from_py_object_no_sync""" cdef PyObject* py_obj = obj cdef DLTensorTestWrapper wrapper = py_obj - cdef TVMFFIStreamHandle current_stream - cdef DLManagedTensorVersioned* temp_managed_tensor - if env_stream != NULL: - env_stream[0] = TVMFFIEnvGetStream( - wrapper.tensor.cdltensor.device.device_type, - wrapper.tensor.cdltensor.device.device_id - ) - return TVMFFITensorToDLPackVersioned(wrapper.tensor.chandle, out) -def _dltensor_test_wrapper_c_dlpack_from_pyobject_as_intptr(): - cdef DLPackFromPyObject converter_func = _dltensor_test_wrapper_c_dlpack_from_pyobject - cdef void* temp_ptr = converter_func - cdef long long temp_int_ptr = temp_ptr - return temp_int_ptr +cdef int _dltensor_test_wrapper_to_pyobject( + DLManagedTensorVersioned* tensor, void** out_py_object +) except -1: + """DLPackExchangeAPI: managed_tensor_to_py_object_no_sync""" + cdef TVMFFIObjectHandle temp_chandle + if TVMFFITensorFromDLPackVersioned(tensor, 0, 0, &temp_chandle) != 0: + return -1 + py_tensor = make_tensor_from_chandle(temp_chandle) + Py_INCREF(py_tensor) + out_py_object[0] = (py_tensor) + return 0 + + +cdef int _dltensor_test_wrapper_current_work_stream( + int device_type, int32_t device_id, void** out_stream +) except -1: + """DLPackExchangeAPI: current_work_stream""" + if device_type != kDLCPU: + out_stream[0] = TVMFFIEnvGetStream(device_type, device_id) + return 0 + + +# Module-level static DLPackExchangeAPI for DLTensorTestWrapper +cdef DLPackExchangeAPI _dltensor_test_wrapper_static_api + +cdef const DLPackExchangeAPI* _dltensor_test_wrapper_get_exchange_api() noexcept: + """Get the static DLPackExchangeAPI instance for DLTensorTestWrapper.""" + global _dltensor_test_wrapper_static_api + + # Initialize header using macros from dlpack.h + _dltensor_test_wrapper_static_api.header.version.major = DLPACK_MAJOR_VERSION + _dltensor_test_wrapper_static_api.header.version.minor = DLPACK_MINOR_VERSION + _dltensor_test_wrapper_static_api.header.prev_api = NULL + + # Initialize function pointers + _dltensor_test_wrapper_static_api.managed_tensor_allocator = NULL + _dltensor_test_wrapper_static_api.managed_tensor_from_py_object_no_sync = ( + _dltensor_test_wrapper_from_pyobject + ) + _dltensor_test_wrapper_static_api.managed_tensor_to_py_object_no_sync = ( + _dltensor_test_wrapper_to_pyobject + ) + _dltensor_test_wrapper_static_api.dltensor_from_py_object_no_sync = NULL + _dltensor_test_wrapper_static_api.current_work_stream = ( + _dltensor_test_wrapper_current_work_stream + ) + + return &_dltensor_test_wrapper_static_api + + +def _dltensor_test_wrapper_exchange_api_ptr(): + """Return the pointer to the DLPackExchangeAPI struct as an integer.""" + return _dltensor_test_wrapper_get_exchange_api() cdef class DLTensorTestWrapper: """Wrapper of a Tensor that exposes DLPack protocol, only for testing purpose. """ - __c_dlpack_from_pyobject__ = _dltensor_test_wrapper_c_dlpack_from_pyobject_as_intptr() + __c_dlpack_exchange_api__ = _dltensor_test_wrapper_exchange_api_ptr() cdef Tensor tensor cdef dict __dict__ @@ -334,19 +375,21 @@ cdef inline object make_ret_dltensor(TVMFFIAny result): return tensor -cdef inline object make_tensor_from_chandle(TVMFFIObjectHandle chandle, DLPackToPyObject c_dlpack_to_pyobject = NULL): +cdef inline object make_tensor_from_chandle( + TVMFFIObjectHandle chandle, const DLPackExchangeAPI* c_ctx_dlpack_api = NULL +): # TODO: Implement cdef Tensor tensor cdef void* py_obj cdef DLManagedTensorVersioned* dlpack - if c_dlpack_to_pyobject != NULL: + if c_ctx_dlpack_api != NULL and c_ctx_dlpack_api.managed_tensor_to_py_object_no_sync != NULL: # try convert and import into the environment array if possible if TVMFFITensorToDLPackVersioned(chandle, &dlpack) == 0: try: # note that py_obj already holds an extra reference to the tensor # so we need to decref it after the conversion - c_dlpack_to_pyobject(dlpack, &py_obj) + c_ctx_dlpack_api.managed_tensor_to_py_object_no_sync(dlpack, &py_obj) tensor = (py_obj) Py_DECREF(tensor) # decref original handle to prevent leak. @@ -365,5 +408,5 @@ cdef inline object make_tensor_from_chandle(TVMFFIObjectHandle chandle, DLPackTo return tensor -cdef inline object make_tensor_from_any(TVMFFIAny any, DLPackToPyObject c_dlpack_to_pyobject): - return make_tensor_from_chandle(any.v_ptr, c_dlpack_to_pyobject) +cdef inline object make_tensor_from_any(TVMFFIAny any, const DLPackExchangeAPI* c_ctx_dlpack_api): + return make_tensor_from_chandle(any.v_ptr, c_ctx_dlpack_api) diff --git a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h index 3e204bc4..f44329d8 100644 --- a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h +++ b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h @@ -89,10 +89,8 @@ struct TVMFFIPyCallContext { void** temp_py_objects = nullptr; /*! \brief the number of temporary arguments */ int num_temp_py_objects = 0; - /*! \brief the DLPack exporter, if any */ - DLPackToPyObject c_dlpack_to_pyobject{nullptr}; - /*! \brief the DLPack allocator, if any */ - DLPackTensorAllocator c_dlpack_tensor_allocator{nullptr}; + /*! \brief the DLPack exchange API, if any */ + const DLPackExchangeAPI* c_dlpack_exchange_api{nullptr}; }; /*! \brief Argument setter for a given python argument. */ @@ -108,17 +106,10 @@ struct TVMFFIPyArgSetter { int (*func)(TVMFFIPyArgSetter* self, TVMFFIPyCallContext* call_ctx, PyObject* arg, TVMFFIAny* out); /*! - * \brief Optional DLPack exporter for for setters that leverages DLPack protocol. + * \brief Optional DLPackExchangeAPI struct pointer. + * This is the new struct-based approach that bundles all DLPack exchange functions. */ - DLPackFromPyObject c_dlpack_from_pyobject{nullptr}; - /*! - * \brief Optional DLPack importer for for setters that leverages DLPack protocol. - */ - DLPackToPyObject c_dlpack_to_pyobject{nullptr}; - /*! - * \brief Optional DLPack allocator for for setters that leverages DLPack protocol. - */ - DLPackTensorAllocator c_dlpack_tensor_allocator{nullptr}; + const DLPackExchangeAPI* c_dlpack_exchange_api{nullptr}; /*! * \brief Invoke the setter. * \param call_ctx The call context. @@ -273,13 +264,14 @@ class TVMFFIPyCallManager { * \param result The result of the function * \param c_api_ret_code The return code of the C-call * \param release_gil Whether to release the GIL - * \param optional_out_dlpack_importer The DLPack importer to be used for the result + * \param optional_out_ctx_dlpack_api The DLPack exchange API to be used for the result * \return 0 on when there is no python error, -1 on python error * \note When an error happens on FFI side, we should return 0 and set c_api_ret_code */ TVM_FFI_INLINE int FuncCall(TVMFFIPyArgSetterFactory setter_factory, void* func_handle, PyObject* py_arg_tuple, TVMFFIAny* result, int* c_api_ret_code, - bool release_gil, DLPackToPyObject* optional_out_dlpack_importer) { + bool release_gil, + const DLPackExchangeAPI** optional_out_ctx_dlpack_api) { int64_t num_args = PyTuple_Size(py_arg_tuple); if (num_args == -1) return -1; try { @@ -300,9 +292,10 @@ class TVMFFIPyCallManager { // setting failed, directly return if (c_api_ret_code[0] != 0) return 0; } - if (ctx.c_dlpack_tensor_allocator != nullptr) { - c_api_ret_code[0] = - TVMFFIEnvSetTensorAllocator(ctx.c_dlpack_tensor_allocator, 0, &prev_tensor_allocator); + if (ctx.c_dlpack_exchange_api != nullptr && + ctx.c_dlpack_exchange_api->managed_tensor_allocator != nullptr) { + c_api_ret_code[0] = TVMFFIEnvSetTensorAllocator( + ctx.c_dlpack_exchange_api->managed_tensor_allocator, 0, &prev_tensor_allocator); if (c_api_ret_code[0] != 0) return 0; } // call the function @@ -323,12 +316,13 @@ class TVMFFIPyCallManager { return -1; } } - if (prev_tensor_allocator != ctx.c_dlpack_tensor_allocator) { + if (ctx.c_dlpack_exchange_api != nullptr && + prev_tensor_allocator != ctx.c_dlpack_exchange_api->managed_tensor_allocator) { c_api_ret_code[0] = TVMFFIEnvSetTensorAllocator(prev_tensor_allocator, 0, nullptr); if (c_api_ret_code[0] != 0) return 0; } - if (optional_out_dlpack_importer != nullptr && ctx.c_dlpack_to_pyobject != nullptr) { - *optional_out_dlpack_importer = ctx.c_dlpack_to_pyobject; + if (optional_out_ctx_dlpack_api != nullptr && ctx.c_dlpack_exchange_api != nullptr) { + *optional_out_ctx_dlpack_api = ctx.c_dlpack_exchange_api; } return 0; } catch (const std::exception& ex) { @@ -379,13 +373,9 @@ class TVMFFIPyCallManager { parent_ctx->device_id = ctx.device_id; parent_ctx->stream = ctx.stream; } - // DLPack allocator - if (parent_ctx->c_dlpack_tensor_allocator == nullptr) { - parent_ctx->c_dlpack_tensor_allocator = ctx.c_dlpack_tensor_allocator; - } - // DLPack importer - if (parent_ctx->c_dlpack_to_pyobject == nullptr) { - parent_ctx->c_dlpack_to_pyobject = ctx.c_dlpack_to_pyobject; + // DLPack exchange API + if (parent_ctx->c_dlpack_exchange_api == nullptr) { + parent_ctx->c_dlpack_exchange_api = ctx.c_dlpack_exchange_api; } } return 0; @@ -490,16 +480,16 @@ class TVMFFIPyCallManager { * \param result The result of the function * \param c_api_ret_code The return code of the function * \param release_gil Whether to release the GIL - * \param out_dlpack_exporter The DLPack exporter to be used for the result + * \param out_ctx_dlpack_api The DLPack exchange API to be used for the result * \return 0 on success, nonzero on failure */ TVM_FFI_INLINE int TVMFFIPyFuncCall(TVMFFIPyArgSetterFactory setter_factory, void* func_handle, PyObject* py_arg_tuple, TVMFFIAny* result, int* c_api_ret_code, bool release_gil = true, - DLPackToPyObject* out_dlpack_importer = nullptr) { + const DLPackExchangeAPI** out_ctx_dlpack_api = nullptr) { return TVMFFIPyCallManager::ThreadLocal()->FuncCall(setter_factory, func_handle, py_arg_tuple, result, c_api_ret_code, release_gil, - out_dlpack_importer); + out_ctx_dlpack_api); } /*! diff --git a/tests/python/test_load_inline.py b/tests/python/test_load_inline.py index b30471bd..795e6910 100644 --- a/tests/python/test_load_inline.py +++ b/tests/python/test_load_inline.py @@ -213,8 +213,8 @@ def test_load_inline_cuda() -> None: @pytest.mark.skipif(torch is None, reason="Requires torch") def test_load_inline_with_env_tensor_allocator() -> None: assert torch is not None - if not hasattr(torch.Tensor, "__c_dlpack_tensor_allocator__"): - pytest.skip("Torch does not support __c_dlpack_tensor_allocator__") + if not hasattr(torch.Tensor, "__c_dlpack_exchange_api__"): + pytest.skip("Torch does not support __c_dlpack_exchange_api__") mod: Module = tvm_ffi.cpp.load_inline( name="hello", cpp_sources=r"""