Skip to content

Commit

Permalink
func dialect (#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
raviqqe committed May 8, 2023
1 parent 61e4c7b commit 7f939e7
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 37 deletions.
8 changes: 7 additions & 1 deletion .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"bfloat",
"canonicalize",
"canonicalizer",
"divf",
"divi",
"femtomc",
"hasher",
"indoc",
Expand All @@ -14,9 +16,13 @@
"melior",
"memref",
"mlir",
"mulf",
"muli",
"rustc",
"sccp",
"spirv",
"stdc"
"stdc",
"subf",
"subi"
]
}
2 changes: 2 additions & 0 deletions src/dialect.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! Dialect handles, instances, and registry.

pub mod arith;
pub mod func;
mod handle;
pub mod llvm;
mod registry;
Expand Down
45 changes: 45 additions & 0 deletions src/dialect/arith.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use crate::ir::{operation::Builder, Location, Operation, Value};

pub fn addi<'c>(lhs: Value, rhs: Value, location: Location<'c>) -> Operation<'c> {
binary_operator("arith.addi", lhs, rhs, location)
}

pub fn subi<'c>(lhs: Value, rhs: Value, location: Location<'c>) -> Operation<'c> {
binary_operator("arith.subi", lhs, rhs, location)
}

pub fn muli<'c>(lhs: Value, rhs: Value, location: Location<'c>) -> Operation<'c> {
binary_operator("arith.muli", lhs, rhs, location)
}

pub fn divi<'c>(lhs: Value, rhs: Value, location: Location<'c>) -> Operation<'c> {
binary_operator("arith.divi", lhs, rhs, location)
}

pub fn addf<'c>(lhs: Value, rhs: Value, location: Location<'c>) -> Operation<'c> {
binary_operator("arith.addf", lhs, rhs, location)
}

pub fn subf<'c>(lhs: Value, rhs: Value, location: Location<'c>) -> Operation<'c> {
binary_operator("arith.subf", lhs, rhs, location)
}

pub fn mulf<'c>(lhs: Value, rhs: Value, location: Location<'c>) -> Operation<'c> {
binary_operator("arith.mulf", lhs, rhs, location)
}

pub fn divf<'c>(lhs: Value, rhs: Value, location: Location<'c>) -> Operation<'c> {
binary_operator("arith.divf", lhs, rhs, location)
}

fn binary_operator<'c>(
name: &str,
lhs: Value,
rhs: Value,
location: Location<'c>,
) -> Operation<'c> {
Builder::new(name, location)
.add_operands(&[lhs, rhs])
.enable_result_type_inference()
.build()
}
76 changes: 76 additions & 0 deletions src/dialect/func.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
use crate::{
ir::{operation::Builder, Attribute, Identifier, Location, Operation, Region, Value},
Context,
};

pub fn func<'c>(
context: &'c Context,
name: Attribute<'c>,
r#type: Attribute<'c>,
region: Region,
location: Location<'c>,
) -> Operation<'c> {
Builder::new("func.func", location)
.add_attributes(&[
(Identifier::new(context, "sym_name"), name),
(Identifier::new(context, "function_type"), r#type),
])
.add_regions(vec![region])
.build()
}

pub fn r#return<'c>(operands: &[Value], location: Location<'c>) -> Operation<'c> {
Builder::new("func.return", location)
.add_operands(operands)
.build()
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{
dialect::arith::addi,
ir::{Attribute, Block, Module, Type},
test::load_all_dialects,
Context,
};

#[test]
fn run_on_function_in_nested_module() {
let context = Context::new();
load_all_dialects(&context);

let location = Location::unknown(&context);
let module = Module::new(location);

let integer_type = Type::integer(&context, 64);

let function = {
let region = Region::new();
let block = Block::new(&[(integer_type, location), (integer_type, location)]);

let sum = block.append_operation(addi(
block.argument(0).unwrap().into(),
block.argument(1).unwrap().into(),
location,
));

block.append_operation(r#return(&[sum.result(0).unwrap().into()], location));

region.append_block(block);

func(
&context,
Attribute::parse(&context, "\"add\"").unwrap(),
Attribute::parse(&context, "(i64, i64) -> i64").unwrap(),
region,
Location::unknown(&context),
)
};

module.body().append_operation(function);

assert!(module.as_operation().verify());
insta::assert_display_snapshot!(module.as_operation());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
source: src/dialect/func.rs
expression: module.as_operation()
---
module {
func.func @add(%arg0: i64, %arg1: i64) -> i64 {
%0 = arith.addi %arg0, %arg1 : i64
return %0 : i64
}
}

9 changes: 2 additions & 7 deletions src/ir/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,8 @@ impl<'a> Debug for BlockRef<'a> {
mod tests {
use super::*;
use crate::{
dialect,
ir::{operation, Module, Region, ValueLike},
utility::register_all_dialects,
test::load_all_dialects,
};
use pretty_assertions::assert_eq;

Expand Down Expand Up @@ -366,12 +365,8 @@ mod tests {

#[test]
fn terminator() {
let registry = dialect::Registry::new();
register_all_dialects(&registry);

let context = Context::new();
context.append_dialect_registry(&registry);
context.load_all_available_dialects();
load_all_dialects(&context);

let block = Block::new(&[]);

Expand Down
8 changes: 2 additions & 6 deletions src/ir/operation/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ impl<'c> Builder<'c> {
#[cfg(test)]
mod tests {
use super::*;
use crate::{context::Context, dialect, ir::Block, utility::register_all_dialects};
use crate::{context::Context, ir::Block, test::load_all_dialects};

#[test]
fn new() {
Expand Down Expand Up @@ -174,12 +174,8 @@ mod tests {

#[test]
fn enable_result_type_inference() {
let registry = dialect::Registry::new();
register_all_dialects(&registry);

let context = Context::new();
context.append_dialect_registry(&registry);
context.load_all_available_dialects();
load_all_dialects(&context);

let location = Location::unknown(&context);
let r#type = Type::index(&context);
Expand Down
16 changes: 6 additions & 10 deletions src/ir/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,8 @@ mod tests {
use super::*;
use crate::{
context::Context,
dialect,
ir::{operation, Attribute, Block, Identifier, Location},
utility::register_all_dialects,
test::load_all_dialects,
};

#[test]
Expand Down Expand Up @@ -191,7 +190,7 @@ mod tests {
#[test]
fn display() {
let context = Context::new();
context.load_all_available_dialects();

let location = Location::unknown(&context);
let index_type = Type::parse(&context, "index").unwrap();

Expand All @@ -211,12 +210,8 @@ mod tests {

#[test]
fn display_with_dialect_loaded() {
let registry = dialect::Registry::new();
register_all_dialects(&registry);

let context = Context::new();
context.append_dialect_registry(&registry);
context.load_all_available_dialects();
load_all_dialects(&context);

let location = Location::unknown(&context);
let index_type = Type::parse(&context, "index").unwrap();
Expand All @@ -238,7 +233,8 @@ mod tests {
#[test]
fn debug() {
let context = Context::new();
context.load_all_available_dialects();
load_all_dialects(&context);

let location = Location::unknown(&context);
let index_type = Type::parse(&context, "index").unwrap();

Expand All @@ -252,7 +248,7 @@ mod tests {

assert_eq!(
format!("{:?}", Value::from(operation.result(0).unwrap())),
"Value(\n%0 = \"arith.constant\"() {value = 0 : index} : () -> index\n)"
"Value(\n%c0 = arith.constant 0 : index\n)"
);
}
}
18 changes: 5 additions & 13 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ pub mod ir;
mod logical_result;
pub mod pass;
mod string_ref;
#[cfg(test)]
mod test;
pub mod utility;

pub use self::{
Expand All @@ -118,7 +120,7 @@ mod tests {
context::Context,
dialect,
ir::{operation, Attribute, Block, Identifier, Location, Module, Region, Type},
utility::register_all_dialects,
test::load_all_dialects,
};

#[test]
Expand All @@ -143,12 +145,8 @@ mod tests {

#[test]
fn build_add() {
let registry = dialect::Registry::new();
register_all_dialects(&registry);

let context = Context::new();
context.append_dialect_registry(&registry);
context.get_or_load_dialect("func");
load_all_dialects(&context);

let location = Location::unknown(&context);
let module = Module::new(location);
Expand Down Expand Up @@ -200,14 +198,8 @@ mod tests {

#[test]
fn build_sum() {
let registry = dialect::Registry::new();
register_all_dialects(&registry);

let context = Context::new();
context.append_dialect_registry(&registry);
context.get_or_load_dialect("func");
context.get_or_load_dialect("memref");
context.get_or_load_dialect("scf");
load_all_dialects(&context);

let location = Location::unknown(&context);
let module = Module::new(location);
Expand Down
8 changes: 8 additions & 0 deletions src/test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
use crate::{dialect::Registry, utility::register_all_dialects, Context};

pub fn load_all_dialects(context: &Context) {
let registry = Registry::new();
register_all_dialects(&registry);
context.append_dialect_registry(&registry);
context.load_all_available_dialects();
}

0 comments on commit 7f939e7

Please sign in to comment.