Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import dace
from dace import data as dace_data
from dace.sdfg import nodes as dace_nodes, propagation as dace_propagation, utils as dace_sdutils
from dace.transformation import dataflow as dace_dataflow
from dace.transformation.auto import auto_optimize as dace_aoptimize
from dace.transformation.passes import analysis as dace_analysis

Expand Down Expand Up @@ -130,6 +131,7 @@ def gt_auto_optimize(
assume_pointwise: bool = True,
optimization_hooks: Optional[dict[GT4PyAutoOptHook, GT4PyAutoOptHookFun]] = None,
demote_fields: Optional[list[str]] = None,
fuse_tasklets: bool = False,
validate: bool = True,
validate_all: bool = False,
**kwargs: Any,
Expand Down Expand Up @@ -197,6 +199,7 @@ def gt_auto_optimize(
see `GT4PyAutoOptHook` for more information.
demote_fields: Consider these fields as transients for the purpose of optimization.
Use at your own risk. See Notes for all implications.
fuse_tasklets: Reduces the number of Tasklets by fusing them.
validate: Perform validation during the steps.
validate_all: Perform extensive validation.

Expand Down Expand Up @@ -324,6 +327,7 @@ def gt_auto_optimize(
blocking_only_if_independent_nodes=blocking_only_if_independent_nodes,
scan_loop_unrolling=scan_loop_unrolling,
scan_loop_unrolling_factor=scan_loop_unrolling_factor,
fuse_tasklets=fuse_tasklets,
validate_all=validate_all,
)

Expand Down Expand Up @@ -660,6 +664,7 @@ def _gt_auto_process_dataflow_inside_maps(
blocking_only_if_independent_nodes: Optional[bool],
scan_loop_unrolling: bool,
scan_loop_unrolling_factor: int,
fuse_tasklets: bool,
validate_all: bool,
) -> dace.SDFG:
"""Optimizes the dataflow inside the top level Maps of the SDFG inplace.
Expand All @@ -674,33 +679,51 @@ def _gt_auto_process_dataflow_inside_maps(
time, so the compiler will fully unroll them anyway.
"""

# Separate Tasklets into dependent and independent parts to promote data
# reusability. It is important that this step has to be performed before
# `TaskletFusion` is used.
if blocking_dim is not None:
sdfg.apply_transformations_once_everywhere(
gtx_transformations.LoopBlocking(
blocking_size=blocking_size,
blocking_parameter=blocking_dim,
require_independent_nodes=blocking_only_if_independent_nodes,
),
validate=False,
validate_all=validate_all,
)

# Merge Tasklets into bigger ones.
# NOTE: Empirical observation for Graupel have shown that this leads to an increase
# in performance, however, it has to be run before `GT4PyMoveTaskletIntoMap`
# (not fully clear why though, probably a compiler artefact) and as well as
# `MoveDataflowIntoIfBody` (not fully clear either, it `TaskletFusion` makes
# things simpler or prevent it from doing certain, negative, things).
# TODO(phimuell): Investigate more.
# TODO(phimuell): Restrict it to Tasklets only inside Maps.
if fuse_tasklets:
sdfg.apply_transformations_repeated(
dace_dataflow.TaskletFusion,
validate=False,
validate_all=validate_all,
)

# Constants (tasklets are needed to write them into a variable) should not be
# arguments to a kernel but be present inside the body.
sdfg.apply_transformations_once_everywhere(
gtx_transformations.GT4PyMoveTaskletIntoMap,
validate=False,
validate_all=validate_all,
)

# TODO(phimuell): figuring out if this is needed?
gtx_transformations.gt_simplify(
sdfg,
skip=gtx_transformations.constants._GT_AUTO_OPT_INNER_DATAFLOW_STAGE_SIMPLIFY_SKIP_LIST,
validate=False,
validate_all=validate_all,
)

# Blocking is performed first, because this ensures that as much as possible
# is moved into the k independent part.
if blocking_dim is not None:
sdfg.apply_transformations_once_everywhere(
gtx_transformations.LoopBlocking(
blocking_size=blocking_size,
blocking_parameter=blocking_dim,
require_independent_nodes=blocking_only_if_independent_nodes,
),
validate=False,
validate_all=validate_all,
)

# Move dataflow into the branches of the `if` such that they are only evaluated
# if they are needed. Important to call it repeatedly.
# TODO(phimuell): It is unclear if `MoveDataflowIntoIfBody` should be called
Expand All @@ -714,6 +737,8 @@ def _gt_auto_process_dataflow_inside_maps(
validate=False,
validate_all=validate_all,
)

# TODO(phimuell): figuring out if this is needed?
gtx_transformations.gt_simplify(
sdfg,
skip=gtx_transformations.constants._GT_AUTO_OPT_INNER_DATAFLOW_STAGE_SIMPLIFY_SKIP_LIST,
Expand Down