Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor macro #379

Merged
merged 21 commits into from
Dec 6, 2023
12 changes: 9 additions & 3 deletions macro/src/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod utility;

use self::{
error::Error,
operation::generate_operation,
utility::{sanitize_documentation, sanitize_snake_case_name},
};
pub use input::DialectInput;
Expand Down Expand Up @@ -64,9 +65,14 @@ fn generate_dialect_module(
.all_derived_definitions("Op")
.map(Operation::new)
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.filter(|operation| operation.dialect_name() == dialect_name)
.map(|operation| operation.to_tokens())
.iter()
.map(|operation| {
Ok::<_, Error>(if operation.dialect_name()? == dialect_name {
Some(generate_operation(operation)?)
} else {
None
})
})
.collect::<Result<Vec<_>, _>>()?;

let doc = format!(
Expand Down
26 changes: 3 additions & 23 deletions macro/src/dialect/error.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
mod ods;

pub use self::ods::OdsError;
use std::{
error,
fmt::{self, Display, Formatter},
Expand Down Expand Up @@ -78,26 +81,3 @@ impl From<FromUtf8Error> for Error {
Self::Utf8(error)
}
}

#[derive(Debug)]
pub enum OdsError {
ExpectedSuperClass(&'static str),
InvalidTrait,
UnexpectedSuperClass(&'static str),
}

impl Display for OdsError {
fn fmt(&self, formatter: &mut Formatter) -> fmt::Result {
match self {
Self::ExpectedSuperClass(class) => {
write!(formatter, "record should be a sub-class of {class}",)
}
Self::InvalidTrait => write!(formatter, "record is not a supported trait"),
Self::UnexpectedSuperClass(class) => {
write!(formatter, "record should not be a sub-class of {class}",)
}
}
}
}

impl error::Error for OdsError {}
27 changes: 27 additions & 0 deletions macro/src/dialect/error/ods.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use std::{
error::Error,
fmt::{self, Display, Formatter},
};

#[derive(Debug)]
pub enum OdsError {
ExpectedSuperClass(&'static str),
InvalidTrait,
UnexpectedSuperClass(&'static str),
}

impl Display for OdsError {
fn fmt(&self, formatter: &mut Formatter) -> fmt::Result {
match self {
Self::ExpectedSuperClass(class) => {
write!(formatter, "record should be a sub-class of {class}",)
}
Self::InvalidTrait => write!(formatter, "record is not a supported trait"),
Self::UnexpectedSuperClass(class) => {
write!(formatter, "record should not be a sub-class of {class}",)
}
}
}
}

impl Error for OdsError {}
230 changes: 123 additions & 107 deletions macro/src/dialect/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@ mod sequence_info;
mod variadic_kind;

use self::{
builder::OperationBuilder, element_kind::ElementKind, field_kind::FieldKind,
operation_field::OperationField, sequence_info::SequenceInfo, variadic_kind::VariadicKind,
builder::{generate_operation_builder, OperationBuilder},
element_kind::ElementKind,
field_kind::FieldKind,
operation_field::OperationField,
sequence_info::SequenceInfo,
variadic_kind::VariadicKind,
};
use super::utility::sanitize_documentation;
use crate::dialect::{
Expand All @@ -19,15 +23,70 @@ use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use tblgen::{error::WithLocation, record::Record};

#[derive(Clone, Debug)]
pub fn generate_operation(operation: &Operation) -> Result<TokenStream, Error> {
let summary = operation.summary()?;
let description = operation.description()?;
let class_name = format_ident!("{}", operation.class_name()?);
let name = &operation.full_name()?;
let accessors = operation
.fields()
.map(|field| field.accessors())
.collect::<Result<Vec<_>, _>>()?;

let builder = OperationBuilder::new(operation)?;
let builder_tokens = generate_operation_builder(&builder)?;
let builder_fn = builder.create_op_builder_fn()?;
let default_constructor = builder.create_default_constructor()?;

Ok(quote! {
#[doc = #summary]
#[doc = "\n\n"]
#[doc = #description]
pub struct #class_name<'c> {
operation: ::melior::ir::operation::Operation<'c>,
}

impl<'c> #class_name<'c> {
pub fn name() -> &'static str {
#name
}

pub fn operation(&self) -> &::melior::ir::operation::Operation<'c> {
&self.operation
}

#builder_fn

#(#accessors)*
}

#builder_tokens

#default_constructor

impl<'c> TryFrom<::melior::ir::operation::Operation<'c>> for #class_name<'c> {
type Error = ::melior::Error;

fn try_from(
operation: ::melior::ir::operation::Operation<'c>,
) -> Result<Self, Self::Error> {
// TODO Check an operation name.
Ok(Self { operation })
}
}

impl<'c> From<#class_name<'c>> for ::melior::ir::operation::Operation<'c> {
fn from(operation: #class_name<'c>) -> ::melior::ir::operation::Operation<'c> {
operation.operation
}
}
})
}

#[derive(Debug)]
pub struct Operation<'a> {
dialect_name: &'a str,
short_name: &'a str,
full_name: String,
class_name: &'a str,
summary: String,
definition: Record<'a>,
can_infer_type: bool,
description: String,
regions: Vec<OperationField<'a>>,
successors: Vec<OperationField<'a>>,
results: Vec<OperationField<'a>>,
Expand All @@ -38,7 +97,6 @@ pub struct Operation<'a> {

impl<'a> Operation<'a> {
pub fn new(definition: Record<'a>) -> Result<Self, Error> {
let dialect = definition.def_value("opDialect")?;
let traits = Self::collect_traits(definition)?;
let has_trait = |name| traits.iter().any(|r#trait| r#trait.has_name(name));

Expand All @@ -50,30 +108,7 @@ impl<'a> Operation<'a> {
has_trait("::mlir::OpTrait::AttrSizedResultSegments"),
)?;

let name = definition.name()?;
let class_name = if name.starts_with('_') {
name
} else if let Some(name) = name.split('_').nth(1) {
// Trim dialect prefix from name.
name
} else {
name
};
let short_name = definition.str_value("opName")?;

Ok(Self {
dialect_name: dialect.name()?,
short_name,
full_name: {
let dialect_name = dialect.string_value("name")?;

if dialect_name.is_empty() {
short_name.into()
} else {
format!("{dialect_name}.{short_name}")
}
},
class_name,
successors: Self::collect_successors(definition)?,
operands: Self::collect_operands(
&arguments,
Expand All @@ -89,26 +124,65 @@ impl<'a> Operation<'a> {
&& unfixed_result_count == 0
|| r#trait.has_name("::mlir::InferTypeOpInterface::Trait") && regions.is_empty()
}),
summary: {
let summary = definition.str_value("summary")?;

[
format!("[`{short_name}`]({class_name}) operation."),
if summary.is_empty() {
Default::default()
} else {
summary[0..1].to_uppercase() + &summary[1..] + "."
},
]
.join(" ")
},
description: sanitize_documentation(definition.str_value("description")?)?,
regions,
definition,
})
}

fn dialect(&self) -> Result<Record, Error> {
Ok(self.definition.def_value("opDialect")?)
}

pub fn dialect_name(&self) -> Result<&str, Error> {
Ok(self.dialect()?.name()?)
}

pub fn class_name(&self) -> Result<&str, Error> {
let name = self.definition.name()?;

Ok(if name.starts_with('_') {
name
} else if let Some(name) = name.split('_').nth(1) {
// Trim dialect prefix from name.
name
} else {
name
})
}

pub fn short_name(&self) -> Result<&str, Error> {
Ok(self.definition.str_value("opName")?)
}

pub fn full_name(&self) -> Result<String, Error> {
let dialect_name = self.dialect()?.string_value("name")?;
let short_name = self.short_name()?;

Ok(if dialect_name.is_empty() {
short_name.into()
} else {
format!("{dialect_name}.{short_name}")
})
}

pub fn dialect_name(&self) -> &str {
self.dialect_name
pub fn summary(&self) -> Result<String, Error> {
let short_name = self.short_name()?;
let class_name = self.class_name()?;
let summary = self.definition.str_value("summary")?;

Ok([
format!("[`{short_name}`]({class_name}) operation."),
if summary.is_empty() {
Default::default()
} else {
summary[0..1].to_uppercase() + &summary[1..] + "."
},
]
.join(" "))
}

pub fn description(&self) -> Result<String, Error> {
sanitize_documentation(self.definition.str_value("description")?)
}

pub fn fields(&self) -> impl Iterator<Item = &OperationField<'a>> + Clone {
Expand Down Expand Up @@ -334,62 +408,4 @@ impl<'a> Operation<'a> {
})
.collect()
}

pub fn to_tokens(&self) -> Result<TokenStream, Error> {
let class_name = format_ident!("{}", &self.class_name);
let name = &self.full_name;
let accessors = self
.fields()
.map(|field| field.accessors())
.collect::<Result<Vec<_>, _>>()?;
let builder = OperationBuilder::new(self)?;
let builder_tokens = builder.to_tokens()?;
let builder_fn = builder.create_op_builder_fn();
let default_constructor = builder.create_default_constructor()?;
let summary = &self.summary;
let description = &self.description;

Ok(quote! {
#[doc = #summary]
#[doc = "\n\n"]
#[doc = #description]
pub struct #class_name<'c> {
operation: ::melior::ir::operation::Operation<'c>,
}

impl<'c> #class_name<'c> {
pub fn name() -> &'static str {
#name
}

pub fn operation(&self) -> &::melior::ir::operation::Operation<'c> {
&self.operation
}

#builder_fn

#(#accessors)*
}

#builder_tokens

#default_constructor

impl<'c> TryFrom<::melior::ir::operation::Operation<'c>> for #class_name<'c> {
type Error = ::melior::Error;

fn try_from(
operation: ::melior::ir::operation::Operation<'c>,
) -> Result<Self, Self::Error> {
Ok(Self { operation })
}
}

impl<'c> From<#class_name<'c>> for ::melior::ir::operation::Operation<'c> {
fn from(operation: #class_name<'c>) -> ::melior::ir::operation::Operation<'c> {
operation.operation
}
}
})
}
}
Loading