Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion external/dace
Submodule dace updated 1044 files
2 changes: 1 addition & 1 deletion external/gt4py
Submodule gt4py updated 45 files
+7 −0 .github/workflows/daily-ci.yml
+9 −0 CHANGELOG.md
+10 −5 pyproject.toml
+1 −1 src/gt4py/__about__.py
+1 −31 src/gt4py/cartesian/backend/base.py
+1 −1 src/gt4py/cartesian/backend/dace_backend.py
+1 −1 src/gt4py/cartesian/gtc/dace/passes/push_vertical_map_down.py
+22 −81 src/gt4py/cartesian/gtc/dace/treeir_to_stree.py
+4 −0 src/gt4py/cartesian/gtc/numpy/npir.py
+5 −0 src/gt4py/cartesian/gtc/numpy/npir_codegen.py
+5 −3 src/gt4py/cartesian/gtc/numpy/oir_to_npir.py
+6 −0 src/gt4py/next/config.py
+2 −2 src/gt4py/next/embedded/operators.py
+2 −1 src/gt4py/next/ffront/past_passes/type_deduction.py
+3 −7 src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py
+124 −122 src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py
+12 −14 src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py
+9 −5 src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py
+2 −0 src/gt4py/next/program_processors/runners/dace/transformations/__init__.py
+36 −12 src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py
+2 −8 src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py
+7 −2 src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py
+66 −0 src/gt4py/next/program_processors/runners/dace/transformations/scan_loop_unrolling.py
+56 −13 src/gt4py/next/program_processors/runners/dace/transformations/simplify.py
+107 −8 src/gt4py/next/program_processors/runners/dace/transformations/split_access_nodes.py
+22 −6 src/gt4py/next/program_processors/runners/dace/transformations/splitting_tools.py
+0 −4 src/gt4py/next/program_processors/runners/dace/workflow/backend.py
+78 −30 src/gt4py/next/program_processors/runners/dace/workflow/bindings.py
+3 −3 src/gt4py/next/program_processors/runners/dace/workflow/common.py
+8 −2 src/gt4py/next/program_processors/runners/dace/workflow/compilation.py
+1 −1 src/gt4py/next/program_processors/runners/dace/workflow/decoration.py
+12 −0 src/gt4py/next/program_processors/runners/dace/workflow/translation.py
+2 −2 tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py
+22 −24 tests/cartesian_tests/unit_tests/backend_tests/test_dace_backend.py
+0 −1 tests/next_tests/definitions.py
+4 −4 tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py
+76 −0 tests/next_tests/integration_tests/feature_tests/ffront_tests/test_named_collections.py
+2 −2 tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py
+1 −11 tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py
+6 −1 tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py
+186 −49 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py
+93 −0 ...nit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_scan_loop_unrolling.py
+45 −1 .../unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_split_access_node.py
+64 −1 ...ts/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_splitting_tools.py
+90 −165 uv.lock
9 changes: 4 additions & 5 deletions ndsl/dsl/dace/dace_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import enum
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, Self

import dace.config
Expand Down Expand Up @@ -326,11 +327,9 @@ def __init__(
)

# Attempt to kill the dace.conf to avoid confusion
if dace.config.Config._cfg_filename:
try:
os.remove(dace.config.Config._cfg_filename)
except OSError:
pass
dace_conf_to_kill = dace.config.Config.cfg_filename()
if dace_conf_to_kill is not None:
Path(dace_conf_to_kill).unlink(missing_ok=True)

self._backend = backend
self.tile_resolution = [tile_nx, tile_nx, tile_nz]
Expand Down
2 changes: 1 addition & 1 deletion ndsl/dsl/dace/labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def set_label(
if isinstance(sdfg, dace.CompiledSDFG):
return

for state in sdfg.states():
for state in sdfg.nodes():
if sdfg.in_edges(state) == []:
# With the topmost SDFG we have to skip over the
# "init" state
Expand Down
7 changes: 2 additions & 5 deletions ndsl/dsl/dace/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
from dace import SDFG, SDFGState, dtypes
from dace.frontend.common import op_repository as oprepo
from dace.frontend.python.newast import ProgramVisitor
from dace.frontend.python.replacements import (
UfuncInput,
UfuncOutput,
_datatype_converter,
)
from dace.frontend.python.replacements import UfuncInput, UfuncOutput
from dace.frontend.python.replacements.array_manipulation import _datatype_converter

from ndsl.dsl.typing import Float, Int

Expand Down
118 changes: 59 additions & 59 deletions ndsl/dsl/dace/stree/optimizations/axis_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
from typing import Any

import dace
import dace.sdfg.analysis.schedule_tree.treenodes as stree
from dace.properties import CodeBlock
from dace.sdfg.analysis.schedule_tree import treenodes as tn

from ndsl import ndsl_log
from ndsl.dsl.dace.stree.optimizations.memlet_helpers import (
AxisIterator,
no_data_dependencies_on_cartesian_axis,
Expand All @@ -18,25 +17,26 @@
list_index,
swap_node_position_in_tree,
)
from ndsl.logging import ndsl_log


# Buggy passes that should work
PUSH_IFSCOPE_DOWNWARD = False # Crashing the overall stree - bad algorithmics


def _is_axis_map(node: stree.MapScope, axis: AxisIterator) -> bool:
def _is_axis_map(node: tn.MapScope, axis: AxisIterator) -> bool:
"""Returns true if node is a map over the given axis."""
map_parameter = node.node.params
return len(map_parameter) == 1 and map_parameter[0].startswith(axis.as_str())


def _is_axis_for(node: stree.ForScope, axis: AxisIterator) -> bool:
return node.header.itervar.startswith(axis.as_str())
def _is_axis_for(node: tn.ForScope, axis: AxisIterator) -> bool:
return node.loop.loop_variable.startswith(axis.as_str())


def _both_same_single_axis_maps(
first: stree.MapScope,
second: stree.MapScope,
first: tn.MapScope,
second: tn.MapScope,
axis: AxisIterator,
) -> bool:
return (
Expand All @@ -47,8 +47,8 @@ def _both_same_single_axis_maps(


def _can_merge_axis_maps(
first: stree.MapScope,
second: stree.MapScope,
first: tn.MapScope,
second: tn.MapScope,
axis: AxisIterator,
) -> bool:
return _both_same_single_axis_maps(
Expand All @@ -60,7 +60,7 @@ def _can_merge_axis_maps(
)


class InsertOvercomputationGuard(stree.ScheduleNodeTransformer):
class InsertOvercomputationGuard(tn.ScheduleNodeTransformer):
def __init__(
self,
axis_as_string: str,
Expand All @@ -85,13 +85,13 @@ def _execution_condition(self) -> CodeBlock:
f"and ({self._axis_as_string} - {start}) % {step} == 0"
)

def visit_MapScope(self, node: stree.MapScope) -> stree.MapScope:
def visit_MapScope(self, node: tn.MapScope) -> tn.MapScope:
all_children_are_maps = all(
[isinstance(child, stree.MapScope) for child in node.children]
[isinstance(child, tn.MapScope) for child in node.children]
)
if not all_children_are_maps:
if self._merged_range != self._original_range:
if_scope = stree.IfScope(
if_scope = tn.IfScope(
condition=self._execution_condition(), children=node.children
)
# Re-parent to IF
Expand All @@ -105,15 +105,13 @@ def visit_MapScope(self, node: stree.MapScope) -> stree.MapScope:


def _get_next_node(
nodes: list[stree.ScheduleTreeNode],
node: stree.ScheduleTreeNode,
) -> stree.ScheduleTreeNode:
nodes: list[tn.ScheduleTreeNode],
node: tn.ScheduleTreeNode,
) -> tn.ScheduleTreeNode:
return nodes[list_index(nodes, node) + 1]


def _last_node(
nodes: list[stree.ScheduleTreeNode], node: stree.ScheduleTreeNode
) -> bool:
def _last_node(nodes: list[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode) -> bool:
return list_index(nodes, node) >= len(nodes) - 1


Expand All @@ -124,13 +122,13 @@ def _sanitize_axis(axis: AxisIterator, name_to_normalize: str) -> str:
return re.sub(pattern, axis_clean, name_to_normalize)


class NormalizeAxisSymbol(stree.ScheduleNodeVisitor):
class NormalizeAxisSymbol(tn.ScheduleNodeVisitor):
def __init__(self, axis: AxisIterator) -> None:
self.axis = axis

def visit_MapScope(
self,
map_scope: stree.MapScope,
map_scope: tn.MapScope,
axis_replacements: dict[str, str] | None = None,
**kwargs: Any,
) -> None:
Expand All @@ -147,7 +145,7 @@ def visit_MapScope(

def visit_TaskletNode(
self,
node: stree.TaskletNode,
node: tn.TaskletNode,
axis_replacements: dict[str, str] | None = None,
**kwargs: Any,
) -> None:
Expand All @@ -159,7 +157,7 @@ def visit_TaskletNode(
memlets.replace(axis_replacements)


class CartesianAxisMerge(stree.ScheduleNodeTransformer):
class CartesianAxisMerge(tn.ScheduleNodeTransformer):
"""Merge a cartesian axis if they are contiguous in code-flow.

Can do:
Expand Down Expand Up @@ -188,8 +186,8 @@ def __str__(self) -> str:

def _merge_node(
self,
node: stree.ScheduleTreeNode,
nodes: list[stree.ScheduleTreeNode],
node: tn.ScheduleTreeNode,
nodes: list[tn.ScheduleTreeNode],
) -> int:
"""Direct code to the correct resolver for the node (e.g. visitor)

Expand All @@ -198,26 +196,28 @@ def _merge_node(
behavior (e.g. IfScope before ControlFlowScope)
"""

if isinstance(node, stree.MapScope):
if isinstance(node, tn.MapScope):
return self._map_overcompute_merge(node, nodes)
elif PUSH_IFSCOPE_DOWNWARD and isinstance(node, stree.IfScope):

if PUSH_IFSCOPE_DOWNWARD and isinstance(node, tn.IfScope):
return self._push_ifelse_down(node, nodes)
elif isinstance(node, stree.ForScope):

if isinstance(node, tn.ForScope):
return self._for_merge(node, nodes)
elif isinstance(node, stree.TaskletNode):

if isinstance(node, tn.TaskletNode):
return self._push_tasklet_down(node, nodes)
elif isinstance(node, stree.ControlFlowScope):

if isinstance(node, tn.ControlFlowScope):
return self._default_control_flow(node)
else:
ndsl_log.debug(
f" (╯°□°)╯︵ ┻━┻: can't merge {type(node)}. Recursion ends."
)

ndsl_log.debug(f" (╯°□°)╯︵ ┻━┻: can't merge {type(node)}. Recursion ends.")
return 0

def _for_merge(
self,
the_for_scope: stree.ForScope,
nodes: list[stree.ScheduleTreeNode],
the_for_scope: tn.ForScope,
nodes: list[tn.ScheduleTreeNode],
) -> int:
merged = 0

Expand All @@ -229,7 +229,7 @@ def _for_merge(
# Non-cartesian for - can be pushed down if everything merged below
if (
len(the_for_scope.children) == 1
and isinstance(the_for_scope.children[0], stree.MapScope)
and isinstance(the_for_scope.children[0], tn.MapScope)
and _is_axis_map(the_for_scope.children[0], self.axis)
):
swap_node_position_in_tree(the_for_scope, the_for_scope.children[0])
Expand All @@ -239,7 +239,7 @@ def _for_merge(

def _default_control_flow(
self,
the_control_flow: stree.ControlFlowScope,
the_control_flow: tn.ControlFlowScope,
) -> int:
if len(the_control_flow.children) != 0:
return self._merge(the_control_flow)
Expand All @@ -248,8 +248,8 @@ def _default_control_flow(

def _push_tasklet_down(
self,
the_tasklet: stree.TaskletNode,
nodes: list[stree.ScheduleTreeNode],
the_tasklet: tn.TaskletNode,
nodes: list[tn.ScheduleTreeNode],
) -> int:
"""Push tasklet into a consecutive map."""
in_memlets = the_tasklet.input_memlets()
Expand All @@ -269,7 +269,7 @@ def _push_tasklet_down(

# Attempt to push the tasklet in the next map
next_node = nodes[next_index + 1]
if isinstance(next_node, stree.MapScope):
if isinstance(next_node, tn.MapScope):
next_node.children.insert(0, the_tasklet)
the_tasklet.parent = next_node
nodes.remove(the_tasklet)
Expand All @@ -279,8 +279,8 @@ def _push_tasklet_down(

def _push_ifelse_down(
self,
the_if: stree.IfScope,
nodes: list[stree.ScheduleTreeNode],
the_if: tn.IfScope,
nodes: list[tn.ScheduleTreeNode],
) -> int:
merged = 0

Expand All @@ -291,8 +291,8 @@ def _push_ifelse_down(
for else_index in range(if_index + 1, len(nodes)):
else_node = nodes[else_index]
if else_index < len(nodes) and (
isinstance(else_node, stree.ElseScope)
or isinstance(else_node, stree.ElifScope)
isinstance(else_node, tn.ElseScope)
or isinstance(else_node, tn.ElifScope)
):
merged += self._merge_node(else_node, else_node.children)
else:
Expand All @@ -302,17 +302,17 @@ def _push_ifelse_down(

# Gather all first maps - if they do not exists, get out
all_maps = []
if isinstance(the_if.children[0], stree.MapScope):
if isinstance(the_if.children[0], tn.MapScope):
all_maps.append(the_if.children[0])
else:
return merged
for else_index in range(if_index + 1, len(nodes)):
else_node = nodes[else_index]
if else_index < len(nodes) and (
isinstance(else_node, stree.ElseScope)
or isinstance(else_node, stree.ElifScope)
isinstance(else_node, tn.ElseScope)
or isinstance(else_node, tn.ElifScope)
):
if isinstance(else_node.children[0], stree.MapScope):
if isinstance(else_node.children[0], tn.MapScope):
all_maps.append(else_node.children[0])
else:
return merged
Expand All @@ -337,8 +337,8 @@ def _push_ifelse_down(
# Swap ELIF/ELSE & maps
for else_index in range(if_index + 1, len(nodes)):
if else_index < len(nodes) and (
isinstance(nodes[else_index], stree.ElseScope)
or isinstance(nodes[else_index], stree.ElifScope)
isinstance(nodes[else_index], tn.ElseScope)
or isinstance(nodes[else_index], tn.ElifScope)
):
swap_node_position_in_tree(
nodes[else_index], nodes[else_index].children[0]
Expand All @@ -347,15 +347,15 @@ def _push_ifelse_down(
break

# Merge the Maps
assert isinstance(nodes[if_index], stree.MapScope)
assert isinstance(nodes[if_index], tn.MapScope)
merged += self._map_overcompute_merge(nodes[if_index], nodes)

return merged

def _map_overcompute_merge(
self,
the_map: stree.MapScope,
nodes: list[stree.ScheduleTreeNode],
the_map: tn.MapScope,
nodes: list[tn.ScheduleTreeNode],
) -> int:
# End of nodes OR
# Not the right axis
Expand All @@ -369,7 +369,7 @@ def _map_overcompute_merge(
next_node = _get_next_node(nodes, the_map)

# Next node is not a MapScope - no merge
if not isinstance(next_node, stree.MapScope):
if not isinstance(next_node, tn.MapScope):
return 0

# Attempt to merge consecutive maps
Expand Down Expand Up @@ -401,7 +401,7 @@ def _map_overcompute_merge(
merged_range=merged_range,
original_range=second_range,
).visit(next_node)
merged_children: list[stree.MapScope] = [
merged_children: list[tn.MapScope] = [
*first_map.children,
*second_map.children,
]
Expand All @@ -419,7 +419,7 @@ def _map_overcompute_merge(

return 1

def _merge(self, node: stree.ScheduleTreeRoot | stree.ScheduleTreeScope) -> int:
def _merge(self, node: tn.ScheduleTreeRoot | tn.ScheduleTreeScope) -> int:
merged = 0

if __debug__:
Expand All @@ -436,7 +436,7 @@ def _merge(self, node: stree.ScheduleTreeRoot | stree.ScheduleTreeScope) -> int:

return merged

def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None:
def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None:
"""Merge as many maps as possible.

The algorithm works as follows:
Expand Down Expand Up @@ -479,5 +479,5 @@ def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None:
passes_apply += 1

ndsl_log.debug(
f"🚀 Cartesian Axis Merge ({self.axis.name}): {overall_merged} map merged in {passes_apply} passes"
f"🚀 {self}: {overall_merged} maps merged in {passes_apply} passes"
)
Loading
Loading