Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions python/tvm_ffi/core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ class PyNativeObject:
) -> None: ...

def _set_class_object(cls: type) -> None: ...
def _register_object_by_index(index: int, cls: type) -> None: ...
def _register_object_by_index(type_index: int, type_cls: type) -> TypeInfo: ...
def _object_type_key_to_index(type_key: str) -> int | None: ...
def _add_class_attrs_by_reflection(type_index: int, cls: type) -> type: ...
def _lookup_type_info_from_type_key(type_key: str) -> TypeInfo: ...

class Error(Object):
"""Base class for FFI errors."""
Expand Down Expand Up @@ -225,3 +225,38 @@ class Bytes(bytes, PyNativeObject):

# pylint: disable=no-self-argument
def __from_tvm_ffi_object__(cls, obj: Any) -> Bytes: ...

# ---------------------------------------------------------------------------
# Type reflection metadata (from cython/type_info.pxi)
# ---------------------------------------------------------------------------

class TypeField:
"""Description of a single reflected field on an FFI-backed type."""

name: str
doc: str | None
size: int
offset: int
frozen: bool
getter: Any
setter: Any

def as_property(self, cls: type) -> property: ...

class TypeMethod:
"""Description of a single reflected method on an FFI-backed type."""

name: str
doc: str | None
func: Any
is_static: bool

class TypeInfo:
"""Aggregated type information required to build a proxy class."""

type_cls: type | None
type_index: int
type_key: str
fields: list[TypeField]
methods: list[TypeMethod]
parent_type_info: TypeInfo | None
3 changes: 1 addition & 2 deletions python/tvm_ffi/cython/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.


include "./base.pxi"
include "./type_info.pxi"
include "./dtype.pxi"
include "./device.pxi"
include "./object.pxi"
Expand Down
103 changes: 0 additions & 103 deletions python/tvm_ffi/cython/function.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -620,109 +620,6 @@ cdef class Function(Object):
_register_object_by_index(kTVMFFIFunction, Function)


cdef class FieldGetter:
cdef TVMFFIFieldGetter getter
cdef int64_t offset

def __call__(self, Object obj):
cdef TVMFFIAny result
cdef int c_api_ret_code
cdef void* field_ptr = (<char*>(<Object>obj).chandle) + self.offset
result.type_index = kTVMFFINone
result.v_int64 = 0
c_api_ret_code = self.getter(field_ptr, &result)
CHECK_CALL(c_api_ret_code)
return make_ret(result)


cdef class FieldSetter:
cdef TVMFFIFieldSetter setter
cdef int64_t offset

def __call__(self, Object obj, value):
cdef int c_api_ret_code
cdef void* field_ptr = (<char*>(<Object>obj).chandle) + self.offset
TVMFFIPyCallFieldSetter(
TVMFFIPyArgSetterFactory_,
self.setter,
field_ptr,
<PyObject*>value,
&c_api_ret_code
)
# NOTE: logic is same as check_call
# directly inline here to simplify traceback
if c_api_ret_code == 0:
return
elif c_api_ret_code == -2:
raise_existing_error()
raise move_from_last_error().py_error()


cdef _get_method_from_method_info(const TVMFFIMethodInfo* method):
cdef TVMFFIAny result
CHECK_CALL(TVMFFIAnyViewToOwnedAny(&(method.method), &result))
return make_ret(result)


def _member_method_wrapper(method_func):
def wrapper(self, *args):
return method_func(self, *args)
return wrapper


def _add_class_attrs_by_reflection(int type_index, object cls):
"""Decorate the class attrs by reflection"""
cdef const TVMFFITypeInfo* info = TVMFFIGetTypeInfo(type_index)
cdef const TVMFFIFieldInfo* field
cdef const TVMFFIMethodInfo* method
cdef int num_fields = info.num_fields
cdef int num_methods = info.num_methods

for i in range(num_fields):
# attach fields to the class
field = &(info.fields[i])
getter = FieldGetter.__new__(FieldGetter)
(<FieldGetter>getter).getter = field.getter
(<FieldGetter>getter).offset = field.offset
setter = FieldSetter.__new__(FieldSetter)
(<FieldSetter>setter).setter = field.setter
(<FieldSetter>setter).offset = field.offset
if (field.flags & kTVMFFIFieldFlagBitMaskWritable) == 0:
setter = None
doc = bytearray_to_str(&field.doc) if field.doc.size != 0 else None
name = bytearray_to_str(&field.name)
if hasattr(cls, name):
# skip already defined attributes
continue
setattr(cls, name, property(getter, setter, doc=doc))

for i in range(num_methods):
# attach methods to the class
method = &(info.methods[i])
name = bytearray_to_str(&method.name)
doc = bytearray_to_str(&method.doc) if method.doc.size != 0 else None
method_func = _get_method_from_method_info(method)

if method.flags & kTVMFFIFieldFlagBitMaskIsStaticMethod:
method_pyfunc = staticmethod(method_func)
else:
# must call into another method instead of direct capture
# to avoid the same method_func variable being used
# across multiple loop iterations
method_pyfunc = _member_method_wrapper(method_func)

if doc is not None:
method_pyfunc.__doc__ = doc
method_pyfunc.__name__ = name

if hasattr(cls, name):
# skip already defined attributes
continue
setattr(cls, name, method_pyfunc)

return cls


def _register_global_func(name, pyfunc, override):
cdef TVMFFIObjectHandle chandle
cdef int c_api_ret_code
Expand Down
100 changes: 85 additions & 15 deletions python/tvm_ffi/cython/object.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -227,18 +227,6 @@ class PyNativeObject:
self.__tvm_ffi_object__ = obj


"""Maps object type index to its constructor"""
cdef list OBJECT_TYPE = []


def _register_object_by_index(int index, object cls):
"""register object class"""
global OBJECT_TYPE
while len(OBJECT_TYPE) <= index:
OBJECT_TYPE.append(None)
OBJECT_TYPE[index] = cls


def _object_type_key_to_index(str type_key):
"""get the type index of object class"""
cdef int32_t tidx
Expand All @@ -265,13 +253,13 @@ cdef inline object make_ret_opaque_object(TVMFFIAny result):


cdef inline object make_ret_object(TVMFFIAny result):
global OBJECT_TYPE
global TYPE_INDEX_TO_INFO
cdef int32_t tindex
cdef object cls
tindex = result.type_index

if tindex < len(OBJECT_TYPE):
cls = OBJECT_TYPE[tindex]
if tindex < len(TYPE_INDEX_TO_INFO):
cls = TYPE_INDEX_TO_INFO[tindex].type_cls
if cls is not None:
if issubclass(cls, PyNativeObject):
obj = Object.__new__(Object)
Expand All @@ -290,4 +278,86 @@ cdef inline object make_ret_object(TVMFFIAny result):
return obj


cdef _get_method_from_method_info(const TVMFFIMethodInfo* method):
cdef TVMFFIAny result
CHECK_CALL(TVMFFIAnyViewToOwnedAny(&(method.method), &result))
return make_ret(result)


def _type_info_create_from_type_key(object type_cls, str type_key):
cdef const TVMFFIFieldInfo* field
cdef const TVMFFIMethodInfo* method
cdef const TVMFFITypeInfo* info
cdef int32_t type_index
cdef object fields = []
cdef object methods = []
cdef FieldGetter getter
cdef FieldSetter setter

if TVMFFITypeKeyToIndex(ByteArrayArg(c_str(type_key)).cptr(), &type_index) != 0:
raise ValueError(f"Cannot find type key: {type_key}")
info = TVMFFIGetTypeInfo(type_index)
for i in range(info.num_fields):
field = &(info.fields[i])
getter = FieldGetter.__new__(FieldGetter)
(<FieldGetter>getter).getter = field.getter
(<FieldGetter>getter).offset = field.offset
setter = FieldSetter.__new__(FieldSetter)
(<FieldSetter>setter).setter = field.setter
(<FieldSetter>setter).offset = field.offset
fields.append(
TypeField(
name=bytearray_to_str(&field.name),
doc=bytearray_to_str(&field.doc) if field.doc.size != 0 else None,
size=field.size,
offset=field.offset,
frozen=(field.flags & kTVMFFIFieldFlagBitMaskWritable) == 0,
getter=getter,
setter=setter,
)
)

for i in range(info.num_methods):
method = &(info.methods[i])
methods.append(
TypeMethod(
name=bytearray_to_str(&method.name),
doc=bytearray_to_str(&method.doc) if method.doc.size != 0 else None,
func=_get_method_from_method_info(method),
is_static=(method.flags & kTVMFFIFieldFlagBitMaskIsStaticMethod) != 0,
)
)

return TypeInfo(
type_cls=type_cls,
type_index=type_index,
type_key=bytearray_to_str(&info.type_key),
fields=fields,
methods=methods,
parent_type_info=None,
)


def _register_object_by_index(int type_index, object type_cls):
global TYPE_INDEX_TO_INFO, TYPE_KEY_TO_INFO
cdef str type_key = _type_index_to_key(type_index)
cdef object info = _type_info_create_from_type_key(type_cls, type_key)
if (extra := type_index + 1 - len(TYPE_INDEX_TO_INFO)) > 0:
TYPE_INDEX_TO_INFO.extend([None] * extra)
TYPE_INDEX_TO_INFO[type_index] = info
TYPE_KEY_TO_INFO[type_key] = info
return info


def _lookup_type_info_from_type_key(type_key: str) -> TypeInfo:
if info := TYPE_KEY_TO_INFO.get(type_key, None):
return info
info = _type_info_create_from_type_key(None, type_key)
TYPE_KEY_TO_INFO[type_key] = info
return info


cdef list TYPE_INDEX_TO_INFO = []
cdef dict TYPE_KEY_TO_INFO = {}

_set_class_object(Object)
Loading
Loading