Skip to content

Commit 5fabff1

Browse files
committed
update
1 parent a0121da commit 5fabff1

File tree

1 file changed

+49
-11
lines changed

1 file changed

+49
-11
lines changed

python/tvm_ffi/cython/function.pxi

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,48 @@ cdef int TVMFFIPyArgSetterObject_(
142142
return 0
143143

144144

145+
cdef int TVMFFIPyArgSetterDLPackExchangeAPI_(
146+
TVMFFIPyArgSetter* this, TVMFFIPyCallContext* ctx,
147+
PyObject* arg, TVMFFIAny* out
148+
) except -1:
149+
cdef DLManagedTensorVersioned* temp_managed_tensor
150+
cdef TVMFFIObjectHandle temp_chandle
151+
cdef void* current_stream = NULL
152+
cdef const DLPackExchangeAPI* api = <const DLPackExchangeAPI*>this.c_dlpack_exchange_api
153+
154+
# Set allocator and ToPyObject converter in context if available
155+
if api.managed_tensor_allocator != NULL:
156+
ctx.c_dlpack_tensor_allocator = api.managed_tensor_allocator
157+
if api.managed_tensor_to_py_object_no_sync != NULL:
158+
ctx.c_dlpack_to_pyobject = <DLPackToPyObject>api.managed_tensor_to_py_object_no_sync
159+
160+
# Convert PyObject to DLPack using the struct's function pointer
161+
if api.managed_tensor_from_py_object_no_sync(arg, &temp_managed_tensor) != 0:
162+
return -1
163+
164+
# Query current stream from producer if device is not CPU
165+
if temp_managed_tensor.dl_tensor.device.device_type != kDLCPU:
166+
if ctx.device_type == -1 and api.current_work_stream != NULL:
167+
# First time seeing a device, query the stream
168+
if api.current_work_stream(
169+
temp_managed_tensor.dl_tensor.device.device_type,
170+
temp_managed_tensor.dl_tensor.device.device_id,
171+
&current_stream
172+
) == 0:
173+
ctx.stream = <TVMFFIStreamHandle>current_stream
174+
ctx.device_type = temp_managed_tensor.dl_tensor.device.device_type
175+
ctx.device_id = temp_managed_tensor.dl_tensor.device.device_id
176+
177+
# Convert to TVM Tensor
178+
if TVMFFITensorFromDLPackVersioned(temp_managed_tensor, 0, 0, &temp_chandle) != 0:
179+
raise BufferError("Failed to convert DLManagedTensorVersioned to ffi.Tensor")
180+
181+
out.type_index = kTVMFFITensor
182+
out.v_ptr = temp_chandle
183+
TVMFFIPyPushTempFFIObject(ctx, temp_chandle)
184+
return 0
185+
186+
145187
cdef int TVMFFIPyArgSetterDLPackCExporter_(
146188
TVMFFIPyArgSetter* this, TVMFFIPyCallContext* ctx,
147189
PyObject* arg, TVMFFIAny* out
@@ -546,17 +588,13 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) exce
546588
out.func = TVMFFIPyArgSetterObjectRValueRef_
547589
return 0
548590
if os.environ.get("TVM_FFI_SKIP_c_dlpack_from_pyobject", "0") != "1":
549-
# external tensors
550-
if hasattr(arg, "__c_dlpack_from_pyobject__"):
551-
out.func = TVMFFIPyArgSetterDLPackCExporter_
552-
temp_ptr = arg.__c_dlpack_from_pyobject__
553-
out.c_dlpack_from_pyobject = <DLPackFromPyObject>temp_ptr
554-
if hasattr(arg, "__c_dlpack_to_pyobject__"):
555-
temp_ptr = arg.__c_dlpack_to_pyobject__
556-
out.c_dlpack_to_pyobject = <DLPackToPyObject>temp_ptr
557-
if hasattr(arg, "__c_dlpack_tensor_allocator__"):
558-
temp_ptr = arg.__c_dlpack_tensor_allocator__
559-
out.c_dlpack_tensor_allocator = <DLPackTensorAllocator>temp_ptr
591+
# Check for DLPackExchangeAPI struct (new approach)
592+
# This is checked on the CLASS, not the instance
593+
arg_class = type(arg)
594+
if hasattr(arg_class, "__c_dlpack_exchange_api__"):
595+
out.func = TVMFFIPyArgSetterDLPackExchangeAPI_
596+
temp_ptr = arg_class.__c_dlpack_exchange_api__
597+
out.c_dlpack_exchange_api = <const void*>temp_ptr
560598
return 0
561599
if torch is not None and isinstance(arg, torch.Tensor):
562600
out.func = TVMFFIPyArgSetterTorchFallback_

0 commit comments

Comments
 (0)