diff --git a/python/tvm_ffi/dataclasses/__init__.py b/python/tvm_ffi/dataclasses/__init__.py index 31854130..bfb44049 100644 --- a/python/tvm_ffi/dataclasses/__init__.py +++ b/python/tvm_ffi/dataclasses/__init__.py @@ -19,6 +19,6 @@ from dataclasses import MISSING from .c_class import c_class -from .field import Field, field +from .field import KW_ONLY, Field, field -__all__ = ["MISSING", "Field", "c_class", "field"] +__all__ = ["KW_ONLY", "MISSING", "Field", "c_class", "field"] diff --git a/python/tvm_ffi/dataclasses/_utils.py b/python/tvm_ffi/dataclasses/_utils.py index 5ed4e963..7c0afb4f 100644 --- a/python/tvm_ffi/dataclasses/_utils.py +++ b/python/tvm_ffi/dataclasses/_utils.py @@ -79,7 +79,13 @@ def _add_method(name: str, func: Callable[..., Any]) -> None: return cast(Type[_InputClsType], new_cls) -def fill_dataclass_field(type_cls: type, type_field: TypeField) -> None: +def fill_dataclass_field( + type_cls: type, + type_field: TypeField, + *, + class_kw_only: bool = False, + kw_only_from_sentinel: bool = False, +) -> None: from .field import Field, field # noqa: PLC0415 field_name = type_field.name @@ -94,6 +100,14 @@ def fill_dataclass_field(type_cls: type, type_field: TypeField) -> None: raise ValueError(f"Cannot recognize field: {type_field.name}: {rhs}") assert isinstance(rhs, Field) rhs.name = type_field.name + + # Resolve kw_only: field-level > KW_ONLY sentinel > class-level + if rhs.kw_only is MISSING: + if kw_only_from_sentinel: + rhs.kw_only = True + else: + rhs.kw_only = class_kw_only + type_field.dataclass_field = rhs @@ -148,47 +162,56 @@ def method_repr(type_cls: type, type_info: TypeInfo) -> Callable[..., str]: return __repr__ -def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]: +def method_init(_type_cls: type, type_info: TypeInfo) -> Callable[..., None]: """Generate an ``__init__`` that forwards to the FFI constructor. The generated initializer has a proper Python signature built from the - reflected field list, supporting default values and ``__post_init__``. + reflected field list, supporting default values, keyword-only args, and ``__post_init__``. """ # Step 0. Collect all fields from the type hierarchy fields = _get_all_fields(type_info) # sanity check - for type_method in type_info.methods: - if type_method.name == "__ffi_init__": - break - else: + if not any(m.name == "__ffi_init__" for m in type_info.methods): raise ValueError(f"Cannot find constructor method: `{type_info.type_key}.__ffi_init__`") # Step 1. Split args into sections and register default factories - args_no_defaults: list[str] = [] - args_with_defaults: list[str] = [] + pos_no_defaults: list[str] = [] + pos_with_defaults: list[str] = [] + kw_no_defaults: list[str] = [] + kw_with_defaults: list[str] = [] fields_with_defaults: list[tuple[str, bool]] = [] ffi_arg_order: list[str] = [] - exec_globals = {"MISSING": MISSING} + exec_globals: dict[str, Any] = {"MISSING": MISSING} + for field in fields: assert field.name is not None assert field.dataclass_field is not None dataclass_field = field.dataclass_field - has_default_factory = (default_factory := dataclass_field.default_factory) is not MISSING + has_default = (default_factory := dataclass_field.default_factory) is not MISSING + is_kw_only = dataclass_field.kw_only is True + if dataclass_field.init: ffi_arg_order.append(field.name) - if has_default_factory: - args_with_defaults.append(field.name) + if has_default: + (kw_with_defaults if is_kw_only else pos_with_defaults).append(field.name) fields_with_defaults.append((field.name, True)) exec_globals[f"_default_factory_{field.name}"] = default_factory else: - args_no_defaults.append(field.name) - elif has_default_factory: + (kw_no_defaults if is_kw_only else pos_no_defaults).append(field.name) + elif has_default: ffi_arg_order.append(field.name) fields_with_defaults.append((field.name, False)) exec_globals[f"_default_factory_{field.name}"] = default_factory + # Step 2. Build signature args: list[str] = ["self"] - args.extend(args_no_defaults) - args.extend(f"{name}=MISSING" for name in args_with_defaults) + args.extend(pos_no_defaults) + args.extend(f"{name}=MISSING" for name in pos_with_defaults) + if kw_no_defaults or kw_with_defaults: + args.append("*") + args.extend(kw_no_defaults) + args.extend(f"{name}=MISSING" for name in kw_with_defaults) + + # Step 3. Build body body_lines: list[str] = [] for field_name, is_init in fields_with_defaults: if is_init: @@ -208,6 +231,7 @@ def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]: " fn_post_init()", ] ) + source_lines = [f"def __init__({', '.join(args)}):"] source_lines.extend(f" {line}" for line in body_lines) source_lines.append(" ...") diff --git a/python/tvm_ffi/dataclasses/c_class.py b/python/tvm_ffi/dataclasses/c_class.py index 8171b1b5..8dd5e5ae 100644 --- a/python/tvm_ffi/dataclasses/c_class.py +++ b/python/tvm_ffi/dataclasses/c_class.py @@ -34,14 +34,14 @@ from ..core import TypeField, TypeInfo, _lookup_or_register_type_info_from_type_key, _set_type_cls from . import _utils -from .field import field +from .field import KW_ONLY, field _InputClsType = TypeVar("_InputClsType") -@dataclass_transform(field_specifiers=(field,)) +@dataclass_transform(field_specifiers=(field,), kw_only_default=False) def c_class( - type_key: str, init: bool = True, repr: bool = True + type_key: str, init: bool = True, kw_only: bool = False, repr: bool = True ) -> Callable[[Type[_InputClsType]], Type[_InputClsType]]: # noqa: UP006 """(Experimental) Create a dataclass-like proxy for a C++ class registered with TVM FFI. @@ -71,6 +71,12 @@ def c_class( signature. The generated initializer calls the C++ ``__init__`` function registered with ``ObjectDef`` and invokes ``__post_init__`` if it exists on the Python class. + + kw_only + If ``True``, all fields become keyword-only parameters in the generated + ``__init__``. Individual fields can override this by setting + ``kw_only=False`` in :func:`field`. Additionally, a ``KW_ONLY`` sentinel + annotation can be used to mark all subsequent fields as keyword-only. repr If ``True`` and the Python class does not define ``__repr__``, a representation method is auto-generated that includes all fields with @@ -129,9 +135,15 @@ def decorator(super_type_cls: Type[_InputClsType]) -> Type[_InputClsType]: # no type_info: TypeInfo = _lookup_or_register_type_info_from_type_key(type_key) assert type_info.parent_type_info is not None # Step 2. Reflect all the fields of the type - type_info.fields = _inspect_c_class_fields(super_type_cls, type_info) - for type_field in type_info.fields: - _utils.fill_dataclass_field(super_type_cls, type_field) + type_info.fields, kw_only_start_idx = _inspect_c_class_fields(super_type_cls, type_info) + for idx, type_field in enumerate(type_info.fields): + kw_only_from_sentinel = kw_only_start_idx is not None and idx >= kw_only_start_idx + _utils.fill_dataclass_field( + super_type_cls, + type_field, + class_kw_only=kw_only, + kw_only_from_sentinel=kw_only_from_sentinel, + ) # Step 3. Create the proxy class with the fields as properties fn_init = _utils.method_init(super_type_cls, type_info) if init else None fn_repr = _utils.method_repr(super_type_cls, type_info) if repr else None @@ -146,7 +158,9 @@ def decorator(super_type_cls: Type[_InputClsType]) -> Type[_InputClsType]: # no return decorator -def _inspect_c_class_fields(type_cls: type, type_info: TypeInfo) -> list[TypeField]: +def _inspect_c_class_fields( + type_cls: type, type_info: TypeInfo +) -> tuple[list[TypeField], int | None]: if sys.version_info >= (3, 9): type_hints_resolved = get_type_hints(type_cls, include_extras=True) else: @@ -159,7 +173,24 @@ def _inspect_c_class_fields(type_cls: type, type_info: TypeInfo) -> list[TypeFie ClassVar, InitVar, ] + and type_hints_resolved[name] is not KW_ONLY } + + # Detect KW_ONLY sentinel position + kw_only_start_idx: int | None = None + field_count = 0 + for name in getattr(type_cls, "__annotations__", {}).keys(): + resolved_type = type_hints_resolved.get(name) + if resolved_type is None: + continue + if get_origin(resolved_type) in [ClassVar, InitVar]: + continue + if resolved_type is KW_ONLY: + if kw_only_start_idx is not None: + raise ValueError(f"KW_ONLY may only be used once per class: {type_cls}") + kw_only_start_idx = field_count + continue + field_count += 1 del type_hints_resolved type_fields_cxx: dict[str, TypeField] = {f.name: f for f in type_info.fields} @@ -178,4 +209,4 @@ def _inspect_c_class_fields(type_cls: type, type_info: TypeInfo) -> list[TypeFie raise ValueError( f"Missing fields in `{type_cls}`: {extra_fields}. Defined in C++ but not in Python" ) - return type_fields + return type_fields, kw_only_start_idx diff --git a/python/tvm_ffi/dataclasses/field.py b/python/tvm_ffi/dataclasses/field.py index d0e27b1f..a395e501 100644 --- a/python/tvm_ffi/dataclasses/field.py +++ b/python/tvm_ffi/dataclasses/field.py @@ -21,7 +21,17 @@ from dataclasses import _MISSING_TYPE, MISSING from typing import Any, Callable, TypeVar, cast +try: + from dataclasses import KW_ONLY # type: ignore[attr-defined] +except ImportError: + # Python < 3.10: define our own KW_ONLY sentinel + class _KW_ONLY_Sentinel: + __slots__ = () + + KW_ONLY = _KW_ONLY_Sentinel() + _FieldValue = TypeVar("_FieldValue") +_KW_ONLY_TYPE = type(KW_ONLY) class Field: @@ -37,7 +47,7 @@ class Field: way the decorator understands. """ - __slots__ = ("default_factory", "init", "name", "repr") + __slots__ = ("default_factory", "init", "kw_only", "name", "repr") def __init__( self, @@ -46,12 +56,14 @@ def __init__( default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING, init: bool = True, repr: bool = True, + kw_only: bool | _MISSING_TYPE = MISSING, ) -> None: """Do not call directly; use :func:`field` instead.""" self.name = name self.default_factory = default_factory self.init = init self.repr = repr + self.kw_only = kw_only def field( @@ -60,6 +72,7 @@ def field( default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING, # type: ignore[assignment] init: bool = True, repr: bool = True, + kw_only: bool | _MISSING_TYPE = MISSING, # type: ignore[assignment] ) -> _FieldValue: """(Experimental) Declare a dataclass-style field on a :func:`c_class` proxy. @@ -84,6 +97,10 @@ def field( repr If ``True`` the field is included in the generated ``__repr__``. If ``False`` the field is omitted from the ``__repr__`` output. + kw_only + If ``True``, the field is a keyword-only argument in ``__init__``. + If ``MISSING``, inherits from the class-level ``kw_only`` setting or + from a preceding ``KW_ONLY`` sentinel annotation. Note ---- @@ -124,6 +141,18 @@ class PyBase: obj = PyBase(v_i64=4) obj.v_i32 # -> 16 + Use ``kw_only=True`` to make a field keyword-only: + + .. code-block:: python + + @c_class("testing.TestCxxClassBase") + class PyBase: + v_i64: int + v_i32: int = field(kw_only=True) + + + obj = PyBase(4, v_i32=8) # v_i32 must be keyword + """ if default is not MISSING and default_factory is not MISSING: raise ValueError("Cannot specify both `default` and `default_factory`") @@ -131,9 +160,11 @@ class PyBase: raise TypeError("`init` must be a bool") if not isinstance(repr, bool): raise TypeError("`repr` must be a bool") + if kw_only is not MISSING and not isinstance(kw_only, bool): + raise TypeError(f"`kw_only` must be a bool, got {type(kw_only).__name__!r}") if default is not MISSING: default_factory = _make_default_factory(default) - ret = Field(default_factory=default_factory, init=init, repr=repr) + ret = Field(default_factory=default_factory, init=init, repr=repr, kw_only=kw_only) return cast(_FieldValue, ret) diff --git a/python/tvm_ffi/testing/__init__.py b/python/tvm_ffi/testing/__init__.py index cd357364..af222103 100644 --- a/python/tvm_ffi/testing/__init__.py +++ b/python/tvm_ffi/testing/__init__.py @@ -25,6 +25,7 @@ _TestCxxClassDerived, _TestCxxClassDerivedDerived, _TestCxxInitSubset, + _TestCxxKwOnly, add_one, create_object, make_unregistered_object, diff --git a/python/tvm_ffi/testing/testing.py b/python/tvm_ffi/testing/testing.py index b905b5b1..0ffeb49a 100644 --- a/python/tvm_ffi/testing/testing.py +++ b/python/tvm_ffi/testing/testing.py @@ -178,3 +178,11 @@ class _TestCxxInitSubset: required_field: int optional_field: int = field(init=False) note: str = field(default_factory=lambda: "py-default", init=False) + + +@c_class("testing.TestCxxKwOnly", kw_only=True) +class _TestCxxKwOnly: + x: int + y: int + z: int + w: int = 100 diff --git a/src/ffi/testing/testing.cc b/src/ffi/testing/testing.cc index 7ee6ffd7..0df7f1ea 100644 --- a/src/ffi/testing/testing.cc +++ b/src/ffi/testing/testing.cc @@ -147,6 +147,19 @@ class TestCxxInitSubsetObj : public Object { TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxInitSubset", TestCxxInitSubsetObj, Object); }; +class TestCxxKwOnly : public Object { + public: + int64_t x; + int64_t y; + int64_t z; + int64_t w; + + TestCxxKwOnly(int64_t x, int64_t y, int64_t z, int64_t w) : x(x), y(y), z(z), w(w) {} + + static constexpr bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxKwOnly", TestCxxKwOnly, Object); +}; + class TestUnregisteredBaseObject : public Object { public: int64_t v1; @@ -229,6 +242,13 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def_rw("optional_field", &TestCxxInitSubsetObj::optional_field) .def_rw("note", &TestCxxInitSubsetObj::note); + refl::ObjectDef() + .def(refl::init()) + .def_rw("x", &TestCxxKwOnly::x) + .def_rw("y", &TestCxxKwOnly::y) + .def_rw("z", &TestCxxKwOnly::z) + .def_rw("w", &TestCxxKwOnly::w); + refl::ObjectDef() .def(refl::init(), "Constructor of TestUnregisteredBaseObject") .def_ro("v1", &TestUnregisteredBaseObject::v1) diff --git a/tests/python/test_dataclasses_c_class.py b/tests/python/test_dataclasses_c_class.py index 5361cb64..3a757d08 100644 --- a/tests/python/test_dataclasses_c_class.py +++ b/tests/python/test_dataclasses_c_class.py @@ -15,12 +15,17 @@ # specific language governing permissions and limitations # under the License. import inspect +from dataclasses import MISSING +import pytest +from tvm_ffi.dataclasses import KW_ONLY, field +from tvm_ffi.dataclasses.field import _KW_ONLY_TYPE, Field from tvm_ffi.testing import ( _TestCxxClassBase, _TestCxxClassDerived, _TestCxxClassDerivedDerived, _TestCxxInitSubset, + _TestCxxKwOnly, ) @@ -129,3 +134,53 @@ def test_cxx_class_repr_derived_derived() -> None: assert "v_i32=456" in repr_str assert "v_str='hello'" in repr_str or 'v_str="hello"' in repr_str assert "v_bool=True" in repr_str + + +def test_kw_only_class_level_signature() -> None: + sig = inspect.signature(_TestCxxKwOnly.__init__) + params = sig.parameters + assert params["x"].kind == inspect.Parameter.KEYWORD_ONLY + assert params["y"].kind == inspect.Parameter.KEYWORD_ONLY + assert params["z"].kind == inspect.Parameter.KEYWORD_ONLY + assert params["w"].kind == inspect.Parameter.KEYWORD_ONLY + + +def test_kw_only_class_level_call() -> None: + obj = _TestCxxKwOnly(x=1, y=2, z=3, w=4) + assert obj.x == 1 + assert obj.y == 2 + assert obj.z == 3 + assert obj.w == 4 + + +def test_kw_only_class_level_with_default() -> None: + obj = _TestCxxKwOnly(x=1, y=2, z=3) + assert obj.w == 100 + + +def test_kw_only_class_level_rejects_positional() -> None: + with pytest.raises(TypeError, match="positional"): + _TestCxxKwOnly(1, 2, 3, 4) # type: ignore[misc] + + +def test_field_kw_only_parameter() -> None: + f1: Field = field(kw_only=True) + assert isinstance(f1, Field) + assert f1.kw_only is True + + f2: Field = field(kw_only=False) + assert f2.kw_only is False + + f3: Field = field() + assert f3.kw_only is MISSING + + +def test_field_kw_only_with_default() -> None: + f = field(default=42, kw_only=True) + assert isinstance(f, Field) + assert f.kw_only is True + assert f.default_factory() == 42 + + +def test_kw_only_sentinel_exists() -> None: + assert isinstance(KW_ONLY, _KW_ONLY_TYPE)