diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 1d04c21fc3..a0f0bdc0ac 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -707,6 +707,15 @@ def _gt_auto_process_dataflow_inside_maps( # before or after `LoopBlocking`. In cases where the condition is `False` # most of the times calling it before is better, but if the condition is # `True` then this order is better. Solve that issue. + # NOTE: The transformation is currently only able to handle dataflow that is + # _directly_ enclosed by a Map. Thus the order in which they (multiple blocks + # in the same Map) are processed matter. Think of a chain of `if` blocks that + # can be perfectly nested. If the last one is handled first, then all other + # can not be processed anymore. This means it is important to set + # `ignore_upstream_blocks` to `False`, thus the transformation will not apply + # if the dataflow that should be relocated into an `if` block contains again + # `if` blocks that can be relocated. It would be more efficient to process + # them in the right order from the beginning. sdfg.apply_transformations_repeated( gtx_transformations.MoveDataflowIntoIfBody( ignore_upstream_blocks=False, diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index 55497fc11a..1db9047199 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -147,6 +147,7 @@ def can_be_applied( if_block=if_block, raw_relocatable_dataflow=raw_relocatable_dataflow, non_relocatable_dataflow=non_relocatable_dataflow, + enclosing_map=enclosing_map, ) # If no branch has something to inline then we are done. @@ -204,6 +205,7 @@ def apply( if_block=if_block, raw_relocatable_dataflow=raw_relocatable_dataflow, non_relocatable_dataflow=non_relocatable_dataflow, + enclosing_map=enclosing_map, ) # Finally relocate the dataflow @@ -551,6 +553,7 @@ def _has_if_block_relocatable_dataflow( if_block=upstream_if_block, raw_relocatable_dataflow=raw_relocatable_dataflow, non_relocatable_dataflow=non_relocatable_dataflow, + enclosing_map=enclosing_map, ) if all(len(rel_df) == 0 for rel_df in filtered_relocatable_dataflow.values()): return False @@ -564,6 +567,7 @@ def _filter_relocatable_dataflow( if_block: dace_nodes.NestedSDFG, raw_relocatable_dataflow: dict[str, set[dace_nodes.Node]], non_relocatable_dataflow: dict[str, set[dace_nodes.Node]], + enclosing_map: dace_nodes.MapEntry, ) -> dict[str, set[dace_nodes.Node]]: """Partition the dependencies. @@ -581,6 +585,8 @@ def _filter_relocatable_dataflow( that can be relocated, not yet filtered. non_relocatable_dataflow: The connectors and their associated dataflow that can not be relocated. + enclosing_map: The limiting node, i.e. the MapEntry of the Map where + `if_block` is located in. """ # Remove the parts of the dataflow that is unrelocatable. @@ -592,8 +598,9 @@ def _filter_relocatable_dataflow( for conn_name, rel_df in raw_relocatable_dataflow.items() } - # Now we determine the nodes that are in more than one sets. - # These sets must be removed, from the individual sets. + # Relocating nodes that are in more than one set is difficult. In the most + # common case of just two branches, this anyway means they have to be + # executed in any case. Thus we remove them now. known_nodes: set[dace_nodes.Node] = set() multiple_df_nodes: set[dace_nodes.Node] = set() for rel_df in relocatable_dataflow.values(): @@ -606,35 +613,73 @@ def _filter_relocatable_dataflow( for conn_name, rel_df in relocatable_dataflow.items() } - # However, not all dataflow can be moved inside the branch. For example if - # something is used outside the dataflow, that is moved inside the `if`, - # then we can not relocate it. # TODO(phimuell): If we operate outside of a Map we also have to make sure that # the data is single use data, is not an AccessNode that refers to global # memory nor is a source AccessNode. def filter_nodes( - branch_nodes: set[dace_nodes.Node], - sdfg: dace.SDFG, - state: dace.SDFGState, + nodes_proposed_for_reloc: set[dace_nodes.Node], ) -> set[dace_nodes.Node]: - # For this to work the `if_block` must be considered part, we remove it later. - branch_nodes.add(if_block) has_been_updated = True while has_been_updated: has_been_updated = False - for node in list(branch_nodes): - if node is if_block: + + for reloc_node in list(nodes_proposed_for_reloc): + # The node was already handled in a previous iteration. + if reloc_node not in nodes_proposed_for_reloc: + continue + + assert ( + state.in_degree(reloc_node) > 0 + ) # Because we are currently always inside a Map + + # If the node is needed by anything that is not also moved + # into the `if` body, then it has to remain outside. For that we + # have to pretend that `if_block` is also relocated. + if any( + oedge.dst not in nodes_proposed_for_reloc + for oedge in state.out_edges(reloc_node) + if oedge.dst is not if_block + ): + nodes_proposed_for_reloc.remove(reloc_node) + has_been_updated = True continue - if any(oedge.dst not in branch_nodes for oedge in state.out_edges(node)): - branch_nodes.remove(node) + + # We do not look at all incoming nodes, but have to ignore some of them. + # We ignore `enclosed_map` because it acts as boundary, and the node + # on the other side of it is mapped into the `if` body anyway. We + # ignore the AccessNodes because they will either be relocated into + # the `if` body or be mapped (remain outside but made accessible + # inside), thus their relocation state is of no concern for + # `reloc_node`. + non_mappable_incoming_nodes: set[dace_nodes.Node] = { + iedge.src + for iedge in state.in_edges(reloc_node) + if not ( + (iedge.src is enclosing_map) + or isinstance(iedge.src, dace_nodes.AccessNode) + ) + } + if non_mappable_incoming_nodes.issubset(nodes_proposed_for_reloc): + # All nodes that can not be mapped into the `if` body are + # currently scheduled to be relocated, thus there is not + # problem. + pass + + else: + # Only some of the non mappable nodes are selected to be + # moved inside the `if` body. This means that `reloc_node` + # can also not be moved because of its input dependencies. + # Since we can not relocate `reloc_node` this also implies + # that none of its input can. Thus we remove them from + # `nodes_proposed_for_reloc`. + nodes_proposed_for_reloc.difference_update(non_mappable_incoming_nodes) + nodes_proposed_for_reloc.remove(reloc_node) has_been_updated = True - assert if_block in branch_nodes - branch_nodes.remove(if_block) - return branch_nodes + + return nodes_proposed_for_reloc return { - conn_name: filter_nodes(rel_df, sdfg, state) - for conn_name, rel_df in relocatable_dataflow.items() + conn_name: filter_nodes(rel_df) for conn_name, rel_df in relocatable_dataflow.items() } def _partition_if_block( diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py index 45c1620108..03aba6599e 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py @@ -1047,3 +1047,155 @@ def test_if_mover_symbolic_tasklet(): assert if_block.sdfg.symbols["__i"] in {dace.int32, dace.int64} assert if_block.symbol_mapping.keys() == expected_symb.union(["__i"]) assert all(str(sym) == str(symval) for sym, symval in if_block.symbol_mapping.items()) + + +def test_if_mover_access_node_between(): + """ + Essentially tests the following situation: + ```python + a = foo(...) + b = bar(...) + c = baz(...) + bb = a if c else b + cc = baz2(d, ...) + aa = foo2(...) + e = aa if cc else bb + ``` + """ + sdfg = dace.SDFG(util.unique_name("if_mover_chain_of_blocks")) + state = sdfg.add_state(is_start_block=True) + + # Inputs + input_names = ["a", "b", "c", "d", "e", "f"] + for name in input_names: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + + # Temporaries + temporary_names = ["a1", "b1", "c1", "a2", "b2", "c2"] + for name in temporary_names: + sdfg.add_scalar( + name, dtype=dace.bool_ if name.startswith("c") else dace.float64, transient=True + ) + + a1, b1, c1, a2, b2, c2 = (state.add_access(name) for name in temporary_names) + me, mx = state.add_map("comp", ndrange={"__i": "0:10"}) + + # First branch of top `if_block` + tasklet_a1 = state.add_tasklet( + "tasklet_a1", inputs={"__in"}, outputs={"__out"}, code="__out = math.sin(__in)" + ) + state.add_edge(state.add_access("a"), None, me, "IN_a", dace.Memlet("a[0:10]")) + state.add_edge(me, "OUT_a", tasklet_a1, "__in", dace.Memlet("a[__i]")) + state.add_edge(tasklet_a1, "__out", a1, None, dace.Memlet("a1[0]")) + + # Second branch of the top `if_block` + tasklet_b1 = state.add_tasklet( + "tasklet_b1", inputs={"__in"}, outputs={"__out"}, code="__out = math.cos(__in)" + ) + state.add_edge(state.add_access("b"), None, me, "IN_b", dace.Memlet("b[0:10]")) + state.add_edge(me, "OUT_b", tasklet_b1, "__in", dace.Memlet("b[__i]")) + state.add_edge(tasklet_b1, "__out", b1, None, dace.Memlet("b1[0]")) + + # The condition of the top `if_block` + tasklet_c1 = state.add_tasklet( + "tasklet_c1", inputs={"__in"}, outputs={"__out"}, code="__out = __in < 0.5" + ) + state.add_edge(state.add_access("c"), None, me, "IN_c", dace.Memlet("c[0:10]")) + state.add_edge(me, "OUT_c", tasklet_c1, "__in", dace.Memlet("c[__i]")) + state.add_edge(tasklet_c1, "__out", c1, None, dace.Memlet("c1[0]")) + + # Create the top `if_block` + top_if_block = _make_if_block(state, sdfg) + state.add_edge(a1, None, top_if_block, "__arg1", dace.Memlet("a1[0]")) + state.add_edge(b1, None, top_if_block, "__arg2", dace.Memlet("b1[0]")) + state.add_edge(c1, None, top_if_block, "__cond", dace.Memlet("c1[0]")) + state.add_edge(top_if_block, "__output", a2, None, dace.Memlet("a2[0]")) + + # The first branch of the lower/second `if_block` + tasklet_b2 = state.add_tasklet( + "tasklet_b2", inputs={"__in"}, outputs={"__out"}, code="__out = math.atan(__in)" + ) + state.add_edge(state.add_access("e"), None, me, "IN_e", dace.Memlet("e[0:10]")) + state.add_edge(me, "OUT_e", tasklet_b2, "__in", dace.Memlet("e[__i]")) + state.add_edge(tasklet_b2, "__out", b2, None, dace.Memlet("b2[0]")) + + # Condition branch of the second `if_block`. + tasklet_c2 = state.add_tasklet( + "tasklet_c2", inputs={"__in"}, outputs={"__out"}, code="__out = __in < 0.5" + ) + state.add_edge(state.add_access("d"), None, me, "IN_d", dace.Memlet("d[0:10]")) + state.add_edge(me, "OUT_d", tasklet_c2, "__in", dace.Memlet("d[__i]")) + state.add_edge(tasklet_c2, "__out", c2, None, dace.Memlet("c2[0]")) + + # Create the second `if_block` + bot_if_block = _make_if_block(state, sdfg) + state.add_edge(a2, None, bot_if_block, "__arg1", dace.Memlet("a2[0]")) + state.add_edge(b2, None, bot_if_block, "__arg2", dace.Memlet("b2[0]")) + state.add_edge(c2, None, bot_if_block, "__cond", dace.Memlet("c2[0]")) + + # Generate the output + state.add_edge(bot_if_block, "__output", mx, "IN_f", dace.Memlet("f[__i]")) + state.add_edge(mx, "OUT_f", state.add_access("f"), None, dace.Memlet("f[0:10]")) + + # Now add the connectors to the Map* + for iname in input_names: + mq = mx if iname == "f" else me + mq.add_in_connector(f"IN_{iname}") + mq.add_out_connector(f"OUT_{iname}") + sdfg.validate() + + # We can not process the bottom block, because this would also inline the top + # block that in turn has dataflow that could be relocated. + _perform_test( + sdfg, + explected_applies=0, + if_block=bot_if_block, + ) + + # But we are able to process them that way, starting from the bottom. + _perform_test( + sdfg, + explected_applies=2, + ) + + expected_top_level_data: set[str] = {"a", "b", "c", "d", "e", "f", "c2"} + assert set(dnode.data for dnode in state.data_nodes()) == expected_top_level_data + assert sdfg.arrays.keys() == expected_top_level_data + assert set(tlet for tlet in state.nodes() if isinstance(tlet, dace_nodes.Tasklet)) == { + tasklet_c2 + } + assert set( + if_block for if_block in state.nodes() if isinstance(if_block, dace_nodes.NestedSDFG) + ) == {bot_if_block} + + expected_bot_if_block_data: set[str] = { + "a", + "b", + "c", + "e", + "c1", + "a2", + "b2", + "__arg1", + "__arg2", + "__output", + "__cond", + } + assert set(bot_if_block.sdfg.arrays.keys()) == expected_bot_if_block_data + + expected_top_if_block_data: set[str] = { + "a", + "b", + "a1", + "b1", + "__arg1", + "__arg2", + "__output", + "__cond", + } + assert set(top_if_block.sdfg.arrays.keys()) == expected_top_if_block_data