diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi index afadcd11..9ca9a181 100644 --- a/python/tvm_ffi/core.pyi +++ b/python/tvm_ffi/core.pyi @@ -45,6 +45,15 @@ class Object: def __ne__(self, other: Any) -> bool: ... def __hash__(self) -> int: ... def __init_handle_by_constructor__(self, fconstructor: Function, *args: Any) -> None: ... + def __ffi_init__(self, *args: Any) -> None: + """Initialize the instance using the ` __init__` method registered on C++ side. + + Parameters + ---------- + args: list of objects + The arguments to the constructor + + """ def same_as(self, other: Any) -> bool: ... def _move(self) -> ObjectRValueRef: ... def __move_handle_from__(self, other: Object) -> None: ... @@ -240,6 +249,7 @@ class TypeField: frozen: bool getter: Any setter: Any + dataclass_field: Any | None def as_property(self, cls: type) -> property: ... diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi index 3d0e33ee..326a98bf 100644 --- a/python/tvm_ffi/cython/object.pxi +++ b/python/tvm_ffi/cython/object.pxi @@ -138,6 +138,16 @@ cdef class Object: (fconstructor).chandle, args, &chandle, NULL) self.chandle = chandle + def __ffi_init__(self, *args) -> None: + """Initialize the instance using the ` __init__` method registered on C++ side. + + Parameters + ---------- + args: list of objects + The arguments to the constructor + """ + self.__init_handle_by_constructor__(type(self).__c_ffi_init__, *args) + def same_as(self, other): """Check object identity. diff --git a/python/tvm_ffi/cython/type_info.pxi b/python/tvm_ffi/cython/type_info.pxi index bde25be2..a50a95fc 100644 --- a/python/tvm_ffi/cython/type_info.pxi +++ b/python/tvm_ffi/cython/type_info.pxi @@ -68,6 +68,7 @@ class TypeField: frozen: bool getter: FieldGetter setter: FieldSetter + dataclass_field: object | None = None def __post_init__(self): assert self.setter is not None diff --git a/python/tvm_ffi/dataclasses/__init__.py b/python/tvm_ffi/dataclasses/__init__.py new file mode 100644 index 00000000..31854130 --- /dev/null +++ b/python/tvm_ffi/dataclasses/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Experimental FFI interface that exposes C++ classes to Python in dataclass syntax.""" + +from dataclasses import MISSING + +from .c_class import c_class +from .field import Field, field + +__all__ = ["MISSING", "Field", "c_class", "field"] diff --git a/python/tvm_ffi/dataclasses/_utils.py b/python/tvm_ffi/dataclasses/_utils.py new file mode 100644 index 00000000..ef7c7e4c --- /dev/null +++ b/python/tvm_ffi/dataclasses/_utils.py @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utilities for constructing Python proxies of FFI types.""" + +from __future__ import annotations + +import functools +import inspect +from dataclasses import MISSING +from typing import Any, Callable, NamedTuple, TypeVar + +from ..core import ( + Object, + TypeField, + TypeInfo, + _lookup_type_info_from_type_key, +) + +_InputClsType = TypeVar("_InputClsType") + + +def get_parent_type_info(type_cls: type) -> TypeInfo: + """Find the nearest ancestor with registered ``__tvm_ffi_type_info__``. + + If none are found, return the base ``ffi.Object`` type info. + """ + for base in type_cls.__bases__: + if (info := getattr(base, "__tvm_ffi_type_info__", None)) is not None: + return info + return _lookup_type_info_from_type_key("ffi.Object") + + +def type_info_to_cls( + type_info: TypeInfo, + cls: type[_InputClsType], + methods: dict[str, Callable[..., Any] | None], +) -> type[_InputClsType]: + assert type_info.type_cls is None, "Type class is already created" + # Step 1. Determine the base classes + cls_bases = cls.__bases__ + if cls_bases == (object,): + # If the class inherits from `object`, we need to set the base class to `Object` + cls_bases = (Object,) + + # Step 2. Define the new class attributes + attrs = dict(cls.__dict__) + attrs.pop("__dict__", None) + attrs.pop("__weakref__", None) + attrs["__slots__"] = () + attrs["__tvm_ffi_type_info__"] = type_info + + # Step 2. Add fields + for field in type_info.fields: + attrs[field.name] = field.as_property(cls) + + # Step 3. Add methods + def _add_method(name: str, func: Callable) -> None: + if name == "__ffi_init__": + name = "__c_ffi_init__" + if name in attrs: # already defined + return + func.__module__ = cls.__module__ + func.__name__ = name + func.__qualname__ = f"{cls.__qualname__}.{name}" + func.__doc__ = f"Method `{name}` of class `{cls.__qualname__}`" + attrs[name] = func + setattr(cls, name, func) + + for name, method in methods.items(): + if method is not None: + _add_method(name, method) + for method in type_info.methods: + _add_method(method.name, method.func) + + # Step 4. Create the new class + new_cls = type(cls.__name__, cls_bases, attrs) + new_cls.__module__ = cls.__module__ + new_cls = functools.wraps(cls, updated=())(new_cls) # type: ignore + return new_cls + + +def fill_dataclass_field(type_cls: type, type_field: TypeField) -> None: + from .field import Field, field # noqa: PLC0415 + + field_name = type_field.name + rhs: Any = getattr(type_cls, field_name, MISSING) + if rhs is MISSING: + rhs = field() + elif isinstance(rhs, Field): + pass + elif isinstance(rhs, (int, float, str, bool, type(None))): + rhs = field(default=rhs) + else: + raise ValueError(f"Cannot recognize field: {type_field.name}: {rhs}") + assert isinstance(rhs, Field) + rhs.name = type_field.name + type_field.dataclass_field = rhs + + +def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]: # noqa: PLR0915 + """Generate an ``__init__`` that forwards to the FFI constructor. + + The generated initializer has a proper Python signature built from the + reflected field list, supporting default values and ``__post_init__``. + """ + + class DefaultFactory(NamedTuple): + """Wrapper that marks a parameter as having a default factory.""" + + fn: Callable[[], Any] + + fields: list[TypeInfo] = [] + cur_type_info = type_info + while True: + fields.extend(reversed(cur_type_info.fields)) + cur_type_info = cur_type_info.parent_type_info + if cur_type_info is None: + break + fields.reverse() + del cur_type_info + + annotations: dict[str, Any] = {"return": None} + # Step 1. Split the parameters into two groups to ensure that + # those without defaults appear first in the signature. + params_without_defaults: list[inspect.Parameter] = [] + params_with_defaults: list[inspect.Parameter] = [] + ordering = [0] * len(fields) + for i, field in enumerate(fields): + assert field.name is not None + name: str = field.name + annotations[name] = Any # NOTE: We might be able to handle annotations better + assert field.dataclass_field is not None + default_factory = field.dataclass_field.default_factory + if default_factory is MISSING: + ordering[i] = len(params_without_defaults) + params_without_defaults.append( + inspect.Parameter(name=name, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD) + ) + else: + ordering[i] = -len(params_with_defaults) - 1 + params_with_defaults.append( + inspect.Parameter( + name=name, + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=DefaultFactory(fn=default_factory), + ) + ) + for i, order in enumerate(ordering): + if order < 0: + ordering[i] = len(params_without_defaults) - order - 1 + # Step 2. Create the signature object + sig = inspect.Signature(parameters=[*params_without_defaults, *params_with_defaults]) + signature_str = ( + f"{type_cls.__module__}.{type_cls.__qualname__}.__init__(" + + ", ".join(p.name for p in sig.parameters.values()) + + ")" + ) + + # Step 3. Create the `binding` method that reorders parameters + def touch_arg(x: Any) -> Any: + return x.fn() if isinstance(x, DefaultFactory) else x + + def bind_args(*args: Any, **kwargs: Any) -> tuple[Any, ...]: + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + args = bound.args + args = tuple(touch_arg(args[i]) for i in ordering) + return args + + for type_method in type_info.methods: + if type_method.name == "__ffi_init__": + break + else: + raise ValueError(f"Cannot find constructor method: `{type_info.type_key}.__ffi_init__`") + + def __init__(self: type, *args: Any, **kwargs: Any) -> None: + e = None + try: + args = bind_args(*args, **kwargs) + del kwargs + self.__ffi_init__(*args) + except Exception as _e: + e = TypeError(f"Error in `{signature_str}`: {_e}").with_traceback(_e.__traceback__) + if e is not None: + raise e + try: + fn_post_init = self.__post_init__ # type: ignore[attr-defined] + except AttributeError: + pass + else: + fn_post_init() + + __init__.__signature__ = sig # type: ignore[attr-defined] + __init__.__annotations__ = annotations + return __init__ diff --git a/python/tvm_ffi/dataclasses/c_class.py b/python/tvm_ffi/dataclasses/c_class.py new file mode 100644 index 00000000..7507b76c --- /dev/null +++ b/python/tvm_ffi/dataclasses/c_class.py @@ -0,0 +1,171 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Helpers for mirroring registered C++ FFI types with Python dataclass syntax. + +The :func:`c_class` decorator is the primary entry point. It inspects the +reflection metadata that the C++ runtime exposes via the TVM FFI registry and +turns it into Python ``dataclass``-style descriptors: annotated attributes become +properties that forward to the underlying C++ object, while an ``__init__`` +method is synthesized to call the FFI constructor when requested. +""" + +from collections.abc import Callable +from dataclasses import InitVar +from typing import ClassVar, TypeVar, get_origin, get_type_hints + +from ..core import TypeField, TypeInfo +from . import _utils, field + +try: + from typing import dataclass_transform +except ImportError: + from typing_extensions import dataclass_transform + + +_InputClsType = TypeVar("_InputClsType") + + +@dataclass_transform(field_specifiers=(field.field, field.Field)) +def c_class( + type_key: str, init: bool = True +) -> Callable[[type[_InputClsType]], type[_InputClsType]]: + """(Experimental) Create a dataclass-like proxy for a C++ class registered with TVM FFI. + + The decorator reads the reflection metadata that was registered on the C++ + side using ``tvm::ffi::reflection::ObjectDef`` and binds it to the annotated + attributes in the decorated Python class. Each field defined in C++ becomes + a property on the Python class, and optional default values can be provided + with :func:`tvm_ffi.dataclasses.field` in the same way as Python's native + ``dataclasses.field``. + + The intent is to offer a familiar dataclass authoring experience while still + exposing the underlying C++ object. The ``type_key`` of the C++ class must + match the string passed to :func:`c_class`, and inheritance relationships are + preserved—subclasses registered in C++ can subclass the Python proxy defined + for their parent. + + Parameters + ---------- + type_key : str + The reflection key that identifies the C++ type in the FFI registry, + e.g. ``"testing.MyClass"`` as registered in + ``src/ffi/extra/testing.cc``. + + init : bool, default True + If ``True`` and the Python class does not define ``__init__``, an + initializer is auto-generated that mirrors the reflected constructor + signature. The generated initializer calls the C++ ``__init__`` + function registered with ``ObjectDef`` and invokes ``__post_init__`` if + it exists on the Python class. + + Returns + ------- + Callable[[type], type] + A class decorator that materializes the final proxy class. + + Examples + -------- + Register the C++ type and its fields with TVM FFI: + + .. code-block:: c++ + + TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_static("__init__", [](int64_t v_i64, int32_t v_i32, + double v_f64, float v_f32) -> Any { + return ObjectRef(ffi::make_object( + v_i64, v_i32, v_f64, v_f32)); + }) + .def_rw("v_i64", &MyClass::v_i64) + .def_rw("v_i32", &MyClass::v_i32) + .def_rw("v_f64", &MyClass::v_f64) + .def_rw("v_f32", &MyClass::v_f32); + } + + Mirror the same structure in Python using dataclass-style annotations: + + .. code-block:: python + + from tvm_ffi.dataclasses import c_class, field + + @c_class("example.MyClass") + class MyClass: + v_i64: int + v_i32: int + v_f64: float = field(default=0.0) + v_f32: float = field(default_factory=lambda: 1.0) + + obj = MyClass(v_i64=4, v_i32=8) + obj.v_f64 = 3.14 # transparently forwards to the underlying C++ object + + """ + + def decorator(super_type_cls: type[_InputClsType]) -> type[_InputClsType]: + nonlocal init + init = init and "__init__" not in super_type_cls.__dict__ + # Step 1. Retrieve `type_info` from registry + type_info: TypeInfo = _utils._lookup_type_info_from_type_key(type_key) + assert type_info.parent_type_info is None, f"Already registered type: {type_key}" + type_info.parent_type_info = _utils.get_parent_type_info(super_type_cls) + # Step 2. Reflect all the fields of the type + type_info.fields = _inspect_c_class_fields(super_type_cls, type_info) + for type_field in type_info.fields: + _utils.fill_dataclass_field(super_type_cls, type_field) + # Step 3. Create the proxy class with the fields as properties + fn_init = _utils.method_init(super_type_cls, type_info) if init else None + type_cls: type[_InputClsType] = _utils.type_info_to_cls( + type_info=type_info, + cls=super_type_cls, + methods={"__init__": fn_init}, + ) + type_info.type_cls = type_cls + return type_cls + + return decorator + + +def _inspect_c_class_fields(type_cls: type, type_info: TypeInfo) -> list[TypeField]: + type_hints_resolved = get_type_hints(type_cls, include_extras=True) + type_hints_py = { + name: type_hints_resolved[name] + for name in getattr(type_cls, "__annotations__", {}).keys() + if get_origin(type_hints_resolved[name]) + not in [ # ignore non-field annotations + ClassVar, + InitVar, + ] + } + del type_hints_resolved + + type_fields_cxx: dict[str, TypeField] = {f.name: f for f in type_info.fields} + type_fields: list[TypeField] = [] + for field_name, _field_ty_py in type_hints_py.items(): + if field_name.startswith("__tvm_ffi"): # TVM's private fields - skip + continue + type_field: TypeField = type_fields_cxx.pop(field_name, None) + if type_field is None: + raise ValueError( + f"Extraneous field `{type_cls}.{field_name}`. Defined in Python but not in C++" + ) + type_fields.append(type_field) + if type_fields_cxx: + extra_fields = ", ".join(f"`{f.name}`" for f in type_fields_cxx.values()) + raise ValueError( + f"Missing fields in `{type_cls}`: {extra_fields}. Defined in C++ but not in Python" + ) + return type_fields diff --git a/python/tvm_ffi/dataclasses/field.py b/python/tvm_ffi/dataclasses/field.py new file mode 100644 index 00000000..00170e5e --- /dev/null +++ b/python/tvm_ffi/dataclasses/field.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Public helpers for describing dataclass-style defaults on FFI proxies.""" + +from __future__ import annotations + +from dataclasses import MISSING, dataclass +from typing import Any, Callable + + +@dataclass(kw_only=True) +class Field: + """(Experimental) Descriptor placeholder returned by :func:`tvm_ffi.dataclasses.field`. + + A ``Field`` mirrors the object returned by :func:`dataclasses.field`, but it + is understood by :func:`tvm_ffi.dataclasses.c_class`. The decorator inspects + the ``Field`` instances, records the ``default_factory`` and later replaces + the field with a property that forwards to the underlying C++ attribute. + + Users should not instantiate ``Field`` directly—use :func:`field` instead, + which guarantees that ``name`` and ``default_factory`` are populated in a + way the decorator understands. + """ + + name: str | None = None + default_factory: Callable[[], Any] + + +def field(*, default: Any = MISSING, default_factory: Any = MISSING) -> Field: + """(Experimental) Declare a dataclass-style field on a :func:`c_class` proxy. + + Use this helper exactly like :func:`dataclasses.field` when defining the + Python side of a C++ class. When :func:`c_class` processes the class body it + replaces the placeholder with a property and arranges for ``default`` or + ``default_factory`` to be respected by the synthesized ``__init__``. + + Parameters + ---------- + default : Any, optional + A literal default value that should populate the field when no argument + is given. The value is copied into a closure because TVM FFI does not + mutate the Python placeholder instance. + default_factory : Callable[[], Any], optional + A zero-argument callable that produces the default. This matches the + semantics of :func:`dataclasses.field` and is useful for mutable + defaults such as ``list`` or ``dict``. + + Returns + ------- + Field + A placeholder object that :func:`c_class` will consume during class + registration. + + Examples + -------- + ``field`` integrates with :func:`c_class` to express defaults the same way a + Python ``dataclass`` would:: + + @c_class("testing.TestCxxClassBase") + class PyBase: + v_i64: int + v_i32: int = field(default=16) + + obj = PyBase(v_i64=4) + obj.v_i32 # -> 16 + + """ + if default is not MISSING and default_factory is not MISSING: + raise ValueError("Cannot specify both `default` and `default_factory`") + if default is not MISSING: + default_factory = _make_default_factory(default) + return Field(default_factory=default_factory) + + +def _make_default_factory(value: Any) -> Callable[[], Any]: + """Make a default factory that returns the given value.""" + + def factory() -> Any: + return value + + return factory diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py index 5a540fb6..6bf08f64 100644 --- a/python/tvm_ffi/registry.py +++ b/python/tvm_ffi/registry.py @@ -248,6 +248,8 @@ def _add_class_attrs(type_cls: type, type_info: TypeInfo) -> type: setattr(type_cls, name, property(getter, setter, doc=doc)) for method in type_info.methods: name = method.name + if name == "__ffi_init__": + name = "__c_ffi_init__" doc = method.doc if method.doc else None method_func = method.func if method.is_static: diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing.py index 3215d8a0..825f9cfc 100644 --- a/python/tvm_ffi/testing.py +++ b/python/tvm_ffi/testing.py @@ -16,10 +16,11 @@ # under the License. """Testing utilities.""" -from typing import Any +from typing import Any, ClassVar from . import _ffi_api from .core import Object +from .dataclasses import c_class, field from .registry import register_object @@ -34,7 +35,7 @@ class TestIntPair(Object): def __init__(self, a: int, b: int) -> None: """Construct the object.""" - self.__init_handle_by_constructor__(TestIntPair.__ffi_init__, a, b) + self.__ffi_init__(a, b) @register_object("testing.TestObjectDerived") @@ -68,3 +69,26 @@ def create_object(type_key: str, **kwargs: Any) -> Object: args.append(k) args.append(v) return _ffi_api.MakeObjectFromPackedArgs(*args) + + +@c_class("testing.TestCxxClassBase") +class _TestCxxClassBase: + v_i64: int + v_i32: int + not_field_1 = 1 + not_field_2: ClassVar[int] = 2 + + def __init__(self, v_i64: int, v_i32: int) -> None: + self.__ffi_init__(v_i64 + 1, v_i32 + 2) + + +@c_class("testing.TestCxxClassDerived") +class _TestCxxClassDerived(_TestCxxClassBase): + v_f64: float + v_f32: float = 8 + + +@c_class("testing.TestCxxClassDerivedDerived") +class _TestCxxClassDerivedDerived(_TestCxxClassDerived): + v_str: str = field(default_factory=lambda: "default") + v_bool: bool diff --git a/src/ffi/extra/testing.cc b/src/ffi/extra/testing.cc index 9c3a0198..370a1c62 100644 --- a/src/ffi/extra/testing.cc +++ b/src/ffi/extra/testing.cc @@ -86,6 +86,41 @@ class TestObjectDerived : public TestObjectBase { TVM_FFI_DECLARE_OBJECT_INFO_FINAL("testing.TestObjectDerived", TestObjectDerived, TestObjectBase); }; +class TestCxxClassBase : public Object { + public: + int64_t v_i64; + int32_t v_i32; + + TestCxxClassBase(int64_t v_i64, int32_t v_i32) : v_i64(v_i64), v_i32(v_i32) {} + + static constexpr bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxClassBase", TestCxxClassBase, Object); +}; + +class TestCxxClassDerived : public TestCxxClassBase { + public: + double v_f64; + float v_f32; + + TestCxxClassDerived(int64_t v_i64, int32_t v_i32, double v_f64, float v_f32) + : TestCxxClassBase(v_i64, v_i32), v_f64(v_f64), v_f32(v_f32) {} + + TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxClassDerived", TestCxxClassDerived, TestCxxClassBase); +}; + +class TestCxxClassDerivedDerived : public TestCxxClassDerived { + public: + String v_str; + bool v_bool; + + TestCxxClassDerivedDerived(int64_t v_i64, int32_t v_i32, double v_f64, float v_f32, String v_str, + bool v_bool) + : TestCxxClassDerived(v_i64, v_i32, v_f64, v_f32), v_str(v_str), v_bool(v_bool) {} + + TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxClassDerivedDerived", TestCxxClassDerivedDerived, + TestCxxClassDerived); +}; + TVM_FFI_NO_INLINE void TestRaiseError(String kind, String msg) { // keep name and no liner for testing traceback throw ffi::Error(kind, msg, TVMFFITraceback(__FILE__, __LINE__, TVM_FFI_FUNC_SIG, 0)); @@ -110,6 +145,22 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def_ro("v_map", &TestObjectDerived::v_map) .def_ro("v_array", &TestObjectDerived::v_array); + refl::ObjectDef() + .def_static("__ffi_init__", refl::init) + .def_rw("v_i64", &TestCxxClassBase::v_i64) + .def_rw("v_i32", &TestCxxClassBase::v_i32); + + refl::ObjectDef() + .def_static("__ffi_init__", refl::init) + .def_rw("v_f64", &TestCxxClassDerived::v_f64) + .def_rw("v_f32", &TestCxxClassDerived::v_f32); + refl::ObjectDef() + .def_static( + "__ffi_init__", + refl::init) + .def_rw("v_str", &TestCxxClassDerivedDerived::v_str) + .def_rw("v_bool", &TestCxxClassDerivedDerived::v_bool); + refl::GlobalDef() .def("testing.test_raise_error", TestRaiseError) .def_packed("testing.nop", [](PackedArgs args, Any* ret) {}) diff --git a/tests/python/test_dataclasses_c_class.py b/tests/python/test_dataclasses_c_class.py new file mode 100644 index 00000000..a2fa80eb --- /dev/null +++ b/tests/python/test_dataclasses_c_class.py @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from tvm_ffi.testing import _TestCxxClassBase, _TestCxxClassDerived, _TestCxxClassDerivedDerived + + +def test_cxx_class_base() -> None: + obj = _TestCxxClassBase(v_i64=123, v_i32=456) + assert obj.v_i64 == 123 + 1 + assert obj.v_i32 == 456 + 2 + + +def test_cxx_class_derived() -> None: + obj = _TestCxxClassDerived(v_i64=123, v_i32=456, v_f64=4.00, v_f32=8.00) + assert obj.v_i64 == 123 + assert obj.v_i32 == 456 + assert obj.v_f64 == 4.00 + assert obj.v_f32 == 8.00 + + +def test_cxx_class_derived_default() -> None: + obj = _TestCxxClassDerived(v_i64=123, v_i32=456, v_f64=4.00) + assert obj.v_i64 == 123 + assert obj.v_i32 == 456 + assert obj.v_f64 == 4.00 + assert isinstance(obj.v_f32, float) and obj.v_f32 == 8.00 # default value + + +def test_cxx_class_derived_derived() -> None: + obj = _TestCxxClassDerivedDerived( + v_i64=123, + v_i32=456, + v_f64=4.00, + v_f32=8.00, + v_str="hello", + v_bool=True, + ) + assert obj.v_i64 == 123 + assert obj.v_i32 == 456 + assert obj.v_f64 == 4.00 + assert obj.v_f32 == 8.00 + assert obj.v_str == "hello" + assert obj.v_bool is True + + +def test_cxx_class_derived_derived_default() -> None: + obj = _TestCxxClassDerivedDerived(123, 456, 4, True) + assert obj.v_i64 == 123 + assert obj.v_i32 == 456 + assert isinstance(obj.v_f64, float) and obj.v_f64 == 4 + assert isinstance(obj.v_f32, float) and obj.v_f32 == 8 + assert obj.v_str == "default" + assert isinstance(obj.v_bool, bool) and obj.v_bool is True