@@ -142,6 +142,48 @@ cdef int TVMFFIPyArgSetterObject_(
142142 return 0
143143
144144
145+ cdef int TVMFFIPyArgSetterDLPackExchangeAPI_(
146+ TVMFFIPyArgSetter* this, TVMFFIPyCallContext* ctx,
147+ PyObject* arg, TVMFFIAny* out
148+ ) except - 1 :
149+ cdef DLManagedTensorVersioned* temp_managed_tensor
150+ cdef TVMFFIObjectHandle temp_chandle
151+ cdef void * current_stream = NULL
152+ cdef const DLPackExchangeAPI* api = < const DLPackExchangeAPI* > this.c_dlpack_exchange_api
153+
154+ # Set allocator and ToPyObject converter in context if available
155+ if api.managed_tensor_allocator != NULL :
156+ ctx.c_dlpack_tensor_allocator = api.managed_tensor_allocator
157+ if api.managed_tensor_to_py_object_no_sync != NULL :
158+ ctx.c_dlpack_to_pyobject = < DLPackToPyObject> api.managed_tensor_to_py_object_no_sync
159+
160+ # Convert PyObject to DLPack using the struct's function pointer
161+ if api.managed_tensor_from_py_object_no_sync(arg, & temp_managed_tensor) != 0 :
162+ return - 1
163+
164+ # Query current stream from producer if device is not CPU
165+ if temp_managed_tensor.dl_tensor.device.device_type != kDLCPU:
166+ if ctx.device_type == - 1 and api.current_work_stream != NULL :
167+ # First time seeing a device, query the stream
168+ if api.current_work_stream(
169+ temp_managed_tensor.dl_tensor.device.device_type,
170+ temp_managed_tensor.dl_tensor.device.device_id,
171+ & current_stream
172+ ) == 0 :
173+ ctx.stream = < TVMFFIStreamHandle> current_stream
174+ ctx.device_type = temp_managed_tensor.dl_tensor.device.device_type
175+ ctx.device_id = temp_managed_tensor.dl_tensor.device.device_id
176+
177+ # Convert to TVM Tensor
178+ if TVMFFITensorFromDLPackVersioned(temp_managed_tensor, 0 , 0 , & temp_chandle) != 0 :
179+ raise BufferError(" Failed to convert DLManagedTensorVersioned to ffi.Tensor" )
180+
181+ out.type_index = kTVMFFITensor
182+ out.v_ptr = temp_chandle
183+ TVMFFIPyPushTempFFIObject(ctx, temp_chandle)
184+ return 0
185+
186+
145187cdef int TVMFFIPyArgSetterDLPackCExporter_(
146188 TVMFFIPyArgSetter* this, TVMFFIPyCallContext* ctx,
147189 PyObject* arg, TVMFFIAny* out
@@ -546,17 +588,13 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) exce
546588 out.func = TVMFFIPyArgSetterObjectRValueRef_
547589 return 0
548590 if os.environ.get(" TVM_FFI_SKIP_c_dlpack_from_pyobject" , " 0" ) != " 1" :
549- # external tensors
550- if hasattr (arg, " __c_dlpack_from_pyobject__" ):
551- out.func = TVMFFIPyArgSetterDLPackCExporter_
552- temp_ptr = arg.__c_dlpack_from_pyobject__
553- out.c_dlpack_from_pyobject = < DLPackFromPyObject> temp_ptr
554- if hasattr (arg, " __c_dlpack_to_pyobject__" ):
555- temp_ptr = arg.__c_dlpack_to_pyobject__
556- out.c_dlpack_to_pyobject = < DLPackToPyObject> temp_ptr
557- if hasattr (arg, " __c_dlpack_tensor_allocator__" ):
558- temp_ptr = arg.__c_dlpack_tensor_allocator__
559- out.c_dlpack_tensor_allocator = < DLPackTensorAllocator> temp_ptr
591+ # Check for DLPackExchangeAPI struct (new approach)
592+ # This is checked on the CLASS, not the instance
593+ arg_class = type (arg)
594+ if hasattr (arg_class, " __c_dlpack_exchange_api__" ):
595+ out.func = TVMFFIPyArgSetterDLPackExchangeAPI_
596+ temp_ptr = arg_class.__c_dlpack_exchange_api__
597+ out.c_dlpack_exchange_api = < const void * > temp_ptr
560598 return 0
561599 if torch is not None and isinstance (arg, torch.Tensor):
562600 out.func = TVMFFIPyArgSetterTorchFallback_
0 commit comments