Skip to content

expand func op pass #500

@andrewcaiuuu

Description

@andrewcaiuuu

Hello, I am trying to write a pass to give each compute op a func op, it is generating IRs that fail to lower.
The pass looks like this, is there anything obviously incorrect?

from ..ast import ast
from .pass_manager import Pass
from hcl_mlir.exceptions import *

class ExpandFunc(Pass):
    """ Convert all funcop into nested funcop """
    def __init__(self):
        super().__init__("expand_func")
        self._ast = None
        self.subfuncs = []

    def visit(self, op):
        if isinstance(op, ast.FuncOp) and op.name == "top":
            self.expand_func(op)
            # print("SUBFUNCS: ", self.subfuncs)
            # print("ORIGINAL BODY: ", op.body)
            op.body = []
            for subfunc in self.subfuncs:
                call_op = ast.CallOp(subfunc.name, subfunc.args, subfunc.return_tensors, subfunc.loc)
                op.body.append(call_op)
                
    
    def apply(self, _ast):
        """Pass entry point"""
        self._ast = _ast
        for op in _ast.region:
            self.visit(op)
        return _ast

    def expand_func(self, scope):
        i = 0
        for op in scope.body:
            # print("EXPAND_FUNC GOT OP: ", op)
            if isinstance(op, ast.ComputeOp):
                lower_func_op = ast.FuncOp(f"sub_func{i}", op.input_tensors, [op], op.loc)
                lower_func_op.level = 1
                self.update_level(lower_func_op)
                self._ast.region.insert(1, lower_func_op)
                self.subfuncs.append(lower_func_op)
                i += 1
        return

This is the module used for testing right now:

A = hcl.placeholder((10,), "A")
def kernel(A):
    B = hcl.compute((10,), lambda x: A[x])
    return B

s = hcl.create_schedule([A], kernel)
print(hcl.lower(s))

which results in this IR, the bolded line is wrong:


#map0 = affine_map<(d0) -> (d0)>  
#map1 = affine_map<() -> (0)>  
#map2 = affine_map<() -> (10)>  
"builtin.module"() ({  
  "func.func"() ({  
  ^bb0(%arg0: memref<10xi32>):  
    "func.call"(%arg0) {callee = @sub_func0} : (memref<10xi32>) -> ()  
    %0 = "memref.alloc"() {name = "tensor_1", operand_segment_sizes = dense<0> : vector<2xi32>} : () -> memref<10xi32>
    "func.return"(%0) : (memref<10xi32>) -> ()
  }) {function_type = (memref<10xi32>) -> memref<10xi32>, itypes = "s", otypes = "s", sym_name = "top"} : () -> ()
  "func.func"() ({
  ^bb0(%arg0: memref<10xi32>):
    "affine.for"() ({
    ^bb0(%arg1: index):
      %0 = "affine.load"(%arg0, %arg1) {from = "A", map = #map0} : (memref<10xi32>, index) -> i32
     "affine.store"(%0, %0, %arg1) {map = #map0, to = "tensor_1"} : (i32, memref<10xi32>, index) -> ()  
      "affine.yield"() : () -> ()  
    }) {loop_name = "x", lower_bound = #map1, op_name = "tensor_1", step = 1 : i32, upper_bound = #map2} : () -> ()
    "func.return"() : () -> ()
  }) {function_type = (memref<10xi32>) -> (), itypes = "s", otypes = "", sym_name = "sub_func0"} : () -> ()
}) : () -> () 

it passes when changed to:


#map0 = affine_map<(d0) -> (d0)>  
#map1 = affine_map<() -> (0)>  
#map2 = affine_map<() -> (10)>  
"builtin.module"() ({  
  "func.func"() ({  
  ^bb0(%arg0: memref<10xi32>):  
    "func.call"(%arg0) {callee = @sub_func0} : (memref<10xi32>) -> ()  
    %0 = "memref.alloc"() {name = "tensor_1", operand_segment_sizes = dense<0> : vector<2xi32>} : () -> memref<10xi32>
    "func.return"(%0) : (memref<10xi32>) -> ()
  }) {function_type = (memref<10xi32>) -> memref<10xi32>, itypes = "s", otypes = "s", sym_name = "top"} : () -> ()
  "func.func"() ({
  ^bb0(%arg0: memref<10xi32>):
    "affine.for"() ({
    ^bb0(%arg1: index):
      %0 = "affine.load"(%arg0, %arg1) {from = "A", map = #map0} : (memref<10xi32>, index) -> i32
      "affine.store"(%0, %arg0, %arg1) {map = #map0, to = "A"} : (i32, memref<10xi32>, index) -> () 
      "affine.yield"() : () -> ()  
    }) {loop_name = "x", lower_bound = #map1, op_name = "tensor_1", step = 1 : i32, upper_bound = #map2} : () -> ()
    "func.return"() : () -> ()
  }) {function_type = (memref<10xi32>) -> (), itypes = "s", otypes = "", sym_name = "sub_func0"} : () -> ()
}) : () -> () 

I have my code in a fork linked here
https://github.com/andrewcaiuuu/heterocl
The module shown above is in the playground folder named simple.py

Thank you very much.

Metadata

Metadata

Assignees

Labels

ASTIssues related to HeteroCL AST

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions