diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index 562da246ab..6acec11dbc 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -26,7 +26,7 @@ ) import dace -from dace import subsets as dace_subsets +from dace import nodes as dace_nodes, subsets as dace_subsets from gt4py import eve from gt4py.eve.extended_typing import MaybeNestedInTuple, NestedTuple @@ -68,7 +68,7 @@ class ValueExpr: gt_dtype: GT4Py data type, which includes the `offset_type` local dimension for lists. """ - dc_node: dace.nodes.AccessNode + dc_node: dace_nodes.AccessNode gt_dtype: ts.ListType | ts.ScalarType def __post_init__(self) -> None: @@ -89,7 +89,7 @@ class MemletExpr: subset: The memlet subset to retrieve the local data. """ - dc_node: dace.nodes.AccessNode + dc_node: dace_nodes.AccessNode gt_field: ts.FieldType subset: dace_subsets.Range @@ -131,7 +131,7 @@ class IteratorExpr: or the result of a tasklet computation like neighbors connectivity or dynamic offset. """ - field: dace.nodes.AccessNode + field: dace_nodes.AccessNode gt_dtype: ts.ListType | ts.ScalarType field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymbolicType]] indices: dict[gtx_common.Dimension, DataExpr] @@ -195,7 +195,7 @@ class DataflowInputEdge(Protocol): """ @abc.abstractmethod - def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: ... + def connect(self, map_entry: Optional[dace_nodes.MapEntry]) -> None: ... @dataclasses.dataclass(frozen=True) @@ -208,12 +208,12 @@ class MemletInputEdge(DataflowInputEdge): """ state: dace.SDFGState - source: dace.nodes.AccessNode + source: dace_nodes.AccessNode subset: dace_subsets.Range - dest: dace.nodes.AccessNode | dace.nodes.Tasklet + dest: dace_nodes.AccessNode | dace_nodes.Tasklet dest_conn: Optional[str] - def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: + def connect(self, map_entry: Optional[dace_nodes.MapEntry]) -> None: memlet = dace.Memlet(data=self.source.data, subset=self.subset) if map_entry is None: self.state.add_edge(self.source, None, self.dest, self.dest_conn, memlet) @@ -237,9 +237,9 @@ class EmptyInputEdge(DataflowInputEdge): """ state: dace.SDFGState - node: dace.nodes.Tasklet + node: dace_nodes.Tasklet - def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: + def connect(self, map_entry: Optional[dace_nodes.MapEntry]) -> None: if map_entry is None: # outside of a map scope it is possible to instantiate a tasklet node # without input connectors @@ -266,8 +266,8 @@ class DataflowOutputEdge: def connect( self, - map_exit: Optional[dace.nodes.MapExit], - dest: dace.nodes.AccessNode, + map_exit: Optional[dace_nodes.MapExit], + dest: dace_nodes.AccessNode, dest_subset: dace_subsets.Range, ) -> bool: """Create a connection to the `dest` node, writing the given `dest_subset`. @@ -281,11 +281,11 @@ def connect( write_edge = self.state.in_edges(self.result.dc_node)[0] # Check the kind of node which writes the result - if isinstance(write_edge.src, dace.nodes.Tasklet): + if isinstance(write_edge.src, dace_nodes.Tasklet): # The temporary data written by a tasklet can be safely deleted. assert map_exit is not None remove_last_node = True - elif isinstance(write_edge.src, dace.nodes.NestedSDFG): + elif isinstance(write_edge.src, dace_nodes.NestedSDFG): if isinstance(dest_desc, dace.data.Scalar): # We keep scalar temporary storage, as a general rule, since it # does not affect performance of the generated code. This scalar @@ -412,9 +412,9 @@ class LambdaToDataflow(eve.NodeVisitor): def _add_input_data_edge( self, - src: dace.nodes.AccessNode, + src: dace_nodes.AccessNode, src_subset: dace_subsets.Range, - dst_node: dace.nodes.Node, + dst_node: dace_nodes.Node, dst_conn: Optional[str] = None, ) -> None: edge = MemletInputEdge(self.state, src, src_subset, dst_node, dst_conn) @@ -439,7 +439,7 @@ def _add_map( List[Tuple[str, Union[str, dace.subsets.Subset]]], ], **kwargs: Any, - ) -> Tuple[dace.nodes.MapEntry, dace.nodes.MapExit]: + ) -> Tuple[dace_nodes.MapEntry, dace_nodes.MapExit]: """ Helper method to add a map in current state. @@ -455,7 +455,7 @@ def _add_tasklet( outputs: set[str] | Mapping[str, dace.dtypes.typeclass | None], code: str, **kwargs: Any, - ) -> tuple[dace.nodes.Tasklet, dict[str, str]]: + ) -> tuple[dace_nodes.Tasklet, dict[str, str]]: """ Helper method to add a tasklet in current state. @@ -482,7 +482,7 @@ def _add_mapped_tasklet( code: str, outputs: Mapping[str, dace.Memlet], **kwargs: Any, - ) -> tuple[dace.nodes.Tasklet, dace.nodes.MapEntry, dace.nodes.MapExit, dict[str, str]]: + ) -> tuple[dace_nodes.Tasklet, dace_nodes.MapEntry, dace_nodes.MapExit, dict[str, str]]: """ Helper method to add a mapped tasklet in current state. @@ -522,7 +522,7 @@ def _construct_local_view(self, field: MemletExpr | ValueExpr) -> ValueExpr: def _construct_tasklet_result( self, dc_dtype: dace.typeclass, - src_node: dace.nodes.Tasklet, + src_node: dace_nodes.Tasklet, src_connector: str, use_array: bool = False, ) -> ValueExpr: @@ -1327,7 +1327,7 @@ def _make_reduce_with_skip_values( reduce_init: SymbolExpr, reduce_identity: SymbolExpr, reduce_wcr: str, - result_node: dace.nodes.AccessNode, + result_node: dace_nodes.AccessNode, ) -> None: """ Helper method to lower reduction on a local field containing skip values. @@ -1600,7 +1600,7 @@ def _make_cartesian_shift( def _make_dynamic_neighbor_offset( self, offset_expr: MemletExpr | ValueExpr, - offset_table_node: dace.nodes.AccessNode, + offset_table_node: dace_nodes.AccessNode, origin_index: SymbolExpr, ) -> ValueExpr: """ @@ -1647,7 +1647,7 @@ def _make_unstructured_shift( self, it: IteratorExpr, conn_type: gtx_common.NeighborConnectivityType, - conn_node: dace.nodes.AccessNode, + conn_node: dace_nodes.AccessNode, offset_expr: DataExpr, ) -> IteratorExpr: """Implements shift in unstructured domain by means of a neighbor table.""" diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index 1d0e20cad9..eb84de9185 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -19,7 +19,7 @@ from typing import Any, Dict, Iterable, List, Mapping, Optional, Protocol, Sequence, Tuple, Union import dace -from dace import subsets as dace_subsets +from dace import nodes as dace_nodes, subsets as dace_subsets from dace.frontend.python import astutils as dace_astutils from gt4py import eve @@ -103,7 +103,7 @@ def add_map( List[Tuple[str, Union[str, dace.subsets.Subset]]], ], **kwargs: Any, - ) -> Tuple[dace.nodes.MapEntry, dace.nodes.MapExit]: + ) -> Tuple[dace_nodes.MapEntry, dace_nodes.MapExit]: """Wrapper of `dace.SDFGState.add_map` that assigns unique name.""" unique_name = self.unique_map_name(name) return state.add_map(unique_name, ndrange, **kwargs) @@ -118,7 +118,7 @@ def add_tasklet( code: str, language: dace.dtypes.Language = dace.dtypes.Language.Python, **kwargs: Any, - ) -> dace.nodes.Tasklet: + ) -> dace_nodes.Tasklet: """Wrapper of `dace.SDFGState.add_tasklet` that assigns a unique name. It also modifies the tasklet connectors by adding a prefix string (see @@ -157,7 +157,7 @@ def add_mapped_tasklet( outputs: Mapping[str, dace.Memlet], language: dace.dtypes.Language = dace.dtypes.Language.Python, **kwargs: Any, - ) -> tuple[dace.nodes.Tasklet, dace.nodes.MapEntry, dace.nodes.MapExit, dict[str, str]]: + ) -> tuple[dace_nodes.Tasklet, dace_nodes.MapEntry, dace_nodes.MapExit, dict[str, str]]: """Wrapper of `dace.SDFGState.add_mapped_tasklet` that assigns a unique name. It also modifies the tasklet connectors, in the same way as `add_tasklet()`. @@ -317,7 +317,7 @@ class SDFGBuilder(DataflowBuilder, Protocol): @abc.abstractmethod def make_field( self, - data_node: dace.nodes.AccessNode, + data_node: dace_nodes.AccessNode, data_type: ts.FieldType, ) -> gtir_to_sdfg_types.FieldopData: """Retrieve the field descriptor of a data node, including the origin information. @@ -391,7 +391,7 @@ def add_nested_sdfg( data_args: Mapping[str, gtir_to_sdfg_types.FieldopData | None], inner_result: gtir_to_sdfg_types.FieldopResult, capture_outer_data: bool, - ) -> tuple[dace.nodes.NestedSDFG, Mapping[str, dace.Memlet]]: + ) -> tuple[dace_nodes.NestedSDFG, Mapping[str, dace.Memlet]]: """ Helper function that prepares the input connections and symbol mapping before calling `SDFG.add_nestd_sdfg()` to add the given SDFG as a nested SDFG node @@ -525,7 +525,7 @@ def get_offset_provider_type(self, offset: str) -> gtx_common.OffsetProviderType def make_field( self, - data_node: dace.nodes.AccessNode, + data_node: dace_nodes.AccessNode, data_type: ts.FieldType, ) -> gtir_to_sdfg_types.FieldopData: local_dims = [dim for dim in data_type.dims if dim.kind == gtx_common.DimensionKind.LOCAL] @@ -658,7 +658,7 @@ def add_nested_sdfg( data_args: Mapping[str, gtir_to_sdfg_types.FieldopData | None], inner_result: gtir_to_sdfg_types.FieldopResult, capture_outer_data: bool, - ) -> tuple[dace.nodes.NestedSDFG, Mapping[str, dace.Memlet]]: + ) -> tuple[dace_nodes.NestedSDFG, Mapping[str, dace.Memlet]]: assert data_args.keys().isdisjoint(symbolic_args.keys()) # Collect the names of all output data, by flattening any tuple structure. diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_concat_where.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_concat_where.py index bbc0b3c1c4..ded61af77b 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_concat_where.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_concat_where.py @@ -17,7 +17,7 @@ from typing import Sequence import dace -from dace import subsets as dace_subsets +from dace import nodes as dace_nodes, subsets as dace_subsets from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir @@ -46,7 +46,7 @@ def _translate_concat_where_branch( output_domain: domain_utils.SymbolicDomain, output_type: ts.FieldType, output_desc: dace.data.Array, - output_node: dace.nodes.AccessNode, + output_node: dace_nodes.AccessNode, output_origin: Sequence[dace.symbolic.SymbolicType], ) -> None: """ diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py index 2428607236..59f1879bdd 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py @@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Iterable, Optional, Protocol import dace -from dace import subsets as dace_subsets +from dace import nodes as dace_nodes, subsets as dace_subsets from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import ir as gtir @@ -91,7 +91,7 @@ def _create_field_operator_impl( field_domain: gtir_domain.FieldopDomain, output_edge: gtir_dataflow.DataflowOutputEdge, output_type: ts.FieldType, - map_exit: dace.nodes.MapExit, + map_exit: dace_nodes.MapExit, ) -> gtir_to_sdfg_types.FieldopData: """ Helper method to allocate a temporary array that stores one field computed @@ -546,7 +546,7 @@ def _get_symbolic_value( symbolic_expr: dace.symbolic.SymExpr, scalar_type: ts.ScalarType, temp_name: Optional[str] = None, -) -> dace.nodes.AccessNode: +) -> dace_nodes.AccessNode: tasklet_node, connector_mapping = sdfg_builder.add_tasklet( name="get_value", sdfg=sdfg, diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py index b8f958f6ab..545e5c3d48 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py @@ -25,7 +25,7 @@ from typing import Iterable, Sequence import dace -from dace import subsets as dace_subsets +from dace import nodes as dace_nodes, subsets as dace_subsets from gt4py import eve from gt4py.eve.extended_typing import MaybeNestedInTuple @@ -92,7 +92,7 @@ def _create_scan_field_operator_impl( output_edge: gtir_dataflow.DataflowOutputEdge | None, output_domain: infer_domain.NonTupleDomainAccess, output_type: ts.FieldType, - map_exit: dace.nodes.MapExit | None, + map_exit: dace_nodes.MapExit | None, ) -> gtir_to_sdfg_types.FieldopData | None: """ Helper method to allocate a temporary array that stores one field computed @@ -180,7 +180,7 @@ def _create_scan_field_operator_impl( # to modify the stride of the scan column array inside the nested SDFG to match # the strides outside. nsdfg_scan = field_node_path[0].src - assert isinstance(nsdfg_scan, dace.nodes.NestedSDFG) + assert isinstance(nsdfg_scan, dace_nodes.NestedSDFG) inner_output_name = field_node_path[0].src_conn inner_output_desc = nsdfg_scan.sdfg.arrays[inner_output_name] assert len(inner_output_desc.shape) == 1 @@ -530,7 +530,7 @@ def connect_scan_output( def _handle_dataflow_result_of_nested_sdfg( sdfg_builder: gtir_to_sdfg.SDFGBuilder, - nsdfg_node: dace.nodes.NestedSDFG, + nsdfg_node: dace_nodes.NestedSDFG, inner_ctx: gtir_to_sdfg.SubgraphContext, outer_ctx: gtir_to_sdfg.SubgraphContext, inner_data: gtir_to_sdfg_types.FieldopData, diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py index 2d4e128a0a..9cc2b60968 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py @@ -14,7 +14,7 @@ from typing import Final, TypeAlias import dace -from dace import subsets as dace_subsets +from dace import nodes as dace_nodes, subsets as dace_subsets from gt4py.eve.extended_typing import MaybeNestedInTuple from gt4py.next import common as gtx_common @@ -40,7 +40,7 @@ class FieldopData: Pass an empty tuple for `ScalarType` data or zero-dimensional fields. """ - dc_node: dace.nodes.AccessNode + dc_node: dace_nodes.AccessNode gt_type: ts.FieldType | ts.ScalarType origin: tuple[dace.symbolic.SymbolicType, ...]