diff --git a/external/dace b/external/dace index 4a9f4602..aa1e4f4a 160000 --- a/external/dace +++ b/external/dace @@ -1 +1 @@ -Subproject commit 4a9f46027147a52e2b0ac9eedeb101c3ab27d0bf +Subproject commit aa1e4f4a872e3d96de2e6435e926c048a89241c2 diff --git a/external/gt4py b/external/gt4py index 014e7190..960e24f6 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit 014e719096e7889ab07c5cc86a13c65dab4cbb9d +Subproject commit 960e24f6f71dc68ab6880e0af7ab9425e8be62fc diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index 2be2b940..3e85050b 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -2,6 +2,7 @@ import enum import os +from pathlib import Path from typing import TYPE_CHECKING, Any, Self import dace.config @@ -326,11 +327,9 @@ def __init__( ) # Attempt to kill the dace.conf to avoid confusion - if dace.config.Config._cfg_filename: - try: - os.remove(dace.config.Config._cfg_filename) - except OSError: - pass + dace_conf_to_kill = dace.config.Config.cfg_filename() + if dace_conf_to_kill is not None: + Path(dace_conf_to_kill).unlink(missing_ok=True) self._backend = backend self.tile_resolution = [tile_nx, tile_nx, tile_nz] diff --git a/ndsl/dsl/dace/labeler.py b/ndsl/dsl/dace/labeler.py index cc673999..a70e19bf 100644 --- a/ndsl/dsl/dace/labeler.py +++ b/ndsl/dsl/dace/labeler.py @@ -45,7 +45,7 @@ def set_label( if isinstance(sdfg, dace.CompiledSDFG): return - for state in sdfg.states(): + for state in sdfg.nodes(): if sdfg.in_edges(state) == []: # With the topmost SDFG we have to skip over the # "init" state diff --git a/ndsl/dsl/dace/replacements.py b/ndsl/dsl/dace/replacements.py index ca7bed0c..79ed627f 100644 --- a/ndsl/dsl/dace/replacements.py +++ b/ndsl/dsl/dace/replacements.py @@ -5,11 +5,8 @@ from dace import SDFG, SDFGState, dtypes from dace.frontend.common import op_repository as oprepo from dace.frontend.python.newast import ProgramVisitor -from dace.frontend.python.replacements import ( - UfuncInput, - UfuncOutput, - _datatype_converter, -) +from dace.frontend.python.replacements import UfuncInput, UfuncOutput +from dace.frontend.python.replacements.array_manipulation import _datatype_converter from ndsl.dsl.typing import Float, Int diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index c5571be4..e0e7393e 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -5,10 +5,9 @@ from typing import Any import dace -import dace.sdfg.analysis.schedule_tree.treenodes as stree from dace.properties import CodeBlock +from dace.sdfg.analysis.schedule_tree import treenodes as tn -from ndsl import ndsl_log from ndsl.dsl.dace.stree.optimizations.memlet_helpers import ( AxisIterator, no_data_dependencies_on_cartesian_axis, @@ -18,25 +17,26 @@ list_index, swap_node_position_in_tree, ) +from ndsl.logging import ndsl_log # Buggy passes that should work PUSH_IFSCOPE_DOWNWARD = False # Crashing the overall stree - bad algorithmics -def _is_axis_map(node: stree.MapScope, axis: AxisIterator) -> bool: +def _is_axis_map(node: tn.MapScope, axis: AxisIterator) -> bool: """Returns true if node is a map over the given axis.""" map_parameter = node.node.params return len(map_parameter) == 1 and map_parameter[0].startswith(axis.as_str()) -def _is_axis_for(node: stree.ForScope, axis: AxisIterator) -> bool: - return node.header.itervar.startswith(axis.as_str()) +def _is_axis_for(node: tn.ForScope, axis: AxisIterator) -> bool: + return node.loop.loop_variable.startswith(axis.as_str()) def _both_same_single_axis_maps( - first: stree.MapScope, - second: stree.MapScope, + first: tn.MapScope, + second: tn.MapScope, axis: AxisIterator, ) -> bool: return ( @@ -47,8 +47,8 @@ def _both_same_single_axis_maps( def _can_merge_axis_maps( - first: stree.MapScope, - second: stree.MapScope, + first: tn.MapScope, + second: tn.MapScope, axis: AxisIterator, ) -> bool: return _both_same_single_axis_maps( @@ -60,7 +60,7 @@ def _can_merge_axis_maps( ) -class InsertOvercomputationGuard(stree.ScheduleNodeTransformer): +class InsertOvercomputationGuard(tn.ScheduleNodeTransformer): def __init__( self, axis_as_string: str, @@ -85,13 +85,13 @@ def _execution_condition(self) -> CodeBlock: f"and ({self._axis_as_string} - {start}) % {step} == 0" ) - def visit_MapScope(self, node: stree.MapScope) -> stree.MapScope: + def visit_MapScope(self, node: tn.MapScope) -> tn.MapScope: all_children_are_maps = all( - [isinstance(child, stree.MapScope) for child in node.children] + [isinstance(child, tn.MapScope) for child in node.children] ) if not all_children_are_maps: if self._merged_range != self._original_range: - if_scope = stree.IfScope( + if_scope = tn.IfScope( condition=self._execution_condition(), children=node.children ) # Re-parent to IF @@ -105,15 +105,13 @@ def visit_MapScope(self, node: stree.MapScope) -> stree.MapScope: def _get_next_node( - nodes: list[stree.ScheduleTreeNode], - node: stree.ScheduleTreeNode, -) -> stree.ScheduleTreeNode: + nodes: list[tn.ScheduleTreeNode], + node: tn.ScheduleTreeNode, +) -> tn.ScheduleTreeNode: return nodes[list_index(nodes, node) + 1] -def _last_node( - nodes: list[stree.ScheduleTreeNode], node: stree.ScheduleTreeNode -) -> bool: +def _last_node(nodes: list[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode) -> bool: return list_index(nodes, node) >= len(nodes) - 1 @@ -124,13 +122,13 @@ def _sanitize_axis(axis: AxisIterator, name_to_normalize: str) -> str: return re.sub(pattern, axis_clean, name_to_normalize) -class NormalizeAxisSymbol(stree.ScheduleNodeVisitor): +class NormalizeAxisSymbol(tn.ScheduleNodeVisitor): def __init__(self, axis: AxisIterator) -> None: self.axis = axis def visit_MapScope( self, - map_scope: stree.MapScope, + map_scope: tn.MapScope, axis_replacements: dict[str, str] | None = None, **kwargs: Any, ) -> None: @@ -147,7 +145,7 @@ def visit_MapScope( def visit_TaskletNode( self, - node: stree.TaskletNode, + node: tn.TaskletNode, axis_replacements: dict[str, str] | None = None, **kwargs: Any, ) -> None: @@ -159,7 +157,7 @@ def visit_TaskletNode( memlets.replace(axis_replacements) -class CartesianAxisMerge(stree.ScheduleNodeTransformer): +class CartesianAxisMerge(tn.ScheduleNodeTransformer): """Merge a cartesian axis if they are contiguous in code-flow. Can do: @@ -188,8 +186,8 @@ def __str__(self) -> str: def _merge_node( self, - node: stree.ScheduleTreeNode, - nodes: list[stree.ScheduleTreeNode], + node: tn.ScheduleTreeNode, + nodes: list[tn.ScheduleTreeNode], ) -> int: """Direct code to the correct resolver for the node (e.g. visitor) @@ -198,26 +196,28 @@ def _merge_node( behavior (e.g. IfScope before ControlFlowScope) """ - if isinstance(node, stree.MapScope): + if isinstance(node, tn.MapScope): return self._map_overcompute_merge(node, nodes) - elif PUSH_IFSCOPE_DOWNWARD and isinstance(node, stree.IfScope): + + if PUSH_IFSCOPE_DOWNWARD and isinstance(node, tn.IfScope): return self._push_ifelse_down(node, nodes) - elif isinstance(node, stree.ForScope): + + if isinstance(node, tn.ForScope): return self._for_merge(node, nodes) - elif isinstance(node, stree.TaskletNode): + + if isinstance(node, tn.TaskletNode): return self._push_tasklet_down(node, nodes) - elif isinstance(node, stree.ControlFlowScope): + + if isinstance(node, tn.ControlFlowScope): return self._default_control_flow(node) - else: - ndsl_log.debug( - f" (╯°□°)╯︵ ┻━┻: can't merge {type(node)}. Recursion ends." - ) + + ndsl_log.debug(f" (╯°□°)╯︵ ┻━┻: can't merge {type(node)}. Recursion ends.") return 0 def _for_merge( self, - the_for_scope: stree.ForScope, - nodes: list[stree.ScheduleTreeNode], + the_for_scope: tn.ForScope, + nodes: list[tn.ScheduleTreeNode], ) -> int: merged = 0 @@ -229,7 +229,7 @@ def _for_merge( # Non-cartesian for - can be pushed down if everything merged below if ( len(the_for_scope.children) == 1 - and isinstance(the_for_scope.children[0], stree.MapScope) + and isinstance(the_for_scope.children[0], tn.MapScope) and _is_axis_map(the_for_scope.children[0], self.axis) ): swap_node_position_in_tree(the_for_scope, the_for_scope.children[0]) @@ -239,7 +239,7 @@ def _for_merge( def _default_control_flow( self, - the_control_flow: stree.ControlFlowScope, + the_control_flow: tn.ControlFlowScope, ) -> int: if len(the_control_flow.children) != 0: return self._merge(the_control_flow) @@ -248,8 +248,8 @@ def _default_control_flow( def _push_tasklet_down( self, - the_tasklet: stree.TaskletNode, - nodes: list[stree.ScheduleTreeNode], + the_tasklet: tn.TaskletNode, + nodes: list[tn.ScheduleTreeNode], ) -> int: """Push tasklet into a consecutive map.""" in_memlets = the_tasklet.input_memlets() @@ -269,7 +269,7 @@ def _push_tasklet_down( # Attempt to push the tasklet in the next map next_node = nodes[next_index + 1] - if isinstance(next_node, stree.MapScope): + if isinstance(next_node, tn.MapScope): next_node.children.insert(0, the_tasklet) the_tasklet.parent = next_node nodes.remove(the_tasklet) @@ -279,8 +279,8 @@ def _push_tasklet_down( def _push_ifelse_down( self, - the_if: stree.IfScope, - nodes: list[stree.ScheduleTreeNode], + the_if: tn.IfScope, + nodes: list[tn.ScheduleTreeNode], ) -> int: merged = 0 @@ -291,8 +291,8 @@ def _push_ifelse_down( for else_index in range(if_index + 1, len(nodes)): else_node = nodes[else_index] if else_index < len(nodes) and ( - isinstance(else_node, stree.ElseScope) - or isinstance(else_node, stree.ElifScope) + isinstance(else_node, tn.ElseScope) + or isinstance(else_node, tn.ElifScope) ): merged += self._merge_node(else_node, else_node.children) else: @@ -302,17 +302,17 @@ def _push_ifelse_down( # Gather all first maps - if they do not exists, get out all_maps = [] - if isinstance(the_if.children[0], stree.MapScope): + if isinstance(the_if.children[0], tn.MapScope): all_maps.append(the_if.children[0]) else: return merged for else_index in range(if_index + 1, len(nodes)): else_node = nodes[else_index] if else_index < len(nodes) and ( - isinstance(else_node, stree.ElseScope) - or isinstance(else_node, stree.ElifScope) + isinstance(else_node, tn.ElseScope) + or isinstance(else_node, tn.ElifScope) ): - if isinstance(else_node.children[0], stree.MapScope): + if isinstance(else_node.children[0], tn.MapScope): all_maps.append(else_node.children[0]) else: return merged @@ -337,8 +337,8 @@ def _push_ifelse_down( # Swap ELIF/ELSE & maps for else_index in range(if_index + 1, len(nodes)): if else_index < len(nodes) and ( - isinstance(nodes[else_index], stree.ElseScope) - or isinstance(nodes[else_index], stree.ElifScope) + isinstance(nodes[else_index], tn.ElseScope) + or isinstance(nodes[else_index], tn.ElifScope) ): swap_node_position_in_tree( nodes[else_index], nodes[else_index].children[0] @@ -347,15 +347,15 @@ def _push_ifelse_down( break # Merge the Maps - assert isinstance(nodes[if_index], stree.MapScope) + assert isinstance(nodes[if_index], tn.MapScope) merged += self._map_overcompute_merge(nodes[if_index], nodes) return merged def _map_overcompute_merge( self, - the_map: stree.MapScope, - nodes: list[stree.ScheduleTreeNode], + the_map: tn.MapScope, + nodes: list[tn.ScheduleTreeNode], ) -> int: # End of nodes OR # Not the right axis @@ -369,7 +369,7 @@ def _map_overcompute_merge( next_node = _get_next_node(nodes, the_map) # Next node is not a MapScope - no merge - if not isinstance(next_node, stree.MapScope): + if not isinstance(next_node, tn.MapScope): return 0 # Attempt to merge consecutive maps @@ -401,7 +401,7 @@ def _map_overcompute_merge( merged_range=merged_range, original_range=second_range, ).visit(next_node) - merged_children: list[stree.MapScope] = [ + merged_children: list[tn.MapScope] = [ *first_map.children, *second_map.children, ] @@ -419,7 +419,7 @@ def _map_overcompute_merge( return 1 - def _merge(self, node: stree.ScheduleTreeRoot | stree.ScheduleTreeScope) -> int: + def _merge(self, node: tn.ScheduleTreeRoot | tn.ScheduleTreeScope) -> int: merged = 0 if __debug__: @@ -436,7 +436,7 @@ def _merge(self, node: stree.ScheduleTreeRoot | stree.ScheduleTreeScope) -> int: return merged - def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: """Merge as many maps as possible. The algorithm works as follows: @@ -479,5 +479,5 @@ def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: passes_apply += 1 ndsl_log.debug( - f"🚀 Cartesian Axis Merge ({self.axis.name}): {overall_merged} map merged in {passes_apply} passes" + f"🚀 {self}: {overall_merged} maps merged in {passes_apply} passes" ) diff --git a/ndsl/dsl/dace/stree/optimizations/clean_tree.py b/ndsl/dsl/dace/stree/optimizations/clean_tree.py index 41056ec4..0da456de 100644 --- a/ndsl/dsl/dace/stree/optimizations/clean_tree.py +++ b/ndsl/dsl/dace/stree/optimizations/clean_tree.py @@ -1,69 +1,66 @@ from __future__ import annotations -import dace.sdfg.analysis.schedule_tree.treenodes as stree +from dace.sdfg.analysis.schedule_tree import treenodes as tn -from ndsl import ndsl_log +from ndsl.logging import ndsl_log -class CleanUpScheduleTree(stree.ScheduleNodeTransformer): - """Clean up unused nodes, or nodes barrying further optimizations.""" +class CleanUpScheduleTree(tn.ScheduleNodeTransformer): + """Remove `StateBoundary` nodes from children of ScheduleTreeScopes.""" def __init__(self) -> None: - self.cleaned_state_boundaries = 0 + self._removed_state_boundaries = 0 def __str__(self) -> str: return "CleanUpScheduleTree" - def _remove_state_boundaries_from_my_childs( - self, node: stree.ScheduleTreeScope + def _remove_state_boundaries_from_children( + self, node: tn.ScheduleTreeScope ) -> None: to_remove = [ - child - for child in node.children - if isinstance(child, stree.StateBoundaryNode) + child for child in node.children if isinstance(child, tn.StateBoundaryNode) ] - for to_remove_child in to_remove: - self.cleaned_state_boundaries += 1 - node.children.remove(to_remove_child) + for boundary in to_remove: + self._removed_state_boundaries += 1 + node.children.remove(boundary) + + def visit_WhileScope(self, node: tn.WhileScope) -> tn.WhileScope: + self._remove_state_boundaries_from_children(node) - def visit_WhileScope(self, node: stree.WhileScope) -> stree.WhileScope: - self._remove_state_boundaries_from_my_childs(node) for child in node.children: self.visit(child) return node - def visit_ForScope(self, node: stree.ForScope) -> stree.ForScope: - self._remove_state_boundaries_from_my_childs(node) - - # We might have inherited a proper `loop_range` from the SDFG - # but the data (sdfg) it relies on is no longer valid. - node.header.loop_range = lambda: None + def visit_ForScope(self, node: tn.ForScope) -> tn.ForScope: + self._remove_state_boundaries_from_children(node) for child in node.children: self.visit(child) return node - def visit_MapScope(self, node: stree.MapScope) -> stree.MapScope: - self._remove_state_boundaries_from_my_childs(node) + def visit_MapScope(self, node: tn.MapScope) -> tn.MapScope: + self._remove_state_boundaries_from_children(node) + for child in node.children: self.visit(child) return node - def visit_IfScope(self, node: stree.IfScope) -> stree.IfScope: - self._remove_state_boundaries_from_my_childs(node) + def visit_IfScope(self, node: tn.IfScope) -> tn.IfScope: + self._remove_state_boundaries_from_children(node) for child in node.children: self.visit(child) return node - def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: - self._remove_state_boundaries_from_my_childs(node) + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: + self._removed_state_boundaries = 0 + + self._remove_state_boundaries_from_children(node) + for child in node.children: self.visit(child) - ndsl_log.debug( - f"Clean up StateBoundary : {self.cleaned_state_boundaries} nodes" - ) + ndsl_log.debug(f"{self}: removed {self._removed_state_boundaries} nodes") diff --git a/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py b/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py index 0626133e..75f68143 100644 --- a/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py +++ b/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py @@ -3,7 +3,7 @@ import dace.sdfg.analysis.schedule_tree.treenodes as stree from dace.memlet import Memlet -from ndsl import ndsl_log +from ndsl.logging import ndsl_log class AxisIterator(Enum): diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index 05224897..64650762 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -5,8 +5,8 @@ import dace.data import dace.sdfg.analysis.schedule_tree.treenodes as stree -from ndsl import ndsl_log from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator +from ndsl.logging import ndsl_log def _change_index_of_tuple( diff --git a/ndsl/dsl/dace/stree/optimizations/tree_common_op.py b/ndsl/dsl/dace/stree/optimizations/tree_common_op.py index b965fc4a..a3dde7bb 100644 --- a/ndsl/dsl/dace/stree/optimizations/tree_common_op.py +++ b/ndsl/dsl/dace/stree/optimizations/tree_common_op.py @@ -11,7 +11,7 @@ def swap_node_position_in_tree( top_children = top_node.parent.children top_level_parent = top_node.parent - # Swap childrens + # Swap children top_node.children = child_node.children child_node.children = [top_node] top_children.insert(list_index(top_children, top_node), child_node) diff --git a/tests/dsl/dace/stree/__init__.py b/tests/dsl/dace/stree/__init__.py new file mode 100644 index 00000000..2fa38d13 --- /dev/null +++ b/tests/dsl/dace/stree/__init__.py @@ -0,0 +1,7 @@ +from .sdfg_stree_tools import StreeOptimization, get_SDFG_and_purge + + +__all__ = [ + "StreeOptimization", + "get_SDFG_and_purge", +] diff --git a/tests/stree_optimizer/__init__.py b/tests/dsl/dace/stree/optimizations/__init__.py similarity index 100% rename from tests/stree_optimizer/__init__.py rename to tests/dsl/dace/stree/optimizations/__init__.py diff --git a/tests/dsl/dace/stree/optimizations/test_clean_tree.py b/tests/dsl/dace/stree/optimizations/test_clean_tree.py new file mode 100644 index 00000000..d5eee074 --- /dev/null +++ b/tests/dsl/dace/stree/optimizations/test_clean_tree.py @@ -0,0 +1,79 @@ +from dace import nodes +from dace.properties import CodeBlock +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.sdfg.state import LoopRegion +from dace.subsets import Range + +from ndsl.dsl.dace.stree.optimizations import CleanUpScheduleTree + + +def test_if_scope() -> None: + stree = tn.ScheduleTreeRoot( + name="tester", + children=[ + tn.IfScope( + condition=CodeBlock("True"), + children=[tn.StateBoundaryNode()], + ), + ], + ) + + cleaner = CleanUpScheduleTree() + cleaner.visit(stree) + + assert [type(node) for node in stree.children] == [tn.IfScope] + assert len(stree.children[0].children) == 0 + + +def test_for_scope() -> None: + stree = tn.ScheduleTreeRoot( + name="tester", + children=[ + tn.ForScope( + loop=LoopRegion("test"), + children=[tn.StateBoundaryNode()], + ), + ], + ) + + cleaner = CleanUpScheduleTree() + cleaner.visit(stree) + + assert [type(node) for node in stree.children] == [tn.ForScope] + assert len(stree.children[0].children) == 0 + + +def test_while_scope() -> None: + stree = tn.ScheduleTreeRoot( + name="tester", + children=[ + tn.WhileScope( + loop=LoopRegion("test"), + children=[tn.StateBoundaryNode()], + ), + ], + ) + + cleaner = CleanUpScheduleTree() + cleaner.visit(stree) + + assert [type(node) for node in stree.children] == [tn.WhileScope] + assert len(stree.children[0].children) == 0 + + +def test_map_scope() -> None: + stree = tn.ScheduleTreeRoot( + name="tester", + children=[ + tn.MapScope( + node=nodes.MapEntry(map=nodes.Map("asdf", ["i"], Range([]))), + children=[tn.StateBoundaryNode()], + ), + ], + ) + + cleaner = CleanUpScheduleTree() + cleaner.visit(stree) + + assert [type(node) for node in stree.children] == [tn.MapScope] + assert len(stree.children[0].children) == 0 diff --git a/tests/stree_optimizer/test_merge.py b/tests/dsl/dace/stree/optimizations/test_merge.py similarity index 53% rename from tests/stree_optimizer/test_merge.py rename to tests/dsl/dace/stree/optimizations/test_merge.py index 6f98c7e3..42996625 100644 --- a/tests/stree_optimizer/test_merge.py +++ b/tests/dsl/dace/stree/optimizations/test_merge.py @@ -1,12 +1,14 @@ +from typing import TypeAlias + import dace +import pytest from ndsl import QuantityFactory, StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile_orchestrated from ndsl.constants import I_DIM, J_DIM, K_DIM from ndsl.dsl.gt4py import FORWARD, PARALLEL, K, computation, interval from ndsl.dsl.typing import FloatField - -from .sdfg_stree_tools import StreeOptimization, get_SDFG_and_purge +from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge def stencil(in_field: FloatField, out_field: FloatField) -> None: @@ -52,7 +54,7 @@ def __init__( "trivial_merge", "missing_merge_of_forscope_and_map", "overcompute_merge", - "block_merge_when_depandencies_is_found", + "block_merge_when_dependencies_are_found", "push_non_cartesian_for", ] for method in orchestratable_methods: @@ -98,7 +100,7 @@ def missing_merge_of_forscope_and_map( self.stencil_with_forward_K(in_field, out_field) self.stencil(in_field, out_field) - def block_merge_when_depandencies_is_found( + def block_merge_when_dependencies_are_found( self, in_field: FloatField, out_field: FloatField, @@ -124,19 +126,29 @@ def push_non_cartesian_for( self.stencil(in_field, out_field) -def test_stree_merge_maps_IJK() -> None: - domain = (3, 3, 4) - stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( - domain[0], domain[1], domain[2], 0, backend="dace:cpu_kfirst" - ) +Factories: TypeAlias = tuple[StencilFactory, QuantityFactory] + + +class TestStreeMergeMapsIJK: + @pytest.fixture + def factories(self) -> Factories: + domain = (3, 3, 4) + return get_factories_single_tile_orchestrated( + domain[0], domain[1], domain[2], 0, backend="dace:cpu_kfirst" + ) + + @pytest.fixture + def code(self, factories: Factories) -> OrchestratedCode: + return OrchestratedCode(*factories) + + def test_trivial_merge(self, code: OrchestratedCode, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - code = OrchestratedCode(stencil_factory, quantity_factory) - in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") - out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + with StreeOptimization(): + code.trivial_merge(in_qty, out_qty) - with StreeOptimization(): - # Trivial merge - code.trivial_merge(in_qty, out_qty) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) all_maps = [ (me, state) @@ -147,24 +159,40 @@ def test_stree_merge_maps_IJK() -> None: assert len(all_maps) == 3 assert (out_qty.field[:] == 2).all() - # Merge IJ - but do not merge K map & for (missing feature) - code.missing_merge_of_forscope_and_map(in_qty, out_qty) + def test_missing_merge_of_forscope_and_map( + self, code: OrchestratedCode, factories: Factories + ) -> None: + stencil_factory, quantity_factory = factories + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + with StreeOptimization(): + code.missing_merge_of_forscope_and_map(in_qty, out_qty) + sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ - (me, state) - for me, state in sdfg.all_nodes_recursive() - if isinstance(me, dace.nodes.MapEntry) + map_entry + for map_entry, _ in sdfg.all_nodes_recursive() + if isinstance(map_entry, dace.nodes.MapEntry) ] assert len(all_maps) == 4 # 2 IJ + 2 Ks - all_loop_guard_state = [ - (me, state) - for me, state in sdfg.all_nodes_recursive() - if isinstance(me, dace.SDFGState) and me.name.startswith("loop_guard") + all_loops = [ + loop + for loop, _ in sdfg.all_nodes_recursive() + if isinstance(loop, dace.sdfg.state.LoopRegion) ] - assert len(all_loop_guard_state) == 1 # 1 For loop + assert len(all_loops) == 1 # 1 For loop + + def test_overcompute_merge( + self, code: OrchestratedCode, factories: Factories + ) -> None: + stencil_factory, quantity_factory = factories + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + with StreeOptimization(): + code.overcompute_merge(in_qty, out_qty) - # Overcompute merge in K - we merge and introduce an If guard - code.overcompute_merge(in_qty, out_qty) sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ (me, state) @@ -174,8 +202,17 @@ def test_stree_merge_maps_IJK() -> None: # ⚠️ WE EXPECT A FAILURE TO MERGE K (because of index) ⚠️ assert len(all_maps) == 4 # Should be all dmerged = 3 - # Forbid merging when data dependancy is detected - code.block_merge_when_depandencies_is_found(in_qty, out_qty) + def test_block_merge_when_dependencies_are_found( + self, code: OrchestratedCode, factories: Factories + ) -> None: + stencil_factory, quantity_factory = factories + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + with StreeOptimization(): + # Forbid merging when data dependencies are detected + code.block_merge_when_dependencies_are_found(in_qty, out_qty) + sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ me.params[0] @@ -185,9 +222,18 @@ def test_stree_merge_maps_IJK() -> None: # ⚠️ WE EXPECT A FAILURE TO MERGE K (because of index) ⚠️ assert len(all_maps) == 5 # Should be 4 = 2 IJ + 2 Ks (un-merged) - # Push non-cartesian ForScope inwward, which allow to potentially - # merge cartesian maps - code.push_non_cartesian_for(in_qty, out_qty) + def test_push_non_cartesian_for( + self, code: OrchestratedCode, factories: Factories + ) -> None: + stencil_factory, quantity_factory = factories + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + with StreeOptimization(): + # Push non-cartesian ForScope inwards, which allow to potentially + # merge cartesian maps + code.push_non_cartesian_for(in_qty, out_qty) + sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ (me, state) @@ -203,19 +249,26 @@ def test_stree_merge_maps_IJK() -> None: assert len(all_loop_guard_state) == 1 # 1 For loop -def test_stree_merge_maps_KJI() -> None: - domain = (3, 3, 4) - stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( - domain[0], domain[1], domain[2], 0, backend="dace:cpu_KJI" - ) +class TestStreeMergeMapsKJI: + @pytest.fixture + def factories(self) -> Factories: + domain = (3, 3, 4) + return get_factories_single_tile_orchestrated( + domain[0], domain[1], domain[2], 0, backend="dace:cpu_KJI" + ) + + @pytest.fixture + def code(self, factories: Factories) -> OrchestratedCode: + return OrchestratedCode(*factories) + + def test_trivial_merge(self, code: OrchestratedCode, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - code = OrchestratedCode(stencil_factory, quantity_factory) - in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") - out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + with StreeOptimization(): + code.trivial_merge(in_qty, out_qty) - with StreeOptimization(): - # Trivial merge - code.trivial_merge(in_qty, out_qty) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) all_maps = [ (me, state) @@ -226,24 +279,45 @@ def test_stree_merge_maps_KJI() -> None: assert len(all_maps) == 3 assert (out_qty.field[:] == 2).all() - # K iterative loop - blocks all merges + def test_missing_merge_of_forscope_and_map( + self, code: OrchestratedCode, factories: Factories + ) -> None: + stencil_factory, quantity_factory = factories + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + code.missing_merge_of_forscope_and_map(in_qty, out_qty) sdfg = get_SDFG_and_purge(stencil_factory).sdfg + + with StreeOptimization(): + # K iterative loop - blocks all merges + code.missing_merge_of_forscope_and_map(in_qty, out_qty) + + sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ - (me, state) - for me, state in sdfg.all_nodes_recursive() - if isinstance(me, dace.nodes.MapEntry) + map_entry + for map_entry, _ in sdfg.all_nodes_recursive() + if isinstance(map_entry, dace.nodes.MapEntry) ] assert len(all_maps) == 8 # 2 KJI (all maps) + 1 for scope - all_loop_guard_state = [ - (me, state) - for me, state in sdfg.all_nodes_recursive() - if isinstance(me, dace.SDFGState) and me.name.startswith("loop_guard") + all_loops = [ + loop + for loop, _ in sdfg.all_nodes_recursive() + if isinstance(loop, dace.sdfg.state.LoopRegion) ] - assert len(all_loop_guard_state) == 1 # 1 For loop + assert len(all_loops) == 1 # 1 For loop + + def test_overcompute_merge( + self, code: OrchestratedCode, factories: Factories + ) -> None: + stencil_factory, quantity_factory = factories + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + with StreeOptimization(): + # Overcompute merge in K - we merge and introduce an If guard + code.overcompute_merge(in_qty, out_qty) - # Overcompute merge in K - we merge and introduce an If guard - code.overcompute_merge(in_qty, out_qty) sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ (me, state) @@ -253,8 +327,17 @@ def test_stree_merge_maps_KJI() -> None: # ⚠️ WE EXPECT A FAILURE TO MERGE K (because of index) ⚠️ assert len(all_maps) == 6 - # Forbid merging when data dependancy is detected - code.block_merge_when_depandencies_is_found(in_qty, out_qty) + def test_block_merge_when_dependencies_are_found( + self, code: OrchestratedCode, factories: Factories + ) -> None: + stencil_factory, quantity_factory = factories + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + with StreeOptimization(): + # Forbid merging when data dependencies are detected + code.block_merge_when_dependencies_are_found(in_qty, out_qty) + sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ (me, state) @@ -264,9 +347,18 @@ def test_stree_merge_maps_KJI() -> None: # ⚠️ WE EXPECT A FAILURE TO MERGE K (because of index) ⚠️ assert len(all_maps) == 9 - # Push non-cartesian ForScope inwward, which allow to potentially - # merge cartesian maps - code.push_non_cartesian_for(in_qty, out_qty) + def test_push_non_cartesian_for( + self, code: OrchestratedCode, factories: Factories + ) -> None: + stencil_factory, quantity_factory = factories + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + with StreeOptimization(): + # Push non-cartesian ForScope inwards, which allow to potentially + # merge cartesian maps + code.push_non_cartesian_for(in_qty, out_qty) + sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ (me, state) diff --git a/tests/stree_optimizer/test_pipeline.py b/tests/dsl/dace/stree/optimizations/test_pipeline.py similarity index 100% rename from tests/stree_optimizer/test_pipeline.py rename to tests/dsl/dace/stree/optimizations/test_pipeline.py diff --git a/tests/stree_optimizer/test_transient_refine.py b/tests/dsl/dace/stree/optimizations/test_transient_refine.py similarity index 98% rename from tests/stree_optimizer/test_transient_refine.py rename to tests/dsl/dace/stree/optimizations/test_transient_refine.py index f2228c2c..ffaee9db 100644 --- a/tests/stree_optimizer/test_transient_refine.py +++ b/tests/dsl/dace/stree/optimizations/test_transient_refine.py @@ -3,8 +3,7 @@ from ndsl.constants import I_DIM, J_DIM, K_DIM from ndsl.dsl.gt4py import IJK, PARALLEL, Field, J, K, computation, interval from ndsl.dsl.typing import Float, FloatField - -from .sdfg_stree_tools import StreeOptimization, get_SDFG_and_purge +from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge DATADIM_SIZE = 8 diff --git a/tests/stree_optimizer/sdfg_stree_tools.py b/tests/dsl/dace/stree/sdfg_stree_tools.py similarity index 100% rename from tests/stree_optimizer/sdfg_stree_tools.py rename to tests/dsl/dace/stree/sdfg_stree_tools.py diff --git a/tests/stencils/test_stencils.py b/tests/stencils/test_stencils.py index 5c4a46d7..17d521e7 100644 --- a/tests/stencils/test_stencils.py +++ b/tests/stencils/test_stencils.py @@ -92,7 +92,7 @@ def __call__( self._column_min_ddim_stencil(data_ddim, min_value_ddim, min_index_ddim) -def test_column_operations(boilerplate): +def test_column_operations(boilerplate: tuple[StencilFactory, QuantityFactory]): stencil_factory, quantity_factory = boilerplate quantity_factory.add_data_dimensions({"ddim": 2}) data = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "n/a")