Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
set dotenv-path := ".env"

val:
uv run ty check src/cascade
uv run ty check tests/cascade
uv run ty check src
uv run ty check tests
uv run ty check integration_tests
# TODO eventually broaden type coverage to ekw as well
uv run pytest -n8 tests
fmt:
uv run prek --all-files
Expand Down
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,14 @@ dynamic = ["version"]
readme = "README.md"

[dependency-groups]
dev = ["pytest", "pytest-xdist>=3.8", "prek", "ty==0.0.2", "build", "bokeh"]
dev = [
"pytest",
"pytest-xdist>=3.8",
"prek",
"ty==0.0.2",
"build",
"bokeh",
]


[tool.setuptools]
Expand Down
4 changes: 2 additions & 2 deletions src/earthkit/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from . import fluent, mark
from .graph import Graph, deduplicate_nodes
from .graph.export import deserialise, serialise
from .visualise import visualise
from .visualise import visualise as _visualise_fn


class Cascade:
Expand All @@ -46,7 +46,7 @@ def serialise(self, filename: str):
dill.dump(data, f)

def visualise(self, *args, **kwargs):
return visualise(self._graph, *args, **kwargs)
return _visualise_fn(self._graph, *args, **kwargs)

def __add__(self, other: "Cascade") -> "Cascade":
if not isinstance(other, Cascade):
Expand Down
2 changes: 1 addition & 1 deletion src/earthkit/workflows/_qubed.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _convert_num_to_abc(num: int) -> str:

def get_name(child: "Qube", index: int) -> str:
if "name" in child.metadata:
name_meta = child.metadata["name"]
name_meta = child.metadata["name"] # type: ignore[index]
return str(np.unique_values(name_meta).flatten()[0])
return _convert_num_to_abc(index)

Expand Down
2 changes: 2 additions & 0 deletions src/earthkit/workflows/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import logging

logger = logging.getLogger(__name__)

try:
from cascade.low.core import DefaultTaskOutput

Expand Down
2 changes: 1 addition & 1 deletion src/earthkit/workflows/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def decorator(func: Callable) -> Callable:
def check_num_args(*args, **kwargs):
if accept_nested and len(args) == 1:
args = args[0]
assert len(args) == expect, f"{func.__name__} expects two input arguments, got {len(args)}"
assert len(args) == expect, f"{func.__name__} expects two input arguments, got {len(args)}" # type: ignore[union-attr]
return func(*args, **kwargs)

return check_num_args
Expand Down
47 changes: 25 additions & 22 deletions src/earthkit/workflows/backends/earthkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from typing import TypeAlias
from typing import Callable, TypeAlias

import array_api_compat
from earthkit.data import FieldList
Expand All @@ -33,7 +33,7 @@ def comp_str2func(array_module, comparison: str):
return array_module.greater


Metadata: TypeAlias = "dict | callable | None"
Metadata: TypeAlias = dict | Callable | None


def resolve_metadata(metadata: Metadata, *args) -> dict:
Expand Down Expand Up @@ -64,7 +64,7 @@ def new_fieldlist(data, metadata: list[ekdMetadata], overrides: dict):


class FieldListBackend:
def _merge(*fieldlists: list[FieldList]):
def _merge(*fieldlists: FieldList):
"""Merge fieldlist elements into a single array. fieldlists with
different number of fields must be concatenated, otherwise, the
elements in each fieldlist are stacked along a new dimension
Expand All @@ -76,7 +76,7 @@ def _merge(*fieldlists: list[FieldList]):
xp = array_api_compat.array_namespace(*values)
return xp.asarray(values)

def multi_arg_function(func: str, *arrays: list[FieldList], metadata: Metadata = None) -> FieldList:
def multi_arg_function(func: str, *arrays: FieldList, metadata: Metadata = None) -> FieldList:
merged_array = FieldListBackend._merge(*arrays)
xp = array_api_compat.array_namespace(*merged_array)
res = standardise_output(getattr(xp, func)(merged_array, axis=0))
Expand All @@ -101,53 +101,56 @@ def two_arg_function(func: str, *arrays: FieldList, metadata: Metadata = None) -
res = getattr(xp, func)(val1, val2)
return new_fieldlist(res, [arrays[0][x].metadata() for x in range(len(res))], metadata)

def mean(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList:
def mean(*arrays: FieldList, metadata: Metadata = None) -> FieldList:
return FieldListBackend.multi_arg_function("mean", *arrays, metadata=metadata)

def std(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList:
def std(*arrays: FieldList, metadata: Metadata = None) -> FieldList:
return FieldListBackend.multi_arg_function("std", *arrays, metadata=metadata)

def min(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList:
def min(*arrays: FieldList, metadata: Metadata = None) -> FieldList:
return FieldListBackend.multi_arg_function("min", *arrays, metadata=metadata)

def max(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList:
def max(*arrays: FieldList, metadata: Metadata = None) -> FieldList:
return FieldListBackend.multi_arg_function("max", *arrays, metadata=metadata)

def sum(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList:
def sum(*arrays: FieldList, metadata: Metadata = None) -> FieldList:
return FieldListBackend.multi_arg_function("sum", *arrays, metadata=metadata)

def prod(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList:
def prod(*arrays: FieldList, metadata: Metadata = None) -> FieldList:
return FieldListBackend.multi_arg_function("prod", *arrays, metadata=metadata)

def var(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList:
def var(*arrays: FieldList, metadata: Metadata = None) -> FieldList:
return FieldListBackend.multi_arg_function("var", *arrays, metadata=metadata)

def stack(*arrays: list[FieldList], axis: int = 0) -> FieldList:
def stack(*arrays: FieldList, axis: int = 0) -> FieldList:
if axis != 0:
raise ValueError("Can not stack FieldList along axis != 0")
assert all([len(x) == 1 for x in arrays]), "Can not stack FieldLists with more than one element, use concat"
return FieldListBackend.concat(*arrays)

def add(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList:
def add(*arrays: FieldList, metadata: Metadata = None) -> FieldList:
return FieldListBackend.two_arg_function("add", *arrays, metadata=metadata)

def subtract(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList:
def subtract(*arrays: FieldList, metadata: Metadata = None) -> FieldList:
return FieldListBackend.two_arg_function("subtract", *arrays, metadata=metadata)

@num_args(2)
def diff(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList:
return FieldListBackend.multiply(FieldListBackend.subtract(*arrays, metadata=metadata), -1)
def diff(*arrays: FieldList, metadata: Metadata = None) -> FieldList:
return FieldListBackend.multiply(
FieldListBackend.subtract(*arrays, metadata=metadata),
-1, # type: ignore[arg-type]
)

def multiply(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList:
def multiply(*arrays: FieldList, metadata: Metadata = None) -> FieldList:
return FieldListBackend.two_arg_function("multiply", *arrays, metadata=metadata)

def divide(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList:
def divide(*arrays: FieldList, metadata: Metadata = None) -> FieldList:
return FieldListBackend.two_arg_function("divide", *arrays, metadata=metadata)

def pow(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList:
def pow(*arrays: FieldList, metadata: Metadata = None) -> FieldList:
return FieldListBackend.two_arg_function("pow", *arrays, metadata=metadata)

def concat(*arrays: list[FieldList]) -> FieldList:
def concat(*arrays: FieldList) -> FieldList:
"""Concatenates the list of fields inside each FieldList into a single
FieldList object

Expand Down Expand Up @@ -176,7 +179,7 @@ def take(
if dim != 0:
raise ValueError("Can not slice from FieldList along dim != 0")
if isinstance(indices, int):
indices = [indices]
indices = [indices] # type: ignore[assignment]
ret = array[indices]
else:
if not isinstance(dim, str):
Expand All @@ -190,7 +193,7 @@ def take(

return FieldList.from_array(ret.values, ret.metadata())

def norm(*arrays: list[FieldList], metadata: Metadata = None) -> FieldList:
def norm(*arrays: FieldList, metadata: Metadata = None) -> FieldList:
merged_array = FieldListBackend._merge(*arrays)
xp = array_api_compat.array_namespace(merged_array)
norm = standardise_output(xp.sqrt(xp.sum(xp.pow(merged_array, 2), axis=0)))
Expand Down
4 changes: 3 additions & 1 deletion src/earthkit/workflows/compilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from earthkit.workflows.graph import Graph, serialise

logger = logging.getLogger(__name__)

try:
from cascade.low.core import JobInstance
from cascade.low.into import graph2job as cascadeInto
Expand All @@ -28,7 +30,7 @@
def cascadeInto(graph: dict) -> Any:
raise NotImplementedError("failed to import cascade execution engine")

JobInstance = Any
JobInstance = object # type: ignore[assignment]

Engine = Literal["cascade"]

Expand Down
Loading
Loading