Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)

import dace
from dace import subsets as dace_subsets
from dace import nodes as dace_nodes, subsets as dace_subsets

from gt4py import eve
from gt4py.eve.extended_typing import MaybeNestedInTuple, NestedTuple
Expand Down Expand Up @@ -68,7 +68,7 @@ class ValueExpr:
gt_dtype: GT4Py data type, which includes the `offset_type` local dimension for lists.
"""

dc_node: dace.nodes.AccessNode
dc_node: dace_nodes.AccessNode
gt_dtype: ts.ListType | ts.ScalarType

def __post_init__(self) -> None:
Expand All @@ -89,7 +89,7 @@ class MemletExpr:
subset: The memlet subset to retrieve the local data.
"""

dc_node: dace.nodes.AccessNode
dc_node: dace_nodes.AccessNode
gt_field: ts.FieldType
subset: dace_subsets.Range

Expand Down Expand Up @@ -131,7 +131,7 @@ class IteratorExpr:
or the result of a tasklet computation like neighbors connectivity or dynamic offset.
"""

field: dace.nodes.AccessNode
field: dace_nodes.AccessNode
gt_dtype: ts.ListType | ts.ScalarType
field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymbolicType]]
indices: dict[gtx_common.Dimension, DataExpr]
Expand Down Expand Up @@ -195,7 +195,7 @@ class DataflowInputEdge(Protocol):
"""

@abc.abstractmethod
def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: ...
def connect(self, map_entry: Optional[dace_nodes.MapEntry]) -> None: ...


@dataclasses.dataclass(frozen=True)
Expand All @@ -208,12 +208,12 @@ class MemletInputEdge(DataflowInputEdge):
"""

state: dace.SDFGState
source: dace.nodes.AccessNode
source: dace_nodes.AccessNode
subset: dace_subsets.Range
dest: dace.nodes.AccessNode | dace.nodes.Tasklet
dest: dace_nodes.AccessNode | dace_nodes.Tasklet
dest_conn: Optional[str]

def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None:
def connect(self, map_entry: Optional[dace_nodes.MapEntry]) -> None:
memlet = dace.Memlet(data=self.source.data, subset=self.subset)
if map_entry is None:
self.state.add_edge(self.source, None, self.dest, self.dest_conn, memlet)
Expand All @@ -237,9 +237,9 @@ class EmptyInputEdge(DataflowInputEdge):
"""

state: dace.SDFGState
node: dace.nodes.Tasklet
node: dace_nodes.Tasklet

def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None:
def connect(self, map_entry: Optional[dace_nodes.MapEntry]) -> None:
if map_entry is None:
# outside of a map scope it is possible to instantiate a tasklet node
# without input connectors
Expand All @@ -266,8 +266,8 @@ class DataflowOutputEdge:

def connect(
self,
map_exit: Optional[dace.nodes.MapExit],
dest: dace.nodes.AccessNode,
map_exit: Optional[dace_nodes.MapExit],
dest: dace_nodes.AccessNode,
dest_subset: dace_subsets.Range,
) -> bool:
"""Create a connection to the `dest` node, writing the given `dest_subset`.
Expand All @@ -281,11 +281,11 @@ def connect(
write_edge = self.state.in_edges(self.result.dc_node)[0]

# Check the kind of node which writes the result
if isinstance(write_edge.src, dace.nodes.Tasklet):
if isinstance(write_edge.src, dace_nodes.Tasklet):
# The temporary data written by a tasklet can be safely deleted.
assert map_exit is not None
remove_last_node = True
elif isinstance(write_edge.src, dace.nodes.NestedSDFG):
elif isinstance(write_edge.src, dace_nodes.NestedSDFG):
if isinstance(dest_desc, dace.data.Scalar):
# We keep scalar temporary storage, as a general rule, since it
# does not affect performance of the generated code. This scalar
Expand Down Expand Up @@ -412,9 +412,9 @@ class LambdaToDataflow(eve.NodeVisitor):

def _add_input_data_edge(
self,
src: dace.nodes.AccessNode,
src: dace_nodes.AccessNode,
src_subset: dace_subsets.Range,
dst_node: dace.nodes.Node,
dst_node: dace_nodes.Node,
dst_conn: Optional[str] = None,
) -> None:
edge = MemletInputEdge(self.state, src, src_subset, dst_node, dst_conn)
Expand All @@ -439,7 +439,7 @@ def _add_map(
List[Tuple[str, Union[str, dace.subsets.Subset]]],
],
**kwargs: Any,
) -> Tuple[dace.nodes.MapEntry, dace.nodes.MapExit]:
) -> Tuple[dace_nodes.MapEntry, dace_nodes.MapExit]:
"""
Helper method to add a map in current state.

Expand All @@ -455,7 +455,7 @@ def _add_tasklet(
outputs: set[str] | Mapping[str, dace.dtypes.typeclass | None],
code: str,
**kwargs: Any,
) -> tuple[dace.nodes.Tasklet, dict[str, str]]:
) -> tuple[dace_nodes.Tasklet, dict[str, str]]:
"""
Helper method to add a tasklet in current state.

Expand All @@ -482,7 +482,7 @@ def _add_mapped_tasklet(
code: str,
outputs: Mapping[str, dace.Memlet],
**kwargs: Any,
) -> tuple[dace.nodes.Tasklet, dace.nodes.MapEntry, dace.nodes.MapExit, dict[str, str]]:
) -> tuple[dace_nodes.Tasklet, dace_nodes.MapEntry, dace_nodes.MapExit, dict[str, str]]:
"""
Helper method to add a mapped tasklet in current state.

Expand Down Expand Up @@ -522,7 +522,7 @@ def _construct_local_view(self, field: MemletExpr | ValueExpr) -> ValueExpr:
def _construct_tasklet_result(
self,
dc_dtype: dace.typeclass,
src_node: dace.nodes.Tasklet,
src_node: dace_nodes.Tasklet,
src_connector: str,
use_array: bool = False,
) -> ValueExpr:
Expand Down Expand Up @@ -1327,7 +1327,7 @@ def _make_reduce_with_skip_values(
reduce_init: SymbolExpr,
reduce_identity: SymbolExpr,
reduce_wcr: str,
result_node: dace.nodes.AccessNode,
result_node: dace_nodes.AccessNode,
) -> None:
"""
Helper method to lower reduction on a local field containing skip values.
Expand Down Expand Up @@ -1600,7 +1600,7 @@ def _make_cartesian_shift(
def _make_dynamic_neighbor_offset(
self,
offset_expr: MemletExpr | ValueExpr,
offset_table_node: dace.nodes.AccessNode,
offset_table_node: dace_nodes.AccessNode,
origin_index: SymbolExpr,
) -> ValueExpr:
"""
Expand Down Expand Up @@ -1647,7 +1647,7 @@ def _make_unstructured_shift(
self,
it: IteratorExpr,
conn_type: gtx_common.NeighborConnectivityType,
conn_node: dace.nodes.AccessNode,
conn_node: dace_nodes.AccessNode,
offset_expr: DataExpr,
) -> IteratorExpr:
"""Implements shift in unstructured domain by means of a neighbor table."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Any, Dict, Iterable, List, Mapping, Optional, Protocol, Sequence, Tuple, Union

import dace
from dace import subsets as dace_subsets
from dace import nodes as dace_nodes, subsets as dace_subsets
from dace.frontend.python import astutils as dace_astutils

from gt4py import eve
Expand Down Expand Up @@ -103,7 +103,7 @@ def add_map(
List[Tuple[str, Union[str, dace.subsets.Subset]]],
],
**kwargs: Any,
) -> Tuple[dace.nodes.MapEntry, dace.nodes.MapExit]:
) -> Tuple[dace_nodes.MapEntry, dace_nodes.MapExit]:
"""Wrapper of `dace.SDFGState.add_map` that assigns unique name."""
unique_name = self.unique_map_name(name)
return state.add_map(unique_name, ndrange, **kwargs)
Expand All @@ -118,7 +118,7 @@ def add_tasklet(
code: str,
language: dace.dtypes.Language = dace.dtypes.Language.Python,
**kwargs: Any,
) -> dace.nodes.Tasklet:
) -> dace_nodes.Tasklet:
"""Wrapper of `dace.SDFGState.add_tasklet` that assigns a unique name.

It also modifies the tasklet connectors by adding a prefix string (see
Expand Down Expand Up @@ -157,7 +157,7 @@ def add_mapped_tasklet(
outputs: Mapping[str, dace.Memlet],
language: dace.dtypes.Language = dace.dtypes.Language.Python,
**kwargs: Any,
) -> tuple[dace.nodes.Tasklet, dace.nodes.MapEntry, dace.nodes.MapExit, dict[str, str]]:
) -> tuple[dace_nodes.Tasklet, dace_nodes.MapEntry, dace_nodes.MapExit, dict[str, str]]:
"""Wrapper of `dace.SDFGState.add_mapped_tasklet` that assigns a unique name.

It also modifies the tasklet connectors, in the same way as `add_tasklet()`.
Expand Down Expand Up @@ -317,7 +317,7 @@ class SDFGBuilder(DataflowBuilder, Protocol):
@abc.abstractmethod
def make_field(
self,
data_node: dace.nodes.AccessNode,
data_node: dace_nodes.AccessNode,
data_type: ts.FieldType,
) -> gtir_to_sdfg_types.FieldopData:
"""Retrieve the field descriptor of a data node, including the origin information.
Expand Down Expand Up @@ -391,7 +391,7 @@ def add_nested_sdfg(
data_args: Mapping[str, gtir_to_sdfg_types.FieldopData | None],
inner_result: gtir_to_sdfg_types.FieldopResult,
capture_outer_data: bool,
) -> tuple[dace.nodes.NestedSDFG, Mapping[str, dace.Memlet]]:
) -> tuple[dace_nodes.NestedSDFG, Mapping[str, dace.Memlet]]:
"""
Helper function that prepares the input connections and symbol mapping before
calling `SDFG.add_nestd_sdfg()` to add the given SDFG as a nested SDFG node
Expand Down Expand Up @@ -525,7 +525,7 @@ def get_offset_provider_type(self, offset: str) -> gtx_common.OffsetProviderType

def make_field(
self,
data_node: dace.nodes.AccessNode,
data_node: dace_nodes.AccessNode,
data_type: ts.FieldType,
) -> gtir_to_sdfg_types.FieldopData:
local_dims = [dim for dim in data_type.dims if dim.kind == gtx_common.DimensionKind.LOCAL]
Expand Down Expand Up @@ -658,7 +658,7 @@ def add_nested_sdfg(
data_args: Mapping[str, gtir_to_sdfg_types.FieldopData | None],
inner_result: gtir_to_sdfg_types.FieldopResult,
capture_outer_data: bool,
) -> tuple[dace.nodes.NestedSDFG, Mapping[str, dace.Memlet]]:
) -> tuple[dace_nodes.NestedSDFG, Mapping[str, dace.Memlet]]:
assert data_args.keys().isdisjoint(symbolic_args.keys())

# Collect the names of all output data, by flattening any tuple structure.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Sequence

import dace
from dace import subsets as dace_subsets
from dace import nodes as dace_nodes, subsets as dace_subsets

from gt4py.next import common as gtx_common
from gt4py.next.iterator import ir as gtir
Expand Down Expand Up @@ -46,7 +46,7 @@ def _translate_concat_where_branch(
output_domain: domain_utils.SymbolicDomain,
output_type: ts.FieldType,
output_desc: dace.data.Array,
output_node: dace.nodes.AccessNode,
output_node: dace_nodes.AccessNode,
output_origin: Sequence[dace.symbolic.SymbolicType],
) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import TYPE_CHECKING, Iterable, Optional, Protocol

import dace
from dace import subsets as dace_subsets
from dace import nodes as dace_nodes, subsets as dace_subsets

from gt4py.next import common as gtx_common, utils as gtx_utils
from gt4py.next.iterator import ir as gtir
Expand Down Expand Up @@ -91,7 +91,7 @@ def _create_field_operator_impl(
field_domain: gtir_domain.FieldopDomain,
output_edge: gtir_dataflow.DataflowOutputEdge,
output_type: ts.FieldType,
map_exit: dace.nodes.MapExit,
map_exit: dace_nodes.MapExit,
) -> gtir_to_sdfg_types.FieldopData:
"""
Helper method to allocate a temporary array that stores one field computed
Expand Down Expand Up @@ -546,7 +546,7 @@ def _get_symbolic_value(
symbolic_expr: dace.symbolic.SymExpr,
scalar_type: ts.ScalarType,
temp_name: Optional[str] = None,
) -> dace.nodes.AccessNode:
) -> dace_nodes.AccessNode:
tasklet_node, connector_mapping = sdfg_builder.add_tasklet(
name="get_value",
sdfg=sdfg,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from typing import Iterable, Sequence

import dace
from dace import subsets as dace_subsets
from dace import nodes as dace_nodes, subsets as dace_subsets

from gt4py import eve
from gt4py.eve.extended_typing import MaybeNestedInTuple
Expand Down Expand Up @@ -92,7 +92,7 @@ def _create_scan_field_operator_impl(
output_edge: gtir_dataflow.DataflowOutputEdge | None,
output_domain: infer_domain.NonTupleDomainAccess,
output_type: ts.FieldType,
map_exit: dace.nodes.MapExit | None,
map_exit: dace_nodes.MapExit | None,
) -> gtir_to_sdfg_types.FieldopData | None:
"""
Helper method to allocate a temporary array that stores one field computed
Expand Down Expand Up @@ -180,7 +180,7 @@ def _create_scan_field_operator_impl(
# to modify the stride of the scan column array inside the nested SDFG to match
# the strides outside.
nsdfg_scan = field_node_path[0].src
assert isinstance(nsdfg_scan, dace.nodes.NestedSDFG)
assert isinstance(nsdfg_scan, dace_nodes.NestedSDFG)
inner_output_name = field_node_path[0].src_conn
inner_output_desc = nsdfg_scan.sdfg.arrays[inner_output_name]
assert len(inner_output_desc.shape) == 1
Expand Down Expand Up @@ -530,7 +530,7 @@ def connect_scan_output(

def _handle_dataflow_result_of_nested_sdfg(
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
nsdfg_node: dace.nodes.NestedSDFG,
nsdfg_node: dace_nodes.NestedSDFG,
inner_ctx: gtir_to_sdfg.SubgraphContext,
outer_ctx: gtir_to_sdfg.SubgraphContext,
inner_data: gtir_to_sdfg_types.FieldopData,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Final, TypeAlias

import dace
from dace import subsets as dace_subsets
from dace import nodes as dace_nodes, subsets as dace_subsets

from gt4py.eve.extended_typing import MaybeNestedInTuple
from gt4py.next import common as gtx_common
Expand All @@ -40,7 +40,7 @@ class FieldopData:
Pass an empty tuple for `ScalarType` data or zero-dimensional fields.
"""

dc_node: dace.nodes.AccessNode
dc_node: dace_nodes.AccessNode
gt_type: ts.FieldType | ts.ScalarType
origin: tuple[dace.symbolic.SymbolicType, ...]

Expand Down