Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/tvm_ffi/dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@
from dataclasses import MISSING

from .c_class import c_class
from .field import Field, field
from .field import KW_ONLY, Field, field

__all__ = ["MISSING", "Field", "c_class", "field"]
__all__ = ["KW_ONLY", "MISSING", "Field", "c_class", "field"]
132 changes: 89 additions & 43 deletions python/tvm_ffi/dataclasses/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,13 @@ def _add_method(name: str, func: Callable[..., Any]) -> None:
return cast(Type[_InputClsType], new_cls)


def fill_dataclass_field(type_cls: type, type_field: TypeField) -> None:
def fill_dataclass_field(
type_cls: type,
type_field: TypeField,
*,
class_kw_only: bool = False,
kw_only_from_sentinel: bool = False,
) -> None:
from .field import Field, field # noqa: PLC0415

field_name = type_field.name
Expand All @@ -95,55 +101,45 @@ def fill_dataclass_field(type_cls: type, type_field: TypeField) -> None:
raise ValueError(f"Cannot recognize field: {type_field.name}: {rhs}")
assert isinstance(rhs, Field)
rhs.name = type_field.name
type_field.dataclass_field = rhs

# Resolve kw_only: field-level > KW_ONLY sentinel > class-level
if rhs.kw_only is MISSING:
if kw_only_from_sentinel:
rhs.kw_only = True
else:
rhs.kw_only = class_kw_only

def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
"""Generate an ``__init__`` that forwards to the FFI constructor.
type_field.dataclass_field = rhs

The generated initializer has a proper Python signature built from the
reflected field list, supporting default values and ``__post_init__``.
"""
# Step 0. Collect all fields from the type hierarchy

def _collect_fields_from_hierarchy(type_info: TypeInfo) -> list[TypeField]:
fields: list[TypeField] = []
cur_type_info: TypeInfo | None = type_info
while cur_type_info is not None:
fields.extend(reversed(cur_type_info.fields))
cur_type_info = cur_type_info.parent_type_info
fields.reverse()
# sanity check
for type_method in type_info.methods:
if type_method.name == "__ffi_init__":
break
else:
raise ValueError(f"Cannot find constructor method: `{type_info.type_key}.__ffi_init__`")
# Step 1. Split args into sections and register default factories
args_no_defaults: list[str] = []
args_with_defaults: list[str] = []
fields_with_defaults: list[tuple[str, bool]] = []
ffi_arg_order: list[str] = []
exec_globals = {"MISSING": MISSING}
for field in fields:
assert field.name is not None
assert field.dataclass_field is not None
dataclass_field = field.dataclass_field
has_default_factory = (default_factory := dataclass_field.default_factory) is not MISSING
if dataclass_field.init:
ffi_arg_order.append(field.name)
if has_default_factory:
args_with_defaults.append(field.name)
fields_with_defaults.append((field.name, True))
exec_globals[f"_default_factory_{field.name}"] = default_factory
else:
args_no_defaults.append(field.name)
elif has_default_factory:
ffi_arg_order.append(field.name)
fields_with_defaults.append((field.name, False))
exec_globals[f"_default_factory_{field.name}"] = default_factory

return fields


def _build_init_source(
pos_no_defaults: list[str],
pos_with_defaults: list[str],
kw_no_defaults: list[str],
kw_with_defaults: list[str],
fields_with_defaults: list[tuple[str, bool]],
ffi_arg_order: list[str],
) -> str:
# Build signature
args: list[str] = ["self"]
args.extend(args_no_defaults)
args.extend(f"{name}=MISSING" for name in args_with_defaults)
args.extend(pos_no_defaults)
args.extend(f"{name}=MISSING" for name in pos_with_defaults)
if kw_no_defaults or kw_with_defaults:
args.append("*")
args.extend(kw_no_defaults)
args.extend(f"{name}=MISSING" for name in kw_with_defaults)

# Build body
body_lines: list[str] = []
for field_name, is_init in fields_with_defaults:
if is_init:
Expand All @@ -163,14 +159,64 @@ def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
" fn_post_init()",
]
)

source_lines = [f"def __init__({', '.join(args)}):"]
source_lines.extend(f" {line}" for line in body_lines)
source_lines.append(" ...")
source = "\n".join(source_lines)
return "\n".join(source_lines)


def method_init(_type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
"""Generate an ``__init__`` that forwards to the FFI constructor.

The generated initializer has a proper Python signature built from the
reflected field list, supporting default values, keyword-only args, and ``__post_init__``.
"""
fields = _collect_fields_from_hierarchy(type_info)
# sanity check
if not any(m.name == "__ffi_init__" for m in type_info.methods):
raise ValueError(f"Cannot find constructor method: `{type_info.type_key}.__ffi_init__`")

# Split args into sections and register default factories
pos_no_defaults: list[str] = []
pos_with_defaults: list[str] = []
kw_no_defaults: list[str] = []
kw_with_defaults: list[str] = []
fields_with_defaults: list[tuple[str, bool]] = []
ffi_arg_order: list[str] = []
exec_globals: dict[str, Any] = {"MISSING": MISSING}

for field in fields:
assert field.name is not None
assert field.dataclass_field is not None
dataclass_field = field.dataclass_field
has_default = (default_factory := dataclass_field.default_factory) is not MISSING
is_kw_only = dataclass_field.kw_only is True

if dataclass_field.init:
ffi_arg_order.append(field.name)
if has_default:
(kw_with_defaults if is_kw_only else pos_with_defaults).append(field.name)
fields_with_defaults.append((field.name, True))
exec_globals[f"_default_factory_{field.name}"] = default_factory
else:
(kw_no_defaults if is_kw_only else pos_no_defaults).append(field.name)
elif has_default:
ffi_arg_order.append(field.name)
fields_with_defaults.append((field.name, False))
exec_globals[f"_default_factory_{field.name}"] = default_factory

source = _build_init_source(
pos_no_defaults,
pos_with_defaults,
kw_no_defaults,
kw_with_defaults,
fields_with_defaults,
ffi_arg_order,
)
# Note: Code generation in this case is guaranteed to be safe,
# because the generated code does not contain any untrusted input.
# This is also a common practice used by `dataclasses` and `pydantic`.
namespace: dict[str, Any] = {}
exec(source, exec_globals, namespace)
__init__ = namespace["__init__"]
return __init__
return namespace["__init__"]
87 changes: 70 additions & 17 deletions python/tvm_ffi/dataclasses/c_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@

from ..core import TypeField, TypeInfo, _lookup_or_register_type_info_from_type_key, _set_type_cls
from . import _utils
from .field import field
from .field import _KW_ONLY_TYPE, field

_InputClsType = TypeVar("_InputClsType")


@dataclass_transform(field_specifiers=(field,))
@dataclass_transform(field_specifiers=(field,), kw_only_default=False)
def c_class(
type_key: str, init: bool = True
type_key: str, init: bool = True, kw_only: bool = False
) -> Callable[[Type[_InputClsType]], Type[_InputClsType]]: # noqa: UP006
"""(Experimental) Create a dataclass-like proxy for a C++ class registered with TVM FFI.

Expand Down Expand Up @@ -72,6 +72,12 @@ def c_class(
function registered with ``ObjectDef`` and invokes ``__post_init__`` if
it exists on the Python class.

kw_only
If ``True``, all fields become keyword-only parameters in the generated
``__init__``. Individual fields can override this by setting
``kw_only=False`` in :func:`field`. Additionally, a ``KW_ONLY`` sentinel
annotation can be used to mark all subsequent fields as keyword-only.

Returns
-------
Callable[[type], type]
Expand Down Expand Up @@ -115,6 +121,31 @@ class MyClass:
obj = MyClass(v_i64=4, v_i32=8)
obj.v_f64 = 3.14 # transparently forwards to the underlying C++ object

Use ``kw_only=True`` to make all fields keyword-only:

.. code-block:: python

@c_class("example.MyClass", kw_only=True)
class MyClass:
x: int
y: int


obj = MyClass(x=1, y=2) # all args must be keyword

Use ``KW_ONLY`` sentinel to make subsequent fields keyword-only:

.. code-block:: python

from tvm_ffi.dataclasses import KW_ONLY


@c_class("example.MyClass")
class MyClass:
x: int
_: KW_ONLY
y: int # keyword-only

"""

def decorator(super_type_cls: Type[_InputClsType]) -> Type[_InputClsType]: # noqa: UP006
Expand All @@ -124,9 +155,15 @@ def decorator(super_type_cls: Type[_InputClsType]) -> Type[_InputClsType]: # no
type_info: TypeInfo = _lookup_or_register_type_info_from_type_key(type_key)
assert type_info.parent_type_info is not None
# Step 2. Reflect all the fields of the type
type_info.fields = _inspect_c_class_fields(super_type_cls, type_info)
for type_field in type_info.fields:
_utils.fill_dataclass_field(super_type_cls, type_field)
type_info.fields, kw_only_start_idx = _inspect_c_class_fields(super_type_cls, type_info)
for idx, type_field in enumerate(type_info.fields):
kw_only_from_sentinel = kw_only_start_idx is not None and idx >= kw_only_start_idx
_utils.fill_dataclass_field(
super_type_cls,
type_field,
class_kw_only=kw_only,
kw_only_from_sentinel=kw_only_from_sentinel,
)
# Step 3. Create the proxy class with the fields as properties
fn_init = _utils.method_init(super_type_cls, type_info) if init else None
type_cls: Type[_InputClsType] = _utils.type_info_to_cls( # noqa: UP006
Expand All @@ -140,20 +177,36 @@ def decorator(super_type_cls: Type[_InputClsType]) -> Type[_InputClsType]: # no
return decorator


def _inspect_c_class_fields(type_cls: type, type_info: TypeInfo) -> list[TypeField]:
def _inspect_c_class_fields(
type_cls: type, type_info: TypeInfo
) -> tuple[list[TypeField], int | None]:
if sys.version_info >= (3, 9):
type_hints_resolved = get_type_hints(type_cls, include_extras=True)
else:
type_hints_resolved = get_type_hints(type_cls)
type_hints_py = {
name: type_hints_resolved[name]
for name in getattr(type_cls, "__annotations__", {}).keys()
if get_origin(type_hints_resolved[name])
not in [ # ignore non-field annotations
ClassVar,
InitVar,
]
}

# Filter out ClassVar, InitVar, and detect KW_ONLY
annotations = getattr(type_cls, "__annotations__", {})
type_hints_py: dict[str, type] = {}
kw_only_start_idx: int | None = None
field_count = 0

for name in annotations.keys():
resolved_type = type_hints_resolved.get(name)
if resolved_type is None:
continue
origin = get_origin(resolved_type)
# Skip ClassVar and InitVar
if origin in [ClassVar, InitVar]:
continue
# Detect KW_ONLY sentinel
if isinstance(resolved_type, _KW_ONLY_TYPE):
if kw_only_start_idx is not None:
raise ValueError(f"KW_ONLY may only be used once per class: {type_cls}")
kw_only_start_idx = field_count
continue
type_hints_py[name] = resolved_type
field_count += 1
del type_hints_resolved

type_fields_cxx: dict[str, TypeField] = {f.name: f for f in type_info.fields}
Expand All @@ -172,4 +225,4 @@ def _inspect_c_class_fields(type_cls: type, type_info: TypeInfo) -> list[TypeFie
raise ValueError(
f"Missing fields in `{type_cls}`: {extra_fields}. Defined in C++ but not in Python"
)
return type_fields
return type_fields, kw_only_start_idx
Loading