Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions src/kirin/dialects/lowering/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,17 @@
def lower_FunctionDef(
self, state: lowering.State[ast.AST], node: ast.FunctionDef
) -> lowering.Result:

frame = state.current_frame
slots = tuple(arg.arg for arg in node.args.args)
self.assert_simple_arguments(node.args)
signature = func.Signature(
inputs=tuple(
self.get_hint(state, arg.annotation) for arg in node.args.args
),
output=self.get_hint(state, node.returns),
)
frame = state.current_frame

slots = tuple(arg.arg for arg in node.args.args)
entries: dict[str, ir.SSAValue] = {}
entr_block = ir.Block()
fn_self = entr_block.args.append_from(
Expand Down Expand Up @@ -109,6 +110,73 @@
# NOTE: Python automatically assigns the lambda to the name
frame.defs[node.name] = lambda_stmt.result

def lower_Lambda(
self, state: lowering.State[ast.AST], node: ast.Lambda
) -> lowering.Result:

frame = state.current_frame
slots = tuple(arg.arg for arg in node.args.args)
self.assert_simple_arguments(node.args)
signature = func.Signature(
inputs=tuple(
self.get_hint(state, arg.annotation) for arg in node.args.args
),
output=types.Any,
)
node_name = f"lambda_0x{id(node)}"

entries: dict[str, ir.SSAValue] = {}
entr_block = ir.Block()
fn_self = entr_block.args.append_from(
types.MethodType[list(signature.inputs), signature.output],
node_name + "_self",
)
entries[node_name] = fn_self
for arg, type in zip(node.args.args, signature.inputs):
entries[arg.arg] = entr_block.args.append_from(type, arg.arg)

def callback(frame: lowering.Frame, value: ir.SSAValue):
first_stmt = entr_block.first_stmt
stmt = func.GetField(obj=fn_self, field=len(frame.captures) - 1)
if value.name:
stmt.result.name = value.name
stmt.result.type = value.type
stmt.source = state.source
if first_stmt:
stmt.insert_before(first_stmt)
else:
entr_block.stmts.append(stmt)
return stmt.result

with state.frame(
[node.body], entr_block=entr_block, capture_callback=callback
) as func_frame:
func_frame.defs.update(entries)
func_frame.exhaust()

last_stmt = func_frame.curr_region.blocks[0].last_stmt
rtrn_stmt = func.Return(last_stmt.result)

Check failure on line 158 in src/kirin/dialects/lowering/func.py

View workflow job for this annotation

GitHub Actions / pyright

"result" is not a known attribute of "None" (reportOptionalMemberAccess)

Check failure on line 158 in src/kirin/dialects/lowering/func.py

View workflow job for this annotation

GitHub Actions / pyright

Cannot access attribute "result" for class "Statement"   Attribute "result" is unknown (reportAttributeAccessIssue)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add comment if you are certain the last_stmt.result won't be None? and make pyright happy about it using ignore? Or we should add a check here.

func_frame.curr_block.stmts.append(rtrn_stmt)

first_stmt = func_frame.curr_region.blocks[0].first_stmt
if first_stmt is None:
raise lowering.BuildError("empty lambda body")

func_frame.curr_region.blocks[1].delete()

lambda_stmt = func.Lambda(
tuple(value for value in func_frame.captures.values()),
sym_name=node_name,
slots=slots,
signature=signature,
body=func_frame.curr_region,
)

lambda_stmt.result.name = node_name
frame.push(lambda_stmt)
frame.defs[node_name] = lambda_stmt.result
return lambda_stmt.result

def assert_simple_arguments(self, node: ast.arguments) -> None:
if node.kwonlyargs:
raise lowering.BuildError("keyword-only arguments are not supported")
Expand Down
2 changes: 2 additions & 0 deletions src/kirin/dialects/py/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,11 @@ def lower_Assign(self, state: lowering.State, node: ast.Assign) -> lowering.Resu
case ast.Assign(
targets=[ast.Name(lhs_name, ast.Store())], value=ast.Name(_, ast.Load())
):

stmt = Alias(
value=result.data[0], target=ir.PyAttr(lhs_name)
) # NOTE: this is guaranteed to be one result

stmt.result.name = lhs_name
current_frame.defs[lhs_name] = current_frame.push(stmt).result
case _:
Expand Down
64 changes: 64 additions & 0 deletions test/lowering/test_lambda_comp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from kirin import ir
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove the comp this is not about lambda in list comprehension?

from kirin.prelude import basic
from kirin.dialects import ilist


def test_lambda_comp_with_closure():
@basic(fold=False)
def main(z, r):
return (lambda x: x + z)(r)

assert main(3, 4) == 7


def test_lambda_comp():
@basic(fold=False)
def main(z):
return lambda x: x + z

x = main(3)
assert isinstance(x, ir.Method)
assert x(4) == 7


def test_invoke_from_lambda_comp():

@basic
def foo(a):
return a * 2

@basic(fold=False)
def main(z):
return lambda x: x + foo(z)

x = main(3)

assert isinstance(x, ir.Method)
assert x(4) == 10


def test_lambda_in_lambda():

@basic(fold=False)
def main(z):

def my_foo(a):
return lambda x: x * a

return my_foo(z)

x = main(3)

assert isinstance(x, ir.Method)
assert x(4) == 12


def test_ilist_map():

@basic(fold=False)
def main(z):
return ilist.map(lambda x: x + z, ilist.range(10))

x = main(3)
assert len(x) == 10
assert x.data == [3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
Loading