Skip to content

Commit

Permalink
Define internal macros (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
raviqqe committed May 8, 2023
1 parent 4a2a450 commit b3ab31f
Show file tree
Hide file tree
Showing 13 changed files with 215 additions and 47 deletions.
11 changes: 10 additions & 1 deletion .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,36 @@
"words": [
"addf",
"addi",
"amdgpu",
"bfloat",
"bufferization",
"canonicalize",
"canonicalizer",
"divf",
"divi",
"femtomc",
"funcs",
"hasher",
"indoc",
"insta",
"interp",
"libm",
"linalg",
"melior",
"memref",
"mlir",
"mulf",
"muli",
"nvgpu",
"nvvm",
"rocdl",
"rustc",
"sccp",
"spirv",
"stdc",
"subf",
"subi"
"subi",
"tosa",
"vulkan"
]
}
28 changes: 28 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[workspace]
members = ["melior"]
members = ["macro", "melior"]

[profile.release]
lto = true
23 changes: 23 additions & 0 deletions macro/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[package]
name = "melior-macro"
description = "Internal macros for Melior"
version = "0.1.0"
edition = "2021"
license = "Apache-2.0"
repository = "https://github.com/raviqqe/melior"
documentation = "https://raviqqe.github.io/melior/melior/"
readme = "../README.md"
keywords = ["mlir", "llvm"]

[lib]
proc-macro = true

[dependencies]
convert_case = "0.6.0"
proc-macro2 = "1"
quote = "1"
syn = { version = "2", features = ["full"] }

[dev-dependencies]
melior = { path = "../melior" }
mlir-sys = "0.2"
33 changes: 33 additions & 0 deletions macro/src/conversion_passes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use convert_case::{Case, Casing};
use proc_macro::TokenStream;
use proc_macro2::Ident;
use quote::quote;
use std::error::Error;

pub fn generate(identifiers: &[Ident]) -> Result<TokenStream, Box<dyn Error>> {
let mut stream = TokenStream::new();

for identifier in identifiers {
let mut name = identifier.to_string();

if let Some(other) = name.strip_prefix("mlirCreateConversion") {
name = other.into();
}

if let Some(other) = name.strip_suffix("ConversionPass") {
name = other.into();
}

let function_name = Ident::new(&name.to_case(Case::Snake), identifier.span());
let document = format!(" Creates a pass of `{}`.", name);

stream.extend(TokenStream::from(quote! {
#[doc = #document]
pub fn #function_name() -> crate::pass::Pass {
crate::pass::Pass::__private_from_raw_fn(mlir_sys::#identifier)
}
}));
}

Ok(stream)
}
23 changes: 23 additions & 0 deletions macro/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
mod conversion_passes;
mod parse;

use parse::IdentifierList;
use proc_macro::TokenStream;
use quote::quote;
use std::error::Error;
use syn::parse_macro_input;

#[proc_macro]
pub fn conversion_passes(stream: TokenStream) -> TokenStream {
let identifiers = parse_macro_input!(stream as IdentifierList);

convert_result(conversion_passes::generate(identifiers.identifiers()))
}

fn convert_result(result: Result<TokenStream, Box<dyn Error>>) -> TokenStream {
result.unwrap_or_else(|error| {
let message = error.to_string();

quote! { compile_error!(#message) }.into()
})
}
26 changes: 26 additions & 0 deletions macro/src/parse.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use proc_macro2::Ident;
use syn::{
parse::{Parse, ParseStream},
punctuated::Punctuated,
Result, Token,
};

pub struct IdentifierList {
identifiers: Vec<Ident>,
}

impl IdentifierList {
pub fn identifiers(&self) -> &[Ident] {
&self.identifiers
}
}

impl Parse for IdentifierList {
fn parse(input: ParseStream) -> Result<Self> {
Ok(Self {
identifiers: Punctuated::<Ident, Token![,]>::parse_terminated(input)?
.into_iter()
.collect(),
})
}
}
1 change: 1 addition & 0 deletions melior/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ readme = "../README.md"
keywords = ["mlir", "llvm"]

[dependencies]
melior-macro = { version = "0.1", path = "../macro" }
mlir-sys = "0.2"
once_cell = "1"

Expand Down
2 changes: 1 addition & 1 deletion melior/src/execution_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ mod tests {

pass_manager
.nested_under("func.func")
.add_pass(pass::conversion::convert_arithmetic_to_llvm());
.add_pass(pass::conversion::arith_to_llvm());

assert_eq!(pass_manager.run(&mut module), Ok(()));

Expand Down
4 changes: 4 additions & 0 deletions melior/src/pass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,8 @@ impl Pass {
pub(crate) unsafe fn to_raw(&self) -> MlirPass {
self.raw
}

pub fn __private_from_raw_fn(create_raw: unsafe extern "C" fn() -> MlirPass) -> Self {
Self::from_raw_fn(create_raw)
}
}
100 changes: 57 additions & 43 deletions melior/src/pass/conversion.rs
Original file line number Diff line number Diff line change
@@ -1,46 +1,60 @@
//! Dialect conversion passes.

use super::Pass;
use mlir_sys::{
mlirCreateConversionArithToLLVMConversionPass, mlirCreateConversionConvertControlFlowToLLVM,
mlirCreateConversionConvertControlFlowToSPIRV, mlirCreateConversionConvertFuncToLLVM,
mlirCreateConversionConvertMathToLLVM, mlirCreateConversionConvertMathToLibm,
melior_macro::conversion_passes!(
mlirCreateConversionArithToLLVMConversionPass,
mlirCreateConversionConvertAffineForToGPU,
mlirCreateConversionConvertAffineToStandard,
mlirCreateConversionConvertAMDGPUToROCDL,
mlirCreateConversionConvertArithToSPIRV,
mlirCreateConversionConvertArmNeon2dToIntr,
mlirCreateConversionConvertAsyncToLLVM,
mlirCreateConversionConvertBufferizationToMemRef,
mlirCreateConversionConvertComplexToLibm,
mlirCreateConversionConvertComplexToLLVM,
mlirCreateConversionConvertComplexToStandard,
mlirCreateConversionConvertControlFlowToLLVM,
mlirCreateConversionConvertControlFlowToSPIRV,
mlirCreateConversionConvertFuncToLLVM,
mlirCreateConversionConvertFuncToSPIRV,
mlirCreateConversionConvertGpuLaunchFuncToVulkanLaunchFunc,
mlirCreateConversionConvertGpuOpsToNVVMOps,
mlirCreateConversionConvertGpuOpsToROCDLOps,
mlirCreateConversionConvertGPUToSPIRV,
mlirCreateConversionConvertIndexToLLVMPass,
mlirCreateConversionConvertLinalgToLLVM,
mlirCreateConversionConvertLinalgToStandard,
mlirCreateConversionConvertMathToFuncs,
mlirCreateConversionConvertMathToLibm,
mlirCreateConversionConvertMathToLLVM,
mlirCreateConversionConvertMathToSPIRV,
};

// TODO Unify a naming convention.

/// Creates a pass to convert the `arith` dialect to the `llvm` dialect.
pub fn convert_arithmetic_to_llvm() -> Pass {
Pass::from_raw_fn(mlirCreateConversionArithToLLVMConversionPass)
}

/// Creates a pass to convert the `cf` dialect to the `llvm` dialect.
pub fn convert_scf_to_llvm() -> Pass {
Pass::from_raw_fn(mlirCreateConversionConvertControlFlowToLLVM)
}

/// Creates a pass to convert the `func` dialect to the `llvm` dialect.
pub fn convert_func_to_llvm() -> Pass {
Pass::from_raw_fn(mlirCreateConversionConvertFuncToLLVM)
}

/// Creates a pass to convert the `math` dialect to the `llvm` dialect.
pub fn convert_math_to_llvm() -> Pass {
Pass::from_raw_fn(mlirCreateConversionConvertMathToLLVM)
}

/// Creates a pass to convert the `cf` dialect to the `spirv` dialect.
pub fn convert_scf_to_spirv() -> Pass {
Pass::from_raw_fn(mlirCreateConversionConvertControlFlowToSPIRV)
}

/// Creates a pass to convert the `math` dialect to the `spirv` dialect.
pub fn convert_math_to_spirv() -> Pass {
Pass::from_raw_fn(mlirCreateConversionConvertMathToSPIRV)
}

/// Creates a pass to convert the `math` dialect to the `libm` dialect.
pub fn convert_math_to_libm() -> Pass {
Pass::from_raw_fn(mlirCreateConversionConvertMathToLibm)
}
mlirCreateConversionConvertMemRefToSPIRV,
mlirCreateConversionConvertNVGPUToNVVM,
mlirCreateConversionConvertOpenACCToLLVM,
mlirCreateConversionConvertOpenACCToSCF,
mlirCreateConversionConvertOpenMPToLLVM,
mlirCreateConversionConvertParallelLoopToGpu,
mlirCreateConversionConvertPDLToPDLInterp,
mlirCreateConversionConvertSCFToOpenMP,
mlirCreateConversionConvertShapeConstraints,
mlirCreateConversionConvertShapeToStandard,
mlirCreateConversionConvertSPIRVToLLVM,
mlirCreateConversionConvertTensorToLinalg,
mlirCreateConversionConvertTensorToSPIRV,
mlirCreateConversionConvertVectorToGPU,
mlirCreateConversionConvertVectorToLLVM,
mlirCreateConversionConvertVectorToSCF,
mlirCreateConversionConvertVectorToSPIRV,
mlirCreateConversionConvertVulkanLaunchFuncToVulkanCalls,
mlirCreateConversionGpuToLLVMConversionPass,
mlirCreateConversionLowerHostCodeToLLVM,
mlirCreateConversionMapMemRefStorageClass,
mlirCreateConversionMemRefToLLVMConversionPass,
mlirCreateConversionReconcileUnrealizedCasts,
mlirCreateConversionSCFToControlFlow,
mlirCreateConversionSCFToSPIRV,
mlirCreateConversionTosaToArith,
mlirCreateConversionTosaToLinalg,
mlirCreateConversionTosaToLinalgNamed,
mlirCreateConversionTosaToSCF,
mlirCreateConversionTosaToTensor,
);
7 changes: 7 additions & 0 deletions tools/all_api.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/sh

set -e

. $(dirname $0)/utility.sh

all_api
2 changes: 1 addition & 1 deletion tools/utility.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ implemented_api() {
}

all_api() {
cat $(find $(brew --prefix llvm)/include/mlir-c -type f) | filter_api
cat $(find $(brew --prefix llvm)/include -type f) | filter_api
}

0 comments on commit b3ab31f

Please sign in to comment.