Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion external/dace
Submodule dace updated 1044 files
2 changes: 1 addition & 1 deletion external/gt4py
Submodule gt4py updated 45 files
+7 −0 .github/workflows/daily-ci.yml
+9 −0 CHANGELOG.md
+10 −5 pyproject.toml
+1 −1 src/gt4py/__about__.py
+1 −31 src/gt4py/cartesian/backend/base.py
+1 −1 src/gt4py/cartesian/backend/dace_backend.py
+1 −1 src/gt4py/cartesian/gtc/dace/passes/push_vertical_map_down.py
+22 −74 src/gt4py/cartesian/gtc/dace/treeir_to_stree.py
+4 −0 src/gt4py/cartesian/gtc/numpy/npir.py
+5 −0 src/gt4py/cartesian/gtc/numpy/npir_codegen.py
+5 −3 src/gt4py/cartesian/gtc/numpy/oir_to_npir.py
+6 −0 src/gt4py/next/config.py
+2 −2 src/gt4py/next/embedded/operators.py
+2 −1 src/gt4py/next/ffront/past_passes/type_deduction.py
+3 −7 src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py
+124 −122 src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py
+12 −14 src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py
+9 −5 src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py
+2 −0 src/gt4py/next/program_processors/runners/dace/transformations/__init__.py
+36 −12 src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py
+2 −8 src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py
+7 −2 src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py
+66 −0 src/gt4py/next/program_processors/runners/dace/transformations/scan_loop_unrolling.py
+56 −13 src/gt4py/next/program_processors/runners/dace/transformations/simplify.py
+107 −8 src/gt4py/next/program_processors/runners/dace/transformations/split_access_nodes.py
+22 −6 src/gt4py/next/program_processors/runners/dace/transformations/splitting_tools.py
+0 −4 src/gt4py/next/program_processors/runners/dace/workflow/backend.py
+78 −30 src/gt4py/next/program_processors/runners/dace/workflow/bindings.py
+3 −3 src/gt4py/next/program_processors/runners/dace/workflow/common.py
+8 −2 src/gt4py/next/program_processors/runners/dace/workflow/compilation.py
+1 −1 src/gt4py/next/program_processors/runners/dace/workflow/decoration.py
+12 −0 src/gt4py/next/program_processors/runners/dace/workflow/translation.py
+2 −2 tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py
+22 −24 tests/cartesian_tests/unit_tests/backend_tests/test_dace_backend.py
+0 −1 tests/next_tests/definitions.py
+4 −4 tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py
+76 −0 tests/next_tests/integration_tests/feature_tests/ffront_tests/test_named_collections.py
+2 −2 tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py
+1 −11 tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py
+6 −1 tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py
+186 −49 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py
+93 −0 ...nit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_scan_loop_unrolling.py
+45 −1 .../unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_split_access_node.py
+64 −1 ...ts/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_splitting_tools.py
+90 −165 uv.lock
9 changes: 4 additions & 5 deletions ndsl/dsl/dace/dace_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import enum
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, Self

import dace.config
Expand Down Expand Up @@ -326,11 +327,9 @@ def __init__(
)

# Attempt to kill the dace.conf to avoid confusion
if dace.config.Config._cfg_filename:
try:
os.remove(dace.config.Config._cfg_filename)
except OSError:
pass
dace_conf_to_kill = dace.config.Config.cfg_filename()
if dace_conf_to_kill is not None:
Path(dace_conf_to_kill).unlink(missing_ok=True)

self._backend = backend
self.tile_resolution = [tile_nx, tile_nx, tile_nz]
Expand Down
2 changes: 1 addition & 1 deletion ndsl/dsl/dace/labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def set_label(
if isinstance(sdfg, dace.CompiledSDFG):
return

for state in sdfg.states():
for state in sdfg.nodes():
if sdfg.in_edges(state) == []:
# With the topmost SDFG we have to skip over the
# "init" state
Expand Down
7 changes: 2 additions & 5 deletions ndsl/dsl/dace/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
from dace import SDFG, SDFGState, dtypes
from dace.frontend.common import op_repository as oprepo
from dace.frontend.python.newast import ProgramVisitor
from dace.frontend.python.replacements import (
UfuncInput,
UfuncOutput,
_datatype_converter,
)
from dace.frontend.python.replacements import UfuncInput, UfuncOutput
from dace.frontend.python.replacements.array_manipulation import _datatype_converter

from ndsl.dsl.typing import Float, Int

Expand Down
2 changes: 1 addition & 1 deletion tests/stencils/test_stencils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __call__(
self._column_min_ddim_stencil(data_ddim, min_value_ddim, min_index_ddim)


def test_column_operations(boilerplate):
def test_column_operations(boilerplate: tuple[StencilFactory, QuantityFactory]):
stencil_factory, quantity_factory = boilerplate
quantity_factory.add_data_dimensions({"ddim": 2})
data = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "n/a")
Expand Down
174 changes: 132 additions & 42 deletions tests/stree_optimizer/test_merge.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import TypeAlias

import dace
import pytest

from ndsl import QuantityFactory, StencilFactory, orchestrate
from ndsl.boilerplate import get_factories_single_tile_orchestrated
Expand Down Expand Up @@ -52,7 +55,7 @@ def __init__(
"trivial_merge",
"missing_merge_of_forscope_and_map",
"overcompute_merge",
"block_merge_when_depandencies_is_found",
"block_merge_when_dependencies_are_found",
"push_non_cartesian_for",
]
for method in orchestratable_methods:
Expand Down Expand Up @@ -98,7 +101,7 @@ def missing_merge_of_forscope_and_map(
self.stencil_with_forward_K(in_field, out_field)
self.stencil(in_field, out_field)

def block_merge_when_depandencies_is_found(
def block_merge_when_dependencies_are_found(
self,
in_field: FloatField,
out_field: FloatField,
Expand All @@ -124,19 +127,29 @@ def push_non_cartesian_for(
self.stencil(in_field, out_field)


def test_stree_merge_maps_IJK() -> None:
domain = (3, 3, 4)
stencil_factory, quantity_factory = get_factories_single_tile_orchestrated(
domain[0], domain[1], domain[2], 0, backend="dace:cpu_kfirst"
)
Factories: TypeAlias = tuple[StencilFactory, QuantityFactory]


class TestStreeMergeMapsIJK:
@pytest.fixture
def factories(self) -> Factories:
domain = (3, 3, 4)
return get_factories_single_tile_orchestrated(
domain[0], domain[1], domain[2], 0, backend="dace:cpu_kfirst"
)

@pytest.fixture
def code(self, factories: Factories) -> OrchestratedCode:
return OrchestratedCode(*factories)

code = OrchestratedCode(stencil_factory, quantity_factory)
in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")
def test_trivial_merge(self, code: OrchestratedCode, factories: Factories) -> None:
stencil_factory, quantity_factory = factories
in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")

with StreeOptimization():
code.trivial_merge(in_qty, out_qty)

with StreeOptimization():
# Trivial merge
code.trivial_merge(in_qty, out_qty)
precompiled_sdfg = get_SDFG_and_purge(stencil_factory)
all_maps = [
(me, state)
Expand All @@ -147,8 +160,16 @@ def test_stree_merge_maps_IJK() -> None:
assert len(all_maps) == 3
assert (out_qty.field[:] == 2).all()

# Merge IJ - but do not merge K map & for (missing feature)
code.missing_merge_of_forscope_and_map(in_qty, out_qty)
def test_missing_merge_of_forscope_and_map(
self, code: OrchestratedCode, factories: Factories
) -> None:
stencil_factory, quantity_factory = factories
in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")

with StreeOptimization():
code.missing_merge_of_forscope_and_map(in_qty, out_qty)

sdfg = get_SDFG_and_purge(stencil_factory).sdfg
all_maps = [
(me, state)
Expand All @@ -163,8 +184,16 @@ def test_stree_merge_maps_IJK() -> None:
]
assert len(all_loop_guard_state) == 1 # 1 For loop

# Overcompute merge in K - we merge and introduce an If guard
code.overcompute_merge(in_qty, out_qty)
def test_overcompute_merge(
self, code: OrchestratedCode, factories: Factories
) -> None:
stencil_factory, quantity_factory = factories
in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")

with StreeOptimization():
code.overcompute_merge(in_qty, out_qty)

sdfg = get_SDFG_and_purge(stencil_factory).sdfg
all_maps = [
(me, state)
Expand All @@ -174,8 +203,17 @@ def test_stree_merge_maps_IJK() -> None:
# ⚠️ WE EXPECT A FAILURE TO MERGE K (because of index) ⚠️
assert len(all_maps) == 4 # Should be all dmerged = 3

# Forbid merging when data dependancy is detected
code.block_merge_when_depandencies_is_found(in_qty, out_qty)
def test_block_merge_when_dependencies_are_found(
self, code: OrchestratedCode, factories: Factories
) -> None:
stencil_factory, quantity_factory = factories
in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")

with StreeOptimization():
# Forbid merging when data dependencies are detected
code.block_merge_when_dependencies_are_found(in_qty, out_qty)

sdfg = get_SDFG_and_purge(stencil_factory).sdfg
all_maps = [
me.params[0]
Expand All @@ -185,9 +223,18 @@ def test_stree_merge_maps_IJK() -> None:
# ⚠️ WE EXPECT A FAILURE TO MERGE K (because of index) ⚠️
assert len(all_maps) == 5 # Should be 4 = 2 IJ + 2 Ks (un-merged)

# Push non-cartesian ForScope inwward, which allow to potentially
# merge cartesian maps
code.push_non_cartesian_for(in_qty, out_qty)
def test_push_non_cartesian_for(
self, code: OrchestratedCode, factories: Factories
) -> None:
stencil_factory, quantity_factory = factories
in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")

with StreeOptimization():
# Push non-cartesian ForScope inwards, which allow to potentially
# merge cartesian maps
code.push_non_cartesian_for(in_qty, out_qty)

sdfg = get_SDFG_and_purge(stencil_factory).sdfg
all_maps = [
(me, state)
Expand All @@ -203,19 +250,26 @@ def test_stree_merge_maps_IJK() -> None:
assert len(all_loop_guard_state) == 1 # 1 For loop


def test_stree_merge_maps_KJI() -> None:
domain = (3, 3, 4)
stencil_factory, quantity_factory = get_factories_single_tile_orchestrated(
domain[0], domain[1], domain[2], 0, backend="dace:cpu_KJI"
)
class TestStreeMergeMapsKJI:
@pytest.fixture
def factories(self) -> Factories:
domain = (3, 3, 4)
return get_factories_single_tile_orchestrated(
domain[0], domain[1], domain[2], 0, backend="dace:cpu_KJI"
)

@pytest.fixture
def code(self, factories: Factories) -> OrchestratedCode:
return OrchestratedCode(*factories)

def test_trivial_merge(self, code: OrchestratedCode, factories: Factories) -> None:
stencil_factory, quantity_factory = factories
in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")

code = OrchestratedCode(stencil_factory, quantity_factory)
in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")
with StreeOptimization():
code.trivial_merge(in_qty, out_qty)

with StreeOptimization():
# Trivial merge
code.trivial_merge(in_qty, out_qty)
precompiled_sdfg = get_SDFG_and_purge(stencil_factory)
all_maps = [
(me, state)
Expand All @@ -226,8 +280,17 @@ def test_stree_merge_maps_KJI() -> None:
assert len(all_maps) == 3
assert (out_qty.field[:] == 2).all()

# K iterative loop - blocks all merges
code.missing_merge_of_forscope_and_map(in_qty, out_qty)
def test_missing_merge_of_forscope_and_map(
self, code: OrchestratedCode, factories: Factories
) -> None:
stencil_factory, quantity_factory = factories
in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")

with StreeOptimization():
# K iterative loop - blocks all merges
code.missing_merge_of_forscope_and_map(in_qty, out_qty)

sdfg = get_SDFG_and_purge(stencil_factory).sdfg
all_maps = [
(me, state)
Expand All @@ -242,8 +305,17 @@ def test_stree_merge_maps_KJI() -> None:
]
assert len(all_loop_guard_state) == 1 # 1 For loop

# Overcompute merge in K - we merge and introduce an If guard
code.overcompute_merge(in_qty, out_qty)
def test_overcompute_merge(
self, code: OrchestratedCode, factories: Factories
) -> None:
stencil_factory, quantity_factory = factories
in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")

with StreeOptimization():
# Overcompute merge in K - we merge and introduce an If guard
code.overcompute_merge(in_qty, out_qty)

sdfg = get_SDFG_and_purge(stencil_factory).sdfg
all_maps = [
(me, state)
Expand All @@ -253,8 +325,17 @@ def test_stree_merge_maps_KJI() -> None:
# ⚠️ WE EXPECT A FAILURE TO MERGE K (because of index) ⚠️
assert len(all_maps) == 6

# Forbid merging when data dependancy is detected
code.block_merge_when_depandencies_is_found(in_qty, out_qty)
def test_block_merge_when_dependencies_are_found(
self, code: OrchestratedCode, factories: Factories
) -> None:
stencil_factory, quantity_factory = factories
in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")

with StreeOptimization():
# Forbid merging when data dependencies are detected
code.block_merge_when_dependencies_are_found(in_qty, out_qty)

sdfg = get_SDFG_and_purge(stencil_factory).sdfg
all_maps = [
(me, state)
Expand All @@ -264,9 +345,18 @@ def test_stree_merge_maps_KJI() -> None:
# ⚠️ WE EXPECT A FAILURE TO MERGE K (because of index) ⚠️
assert len(all_maps) == 9

# Push non-cartesian ForScope inwward, which allow to potentially
# merge cartesian maps
code.push_non_cartesian_for(in_qty, out_qty)
def test_push_non_cartesian_for(
self, code: OrchestratedCode, factories: Factories
) -> None:
stencil_factory, quantity_factory = factories
in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")

with StreeOptimization():
# Push non-cartesian ForScope inwards, which allow to potentially
# merge cartesian maps
code.push_non_cartesian_for(in_qty, out_qty)

sdfg = get_SDFG_and_purge(stencil_factory).sdfg
all_maps = [
(me, state)
Expand Down
Loading