Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ url = 'https://gridtools.github.io/pypi/'
[tool.uv.sources]
atlas4py = {index = "test.pypi"}
dace = [
{git = "https://github.com/GridTools/dace", branch = "romanc/stree-roundtrip", group = "dace-cartesian"},
{git = "https://github.com/romanc/dace", branch = "romanc/stree-v2", group = "dace-cartesian"},
{index = "gridtools", group = "dace-next"}
]

Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def freeze_origin_domain_sdfg(
if node.has_writes(inner_state):
outputs.add(node.data)

nsdfg = state.add_nested_sdfg(inner_sdfg, None, inputs, outputs)
nsdfg = state.add_nested_sdfg(inner_sdfg, inputs, outputs)

_sdfg_add_arrays_and_edges(
field_info, wrapper_sdfg, state, inner_sdfg, nsdfg, inputs, outputs, origin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,5 @@ def visit_MapScope(self, node: tn.MapScope):
self._push_K_loop_in_IJ(node)

def visit_ForScope(self, node: tn.ForScope):
if node.header.itervar.startswith("__k"):
if node.loop.loop_variable.startswith("__k"):
self._push_K_loop_in_IJ(node)
103 changes: 22 additions & 81 deletions src/gt4py/cartesian/gtc/dace/treeir_to_stree.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from dataclasses import dataclass
from types import TracebackType

from dace import __version__ as dace_version, dtypes, nodes, sdfg, subsets
from dace.codegen import control_flow as dcf
from dace import nodes, subsets
from dace.properties import CodeBlock
from dace.sdfg.analysis.schedule_tree import treenodes as tn
from dace.sdfg.state import LoopRegion

from gt4py import eve
from gt4py.cartesian.gtc import common
Expand Down Expand Up @@ -67,10 +67,6 @@ def visit_Tasklet(self, node: tir.Tasklet, ctx: Context) -> None:
ctx.current_scope.children.append(tasklet)

def visit_HorizontalLoop(self, node: tir.HorizontalLoop, ctx: Context) -> None:
# Define axis iteration symbols
for axis in tir.Axis.dims_horizontal():
ctx.tree.symbols[axis.iteration_symbol()] = dtypes.int32

dace_map = nodes.Map(
label=f"horizontal_loop_{id(node)}",
params=[tir.Axis.J.iteration_symbol(), tir.Axis.I.iteration_symbol()],
Expand All @@ -89,12 +85,9 @@ def visit_HorizontalLoop(self, node: tir.HorizontalLoop, ctx: Context) -> None:
self.visit(node.children, ctx=ctx)

def visit_VerticalLoop(self, node: tir.VerticalLoop, ctx: Context) -> None:
# In any case, define the iteration symbol
ctx.tree.symbols[node.iteration_variable] = dtypes.int32

# For serial loops, create a ForScope and add it to the tree
if node.loop_order != common.LoopOrder.PARALLEL:
for_scope = tn.ForScope(header=_for_scope_header(node), children=[])
for_scope = tn.ForScope(loop=_loop_region_for(node), children=[])

with ContextPushPop(ctx, for_scope):
self.visit(node.children, ctx=ctx)
Expand All @@ -118,15 +111,15 @@ def visit_VerticalLoop(self, node: tir.VerticalLoop, ctx: Context) -> None:

def visit_IfElse(self, node: tir.IfElse, ctx: Context) -> None:
if_scope = tn.IfScope(
condition=tn.CodeBlock(node.if_condition_code),
condition=CodeBlock(node.if_condition_code),
children=[],
)

with ContextPushPop(ctx, if_scope):
self.visit(node.children, ctx=ctx)

def visit_While(self, node: tir.While, ctx: Context) -> None:
while_scope = tn.WhileScope(children=[], header=_while_scope_header(node))
while_scope = tn.WhileScope(loop=_loop_region_while(node), children=[])

with ContextPushPop(ctx, while_scope):
self.visit(node.children, ctx=ctx)
Expand All @@ -147,83 +140,31 @@ def visit_TreeRoot(self, node: tir.TreeRoot) -> tn.ScheduleTreeRoot:
return ctx.tree


def _for_scope_header(node: tir.VerticalLoop) -> dcf.ForScope:
"""Header for the tn.ForScope re-using DaCe codegen ForScope.

Only setup the required data, default or mock the rest.

TODO: In DaCe 2.x this will be replaced by an SDFG concept which should
be closer and required less mockup.
def _loop_region_for(node: tir.VerticalLoop) -> LoopRegion:
"""
if not dace_version.startswith("1."):
raise NotImplementedError("DaCe 2.x detected - please fix below code")
if node.loop_order == common.LoopOrder.PARALLEL:
raise ValueError("Parallel vertical loops should be translated to maps instead.")
Translates a vertical loop into a Dace LoopRegion to be used in `tn.ForScope`.

:param node: Vertical loop to translate
:return: DaCe LoopRegion to use in `tn.ForScope`
"""
plus_minus = "+" if node.loop_order == common.LoopOrder.FORWARD else "-"
comparison = "<" if node.loop_order == common.LoopOrder.FORWARD else ">="
iteration_var = node.iteration_variable

for_scope = dcf.ForScope(
condition=CodeBlock(
code=f"{iteration_var} {comparison} {node.bounds_k.end}",
language=dtypes.Language.Python,
),
itervar=iteration_var,
init=node.bounds_k.start,
update=f"{iteration_var} {plus_minus} 1",
# Unused
parent=None, # not Tree parent, CF parent
dispatch_state=lambda _state: "",
last_block=False,
guard=sdfg.SDFGState(),
body=dcf.GeneralBlock(
lambda _state: "",
None,
True,
None,
[],
[],
[],
[],
[],
False,
),
init_edges=[],
return LoopRegion(
label=f"vertical_loop_{id(node)}",
loop_var=iteration_var,
initialize_expr=CodeBlock(f"{iteration_var} = {node.bounds_k.start}"),
condition_expr=CodeBlock(f"{iteration_var} {comparison} {node.bounds_k.end}"),
update_expr=CodeBlock(f"{iteration_var} = {iteration_var} {plus_minus} 1"),
)
# Kill the loop_range test for memlet propagation check going in
dcf.ForScope.loop_range = lambda self: None
return for_scope


def _while_scope_header(node: tir.While) -> dcf.WhileScope:
"""Header for the tn.WhileScope re-using DaCe codegen WhileScope.

Only setup the required data, default or mock the rest.
def _loop_region_while(node: tir.While) -> LoopRegion:
"""
Translates a while loop into a Dace LoopRegion to be used in `tn.WhileScope`.

TODO: In DaCe 2.x this will be replaced by an SDFG concept which should
be closer and required less mockup.
:param node: While loop to translate
:return: DaCe LoopRegion to use in `tn.WhileScope`
"""
if not dace_version.startswith("1."):
raise NotImplementedError("DaCe 2.x detected - please fix below code")

return dcf.WhileScope(
test=CodeBlock(node.condition_code),
# Unused
guard=sdfg.SDFGState(),
dispatch_state=lambda _state: "",
parent=None,
body=dcf.GeneralBlock(
lambda _state: "",
None,
True,
None,
[],
[],
[],
[],
[],
False,
),
last_block=False,
)
return LoopRegion(label=f"while_loop_{id(node)}", condition_expr=CodeBlock(node.condition_code))
46 changes: 22 additions & 24 deletions tests/cartesian_tests/unit_tests/backend_tests/test_dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from dace import nodes
from dace import sdfg as dace_sdfg
from dace.sdfg.state import LoopRegion
import dace.sdfg.analysis.schedule_tree.treenodes as tn

from gt4py.cartesian import backend
Expand Down Expand Up @@ -109,21 +110,22 @@ def test_dace_cpu_kfirst_loop_structure():
manager = SDFGManager(builder)

sdfg = manager.sdfg_via_schedule_tree()
assert len(list(sdfg.states())) == 1, "expect one state"
state = sdfg.states()[0]

# Expect IJ Map and For loop construct (Nested SDFG, four guard states)
assert [node.map.params for node in state.nodes() if isinstance(node, nodes.MapEntry)] == [
["__i", "__j"]
]
for_nested_nodes = [
node.sdfg.nodes() for node in state.nodes() if isinstance(node, nodes.NestedSDFG)
]
assert [isinstance(node, dace_sdfg.SDFGState) for node in for_nested_nodes[0]] == [
True,
True,
True,
True,
]
# Expect a Map for IJ outside
map_entry_nodes = [node for node in state.nodes() if isinstance(node, nodes.MapEntry)]
assert len(map_entry_nodes) == 1, "expect one MapEntry node"
assert map_entry_nodes[0].map.params == ["__i", "__j"]

# Expect LoopRegion for K inside map
nsdfg_nodes = [node for node in state.nodes() if isinstance(node, nodes.NestedSDFG)]
assert len(nsdfg_nodes) == 1
for_nested_nodes = nsdfg_nodes[0].sdfg.nodes()
assert len(for_nested_nodes) == 1
loop_region = for_nested_nodes[0]
assert isinstance(loop_region, LoopRegion)
assert loop_region.loop_variable.startswith("__k")


def test_dace_cpu_KJI_loop_structure():
Expand All @@ -141,17 +143,13 @@ def test_dace_cpu_KJI_loop_structure():
manager = SDFGManager(builder)

sdfg = manager.sdfg_via_schedule_tree()
state = sdfg.states()[0]

# Expect top-level for loop guards (4)
assert [isinstance(node, dace_sdfg.SDFGState) for node in sdfg.states()] == [
True,
True,
True,
True,
]
# Expect LoopRegion for K outside
loop_region: LoopRegion = list(sdfg.all_control_flow_blocks())[0]
assert loop_region.loop_variable.startswith("__k")

# Expect JI Map and in loop_body state (#2)
assert [
node.map.params for node in sdfg.states()[2].nodes() if isinstance(node, nodes.MapEntry)
] == [["__j", "__i"]]
state = loop_region.start_block
assert [node.map.params for node in state.nodes() if isinstance(node, nodes.MapEntry)] == [
["__j", "__i"]
]
Loading