diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py b/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py index 4663031edb..83fbbd08c8 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py @@ -314,6 +314,7 @@ def can_be_applied( return False # This avoids that we have to modify the subsets in a fancy way. + # TODO(phimuell): Lift this limitation. if len(a1_desc.shape) != len(a2_desc.shape): return False diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py index 78f3ea73e1..0561bde568 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py @@ -11,7 +11,7 @@ from typing import Any, Container, Optional, Sequence, TypeVar, Union import dace -from dace import data as dace_data, subsets as dace_sbs, symbolic as dace_sym +from dace import data as dace_data, libraries as dace_lib, subsets as dace_sbs, symbolic as dace_sym from dace.sdfg import graph as dace_graph, nodes as dace_nodes from dace.transformation import pass_pipeline as dace_ppl from dace.transformation.passes import analysis as dace_analysis @@ -405,9 +405,9 @@ def reroute_edge( """ current_memlet: dace.Memlet = current_edge.data if is_producer_edge: - # NOTE: See the note in `_reconfigure_dataflow()` why it is not save to - # use the `get_{dst, src}_subset()` function, although it would be more - # appropriate. + # NOTE: See the note in `reconfigure_dataflow_after_rerouting()` why it is not + # safe to use the `get_{dst, src}_subset()` function, although it would be + # more appropriate. assert current_edge.dst is old_node current_subset: dace_sbs.Range = current_memlet.dst_subset new_src = current_edge.src @@ -503,6 +503,10 @@ def reconfigure_dataflow_after_rerouting( old_node: The old that was involved in the old, rerouted, edge. new_node: The new node that should be used instead of `old_node`. """ + + # NOTE: The base assumption of this function is that the subset on the side of + # `new_node` is already correct and we have to adjust the subset on the side + # of `other_node`. other_node = new_edge.src if is_producer_edge else new_edge.dst if isinstance(other_node, dace_nodes.AccessNode): @@ -565,6 +569,21 @@ def reconfigure_dataflow_after_rerouting( # the full array, but essentially slice a bit. pass + elif isinstance(other_node, dace_lib.standard.Reduce): + # For now we only handle the case that the reduction node is writing into + # `new_node`, before the data was written into `old_node`. In that case + # there is nothing to do, we just do some checks. + # TODO(phimuell): This about how to handle the other case or how to extend + # to other library nodes. + + if not is_producer_edge: + raise ValueError("Reduction nodes are only supported as output.") + assert isinstance(new_node, dace_nodes.AccessNode) + + # The subset at the reduction node needs to be `None`, which means undefined. + other_subset = new_edge.data.src_subset if is_producer_edge else new_edge.data.dst_subset + assert other_subset is None + else: # As we encounter them we should handle them case by case. raise NotImplementedError( diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_copy_chain_remover.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_copy_chain_remover.py index ac7321b00d..e6593e3e27 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_copy_chain_remover.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_copy_chain_remover.py @@ -13,6 +13,8 @@ import numpy as np dace = pytest.importorskip("dace") + +from dace import libraries as dace_libnode from dace.sdfg import nodes as dace_nodes from gt4py.next.program_processors.runners.dace import ( @@ -355,6 +357,105 @@ def _make_a1_has_output_sdfg() -> dace.SDFG: return sdfg +def _make_copy_chain_with_reduction_node( + output_an_array: bool, +) -> tuple[ + dace.SDFG, + dace.SDFGState, + dace_libnode.Reduce, + dace_nodes.AccessNode, + dace_libnode.standar.Reduce, + dace_nodes.AccessNode, + dace_nodes.AccessNode, +]: + sdfg = dace.SDFG(util.unique_name("copy_chain_remover_with_reduction_sdfg")) + state = sdfg.add_state(is_start_block=True) + + if output_an_array: + input_shape = (10, 3, 20) + output_shape = (10, 2, 20) + reduce_axes = [1] + # Actually we should use `(10, 20)` as shape, but then the transformation + # does not apply, this is a limitation in the transformation. + acc_shape = (10, 1, 20) + else: + input_shape = (3,) + acc_shape = () # Is a scalar. + output_shape = (2,) + reduce_axes = None + + for i in range(2): + sdfg.add_array( + f"data{i}", + shape=input_shape, + dtype=dace.float64, + transient=False, + ) + if output_an_array: + sdfg.add_array( + f"acc{i}", + shape=acc_shape, + dtype=dace.float64, + transient=True, + ) + else: + sdfg.add_scalar( + f"acc{i}", + dtype=dace.float64, + transient=True, + ) + sdfg.add_array( + "output", + shape=output_shape, + dtype=dace.float64, + transient=False, + ) + + accumulators: list[dace_nodes.AccessNode] = [] + reducers: list[dace_libnode.standard.Reduce] = [] + for i in range(2): + data_ac = state.add_access(f"data{i}") + acc_ac = state.add_access(f"acc{i}") + reduce_node = state.add_reduce( + wcr="lambda x, y: x + y", + axes=reduce_axes, + identity=0.0, + ) + state.add_nedge( + data_ac, + reduce_node, + dace.Memlet(f"{data_ac.data}[" + ", ".join(f"0:{s}" for s in input_shape) + "]"), + ) + if output_an_array: + state.add_nedge( + reduce_node, + acc_ac, + dace.Memlet(f"{acc_ac.data}[0:{acc_shape[0]}, 0, 0:{acc_shape[-1]}]"), + ) + else: + state.add_nedge(reduce_node, acc_ac, dace.Memlet(f"{acc_ac.data}[0]")) + accumulators.append(acc_ac) + reducers.append(reduce_node) + + red0, red1 = reducers + output_ac = state.add_access("output") + + for i, acc in enumerate(accumulators): + if output_an_array: + state.add_nedge( + acc, + output_ac, + dace.Memlet( + f"{acc.data}[0:{output_shape[0]}, 0, 0:{output_shape[-1]}] -> [0:{output_shape[0]}, {i}, 0:{output_shape[-1]}]" + ), + ) + else: + state.add_nedge(acc, output_ac, dace.Memlet(f"{acc.data}[0] -> [{i}]")) + sdfg.validate() + + return sdfg, state, red0, accumulators[0], red1, accumulators[1], output_ac + + def test_simple_linear_chain(): sdfg = _make_simple_linear_chain_sdfg() @@ -567,3 +668,102 @@ def inner_ref(i0, o0): # Now run the transformed SDFG to see if the same output is generated. util.compile_and_run_sdfg(sdfg, **res) assert all(np.allclose(ref[name], res[name]) for name in ref.keys()) + + +@pytest.mark.parametrize("output_an_array", [False, True]) +def test_copy_chain_remover_with_reduction(output_an_array: bool): + sdfg, state, red0, acc0, red1, acc1, output = _make_copy_chain_with_reduction_node( + output_an_array + ) + + def apply_to(a1, a2): + candidate = { + gtx_transformations.CopyChainRemover.node_a1: a1, + gtx_transformations.CopyChainRemover.node_a2: a2, + } + copy_chain_remover = gtx_transformations.CopyChainRemover( + single_use_data={sdfg: {acc0.data, acc1.data}}, + ) + copy_chain_remover.setup_match( + sdfg=sdfg, + cfg_id=state.parent_graph.cfg_id, + state_id=state.block_id, + subgraph=candidate, + expr_index=0, + override=True, + ) + assert copy_chain_remover.can_be_applied(state, 0, sdfg, permissive=False) + copy_chain_remover.apply(state, sdfg) + + assert sdfg.number_of_nodes() == 1 + assert state.number_of_nodes() == 7 + + assert all(e.dst is acc0 for e in state.out_edges(red0)) + assert state.in_degree(acc0) == 1 + assert state.out_degree(acc0) == 1 + + assert all(e.dst is acc1 for e in state.out_edges(red1)) + assert state.in_degree(acc1) == 1 + assert state.out_degree(acc1) == 1 + + assert all(e.src in [acc0, acc1] for e in state.in_edges(output)) + assert state.out_degree(output) == 0 + + # Now remove the `acc0` intermediate. + apply_to(a1=acc0, a2=output) + + # The accumulator `acc0` has been removed. + sdfg.validate() + assert state.number_of_nodes() == 6 + + access_nodes: list[dace_nodes.AccessNode] = util.count_nodes(sdfg, dace_nodes.AccessNode, True) + assert len(access_nodes) == 4 + assert acc0 not in access_nodes + assert acc1 in access_nodes + assert output in access_nodes + + assert state.out_degree(red0) == 1 + red0_oedge: dace.sdfg.graph.MultiDiConnectorGraph[dace.Memlet] = next( + iter(state.out_edges(red0)) + ) + assert red0_oedge.dst is output + red0_oedge_mlet: dace.Memlet = red0_oedge.data + assert red0_oedge_mlet.src_subset is None + + if output_an_array: + assert len(red0_oedge_mlet.dst_subset) == 3 + assert red0_oedge_mlet.dst_subset == dace.subsets.Range.from_string("0:10, 0, 0:20") + + else: + assert len(red0_oedge_mlet.dst_subset) == 1 + assert red0_oedge_mlet.dst_subset == dace.subsets.Range.from_string("0") + + assert state.out_degree(red1) == 1 + assert all(e.dst is acc1 for e in state.out_edges(red1)) + + # Now the accumulator `acc1` will be removed. + apply_to(a1=acc1, a2=output) + + sdfg.validate() + assert state.number_of_nodes() == 5 + + access_nodes = util.count_nodes(sdfg, dace_nodes.AccessNode, True) + assert len(access_nodes) == 3 + assert acc0 not in access_nodes + assert acc1 not in access_nodes + assert output in access_nodes + + assert state.out_degree(red1) == 1 + red1_oedge: dace.sdfg.graph.MultiDiConnectorGraph[dace.Memlet] = next( + iter(state.out_edges(red1)) + ) + assert red1_oedge.dst is output + red1_oedge_mlet: dace.Memlet = red1_oedge.data + assert red1_oedge_mlet.src_subset is None + + if output_an_array: + assert len(red1_oedge_mlet.dst_subset) == 3 + assert red1_oedge_mlet.dst_subset == dace.subsets.Range.from_string("0:10, 1, 0:20") + else: + assert len(red1_oedge_mlet.dst_subset) == 1 + assert red1_oedge_mlet.dst_subset == dace.subsets.Range.from_string("1") diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_rerouting.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_rerouting.py new file mode 100644 index 0000000000..1cfe4b3433 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_rerouting.py @@ -0,0 +1,148 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import pytest +import copy +import numpy as np + +dace = pytest.importorskip("dace") + +from dace import libraries as dace_libnode +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _make_tasklet_in_map() -> tuple[ + dace.SDFG, + dace.SDFGState, + dace_nodes.AccessNode, + dace_nodes.AccessNode, + dace_nodes.AccessNode, + dace_nodes.AccessNode, + dace_nodes.MapEntry, +]: + sdfg = dace.SDFG(util.unique_name("tasklet_in_map")) + state = sdfg.add_state(is_start_block=True) + + sdfg.add_array( + "a", + shape=(20,), + dtype=dace.float64, + transient=False, + ) + sdfg.add_array( + "b", + shape=(10,), + dtype=dace.float64, + transient=True, + ) + sdfg.add_array( + "c", + shape=(10,), + dtype=dace.float64, + transient=True, + ) + sdfg.add_array( + "d", + shape=(20,), + dtype=dace.float64, + transient=False, + ) + + a, b, c, d = (state.add_access(name) for name in "abcd") + + _, me, _ = state.add_mapped_tasklet( + "computation", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("b[__i]")}, + code="__out = __in + 3.0", + outputs={"__out": dace.Memlet("c[__i]")}, + input_nodes={b}, + output_nodes={c}, + external_edges=True, + ) + + state.add_nedge(a, b, dace.Memlet("a[3:13] -> [0:10]")) + state.add_nedge(c, d, dace.Memlet("c[0:10] -> [2:12]")) + sdfg.validate() + + return sdfg, state, a, b, c, d, me + + +def test_reconfigure_tasklet_in_map(): + sdfg, state, a, b, c, d, me = _make_tasklet_in_map() + + assert state.out_degree(b) == 1 + old_edge = next(iter(state.out_edges(b))) + + new_edge = gtx_transformations.utils.reroute_edge( + is_producer_edge=False, + current_edge=old_edge, + ss_offset=[3], + state=state, + sdfg=sdfg, + old_node=b, + new_node=a, + ) + + # Currently the SDFG is invalid, because the old edge has not been removed. + # We thus check if it is invalid in the correct way. + assert state.out_degree(a) == 2 + assert state.in_degree(me) == 2 + assert {b, me} == {e.dst for e in state.out_edges(a)} + assert len({e.dst_conn for e in state.in_edges(me)}) == 1 + assert new_edge.data.dst_subset == dace.subsets.Range.from_string("3:13") + + # The edge on on the inside has not been updated yet. + me_oedge = next(iter(state.out_edges(me))) + assert me_oedge.data.src_subset == dace.subsets.Range.from_string("__i") + assert me_oedge.data.data == "b" + + # Now let's propagate the change into the Map. + gtx_transformations.utils.reconfigure_dataflow_after_rerouting( + is_producer_edge=False, + new_edge=new_edge, + ss_offset=[3], + state=state, + sdfg=sdfg, + old_node=b, + new_node=a, + ) + + # Now the edge inside the Map has changed. + assert me_oedge.data.src_subset == dace.subsets.Range.from_string("__i") + assert me_oedge.data.data == "b" + + +def test_reconfigure_reduction_in_map(): + pass + + +def test_reconfigure_tasklet(): + pass + + +def test_reconfigure_reduction(): + pass + + +def test_reconfigure_access_node(): + pass + + +def test_reconfigure_access_node_in_map(): + pass