Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/tvm_ffi/dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@
from dataclasses import MISSING

from .c_class import c_class
from .field import Field, field
from .field import KW_ONLY, Field, field

__all__ = ["MISSING", "Field", "c_class", "field"]
__all__ = ["KW_ONLY", "MISSING", "Field", "c_class", "field"]
58 changes: 41 additions & 17 deletions python/tvm_ffi/dataclasses/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,13 @@ def _add_method(name: str, func: Callable[..., Any]) -> None:
return cast(Type[_InputClsType], new_cls)


def fill_dataclass_field(type_cls: type, type_field: TypeField) -> None:
def fill_dataclass_field(
type_cls: type,
type_field: TypeField,
*,
class_kw_only: bool = False,
kw_only_from_sentinel: bool = False,
) -> None:
from .field import Field, field # noqa: PLC0415

field_name = type_field.name
Expand All @@ -94,6 +100,14 @@ def fill_dataclass_field(type_cls: type, type_field: TypeField) -> None:
raise ValueError(f"Cannot recognize field: {type_field.name}: {rhs}")
assert isinstance(rhs, Field)
rhs.name = type_field.name

# Resolve kw_only: field-level > KW_ONLY sentinel > class-level
if rhs.kw_only is MISSING:
if kw_only_from_sentinel:
rhs.kw_only = True
else:
rhs.kw_only = class_kw_only

type_field.dataclass_field = rhs


Expand Down Expand Up @@ -148,47 +162,56 @@ def method_repr(type_cls: type, type_info: TypeInfo) -> Callable[..., str]:
return __repr__


def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
def method_init(_type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
"""Generate an ``__init__`` that forwards to the FFI constructor.

The generated initializer has a proper Python signature built from the
reflected field list, supporting default values and ``__post_init__``.
reflected field list, supporting default values, keyword-only args, and ``__post_init__``.
"""
# Step 0. Collect all fields from the type hierarchy
fields = _get_all_fields(type_info)
# sanity check
for type_method in type_info.methods:
if type_method.name == "__ffi_init__":
break
else:
if not any(m.name == "__ffi_init__" for m in type_info.methods):
raise ValueError(f"Cannot find constructor method: `{type_info.type_key}.__ffi_init__`")
# Step 1. Split args into sections and register default factories
args_no_defaults: list[str] = []
args_with_defaults: list[str] = []
pos_no_defaults: list[str] = []
pos_with_defaults: list[str] = []
kw_no_defaults: list[str] = []
kw_with_defaults: list[str] = []
fields_with_defaults: list[tuple[str, bool]] = []
ffi_arg_order: list[str] = []
exec_globals = {"MISSING": MISSING}
exec_globals: dict[str, Any] = {"MISSING": MISSING}

for field in fields:
assert field.name is not None
assert field.dataclass_field is not None
dataclass_field = field.dataclass_field
has_default_factory = (default_factory := dataclass_field.default_factory) is not MISSING
has_default = (default_factory := dataclass_field.default_factory) is not MISSING
is_kw_only = dataclass_field.kw_only is True

if dataclass_field.init:
ffi_arg_order.append(field.name)
if has_default_factory:
args_with_defaults.append(field.name)
if has_default:
(kw_with_defaults if is_kw_only else pos_with_defaults).append(field.name)
fields_with_defaults.append((field.name, True))
exec_globals[f"_default_factory_{field.name}"] = default_factory
else:
args_no_defaults.append(field.name)
elif has_default_factory:
(kw_no_defaults if is_kw_only else pos_no_defaults).append(field.name)
elif has_default:
ffi_arg_order.append(field.name)
fields_with_defaults.append((field.name, False))
exec_globals[f"_default_factory_{field.name}"] = default_factory

# Step 2. Build signature
args: list[str] = ["self"]
args.extend(args_no_defaults)
args.extend(f"{name}=MISSING" for name in args_with_defaults)
args.extend(pos_no_defaults)
args.extend(f"{name}=MISSING" for name in pos_with_defaults)
if kw_no_defaults or kw_with_defaults:
args.append("*")
args.extend(kw_no_defaults)
args.extend(f"{name}=MISSING" for name in kw_with_defaults)

# Step 3. Build body
body_lines: list[str] = []
for field_name, is_init in fields_with_defaults:
if is_init:
Expand All @@ -208,6 +231,7 @@ def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
" fn_post_init()",
]
)

source_lines = [f"def __init__({', '.join(args)}):"]
source_lines.extend(f" {line}" for line in body_lines)
source_lines.append(" ...")
Expand Down
47 changes: 39 additions & 8 deletions python/tvm_ffi/dataclasses/c_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@

from ..core import TypeField, TypeInfo, _lookup_or_register_type_info_from_type_key, _set_type_cls
from . import _utils
from .field import field
from .field import KW_ONLY, field

_InputClsType = TypeVar("_InputClsType")


@dataclass_transform(field_specifiers=(field,))
@dataclass_transform(field_specifiers=(field,), kw_only_default=False)
def c_class(
type_key: str, init: bool = True, repr: bool = True
type_key: str, init: bool = True, kw_only: bool = False, repr: bool = True
) -> Callable[[Type[_InputClsType]], Type[_InputClsType]]: # noqa: UP006
"""(Experimental) Create a dataclass-like proxy for a C++ class registered with TVM FFI.

Expand Down Expand Up @@ -71,6 +71,12 @@ def c_class(
signature. The generated initializer calls the C++ ``__init__``
function registered with ``ObjectDef`` and invokes ``__post_init__`` if
it exists on the Python class.

kw_only
If ``True``, all fields become keyword-only parameters in the generated
``__init__``. Individual fields can override this by setting
``kw_only=False`` in :func:`field`. Additionally, a ``KW_ONLY`` sentinel
annotation can be used to mark all subsequent fields as keyword-only.
repr
If ``True`` and the Python class does not define ``__repr__``, a
representation method is auto-generated that includes all fields with
Expand Down Expand Up @@ -129,9 +135,15 @@ def decorator(super_type_cls: Type[_InputClsType]) -> Type[_InputClsType]: # no
type_info: TypeInfo = _lookup_or_register_type_info_from_type_key(type_key)
assert type_info.parent_type_info is not None
# Step 2. Reflect all the fields of the type
type_info.fields = _inspect_c_class_fields(super_type_cls, type_info)
for type_field in type_info.fields:
_utils.fill_dataclass_field(super_type_cls, type_field)
type_info.fields, kw_only_start_idx = _inspect_c_class_fields(super_type_cls, type_info)
for idx, type_field in enumerate(type_info.fields):
kw_only_from_sentinel = kw_only_start_idx is not None and idx >= kw_only_start_idx
_utils.fill_dataclass_field(
super_type_cls,
type_field,
class_kw_only=kw_only,
kw_only_from_sentinel=kw_only_from_sentinel,
)
# Step 3. Create the proxy class with the fields as properties
fn_init = _utils.method_init(super_type_cls, type_info) if init else None
fn_repr = _utils.method_repr(super_type_cls, type_info) if repr else None
Expand All @@ -146,7 +158,9 @@ def decorator(super_type_cls: Type[_InputClsType]) -> Type[_InputClsType]: # no
return decorator


def _inspect_c_class_fields(type_cls: type, type_info: TypeInfo) -> list[TypeField]:
def _inspect_c_class_fields(
type_cls: type, type_info: TypeInfo
) -> tuple[list[TypeField], int | None]:
if sys.version_info >= (3, 9):
type_hints_resolved = get_type_hints(type_cls, include_extras=True)
else:
Expand All @@ -159,7 +173,24 @@ def _inspect_c_class_fields(type_cls: type, type_info: TypeInfo) -> list[TypeFie
ClassVar,
InitVar,
]
and type_hints_resolved[name] is not KW_ONLY
}

# Detect KW_ONLY sentinel position
kw_only_start_idx: int | None = None
field_count = 0
for name in getattr(type_cls, "__annotations__", {}).keys():
resolved_type = type_hints_resolved.get(name)
if resolved_type is None:
continue
if get_origin(resolved_type) in [ClassVar, InitVar]:
continue
if resolved_type is KW_ONLY:
if kw_only_start_idx is not None:
raise ValueError(f"KW_ONLY may only be used once per class: {type_cls}")
kw_only_start_idx = field_count
continue
field_count += 1
del type_hints_resolved

type_fields_cxx: dict[str, TypeField] = {f.name: f for f in type_info.fields}
Expand All @@ -178,4 +209,4 @@ def _inspect_c_class_fields(type_cls: type, type_info: TypeInfo) -> list[TypeFie
raise ValueError(
f"Missing fields in `{type_cls}`: {extra_fields}. Defined in C++ but not in Python"
)
return type_fields
return type_fields, kw_only_start_idx
28 changes: 25 additions & 3 deletions python/tvm_ffi/dataclasses/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@

from __future__ import annotations

from dataclasses import _MISSING_TYPE, MISSING
from dataclasses import _MISSING_TYPE, KW_ONLY, MISSING # type: ignore[attr-defined]
from typing import Any, Callable, TypeVar, cast

_FieldValue = TypeVar("_FieldValue")
_KW_ONLY_TYPE = type(KW_ONLY)


class Field:
Expand All @@ -37,7 +38,7 @@ class Field:
way the decorator understands.
"""

__slots__ = ("default_factory", "init", "name", "repr")
__slots__ = ("default_factory", "init", "kw_only", "name", "repr")

def __init__(
self,
Expand All @@ -46,12 +47,14 @@ def __init__(
default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING,
init: bool = True,
repr: bool = True,
kw_only: bool | _MISSING_TYPE = MISSING,
) -> None:
"""Do not call directly; use :func:`field` instead."""
self.name = name
self.default_factory = default_factory
self.init = init
self.repr = repr
self.kw_only = kw_only


def field(
Expand All @@ -60,6 +63,7 @@ def field(
default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING, # type: ignore[assignment]
init: bool = True,
repr: bool = True,
kw_only: bool | _MISSING_TYPE = MISSING, # type: ignore[assignment]
) -> _FieldValue:
"""(Experimental) Declare a dataclass-style field on a :func:`c_class` proxy.

Expand All @@ -84,6 +88,10 @@ def field(
repr
If ``True`` the field is included in the generated ``__repr__``.
If ``False`` the field is omitted from the ``__repr__`` output.
kw_only
If ``True``, the field is a keyword-only argument in ``__init__``.
If ``MISSING``, inherits from the class-level ``kw_only`` setting or
from a preceding ``KW_ONLY`` sentinel annotation.

Note
----
Expand Down Expand Up @@ -124,16 +132,30 @@ class PyBase:
obj = PyBase(v_i64=4)
obj.v_i32 # -> 16

Use ``kw_only=True`` to make a field keyword-only:

.. code-block:: python

@c_class("testing.TestCxxClassBase")
class PyBase:
v_i64: int
v_i32: int = field(kw_only=True)


obj = PyBase(4, v_i32=8) # v_i32 must be keyword

"""
if default is not MISSING and default_factory is not MISSING:
raise ValueError("Cannot specify both `default` and `default_factory`")
if not isinstance(init, bool):
raise TypeError("`init` must be a bool")
if not isinstance(repr, bool):
raise TypeError("`repr` must be a bool")
if kw_only is not MISSING and not isinstance(kw_only, bool):
raise TypeError(f"`kw_only` must be a bool, got {type(kw_only).__name__!r}")
if default is not MISSING:
default_factory = _make_default_factory(default)
ret = Field(default_factory=default_factory, init=init, repr=repr)
ret = Field(default_factory=default_factory, init=init, repr=repr, kw_only=kw_only)
return cast(_FieldValue, ret)


Expand Down
1 change: 1 addition & 0 deletions python/tvm_ffi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
_TestCxxClassDerived,
_TestCxxClassDerivedDerived,
_TestCxxInitSubset,
_TestCxxKwOnly,
add_one,
create_object,
make_unregistered_object,
Expand Down
8 changes: 8 additions & 0 deletions python/tvm_ffi/testing/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,11 @@ class _TestCxxInitSubset:
required_field: int
optional_field: int = field(init=False)
note: str = field(default_factory=lambda: "py-default", init=False)


@c_class("testing.TestCxxKwOnly", kw_only=True)
class _TestCxxKwOnly:
x: int
y: int
z: int
w: int = 100
20 changes: 20 additions & 0 deletions src/ffi/testing/testing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,19 @@ class TestCxxInitSubsetObj : public Object {
TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxInitSubset", TestCxxInitSubsetObj, Object);
};

class TestCxxKwOnly : public Object {
public:
int64_t x;
int64_t y;
int64_t z;
int64_t w;

TestCxxKwOnly(int64_t x, int64_t y, int64_t z, int64_t w) : x(x), y(y), z(z), w(w) {}

static constexpr bool _type_mutable = true;
TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxKwOnly", TestCxxKwOnly, Object);
};

class TestUnregisteredBaseObject : public Object {
public:
int64_t v1;
Expand Down Expand Up @@ -229,6 +242,13 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def_rw("optional_field", &TestCxxInitSubsetObj::optional_field)
.def_rw("note", &TestCxxInitSubsetObj::note);

refl::ObjectDef<TestCxxKwOnly>()
.def(refl::init<int64_t, int64_t, int64_t, int64_t>())
.def_rw("x", &TestCxxKwOnly::x)
.def_rw("y", &TestCxxKwOnly::y)
.def_rw("z", &TestCxxKwOnly::z)
.def_rw("w", &TestCxxKwOnly::w);

refl::ObjectDef<TestUnregisteredBaseObject>()
.def(refl::init<int64_t>(), "Constructor of TestUnregisteredBaseObject")
.def_ro("v1", &TestUnregisteredBaseObject::v1)
Expand Down
Loading
Loading