Skip to content

Commit

Permalink
add some more helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Jan 29, 2025
1 parent 2a6787f commit f0c02ac
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 16 deletions.
17 changes: 9 additions & 8 deletions projects/eudsl-llvmpy/eudsl-llvmpy-generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,14 +287,15 @@ class LLVMMatchType(Generic[_T]):
):
arg_types = []
ret_types = []
for p in intr.get_values().ParamTypes.value:
p_s = p.as_string
for p in intr.get_values().ParamTypes.get_value():
p_s = p.get_as_string()
if p_s.startswith("anon"):
p_s = p.type.as_string
p_s = p.get_type().get_as_string()
pdv = p.get_def().get_values()
if p_s == "LLVMMatchType":
p_s += f"[{p.def_.values.Number.value.value}]"
p_s += f"[{pdv.Number.get_value()}]"
elif p_s == "LLVMQualPointerType":
_, addr_space = p.def_.values.Sig.value.values
kind, addr_space = pdv.Sig.get_value()
p_s += f"[{addr_space}]"
else:
raise NotImplemented(f"unsupported {p_s=}")
Expand All @@ -307,8 +308,8 @@ class LLVMMatchType(Generic[_T]):
p_s = "pointer"

arg_types.append(p_s)
for p in intr.get_values().RetTypes.value:
ret_types.append(p.as_string)
for p in intr.get_values().RetTypes.get_value():
ret_types.append(p.get_as_string())

ret_str = ""
if len(ret_types):
Expand Down Expand Up @@ -383,5 +384,5 @@ def generate_nb_bindings(header_root: Path, output_root: Path):
parser.add_argument("llvmpy_module_dir", type=Path)
args = parser.parse_args()

generate_nb_bindings(args.llvm_include_root / "llvm-c", args.output_root)
# generate_nb_bindings(args.llvm_include_root / "llvm-c", args.output_root)
generate_amdgcn_intrinsics(args.llvm_include_root, args.llvmpy_module_dir)
52 changes: 51 additions & 1 deletion projects/eudsl-tblgen/src/eudsl_tblgen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Copyright (c) 2024.
from typing import List, Optional

from .eudsl_tblgen_ext import *

Expand Down Expand Up @@ -39,7 +40,56 @@ def get_requested_op_definitions(records, op_inc_filter=None, op_exc_filter=None
# Unless there is an exclude filter and it matches.
if op_exc_filter and exclude_regex.match(get_operation_name(def_record)):
continue
def_record.dump()
defs.append(def_record)

return defs


def collect_all_defs(
record_keeper: RecordKeeper,
selected_dialect: Optional[str] = None,
) -> List[AttrOrTypeDef]:
records = record_keeper.get_defs()
records = [records[d] for d in records]
# Nothing to do if no defs were found.
if not records:
return []

defs = [
AttrOrTypeDef(rec)
for rec in records
if rec.get_value("builders") and rec.get_value("parameters")
]
result_defs = []

if not selected_dialect:
# If a dialect was not specified, ensure that all found defs belong to the same dialect.
dialects = {definition.get_dialect().get_name() for definition in defs}
if len(dialects) > 1:
raise RuntimeError(
"Defs belong to more than one dialect. Must select one via '--(attr|type)defs-dialect'"
)
result_defs.extend(defs)
else:
# Otherwise, generate the defs that belong to the selected dialect.
dialect_defs = [
definition
for definition in defs
if definition.get_dialect().get_name() == selected_dialect
]
result_defs.extend(dialect_defs)

return result_defs


def get_all_type_constraints(records: RecordKeeper) -> List[Constraint]:
result = []
for record in records.get_all_derived_definitions_if_defined("TypeConstraint"):
# Ignore constraints defined outside of the top-level file.
constr = Constraint(record)
# Generate C++ function only if "cppFunctionName" is set.
if not constr.get_cpp_function_name():
continue
result.append(constr)
return result

15 changes: 12 additions & 3 deletions projects/eudsl-tblgen/src/eudsl_tblgen_ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,8 @@ NB_MODULE(eudsl_tblgen_ext, m) {
nb::rv_policy::reference_internal)
.def("get_name_init_as_string", &llvm::Record::getNameInitAsString)
.def("set_name", &llvm::Record::setName, "name"_a)
.def("get_loc", &llvm::Record::getLoc)
.def("get_loc", eudsl::coerceReturn<std::vector<llvm::SMLoc>>(
&llvm::Record::getLoc, nb::const_))
.def("append_loc", &llvm::Record::appendLoc, "loc"_a)
.def("get_forward_declaration_locs",
&llvm::Record::getForwardDeclarationLocs)
Expand Down Expand Up @@ -1088,8 +1089,8 @@ NB_MODULE(eudsl_tblgen_ext, m) {
const std::vector<std::string> &macroNames,
bool noWarnOnUnusedTemplateArgs) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename,
/*IsText=*/true);
llvm::MemoryBuffer::getFile(inputFilename,
/*IsText=*/true);
if (std::error_code EC = fileOrErr.getError())
throw std::runtime_error("Could not open input file '" +
inputFilename + "': " + EC.message() +
Expand Down Expand Up @@ -1145,6 +1146,13 @@ NB_MODULE(eudsl_tblgen_ext, m) {
-> std::vector<const llvm::Record *> {
return self.getAllDerivedDefinitions(className);
},
"class_name"_a, nb::rv_policy::reference_internal)
.def(
"get_all_derived_definitions_if_defined",
[](llvm::RecordKeeper &self, const std::string &className)
-> std::vector<const llvm::Record *> {
return self.getAllDerivedDefinitionsIfDefined(className);
},
"class_name"_a, nb::rv_policy::reference_internal);

nb::class_<llvm::raw_ostream>(m, "raw_ostream");
Expand Down Expand Up @@ -1239,6 +1247,7 @@ NB_MODULE(eudsl_tblgen_ext, m) {
.def("get_kind", &mlir::tblgen::Constraint::getKind)
.def("get_def", &mlir::tblgen::Constraint::getDef,
nb::rv_policy::reference_internal);

nb::enum_<mlir::tblgen::Constraint::Kind>(mlir_tblgen_Constraint, "Kind")
.value("CK_Attr", mlir::tblgen::Constraint::CK_Attr)
.value("CK_Region", mlir::tblgen::Constraint::CK_Region)
Expand Down
4 changes: 4 additions & 0 deletions projects/eudsl-tblgen/tests/td/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -918,4 +918,8 @@ def SignlessIntegerOrFloatLike : TypeConstraint<Or<[
SignlessIntegerLike.predicate, FloatLike.predicate]>,
"signless-integer-like or floating-point-like">;

def DummyConstraint : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
let cppFunctionName = "isValidDummy";
}

#endif // COMMON_TYPE_CONSTRAINTS_TD
24 changes: 20 additions & 4 deletions projects/eudsl-tblgen/tests/test_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from pathlib import Path

import pytest
from eudsl_tblgen import RecordKeeper, get_requested_op_definitions
from eudsl_tblgen import (
RecordKeeper,
get_requested_op_definitions,
get_all_type_constraints,
collect_all_defs,
)


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -174,7 +179,7 @@ def test_init_complex(record_keeper_test_dialect):

assert (
repr(op.get_values())
== "RecordValues(opDialect=Test_Dialect, opName=types, cppNamespace=test, summary=, description=, opDocGroup=?, arguments=(ins I32:$a, SI64:$b, UI8:$c, Index:$d, F32:$e, NoneType:$f, anonymous_347), results=(outs), regions=(region), successors=(successor), builders=?, skipDefaultBuilders=0, assemblyFormat=?, hasCustomAssemblyFormat=0, hasVerifier=0, hasRegionVerifier=0, hasCanonicalizer=0, hasCanonicalizeMethod=0, hasFolder=0, useCustomPropertiesEncoding=0, traits=[], extraClassDeclaration=?, extraClassDefinition=?)"
== "RecordValues(opDialect=Test_Dialect, opName=types, cppNamespace=test, summary=, description=, opDocGroup=?, arguments=(ins I32:$a, SI64:$b, UI8:$c, Index:$d, F32:$e, NoneType:$f, anonymous_348), results=(outs), regions=(region), successors=(successor), builders=?, skipDefaultBuilders=0, assemblyFormat=?, hasCustomAssemblyFormat=0, hasVerifier=0, hasRegionVerifier=0, hasCanonicalizer=0, hasCanonicalizeMethod=0, hasFolder=0, useCustomPropertiesEncoding=0, traits=[], extraClassDeclaration=?, extraClassDefinition=?)"
)

arguments = op.get_values().arguments
Expand All @@ -193,7 +198,7 @@ def test_init_complex(record_keeper_test_dialect):
assert str(arguments.get_value()[5]) == "NoneType"

attr = record_keeper_test_dialect.get_defs()["Test_TestAttr"]
assert str(attr.get_values().predicate) == "anonymous_334"
assert str(attr.get_values().predicate) == "anonymous_335"
assert str(attr.get_values().storageType) == "test::TestAttr"
assert str(attr.get_values().returnType) == "test::TestAttr"
assert (
Expand Down Expand Up @@ -228,4 +233,15 @@ def test_init_complex(record_keeper_test_dialect):

def test_mlir_tblgen(record_keeper_test_dialect):
for op in get_requested_op_definitions(record_keeper_test_dialect):
op.dump()
print(op.get_name())
for constraint in get_all_type_constraints(record_keeper_test_dialect):
print(constraint.get_def_name())
print(constraint.get_summary())

all_defs = collect_all_defs(record_keeper_test_dialect)
for d in all_defs:
print(d.get_name())

all_defs = collect_all_defs(record_keeper_test_dialect, "test")
for d in all_defs:
print(d.get_name())

0 comments on commit f0c02ac

Please sign in to comment.