Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,9 @@ def _gt_auto_process_dataflow_inside_maps(
# before or after `LoopBlocking`. In cases where the condition is `False`
# most of the times calling it before is better, but if the condition is
# `True` then this order is better. Solve that issue.
# TODO(phimuell): Because of the limitation that the transformation only works
# for dataflow that is directly enclosed by a Map, the order in which it is
# applied matters. Instead we have to run it into a topological order.
sdfg.apply_transformations_repeated(
gtx_transformations.MoveDataflowIntoIfBody(
ignore_upstream_blocks=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def can_be_applied(
if_block=if_block,
raw_relocatable_dataflow=raw_relocatable_dataflow,
non_relocatable_dataflow=non_relocatable_dataflow,
enclosing_map=enclosing_map,
)

# If no branch has something to inline then we are done.
Expand Down Expand Up @@ -204,6 +205,7 @@ def apply(
if_block=if_block,
raw_relocatable_dataflow=raw_relocatable_dataflow,
non_relocatable_dataflow=non_relocatable_dataflow,
enclosing_map=enclosing_map,
)

# Finally relocate the dataflow
Expand Down Expand Up @@ -551,6 +553,7 @@ def _has_if_block_relocatable_dataflow(
if_block=upstream_if_block,
raw_relocatable_dataflow=raw_relocatable_dataflow,
non_relocatable_dataflow=non_relocatable_dataflow,
enclosing_map=enclosing_map,
)
if all(len(rel_df) == 0 for rel_df in filtered_relocatable_dataflow.values()):
return False
Expand All @@ -564,6 +567,7 @@ def _filter_relocatable_dataflow(
if_block: dace_nodes.NestedSDFG,
raw_relocatable_dataflow: dict[str, set[dace_nodes.Node]],
non_relocatable_dataflow: dict[str, set[dace_nodes.Node]],
enclosing_map: dace_nodes.MapEntry,
) -> dict[str, set[dace_nodes.Node]]:
"""Partition the dependencies.

Expand All @@ -581,6 +585,8 @@ def _filter_relocatable_dataflow(
that can be relocated, not yet filtered.
non_relocatable_dataflow: The connectors and their associated dataflow
that can not be relocated.
enclosing_map: The limiting node, i.e. the MapEntry of the Map `if_block`
is located in.
"""

# Remove the parts of the dataflow that is unrelocatable.
Expand All @@ -592,8 +598,9 @@ def _filter_relocatable_dataflow(
for conn_name, rel_df in raw_relocatable_dataflow.items()
}

# Now we determine the nodes that are in more than one sets.
# These sets must be removed, from the individual sets.
# Relocating nodes that are in more than one set is difficult. In the most
# common case of just two branches, this anyway means they have to be
# executed in any case. Thus we remove them now.
known_nodes: set[dace_nodes.Node] = set()
multiple_df_nodes: set[dace_nodes.Node] = set()
for rel_df in relocatable_dataflow.values():
Expand All @@ -606,35 +613,73 @@ def _filter_relocatable_dataflow(
for conn_name, rel_df in relocatable_dataflow.items()
}

# However, not all dataflow can be moved inside the branch. For example if
# something is used outside the dataflow, that is moved inside the `if`,
# then we can not relocate it.
# TODO(phimuell): If we operate outside of a Map we also have to make sure that
# the data is single use data, is not an AccessNode that refers to global
# memory nor is a source AccessNode.
def filter_nodes(
branch_nodes: set[dace_nodes.Node],
sdfg: dace.SDFG,
state: dace.SDFGState,
nodes_proposed_for_reloc: set[dace_nodes.Node],
) -> set[dace_nodes.Node]:
# For this to work the `if_block` must be considered part, we remove it later.
branch_nodes.add(if_block)
has_been_updated = True
while has_been_updated:
has_been_updated = False
for node in list(branch_nodes):
if node is if_block:

for reloc_node in list(nodes_proposed_for_reloc):
assert (
state.in_degree(reloc_node) > 0
) # Because we are currently always inside a Map

# If the node is needed by anything that is not also moved
# into the `if` body, then it has to remain outside. For that we
# have to pretend that `if_block` is also relocated.
if any(
oedge.dst not in nodes_proposed_for_reloc
for oedge in state.out_edges(reloc_node)
if oedge.dst is not if_block
):
nodes_proposed_for_reloc.discard(reloc_node)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why discard() and not remove()? remove() is more strict because it also checks that it exists, right? and the node has to exists in nodes_proposed_for_reloc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting point.
If you look at line 680, you see that sometimes multiple nodes are removed in one go, this means because you make a copy you might have already removed that node in a previous iteration, thus the node might be no longer inside.
However, your comment made me think and I now changed it, such that at the beginning, it is checked if the node is still there.
Which is probably nicer and a bit faster, because you perform an early exit.

has_been_updated = True
continue
if any(oedge.dst not in branch_nodes for oedge in state.out_edges(node)):
branch_nodes.remove(node)

# We do not look at all incoming nodes, but have to ignore some of them.
# We ignore `enclosed_map` because it acts as boundary, and the node
# on the other side of it is mapped into the `if` body anyway. Then we
# have to ignore all AccessNodes, since they are either relocated into
# the `if` body or are mapped into. We then have to look only at the
# remaining nodes.
incoming_nodes: set[dace_nodes.Node] = {
iedge.src
for iedge in state.in_edges(reloc_node)
if not (
(iedge.src is enclosing_map)
or isinstance(iedge.src, dace_nodes.AccessNode)
)
}
if incoming_nodes.issubset(nodes_proposed_for_reloc):
# All nodes will be moved into the `if` body too, so no problem.
pass

elif incoming_nodes.isdisjoint(nodes_proposed_for_reloc):
# None of the incoming nodes will be moved into the if body,
# thus `reloc_node` is an interface node, it might be _mapped_
# into the `if` body (if it is an `AccessNode`), but the node
# itself will not be moved into the `if` body.
nodes_proposed_for_reloc.discard(reloc_node)
has_been_updated = True
assert if_block in branch_nodes
branch_nodes.remove(if_block)
return branch_nodes

else:
# Only some of the incoming nodes will be moved into the `if`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first part of the comment is misleading, in my opinion. I would write:

Suggested change
# Only some of the incoming nodes will be moved into the `if`
# Only some of the incoming nodes could be moved into the `if`

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right your suggestion is better.

# body. This is legal only if the not moved nodes are
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and below:

Suggested change
# body. This is legal only if the not moved nodes are
# body. It would be legal to relocate the node if all incoming nodes were access nodes, this is why access nodes are not included in 'incoming_nodes'.
# In this case, we cannot relocate the node, so we discard also those 'incoming_nodes' that were candidates for relocation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kind of like this suggestion, but it is partially a repetition of the comments about input_nodes.
However, I have renamed input_nodes (which is a bad name) to non_mappable_input_nodes which is a much more accurate then the old one and have updated the description by incorporating your suggestions.
I also have refactored it a bit it should be now simpler.

# AccessNodes, because we have ignored them in the definition
# of `incoming_nodes`, `reloc_node` can not be moved into
# the `if` body and neither can the incoming nodes.
nodes_proposed_for_reloc.difference_update(incoming_nodes)
nodes_proposed_for_reloc.discard(reloc_node)
has_been_updated = True

return nodes_proposed_for_reloc

return {
conn_name: filter_nodes(rel_df, sdfg, state)
for conn_name, rel_df in relocatable_dataflow.items()
conn_name: filter_nodes(rel_df) for conn_name, rel_df in relocatable_dataflow.items()
}

def _partition_if_block(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1047,3 +1047,131 @@ def test_if_mover_symbolic_tasklet():
assert if_block.sdfg.symbols["__i"] in {dace.int32, dace.int64}
assert if_block.symbol_mapping.keys() == expected_symb.union(["__i"])
assert all(str(sym) == str(symval) for sym, symval in if_block.symbol_mapping.items())


def test_if_mover_access_node_between():
"""
Essentially tests the following situation:
```python
a = foo(...)
b = bar(...)
c = baz(...)
bb = a if c else b
cc = baz2(d, ...)
aa = foo2(...)
e = aa if cc else bb
```
"""
# This test is temporarily disabled.
return
sdfg = dace.SDFG(util.unique_name("if_mover_chain_of_blocks"))
state = sdfg.add_state(is_start_block=True)

# Inputs
input_names = ["a", "b", "c", "d", "e", "f"]
for name in input_names:
sdfg.add_array(
name,
shape=(10,),
dtype=dace.float64,
transient=False,
)

# Temporaries
temporary_names = ["a1", "b1", "c1", "a2", "b2", "c2", "o"]
for name in temporary_names:
sdfg.add_scalar(
name, dtype=dace.bool_ if name.startswith("c") else dace.float64, transient=True
)

a1, b1, c1, a2, b2, c2, o = (state.add_access(name) for name in temporary_names)
me, mx = state.add_map("comp", ndrange={"__i": "0:10"})

# First branch of top `if_block`
tasklet_a1 = state.add_tasklet(
"tasklet_a1", inputs={"__in"}, outputs={"__out"}, code="__out = math.sin(__in)"
)
state.add_edge(state.add_access("a"), None, me, "IN_a", dace.Memlet("a[0:10]"))
state.add_edge(me, "OUT_a", tasklet_a1, "__in", dace.Memlet("a[__i]"))
state.add_edge(tasklet_a1, "__out", a1, None, dace.Memlet("a1[0]"))

# Second branch of the top `if_block`
tasklet_b1 = state.add_tasklet(
"tasklet_b1", inputs={"__in"}, outputs={"__out"}, code="__out = math.cos(__in)"
)
state.add_edge(state.add_access("b"), None, me, "IN_b", dace.Memlet("b[0:10]"))
state.add_edge(me, "OUT_b", tasklet_b1, "__in", dace.Memlet("b[__i]"))
state.add_edge(tasklet_b1, "__out", b1, None, dace.Memlet("b1[0]"))

# The condition of the top `if_block`
tasklet_c1 = state.add_tasklet(
"tasklet_c1", inputs={"__in"}, outputs={"__out"}, code="__out = __in < 0.5"
)
state.add_edge(state.add_access("c"), None, me, "IN_c", dace.Memlet("c[0:10]"))
state.add_edge(me, "OUT_c", tasklet_c1, "__in", dace.Memlet("c[__i]"))
state.add_edge(tasklet_c1, "__out", c1, None, dace.Memlet("c1[0]"))

# Create the top `if_block`
top_if_block = _make_if_block(state, sdfg)
state.add_edge(a1, None, top_if_block, "__arg1", dace.Memlet("a1[0]"))
state.add_edge(b1, None, top_if_block, "__arg2", dace.Memlet("b1[0]"))
state.add_edge(c1, None, top_if_block, "__cond", dace.Memlet("c1[0]"))
state.add_edge(top_if_block, "__output", t1, None, dace.Memlet("t1[0]"))

# The first branch of the lower/second `if_block`, which uses data computed
# by the top `if_block`.
tasklet_t2 = state.add_tasklet(
"tasklet_t2", inputs={"__in"}, outputs={"__out"}, code="__out = math.exp(__in)"
)
state.add_edge(t1, None, tasklet_t2, "__in", dace.Memlet("t1[0]"))
state.add_edge(tasklet_t2, "__out", t2, None, dace.Memlet("t2[0]"))

# Second branch of the second `if_block`.
tasklet_d1 = state.add_tasklet(
"tasklet_d1", inputs={"__in"}, outputs={"__out"}, code="__out = math.atan(__in)"
)
state.add_edge(state.add_access("d"), None, me, "IN_d", dace.Memlet("d[0:10]"))
state.add_edge(me, "OUT_d", tasklet_d1, "__in", dace.Memlet("d[__i]"))
state.add_edge(tasklet_d1, "__out", d1, None, dace.Memlet("d1[0]"))

# Condition branch of the second `if_block`.
tasklet_cc1 = state.add_tasklet(
"tasklet_cc1", inputs={"__in"}, outputs={"__out"}, code="__out = __in < 0.5"
)
state.add_edge(state.add_access("cc"), None, me, "IN_cc", dace.Memlet("cc[0:10]"))
state.add_edge(me, "OUT_cc", tasklet_cc1, "__in", dace.Memlet("cc[__i]"))
state.add_edge(tasklet_cc1, "__out", cc1, None, dace.Memlet("cc1[0]"))

# Create the second `if_block`
bot_if_block = _make_if_block(state, sdfg)
state.add_edge(t2, None, bot_if_block, "__arg1", dace.Memlet("t2[0]"))
state.add_edge(d1, None, bot_if_block, "__arg2", dace.Memlet("d1[0]"))
state.add_edge(cc1, None, bot_if_block, "__cond", dace.Memlet("cc1[0]"))

# Generate the output
state.add_edge(bot_if_block, "__output", mx, "IN_e", dace.Memlet("e[__i]"))
state.add_edge(mx, "OUT_e", state.add_access("e"), None, dace.Memlet("e[0:10]"))

# Now add the connectors to the Map*
for iname in input_names:
if iname == "e":
mx.add_in_connector(f"IN_{iname}")
mx.add_out_connector(f"OUT_{iname}")
else:
me.add_in_connector(f"IN_{iname}")
me.add_out_connector(f"OUT_{iname}")
sdfg.validate()

# It is not possible to apply the transformation on the lower `if_block`,
# because it is limited by the top one.
_perform_test(
sdfg,
explected_applies=0,
if_block=bot_if_block,
)

# But we are able to inline both.
_perform_test(
sdfg,
explected_applies=2,
)