diff --git a/justfile b/justfile index 86c7af6..bb30bb2 100644 --- a/justfile +++ b/justfile @@ -10,10 +10,9 @@ set dotenv-path := ".env" val: - uv run ty check src/cascade - uv run ty check tests/cascade + uv run ty check src + uv run ty check tests uv run ty check integration_tests - # TODO eventually broaden type coverage to ekw as well uv run pytest -n8 tests fmt: uv run prek --all-files diff --git a/pyproject.toml b/pyproject.toml index 25e16c8..2eb505e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,14 @@ dynamic = ["version"] readme = "README.md" [dependency-groups] -dev = ["pytest", "pytest-xdist>=3.8", "prek", "ty==0.0.2", "build", "bokeh"] +dev = [ + "pytest", + "pytest-xdist>=3.8", + "prek", + "ty==0.0.2", + "build", + "bokeh", +] [tool.setuptools] diff --git a/src/earthkit/workflows/__init__.py b/src/earthkit/workflows/__init__.py index 5891f5e..6feb522 100644 --- a/src/earthkit/workflows/__init__.py +++ b/src/earthkit/workflows/__init__.py @@ -20,7 +20,7 @@ from . import fluent, mark from .graph import Graph, deduplicate_nodes from .graph.export import deserialise, serialise -from .visualise import visualise +from .visualise import visualise as _visualise_fn class Cascade: @@ -46,7 +46,7 @@ def serialise(self, filename: str): dill.dump(data, f) def visualise(self, *args, **kwargs): - return visualise(self._graph, *args, **kwargs) + return _visualise_fn(self._graph, *args, **kwargs) def __add__(self, other: "Cascade") -> "Cascade": if not isinstance(other, Cascade): diff --git a/src/earthkit/workflows/_qubed.py b/src/earthkit/workflows/_qubed.py index 1e6f8b4..98105a1 100644 --- a/src/earthkit/workflows/_qubed.py +++ b/src/earthkit/workflows/_qubed.py @@ -35,7 +35,7 @@ def _convert_num_to_abc(num: int) -> str: def get_name(child: "Qube", index: int) -> str: if "name" in child.metadata: - name_meta = child.metadata["name"] + name_meta = child.metadata["name"] # type: ignore[index] return str(np.unique_values(name_meta).flatten()[0]) return _convert_num_to_abc(index) diff --git a/src/earthkit/workflows/adapters.py b/src/earthkit/workflows/adapters.py index 1177f32..2251d4d 100644 --- a/src/earthkit/workflows/adapters.py +++ b/src/earthkit/workflows/adapters.py @@ -16,6 +16,8 @@ import logging +logger = logging.getLogger(__name__) + try: from cascade.low.core import DefaultTaskOutput diff --git a/src/earthkit/workflows/backends/__init__.py b/src/earthkit/workflows/backends/__init__.py index 0aaf404..277da52 100644 --- a/src/earthkit/workflows/backends/__init__.py +++ b/src/earthkit/workflows/backends/__init__.py @@ -86,7 +86,7 @@ def decorator(func: Callable) -> Callable: def check_num_args(*args, **kwargs): if accept_nested and len(args) == 1: args = args[0] - assert len(args) == expect, f"{func.__name__} expects two input arguments, got {len(args)}" + assert len(args) == expect, f"{func.__name__} expects two input arguments, got {len(args)}" # type: ignore[union-attr] return func(*args, **kwargs) return check_num_args diff --git a/src/earthkit/workflows/backends/earthkit.py b/src/earthkit/workflows/backends/earthkit.py index 6db43fe..3937194 100644 --- a/src/earthkit/workflows/backends/earthkit.py +++ b/src/earthkit/workflows/backends/earthkit.py @@ -6,7 +6,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -from typing import TypeAlias +from typing import Callable, TypeAlias import array_api_compat from earthkit.data import FieldList @@ -33,7 +33,7 @@ def comp_str2func(array_module, comparison: str): return array_module.greater -Metadata: TypeAlias = "dict | callable | None" +Metadata: TypeAlias = dict | Callable | None def resolve_metadata(metadata: Metadata, *args) -> dict: @@ -64,7 +64,7 @@ def new_fieldlist(data, metadata: list[ekdMetadata], overrides: dict): class FieldListBackend: - def _merge(*fieldlists: list[FieldList]): + def _merge(*fieldlists: FieldList): """Merge fieldlist elements into a single array. fieldlists with different number of fields must be concatenated, otherwise, the elements in each fieldlist are stacked along a new dimension @@ -76,7 +76,7 @@ def _merge(*fieldlists: list[FieldList]): xp = array_api_compat.array_namespace(*values) return xp.asarray(values) - def multi_arg_function(func: str, *arrays: list[FieldList], metadata: Metadata = None) -> FieldList: + def multi_arg_function(func: str, *arrays: FieldList, metadata: Metadata = None) -> FieldList: merged_array = FieldListBackend._merge(*arrays) xp = array_api_compat.array_namespace(*merged_array) res = standardise_output(getattr(xp, func)(merged_array, axis=0)) @@ -101,53 +101,56 @@ def two_arg_function(func: str, *arrays: FieldList, metadata: Metadata = None) - res = getattr(xp, func)(val1, val2) return new_fieldlist(res, [arrays[0][x].metadata() for x in range(len(res))], metadata) - def mean(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList: + def mean(*arrays: FieldList, metadata: Metadata = None) -> FieldList: return FieldListBackend.multi_arg_function("mean", *arrays, metadata=metadata) - def std(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList: + def std(*arrays: FieldList, metadata: Metadata = None) -> FieldList: return FieldListBackend.multi_arg_function("std", *arrays, metadata=metadata) - def min(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList: + def min(*arrays: FieldList, metadata: Metadata = None) -> FieldList: return FieldListBackend.multi_arg_function("min", *arrays, metadata=metadata) - def max(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList: + def max(*arrays: FieldList, metadata: Metadata = None) -> FieldList: return FieldListBackend.multi_arg_function("max", *arrays, metadata=metadata) - def sum(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList: + def sum(*arrays: FieldList, metadata: Metadata = None) -> FieldList: return FieldListBackend.multi_arg_function("sum", *arrays, metadata=metadata) - def prod(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList: + def prod(*arrays: FieldList, metadata: Metadata = None) -> FieldList: return FieldListBackend.multi_arg_function("prod", *arrays, metadata=metadata) - def var(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList: + def var(*arrays: FieldList, metadata: Metadata = None) -> FieldList: return FieldListBackend.multi_arg_function("var", *arrays, metadata=metadata) - def stack(*arrays: list[FieldList], axis: int = 0) -> FieldList: + def stack(*arrays: FieldList, axis: int = 0) -> FieldList: if axis != 0: raise ValueError("Can not stack FieldList along axis != 0") assert all([len(x) == 1 for x in arrays]), "Can not stack FieldLists with more than one element, use concat" return FieldListBackend.concat(*arrays) - def add(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList: + def add(*arrays: FieldList, metadata: Metadata = None) -> FieldList: return FieldListBackend.two_arg_function("add", *arrays, metadata=metadata) - def subtract(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList: + def subtract(*arrays: FieldList, metadata: Metadata = None) -> FieldList: return FieldListBackend.two_arg_function("subtract", *arrays, metadata=metadata) @num_args(2) - def diff(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList: - return FieldListBackend.multiply(FieldListBackend.subtract(*arrays, metadata=metadata), -1) + def diff(*arrays: FieldList, metadata: Metadata = None) -> FieldList: + return FieldListBackend.multiply( + FieldListBackend.subtract(*arrays, metadata=metadata), + -1, # type: ignore[arg-type] + ) - def multiply(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList: + def multiply(*arrays: FieldList, metadata: Metadata = None) -> FieldList: return FieldListBackend.two_arg_function("multiply", *arrays, metadata=metadata) - def divide(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList: + def divide(*arrays: FieldList, metadata: Metadata = None) -> FieldList: return FieldListBackend.two_arg_function("divide", *arrays, metadata=metadata) - def pow(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList: + def pow(*arrays: FieldList, metadata: Metadata = None) -> FieldList: return FieldListBackend.two_arg_function("pow", *arrays, metadata=metadata) - def concat(*arrays: list[FieldList]) -> FieldList: + def concat(*arrays: FieldList) -> FieldList: """Concatenates the list of fields inside each FieldList into a single FieldList object @@ -176,7 +179,7 @@ def take( if dim != 0: raise ValueError("Can not slice from FieldList along dim != 0") if isinstance(indices, int): - indices = [indices] + indices = [indices] # type: ignore[assignment] ret = array[indices] else: if not isinstance(dim, str): @@ -190,7 +193,7 @@ def take( return FieldList.from_array(ret.values, ret.metadata()) - def norm(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList: + def norm(*arrays: FieldList, metadata: Metadata = None) -> FieldList: merged_array = FieldListBackend._merge(*arrays) xp = array_api_compat.array_namespace(merged_array) norm = standardise_output(xp.sqrt(xp.sum(xp.pow(merged_array, 2), axis=0))) diff --git a/src/earthkit/workflows/compilers.py b/src/earthkit/workflows/compilers.py index ac6fb3e..d9f2edd 100644 --- a/src/earthkit/workflows/compilers.py +++ b/src/earthkit/workflows/compilers.py @@ -19,6 +19,8 @@ from earthkit.workflows.graph import Graph, serialise +logger = logging.getLogger(__name__) + try: from cascade.low.core import JobInstance from cascade.low.into import graph2job as cascadeInto @@ -28,7 +30,7 @@ def cascadeInto(graph: dict) -> Any: raise NotImplementedError("failed to import cascade execution engine") - JobInstance = Any + JobInstance = object # type: ignore[assignment] Engine = Literal["cascade"] diff --git a/src/earthkit/workflows/fluent.py b/src/earthkit/workflows/fluent.py index b008c19..0a338b9 100644 --- a/src/earthkit/workflows/fluent.py +++ b/src/earthkit/workflows/fluent.py @@ -75,7 +75,7 @@ def name(self) -> str: if isinstance(self.func, str): return self.func if hasattr(self.func, "__name__"): - return self.func.__name__ + return self.func.__name__ # type: ignore[union-attr] return "" def __str__(self) -> str: @@ -166,7 +166,7 @@ def __str__(self) -> str: return f"Node {self.name}, inputs: {[x.parent.name for x in self.inputs.values()]}, payload: {self.payload}" def copy(self) -> "Node": - return self.__class__(*self._for_copy) + return self.__class__(*self._for_copy) # type: ignore[arg-type] class Action: @@ -270,7 +270,7 @@ def join( def transform( self, - func: Callable[["Action", Any], "Action"], + func: Callable[..., "Action"], params: list, dim: str | Coord, axis: int = 0, @@ -350,7 +350,7 @@ def broadcast( ) broadcasted_nodes = array.broadcast_like(narray, exclude=exclude) new_nodes = np.empty(broadcasted_nodes.shape, dtype=object) - it = np.nditer( + it = np.nditer( # type: ignore[call-overload] array.transpose(*broadcasted_nodes.dims, missing_dims="ignore"), flags=["multi_index", "refs_ok"], ) @@ -412,7 +412,7 @@ def expand( def map( self, - payload: PayloadFunc | Payload | np.ndarray[Any, Any], + payload: PayloadFunc | Payload | np.ndarray[Any, Any] | list, yields: Coord | None = None, path: Optional[str] = None, ) -> "Action": @@ -448,7 +448,7 @@ def map( # Applies operation to every node, keeping node array structure new_nodes = np.empty(narray.shape, dtype=object) - it = np.nditer(narray, flags=["multi_index", "refs_ok"]) + it = np.nditer(narray, flags=["multi_index", "refs_ok"]) # type: ignore[call-overload] node_payload = payload for node in it: if not isinstance(payload, PayloadFunc | Payload): # type: ignore @@ -516,7 +516,9 @@ def reduce( raise ValueError("Can not batch the execution of a generator") if batch_size > 1 and batch_size < nodetree_array(batched.nodes).sizes[dim]: if not getattr(payload.func, "batchable", False): - raise ValueError(f"Function {payload.func.name()} is not batchable, but batch_size {batch_size} is specified") + raise ValueError( + f"Function {payload.func.name()} is not batchable, but batch_size {batch_size} is specified" # type: ignore[union-attr] + ) while batch_size < nodetree_array(batched.nodes).sizes[dim]: lst = nodetree_array(batched.nodes).coords[dim].data @@ -536,7 +538,7 @@ def reduce( new_dims = [x for x in batched_narray.dims if x != dim] transposed_nodes = batched_narray.transpose(dim, *new_dims) new_nodes = np.empty(transposed_nodes.shape[1:], dtype=object) - it = np.nditer(new_nodes, flags=["multi_index", "refs_ok"]) + it = np.nditer(new_nodes, flags=["multi_index", "refs_ok"]) # type: ignore[call-overload] for _ in it: inputs = transposed_nodes[(slice(None, None, 1), *it.multi_index)].data new_nodes[it.multi_index] = Node(payload, inputs, num_outputs=len(yields[1]) if yields else 1) @@ -568,6 +570,7 @@ def flatten( axis: int = 0, path: Optional[str] = None, backend_kwargs: dict = {}, + payload_metadata: dict | None = None, ) -> "Action": """Flattens the array of nodes along specified dimension by creating new nodes from stacking internal data of nodes along that dimension. @@ -604,7 +607,7 @@ def set_path(self, path: str) -> "Action": raise NotImplementedError("Multiple node arrays present, can not set single path") return type(self)(nodetree_from_dict({path: nodetree_array(self.nodes)})) - def split(self, expansion: Optional[dict[str, PayloadFunc | Payload]] = None) -> "Action": + def split(self, expansion: dict[str, PayloadFunc | Payload]) -> "Action": """Create action containing new node arrays by splitting an existing node array by the specified functions in expansion @@ -629,7 +632,7 @@ def split(self, expansion: Optional[dict[str, PayloadFunc | Payload]] = None) -> node_arrays[path] = nodetree_array(action.map(func).nodes, parent) return type(self)(nodetree_from_dict(node_arrays)) - def _validate_criteria(cls, array: xr.DataArray, criteria: dict) -> tuple[bool, dict]: + def _validate_criteria(self, array: xr.DataArray, criteria: dict) -> tuple[bool, dict]: keys = list(criteria.keys()) new_criteria = criteria.copy() for key in keys: @@ -664,7 +667,7 @@ def select( nodes = self.nodes if path is None else self.nodes[path] new_nodes = {} - for npath, narray in nodetree_arrays(nodes): + for npath, narray in nodetree_arrays(nodes): # type: ignore[arg-type] valid, new_criteria = self._validate_criteria(narray, crit) if valid: try: @@ -702,7 +705,7 @@ def iselect( nodes = self.nodes if path is None else self.nodes[path] new_nodes = {} - for npath, narray in nodetree_arrays(nodes): + for npath, narray in nodetree_arrays(nodes): # type: ignore[arg-type] valid, new_criteria = self._validate_criteria(narray, crit) if valid: try: @@ -724,6 +727,7 @@ def concatenate( keep_dim: bool = False, path: Optional[str] = None, backend_kwargs: dict = {}, + payload_metadata: dict | None = None, ) -> "Action": return _combine_nodes(self, "concat", dim, batch_size, keep_dim, path, backend_kwargs) @@ -736,6 +740,7 @@ def stack( axis: int = 0, path: Optional[str] = None, backend_kwargs: dict = {}, + payload_metadata: dict | None = None, ) -> "Action": return _combine_nodes( self, @@ -755,6 +760,7 @@ def sum( keep_dim: bool = False, path: Optional[str] = None, backend_kwargs: dict = {}, + payload_metadata: dict | None = None, ) -> "Action": return self.reduce( Payload(backends.sum, kwargs=backend_kwargs), @@ -772,6 +778,7 @@ def mean( keep_dim: bool = False, path: Optional[str] = None, backend_kwargs: dict = {}, + payload_metadata: dict | None = None, ) -> "Action": action = self for npath, narray in nodetree_arrays(self.select(path=path).nodes): @@ -804,6 +811,7 @@ def std( keep_dim: bool = False, path: Optional[str] = None, backend_kwargs: dict = {}, + payload_metadata: dict | None = None, ) -> "Action": action = self for npath, narray in nodetree_arrays(self.select(path=path).nodes): @@ -844,6 +852,7 @@ def max( keep_dim: bool = False, path: Optional[str] = None, backend_kwargs: dict = {}, + payload_metadata: dict | None = None, ) -> "Action": return self.reduce( Payload(backends.max, kwargs=backend_kwargs), @@ -861,6 +870,7 @@ def min( keep_dim: bool = False, path: Optional[str] = None, backend_kwargs: dict = {}, + payload_metadata: dict | None = None, ) -> "Action": return self.reduce( Payload(backends.min, kwargs=backend_kwargs), @@ -878,6 +888,7 @@ def prod( keep_dim: bool = False, path: Optional[str] = None, backend_kwargs: dict = {}, + payload_metadata: dict | None = None, ) -> "Action": return self.reduce( Payload(backends.prod, kwargs=backend_kwargs), @@ -906,6 +917,7 @@ def subtract( other: "Action | float", path: Optional[str] = None, backend_kwargs: dict = {}, + payload_metadata: dict | None = None, ) -> "Action": return self.__two_arg_method(backends.subtract, other, path=path, **backend_kwargs) @@ -915,6 +927,7 @@ def divide( other: "Action | float", path: Optional[str] = None, backend_kwargs: dict = {}, + payload_metadata: dict | None = None, ) -> "Action": return self.__two_arg_method(backends.divide, other, path=path, **backend_kwargs) @@ -924,6 +937,7 @@ def add( other: "Action | float", path: Optional[str] = None, backend_kwargs: dict = {}, + payload_metadata: dict | None = None, ) -> "Action": return self.__two_arg_method(backends.add, other, path=path, **backend_kwargs) @@ -933,6 +947,7 @@ def multiply( other: "Action | float", path: Optional[str] = None, backend_kwargs: dict = {}, + payload_metadata: dict | None = None, ) -> "Action": return self.__two_arg_method(backends.multiply, other, path=path, **backend_kwargs) @@ -942,6 +957,7 @@ def power( other: "Action | float", path: Optional[str] = None, backend_kwargs: dict = {}, + payload_metadata: dict | None = None, ) -> "Action": return self.__two_arg_method(backends.pow, other, path=path, **backend_kwargs) @@ -1040,20 +1056,21 @@ def _combine_nodes( def from_source( - payloads_list: (np.ndarray[Any, Any] | dict[str, np.ndarray[Any, Any]]), # values are Callables + payloads_list: (np.ndarray[Any, Any] | dict[str, np.ndarray[Any, Any]] | list[Any] | PayloadFunc | Payload), # values are Callables yields: Coord | None = None, dims: list | None = None, coords: dict | None = None, action=Action, ) -> Action: - if not isinstance(payloads_list, dict): - payloads_list = {"/": payloads_list} + payloads_dict: dict[str, Any] = ( # type: ignore[assignment] + payloads_list if isinstance(payloads_list, dict) else {"/": payloads_list} + ) node_arrays = {} - for nindex, (path, parray) in enumerate(payloads_list.items()): + for nindex, (path, parray) in enumerate(payloads_dict.items()): payloads = xr.DataArray(parray, dims=dims, coords=coords) nodes = xr.DataArray(np.empty(payloads.shape, dtype=object), dims=dims, coords=coords) - it = np.nditer(payloads, flags=["multi_index", "refs_ok"]) + it = np.nditer(payloads, flags=["multi_index", "refs_ok"]) # type: ignore[call-overload] # Ensure all source nodes have a unique name node_names = set() for item in it: diff --git a/src/earthkit/workflows/graph/expand.py b/src/earthkit/workflows/graph/expand.py index 6d4e4c7..348018d 100644 --- a/src/earthkit/workflows/graph/expand.py +++ b/src/earthkit/workflows/graph/expand.py @@ -192,8 +192,8 @@ def node(self, n: Node, **inputs: Output) -> Node | _Subgraph: output_map = None else: expanded, input_map, output_map = expanded # type: ignore # expanded[2] is dict[str, str|None] - sp = self.splicer(n.name, inputs, input_map, n.outputs, output_map) - return sp.transform(expanded) + sp = self.splicer(n.name, inputs, input_map, n.outputs, output_map) # type: ignore[arg-type] + return sp.transform(expanded) # type: ignore[arg-type] def graph(self, graph: Graph, sinks: list[Node | _Subgraph]) -> Graph: new_sinks = [] diff --git a/src/earthkit/workflows/graph/graphviz.py b/src/earthkit/workflows/graph/graphviz.py index e7fbd8d..8dd9fb6 100644 --- a/src/earthkit/workflows/graph/graphviz.py +++ b/src/earthkit/workflows/graph/graphviz.py @@ -46,7 +46,7 @@ def render_graph(graph: Graph, **kwargs) -> str: Keyword arguments are passed to `graphviz.Source.render`. """ - import graphviz + import graphviz # type: ignore[import-untyped] dot = to_dot(graph) src = graphviz.Source(dot) diff --git a/src/earthkit/workflows/graph/transform.py b/src/earthkit/workflows/graph/transform.py index 02174bf..4d5a289 100644 --- a/src/earthkit/workflows/graph/transform.py +++ b/src/earthkit/workflows/graph/transform.py @@ -88,7 +88,7 @@ def __transform(self, node: Node, inputs: dict[str, Any]) -> Any: def __transform_output(self, node: Any, output: Output) -> Any: if hasattr(self, "output"): - return self.output(node, output.name) + return getattr(self, "output")(node, output.name) if isinstance(node, dict): return node[output.name] try: @@ -98,5 +98,5 @@ def __transform_output(self, node: Any, output: Output) -> Any: def __transform_graph(self, graph: Graph, sinks: list[dict[str, Any]]) -> Any: if hasattr(self, "graph"): - return self.graph(graph, sinks) + return getattr(self, "graph")(graph, sinks) return sinks diff --git a/src/earthkit/workflows/nodetree.py b/src/earthkit/workflows/nodetree.py index 5cb446d..b8a043f 100644 --- a/src/earthkit/workflows/nodetree.py +++ b/src/earthkit/workflows/nodetree.py @@ -23,7 +23,7 @@ def nodetree_from_dict(data: dict[str, xr.DataArray] | dict[str, xr.Dataset], *a raise ValueError("NodeTree can only be created from dict of xr.DataArray or xr.Dataset") tree = xr.DataTree.from_dict(new_data, *args, **kwargs) for leaf in tree.leaves: - var = list(leaf.dataset.data_vars.keys())[0] + var: str = list(leaf.dataset.data_vars.keys())[0] # type: ignore[assignment] if np.any(leaf[var].isnull()): raise ValueError(f"Nodes in Action can not contain NaNs. Found NaN in nodeset {leaf.path}, variable {var}") if not tree.is_hollow: @@ -33,7 +33,7 @@ def nodetree_from_dict(data: dict[str, xr.DataArray] | dict[str, xr.Dataset], *a def nodetree_arrays(nodetree: xr.DataTree) -> Iterable[Tuple[str, xr.DataArray]]: for leaf in nodetree.leaves: - var = list(leaf.dataset.data_vars.keys())[0] + var: str = list(leaf.dataset.data_vars.keys())[0] # type: ignore[assignment] yield leaf.path, leaf[var] @@ -46,8 +46,8 @@ def nodetree_array(nodetree: xr.DataTree, path: Optional[str] = None) -> xr.Data if path not in nodetree.leaves: raise KeyError(f"Path {path} not found in nodetree") leaf = nodetree[path] - var = list(leaf.dataset.data_vars.keys())[0] - return leaf[var] + var: str = list(leaf.dataset.data_vars.keys())[0] # type: ignore[assignment] + return leaf[var] # type: ignore[return-value] def nodetree_size(nodetree: xr.DataTree) -> int: diff --git a/tests/earthkit_workflows/backends/test_arrayapi.py b/tests/earthkit_workflows/backends/test_arrayapi.py index 31d666c..d4f5795 100644 --- a/tests/earthkit_workflows/backends/test_arrayapi.py +++ b/tests/earthkit_workflows/backends/test_arrayapi.py @@ -14,7 +14,7 @@ class TestArrayAPIBackend(BackendBase): - def input_generator(self, num_inputs: int, input_shape=(2, 3)): + def input_generator(self, num_inputs: int, input_shape=(2, 3)): # type: ignore[override] return [np.random.rand(*input_shape) for _ in range(num_inputs)] def shape(self, array): diff --git a/tests/earthkit_workflows/backends/test_xarray.py b/tests/earthkit_workflows/backends/test_xarray.py index 20b1f3d..27ad62c 100644 --- a/tests/earthkit_workflows/backends/test_xarray.py +++ b/tests/earthkit_workflows/backends/test_xarray.py @@ -15,6 +15,9 @@ class XarrayBackend(BackendBase): + def values(self, array): + raise NotImplementedError + @pytest.mark.parametrize( ["num_inputs", "kwargs", "output_shape"], [ @@ -80,7 +83,7 @@ def test_take_extended(self, args, kwargs, output_shape): class TestXarrayDataArrayBackend(XarrayBackend): - def input_generator(self, number: int, shape=(2, 3)): + def input_generator(self, number: int, shape=(2, 3)): # type: ignore[override] return [ xr.DataArray( np.random.rand(*shape), @@ -98,7 +101,7 @@ def values(self, array): class TestXarrayDatasetBackend(XarrayBackend): - def input_generator(self, number: int, shape=(2, 3)): + def input_generator(self, number: int, shape=(2, 3)): # type: ignore[override] return [ xr.Dataset( { diff --git a/tests/earthkit_workflows/helpers.py b/tests/earthkit_workflows/helpers.py index 91bf449..f5a96cf 100644 --- a/tests/earthkit_workflows/helpers.py +++ b/tests/earthkit_workflows/helpers.py @@ -20,7 +20,7 @@ def __init__(self, name: str): def mock_action(shape: tuple) -> Action: nodes = np.empty(shape, dtype=object) - it = np.nditer(nodes, flags=["multi_index", "refs_ok"]) + it = np.nditer(nodes, flags=["multi_index", "refs_ok"]) # type: ignore[call-overload] for _ in it: nodes[it.multi_index] = MockNode(f"{it.multi_index}") nodes_xr = xr.DataArray(nodes, coords={f"dim_{x}": list(range(dim)) for x, dim in enumerate(shape)}) diff --git a/tests/earthkit_workflows/test_fluent.py b/tests/earthkit_workflows/test_fluent.py index f8fbaab..e7b850a 100644 --- a/tests/earthkit_workflows/test_fluent.py +++ b/tests/earthkit_workflows/test_fluent.py @@ -120,7 +120,7 @@ def test_broadcast(): out_array = nodetree_array(output_action.nodes) assert out_array.shape == (2, 3, 3) assert len(out_array.data.item(0).inputs) == 1 - it = np.nditer(out_array, flags=["multi_index", "refs_ok"]) + it = np.nditer(out_array, flags=["multi_index", "refs_ok"]) # type: ignore[call-overload] for _ in it: print(it.multi_index) assert out_array[it.multi_index].item(0).inputs["input0"].parent == nodetree_array(input_action.nodes)[it.multi_index[:2]].item(0) @@ -248,7 +248,7 @@ def test_serialisation(tmpdir, task_graph): def test_invalid_registration(): with pytest.raises(TypeError): - Action.register("test", None) + Action.register("test", None) # type: ignore[arg-type] def test_registration():