Skip to content

Commit

Permalink
Type check functions (#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
raviqqe committed May 8, 2023
1 parent 8c13f07 commit 64630aa
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 53 deletions.
2 changes: 2 additions & 0 deletions .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"addf",
"addi",
"amdgpu",
"apdl",
"bfloat",
"bufferization",
"canonicalize",
Expand Down Expand Up @@ -32,6 +33,7 @@
"subf",
"subi",
"tosa",
"unranked",
"vulkan"
]
}
8 changes: 8 additions & 0 deletions macro/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
mod conversion_passes;
mod parse;
mod type_check_functions;

use parse::IdentifierList;
use proc_macro::TokenStream;
use quote::quote;
use std::error::Error;
use syn::parse_macro_input;

#[proc_macro]
pub fn type_check_functions(stream: TokenStream) -> TokenStream {
let identifiers = parse_macro_input!(stream as IdentifierList);

convert_result(type_check_functions::generate(identifiers.identifiers()))
}

#[proc_macro]
pub fn conversion_passes(stream: TokenStream) -> TokenStream {
let identifiers = parse_macro_input!(stream as IdentifierList);
Expand Down
29 changes: 29 additions & 0 deletions macro/src/type_check_functions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use convert_case::{Case, Casing};
use proc_macro::TokenStream;
use proc_macro2::Ident;
use quote::quote;
use std::error::Error;

pub fn generate(identifiers: &[Ident]) -> Result<TokenStream, Box<dyn Error>> {
let mut stream = TokenStream::new();

for identifier in identifiers {
let name = identifier
.to_string()
.strip_prefix("mlirTypeIsA")
.unwrap()
.to_case(Case::Snake);

let function_name = Ident::new(&format!("is_{}", &name), identifier.span());
let document = format!(" Returns `true` if a type is `{}`.", name);

stream.extend(TokenStream::from(quote! {
#[doc = #document]
fn #function_name(&self) -> bool {
unsafe { mlir_sys::#identifier(self.to_raw()) }
}
}));
}

Ok(stream)
}
90 changes: 37 additions & 53 deletions melior/src/ir/type/type_like.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use super::Id;
use crate::context::ContextRef;
use mlir_sys::{
mlirIntegerTypeGetWidth, mlirTypeDump, mlirTypeGetContext, mlirTypeGetTypeID, mlirTypeIsABF16,
mlirTypeIsAF16, mlirTypeIsAF32, mlirTypeIsAF64, mlirTypeIsAFunction, mlirTypeIsAIndex,
mlirTypeIsAInteger, mlirTypeIsAMemRef, mlirTypeIsATuple, mlirTypeIsAVector, MlirType,
mlirIntegerTypeGetWidth, mlirTypeDump, mlirTypeGetContext, mlirTypeGetTypeID, MlirType,
};

/// Trait for type-like types.
Expand All @@ -21,11 +19,6 @@ pub trait TypeLike<'c> {
unsafe { Id::from_raw(mlirTypeGetTypeID(self.to_raw())) }
}

/// Returns `true` if a type is integer.
fn is_integer(&self) -> bool {
unsafe { mlirTypeIsAInteger(self.to_raw()) }
}

/// Gets a bit width of an integer type.
fn get_width(&self) -> Option<usize> {
if self.is_integer() {
Expand All @@ -35,55 +28,46 @@ pub trait TypeLike<'c> {
}
}

/// Returns `true` if a type is index.
fn is_index(&self) -> bool {
unsafe { mlirTypeIsAIndex(self.to_raw()) }
}

/// Returns `true` if a type is bfloat16.
fn is_bfloat16(&self) -> bool {
unsafe { mlirTypeIsABF16(self.to_raw()) }
}

/// Returns `true` if a type is float16.
fn is_float16(&self) -> bool {
unsafe { mlirTypeIsAF16(self.to_raw()) }
}

/// Returns `true` if a type is float32.
fn is_float32(&self) -> bool {
unsafe { mlirTypeIsAF32(self.to_raw()) }
}

/// Returns `true` if a type is float64.
fn is_float64(&self) -> bool {
unsafe { mlirTypeIsAF64(self.to_raw()) }
}

/// Returns `true` if a type is a function.
fn is_function(&self) -> bool {
unsafe { mlirTypeIsAFunction(self.to_raw()) }
}

/// Returns `true` if a type is a memory reference.
fn is_mem_ref(&self) -> bool {
unsafe { mlirTypeIsAMemRef(self.to_raw()) }
}

/// Returns `true` if a type is a tuple.
fn is_tuple(&self) -> bool {
unsafe { mlirTypeIsATuple(self.to_raw()) }
}

/// Returns `true` if a type is a vector.
fn is_vector(&self) -> bool {
unsafe { mlirTypeIsAVector(self.to_raw()) }
}

/// Dumps a type.
fn dump(&self) {
unsafe { mlirTypeDump(self.to_raw()) }
}

melior_macro::type_check_functions!(
mlirTypeIsAAnyQuantizedType,
mlirTypeIsABF16,
mlirTypeIsACalibratedQuantizedType,
mlirTypeIsAComplex,
mlirTypeIsAF16,
mlirTypeIsAF32,
mlirTypeIsAF64,
mlirTypeIsAFloat8E4M3FN,
mlirTypeIsAFloat8E5M2,
mlirTypeIsAFunction,
mlirTypeIsAIndex,
mlirTypeIsAInteger,
mlirTypeIsAMemRef,
mlirTypeIsANone,
mlirTypeIsAOpaque,
mlirTypeIsAPDLAttributeType,
mlirTypeIsAPDLOperationType,
mlirTypeIsAPDLRangeType,
mlirTypeIsAPDLType,
mlirTypeIsAPDLTypeType,
mlirTypeIsAPDLValueType,
mlirTypeIsAQuantizedType,
mlirTypeIsARankedTensor,
mlirTypeIsAShaped,
mlirTypeIsATensor,
mlirTypeIsATransformAnyOpType,
mlirTypeIsATransformOperationType,
mlirTypeIsATuple,
mlirTypeIsAUniformQuantizedPerAxisType,
mlirTypeIsAUniformQuantizedType,
mlirTypeIsAUnrankedMemRef,
mlirTypeIsAUnrankedTensor,
mlirTypeIsAVector,
);
}

#[cfg(test)]
Expand Down

0 comments on commit 64630aa

Please sign in to comment.