diff --git a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py index aed61639c2..7d77432a8f 100644 --- a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py +++ b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py @@ -432,6 +432,28 @@ def create_multinode_node_exchange( class GlobalReductions(Reductions): props: definitions.ProcessProperties + @staticmethod + def _min_identity(dtype: np.dtype, array_ns: ModuleType = np) -> data_alloc.NDArray: + if array_ns.issubdtype(dtype, array_ns.integer): + return array_ns.asarray([dtype.type(array_ns.iinfo(dtype).max)]) + elif array_ns.issubdtype(dtype, array_ns.floating): + return array_ns.asarray([dtype.type(array_ns.inf)]) + else: + raise TypeError(f"Unsupported dtype for min identity: {dtype}") + + @staticmethod + def _max_identity(dtype: np.dtype, array_ns: ModuleType = np) -> data_alloc.NDArray: + if array_ns.issubdtype(dtype, array_ns.integer): + return array_ns.asarray([dtype.type(array_ns.iinfo(dtype).min)]) + elif array_ns.issubdtype(dtype, array_ns.floating): + return array_ns.asarray([dtype.type(-array_ns.inf)]) + else: + raise TypeError(f"Unsupported dtype for max identity: {dtype}") + + @staticmethod + def _sum_identity(dtype: np.dtype, array_ns: ModuleType = np) -> data_alloc.NDArray: + return array_ns.asarray([dtype.type(0)]) + def _reduce( self, buffer: data_alloc.NDArray, @@ -459,7 +481,7 @@ def min(self, buffer: data_alloc.NDArray, array_ns: ModuleType = np) -> state_ut if self._calc_buffer_size(buffer, array_ns) == 0: raise ValueError("global_min requires a non-empty buffer") return self._reduce( - buffer if buffer.size != 0 else array_ns.asarray([array_ns.inf]), + buffer if buffer.size != 0 else self._min_identity(buffer.dtype, array_ns), array_ns.min, mpi4py.MPI.MIN, array_ns, @@ -469,7 +491,7 @@ def max(self, buffer: data_alloc.NDArray, array_ns: ModuleType = np) -> state_ut if self._calc_buffer_size(buffer, array_ns) == 0: raise ValueError("global_max requires a non-empty buffer") return self._reduce( - buffer if buffer.size != 0 else array_ns.asarray([-array_ns.inf]), + buffer if buffer.size != 0 else self._max_identity(buffer.dtype, array_ns), array_ns.max, mpi4py.MPI.MAX, array_ns, @@ -479,7 +501,7 @@ def sum(self, buffer: data_alloc.NDArray, array_ns: ModuleType = np) -> state_ut if self._calc_buffer_size(buffer, array_ns) == 0: raise ValueError("global_sum requires a non-empty buffer") return self._reduce( - buffer if buffer.size != 0 else array_ns.asarray([0]), + buffer if buffer.size != 0 else self._sum_identity(buffer.dtype, array_ns), array_ns.sum, mpi4py.MPI.SUM, array_ns, @@ -492,7 +514,7 @@ def mean(self, buffer: data_alloc.NDArray, array_ns: ModuleType = np) -> state_u return ( self._reduce( - buffer if buffer.size != 0 else array_ns.asarray([0]), + (buffer if buffer.size != 0 else self._sum_identity(buffer.dtype, array_ns)), array_ns.sum, mpi4py.MPI.SUM, array_ns, diff --git a/model/common/tests/common/decomposition/mpi_tests/test_mpi_decomposition.py b/model/common/tests/common/decomposition/mpi_tests/test_mpi_decomposition.py index 93a1fab38a..ed998a8970 100644 --- a/model/common/tests/common/decomposition/mpi_tests/test_mpi_decomposition.py +++ b/model/common/tests/common/decomposition/mpi_tests/test_mpi_decomposition.py @@ -326,7 +326,7 @@ def test_halo_exchange_for_sparse_field( assert test_helpers.dallclose(result.asnumpy(), field_ref.asnumpy()) -inputs_ls = [[2.0, 2.0, 4.0, 1.0], [2.0, 1.0], [30.0], []] +inputs_ls = [[2.0, 2.0, 4.0, 1.0], [2.0, 1.0], [30.0], [], [-10, 20, 4]] @pytest.mark.parametrize("global_list", inputs_ls) @@ -335,7 +335,7 @@ def test_halo_exchange_for_sparse_field( def test_global_reductions_min( processor_props: definitions.ProcessProperties, backend_like: model_backends.BackendLike, - global_list: list[float], + global_list: list[data_alloc.ScalarT], ) -> None: my_rank = processor_props.rank xp = data_alloc.import_array_ns(model_backends.get_allocator(backend_like)) @@ -360,7 +360,7 @@ def test_global_reductions_min( def test_global_reductions_max( processor_props: definitions.ProcessProperties, backend_like: model_backends.BackendLike, - global_list: list[float], + global_list: list[data_alloc.ScalarT], ) -> None: my_rank = processor_props.rank xp = data_alloc.import_array_ns(model_backends.get_allocator(backend_like)) @@ -385,7 +385,7 @@ def test_global_reductions_max( def test_global_reductions_sum( processor_props: definitions.ProcessProperties, backend_like: model_backends.BackendLike, - global_list: list[float], + global_list: list[data_alloc.ScalarT], ) -> None: my_rank = processor_props.rank xp = data_alloc.import_array_ns(model_backends.get_allocator(backend_like)) @@ -410,7 +410,7 @@ def test_global_reductions_sum( def test_global_reductions_mean( processor_props: definitions.ProcessProperties, backend_like: model_backends.BackendLike, - global_list: list[float], + global_list: list[data_alloc.ScalarT], ) -> None: my_rank = processor_props.rank xp = data_alloc.import_array_ns(model_backends.get_allocator(backend_like))