Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,28 @@ def create_multinode_node_exchange(
class GlobalReductions(Reductions):
props: definitions.ProcessProperties

@staticmethod
def _min_identity(dtype: np.dtype, array_ns: ModuleType = np) -> data_alloc.NDArray:
if array_ns.issubdtype(dtype, array_ns.integer):
return array_ns.asarray([dtype.type(array_ns.iinfo(dtype).max)])
elif array_ns.issubdtype(dtype, array_ns.floating):
return array_ns.asarray([dtype.type(array_ns.inf)])
else:
raise TypeError(f"Unsupported dtype for min identity: {dtype}")

@staticmethod
def _max_identity(dtype: np.dtype, array_ns: ModuleType = np) -> data_alloc.NDArray:
if array_ns.issubdtype(dtype, array_ns.integer):
return array_ns.asarray([dtype.type(array_ns.iinfo(dtype).min)])
elif array_ns.issubdtype(dtype, array_ns.floating):
return array_ns.asarray([dtype.type(-array_ns.inf)])
else:
raise TypeError(f"Unsupported dtype for max identity: {dtype}")

@staticmethod
def _sum_identity(dtype: np.dtype, array_ns: ModuleType = np) -> data_alloc.NDArray:
return array_ns.asarray([dtype.type(0)])

def _reduce(
self,
buffer: data_alloc.NDArray,
Expand Down Expand Up @@ -459,7 +481,7 @@ def min(self, buffer: data_alloc.NDArray, array_ns: ModuleType = np) -> state_ut
if self._calc_buffer_size(buffer, array_ns) == 0:
raise ValueError("global_min requires a non-empty buffer")
return self._reduce(
buffer if buffer.size != 0 else array_ns.asarray([array_ns.inf]),
buffer if buffer.size != 0 else self._min_identity(buffer.dtype, array_ns),
array_ns.min,
mpi4py.MPI.MIN,
array_ns,
Expand All @@ -469,7 +491,7 @@ def max(self, buffer: data_alloc.NDArray, array_ns: ModuleType = np) -> state_ut
if self._calc_buffer_size(buffer, array_ns) == 0:
raise ValueError("global_max requires a non-empty buffer")
return self._reduce(
buffer if buffer.size != 0 else array_ns.asarray([-array_ns.inf]),
buffer if buffer.size != 0 else self._max_identity(buffer.dtype, array_ns),
array_ns.max,
mpi4py.MPI.MAX,
array_ns,
Expand All @@ -479,7 +501,7 @@ def sum(self, buffer: data_alloc.NDArray, array_ns: ModuleType = np) -> state_ut
if self._calc_buffer_size(buffer, array_ns) == 0:
raise ValueError("global_sum requires a non-empty buffer")
return self._reduce(
buffer if buffer.size != 0 else array_ns.asarray([0]),
buffer if buffer.size != 0 else self._sum_identity(buffer.dtype, array_ns),
array_ns.sum,
mpi4py.MPI.SUM,
array_ns,
Expand All @@ -492,7 +514,7 @@ def mean(self, buffer: data_alloc.NDArray, array_ns: ModuleType = np) -> state_u

return (
self._reduce(
buffer if buffer.size != 0 else array_ns.asarray([0]),
(buffer if buffer.size != 0 else self._sum_identity(buffer.dtype, array_ns)),
array_ns.sum,
mpi4py.MPI.SUM,
array_ns,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def test_halo_exchange_for_sparse_field(
assert test_helpers.dallclose(result.asnumpy(), field_ref.asnumpy())


inputs_ls = [[2.0, 2.0, 4.0, 1.0], [2.0, 1.0], [30.0], []]
inputs_ls = [[2.0, 2.0, 4.0, 1.0], [2.0, 1.0], [30.0], [], [-10, 20, 4]]


@pytest.mark.parametrize("global_list", inputs_ls)
Expand All @@ -335,7 +335,7 @@ def test_halo_exchange_for_sparse_field(
def test_global_reductions_min(
processor_props: definitions.ProcessProperties,
backend_like: model_backends.BackendLike,
global_list: list[float],
global_list: list[data_alloc.ScalarT],
) -> None:
my_rank = processor_props.rank
xp = data_alloc.import_array_ns(model_backends.get_allocator(backend_like))
Expand All @@ -360,7 +360,7 @@ def test_global_reductions_min(
def test_global_reductions_max(
processor_props: definitions.ProcessProperties,
backend_like: model_backends.BackendLike,
global_list: list[float],
global_list: list[data_alloc.ScalarT],
) -> None:
my_rank = processor_props.rank
xp = data_alloc.import_array_ns(model_backends.get_allocator(backend_like))
Expand All @@ -385,7 +385,7 @@ def test_global_reductions_max(
def test_global_reductions_sum(
processor_props: definitions.ProcessProperties,
backend_like: model_backends.BackendLike,
global_list: list[float],
global_list: list[data_alloc.ScalarT],
) -> None:
my_rank = processor_props.rank
xp = data_alloc.import_array_ns(model_backends.get_allocator(backend_like))
Expand All @@ -410,7 +410,7 @@ def test_global_reductions_sum(
def test_global_reductions_mean(
processor_props: definitions.ProcessProperties,
backend_like: model_backends.BackendLike,
global_list: list[float],
global_list: list[data_alloc.ScalarT],
) -> None:
my_rank = processor_props.rank
xp = data_alloc.import_array_ns(model_backends.get_allocator(backend_like))
Expand Down