Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,15 @@ def _gt_auto_process_dataflow_inside_maps(
# before or after `LoopBlocking`. In cases where the condition is `False`
# most of the times calling it before is better, but if the condition is
# `True` then this order is better. Solve that issue.
# NOTE: The transformation is currently only able to handle dataflow that is
# _directly_ enclosed by a Map. Thus the order in which they (multiple blocks
# in the same Map) are processed matter. Think of a chain of `if` blocks that
# can be perfectly nested. If the last one is handled first, then all other
# can not be processed anymore. This means it is important to set
# `ignore_upstream_blocks` to `False`, thus the transformation will not apply
# if the dataflow that should be relocated into an `if` block contains again
# `if` blocks that can be relocated. It would be more efficient to process
# them in the right order from the beginning.
sdfg.apply_transformations_repeated(
gtx_transformations.MoveDataflowIntoIfBody(
ignore_upstream_blocks=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def can_be_applied(
if_block=if_block,
raw_relocatable_dataflow=raw_relocatable_dataflow,
non_relocatable_dataflow=non_relocatable_dataflow,
enclosing_map=enclosing_map,
)

# If no branch has something to inline then we are done.
Expand Down Expand Up @@ -204,6 +205,7 @@ def apply(
if_block=if_block,
raw_relocatable_dataflow=raw_relocatable_dataflow,
non_relocatable_dataflow=non_relocatable_dataflow,
enclosing_map=enclosing_map,
)

# Finally relocate the dataflow
Expand Down Expand Up @@ -551,6 +553,7 @@ def _has_if_block_relocatable_dataflow(
if_block=upstream_if_block,
raw_relocatable_dataflow=raw_relocatable_dataflow,
non_relocatable_dataflow=non_relocatable_dataflow,
enclosing_map=enclosing_map,
)
if all(len(rel_df) == 0 for rel_df in filtered_relocatable_dataflow.values()):
return False
Expand All @@ -564,6 +567,7 @@ def _filter_relocatable_dataflow(
if_block: dace_nodes.NestedSDFG,
raw_relocatable_dataflow: dict[str, set[dace_nodes.Node]],
non_relocatable_dataflow: dict[str, set[dace_nodes.Node]],
enclosing_map: dace_nodes.MapEntry,
) -> dict[str, set[dace_nodes.Node]]:
"""Partition the dependencies.

Expand All @@ -581,6 +585,8 @@ def _filter_relocatable_dataflow(
that can be relocated, not yet filtered.
non_relocatable_dataflow: The connectors and their associated dataflow
that can not be relocated.
enclosing_map: The limiting node, i.e. the MapEntry of the Map where
`if_block` is located in.
"""

# Remove the parts of the dataflow that is unrelocatable.
Expand All @@ -592,8 +598,9 @@ def _filter_relocatable_dataflow(
for conn_name, rel_df in raw_relocatable_dataflow.items()
}

# Now we determine the nodes that are in more than one sets.
# These sets must be removed, from the individual sets.
# Relocating nodes that are in more than one set is difficult. In the most
# common case of just two branches, this anyway means they have to be
# executed in any case. Thus we remove them now.
known_nodes: set[dace_nodes.Node] = set()
multiple_df_nodes: set[dace_nodes.Node] = set()
for rel_df in relocatable_dataflow.values():
Expand All @@ -606,35 +613,73 @@ def _filter_relocatable_dataflow(
for conn_name, rel_df in relocatable_dataflow.items()
}

# However, not all dataflow can be moved inside the branch. For example if
# something is used outside the dataflow, that is moved inside the `if`,
# then we can not relocate it.
# TODO(phimuell): If we operate outside of a Map we also have to make sure that
# the data is single use data, is not an AccessNode that refers to global
# memory nor is a source AccessNode.
def filter_nodes(
branch_nodes: set[dace_nodes.Node],
sdfg: dace.SDFG,
state: dace.SDFGState,
nodes_proposed_for_reloc: set[dace_nodes.Node],
) -> set[dace_nodes.Node]:
# For this to work the `if_block` must be considered part, we remove it later.
branch_nodes.add(if_block)
has_been_updated = True
while has_been_updated:
has_been_updated = False
for node in list(branch_nodes):
if node is if_block:

for reloc_node in list(nodes_proposed_for_reloc):
# The node was already handled in a previous iteration.
if reloc_node not in nodes_proposed_for_reloc:
continue

assert (
state.in_degree(reloc_node) > 0
) # Because we are currently always inside a Map

# If the node is needed by anything that is not also moved
# into the `if` body, then it has to remain outside. For that we
# have to pretend that `if_block` is also relocated.
if any(
oedge.dst not in nodes_proposed_for_reloc
for oedge in state.out_edges(reloc_node)
if oedge.dst is not if_block
):
nodes_proposed_for_reloc.remove(reloc_node)
has_been_updated = True
continue
if any(oedge.dst not in branch_nodes for oedge in state.out_edges(node)):
branch_nodes.remove(node)

# We do not look at all incoming nodes, but have to ignore some of them.
# We ignore `enclosed_map` because it acts as boundary, and the node
# on the other side of it is mapped into the `if` body anyway. We
# ignore the AccessNodes because they will either be relocated into
# the `if` body or be mapped (remain outside but made accessible
# inside), thus their relocation state is of no concern for
# `reloc_node`.
non_mappable_incoming_nodes: set[dace_nodes.Node] = {
iedge.src
for iedge in state.in_edges(reloc_node)
if not (
(iedge.src is enclosing_map)
or isinstance(iedge.src, dace_nodes.AccessNode)
)
}
if non_mappable_incoming_nodes.issubset(nodes_proposed_for_reloc):
# All nodes that can not be mapped into the `if` body are
# currently scheduled to be relocated, thus there is not
# problem.
pass

else:
# Only some of the non mappable nodes are selected to be
# moved inside the `if` body. This means that `reloc_node`
# can also not be moved because of its input dependencies.
# Since we can not relocate `reloc_node` this also implies
# that none of its input can. Thus we remove them from
# `nodes_proposed_for_reloc`.
nodes_proposed_for_reloc.difference_update(non_mappable_incoming_nodes)
nodes_proposed_for_reloc.remove(reloc_node)
has_been_updated = True
assert if_block in branch_nodes
branch_nodes.remove(if_block)
return branch_nodes

return nodes_proposed_for_reloc

return {
conn_name: filter_nodes(rel_df, sdfg, state)
for conn_name, rel_df in relocatable_dataflow.items()
conn_name: filter_nodes(rel_df) for conn_name, rel_df in relocatable_dataflow.items()
}

def _partition_if_block(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1047,3 +1047,155 @@ def test_if_mover_symbolic_tasklet():
assert if_block.sdfg.symbols["__i"] in {dace.int32, dace.int64}
assert if_block.symbol_mapping.keys() == expected_symb.union(["__i"])
assert all(str(sym) == str(symval) for sym, symval in if_block.symbol_mapping.items())


def test_if_mover_access_node_between():
"""
Essentially tests the following situation:
```python
a = foo(...)
b = bar(...)
c = baz(...)
bb = a if c else b
cc = baz2(d, ...)
aa = foo2(...)
e = aa if cc else bb
```
"""
sdfg = dace.SDFG(util.unique_name("if_mover_chain_of_blocks"))
state = sdfg.add_state(is_start_block=True)

# Inputs
input_names = ["a", "b", "c", "d", "e", "f"]
for name in input_names:
sdfg.add_array(
name,
shape=(10,),
dtype=dace.float64,
transient=False,
)

# Temporaries
temporary_names = ["a1", "b1", "c1", "a2", "b2", "c2"]
for name in temporary_names:
sdfg.add_scalar(
name, dtype=dace.bool_ if name.startswith("c") else dace.float64, transient=True
)

a1, b1, c1, a2, b2, c2 = (state.add_access(name) for name in temporary_names)
me, mx = state.add_map("comp", ndrange={"__i": "0:10"})

# First branch of top `if_block`
tasklet_a1 = state.add_tasklet(
"tasklet_a1", inputs={"__in"}, outputs={"__out"}, code="__out = math.sin(__in)"
)
state.add_edge(state.add_access("a"), None, me, "IN_a", dace.Memlet("a[0:10]"))
state.add_edge(me, "OUT_a", tasklet_a1, "__in", dace.Memlet("a[__i]"))
state.add_edge(tasklet_a1, "__out", a1, None, dace.Memlet("a1[0]"))

# Second branch of the top `if_block`
tasklet_b1 = state.add_tasklet(
"tasklet_b1", inputs={"__in"}, outputs={"__out"}, code="__out = math.cos(__in)"
)
state.add_edge(state.add_access("b"), None, me, "IN_b", dace.Memlet("b[0:10]"))
state.add_edge(me, "OUT_b", tasklet_b1, "__in", dace.Memlet("b[__i]"))
state.add_edge(tasklet_b1, "__out", b1, None, dace.Memlet("b1[0]"))

# The condition of the top `if_block`
tasklet_c1 = state.add_tasklet(
"tasklet_c1", inputs={"__in"}, outputs={"__out"}, code="__out = __in < 0.5"
)
state.add_edge(state.add_access("c"), None, me, "IN_c", dace.Memlet("c[0:10]"))
state.add_edge(me, "OUT_c", tasklet_c1, "__in", dace.Memlet("c[__i]"))
state.add_edge(tasklet_c1, "__out", c1, None, dace.Memlet("c1[0]"))

# Create the top `if_block`
top_if_block = _make_if_block(state, sdfg)
state.add_edge(a1, None, top_if_block, "__arg1", dace.Memlet("a1[0]"))
state.add_edge(b1, None, top_if_block, "__arg2", dace.Memlet("b1[0]"))
state.add_edge(c1, None, top_if_block, "__cond", dace.Memlet("c1[0]"))
state.add_edge(top_if_block, "__output", a2, None, dace.Memlet("a2[0]"))

# The first branch of the lower/second `if_block`
tasklet_b2 = state.add_tasklet(
"tasklet_b2", inputs={"__in"}, outputs={"__out"}, code="__out = math.atan(__in)"
)
state.add_edge(state.add_access("e"), None, me, "IN_e", dace.Memlet("e[0:10]"))
state.add_edge(me, "OUT_e", tasklet_b2, "__in", dace.Memlet("e[__i]"))
state.add_edge(tasklet_b2, "__out", b2, None, dace.Memlet("b2[0]"))

# Condition branch of the second `if_block`.
tasklet_c2 = state.add_tasklet(
"tasklet_c2", inputs={"__in"}, outputs={"__out"}, code="__out = __in < 0.5"
)
state.add_edge(state.add_access("d"), None, me, "IN_d", dace.Memlet("d[0:10]"))
state.add_edge(me, "OUT_d", tasklet_c2, "__in", dace.Memlet("d[__i]"))
state.add_edge(tasklet_c2, "__out", c2, None, dace.Memlet("c2[0]"))

# Create the second `if_block`
bot_if_block = _make_if_block(state, sdfg)
state.add_edge(a2, None, bot_if_block, "__arg1", dace.Memlet("a2[0]"))
state.add_edge(b2, None, bot_if_block, "__arg2", dace.Memlet("b2[0]"))
state.add_edge(c2, None, bot_if_block, "__cond", dace.Memlet("c2[0]"))

# Generate the output
state.add_edge(bot_if_block, "__output", mx, "IN_f", dace.Memlet("f[__i]"))
state.add_edge(mx, "OUT_f", state.add_access("f"), None, dace.Memlet("f[0:10]"))

# Now add the connectors to the Map*
for iname in input_names:
mq = mx if iname == "f" else me
mq.add_in_connector(f"IN_{iname}")
mq.add_out_connector(f"OUT_{iname}")
sdfg.validate()

# We can not process the bottom block, because this would also inline the top
# block that in turn has dataflow that could be relocated.
_perform_test(
sdfg,
explected_applies=0,
if_block=bot_if_block,
)

# But we are able to process them that way, starting from the bottom.
_perform_test(
sdfg,
explected_applies=2,
)

expected_top_level_data: set[str] = {"a", "b", "c", "d", "e", "f", "c2"}
assert set(dnode.data for dnode in state.data_nodes()) == expected_top_level_data
assert sdfg.arrays.keys() == expected_top_level_data
assert set(tlet for tlet in state.nodes() if isinstance(tlet, dace_nodes.Tasklet)) == {
tasklet_c2
}
assert set(
if_block for if_block in state.nodes() if isinstance(if_block, dace_nodes.NestedSDFG)
) == {bot_if_block}

expected_bot_if_block_data: set[str] = {
"a",
"b",
"c",
"e",
"c1",
"a2",
"b2",
"__arg1",
"__arg2",
"__output",
"__cond",
}
assert set(bot_if_block.sdfg.arrays.keys()) == expected_bot_if_block_data

expected_top_if_block_data: set[str] = {
"a",
"b",
"a1",
"b1",
"__arg1",
"__arg2",
"__output",
"__cond",
}
assert set(top_if_block.sdfg.arrays.keys()) == expected_top_if_block_data