Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the SubgraphFusion bug #1688

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dace/codegen/targets/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,10 @@ def reshape_strides(subset, strides, original_strides, copy_shape):
dims = len(copy_shape)

reduced_tile_sizes = [ts for ts, s in zip(subset.tile_sizes, original_copy_shape) if s != 1]
reduced_tile_sizes += [1] * (dims - len(reduced_tile_sizes)) # Pad the remainder with 1s to maintain dimensions.

reshaped_copy = copy_shape + [ts for ts in subset.tile_sizes if ts != 1]
reshaped_copy[:len(copy_shape)] = [s / ts for s, ts in zip(copy_shape, reduced_tile_sizes)]
reshaped_copy[:len(copy_shape)] = [s // ts for s, ts in zip(copy_shape, reduced_tile_sizes)]

new_strides = [0] * len(reshaped_copy)
elements_remaining = functools.reduce(sp.Mul, copy_shape, 1)
Expand Down
8 changes: 4 additions & 4 deletions dace/transformation/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dace.sdfg.state import ControlFlowRegion
from dace.subsets import Range, Subset, union
import dace.subsets as subsets
from typing import Dict, List, Optional, Tuple, Set, Union
from typing import Dict, List, Optional, Tuple, Set, Union, Iterable

from dace import data, dtypes, symbolic
from dace.codegen import control_flow as cf
Expand Down Expand Up @@ -275,7 +275,7 @@ def find_sdfg_control_flow(sdfg: SDFG) -> Dict[SDFGState, Set[SDFGState]]:
if isinstance(child, cf.BasicCFBlock):
if child.state in visited:
continue
components[child.state] = (set([child.state]), child)
components[child.state] = ({child.state}, child)
visited[child.state] = False
elif isinstance(child, (cf.ForScope, cf.WhileScope)):
guard = child.guard
Expand Down Expand Up @@ -1031,11 +1031,11 @@ def are_subsets_contiguous(subset_a: subsets.Subset, subset_b: subsets.Subset, d
return False


def find_contiguous_subsets(subset_list: List[subsets.Subset], dim: int = None) -> Set[subsets.Subset]:
def find_contiguous_subsets(subset_list: Iterable[subsets.Subset], dim: int = None) -> Set[subsets.Subset]:
"""
Finds the set of largest contiguous subsets in a list of subsets.

:param subsets: Iterable of subset objects.
:param subset_list: Iterable of subset objects.
:param dim: Check for contiguity only for the specified dimension.
:return: A list of contiguous subsets.
"""
Expand Down
53 changes: 23 additions & 30 deletions dace/transformation/subgraph/stencil_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,27 @@
""" This module contains classes and functions that implement the orthogonal
stencil tiling transformation. """

import math
import itertools
import warnings
from collections import defaultdict
from copy import deepcopy as dcpy

import dace
from dace import dtypes, registry, symbolic
import dace.subsets as subsets
import dace.symbolic as symbolic
from dace import dtypes
from dace.properties import make_properties, Property, ShapeProperty
from dace.sdfg import nodes
from dace.transformation import transformation
from dace.sdfg.propagation import _propagate_node

from dace.transformation.dataflow.map_for_loop import MapToForLoop
from dace.transformation.dataflow.map_expansion import MapExpansion
from dace.transformation import transformation
from dace.transformation.dataflow.map_collapse import MapCollapse
from dace.transformation.dataflow.map_expansion import MapExpansion
from dace.transformation.dataflow.map_for_loop import MapToForLoop
from dace.transformation.dataflow.strip_mining import StripMining
from dace.transformation.interstate.loop_unroll import LoopUnroll
from dace.transformation.interstate.loop_detection import DetectLoop
from dace.transformation.subgraph import SubgraphFusion

from copy import deepcopy as dcpy

import dace.subsets as subsets
import dace.symbolic as symbolic

import itertools
import warnings

from collections import defaultdict

from dace.transformation.interstate.loop_unroll import LoopUnroll
from dace.transformation.subgraph import helpers
from dace.transformation.subgraph import subgraph_fusion


@make_properties
Expand All @@ -51,7 +44,7 @@ class StencilTiling(transformation.SubgraphTransformation):

prefix = Property(dtype=str, default="stencil", desc="Prefix for new inner tiled range symbols")

strides = ShapeProperty(dtype=tuple, default=(1, ), desc="Tile stride")
strides = ShapeProperty(dtype=tuple, default=(1,), desc="Tile stride")

schedule = Property(dtype=dace.dtypes.ScheduleType,
default=dace.dtypes.ScheduleType.Default,
Expand Down Expand Up @@ -200,13 +193,13 @@ def can_be_applied(sdfg, subgraph) -> bool:

# get intermediate_nodes, out_nodes from SubgraphFusion Transformation
try:
node_config = SubgraphFusion.get_adjacent_nodes(sdfg, graph, map_entries)
node_config = subgraph_fusion.get_adjacent_nodes(graph, map_entries)
(_, intermediate_nodes, out_nodes) = node_config
except NotImplementedError:
return False

# 1.4: check topological feasibility
if not SubgraphFusion.check_topo_feasibility(sdfg, graph, map_entries, intermediate_nodes, out_nodes):
if not subgraph_fusion.check_topo_feasibility(graph, map_entries, intermediate_nodes, out_nodes):
return False
# 1.5 nodes that are both intermediate and out nodes
# are not supported in StencilTiling
Expand All @@ -215,8 +208,8 @@ def can_be_applied(sdfg, subgraph) -> bool:

# 1.6 check that we only deal with compressible transients

subgraph_contains_data = SubgraphFusion.determine_compressible_nodes(sdfg, graph, intermediate_nodes,
map_entries, map_exits)
subgraph_contains_data = subgraph_fusion.determine_compressible_nodes(sdfg, graph, intermediate_nodes,
map_entries, map_exits)
if any([s == False for s in subgraph_contains_data.values()]):
return False

Expand Down Expand Up @@ -264,8 +257,8 @@ def can_be_applied(sdfg, subgraph) -> bool:
for i, (p_subset, c_subset) in enumerate(zip(parent_coverage, children_coverage)):

# transform into subset
p_subset = subsets.Range((p_subset, ))
c_subset = subsets.Range((c_subset, ))
p_subset = subsets.Range((p_subset,))
c_subset = subsets.Range((c_subset,))

# get associated parameter in memlet
params1 = symbolic.symlist(memlets[map_entry][1][data_name][i]).keys()
Expand All @@ -292,7 +285,7 @@ def can_be_applied(sdfg, subgraph) -> bool:
except KeyError:
return False

#parameter mapping must be the same
# parameter mapping must be the same
if param_parent_coverage != param_children_coverage:
return False

Expand Down Expand Up @@ -394,7 +387,7 @@ def apply(self, sdfg):
for data_name, ranges in local_ranges.items():
for param, r in zip(variable_mapping[data_name], ranges):
# create new range from this subset and assign
rng = subsets.Range((r, ))
rng = subsets.Range((r,))
if param:
inferred_ranges[map_entry][param] = subsets.union(inferred_ranges[map_entry][param], rng)

Expand Down Expand Up @@ -457,9 +450,9 @@ def apply(self, sdfg):
reference_range_current = self.reference_range[param]

min_diff = symbolic.SymExpr(reference_range_current.min_element()[0] \
- target_range_current.min_element()[0])
- target_range_current.min_element()[0])
max_diff = symbolic.SymExpr(target_range_current.max_element()[0] \
- reference_range_current.max_element()[0])
- reference_range_current.max_element()[0])

try:
min_diff = symbolic.evaluate(min_diff, {})
Expand Down
Loading
Loading