diff --git a/include/Conversion/ConversionPasses.h b/include/Conversion/ConversionPasses.h index 34eeb441..2477bb3d 100644 --- a/include/Conversion/ConversionPasses.h +++ b/include/Conversion/ConversionPasses.h @@ -12,11 +12,11 @@ namespace mlir { -// Passes defined in GraphPasses.td +// Passes defined in GraphPasses.td. #define GEN_PASS_DECL #include "Conversion/ConversionPasses.h.inc" -// Conversion passes +// Conversion passes. std::unique_ptr createLowerArithToNeuraPass(); std::unique_ptr createLowerLlvmToNeuraPass(); diff --git a/include/NeuraDialect/CMakeLists.txt b/include/NeuraDialect/CMakeLists.txt index b829e584..1c9b30b5 100644 --- a/include/NeuraDialect/CMakeLists.txt +++ b/include/NeuraDialect/CMakeLists.txt @@ -1,3 +1,11 @@ +# Set TableGen include paths +set(MLIR_TABLEGEN_INCLUDES + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/include/NeuraDialect + ${CMAKE_CURRENT_BINARY_DIR}/include/NeuraDialect + ${MLIR_MAIN_INCLUDE_DIR} + ${MLIR_INCLUDE_DIR}) + add_mlir_dialect(Neura neura) set(LLVM_TARGET_DEFINITIONS NeuraPasses.td) diff --git a/include/NeuraDialect/Neura.td b/include/NeuraDialect/Neura.td index b2bc6543..8f2aa4af 100644 --- a/include/NeuraDialect/Neura.td +++ b/include/NeuraDialect/Neura.td @@ -4,5 +4,6 @@ include "NeuraDialect.td" include "NeuraOps.td" include "NeuraPasses.td" +include "NeuraTypes.td" #endif // GRAPH_TD \ No newline at end of file diff --git a/include/NeuraDialect/NeuraDialect.h b/include/NeuraDialect/NeuraDialect.h index f5562e4f..aa1dabec 100644 --- a/include/NeuraDialect/NeuraDialect.h +++ b/include/NeuraDialect/NeuraDialect.h @@ -1,17 +1,34 @@ -#ifndef NEURADIALECT_NEURADIALECT_H -#define NEURADIALECT_NEURADIALECT_H +#ifndef NEURA_DIALECT_H +#define NEURA_DIALECT_H #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" -// Defines the export macro. #ifdef _WIN32 - #define NEURA_DIALECT_EXPORT __declspec(dllexport) +#define NEURA_DIALECT_EXPORT __declspec(dllexport) #else - #define NEURA_DIALECT_EXPORT __attribute__((visibility("default"))) +#define NEURA_DIALECT_EXPORT __attribute__((visibility("default"))) #endif -// Includes generated TableGen headers. +namespace mlir { +namespace neura { + +// Forward declare before including generated code +class NeuraDialect; + +} // end namespace neura +} // end namespace mlir + +// Include the generated dialect declarations #include "NeuraDialect/NeuraDialect.h.inc" -#endif // NEURADIALECT_NEURADIALECT_H +namespace mlir { +namespace neura { + +// Declare additional methods for the generated dialect class +NEURA_DIALECT_EXPORT void registerNeuraDialect(); + +} // end namespace neura +} // end namespace mlir + +#endif // NEURA_DIALECT_H \ No newline at end of file diff --git a/include/NeuraDialect/NeuraDialect.td b/include/NeuraDialect/NeuraDialect.td index bd59a3e0..23b89bf4 100644 --- a/include/NeuraDialect/NeuraDialect.td +++ b/include/NeuraDialect/NeuraDialect.td @@ -1,13 +1,33 @@ // NeuraDialect.td - TableGen description of the dialect. -#ifndef NEURADIALECT_TD -#define NEURADIALECT_TD +#ifndef NEURA_DIALECT_TD +#define NEURA_DIALECT_TD include "mlir/IR/OpBase.td" include "mlir/IR/DialectBase.td" +include "mlir/IR/AttrTypeBase.td" +// First define the dialect def NeuraDialect : Dialect { let name = "neura"; let cppNamespace = "::mlir::neura"; + + let summary = "A dialect for the Neura compiler infrastructure."; + let description = [{ + This dialect is used for representing Neura operations and types. + }]; + + let useDefaultTypePrinterParser = 0; + let useDefaultAttributePrinterParser = 0; + + let extraClassDeclaration = [{ + // Type parsing/printing + Type parseType(DialectAsmParser &parser) const override; + void printType(Type type, DialectAsmPrinter &printer) const override; + + // Attribute parsing/printing + Attribute parseAttribute(DialectAsmParser &parser, Type type) const override; + void printAttribute(Attribute attr, DialectAsmPrinter &printer) const override; +}]; } -#endif // NEURADIALECT_TD \ No newline at end of file +#endif // NEURA_DIALECT_TD \ No newline at end of file diff --git a/include/NeuraDialect/NeuraOps.h b/include/NeuraDialect/NeuraOps.h index 5d354bfb..6c66620f 100644 --- a/include/NeuraDialect/NeuraOps.h +++ b/include/NeuraDialect/NeuraOps.h @@ -7,9 +7,19 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Builders.h" -#define GET_OP_CLASSES +// First includes the interface declarations. +#define GET_OP_INTERFACE_CLASSES +#include "NeuraDialect/Neura.h.inc" +#undef GET_OP_INTERFACE_CLASSES + +// Then includes the op declarations. +#define GET_OP_DECLARATIONS #include "NeuraDialect/Neura.h.inc" +#undef GET_OP_DECLARATIONS -// Additional definitions or includes can go here. +// Finally includes the op definitions. +#define GET_OP_CLASSES +#include "NeuraDialect/Neura.h.inc" +#undef GET_OP_CLASSES #endif // NEURA_OPS_H diff --git a/include/NeuraDialect/NeuraOps.td b/include/NeuraDialect/NeuraOps.td index d15bafc0..01e54159 100644 --- a/include/NeuraDialect/NeuraOps.td +++ b/include/NeuraDialect/NeuraOps.td @@ -6,7 +6,10 @@ include "NeuraDialect/NeuraDialect.td" // Defines basic scalar operations. def Neura_ConstantOp : Op { - let arguments = (ins AnyAttr:$value); + let arguments = (ins + AnyAttr:$value, + OptionalAttr:$predicate // Add optional predicate attribute + ); let results = (outs AnyType:$result); // let assemblyFormat = "attr-dict `:` type($result)"; } @@ -15,9 +18,9 @@ def Neura_ConstantOp : Op { def Neura_AddOp : Op { let summary = "Integer addition operation"; let opName = "add"; - let arguments = (ins AnyInteger:$lhs, AnyInteger:$rhs); + let arguments = (ins AnyInteger:$lhs, AnyInteger:$rhs, Optional:$predicate); let results = (outs AnyInteger:$result); - // let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; + // let assemblyFormat = "$lhs `,` $rhs `,` $predicate attr-dict `:` type($result)"; let traits = [SameOperandsAndResultElementType]; } @@ -25,10 +28,10 @@ def Neura_AddOp : Op { def Neura_FAddOp : Op { let summary = "Floating addition operation"; let opName = "fadd"; - let arguments = (ins AnyFloat:$lhs, AnyFloat:$rhs); - let results = (outs AnyFloat:$result); - // let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; - let traits = [SameOperandsAndResultElementType]; + let arguments = (ins AnyType:$lhs, AnyType:$rhs, Optional:$predicate); + let results = (outs AnyType:$result); + // let assemblyFormat = "$lhs `,` $rhs `,` $predicate attr-dict `:` type($result)"; + //let traits = [SameOperandsAndResultElementType]; } // Defines a floating-point substraction operation. @@ -45,70 +48,100 @@ def Neura_FSubOp: Op { def Neura_FMulOp : Op { let summary = "Floating multiplication operation"; let opName = "fmul"; - let arguments = (ins AnyFloat:$lhs, AnyFloat:$rhs); - let results = (outs AnyFloat:$result); - // let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; - let traits = [SameOperandsAndResultElementType]; + let arguments = (ins AnyType:$lhs, AnyType:$rhs, Optional:$predicate); + let results = (outs AnyType:$result); + // let assemblyFormat = "$lhs `,` $rhs `,` $predicate attr-dict `:` type($result)"; + // let traits = [SameOperandsAndResultElementType]; } +// Defines a bitwise OR operation. def Neura_OrOp : Op { let summary = "Bitwise OR operation"; - let arguments = (ins AnySignlessInteger:$lhs, AnySignlessInteger:$rhs); + let arguments = (ins AnySignlessInteger:$lhs, AnySignlessInteger:$rhs, Optional:$predicate); let results = (outs AnySignlessInteger:$result); + // let assemblyFormat = "$lhs `,` $rhs `,` $predicate attr-dict `:` type($result)"; let traits = [SameOperandsAndResultElementType]; - // let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; -} - -// Defines a move operation for data communication. -def Neura_MovOp : Op { - let summary = "Move operation"; - let opName = "mov"; - let arguments = (ins AnyType:$lhs); - let results = (outs AnyType:$result); - // let assemblyFormat = "$lhs attr-dict `:` type($lhs) `->` type($result)"; - // let traits = [Pure]; } +// Defines an integer compare operation. def Neura_ICmpOp : Op { let summary = "Integer compare operation"; let opName = "icmp"; - let arguments = (ins AnyInteger:$lhs, AnyInteger:$rhs, - StrAttr:$predicate); + let arguments = (ins AnyType:$lhs, AnyType:$rhs, Optional:$predicate, + StrAttr:$cmpType); + let results = (outs AnyType:$result); + // let assemblyFormat = "$lhs `,` $rhs `,` $cmpTypeAttr `,` $cmp_type attr-dict `:` type($result)"; + // let traits = [SameOperandsAndResultElementType]; +} + +// Defines a floating-point compare operation. +def Neura_FCmpOp : Op { + let summary = "Floating-point compare operation"; + let opName = "fcmp"; + let arguments = (ins AnyFloat:$lhs, + AnyFloat:$rhs, + Optional:$predicate, + StrAttr:$cmpType); let results = (outs I1:$result); - // let assemblyFormat = "$lhs `,` $rhs `,` $predicate attr-dict `:` type($result)"; + // let assemblyFormat = "$lhs `,` $rhs `,` $cmpType attr-dict `:` type($result)"; // let traits = [SameOperandsAndResultElementType]; } +// Defines a load operation. def Neura_LoadOp : Op { - let arguments = (ins AnyType:$addr); + let arguments = (ins AnyType:$addr, Optional:$predicate); let results = (outs AnyType:$value); - // let assemblyFormat = "$addr attr-dict `:` type($value)"; + // let assemblyFormat = "$addr `,` $predicate attr-dict `:` type($value)"; } +// Defines a store operation. def Neura_StoreOp : Op { - let arguments = (ins AnyType:$value, AnyType:$addr); + let arguments = (ins AnyType:$value, AnyType:$addr, Optional:$predicate); let results = (outs); - // let assemblyFormat = "$value `,` $addr attr-dict"; + // let assemblyFormat = "$value `,` $addr `,` $predicate attr-dict"; } +// Defines a pointer computation operation. def Neura_GEP : Op { let summary = "Pointer computation using offset indices"; - let arguments = (ins AnyType:$base, Variadic:$indices); + let arguments = (ins AnyType:$base, Variadic:$indicesAndPredicate); let results = (outs AnyType:$result); - // let assemblyFormat = "$base `[` $indices `]` attr-dict"; + // let assemblyFormat = "$base `[` $indicesAndPredicate `]` `,` $predicate attr-dict"; } +// Defines a conditional branch operation. def Neura_CondBr : Op { - let arguments = (ins I1:$condition, + let arguments = (ins AnyType:$condition, + Optional:$predicate, Variadic:$trueArgs, Variadic:$falseArgs); let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest); - let assemblyFormat = "$condition `then` $trueArgs `:` type($trueArgs) `to` $trueDest `else` $falseArgs `:` type($falseArgs) `to` $falseDest attr-dict"; + let assemblyFormat = "$condition `:` type($condition) ($predicate^ `:` type($predicate))? `then` ($trueArgs^)? `:` type($trueArgs) `to` $trueDest `else` ($falseArgs^)? `:` type($falseArgs) `to` $falseDest attr-dict"; } +// Defines an unconditional branch operation. +def Neura_Br : Op { + let arguments = (ins Variadic:$args); + let successors = (successor AnySuccessor:$dest); + let assemblyFormat = "($args^)? `:` type($args) `to` $dest attr-dict"; +} + +def Neura_SelOp : Op { + let arguments = (ins AnyType:$ifTrue, AnyType:$ifFalse, I1:$cond); + let results = (outs AnyType:$result); + // let assemblyFormat = "$ifTrue `,` $ifFalse `,` $cond attr-dict `:` type($ifTrue)"; +} + +def Neura_NotOp : Op { + let arguments = (ins I1:$input); + let results = (outs I1:$output); + // let assemblyFormat = "$input attr-dict `:` type($output)"; +} + +// Defines a return operation. def Neura_ReturnOp : Op { let arguments = (ins Variadic:$values); - // let assemblyFormat = "($values^)? attr-dict"; + // let assemblyFormat = "($values^)? `,` $predicate attr-dict"; } // ---------------------------------------------------- @@ -127,9 +160,9 @@ def VectorOfAnyFloat : def Neura_VFMulOp : Op { let summary = "Vector floating multiplication operation"; let opName = "vfmul"; - let arguments = (ins VectorOfAnyFloat:$lhs, VectorOfAnyFloat:$rhs); + let arguments = (ins VectorOfAnyFloat:$lhs, VectorOfAnyFloat:$rhs, Optional:$predicate); let results = (outs VectorOfAnyFloat:$result); - // let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; + // let assemblyFormat = "$lhs `,` $rhs `,` $predicate attr-dict `:` type($result)"; let traits = [SameOperandsAndResultElementType]; } @@ -138,18 +171,85 @@ def Neura_VFMulOp : Op { def Neura_FAddFAddOp : Op { let summary = "Fused fadd(fadd(a, b), c)"; - let arguments = (ins AnyFloat:$a, AnyFloat:$b, AnyFloat:$c); + let arguments = (ins AnyFloat:$a, AnyFloat:$b, AnyFloat:$c, Optional:$predicate); let results = (outs AnyFloat:$result); - // let assemblyFormat = "$a `,` $b `,` $c attr-dict `:` type($result)"; + // let assemblyFormat = "$a `,` $b `,` $c `,` $predicate attr-dict `:` type($result)"; let traits = [SameOperandsAndResultElementType]; } def Neura_FMulFAddOp : Op { let summary = "Fused fadd(fmul(a, b), c)"; - let arguments = (ins AnyFloat:$a, AnyFloat:$b, AnyFloat:$c); + let arguments = (ins AnyFloat:$a, AnyFloat:$b, AnyFloat:$c, Optional:$predicate); let results = (outs AnyFloat:$result); - // let assemblyFormat = "$a `,` $b `,` $c attr-dict `:` type($result)"; + // let assemblyFormat = "$a `,` $b `,` $c `,` $predicate attr-dict `:` type($result)"; let traits = [SameOperandsAndResultElementType]; } +// ---------------------------------------------------- +// Defines move operations. +def Neura_DataMovOp : Op { + let summary = "Data movement operation"; + let arguments = (ins AnyType:$input); + let results = (outs AnyType:$output); + // let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)"; +} +// ---------------------------------------------------- +// Defines ctrl-related operations. + +// Phi operation for merging values in dataflow form +def Neura_PhiOp : Op { + let summary = "Phi node in dataflow form"; + let description = [{ + Merges values from different control paths in dataflow form. + Used with reserve and ctrl_mov to represent control flow. + + Example: + %v = neura.reserve : f32 // Create placeholder + %result = neura.phi %init, %v // Merge initial and loop-carried values + neura.ctrl_mov %next to %v // Connect next iteration + }]; + + let arguments = (ins AnyType:$init_val, AnyType:$loop_val); + let results = (outs AnyType:$result); + + // Explicitly specify types for operands in the assembly format + let assemblyFormat = "$init_val `:` type($init_val) `,` $loop_val `:` type($loop_val) attr-dict `:` type($result)"; +} + +// Control movement extending base move but with different signature. +def Neura_CtrlMovOp : Op { + let summary = "Control movement operation"; + let description = [{ + Connects a value to a reserved placeholder in the dataflow. + Used to establish control flow dependencies. + + Example: + ctrl_mov %value to %placeholder : f32 // Connect value to placeholder + }]; + + // Add type constraints for both operands + let arguments = (ins AnyType:$value, AnyType:$target); + let results = (outs); + + // Correct assembly format - types must be space-separated + let assemblyFormat = "$value `->` $target attr-dict `:` type($value) type($target)"; +} + +// Reserve operation for control flow values. +def Neura_ReserveOp : Op { + let summary = "Creates a placeholder for control flow values"; + let description = [{ + Creates a placeholder value that will be connected via ctrl_mov. + Used to represent control flow dependencies in dataflow form. + + Example: + %v = neura.reserve : f32 // Create placeholder + %result = neura.phi %init, %v // Use in phi node + neura.ctrl_mov %next to %v // Connect value + }]; + + let arguments = (ins); + let results = (outs AnyType:$result); + let assemblyFormat = "attr-dict `:` type($result)"; +} diff --git a/include/NeuraDialect/NeuraPasses.h b/include/NeuraDialect/NeuraPasses.h index 764e94cb..9cdeef7f 100644 --- a/include/NeuraDialect/NeuraPasses.h +++ b/include/NeuraDialect/NeuraPasses.h @@ -16,9 +16,12 @@ namespace neura { // Passes defined in GraphPasses.td #define GEN_PASS_DECL #include "NeuraDialect/NeuraPasses.h.inc" -std::unique_ptr createInsertMovPass(); +std::unique_ptr createInsertDataMovPass(); +std::unique_ptr createInsertCtrlMovPass(); std::unique_ptr createFusePatternsPass(); std::unique_ptr createAssignAcceleratorPass(); +std::unique_ptr createTransformCtrlToDataFlowPass(); +std::unique_ptr createLeveragePredicatedValuePass(); #define GEN_PASS_REGISTRATION #include "NeuraDialect/NeuraPasses.h.inc" diff --git a/include/NeuraDialect/NeuraPasses.td b/include/NeuraDialect/NeuraPasses.td index db5d2916..f4ea76a7 100644 --- a/include/NeuraDialect/NeuraPasses.td +++ b/include/NeuraDialect/NeuraPasses.td @@ -20,11 +20,34 @@ def FusePatterns : Pass<"fuse-patterns", "ModuleOp"> { let constructor = "neura::createFusePatternsPass()"; } -def InsertMov : Pass<"insert-mov", "ModuleOp"> { - let summary = "Inserts move operations in the Neura dialect"; +def InsertDataMov : Pass<"insert-data-mov", "ModuleOp"> { + let summary = "Inserts data move operations in the Neura dialect"; let description = - [{Insert neura.mov before and after all neura dialect operations.}]; - let constructor = "neura::createInsertMovPass()"; + [{Insert neura.data_mov before all neura dialect operations.}]; + let constructor = "neura::createInsertDataMovPass()"; +} + +def InsertCtrlMov : Pass<"insert-ctrl-mov", "ModuleOp"> { + let summary = "Inserts ctrl move operations in the Neura dialect"; + let description = + [{Insert neura.ctrl_mov before all neura dialect operations.}]; + let constructor = "neura::createInsertCtrlMovPass()"; +} + +def TransformCtrlToDataFlow : Pass<"transform-ctrl-to-data-flow", "ModuleOp"> { + let summary = "Inserts ctrl move operations in the Neura dialect"; + let description = + [{Transform ctrl to predicate-based data flow.}]; + let constructor = "neura::createTransformCtrlToDataFlowPass()"; +} + +def LeveragePredicatedValue : Pass<"leverage-predicated-value", "ModuleOp"> { + let summary = "Convert values to predicated values in Neura dialect"; + let description = [{ + This pass converts regular values to predicated values in Neura dialect operations. + Each value is wrapped in a predicated value type with a default true predicate. + }]; + let constructor = "neura::createLeveragePredicatedValuePass()"; } #endif // NEURA_PASSES_TD \ No newline at end of file diff --git a/include/NeuraDialect/NeuraTypes.h b/include/NeuraDialect/NeuraTypes.h new file mode 100644 index 00000000..e3f10fa7 --- /dev/null +++ b/include/NeuraDialect/NeuraTypes.h @@ -0,0 +1,81 @@ +#ifndef NEURA_TYPES_H +#define NEURA_TYPES_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/TypeSupport.h" +#include "NeuraDialect/NeuraDialect.h" + +namespace mlir { +namespace neura { + +namespace detail { +// Storage class for predicated value type. +struct PredicatedValueStorage : public mlir::TypeStorage { + using KeyTy = std::pair; // valueType and predicateType + + PredicatedValueStorage(Type valueType, IntegerType predicateType) + : valueType(valueType), predicateType(predicateType) {} + + // Required storage class methods. + bool operator==(const KeyTy& key) const { + return key.first == valueType && key.second == predicateType; + } + + static PredicatedValueStorage* construct(mlir::TypeStorageAllocator& allocator, + const KeyTy& key) { + // Allocate the storage instance and construct it + return new (allocator.allocate()) + PredicatedValueStorage(key.first, key.second); + } + + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_combine(key.first, key.second); + } + + Type valueType; // The type being predicated + IntegerType predicateType; // The predicate type (usually i1) +}; +} // namespace detail + +class PredicatedValue : public mlir::Type::TypeBase< + PredicatedValue, mlir::Type, + detail::PredicatedValueStorage> { + public: + using Base = mlir::Type::TypeBase; + static constexpr llvm::StringLiteral name = "data"; + + using Base::Base; + + // Static method to create a PredicatedValue instance. + static PredicatedValue get(MLIRContext* context, Type valueType, + IntegerType predicateType) { + return Base::get(context, valueType, predicateType); + } + + // Overload verify that takes two separate parameters. + static LogicalResult verify(function_ref emitError, + Type valueType, IntegerType predicateType) { + return verify(emitError, std::make_pair(valueType, predicateType)); + } + + // New overload verify that accepts the KeyTy as expected by MLIR + static LogicalResult verify(function_ref emitError, + const detail::PredicatedValueStorage::KeyTy &key) { + if (!key.second.isInteger(1)) + return emitError() << "predicate must be i1 type"; + return success(); + } + + Type getValueType() const { return getImpl()->valueType; } + IntegerType getPredicateType() const { return getImpl()->predicateType; } + + static Type parse(AsmParser& parser); + void print(AsmPrinter& printer) const; +}; + +} // namespace neura +} // namespace mlir + +#endif // NEURA_TYPES_H \ No newline at end of file diff --git a/include/NeuraDialect/NeuraTypes.td b/include/NeuraDialect/NeuraTypes.td new file mode 100644 index 00000000..e17e6c3e --- /dev/null +++ b/include/NeuraDialect/NeuraTypes.td @@ -0,0 +1,44 @@ +#ifndef NEURA_TYPES_TD +#define NEURA_TYPES_TD + +include "mlir/IR/AttrTypeBase.td" +include "NeuraDialect.td" + +// Define the Neura dialect types +// class Neura_Type traits = []> : +// TypeDef; +class Neura_Type : TypeDef { + let mnemonic = typeMnemonic; // Changed parameter name to avoid self-assignment +} + +// Predicated value type - a value with an optional predicate +def Neura_PredicatedValue : Neura_Type<"PredicatedValue", "data"> { + let summary = "A value with an optional predicate"; + let description = [{ + Represents a value that may be conditional based on a predicate. + Contains: + - A base value type (integer, float, or vector) + - An i1 predicate indicating validity + + Examples: + !neura.data // Predicated float + !neura.data // Predicated integer + !neura.data // Predicated vector + }]; + + let parameters = (ins + "Type":$valueType, + "IntegerType":$predicateType + ); + + // Verify predicate is i1 + let genVerifyDecl = 1; + + // Fix: Update assembly format syntax with proper semicolon + // let assemblyFormat = "<$valueType,$predicateType>"; + let hasCustomAssemblyFormat = 1; + // let mnemonic = "data"; + // let cppNamespace = "::mlir::neura"; +} + +#endif // NEURA_TYPES_TD \ No newline at end of file diff --git a/lib/Conversion/ArithToNeura/ArithToNeuraPass.cpp b/lib/Conversion/ArithToNeura/ArithToNeuraPass.cpp index 37679077..ab952519 100644 --- a/lib/Conversion/ArithToNeura/ArithToNeuraPass.cpp +++ b/lib/Conversion/ArithToNeura/ArithToNeuraPass.cpp @@ -28,6 +28,21 @@ using namespace mlir::neura; namespace{ +struct ArithFAddToNeuraFAdd : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::AddFOp op, + PatternRewriter &rewriter) const override { + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + Type resultType = op.getType(); + + // Optional predicate: default to 'none' + rewriter.replaceOpWithNewOp(op, resultType, lhs, rhs, Value()); + return success(); + } +}; + struct LowerArithToNeuraPass : public PassWrapper> { @@ -45,6 +60,7 @@ struct LowerArithToNeuraPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); mlir::neura::arith2neura::populateWithGenerated(patterns); + patterns.add(&getContext()); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); } diff --git a/lib/Conversion/ArithToNeura/ArithToNeuraPatterns.td b/lib/Conversion/ArithToNeura/ArithToNeuraPatterns.td index 785f5b47..7715f90f 100644 --- a/lib/Conversion/ArithToNeura/ArithToNeuraPatterns.td +++ b/lib/Conversion/ArithToNeura/ArithToNeuraPatterns.td @@ -2,9 +2,3 @@ include "mlir/IR/OpBase.td" include "mlir/IR/PatternBase.td" include "mlir/Dialect/Arith/IR/ArithOps.td" include "NeuraDialect/NeuraOps.td" - -def : Pat< - (Arith_AddFOp $lhs, $rhs, $_fastmath), - (Neura_FAddOp $lhs, $rhs) ->; - diff --git a/lib/Conversion/LlvmToNeura/LlvmToNeuraPass.cpp b/lib/Conversion/LlvmToNeura/LlvmToNeuraPass.cpp index 403c8f60..c9c2fe23 100644 --- a/lib/Conversion/LlvmToNeura/LlvmToNeuraPass.cpp +++ b/lib/Conversion/LlvmToNeura/LlvmToNeuraPass.cpp @@ -38,7 +38,36 @@ struct LlvmAddToNeuraAdd : public OpRewritePattern { LogicalResult matchAndRewrite(mlir::LLVM::AddOp op, PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, op.getType(), op.getLhs(), op.getRhs()); + rewriter.replaceOpWithNewOp(op, op.getType(), op.getLhs(), op.getRhs(), Value()); + return success(); + } +}; + +struct LlvmFAddToNeuraFAdd : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::LLVM::FAddOp op, + PatternRewriter &rewriter) const override { + Value lhs = op->getOperand(0); + Value rhs = op->getOperand(1); + Type result_type = op->getResult(0).getType(); + + // Only matches scalar float. + if (!mlir::isa(result_type)) + return failure(); + + // Optional predicate: default to 'none' + rewriter.replaceOpWithNewOp(op, result_type, lhs, rhs, Value()); + return success(); + } +}; + +struct LlvmOrToNeuraOr : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::LLVM::OrOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getType(), op.getLhs(), op.getRhs(), Value()); return success(); } }; @@ -56,7 +85,7 @@ struct LlvmFMulToNeuraFMul : public OpRewritePattern { if (!mlir::isa(result_type)) return failure(); - rewriter.replaceOpWithNewOp(op, result_type, lhs, rhs); + rewriter.replaceOpWithNewOp(op, result_type, lhs, rhs, Value()); return success(); } }; @@ -75,7 +104,7 @@ struct LlvmVFMulToNeuraVFMul: public OpRewritePattern { if (!vecTy || !mlir::isa(vecTy.getElementType())) return failure(); - rewriter.replaceOpWithNewOp(op, result_type, lhs, rhs); + rewriter.replaceOpWithNewOp(op, result_type, lhs, rhs, Value()); return success(); } }; @@ -91,7 +120,25 @@ struct LlvmICmpToNeuraICmp : public OpRewritePattern { auto resultType = op.getType(); rewriter.replaceOpWithNewOp( - op, resultType, lhs, rhs, rewriter.getStringAttr(LLVM::stringifyICmpPredicate(pred))); + op, resultType, lhs, rhs, Value(), + rewriter.getStringAttr(LLVM::stringifyICmpPredicate(pred))); + return success(); + } +}; + +struct LlvmFCmpToNeuraFCmp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LLVM::FCmpOp op, + PatternRewriter &rewriter) const override { + auto pred = op.getPredicate(); + auto lhs = op.getLhs(); + auto rhs = op.getRhs(); + auto resultType = op.getType(); + + rewriter.replaceOpWithNewOp( + op, resultType, lhs, rhs, Value(), + rewriter.getStringAttr(LLVM::stringifyFCmpPredicate(pred))); return success(); } }; @@ -108,8 +155,12 @@ struct LlvmGEPToNeuraGEP : public OpRewritePattern { if (auto val = gepIndex.dyn_cast()) { indexValues.push_back(val); } else if (auto intAttr = gepIndex.dyn_cast()) { - auto cst = rewriter.create( - op.getLoc(), rewriter.getIndexType(), intAttr); + // Create constant operation state manually + OperationState state(op.getLoc(), neura::ConstantOp::getOperationName()); + state.addAttribute("value", intAttr); + state.addAttribute("predicate", rewriter.getBoolAttr(true)); + state.addTypes(rewriter.getIndexType()); + Value cst = rewriter.create(state)->getResult(0); indexValues.push_back(cst); } else { return op.emitOpError("Unsupported GEP index kind"); @@ -128,7 +179,7 @@ struct LlvmLoadToNeuraLoad : public OpRewritePattern { PatternRewriter &rewriter) const override { Value ptr = op.getAddr(); // getPointer() is deprecated Type resultType = op.getResult().getType(); - rewriter.replaceOpWithNewOp(op, resultType, ptr); + rewriter.replaceOpWithNewOp(op, resultType, ptr, Value()); return success(); } }; @@ -140,7 +191,7 @@ struct LlvmStoreToNeuraStore : public OpRewritePattern { PatternRewriter &rewriter) const override { Value value = op.getValue(); Value addr = op.getAddr(); // getPointer() is deprecated - rewriter.replaceOpWithNewOp(op, value, addr); + rewriter.replaceOpWithNewOp(op, value, addr, Value()); return success(); } }; @@ -161,6 +212,7 @@ struct LlvmCondBrToNeuraCondBr : public OpRewritePattern { auto newOp = rewriter.create( op.getLoc(), // Location op.getCondition(), // Condition + Value(), // Optional predicate, default to 'none' trueOperands, // True destination operands falseOperands, // False destination operands trueDest, // True destination block @@ -174,6 +226,23 @@ struct LlvmCondBrToNeuraCondBr : public OpRewritePattern { } }; +struct LlvmBrToNeuraBr : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::LLVM::BrOp op, + PatternRewriter &rewriter) const override { + // Get the destination block and its operands + Block *dest = op.getDest(); + ValueRange destOperands = op.getDestOperands(); + + // Create the new Neura_Br operation + rewriter.replaceOpWithNewOp( + op, destOperands, dest); + + return success(); + } +}; + struct LlvmReturnToNeuraReturn : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -184,6 +253,36 @@ struct LlvmReturnToNeuraReturn : public OpRewritePattern { } }; +struct FuncReturnToNeuraReturn : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(func::ReturnOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getOperands()); + return success(); + } +}; + +struct LlvmConstantToNeuraConstant : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LLVM::ConstantOp op, + PatternRewriter &rewriter) const override { + auto attr = op.getValue(); + + // Create operation state manually + OperationState state(op.getLoc(), neura::ConstantOp::getOperationName()); + state.addAttribute("value", attr); + state.addAttribute("predicate", rewriter.getBoolAttr(true)); + state.addTypes(op.getType()); + + // Create the operation and replace + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; + struct LowerLlvmToNeuraPass : public PassWrapper> { @@ -202,15 +301,21 @@ struct LowerLlvmToNeuraPass RewritePatternSet patterns(&getContext()); // Adds DRR patterns. mlir::neura::llvm2neura::populateWithGenerated(patterns); + patterns.add(&getContext()); patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); + patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); + patterns.add(&getContext()); patterns.add(&getContext()); + patterns.add(&getContext()); FrozenRewritePatternSet frozen(std::move(patterns)); diff --git a/lib/Conversion/LlvmToNeura/LlvmToNeuraPatterns.td b/lib/Conversion/LlvmToNeura/LlvmToNeuraPatterns.td index 6004cf96..e01ff728 100644 --- a/lib/Conversion/LlvmToNeura/LlvmToNeuraPatterns.td +++ b/lib/Conversion/LlvmToNeura/LlvmToNeuraPatterns.td @@ -4,23 +4,8 @@ include "mlir/Dialect/LLVMIR/LLVMOps.td" include "NeuraDialect/NeuraOps.td" // Floating point binary operations. -def : Pat< - (LLVM_FAddOp $lhs, $rhs, $_fastmath), - (Neura_FAddOp $lhs, $rhs) ->; - def : Pat< (LLVM_FSubOp $lhs, $rhs, $_fastmath), (Neura_FSubOp $lhs, $rhs) >; -def : Pat< - (LLVM_ConstantOp $value), - (Neura_ConstantOp $value) ->; - -def : Pat< - (LLVM_OrOp $lhs, $rhs), - (Neura_OrOp $lhs, $rhs) ->; - diff --git a/lib/NeuraDialect/CMakeLists.txt b/lib/NeuraDialect/CMakeLists.txt index e906523a..50532491 100644 --- a/lib/NeuraDialect/CMakeLists.txt +++ b/lib/NeuraDialect/CMakeLists.txt @@ -1,13 +1,34 @@ +# Set include paths for TableGen +set(MLIR_TABLEGEN_INCLUDES + "-I${PROJECT_SOURCE_DIR}/include" + "-I${PROJECT_SOURCE_DIR}/include/NeuraDialect" + "-I${CMAKE_CURRENT_BINARY_DIR}/include/NeuraDialect") + +# Generate TableGen files +set(LLVM_TARGET_DEFINITIONS ${PROJECT_SOURCE_DIR}/include/NeuraDialect/Neura.td) +mlir_tablegen(Neura.h.inc -gen-op-decls ${MLIR_TABLEGEN_INCLUDES}) +mlir_tablegen(Neura.cpp.inc -gen-op-defs ${MLIR_TABLEGEN_INCLUDES}) +mlir_tablegen(NeuraDialect.h.inc -gen-dialect-decls ${MLIR_TABLEGEN_INCLUDES}) +mlir_tablegen(NeuraDialect.cpp.inc -gen-dialect-defs ${MLIR_TABLEGEN_INCLUDES}) +mlir_tablegen(NeuraTypes.h.inc -gen-typedef-decls ${MLIR_TABLEGEN_INCLUDES}) +mlir_tablegen(NeuraTypes.cpp.inc -gen-typedef-defs ${MLIR_TABLEGEN_INCLUDES}) +add_public_tablegen_target(MLIRNeuraDialectIncGen) + +# Add the dialect library add_mlir_dialect_library(MLIRNeura Neura.cpp + NeuraTypes.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/NeuraDialect DEPENDS - MLIRNeuraIncGen + MLIRNeuraDialectIncGen LINK_LIBS PUBLIC MLIRIR MLIRSupport + MLIRInferTypeOpInterface ) -add_subdirectory(Transforms) -# add_subdirectory(Conversion) \ No newline at end of file +add_subdirectory(Transforms) \ No newline at end of file diff --git a/lib/NeuraDialect/Neura.cpp b/lib/NeuraDialect/Neura.cpp index a9d80ce1..fb19161d 100644 --- a/lib/NeuraDialect/Neura.cpp +++ b/lib/NeuraDialect/Neura.cpp @@ -1,24 +1,92 @@ +// #include "NeuraDialect/NeuraDialect.h" +// #include "NeuraDialect/NeuraOps.h" +// #include "NeuraDialect/NeuraTypes.h" + +// using namespace mlir; +// using namespace mlir::neura; + +// // Include the generated operation classes first +// #define GET_OP_CLASSES +// #include "NeuraDialect/Neura.cpp.inc" + +// // Include the generated definitions +// #include "NeuraDialect/NeuraDialect.cpp.inc" + +// // void mlir::neura::registerNeuraDialect() { +// // registerDialect(); +// // } + +// // Add any additional method implementations needed +// void NeuraDialect::initialize() { +// addOperations< +// #define GET_OP_LIST +// #include "NeuraDialect/Neura.cpp.inc" +// >(); + +// addTypes(); +// } + +#include "mlir/IR/DialectImplementation.h" // Required for AsmPrinter/Parser #include "NeuraDialect/NeuraDialect.h" #include "NeuraDialect/NeuraOps.h" - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/DialectRegistry.h" -#include "mlir/IR/OpImplementation.h" +#include "NeuraDialect/NeuraTypes.h" using namespace mlir; using namespace mlir::neura; +// Include the generated dialect definitions (type ID + constructor/destructor) #include "NeuraDialect/NeuraDialect.cpp.inc" +// Include the generated operation classes first +#define GET_OP_CLASSES +#include "NeuraDialect/Neura.cpp.inc" + +// NeuraDialect::NeuraDialect(MLIRContext *context) +// : Dialect(getDialectNamespace(), context, TypeID::get()) { +// initialize(); +// } + void NeuraDialect::initialize() { addOperations< - #define GET_OP_LIST - #include "NeuraDialect/Neura.cpp.inc" +#define GET_OP_LIST +#include "NeuraDialect/Neura.cpp.inc" >(); + + addTypes(); } -#define GET_OP_CLASSES -#include "NeuraDialect/Neura.cpp.inc" +// Type parsing/printing +Type NeuraDialect::parseType(DialectAsmParser &parser) const { + StringRef keyword; + if (parser.parseKeyword(&keyword)) + return Type(); + + if (keyword == PredicatedValue::name) { + return PredicatedValue::parse(parser); + } + + parser.emitError(parser.getNameLoc()) << "unknown Neura type: " << keyword; + return Type(); +} + +void NeuraDialect::printType(Type type, DialectAsmPrinter &printer) const { + if (auto predType = dyn_cast(type)) { + printer << PredicatedValue::name; + predType.print(printer); + return; + } + llvm_unreachable("Unknown Neura type"); +} +// Attribute parsing/printing +Attribute NeuraDialect::parseAttribute(DialectAsmParser &parser, Type type) const { + // Currently no custom attributes to parse + parser.emitError(parser.getNameLoc()) << "unknown Neura attribute"; + return Attribute(); +} + +void NeuraDialect::printAttribute(Attribute attr, DialectAsmPrinter &printer) const { + // Currently no custom attributes to print + llvm_unreachable("Unknown Neura attribute"); +} diff --git a/lib/NeuraDialect/NeuraTypes.cpp b/lib/NeuraDialect/NeuraTypes.cpp new file mode 100644 index 00000000..a51416b9 --- /dev/null +++ b/lib/NeuraDialect/NeuraTypes.cpp @@ -0,0 +1,33 @@ +#include "NeuraDialect/NeuraTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" + +using namespace mlir; +using namespace mlir::neura; + +Type PredicatedValue::parse(AsmParser &parser) { + // Parse: !neura.predicated + Type valueType, predicateType; + + if (parser.parseLess() || + parser.parseType(valueType) || + parser.parseComma() || + parser.parseType(predicateType) || + parser.parseGreater()) + return Type(); + + // Verify predicate is i1 + auto intType = mlir::dyn_cast(predicateType); + if (!intType || !intType.isInteger(1)) { + parser.emitError(parser.getNameLoc()) + << "predicate type must be i1, got " << predicateType; + return Type(); + } + + return get(parser.getContext(), valueType, intType); +} + +void PredicatedValue::print(AsmPrinter &printer) const { + printer << "<" << getValueType() << ", " + << getPredicateType() << ">"; +} \ No newline at end of file diff --git a/lib/NeuraDialect/Transforms/AssignAcceleratorPass.cpp b/lib/NeuraDialect/Transforms/AssignAcceleratorPass.cpp index c675865a..1e986408 100644 --- a/lib/NeuraDialect/Transforms/AssignAcceleratorPass.cpp +++ b/lib/NeuraDialect/Transforms/AssignAcceleratorPass.cpp @@ -34,7 +34,7 @@ struct AssignAcceleratorPass : public PassWrapper createAssignAcceleratorPass() { diff --git a/lib/NeuraDialect/Transforms/CMakeLists.txt b/lib/NeuraDialect/Transforms/CMakeLists.txt index 0dc9bb7f..b5dbb4f9 100644 --- a/lib/NeuraDialect/Transforms/CMakeLists.txt +++ b/lib/NeuraDialect/Transforms/CMakeLists.txt @@ -2,9 +2,12 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) add_mlir_library( MLIRNeuraTransforms - InsertMovPass.cpp + InsertDataMovPass.cpp + InsertCtrlMovPass.cpp FusePatternsPass.cpp AssignAcceleratorPass.cpp + TransformCtrlToDataFlowPass.cpp + LeveragePredicatedValuePass.cpp DEPENDS MLIRNeuraTransformsIncGen @@ -15,9 +18,5 @@ add_mlir_library( MLIRSupport MLIRTransforms MLIRNeura - # MLIRNeuraArithToNeuraPass - # MLIRNeuraLlvmToNeuraPass ${dialect_libs} -) -# add_subdirectory(ArithToNeura) -# add_subdirectory(LlvmToNeura) \ No newline at end of file +) \ No newline at end of file diff --git a/lib/NeuraDialect/Transforms/FusePatternsPass.cpp b/lib/NeuraDialect/Transforms/FusePatternsPass.cpp index b4246efb..b8636ebb 100644 --- a/lib/NeuraDialect/Transforms/FusePatternsPass.cpp +++ b/lib/NeuraDialect/Transforms/FusePatternsPass.cpp @@ -47,7 +47,7 @@ struct FuseFAddFAddPattern : public OpRewritePattern { Type type = second.getType(); auto fused = rewriter.create( - loc, type, first.getLhs(), first.getRhs(), tail); + loc, type, first.getLhs(), first.getRhs(), tail, Value()); rewriter.replaceOp(second, fused.getResult()); rewriter.eraseOp(first); @@ -88,7 +88,7 @@ struct FuseFMulFAddPattern : public OpRewritePattern { Type type = add.getType(); auto fused = rewriter.create( - loc, type, fmul.getLhs(), fmul.getRhs(), other); + loc, type, fmul.getLhs(), fmul.getRhs(), other, Value()); rewriter.replaceOp(add, fused.getResult()); rewriter.eraseOp(fmul); diff --git a/lib/NeuraDialect/Transforms/InsertMovPass.cpp b/lib/NeuraDialect/Transforms/InsertCtrlMovPass.cpp similarity index 71% rename from lib/NeuraDialect/Transforms/InsertMovPass.cpp rename to lib/NeuraDialect/Transforms/InsertCtrlMovPass.cpp index 3cd1afab..9f9d6c6c 100644 --- a/lib/NeuraDialect/Transforms/InsertMovPass.cpp +++ b/lib/NeuraDialect/Transforms/InsertCtrlMovPass.cpp @@ -9,23 +9,23 @@ using namespace mlir; -#define GEN_PASS_DEF_INSERTMOV +#define GEN_PASS_DEF_InsertCtrlMov #include "NeuraDialect/NeuraPasses.h.inc" namespace { -struct InsertMovForNeuraOps : public RewritePattern { - InsertMovForNeuraOps(MLIRContext *context) +struct InsertCtrlMovForNeuraOps : public RewritePattern { + InsertCtrlMovForNeuraOps(MLIRContext *context) : RewritePattern(/*matchAnyOpTypeTag=*/MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (op->getDialect()->getNamespace() != "neura" || - isa(op)) { + isa(op)) { return failure(); } // Skips ops that already being inserted mov on the operands. bool allInputsAreMov = llvm::all_of(op->getOperands(), [](Value v) { - return isa_and_nonnull(v.getDefiningOp()); + return isa_and_nonnull(v.getDefiningOp()); }); if (allInputsAreMov) { return failure(); @@ -33,7 +33,7 @@ struct InsertMovForNeuraOps : public RewritePattern { // Makes sure none of the operand has being processed. bool hasAnyMovInput = llvm::any_of(op->getOperands(), [](Value v) { - return isa_and_nonnull(v.getDefiningOp()); + return isa_and_nonnull(v.getDefiningOp()); }); assert(!hasAnyMovInput && "Unexpected: operand already wrapped in neura.mov"); @@ -41,10 +41,10 @@ struct InsertMovForNeuraOps : public RewritePattern { // Wraps operands in mov. SmallVector newOperands; - for (Value operand : op->getOperands()) { - auto mov = rewriter.create(loc, operand.getType(), operand); - newOperands.push_back(mov); - } + // for (Value operand : op->getOperands()) { + // auto mov = rewriter.create(loc, operand.getType(), operand); + // newOperands.push_back(mov); + // } // Clones op with new operands. OperationState state(loc, op->getName()); @@ -58,13 +58,13 @@ struct InsertMovForNeuraOps : public RewritePattern { } }; -struct InsertMovPass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertMovPass) +struct InsertCtrlMovPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertCtrlMovPass) - StringRef getArgument() const override { return "insert-mov"; } + StringRef getArgument() const override { return "insert-ctrl-mov"; } StringRef getDescription() const override { - return "Insert neura.mov before and after all neura dialect operations."; + return "Insert neura.ctrl_mov before all neura dialect operations."; } void getDependentDialects(DialectRegistry ®istry) const override { @@ -73,7 +73,7 @@ struct InsertMovPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); + patterns.add(&getContext()); FrozenRewritePatternSet frozen(std::move(patterns)); ModuleOp module_op = getOperation(); @@ -96,8 +96,8 @@ struct InsertMovPass namespace mlir { namespace neura { -std::unique_ptr createInsertMovPass() { - return std::make_unique(); +std::unique_ptr createInsertCtrlMovPass() { + return std::make_unique(); } } // namespace neura diff --git a/lib/NeuraDialect/Transforms/InsertDataMovPass.cpp b/lib/NeuraDialect/Transforms/InsertDataMovPass.cpp new file mode 100644 index 00000000..15ac86a8 --- /dev/null +++ b/lib/NeuraDialect/Transforms/InsertDataMovPass.cpp @@ -0,0 +1,111 @@ +#include "NeuraDialect/NeuraDialect.h" +#include "NeuraDialect/NeuraOps.h" +#include "NeuraDialect/NeuraPasses.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +#define GEN_PASS_DEF_InsertDataMov +#include "NeuraDialect/NeuraPasses.h.inc" + +namespace { +struct InsertDataMovForNeuraOps : public RewritePattern { + InsertDataMovForNeuraOps(MLIRContext *context) + : RewritePattern(/*matchAnyOpTypeTag=*/MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + + LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { + if (op->getDialect()->getNamespace() != "neura" || + isa(op)) { + return failure(); + } + + // Skips ops that already being inserted mov on the operands. + bool allInputsAreMov = llvm::all_of(op->getOperands(), [](Value v) { + return isa_and_nonnull(v.getDefiningOp()); + }); + if (allInputsAreMov) { + return failure(); + } + + // Makes sure none of the operand has being processed. + bool hasAnyMovInput = llvm::any_of(op->getOperands(), [](Value v) { + return isa_and_nonnull(v.getDefiningOp()); + }); + assert(!hasAnyMovInput && "Unexpected: operand already wrapped in neura.mov"); + + Location loc = op->getLoc(); + + // Wraps operands in mov. + SmallVector newOperands; + for (Value operand : op->getOperands()) { + auto mov = rewriter.create(loc, operand.getType(), operand); + newOperands.push_back(mov); + } + + // Clones op with new operands. + OperationState state(loc, op->getName()); + state.addOperands(newOperands); + state.addTypes(op->getResultTypes()); + state.addAttributes(op->getAttrs()); + + // Copies successors for terminator operations. + if (op->hasTrait()) { + for (Block *successor : op->getSuccessors()) { + state.addSuccessors(successor); + } + } + + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; + +struct InsertDataMovPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertDataMovPass) + + StringRef getArgument() const override { return "insert-data-mov"; } + StringRef getDescription() const override { + return "Insert neura.data_mov before all neura dialect operations."; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + FrozenRewritePatternSet frozen(std::move(patterns)); + + ModuleOp module_op = getOperation(); + + // Applies to every region inside the module (regardless of func type, + // e.g., mlir func or llvm func). + module_op.walk([&](Operation *op) { + if (!op->getRegions().empty()) { + for (Region ®ion : op->getRegions()) { + if (failed(applyPatternsAndFoldGreedily(region, frozen))) { + signalPassFailure(); + } + } + } + }); + } +}; +} // namespace + +namespace mlir { +namespace neura { + +std::unique_ptr createInsertDataMovPass() { + return std::make_unique(); +} + +} // namespace neura +} // namespace mlir diff --git a/lib/NeuraDialect/Transforms/LeveragePredicatedValuePass.cpp b/lib/NeuraDialect/Transforms/LeveragePredicatedValuePass.cpp new file mode 100644 index 00000000..74f8b93a --- /dev/null +++ b/lib/NeuraDialect/Transforms/LeveragePredicatedValuePass.cpp @@ -0,0 +1,174 @@ +#include "NeuraDialect/NeuraDialect.h" +#include "NeuraDialect/NeuraOps.h" +#include "NeuraDialect/NeuraTypes.h" +#include "NeuraDialect/NeuraPasses.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +#define GEN_PASS_DEF_LeveragePredicatedValue +#include "NeuraDialect/NeuraPasses.h.inc" + +namespace { +struct applyPredicatedDataType : public RewritePattern { + applyPredicatedDataType(MLIRContext *context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + + LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { + llvm::errs() << "Processing op: " << *op << "\n"; + + // Skips if not a Neura op or already using predicated values. + if (op->getDialect()->getNamespace() != "neura") { + llvm::errs() << "Skipping non-Neura op\n"; + return failure(); + } + + if (llvm::any_of(op->getResultTypes(), + [](Type t) { return mlir::isa(t); })) { + llvm::errs() << "Skipping already predicated op\n"; + return failure(); + } + + // Converts result types to predicated form. + SmallVector newResults; + for (Type t : op->getResultTypes()) { + auto predicatedTy = mlir::neura::PredicatedValue::get( + op->getContext(), + t, + rewriter.getI1Type()); + newResults.push_back(predicatedTy); + } + + // Clones the operation with new result types. + OperationState state(op->getLoc(), op->getName()); + state.addOperands(op->getOperands()); + state.addTypes(newResults); + state.addAttributes(op->getAttrs()); + Operation *newOp = rewriter.create(state); + + // Replaces the old op with the new one. + rewriter.replaceOp(op, newOp->getResults()); + llvm::errs() << "Converted op to predicated form: " << *newOp << "\n"; + if (!newResults.empty()) { + assert(false); + } + return success(); + } +}; + +struct LeveragePredicatedValuePass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LeveragePredicatedValuePass) + + StringRef getArgument() const override { return "leverage-predicated-value"; } + StringRef getDescription() const override { + return "Convert values to predicated values in Neura dialect operations."; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + + // Processes each function. + module.walk([&](func::FuncOp func) { + // Get operations in topological order (operands before users) + SmallVector orderedOps; + getOperationsInTopologicalOrder(func, orderedOps); + + // Processes each operation in order. + for (Operation *op : orderedOps) { + if (failed(applyPredicatedDataType(op))) { + llvm::errs() << "Failed to convert op to predicated form: " << *op << "\n"; + signalPassFailure(); + return; + } + } + }); + } + +private: + // Gets operations in topological order. + void getOperationsInTopologicalOrder(func::FuncOp func, + SmallVector &ordered) { + DenseSet visited; + func.walk([&](Operation *op) { + // Uses standard DFS to build topological order. + if (visited.contains(op)) + return; + + // Visits operands first. + for (Value operand : op->getOperands()) { + if (auto defOp = operand.getDefiningOp()) { + if (!visited.contains(defOp)) { + visited.insert(defOp); + ordered.push_back(defOp); + } + } + } + + // Then visits current op. + if (!visited.contains(op)) { + visited.insert(op); + ordered.push_back(op); + } + }); + } + + // Converts a single operation to use predicated values. + LogicalResult applyPredicatedDataType(Operation *op) { + llvm::errs() << "Processing op: " << *op << "\n"; + + // Skips if not a Neura op. + if (op->getDialect()->getNamespace() != "neura") { + llvm::errs() << "Skipping non-Neura op\n"; + return success(); + } + + // Skips if no results or already predicated. + if (op->getNumResults() == 0 || + llvm::any_of(op->getResultTypes(), + [](Type t) { return mlir::isa(t); })) { + return success(); + } + + // Converts result types to predicated form. + OpBuilder builder(op); + SmallVector newResults; + for (Type t : op->getResultTypes()) { + auto predicatedTy = mlir::neura::PredicatedValue::get( + op->getContext(), + t, + builder.getI1Type()); + newResults.push_back(predicatedTy); + } + + // Clones with new result types. + OperationState state(op->getLoc(), op->getName()); + state.addOperands(op->getOperands()); + state.addTypes(newResults); + state.addAttributes(op->getAttrs()); + Operation *newOp = builder.create(state); + + // Replaces old op. + op->replaceAllUsesWith(newOp); + op->erase(); + return success(); + } +}; +} // namespace + +namespace mlir { +namespace neura { + +std::unique_ptr createLeveragePredicatedValuePass() { + return std::make_unique(); +} + +} // namespace neura +} // namespace mlir \ No newline at end of file diff --git a/lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp b/lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp new file mode 100644 index 00000000..25461697 --- /dev/null +++ b/lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp @@ -0,0 +1,252 @@ +#include "Common/AcceleratorAttrs.h" +#include "NeuraDialect/NeuraDialect.h" +#include "NeuraDialect/NeuraOps.h" +#include "NeuraDialect/NeuraPasses.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" + +using namespace mlir; + +#define GEN_PASS_DEF_TransformCtrlToDataFlow +#include "NeuraDialect/NeuraPasses.h.inc" + +// Processes a block recursively, cloning its operations into the entry block. +void processBlockRecursively(Block *block, Block &entry_block, Value predicate, OpBuilder &builder, + SmallVector &results, DenseSet &visited_blocks, + DenseMap &arg_mapping, + DenseMap &value_mapping) { + // Checks if the block has already been visited. + if (visited_blocks.contains(block)) { + llvm::errs() << "Skipping already visited block:\n"; + block->dump(); + return; + } + + // Marks the block as visited. + visited_blocks.insert(block); + + llvm::errs() << "Processing block:\n"; + block->dump(); + + // Handle. block arguments first. + for (BlockArgument arg : block->getArguments()) { + llvm::errs() << "Processing block argument: " << arg << "\n"; + + // Checks if we already have a mapping for this argument. + if (auto mapped = arg_mapping.lookup(arg)) { + llvm::errs() << "Found existing mapping for argument\n"; + continue; + } + + builder.setInsertionPointToEnd(&entry_block); + // Creates a new constant operation with zero value and true predicate. + OperationState state(arg.getLoc(), neura::ConstantOp::getOperationName()); + state.addAttribute("value", builder.getZeroAttr(arg.getType())); + state.addAttribute("predicate", builder.getBoolAttr(true)); + state.addTypes(arg.getType()); + Value false_val = builder.create(state)->getResult(0); + + llvm::errs() << "Creating false_val: \n"; + false_val.dump(); + auto sel = builder.create( + arg.getLoc(), arg.getType(), arg, false_val, predicate); + + llvm::errs() << "Created sel operation for argument:\n"; + sel->dump(); + + // Stores mapping. + arg_mapping.try_emplace(arg, sel.getResult()); + value_mapping[arg] = sel.getResult(); + results.push_back(sel.getResult()); + } + + // Processes operations. + SmallVector ops_to_process; + for (Operation &op : *block) { + ops_to_process.push_back(&op); + } + + for (Operation *op : ops_to_process) { + llvm::errs() << "Processing operation:\n"; + op->dump(); + + if (op->hasTrait()) { + if (auto br = dyn_cast(op)) { + llvm::errs() << "Found unconditional branch\n"; + for (Value operand : br.getOperands()) { + if (auto mapped = value_mapping.lookup(operand)) { + results.push_back(mapped); + } else { + results.push_back(operand); + } + } + } else if (auto cond_br = dyn_cast(op)) { + llvm::errs() << "Found conditional branch\n"; + Value cond = cond_br.getCondition(); + auto not_cond = builder.create(cond_br.getLoc(), cond.getType(), cond); + + SmallVector true_results, false_results; + processBlockRecursively(cond_br.getTrueDest(), entry_block, cond, + builder, true_results, visited_blocks, arg_mapping, value_mapping); + processBlockRecursively(cond_br.getFalseDest(), entry_block, not_cond.getResult(), + builder, false_results, visited_blocks, arg_mapping, value_mapping); + + builder.setInsertionPointToEnd(&entry_block); + for (auto [true_result, false_result] : llvm::zip(true_results, false_results)) { + auto sel = builder.create( + op->getLoc(), true_result.getType(), true_result, false_result, cond); + value_mapping[sel.getResult()] = sel.getResult(); + results.push_back(sel.getResult()); + } + } else if (auto ret = dyn_cast(op)) { + llvm::errs() << "Found Return\n"; + for (Value operand : ret.getOperands()) { + if (auto mapped = value_mapping.lookup(operand)) { + results.push_back(mapped); + } else { + results.push_back(operand); + } + } + } else { + // Handle other terminators if needed + llvm::errs() << "Found unexpected terminator operation:\n"; + op->dump(); + assert(false && "Unexpected terminator operation in block"); + } + } + + builder.setInsertionPointToEnd(&entry_block); + Operation *cloned_op = builder.clone(*op); + + // Replaces operands with mapped values. + for (unsigned i = 0; i < cloned_op->getNumOperands(); ++i) { + Value operand = cloned_op->getOperand(i); + if (auto mapped = value_mapping.lookup(operand)) { + cloned_op->setOperand(i, mapped); + } + } + + if (!cloned_op->hasTrait()) { + cloned_op->insertOperands(cloned_op->getNumOperands(), predicate); + } + + // Stores mappings and results. + for (unsigned i = 0; i < op->getNumResults(); ++i) { + Value orig_result = op->getResult(i); + Value new_result = cloned_op->getResult(i); + value_mapping[orig_result] = new_result; + results.push_back(new_result); + } + } + llvm::errs() << "[cheng] after processing entry_block:\n"; + entry_block.dump(); +} + +namespace { +struct TransformCtrlToDataFlowPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TransformCtrlToDataFlowPass) + + StringRef getArgument() const override { return "transform-ctrl-to-data-flow"; } + StringRef getDescription() const override { + return "Flattens control flow into predicated linear SSA for Neura dialect."; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + + module.walk([&](func::FuncOp func) { + llvm::errs() << "Processing function: "; + func.dump(); + + if (!func->hasAttr(mlir::accel::kAcceleratorAttr)) + return; + + auto target = func->getAttrOfType(mlir::accel::kAcceleratorAttr); + if (!target || target.getValue() != mlir::accel::kNeuraTarget) + return; + + Block &entry_block = func.getBody().front(); + llvm::errs() << "Entry block before processing:\n"; + entry_block.dump(); + + OpBuilder builder(&entry_block, entry_block.begin()); + + // Check for terminator + Operation *terminator = nullptr; + if (!entry_block.empty()) { + terminator = &entry_block.back(); + } + + auto cond_br = dyn_cast_or_null(terminator); + if (!cond_br) { + llvm::errs() << "No conditional branch found in entry block\n"; + return; + } + + // Get condition and create not condition + Location loc = cond_br.getLoc(); + Value cond = cond_br.getCondition(); + builder.setInsertionPoint(cond_br); + auto not_cond = builder.create(loc, cond.getType(), cond); + + // Processes branches. + DenseMap arg_mapping; + DenseMap value_mapping; + DenseSet visited_blocks; + SmallVector true_results, false_results; + + processBlockRecursively(cond_br.getTrueDest(), entry_block, cond, + builder, true_results, visited_blocks, arg_mapping, value_mapping); + processBlockRecursively(cond_br.getFalseDest(), entry_block, not_cond.getResult(), + builder, false_results, visited_blocks, arg_mapping, value_mapping); + + llvm::errs() << "Entry block after processing:\n"; + entry_block.dump(); + + // Creates final return operation. + if (!true_results.empty() && !false_results.empty()) { + builder.setInsertionPoint(cond_br); + auto sel = builder.create( + loc, true_results[0].getType(), true_results[0], false_results[0], cond); + builder.create(loc, sel.getResult()); + } + + // Replaces all uses with mapped values. + for (auto &[orig, mapped] : value_mapping) { + orig.replaceAllUsesWith(mapped); + } + + // Erases the conditional branch. + cond_br->erase(); + + // Finally erases all other blocks. + SmallVector blocks_to_erase; + for (Block &block : llvm::make_early_inc_range(func.getBody())) { + if (&block != &entry_block) { + blocks_to_erase.push_back(&block); + } + } + + for (Block *block : blocks_to_erase) { + block->dropAllReferences(); + block->erase(); + } + + llvm::errs() << "Function after transformation:\n"; + func.dump(); + }); + } +}; +} // namespace + +namespace mlir::neura { +std::unique_ptr createTransformCtrlToDataFlowPass() { + return std::make_unique(); +} +} // namespace mlir::neura \ No newline at end of file diff --git a/test/neura/arith_add.mlir b/test/neura/arith_add.mlir index 5c605b77..86ecefa7 100644 --- a/test/neura/arith_add.mlir +++ b/test/neura/arith_add.mlir @@ -1,10 +1,10 @@ -// RUN: mlir-neura-opt --lower-arith-to-neura --insert-mov %s | FileCheck %s +// RUN: mlir-neura-opt --lower-arith-to-neura --insert-data-mov %s | FileCheck %s func.func @test(%a: f32) -> f32 { %b = arith.constant 2.0 : f32 %res = arith.addf %a, %b : f32 - // CHECK: neura.mov - // CHECK: neura.mov + // CHECK: neura.data_mov + // CHECK: neura.data_mov // CHECK: neura.fadd return %res : f32 } diff --git a/test/neura/fadd_fadd.mlir b/test/neura/fadd_fadd.mlir index fe5aa9de..da1aef44 100644 --- a/test/neura/fadd_fadd.mlir +++ b/test/neura/fadd_fadd.mlir @@ -1,5 +1,5 @@ // Applies pattern fusion before mov insertion. -// RUN: mlir-neura-opt --lower-arith-to-neura --fuse-patterns --insert-mov %s | FileCheck %s +// RUN: mlir-neura-opt --lower-arith-to-neura --fuse-patterns --insert-data-mov %s | FileCheck %s func.func @test(%a: f32, %b: f32) -> f32 { %c = arith.constant 2.0 : f32 diff --git a/test/neura/for_loop/test.mlir b/test/neura/for_loop/test.mlir index 1c72782d..b3c00bd5 100644 --- a/test/neura/for_loop/test.mlir +++ b/test/neura/for_loop/test.mlir @@ -1,5 +1,5 @@ // Compiles the original kernel to mlir, then lower back to llvm, eventually binary. -// RUN: clang++ -S -emit-llvm -O2 -o %t-kernel.ll kernel.cpp +// RUN: clang++ -S -emit-llvm -O1 -o %t-kernel.ll kernel.cpp // RUN: mlir-translate --import-llvm %t-kernel.ll -o %t-kernel.mlir // TODO: Enable --insert-mov once the backward ctrl flow mov is supported. @@ -7,6 +7,7 @@ // RUN: mlir-neura-opt \ // RUN: --assign-accelerator \ // RUN: --lower-llvm-to-neura \ +// RN: --transform-ctrl-to-data-flow \ // RUN: --fuse-patterns \ // RN: --insert-mov \ // RUN: %t-kernel.mlir | FileCheck %s diff --git a/test/neura/interpreter/add.mlir b/test/neura/interpreter/add.mlir new file mode 100644 index 00000000..836c625d --- /dev/null +++ b/test/neura/interpreter/add.mlir @@ -0,0 +1,13 @@ +// RUN: neura-interpreter %s | FileCheck %s + +module { + func.func @test() -> f32 { + %arg0 = "neura.constant"() <{value = 9.0 : f32}> : () -> f32 + %cst = "neura.constant"() <{value = 2.0 : f32}> : () -> f32 + %0 = "neura.data_mov"(%arg0) : (f32) -> f32 + %1 = "neura.data_mov"(%cst) : (f32) -> f32 + %2 = "neura.fadd"(%0, %1) : (f32, f32) -> f32 + return %2 : f32 + // CHECK: 11.0 + } +} diff --git a/test/neura/interpreter/interpreter.mlir b/test/neura/interpreter/interpreter.mlir index c5e163e7..0160fa01 100644 --- a/test/neura/interpreter/interpreter.mlir +++ b/test/neura/interpreter/interpreter.mlir @@ -4,8 +4,8 @@ module { func.func @test() -> f32 { %arg0 = arith.constant 9.0 : f32 %cst = arith.constant 2.0 : f32 - %0 = "neura.mov"(%arg0) : (f32) -> f32 - %1 = "neura.mov"(%cst) : (f32) -> f32 + %0 = "neura.data_mov"(%arg0) : (f32) -> f32 + %1 = "neura.data_mov"(%cst) : (f32) -> f32 %2 = "neura.fadd"(%0, %1) : (f32, f32) -> f32 return %2 : f32 // CHECK: 11.0 @@ -14,8 +14,8 @@ module { func.func @test_sub() -> f32 { %arg0 = arith.constant 9.0 : f32 %cst = arith.constant 2.0 : f32 - %0 = "neura.mov"(%arg0) : (f32) -> f32 - %1 = "neura.mov"(%cst) : (f32) -> f32 + %0 = "neura.data_mov"(%arg0) : (f32) -> f32 + %1 = "neura.data_mov"(%cst) : (f32) -> f32 %2 = "neura.fsub"(%0, %1) : (f32, f32) -> f32 return %2 : f32 // CHECK: 7.0 diff --git a/test/neura/interpreter/lower_and_interpreter.mlir b/test/neura/interpreter/lower_and_interpret.mlir similarity index 93% rename from test/neura/interpreter/lower_and_interpreter.mlir rename to test/neura/interpreter/lower_and_interpret.mlir index c76db420..4f9c0371 100644 --- a/test/neura/interpreter/lower_and_interpreter.mlir +++ b/test/neura/interpreter/lower_and_interpret.mlir @@ -19,7 +19,7 @@ // RUN: %t-out.bin > %t-dumped_output.txt -// RUN: mlir-neura-opt --lower-arith-to-neura --insert-mov %s \ +// RUN: mlir-neura-opt --lower-arith-to-neura --insert-data-mov %s \ // RUN: -o %t-neura.mlir // RUN: neura-interpreter %t-neura.mlir >> %t-dumped_output.txt diff --git a/test/neura/interpreter/lower_and_interpret_subf.mlir b/test/neura/interpreter/lower_and_interpret_subf.mlir index 47475f70..4c7d7d01 100644 --- a/test/neura/interpreter/lower_and_interpret_subf.mlir +++ b/test/neura/interpreter/lower_and_interpret_subf.mlir @@ -19,7 +19,7 @@ // RUN: %t-out.bin > %t-dumped_output.txt -// RUN: mlir-neura-opt --lower-arith-to-neura --insert-mov %s \ +// RUN: mlir-neura-opt --lower-arith-to-neura --insert-data-mov %s \ // RUN: -o %t-neura.mlir // RUN: neura-interpreter %t-neura.mlir >> %t-dumped_output.txt diff --git a/test/neura/interpreter/predicated_data.mlir b/test/neura/interpreter/predicated_data.mlir new file mode 100644 index 00000000..42e635cf --- /dev/null +++ b/test/neura/interpreter/predicated_data.mlir @@ -0,0 +1,11 @@ +// RUN: neura-interpreter %s | FileCheck %s + +module { + func.func @test() -> !neura.data { + %arg0 = "neura.constant"() <{value = 9.0 : f32, predicate = true}> : () -> !neura.data + %cst = "neura.constant"() <{value = 2.0 : f32, predicate = false}> : () -> !neura.data + %res = "neura.fadd"(%arg0, %cst) : (!neura.data, !neura.data) -> !neura.data + return %res : !neura.data + // CHECK: Output: 11.000000 (predicate=false) + } +} diff --git a/test/neura/llvm_add.mlir b/test/neura/llvm_add.mlir index e761bcbf..d17fafe2 100644 --- a/test/neura/llvm_add.mlir +++ b/test/neura/llvm_add.mlir @@ -1,10 +1,10 @@ -// RUN: mlir-neura-opt --assign-accelerator --lower-llvm-to-neura --insert-mov %s | FileCheck %s +// RUN: mlir-neura-opt --assign-accelerator --lower-llvm-to-neura --insert-data-mov %s | FileCheck %s func.func @test(%a: f32) -> f32 { %b = llvm.mlir.constant(2.0 : f32) : f32 %res = llvm.fadd %a, %b : f32 - // CHECK: [[LHS:%.*]] = "neura.mov"(%{{.*}}) : (f32) -> f32 - // CHECK: [[RHS:%.*]] = "neura.mov"(%{{.*}}) : (f32) -> f32 + // CHECK: [[LHS:%.*]] = "neura.data_mov"(%{{.*}}) : (f32) -> f32 + // CHECK: [[RHS:%.*]] = "neura.data_mov"(%{{.*}}) : (f32) -> f32 // CHECK: [[RES:%.*]] = "neura.fadd"([[LHS]], [[RHS]]) return %res : f32 } diff --git a/test/neura/llvm_sub.mlir b/test/neura/llvm_sub.mlir index 0003cc34..1cf1fbf4 100644 --- a/test/neura/llvm_sub.mlir +++ b/test/neura/llvm_sub.mlir @@ -1,10 +1,10 @@ -// RUN: mlir-neura-opt --assign-accelerator --lower-llvm-to-neura --insert-mov %s | FileCheck %s +// RUN: mlir-neura-opt --assign-accelerator --lower-llvm-to-neura --insert-data-mov %s | FileCheck %s func.func @test(%a: f32) -> f32 { %b = llvm.mlir.constant(2.0 : f32) : f32 %res = llvm.fsub %a, %b : f32 - // CHECK: [[LHS:%.*]] = "neura.mov"(%{{.*}}) : (f32) -> f32 - // CHECK: [[RHS:%.*]] = "neura.mov"(%{{.*}}) : (f32) -> f32 + // CHECK: [[LHS:%.*]] = "neura.data_mov"(%{{.*}}) : (f32) -> f32 + // CHECK: [[RHS:%.*]] = "neura.data_mov"(%{{.*}}) : (f32) -> f32 // CHECK: [[RES:%.*]] = "neura.fsub"([[LHS]], [[RHS]]) return %res : f32 } \ No newline at end of file diff --git a/tools/mlir-neura-opt/CMakeLists.txt b/tools/mlir-neura-opt/CMakeLists.txt index 1867af58..70c06a51 100644 --- a/tools/mlir-neura-opt/CMakeLists.txt +++ b/tools/mlir-neura-opt/CMakeLists.txt @@ -10,7 +10,9 @@ set(LIBS MLIRTransforms MLIROptLib MLIRPass - MLIRIR MLIRParser MLIRSupport + MLIRIR + MLIRParser + MLIRSupport ) target_link_libraries(mlir-neura-opt PRIVATE ${LIBS}) \ No newline at end of file diff --git a/tools/neura-interpreter/CMakeLists.txt b/tools/neura-interpreter/CMakeLists.txt index 0e2db039..3447fe99 100644 --- a/tools/neura-interpreter/CMakeLists.txt +++ b/tools/neura-interpreter/CMakeLists.txt @@ -10,7 +10,10 @@ set(LIBS MLIRTransforms MLIROptLib MLIRPass - MLIRIR MLIRParser MLIRSupport + MLIRIR + MLIRParser + MLIRSupport + MLIRInferTypeOpInterface ) target_link_libraries(neura-interpreter PRIVATE ${LIBS}) diff --git a/tools/neura-interpreter/neura-interpreter.cpp b/tools/neura-interpreter/neura-interpreter.cpp index c12b58d0..53cf0eb8 100644 --- a/tools/neura-interpreter/neura-interpreter.cpp +++ b/tools/neura-interpreter/neura-interpreter.cpp @@ -19,6 +19,12 @@ using namespace mlir; +// Data structure to hold both value and predicate. +struct PredicatedData { + float value; + bool predicate; +}; + int main(int argc, char **argv) { if (argc < 2) { llvm::errs() << "Usage: neura-interpreter \n"; @@ -46,7 +52,8 @@ int main(int argc, char **argv) { return 1; } - llvm::DenseMap valueMap; + // Changes map to store PredicatedData instead of just float. + llvm::DenseMap valueMap; for (auto func : module->getOps()) { Block &block = func.getBody().front(); @@ -54,32 +61,69 @@ int main(int argc, char **argv) { for (Operation &op : block.getOperations()) { if (auto constOp = dyn_cast(op)) { auto attr = constOp.getValue(); - - float val = 0.0f; + PredicatedData val{0.0f, true}; // arith constants always have true predicate if (auto floatAttr = llvm::dyn_cast(attr)) { - val = floatAttr.getValueAsDouble(); // or .convertToFloat() + val.value = floatAttr.getValueAsDouble(); } else if (auto intAttr = llvm::dyn_cast(attr)) { - val = static_cast(intAttr.getInt()); // interpret integer as float + val.value = static_cast(intAttr.getInt()); } else { llvm::errs() << "Unsupported constant type in arith.constant\n"; return 1; } valueMap[constOp.getResult()] = val; - } else if (auto movOp = dyn_cast(op)) { + } else if (auto constOp = dyn_cast(op)) { + auto attr = constOp.getValue(); + // Initializes PredicatedData with default values. + PredicatedData val{0.0f, true}; + + // Handles value attribute. + if (auto floatAttr = llvm::dyn_cast(attr)) { + val.value = floatAttr.getValueAsDouble(); + } else if (auto intAttr = llvm::dyn_cast(attr)) { + val.value = static_cast(intAttr.getInt()); + } else { + llvm::errs() << "Unsupported constant type in neura.constant\n"; + return 1; + } + + // Tries getting predicate attribute. + if (auto predAttr = constOp->getAttrOfType("predicate")) { + val.predicate = predAttr.getValue(); + } + + valueMap[constOp.getResult()] = val; + + } else if (auto movOp = dyn_cast(op)) { valueMap[movOp.getResult()] = valueMap[movOp.getOperand()]; + } else if (auto faddOp = dyn_cast(op)) { - float lhs = valueMap[faddOp.getLhs()]; - float rhs = valueMap[faddOp.getRhs()]; - valueMap[faddOp.getResult()] = lhs + rhs; + auto lhs = valueMap[faddOp.getLhs()]; + auto rhs = valueMap[faddOp.getRhs()]; + + // Always performs addition, but combines predicate. + PredicatedData result; + result.value = lhs.value + rhs.value; + result.predicate = lhs.predicate && rhs.predicate; + + valueMap[faddOp.getResult()] = result; } else if (auto fsubOp = dyn_cast(op)) { - float lhs = valueMap[fsubOp.getLhs()]; - float rhs = valueMap[fsubOp.getRhs()]; - valueMap[fsubOp.getResult()] = lhs - rhs; + auto lhs = valueMap[fsubOp.getLhs()]; + auto rhs = valueMap[fsubOp.getRhs()]; + + // Always performs addition, but combines predicate. + PredicatedData result; + result.value = lhs.value - rhs.value; + result.predicate = lhs.predicate && rhs.predicate; + valueMap[fsubOp.getResult()] = result; } else if (auto retOp = dyn_cast(op)) { - float result = valueMap[retOp.getOperand(0)]; - llvm::outs() << "[neura-interpreter] Output: " << llvm::format("%.6f", result) << "\n"; + auto result = valueMap[retOp.getOperand(0)]; + llvm::outs() << "[neura-interpreter] Output: " << llvm::format("%.6f", result.value); + if (!result.predicate) { + llvm::outs() << " (predicate=false)"; + } + llvm::outs() << "\n"; } else { llvm::errs() << "Unhandled op: "; op.print(llvm::errs());