Skip to content

Commit

Permalink
add count process (#273)
Browse files Browse the repository at this point in the history
* add count process

* add count process

* add count process

* update tests
  • Loading branch information
ValentinaHutter authored Sep 10, 2024
1 parent fb59599 commit ebd456c
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
22 changes: 21 additions & 1 deletion openeo_processes_dask/process_implementations/arrays.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import itertools
import logging
from typing import Any, Optional
from typing import Any, Callable, Optional, Union

import dask.array as da
import numpy as np
Expand All @@ -10,6 +11,7 @@
from openeo_pg_parser_networkx.pg_schema import DateTime
from xarray.core.duck_array_ops import isnull, notnull

from openeo_processes_dask.process_implementations.comparison import is_valid
from openeo_processes_dask.process_implementations.cubes.utils import _is_dask_array
from openeo_processes_dask.process_implementations.exceptions import (
ArrayElementNotAvailable,
Expand All @@ -35,6 +37,7 @@
"order",
"rearrange",
"sort",
"count",
]


Expand Down Expand Up @@ -337,3 +340,20 @@ def sort(
return data_sorted_flip
elif nodata == True: # default sort behaviour, np.nan values are put last
return data_sorted


def count(
data: ArrayLike,
condition: Optional[Union[Callable, bool]] = None,
context: Any = None,
axis=None,
keepdims=False,
):
if condition is None:
valid = is_valid(data)
return np.nansum(valid, axis=axis, keepdims=keepdims)
if condition is True:
return np.nansum(np.ones_like(data), axis=axis, keepdims=keepdims)
if callable(condition):
count = condition(data)
return np.nansum(count, axis=axis, keepdims=keepdims)
41 changes: 41 additions & 0 deletions tests/test_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,3 +459,44 @@ def test_reduce_dimension(
)
assert output_cube[0, 0, 0].data.compute().item() is True
assert not output_cube[slice(1, None), :, :].data.compute().any()


@pytest.mark.parametrize("size", [(3, 3, 2, 4)])
@pytest.mark.parametrize("dtype", [np.float32])
def test_count(temporal_interval, bounding_box, random_raster_data, process_registry):
input_cube = create_fake_rastercube(
data=random_raster_data,
spatial_extent=bounding_box,
temporal_extent=temporal_interval,
bands=["B02", "B03", "B04", "B08"],
backend="dask",
)

_process = partial(
process_registry["count"].implementation,
data=ParameterReference(from_parameter="data"),
)
output_cube = reduce_dimension(data=input_cube, reducer=_process, dimension="bands")
general_output_checks(
input_cube=input_cube,
output_cube=output_cube,
verify_attrs=False,
verify_crs=True,
)
assert output_cube.dims == ("x", "y", "t")
xr.testing.assert_equal(output_cube, xr.zeros_like(output_cube) + 4)

_process = partial(
process_registry["count"].implementation,
data=ParameterReference(from_parameter="data"),
condition=True,
)
output_cube = reduce_dimension(data=input_cube, reducer=_process, dimension="bands")
general_output_checks(
input_cube=input_cube,
output_cube=output_cube,
verify_attrs=False,
verify_crs=True,
)
assert output_cube.dims == ("x", "y", "t")
xr.testing.assert_equal(output_cube, xr.zeros_like(output_cube) + 4)

0 comments on commit ebd456c

Please sign in to comment.