diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi index d57d020e..afadcd11 100644 --- a/python/tvm_ffi/core.pyi +++ b/python/tvm_ffi/core.pyi @@ -74,9 +74,9 @@ class PyNativeObject: ) -> None: ... def _set_class_object(cls: type) -> None: ... -def _register_object_by_index(index: int, cls: type) -> None: ... +def _register_object_by_index(type_index: int, type_cls: type) -> TypeInfo: ... def _object_type_key_to_index(type_key: str) -> int | None: ... -def _add_class_attrs_by_reflection(type_index: int, cls: type) -> type: ... +def _lookup_type_info_from_type_key(type_key: str) -> TypeInfo: ... class Error(Object): """Base class for FFI errors.""" @@ -225,3 +225,38 @@ class Bytes(bytes, PyNativeObject): # pylint: disable=no-self-argument def __from_tvm_ffi_object__(cls, obj: Any) -> Bytes: ... + +# --------------------------------------------------------------------------- +# Type reflection metadata (from cython/type_info.pxi) +# --------------------------------------------------------------------------- + +class TypeField: + """Description of a single reflected field on an FFI-backed type.""" + + name: str + doc: str | None + size: int + offset: int + frozen: bool + getter: Any + setter: Any + + def as_property(self, cls: type) -> property: ... + +class TypeMethod: + """Description of a single reflected method on an FFI-backed type.""" + + name: str + doc: str | None + func: Any + is_static: bool + +class TypeInfo: + """Aggregated type information required to build a proxy class.""" + + type_cls: type | None + type_index: int + type_key: str + fields: list[TypeField] + methods: list[TypeMethod] + parent_type_info: TypeInfo | None diff --git a/python/tvm_ffi/cython/core.pyx b/python/tvm_ffi/cython/core.pyx index b24a83da..ca3a0ce0 100644 --- a/python/tvm_ffi/cython/core.pyx +++ b/python/tvm_ffi/cython/core.pyx @@ -14,9 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - - include "./base.pxi" +include "./type_info.pxi" include "./dtype.pxi" include "./device.pxi" include "./object.pxi" diff --git a/python/tvm_ffi/cython/function.pxi b/python/tvm_ffi/cython/function.pxi index c4662ca1..f4503ecb 100644 --- a/python/tvm_ffi/cython/function.pxi +++ b/python/tvm_ffi/cython/function.pxi @@ -620,109 +620,6 @@ cdef class Function(Object): _register_object_by_index(kTVMFFIFunction, Function) -cdef class FieldGetter: - cdef TVMFFIFieldGetter getter - cdef int64_t offset - - def __call__(self, Object obj): - cdef TVMFFIAny result - cdef int c_api_ret_code - cdef void* field_ptr = ((obj).chandle) + self.offset - result.type_index = kTVMFFINone - result.v_int64 = 0 - c_api_ret_code = self.getter(field_ptr, &result) - CHECK_CALL(c_api_ret_code) - return make_ret(result) - - -cdef class FieldSetter: - cdef TVMFFIFieldSetter setter - cdef int64_t offset - - def __call__(self, Object obj, value): - cdef int c_api_ret_code - cdef void* field_ptr = ((obj).chandle) + self.offset - TVMFFIPyCallFieldSetter( - TVMFFIPyArgSetterFactory_, - self.setter, - field_ptr, - value, - &c_api_ret_code - ) - # NOTE: logic is same as check_call - # directly inline here to simplify traceback - if c_api_ret_code == 0: - return - elif c_api_ret_code == -2: - raise_existing_error() - raise move_from_last_error().py_error() - - -cdef _get_method_from_method_info(const TVMFFIMethodInfo* method): - cdef TVMFFIAny result - CHECK_CALL(TVMFFIAnyViewToOwnedAny(&(method.method), &result)) - return make_ret(result) - - -def _member_method_wrapper(method_func): - def wrapper(self, *args): - return method_func(self, *args) - return wrapper - - -def _add_class_attrs_by_reflection(int type_index, object cls): - """Decorate the class attrs by reflection""" - cdef const TVMFFITypeInfo* info = TVMFFIGetTypeInfo(type_index) - cdef const TVMFFIFieldInfo* field - cdef const TVMFFIMethodInfo* method - cdef int num_fields = info.num_fields - cdef int num_methods = info.num_methods - - for i in range(num_fields): - # attach fields to the class - field = &(info.fields[i]) - getter = FieldGetter.__new__(FieldGetter) - (getter).getter = field.getter - (getter).offset = field.offset - setter = FieldSetter.__new__(FieldSetter) - (setter).setter = field.setter - (setter).offset = field.offset - if (field.flags & kTVMFFIFieldFlagBitMaskWritable) == 0: - setter = None - doc = bytearray_to_str(&field.doc) if field.doc.size != 0 else None - name = bytearray_to_str(&field.name) - if hasattr(cls, name): - # skip already defined attributes - continue - setattr(cls, name, property(getter, setter, doc=doc)) - - for i in range(num_methods): - # attach methods to the class - method = &(info.methods[i]) - name = bytearray_to_str(&method.name) - doc = bytearray_to_str(&method.doc) if method.doc.size != 0 else None - method_func = _get_method_from_method_info(method) - - if method.flags & kTVMFFIFieldFlagBitMaskIsStaticMethod: - method_pyfunc = staticmethod(method_func) - else: - # must call into another method instead of direct capture - # to avoid the same method_func variable being used - # across multiple loop iterations - method_pyfunc = _member_method_wrapper(method_func) - - if doc is not None: - method_pyfunc.__doc__ = doc - method_pyfunc.__name__ = name - - if hasattr(cls, name): - # skip already defined attributes - continue - setattr(cls, name, method_pyfunc) - - return cls - - def _register_global_func(name, pyfunc, override): cdef TVMFFIObjectHandle chandle cdef int c_api_ret_code diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi index 08cfd546..3d0e33ee 100644 --- a/python/tvm_ffi/cython/object.pxi +++ b/python/tvm_ffi/cython/object.pxi @@ -227,18 +227,6 @@ class PyNativeObject: self.__tvm_ffi_object__ = obj -"""Maps object type index to its constructor""" -cdef list OBJECT_TYPE = [] - - -def _register_object_by_index(int index, object cls): - """register object class""" - global OBJECT_TYPE - while len(OBJECT_TYPE) <= index: - OBJECT_TYPE.append(None) - OBJECT_TYPE[index] = cls - - def _object_type_key_to_index(str type_key): """get the type index of object class""" cdef int32_t tidx @@ -265,13 +253,13 @@ cdef inline object make_ret_opaque_object(TVMFFIAny result): cdef inline object make_ret_object(TVMFFIAny result): - global OBJECT_TYPE + global TYPE_INDEX_TO_INFO cdef int32_t tindex cdef object cls tindex = result.type_index - if tindex < len(OBJECT_TYPE): - cls = OBJECT_TYPE[tindex] + if tindex < len(TYPE_INDEX_TO_INFO): + cls = TYPE_INDEX_TO_INFO[tindex].type_cls if cls is not None: if issubclass(cls, PyNativeObject): obj = Object.__new__(Object) @@ -290,4 +278,86 @@ cdef inline object make_ret_object(TVMFFIAny result): return obj +cdef _get_method_from_method_info(const TVMFFIMethodInfo* method): + cdef TVMFFIAny result + CHECK_CALL(TVMFFIAnyViewToOwnedAny(&(method.method), &result)) + return make_ret(result) + + +def _type_info_create_from_type_key(object type_cls, str type_key): + cdef const TVMFFIFieldInfo* field + cdef const TVMFFIMethodInfo* method + cdef const TVMFFITypeInfo* info + cdef int32_t type_index + cdef object fields = [] + cdef object methods = [] + cdef FieldGetter getter + cdef FieldSetter setter + + if TVMFFITypeKeyToIndex(ByteArrayArg(c_str(type_key)).cptr(), &type_index) != 0: + raise ValueError(f"Cannot find type key: {type_key}") + info = TVMFFIGetTypeInfo(type_index) + for i in range(info.num_fields): + field = &(info.fields[i]) + getter = FieldGetter.__new__(FieldGetter) + (getter).getter = field.getter + (getter).offset = field.offset + setter = FieldSetter.__new__(FieldSetter) + (setter).setter = field.setter + (setter).offset = field.offset + fields.append( + TypeField( + name=bytearray_to_str(&field.name), + doc=bytearray_to_str(&field.doc) if field.doc.size != 0 else None, + size=field.size, + offset=field.offset, + frozen=(field.flags & kTVMFFIFieldFlagBitMaskWritable) == 0, + getter=getter, + setter=setter, + ) + ) + + for i in range(info.num_methods): + method = &(info.methods[i]) + methods.append( + TypeMethod( + name=bytearray_to_str(&method.name), + doc=bytearray_to_str(&method.doc) if method.doc.size != 0 else None, + func=_get_method_from_method_info(method), + is_static=(method.flags & kTVMFFIFieldFlagBitMaskIsStaticMethod) != 0, + ) + ) + + return TypeInfo( + type_cls=type_cls, + type_index=type_index, + type_key=bytearray_to_str(&info.type_key), + fields=fields, + methods=methods, + parent_type_info=None, + ) + + +def _register_object_by_index(int type_index, object type_cls): + global TYPE_INDEX_TO_INFO, TYPE_KEY_TO_INFO + cdef str type_key = _type_index_to_key(type_index) + cdef object info = _type_info_create_from_type_key(type_cls, type_key) + if (extra := type_index + 1 - len(TYPE_INDEX_TO_INFO)) > 0: + TYPE_INDEX_TO_INFO.extend([None] * extra) + TYPE_INDEX_TO_INFO[type_index] = info + TYPE_KEY_TO_INFO[type_key] = info + return info + + +def _lookup_type_info_from_type_key(type_key: str) -> TypeInfo: + if info := TYPE_KEY_TO_INFO.get(type_key, None): + return info + info = _type_info_create_from_type_key(None, type_key) + TYPE_KEY_TO_INFO[type_key] = info + return info + + +cdef list TYPE_INDEX_TO_INFO = [] +cdef dict TYPE_KEY_TO_INFO = {} + _set_class_object(Object) diff --git a/python/tvm_ffi/cython/type_info.pxi b/python/tvm_ffi/cython/type_info.pxi new file mode 100644 index 00000000..2abb2040 --- /dev/null +++ b/python/tvm_ffi/cython/type_info.pxi @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import dataclasses + + +cdef class FieldGetter: + cdef dict __dict__ + cdef TVMFFIFieldGetter getter + cdef int64_t offset + + def __call__(self, Object obj): + cdef TVMFFIAny result + cdef int c_api_ret_code + cdef void* field_ptr = ((obj).chandle) + self.offset + result.type_index = kTVMFFINone + result.v_int64 = 0 + c_api_ret_code = self.getter(field_ptr, &result) + CHECK_CALL(c_api_ret_code) + return make_ret(result) + + +cdef class FieldSetter: + cdef dict __dict__ + cdef TVMFFIFieldSetter setter + cdef int64_t offset + + def __call__(self, Object obj, value): + cdef int c_api_ret_code + cdef void* field_ptr = ((obj).chandle) + self.offset + TVMFFIPyCallFieldSetter( + TVMFFIPyArgSetterFactory_, + self.setter, + field_ptr, + value, + &c_api_ret_code + ) + # NOTE: logic is same as check_call + # directly inline here to simplify traceback + if c_api_ret_code == 0: + return + elif c_api_ret_code == -2: + raise_existing_error() + raise move_from_last_error().py_error() + + +@dataclasses.dataclass(eq=False) +class TypeField: + """Description of a single reflected field on an FFI-backed type.""" + + name: str + doc: str | None + size: int + offset: int + frozen: bool + getter: FieldGetter + setter: FieldSetter + + def __post_init__(self): + assert self.setter is not None + assert self.getter is not None + + def as_property(self, cls: type) -> property: + """Create a Python ``property`` object for this field on ``cls``.""" + name = self.name + fget = self.getter + fset = self.setter + fget.__name__ = fset.__name__ = name + fget.__module__ = fset.__module__ = cls.__module__ + fget.__qualname__ = fset.__qualname__ = f"{cls.__qualname__}.{name}" # type: ignore[attr-defined] + fget.__doc__ = fset.__doc__ = f"Property `{name}` of class `{cls.__qualname__}`" # type: ignore[attr-defined] + + return property( + fget=fget if self.getter is not None else None, + fset=fset if (not self.frozen) and self.setter is not None else None, + doc=f"{cls.__module__}.{cls.__qualname__}.{name}", + ) + + +@dataclasses.dataclass(eq=False) +class TypeMethod: + """Description of a single reflected method on an FFI-backed type.""" + + name: str + doc: str | None + func: object + is_static: bool + + +@dataclasses.dataclass(eq=False) +class TypeInfo: + """Aggregated type information required to build a proxy class.""" + + type_cls: type | None + type_index: int + type_key: str + fields: list[TypeField] + methods: list[TypeMethod] + parent_type_info: TypeInfo | None diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py index 1f4e340c..5a540fb6 100644 --- a/python/tvm_ffi/registry.py +++ b/python/tvm_ffi/registry.py @@ -20,6 +20,7 @@ from typing import Any, Callable, Optional from . import core +from .core import TypeInfo # whether we simplify skip unknown objects regtistration _SKIP_UNKNOWN_OBJECTS = False @@ -54,8 +55,8 @@ def register(cls: type) -> type: if _SKIP_UNKNOWN_OBJECTS: return cls raise ValueError(f"Cannot find object type index for {object_name}") - core._add_class_attrs_by_reflection(type_index, cls) - core._register_object_by_index(type_index, cls) + info = core._register_object_by_index(type_index, cls) + _add_class_attrs(type_cls=cls, type_info=info) return cls if isinstance(type_key, str): @@ -228,6 +229,46 @@ def init_ffi_api(namespace: str, target_module_name: Optional[str] = None) -> No setattr(target_module, f.__name__, f) +def _member_method_wrapper(method_func: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(self: Any, *args: Any) -> Any: + return method_func(self, *args) + + return wrapper + + +def _add_class_attrs(type_cls: type, type_info: TypeInfo) -> type: + for field in type_info.fields: + getter = field.getter + setter = field.setter if not field.frozen else None + doc = field.doc if field.doc else None + name = field.name + if hasattr(type_cls, name): + # skip already defined attributes + continue + setattr(type_cls, name, property(getter, setter, doc=doc)) + for method in type_info.methods: + name = method.name + doc = method.doc if method.doc else None + method_func = method.func + if method.is_static: + method_pyfunc = staticmethod(method_func) + else: + # must call into another method instead of direct capture + # to avoid the same method_func variable being used + # across multiple loop iterations + method_pyfunc = _member_method_wrapper(method_func) + + if doc is not None: + method_pyfunc.__doc__ = doc + method_pyfunc.__name__ = name + + if hasattr(type_cls, name): + # skip already defined attributes + continue + setattr(type_cls, name, method_pyfunc) + return type_cls + + __all__ = [ "get_global_func", "init_ffi_api",