Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full pass API #116

Merged
merged 12 commits into from
May 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 0 additions & 33 deletions macro/src/conversion_passes.rs

This file was deleted.

17 changes: 15 additions & 2 deletions macro/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mod conversion_passes;
mod parse;
mod pass;
mod type_check_functions;

use parse::IdentifierList;
Expand All @@ -19,7 +19,20 @@ pub fn type_check_functions(stream: TokenStream) -> TokenStream {
pub fn conversion_passes(stream: TokenStream) -> TokenStream {
let identifiers = parse_macro_input!(stream as IdentifierList);

convert_result(conversion_passes::generate(identifiers.identifiers()))
convert_result(pass::generate(identifiers.identifiers(), |mut name| {
name = name.strip_prefix("Conversion").unwrap_or(name);
name = name.strip_prefix("Convert").unwrap_or(name);
name.strip_suffix("ConversionPass").unwrap_or(name).into()
}))
}

#[proc_macro]
pub fn transform_passes(stream: TokenStream) -> TokenStream {
let identifiers = parse_macro_input!(stream as IdentifierList);

convert_result(pass::generate(identifiers.identifiers(), |name| {
name.strip_prefix("Transforms").unwrap_or(name).into()
}))
}

fn convert_result(result: Result<TokenStream, Box<dyn Error>>) -> TokenStream {
Expand Down
48 changes: 48 additions & 0 deletions macro/src/pass.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use convert_case::{Case, Casing};
use proc_macro::TokenStream;
use proc_macro2::Ident;
use quote::quote;
use std::error::Error;

pub fn generate(
identifiers: &[Ident],
extract_pass_name: impl Fn(&str) -> String,
) -> Result<TokenStream, Box<dyn Error>> {
let mut stream = TokenStream::new();

for identifier in identifiers {
let name = extract_pass_name(identifier.to_string().strip_prefix("mlirCreate").unwrap());

let function_name = Ident::new(&name.to_case(Case::Snake), identifier.span());
let document = format!(" Creates a pass of `{}`.", name);

stream.extend(TokenStream::from(quote! {
#[doc = #document]
pub fn #function_name() -> crate::pass::Pass {
crate::pass::Pass::__private_from_raw_fn(mlir_sys::#identifier)
}
}));
}

for identifier in identifiers {
let name = identifier.to_string();
let name = name.strip_prefix("mlirCreate").unwrap();

let foreign_function_name =
Ident::new(&("mlirRegister".to_owned() + name), identifier.span());
let function_name = Ident::new(
&("register_".to_owned() + &extract_pass_name(name).to_case(Case::Snake)),
identifier.span(),
);
let document = format!(" Registers a pass of `{}`.", name);

stream.extend(TokenStream::from(quote! {
#[doc = #document]
pub fn #function_name() {
unsafe { mlir_sys::#foreign_function_name() }
}
}));
}

Ok(stream)
}
2 changes: 1 addition & 1 deletion melior/src/execution_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ mod tests {
.unwrap();

let pass_manager = pass::Manager::new(&context);
pass_manager.add_pass(pass::conversion::convert_func_to_llvm());
pass_manager.add_pass(pass::conversion::func_to_llvm());

pass_manager
.nested_under("func.func")
Expand Down
16 changes: 8 additions & 8 deletions melior/src/pass/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ mod tests {
use crate::{
dialect,
ir::{Location, Module},
pass::{self, transform::register_print_operation_stats},
pass::{self, transform::register_print_op_stats},
utility::{parse_pass_pipeline, register_all_dialects},
};
use indoc::indoc;
Expand All @@ -105,7 +105,7 @@ mod tests {
fn add_pass() {
let context = Context::new();

Manager::new(&context).add_pass(pass::conversion::convert_func_to_llvm());
Manager::new(&context).add_pass(pass::conversion::func_to_llvm());
}

#[test]
Expand All @@ -128,7 +128,7 @@ mod tests {
let context = Context::new();
let manager = Manager::new(&context);

manager.add_pass(pass::conversion::convert_func_to_llvm());
manager.add_pass(pass::conversion::func_to_llvm());
manager
.run(&mut Module::new(Location::unknown(&context)))
.unwrap();
Expand All @@ -153,7 +153,7 @@ mod tests {
.unwrap();

let manager = Manager::new(&context);
manager.add_pass(pass::transform::print_operation_stats());
manager.add_pass(pass::transform::print_op_stats());

assert_eq!(manager.run(&mut module), Ok(()));
}
Expand Down Expand Up @@ -186,15 +186,15 @@ mod tests {
let manager = Manager::new(&context);
manager
.nested_under("func.func")
.add_pass(pass::transform::print_operation_stats());
.add_pass(pass::transform::print_op_stats());

assert_eq!(manager.run(&mut module), Ok(()));

let manager = Manager::new(&context);
manager
.nested_under("builtin.module")
.nested_under("func.func")
.add_pass(pass::transform::print_operation_stats());
.add_pass(pass::transform::print_op_stats());

assert_eq!(manager.run(&mut module), Ok(()));
}
Expand All @@ -205,7 +205,7 @@ mod tests {
let manager = Manager::new(&context);
let function_manager = manager.nested_under("func.func");

function_manager.add_pass(pass::transform::print_operation_stats());
function_manager.add_pass(pass::transform::print_op_stats());

assert_eq!(
manager.as_operation_pass_manager().to_string(),
Expand All @@ -229,7 +229,7 @@ mod tests {
)
.unwrap_err());

register_print_operation_stats();
register_print_op_stats();

assert_eq!(
parse_pass_pipeline(
Expand Down
106 changes: 16 additions & 90 deletions melior/src/pass/transform.rs
Original file line number Diff line number Diff line change
@@ -1,92 +1,18 @@
//! General transformation passes.

use super::Pass;
use mlir_sys::{
mlirCreateTransformsCSE, mlirCreateTransformsCanonicalizer, mlirCreateTransformsInliner,
mlirCreateTransformsPrintOpStats, mlirCreateTransformsSCCP, mlirCreateTransformsStripDebugInfo,
mlirCreateTransformsSymbolDCE, mlirCreateTransformsSymbolPrivatize, mlirRegisterTransformsCSE,
mlirRegisterTransformsCanonicalizer, mlirRegisterTransformsInliner,
mlirRegisterTransformsPrintOpStats, mlirRegisterTransformsSCCP,
mlirRegisterTransformsStripDebugInfo, mlirRegisterTransformsSymbolDCE,
mlirRegisterTransformsSymbolPrivatize,
};

/// Creates a pass to canonicalize IR.
pub fn canonicalizer() -> Pass {
Pass::from_raw_fn(mlirCreateTransformsCanonicalizer)
}

/// Registers a pass to canonicalize IR.
pub fn register_canonicalizer() {
unsafe { mlirRegisterTransformsCanonicalizer() }
}

/// Creates a pass to eliminate common sub-expressions.
pub fn cse() -> Pass {
Pass::from_raw_fn(mlirCreateTransformsCSE)
}

/// Registers a pass to print operation stats.
pub fn register_cse() {
unsafe { mlirRegisterTransformsCSE() }
}

/// Creates a pass to inline function calls.
pub fn inliner() -> Pass {
Pass::from_raw_fn(mlirCreateTransformsInliner)
}

/// Registers a pass to inline function calls.
pub fn register_inliner() {
unsafe { mlirRegisterTransformsInliner() }
}

/// Creates a pass to propagate constants.
pub fn sccp() -> Pass {
Pass::from_raw_fn(mlirCreateTransformsSCCP)
}

/// Registers a pass to propagate constants.
pub fn register_sccp() {
unsafe { mlirRegisterTransformsSCCP() }
}

/// Creates a pass to strip debug information.
pub fn strip_debug_info() -> Pass {
Pass::from_raw_fn(mlirCreateTransformsStripDebugInfo)
}

/// Registers a pass to strip debug information.
pub fn register_strip_debug_info() {
unsafe { mlirRegisterTransformsStripDebugInfo() }
}

/// Creates a pass to eliminate dead symbols.
pub fn symbol_dce() -> Pass {
Pass::from_raw_fn(mlirCreateTransformsSymbolDCE)
}

/// Registers a pass to eliminate dead symbols.
pub fn register_symbol_dce() {
unsafe { mlirRegisterTransformsSymbolDCE() }
}

/// Creates a pass to mark all top-level symbols private.
pub fn symbol_privatize() -> Pass {
Pass::from_raw_fn(mlirCreateTransformsSymbolPrivatize)
}

/// Registers a pass to mark all top-level symbols private.
pub fn register_symbol_privatize() {
unsafe { mlirRegisterTransformsSymbolPrivatize() }
}

/// Creates a pass to print operation statistics.
pub fn print_operation_stats() -> Pass {
Pass::from_raw_fn(mlirCreateTransformsPrintOpStats)
}

/// Registers a pass to print operation stats.
pub fn register_print_operation_stats() {
unsafe { mlirRegisterTransformsPrintOpStats() }
}
melior_macro::transform_passes!(
mlirCreateTransformsCSE,
mlirCreateTransformsCanonicalizer,
mlirCreateTransformsControlFlowSink,
mlirCreateTransformsGenerateRuntimeVerification,
mlirCreateTransformsInliner,
mlirCreateTransformsLocationSnapshot,
mlirCreateTransformsLoopInvariantCodeMotion,
mlirCreateTransformsPrintOpStats,
mlirCreateTransformsSCCP,
mlirCreateTransformsStripDebugInfo,
mlirCreateTransformsSymbolDCE,
mlirCreateTransformsSymbolPrivatize,
mlirCreateTransformsTopologicalSort,
mlirCreateTransformsViewOpGraph,
);