diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index 7fcfd1b1..f8257a29 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -3,7 +3,7 @@ name: cd on: push: tags: - - '**' + - '**' jobs: pypi_binwheels: diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d32c43b3..d80a36e3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,9 +2,9 @@ name: ci on: push: - branches: [ "main", "develop" ] + branches: ["main", "develop"] pull_request: - branches: [ "main", "develop" ] + branches: ["main", "develop"] jobs: qa-pre-commit: diff --git a/.github/workflows/downstream-ci.yml b/.github/workflows/downstream-ci.yml index a9b1bb00..d9a5d11b 100644 --- a/.github/workflows/downstream-ci.yml +++ b/.github/workflows/downstream-ci.yml @@ -4,16 +4,16 @@ on: # Trigger the workflow on push to master or develop, except tag creation push: branches: - - 'main' - - 'develop' + - 'main' + - 'develop' tags-ignore: - - '**' + - '**' # Trigger the workflow on pull request - pull_request: ~ + pull_request: # Trigger the workflow manually - workflow_dispatch: ~ + workflow_dispatch: # Trigger after public PR approved for CI pull_request_target: diff --git a/.github/workflows/test-pypi.yml b/.github/workflows/test-pypi.yml index c73b85f9..f1aef3ab 100644 --- a/.github/workflows/test-pypi.yml +++ b/.github/workflows/test-pypi.yml @@ -2,7 +2,7 @@ name: test-cd on: pull_request: - branches: [ "main" ] + branches: ["main"] jobs: pypi_binwheels: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0cfe1815..962789fa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,45 +1,43 @@ default_language_version: - python: python3 + python: python3 default_stages: - - pre-commit - - pre-push +- pre-commit +- pre-push repos: - - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.5.6 - hooks: - - id: ruff # fix linting violations - types_or: [ python, pyi, jupyter ] - args: [ --fix ] - # - id: ruff-format # fix formatting - # types_or: [ python, pyi, jupyter ] - - repo: https://github.com/psf/black - rev: 25.1.0 - hooks: - - id: black - - repo: https://github.com/pycqa/isort - rev: 5.13.2 - hooks: - - id: isort - - repo: https://github.com/pycqa/flake8 - rev: 7.0.0 - hooks: - - id: flake8 - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 - hooks: - - id: detect-private-key - - id: check-ast - - id: end-of-file-fixer - - id: mixed-line-ending - args: [--fix=lf] - - id: trailing-whitespace - - id: check-case-conflict - - repo: local - hooks: - - id: forbid-to-commit - name: Don't commit rej files - entry: | - Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates. - Fix the merge conflicts manually and remove the .rej files. - language: fail - files: '.*\.rej$' +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: detect-private-key + - id: check-ast + - id: end-of-file-fixer + - id: mixed-line-ending + args: [--fix=lf] + - id: trailing-whitespace + - id: check-case-conflict +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.15.4 + hooks: + - id: ruff-check + exclude: '(dev/.*|.*_)\.py$' + args: + - --line-length=120 + - --fix + - --exit-non-zero-on-fix + - --preview + - id: ruff-format +- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks + rev: v2.16.0 + hooks: + - id: pretty-format-yaml + args: [--autofix, --preserve-quotes] + - id: pretty-format-toml + args: [--autofix] +- repo: local + hooks: + - id: forbid-to-commit + name: Don't commit rej files + entry: | + Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates. + Fix the merge conflicts manually and remove the .rej files. + language: fail + files: '.*\.rej$' diff --git a/.readthedocs.yml b/.readthedocs.yml index bc4b12d1..0a8b2a80 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -6,16 +6,16 @@ build: python: "3.11" jobs: pre_build: - - rm -rf _build - - rm -rf docs/_build - - cd docs && make rst + - rm -rf _build + - rm -rf docs/_build + - cd docs && make rst python: install: - - method: pip - path: . - extra_requirements: - - docs + - method: pip + path: . + extra_requirements: + - docs sphinx: configuration: docs/source/conf.py diff --git a/Cargo.toml b/Cargo.toml index d9fd6c2e..839a8a44 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,18 +1,18 @@ -[package] -name = "earthkit-hydro" -version = "0.0.0" # placeholder, will be overwritten -edition = "2021" - [dependencies] -pyo3 = { version = "0.26", features = ["extension-module"] } +fixedbitset = "0.5" numpy = "0.26" +pyo3 = {version = "0.26", features = ["extension-module"]} rayon = "1.7" -fixedbitset = "0.5" [lib] +crate-type = ["cdylib"] # See https://github.com/PyO3/pyo3 for details name = "_rust" # private module to be nested into Python package path = "rust/lib.rs" -crate-type = ["cdylib"] + +[package] +edition = "2021" +name = "earthkit-hydro" +version = "0.0.0" # placeholder, will be overwritten # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/docs/source/tutorials/array_backend.ipynb b/docs/source/tutorials/array_backend.ipynb index db088140..f27c4ccb 100644 --- a/docs/source/tutorials/array_backend.ipynb +++ b/docs/source/tutorials/array_backend.ipynb @@ -7,11 +7,12 @@ "metadata": {}, "outputs": [], "source": [ - "import earthkit.hydro as ekh\n", - "import numpy as np\n", "import cupy as cp\n", + "import numpy as np\n", "import torch\n", - "import xarray as xr" + "import xarray as xr\n", + "\n", + "import earthkit.hydro as ekh" ] }, { @@ -208,7 +209,9 @@ } ], "source": [ - "ekh.upstream.array.sum(torch_network, torch.ones(numpy_network.shape, device=torch_network.device), return_type=\"masked\")" + "ekh.upstream.array.sum(\n", + " torch_network, torch.ones(numpy_network.shape, device=torch_network.device), return_type=\"masked\"\n", + ")" ] }, { @@ -231,24 +234,24 @@ "example_arr_numpy = np.random.rand(10, *numpy_network.shape)\n", "example_arr_cupy = cp.random.rand(10, *numpy_network.shape)\n", "\n", - "lat = numpy_network.coords['lat']\n", - "lon = numpy_network.coords['lon']\n", + "lat = numpy_network.coords[\"lat\"]\n", + "lon = numpy_network.coords[\"lon\"]\n", "step = np.arange(10)\n", "\n", "example_da_numpy = xr.DataArray(\n", " example_arr_numpy,\n", - " dims = [\"step\", \"lat\", \"lon\"],\n", - " coords = {\"step\": step, \"lat\": lat, \"lon\": lon},\n", - " name = \"precip\",\n", - " attrs={\"units\": \"m\", \"description\": \"Sample precipitation data\"}\n", + " dims=[\"step\", \"lat\", \"lon\"],\n", + " coords={\"step\": step, \"lat\": lat, \"lon\": lon},\n", + " name=\"precip\",\n", + " attrs={\"units\": \"m\", \"description\": \"Sample precipitation data\"},\n", ")\n", "\n", "example_da_cupy = xr.DataArray(\n", " example_arr_cupy,\n", - " dims = [\"step\", \"lat\", \"lon\"],\n", - " coords = {\"step\": step, \"lat\": lat, \"lon\": lon},\n", - " name = \"precip\",\n", - " attrs={\"units\": \"m\", \"description\": \"Sample precipitation data\"}\n", + " dims=[\"step\", \"lat\", \"lon\"],\n", + " coords={\"step\": step, \"lat\": lat, \"lon\": lon},\n", + " name=\"precip\",\n", + " attrs={\"units\": \"m\", \"description\": \"Sample precipitation data\"},\n", ")" ] }, diff --git a/docs/source/tutorials/catchment_statistics.ipynb b/docs/source/tutorials/catchment_statistics.ipynb index 2a704ad6..c41fde24 100644 --- a/docs/source/tutorials/catchment_statistics.ipynb +++ b/docs/source/tutorials/catchment_statistics.ipynb @@ -15,9 +15,10 @@ } ], "source": [ - "import earthkit.hydro as ekh\n", "import numpy as np\n", "\n", + "import earthkit.hydro as ekh\n", + "\n", "network = ekh.river_network.load(\"efas\", \"5\", use_cache=False)" ] }, @@ -507,12 +508,9 @@ } ], "source": [ - "field = np.ones(network.shape) # or load array/xarray from file\n", + "field = np.ones(network.shape) # or load array/xarray from file\n", "\n", - "ekh.catchments.sum(network, field, locations = {\n", - " \"gauge_1\": (70.475, 28.32),\n", - " \"gauge_2\": (42.225, 50.24)\n", - "})" + "ekh.catchments.sum(network, field, locations={\"gauge_1\": (70.475, 28.32), \"gauge_2\": (42.225, 50.24)})" ] } ], diff --git a/docs/source/tutorials/computing_accumulations.ipynb b/docs/source/tutorials/computing_accumulations.ipynb index c48f8155..a47a148e 100644 --- a/docs/source/tutorials/computing_accumulations.ipynb +++ b/docs/source/tutorials/computing_accumulations.ipynb @@ -15,9 +15,10 @@ } ], "source": [ - "import earthkit.hydro as ekh\n", - "import numpy as np\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "import earthkit.hydro as ekh\n", "\n", "network = ekh.river_network.load(\"efas\", \"5\", use_cache=False)" ] @@ -64,7 +65,7 @@ } ], "source": [ - "field = np.ones(network.shape) # or load array/xarray from file\n", + "field = np.ones(network.shape) # or load array/xarray from file\n", "\n", "da = ekh.upstream.sum(network, field)\n", "\n", diff --git a/docs/source/tutorials/creating_subnetworks.ipynb b/docs/source/tutorials/creating_subnetworks.ipynb index 955b60b4..564d7161 100644 --- a/docs/source/tutorials/creating_subnetworks.ipynb +++ b/docs/source/tutorials/creating_subnetworks.ipynb @@ -15,9 +15,10 @@ } ], "source": [ - "import earthkit.hydro as ekh\n", - "import numpy as np\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "import earthkit.hydro as ekh\n", "\n", "network = ekh.river_network.load(\"efas\", \"5\", use_cache=False)" ] @@ -64,7 +65,7 @@ } ], "source": [ - "node_mask = ekh.catchments.array.find(network, locations = {\"gauge_1\": (42.225, 50.24)})\n", + "node_mask = ekh.catchments.array.find(network, locations={\"gauge_1\": (42.225, 50.24)})\n", "# has nans for missing, and 0s for the catchment\n", "# => convert to boolean mask\n", "node_mask = node_mask == 0\n", diff --git a/docs/source/tutorials/distance_length.ipynb b/docs/source/tutorials/distance_length.ipynb index 3cbb3902..0387a53d 100644 --- a/docs/source/tutorials/distance_length.ipynb +++ b/docs/source/tutorials/distance_length.ipynb @@ -15,9 +15,10 @@ } ], "source": [ - "import earthkit.hydro as ekh\n", - "import numpy as np\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "import earthkit.hydro as ekh\n", "\n", "network = ekh.river_network.load(\"efas\", \"5\", use_cache=False)" ] @@ -62,10 +63,12 @@ } ], "source": [ - "da = ekh.length.min(network, locations = {\n", - " \"gauge_1\": (47.04166666666667, 47.40833333333333),\n", - " \"gauge_2\": (42.225, 50.24)\n", - "}, upstream=True, downstream=False)\n", + "da = ekh.length.min(\n", + " network,\n", + " locations={\"gauge_1\": (47.04166666666667, 47.40833333333333), \"gauge_2\": (42.225, 50.24)},\n", + " upstream=True,\n", + " downstream=False,\n", + ")\n", "\n", "da.plot.contourf(cmap=\"viridis\", levels=100)\n", "plt.show()" @@ -89,10 +92,12 @@ } ], "source": [ - "da = ekh.length.min(network, locations = {\n", - " \"gauge_1\": (47.04166666666667, 47.40833333333333),\n", - " \"gauge_2\": (42.225, 50.24)\n", - "}, upstream=False, downstream=True)\n", + "da = ekh.length.min(\n", + " network,\n", + " locations={\"gauge_1\": (47.04166666666667, 47.40833333333333), \"gauge_2\": (42.225, 50.24)},\n", + " upstream=False,\n", + " downstream=True,\n", + ")\n", "\n", "da.plot.contourf(cmap=\"viridis\", levels=100)\n", "plt.show()" @@ -116,10 +121,12 @@ } ], "source": [ - "da = ekh.length.min(network, locations = {\n", - " \"gauge_1\": (47.04166666666667, 47.40833333333333),\n", - " \"gauge_2\": (42.225, 50.24)\n", - "}, upstream=True, downstream=True)\n", + "da = ekh.length.min(\n", + " network,\n", + " locations={\"gauge_1\": (47.04166666666667, 47.40833333333333), \"gauge_2\": (42.225, 50.24)},\n", + " upstream=True,\n", + " downstream=True,\n", + ")\n", "\n", "da.plot.contourf(cmap=\"viridis\", levels=100)\n", "plt.show()" @@ -153,10 +160,13 @@ "source": [ "pixel_lengths = np.random.rand(*network.shape)\n", "\n", - "da = ekh.length.min(network, locations = {\n", - " \"gauge_1\": (47.04166666666667, 47.40833333333333),\n", - " \"gauge_2\": (42.225, 50.24)\n", - "}, field=pixel_lengths, upstream=True, downstream=True)\n", + "da = ekh.length.min(\n", + " network,\n", + " locations={\"gauge_1\": (47.04166666666667, 47.40833333333333), \"gauge_2\": (42.225, 50.24)},\n", + " field=pixel_lengths,\n", + " upstream=True,\n", + " downstream=True,\n", + ")\n", "\n", "da.plot.contourf(cmap=\"viridis\", levels=100)\n", "plt.show()" @@ -192,12 +202,12 @@ } ], "source": [ - "locations = {\n", - " \"gauge_1\": (47.04166666666667, 47.40833333333333),\n", - " \"gauge_2\": (42.225, 50.24)\n", - "}\n", + "locations = {\"gauge_1\": (47.04166666666667, 47.40833333333333), \"gauge_2\": (42.225, 50.24)}\n", "\n", - "np.all(ekh.distance.array.min(network, locations, return_type=\"masked\") == ekh.length.array.min(network, locations, return_type=\"masked\") - 1)" + "np.all(\n", + " ekh.distance.array.min(network, locations, return_type=\"masked\")\n", + " == ekh.length.array.min(network, locations, return_type=\"masked\") - 1\n", + ")" ] }, { @@ -228,7 +238,7 @@ "source": [ "edge_weights = np.random.rand(network.n_edges)\n", "\n", - "da = ekh.distance.min(network, locations, field = edge_weights, upstream=True, downstream=True)\n", + "da = ekh.distance.min(network, locations, field=edge_weights, upstream=True, downstream=True)\n", "\n", "da.plot.contourf(cmap=\"viridis\", levels=100)\n", "plt.show()" @@ -264,7 +274,7 @@ "source": [ "edge_weights = np.random.rand(network.n_edges)\n", "\n", - "da = ekh.distance.to_source(network, field = edge_weights, path=\"shortest\")\n", + "da = ekh.distance.to_source(network, field=edge_weights, path=\"shortest\")\n", "\n", "da.plot.contourf(cmap=\"viridis\", levels=100)\n", "plt.show()" @@ -290,7 +300,7 @@ "source": [ "edge_weights = np.random.rand(network.n_nodes)\n", "\n", - "da = ekh.length.to_sink(network, field = edge_weights, path=\"longest\")\n", + "da = ekh.length.to_sink(network, field=edge_weights, path=\"longest\")\n", "\n", "da.plot.contourf(cmap=\"viridis\", levels=100)\n", "plt.show()" diff --git a/docs/source/tutorials/finding_catchments.ipynb b/docs/source/tutorials/finding_catchments.ipynb index 17d3c94b..714120ce 100644 --- a/docs/source/tutorials/finding_catchments.ipynb +++ b/docs/source/tutorials/finding_catchments.ipynb @@ -15,9 +15,10 @@ } ], "source": [ - "import earthkit.hydro as ekh\n", - "import numpy as np\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "import earthkit.hydro as ekh\n", "\n", "network = ekh.river_network.load(\"efas\", \"5\", use_cache=False)" ] @@ -60,10 +61,7 @@ } ], "source": [ - "da = ekh.catchments.find(network, locations={\n", - " \"gauge_1\": (70.475, 28.32),\n", - " \"gauge_2\": (42.225, 50.24)\n", - "})\n", + "da = ekh.catchments.find(network, locations={\"gauge_1\": (70.475, 28.32), \"gauge_2\": (42.225, 50.24)})\n", "\n", "da.plot.contourf(cmap=\"viridis\", levels=100)\n", "plt.show()" @@ -88,7 +86,7 @@ ], "source": [ "test = np.ones(network.shape, dtype=bool)\n", - "test[(106, 3214)] = 0\n", + "test[(106, 3214)] = 0\n", "test.flat[network.mask].argmin()" ] }, @@ -118,10 +116,7 @@ } ], "source": [ - "da = ekh.catchments.find(network, locations=[\n", - " (106, 3214),\n", - " (1801, 4529)\n", - "])\n", + "da = ekh.catchments.find(network, locations=[(106, 3214), (1801, 4529)])\n", "\n", "da.plot.contourf(cmap=\"viridis\", levels=100)\n", "plt.show()" @@ -189,10 +184,9 @@ } ], "source": [ - "da = ekh.catchments.find(network, locations={\n", - " \"gauge_1\": (47.04166666666667, 47.40833333333333),\n", - " \"gauge_2\": (42.225, 50.24)\n", - "}, overwrite=False)\n", + "da = ekh.catchments.find(\n", + " network, locations={\"gauge_1\": (47.04166666666667, 47.40833333333333), \"gauge_2\": (42.225, 50.24)}, overwrite=False\n", + ")\n", "\n", "da.plot.contourf(cmap=\"viridis\", levels=100)\n", "plt.show()" @@ -224,10 +218,9 @@ } ], "source": [ - "da = ekh.catchments.find(network, locations={\n", - " \"gauge_1\": (47.04166666666667, 47.40833333333333),\n", - " \"gauge_2\": (42.225, 50.24)\n", - "})\n", + "da = ekh.catchments.find(\n", + " network, locations={\"gauge_1\": (47.04166666666667, 47.40833333333333), \"gauge_2\": (42.225, 50.24)}\n", + ")\n", "\n", "da.plot.contourf(cmap=\"viridis\", levels=100)\n", "plt.show()" diff --git a/docs/source/tutorials/gridded_masked.ipynb b/docs/source/tutorials/gridded_masked.ipynb index a0594217..9eefc70d 100644 --- a/docs/source/tutorials/gridded_masked.ipynb +++ b/docs/source/tutorials/gridded_masked.ipynb @@ -15,10 +15,11 @@ } ], "source": [ - "import earthkit.hydro as ekh\n", "import numpy as np\n", "import xarray as xr\n", "\n", + "import earthkit.hydro as ekh\n", + "\n", "network = ekh.river_network.load(\"efas\", \"5\", use_cache=False)" ] }, @@ -1089,8 +1090,20 @@ } ], "source": [ - "print(\"gridded shape:\", ekh.upstream.array.sum(network, np.random.rand(*network.shape), return_type=\"gridded\").shape, \"=\", \"network.shape:\", network.shape)\n", - "print(\"masked shape:\", ekh.upstream.array.sum(network, np.random.rand(*network.shape), return_type=\"masked\").shape, \"=\", \"network.n_nodes\", network.n_nodes)" + "print(\n", + " \"gridded shape:\",\n", + " ekh.upstream.array.sum(network, np.random.rand(*network.shape), return_type=\"gridded\").shape,\n", + " \"=\",\n", + " \"network.shape:\",\n", + " network.shape,\n", + ")\n", + "print(\n", + " \"masked shape:\",\n", + " ekh.upstream.array.sum(network, np.random.rand(*network.shape), return_type=\"masked\").shape,\n", + " \"=\",\n", + " \"network.n_nodes\",\n", + " network.n_nodes,\n", + ")" ] }, { @@ -2149,10 +2162,10 @@ "\n", "example_da = xr.DataArray(\n", " example_arr,\n", - " dims = [\"index\"],\n", - " coords = {\"index\": index},\n", - " name = \"precip\",\n", - " attrs={\"units\": \"m\", \"description\": \"Sample precipitation data\"}\n", + " dims=[\"index\"],\n", + " coords={\"index\": index},\n", + " name=\"precip\",\n", + " attrs={\"units\": \"m\", \"description\": \"Sample precipitation data\"},\n", ")\n", "\n", "ekh.upstream.sum(network, example_da, return_type=\"gridded\")" diff --git a/docs/source/tutorials/streamorder.ipynb b/docs/source/tutorials/streamorder.ipynb index 48a073ed..4400b4fe 100644 --- a/docs/source/tutorials/streamorder.ipynb +++ b/docs/source/tutorials/streamorder.ipynb @@ -7,8 +7,9 @@ "metadata": {}, "outputs": [], "source": [ - "import earthkit.hydro as ekh\n", - "import matplotlib.pyplot as plt" + "import matplotlib.pyplot as plt\n", + "\n", + "import earthkit.hydro as ekh" ] }, { diff --git a/docs/source/tutorials/xarray_array.ipynb b/docs/source/tutorials/xarray_array.ipynb index 48de116b..4577c10a 100644 --- a/docs/source/tutorials/xarray_array.ipynb +++ b/docs/source/tutorials/xarray_array.ipynb @@ -7,9 +7,10 @@ "metadata": {}, "outputs": [], "source": [ - "import earthkit.hydro as ekh\n", "import numpy as np\n", - "import xarray as xr" + "import xarray as xr\n", + "\n", + "import earthkit.hydro as ekh" ] }, { @@ -636,16 +637,16 @@ "source": [ "example_arr = np.random.rand(2, *network.shape)\n", "\n", - "lat = network.coords['lat']\n", - "lon = network.coords['lon']\n", + "lat = network.coords[\"lat\"]\n", + "lon = network.coords[\"lon\"]\n", "step = np.arange(2)\n", "\n", "example_da = xr.DataArray(\n", " example_arr,\n", - " dims = [\"step\", \"lat\", \"lon\"],\n", - " coords = {\"step\": step, \"lat\": lat, \"lon\": lon},\n", - " name = \"precip\",\n", - " attrs={\"units\": \"m\", \"description\": \"Sample precipitation data\"}\n", + " dims=[\"step\", \"lat\", \"lon\"],\n", + " coords={\"step\": step, \"lat\": lat, \"lon\": lon},\n", + " name=\"precip\",\n", + " attrs={\"units\": \"m\", \"description\": \"Sample precipitation data\"},\n", ")\n", "example_da" ] @@ -1715,11 +1716,7 @@ } ], "source": [ - "ekh.catchments.sum(network, example_da, locations = {\n", - " \"gauge_1\": (70.475, 28.32),\n", - " \"gauge_2\": (42.225, 50.24)\n", - " }\n", - ")" + "ekh.catchments.sum(network, example_da, locations={\"gauge_1\": (70.475, 28.32), \"gauge_2\": (42.225, 50.24)})" ] }, { @@ -2281,9 +2278,7 @@ } ], "source": [ - "example_ds = xr.Dataset(\n", - " data_vars={\"var1\": example_da, \"var2\": example_da+1}\n", - ")\n", + "example_ds = xr.Dataset(data_vars={\"var1\": example_da, \"var2\": example_da + 1})\n", "example_ds" ] }, @@ -2858,10 +2853,10 @@ "source": [ "example_da_uninferrable = xr.DataArray(\n", " example_arr,\n", - " dims = [\"step\", \"uninferrable_name_1\", \"uninferrable_name_2\"],\n", - " coords = {\"step\": step, \"uninferrable_name_1\": lat, \"uninferrable_name_2\": lon},\n", - " name = \"precip\",\n", - " attrs={\"units\": \"m\", \"description\": \"Sample precipitation data\"}\n", + " dims=[\"step\", \"uninferrable_name_1\", \"uninferrable_name_2\"],\n", + " coords={\"step\": step, \"uninferrable_name_1\": lat, \"uninferrable_name_2\": lon},\n", + " name=\"precip\",\n", + " attrs={\"units\": \"m\", \"description\": \"Sample precipitation data\"},\n", ")" ] }, @@ -4028,7 +4023,12 @@ } ], "source": [ - "ekh.upstream.sum(network, example_da, node_weights=example_da_uninferrable, input_core_dims=[[\"lat\", \"lon\"], [\"uninferrable_name_1\", \"uninferrable_name_2\"]])" + "ekh.upstream.sum(\n", + " network,\n", + " example_da,\n", + " node_weights=example_da_uninferrable,\n", + " input_core_dims=[[\"lat\", \"lon\"], [\"uninferrable_name_1\", \"uninferrable_name_2\"]],\n", + ")" ] }, { @@ -4697,11 +4697,7 @@ } ], "source": [ - "ekh.catchments.array.sum(network, example_arr, locations = {\n", - " \"gauge_1\": (70.475, 28.32),\n", - " \"gauge_2\": (42.225, 50.24)\n", - " }\n", - ")" + "ekh.catchments.array.sum(network, example_arr, locations={\"gauge_1\": (70.475, 28.32), \"gauge_2\": (42.225, 50.24)})" ] } ], diff --git a/pyproject.toml b/pyproject.toml index 3f4695cc..c44432de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,108 +1,134 @@ [build-system] -requires = ["setuptools>=77", "setuptools-rust", "setuptools_scm[toml]>=8"] build-backend = "setuptools.build_meta" +requires = ["setuptools>=77", "setuptools-rust", "setuptools_scm[toml]>=8"] [project] -name = "earthkit-hydro" -requires-python = ">=3.9" authors = [ - { name = "European Centre for Medium-Range Weather Forecasts (ECMWF)", email = "software.support@ecmwf.int" } - ] -maintainers = [ - {name = "Oisín M. Morrison", email = "oisin.morrison@ecmwf.int"}, - {name = "Corentin Carton de Wiart", email = "corentin.carton@ecmwf.int"} - ] -description = "A Python library for common hydrological functions" -license = "Apache-2.0" -license-files = [ "LICENSE" ] + {name = "European Centre for Medium-Range Weather Forecasts (ECMWF)", email = "software.support@ecmwf.int"} +] classifiers = [ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "Intended Audience :: Science/Research", - "Natural Language :: English", - "Operating System :: OS Independent", - "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", - "Programming Language :: Python :: 3.14", - "Topic :: Scientific/Engineering" - ] + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Natural Language :: English", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Topic :: Scientific/Engineering" +] +dependencies = [ + "numpy", + "joblib", + "xarray", + "earthkit-utils>=0.2.1" +] +description = "A Python library for common hydrological functions" dynamic = ["version", "readme"] +license = "Apache-2.0" +license-files = ["LICENSE"] +maintainers = [ + {name = "Oisín M. Morrison", email = "oisin.morrison@ecmwf.int"}, + {name = "Corentin Carton de Wiart", email = "corentin.carton@ecmwf.int"} +] +name = "earthkit-hydro" +requires-python = ">=3.9" -dependencies = [ - "numpy", - "joblib", - "xarray", - "earthkit-utils>=0.2.1" +[project.optional-dependencies] +all = [ + "earthkit-data[geotiff]>=0.13.8", + "pytest", + "pre-commit" +] +dev = [ + "pytest", + "pre-commit" +] +docs = [ + "sphinx", + "furo", + "sphinxcontrib-bibtex", + "nbsphinx", + "nbconvert", + "ipykernel" +] +grit = [ + "geopandas", + "pandas" +] +readers = [ + "earthkit-data[geotiff]>=0.13.8" +] +tests = [ + "pytest", + "torch", + "jax" ] [project.urls] - repository = "https://github.com/ecmwf/earthkit-hydro" - documentation = "https://earthkit-hydro.readthedocs.io/" - issues = "https://github.com/ecmwf/earthkit-hydro/issues" +documentation = "https://earthkit-hydro.readthedocs.io/" +issues = "https://github.com/ecmwf/earthkit-hydro/issues" +repository = "https://github.com/ecmwf/earthkit-hydro" -[project.optional-dependencies] - readers = [ - "earthkit-data[geotiff]>=0.13.8" - ] - grit = [ - "geopandas", - "pandas" - ] - tests = [ - "pytest", - "torch", - "jax" - ] - dev = [ - "pytest", - "pre-commit" - ] - docs = [ - "sphinx", - "furo", - "sphinxcontrib-bibtex", - "nbsphinx", - "nbconvert", - "ipykernel" - ] - all = [ - "earthkit-data[geotiff]>=0.13.8", - "pytest", - "pre-commit" - ] - -[tool.black] -line-length = 88 -skip-string-normalization = false +# takes inspiration from https://github.com/pypa/cibuildwheel/discussions/1814 +[tool.cibuildwheel] +skip = ["*musl*", "*-win_arm64"] # ignore some problematic platforms +linux.before-all = "curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --profile minimal" +linux.environment = {PATH = "$HOME/.cargo/bin:$PATH"} +macos.before-all = "curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --profile minimal" +windows.archs = "all" +windows.before-all = "rustup target add aarch64-pc-windows-msvc i686-pc-windows-msvc x86_64-pc-windows-msvc" -[tool.isort] -profile = "black" # Ensures compatibility with Black's formatting. -line_length = 88 # Same as Black's line length for consistency. +# Testing +[tool.pytest] +adopts = "--pdbcls=IPython.terminal.debugger:Pdb" +testpaths = ["tests"] # Linting settings [tool.ruff] -line-length = 88 +line-length = 120 [tool.ruff.format] quote-style = "double" +[tool.ruff.lint] +ignore = [ + "D1", # pydocstyle: Missing Docstrings + "D107", # pydocstyle: numpy convention + "D203", + "D205", + "D212", + "D213", + "D401", + "D402", + "D413", + "D415", + "D416", + "D417" +] +select = [ + "F", # pyflakes + "E", # pycodestyle + "W", # pycodestyle warnings + "I", # isort + "D" # pydocstyle +] + [tool.ruff.lint.per-file-ignores] "__init__.py" = [ - "F401", # unused imports - ] + "F401" # unused imports +] +"docs/source/conf.py" = [ + "E501" # SVG path data cannot be split +] "tests/*" = [ - "F405", # variable may be undefined, or defined from star imports - "F403", # use of wildcard imports - ] - -# Testing -[tool.pytest] -addopts = "--pdbcls=IPython.terminal.debugger:Pdb" -testpaths = ["tests"] + "F405", # variable may be undefined, or defined from star imports + "F403" # use of wildcard imports +] # Packaging/setuptools options [tool.setuptools] @@ -112,24 +138,15 @@ include-package-data = true readme = {file = ["README.md"], content-type = "text/markdown"} [tool.setuptools.packages.find] -include = [ "earthkit.hydro" ] -where = [ "src/" ] +include = ["earthkit.hydro"] +where = ["src/"] [tool.setuptools_scm] +fallback_version = '0.1.0' +local_scheme = "no-local-version" +parentdir_prefix_version = 'earthkit-hydro-' # get version from GitHub-like tarballs version_file = "src/earthkit/hydro/_version.py" version_file_template = ''' # Do not change! Do not track in version control! __version__ = "{version}" ''' -parentdir_prefix_version='earthkit-hydro-' # get version from GitHub-like tarballs -fallback_version='0.1.0' -local_scheme = "no-local-version" - -# takes inspiration from https://github.com/pypa/cibuildwheel/discussions/1814 -[tool.cibuildwheel] -skip = ["*musl*", "*-win_arm64"] # ignore some problematic platforms -linux.before-all = "curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --profile minimal" -linux.environment = { PATH="$HOME/.cargo/bin:$PATH" } -macos.before-all = "curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --profile minimal" -windows.before-all = "rustup target add aarch64-pc-windows-msvc i686-pc-windows-msvc x86_64-pc-windows-msvc" -windows.archs = "all" diff --git a/src/earthkit/hydro/_backends/tensorflow_backend.py b/src/earthkit/hydro/_backends/tensorflow_backend.py index 412fd753..ae3a6f46 100644 --- a/src/earthkit/hydro/_backends/tensorflow_backend.py +++ b/src/earthkit/hydro/_backends/tensorflow_backend.py @@ -25,15 +25,11 @@ def scatter_assign(self, target, indices, updates): batch_range = tf.range(num_batch)[:, None] batch_ids = tf.tile(batch_range, [1, num_idx]) - scatter_idx = tf.stack( - [batch_ids, tf.tile(tf.expand_dims(indices, 0), [num_batch, 1])], axis=-1 - ) + scatter_idx = tf.stack([batch_ids, tf.tile(tf.expand_dims(indices, 0), [num_batch, 1])], axis=-1) scatter_idx = tf.reshape(scatter_idx, (-1, 2)) scatter_vals = tf.reshape(flat_values, (-1,)) - flat_result = tf.tensor_scatter_nd_update( - flat_target, scatter_idx, scatter_vals - ) + flat_result = tf.tensor_scatter_nd_update(flat_target, scatter_idx, scatter_vals) return tf.reshape(flat_result, target_shape) diff --git a/src/earthkit/hydro/_core/_find.py b/src/earthkit/hydro/_core/_find.py index 5b501b54..f02f3f29 100644 --- a/src/earthkit/hydro/_core/_find.py +++ b/src/earthkit/hydro/_core/_find.py @@ -57,18 +57,12 @@ def _find_catchments(xp, field, did, uid, eid, overwrite): None """ down_not_missing = ~xp.isnan(xp.gather(field, uid, axis=-1)) - did = did[ - down_not_missing - ] # only update nodes where the downstream belongs to a catchment + did = did[down_not_missing] # only update nodes where the downstream belongs to a catchment if not overwrite: up_is_missing = xp.isnan(xp.gather(field, did, axis=-1)) did = did[up_is_missing] else: up_is_missing = None - uid = ( - uid[down_not_missing][up_is_missing] - if up_is_missing is not None - else uid[down_not_missing] - ) + uid = uid[down_not_missing][up_is_missing] if up_is_missing is not None else uid[down_not_missing] updates = xp.gather(field, uid, axis=-1) return xp.scatter_assign(field, did, updates) diff --git a/src/earthkit/hydro/_core/move.py b/src/earthkit/hydro/_core/move.py index ede20a82..a26343e4 100644 --- a/src/earthkit/hydro/_core/move.py +++ b/src/earthkit/hydro/_core/move.py @@ -18,9 +18,7 @@ def calculate_move_metric( invert_graph = False node_modifier_use_upstream = True else: - raise ValueError( - f"flow_direction must be 'up' or 'down', got {flow_direction}." - ) + raise ValueError(f"flow_direction must be 'up' or 'down', got {flow_direction}.") if node_weights is None: if metric == "mean" or metric == "std" or metric == "var": @@ -66,9 +64,7 @@ def calculate_move_metric( xp.zeros(field.shape), func, invert_graph, - node_additive_weight=( - field**2 if node_weights is None else field**2 * node_weights - ), + node_additive_weight=(field**2 if node_weights is None else field**2 * node_weights), node_modifier_use_upstream=node_modifier_use_upstream, edge_multiplicative_weight=edge_weights, ) diff --git a/src/earthkit/hydro/_core/online.py b/src/earthkit/hydro/_core/online.py index 0132fec5..b8438380 100644 --- a/src/earthkit/hydro/_core/online.py +++ b/src/earthkit/hydro/_core/online.py @@ -16,9 +16,7 @@ def calculate_online_metric( elif flow_direction == "down": invert_graph = False else: - raise ValueError( - f"flow_direction must be 'up' or 'down', got {flow_direction}." - ) + raise ValueError(f"flow_direction must be 'up' or 'down', got {flow_direction}.") field = xp.copy(field) diff --git a/src/earthkit/hydro/_readers/readers.py b/src/earthkit/hydro/_readers/readers.py index 6abd72f2..70420a07 100644 --- a/src/earthkit/hydro/_readers/readers.py +++ b/src/earthkit/hydro/_readers/readers.py @@ -128,10 +128,8 @@ def from_cama_downxy(dx, dy): missing_mask = x_offsets != -9999 x_offsets = x_offsets[mask_upstream] y_offsets = y_offsets[mask_upstream] - upstream_indices, downstream_indices = ( - find_upstream_downstream_indices_from_offsets( - x_offsets, y_offsets, missing_mask, mask_upstream, shape - ) + upstream_indices, downstream_indices = find_upstream_downstream_indices_from_offsets( + x_offsets, y_offsets, missing_mask, mask_upstream, shape ) return create_network(upstream_indices, downstream_indices, missing_mask, shape) @@ -157,14 +155,10 @@ def from_d8(data, river_network_format="pcr_d8"): missing_mask = np.isin(data_flat, range(1, 10)) mask_upstream = data_flat != 5 elif river_network_format == "esri_d8": - missing_mask = np.isin(data_flat, np.append(0, 2 ** np.arange(8))) & ( - data_flat != 255 - ) + missing_mask = np.isin(data_flat, np.append(0, 2 ** np.arange(8))) & (data_flat != 255) mask_upstream = (data_flat != 0) & (data_flat != -1) elif river_network_format == "merit_d8": - missing_mask = np.isin(data_flat, np.append(0, 2 ** np.arange(8))) & ( - data_flat != 247 - ) + missing_mask = np.isin(data_flat, np.append(0, 2 ** np.arange(8))) & (data_flat != 247) mask_upstream = (data_flat != 0) & (data_flat != 255) else: raise ValueError(f"Unsupported river network format: {river_network_format}.") @@ -180,17 +174,13 @@ def from_d8(data, river_network_format="pcr_d8"): x_offsets = np.vectorize(x_mapping.get)(directions) y_offsets = -np.vectorize(y_mapping.get)(directions) del directions - upstream_indices, downstream_indices = ( - find_upstream_downstream_indices_from_offsets( - x_offsets, y_offsets, missing_mask, mask_upstream, shape - ) + upstream_indices, downstream_indices = find_upstream_downstream_indices_from_offsets( + x_offsets, y_offsets, missing_mask, mask_upstream, shape ) return create_network(upstream_indices, downstream_indices, missing_mask, shape) -def find_upstream_downstream_indices_from_offsets( - x_offsets, y_offsets, missing_mask, mask_upstream, shape -): +def find_upstream_downstream_indices_from_offsets(x_offsets, y_offsets, missing_mask, mask_upstream, shape): """ Function to convert from offsets to absolute indices. @@ -269,9 +259,7 @@ def create_network(upstream_indices, downstream_indices, missing_mask, shape): n_nodes, )[has_downstream] - sort_indices = np.lexsort( - (nodes[has_downstream], distances) - ) # np.argsort(distances) + sort_indices = np.lexsort((nodes[has_downstream], distances)) # np.argsort(distances) sorted_distances = distances[sort_indices] # from source to sink up_ids_sort = up_ids[sort_indices] @@ -352,9 +340,7 @@ def from_grit(path): offsets[1:] = np.cumsum(counts) del counts - topological_labels = compute_topological_labels_bifurcations( - down_ids, offsets, sources, sinks - ) + topological_labels = compute_topological_labels_bifurcations(down_ids, offsets, sources, sinks) topological_labels = topological_labels[up_ids] sort_indices = np.argsort(topological_labels) diff --git a/src/earthkit/hydro/_utils/decorators/array_backend.py b/src/earthkit/hydro/_utils/decorators/array_backend.py index e7ef32a7..62a8351f 100644 --- a/src/earthkit/hydro/_utils/decorators/array_backend.py +++ b/src/earthkit/hydro/_utils/decorators/array_backend.py @@ -13,7 +13,6 @@ def wrapper(**kwargs): backend_name = xp.name kwargs["xp"] = xp if backend_name == "jax" and allow_jax_jit: - nonlocal compiled_jax_fn if compiled_jax_fn is None: from jax import jit diff --git a/src/earthkit/hydro/_utils/decorators/xarray.py b/src/earthkit/hydro/_utils/decorators/xarray.py index 4ece7119..02e1b3d6 100644 --- a/src/earthkit/hydro/_utils/decorators/xarray.py +++ b/src/earthkit/hydro/_utils/decorators/xarray.py @@ -49,9 +49,7 @@ def reshuffled_func(*only_xr_args, **non_xr_kwargs): return reshuffled_func -def get_input_output_core_dims( - input_core_dims, output_core_dims, xr_args, river_network, return_grid -): +def get_input_output_core_dims(input_core_dims, output_core_dims, xr_args, river_network, return_grid): if input_core_dims is None: input_core_dims = [get_core_dims(xr_arg) for xr_arg in xr_args] elif len(input_core_dims) == 1: @@ -62,9 +60,7 @@ def get_input_output_core_dims( if len(input_core_dims[0]) == 2: # grid in and out output_core_dims = [input_core_dims[0]] else: - output_core_dims = [ - list(river_network.coords.keys()) - ] # 1d in, grid out + output_core_dims = [list(river_network.coords.keys())] # 1d in, grid out else: if len(input_core_dims[0]) == 1: # 1d in and out output_core_dims = [input_core_dims[0]] @@ -101,10 +97,7 @@ def wrapper(*args, **kwargs): offset = 2 if return_grid else 1 ndim = output.ndim dim_names = [f"axis{i + 1}" for i in range(ndim - offset)] - coords = { - dim: np.arange(size) - for dim, size in zip(dim_names, output.shape[:-offset]) - } + coords = {dim: np.arange(size) for dim, size in zip(dim_names, output.shape[:-offset])} if return_grid: for k, v in river_network.coords.items(): @@ -124,7 +117,6 @@ def wrapper(*args, **kwargs): } result = result.assign_coords(**assign_dict) else: - reshuffled_func = get_reshuffled_func(func, arg_order) input_core_dims, output_core_dims = get_input_output_core_dims( diff --git a/src/earthkit/hydro/_utils/locations.py b/src/earthkit/hydro/_utils/locations.py index 2a50b782..579454db 100644 --- a/src/earthkit/hydro/_utils/locations.py +++ b/src/earthkit/hydro/_utils/locations.py @@ -6,16 +6,12 @@ def locations_to_1d(xp, river_network, locations): orig_locations = locations dict_locations = isinstance(locations, dict) if dict_locations: - coord1_network_vals, coord2_network_vals = river_network.coords.values() locations = [] if river_network.shape is None: # vector network for coord1_val, coord2_val in orig_locations.values(): - indx = ( - (coord1_val - coord1_network_vals) ** 2 - + (coord2_val - coord2_network_vals) ** 2 - ).argmin() + indx = ((coord1_val - coord1_network_vals) ** 2 + (coord2_val - coord2_network_vals) ** 2).argmin() locations.append(int(indx)) else: for coord1_val, coord2_val in orig_locations.values(): @@ -38,14 +34,10 @@ def locations_to_1d(xp, river_network, locations): dtype=int, device=river_network.device, ) - reverse_map[flat_mask] = xp.arange( - flat_mask.shape[0], device=river_network.device - ) + reverse_map[flat_mask] = xp.arange(flat_mask.shape[0], device=river_network.device) masked_indices = reverse_map[flat_indices] if xp.any(masked_indices < 0): - raise ValueError( - "Some station points are not included in the masked array." - ) + raise ValueError("Some station points are not included in the masked array.") stations = xp.asarray(masked_indices, device=river_network.device) else: assert stations.ndim == 1 diff --git a/src/earthkit/hydro/_utils/readers.py b/src/earthkit/hydro/_utils/readers.py index 4d127084..0fbf9ad1 100644 --- a/src/earthkit/hydro/_utils/readers.py +++ b/src/earthkit/hydro/_utils/readers.py @@ -68,21 +68,16 @@ def _replace_missing_f8(cur, new): def from_file(path, mask=False): """Load a .map file into a numpy array.""" - with open(path, "rb") as f: bytes = f.read() nbytes_header = 64 + 2 + 2 + 8 + 8 + 8 + 8 + 4 + 4 + 8 + 8 + 8 - _, cellRepr, _, _, _, _, nrRows, nrCols, _, _, _ = unpack( - "=hhddddIIddd", bytes[64:nbytes_header] - ) + _, cellRepr, _, _, _, _, nrRows, nrCols, _, _, _ = unpack("=hhddddIIddd", bytes[64:nbytes_header]) try: celltype = CELLREPR[cellRepr] except KeyError: - raise Exception( - "{}: invalid cellRepr value ({}) in header".format(path, cellRepr) - ) + raise Exception("{}: invalid cellRepr value ({}) in header".format(path, cellRepr)) dtype = celltype["dtype"] diff --git a/src/earthkit/hydro/catchments/_toplevel.py b/src/earthkit/hydro/catchments/_toplevel.py index d94d1e20..c4e48e78 100644 --- a/src/earthkit/hydro/catchments/_toplevel.py +++ b/src/earthkit/hydro/catchments/_toplevel.py @@ -48,7 +48,8 @@ def var( - :math:`\bar{x}_j` is the weighted average at node :math:`j`, - :math:`\mathrm{Var}(x)_j` is the weighted variance at node :math:`j`. - Accumulation proceeds in topological order from the sources to the sinks. This formulation computes the population variance. + Accumulation proceeds in topological order from the sources to the sinks. + This formulation computes the population variance. Parameters ---------- @@ -126,7 +127,8 @@ def std( - :math:`\mathrm{Var}(x)_j` is the weighted variance at node :math:`j`. - :math:`\mathrm{Std}(x)_j` is the weighted standard deviation at node :math:`j`. - Accumulation proceeds in topological order from the sources to the sinks. This formulation computes the population standard deviation. + Accumulation proceeds in topological order from the sources to the sinks. + This formulation computes the population standard deviation. Parameters ---------- @@ -432,14 +434,13 @@ def max( @find_xarray -def find( - river_network, locations, overwrite=True, return_type=None, input_core_dims=None -): +def find(river_network, locations, overwrite=True, return_type=None, input_core_dims=None): r""" Delineates catchment areas. Given a field indicating one or more start locations (e.g., outlet points or pour points), - this function delineates the catchments upstream of each start location by grouping all cells that flow into these points. + this function delineates the catchments upstream of each start location + by grouping all cells that flow into these points. Parameters ---------- diff --git a/src/earthkit/hydro/catchments/_xarray.py b/src/earthkit/hydro/catchments/_xarray.py index e247f9ce..db3827ad 100644 --- a/src/earthkit/hydro/catchments/_xarray.py +++ b/src/earthkit/hydro/catchments/_xarray.py @@ -41,9 +41,7 @@ def wrapper(*args, **kwargs): locations = all_args["locations"] - stations_1d, locations, orig_locations = locations_to_1d( - xp, river_network, locations - ) + stations_1d, locations, orig_locations = locations_to_1d(xp, river_network, locations) all_args["locations"] = stations_1d @@ -55,9 +53,7 @@ def wrapper(*args, **kwargs): ndim = output.ndim dim_names = [f"axis{i + 1}" for i in range(ndim - 1)] - coords = { - dim: np.arange(size) for dim, size in zip(dim_names, output.shape[:-1]) - } + coords = {dim: np.arange(size) for dim, size in zip(dim_names, output.shape[:-1])} coords[node_default_coord] = np.arange(river_network.n_nodes)[stations_1d] dim_names.append(node_default_coord) @@ -65,7 +61,6 @@ def wrapper(*args, **kwargs): result = xr.DataArray(output, dims=dim_names, coords=coords, name="out") else: - reshuffled_func = get_reshuffled_func(func, arg_order) input_core_dims = get_input_core_dims(input_core_dims, xr_args) @@ -75,9 +70,7 @@ def wrapper(*args, **kwargs): *xr_args, input_core_dims=input_core_dims, output_core_dims=[[node_default_coord]], - dask_gufunc_kwargs={ - "output_sizes": {node_default_coord: stations_1d.shape[0]} - }, + dask_gufunc_kwargs={"output_sizes": {node_default_coord: stations_1d.shape[0]}}, output_dtypes=[float], dask="parallelized", kwargs=non_xr_kwargs, diff --git a/src/earthkit/hydro/catchments/array/_operations.py b/src/earthkit/hydro/catchments/array/_operations.py index 296b235e..9a136505 100644 --- a/src/earthkit/hydro/catchments/array/_operations.py +++ b/src/earthkit/hydro/catchments/array/_operations.py @@ -6,49 +6,37 @@ @multi_backend(allow_jax_jit=False) def var(xp, river_network, field, locations, node_weights, edge_weights): stations_1d, _, _ = locations_to_1d(xp, river_network, locations) - return _operations.var( - xp, river_network, field, stations_1d, node_weights, edge_weights - ) + return _operations.var(xp, river_network, field, stations_1d, node_weights, edge_weights) @multi_backend(allow_jax_jit=False) def std(xp, river_network, field, locations, node_weights, edge_weights): stations_1d, _, _ = locations_to_1d(xp, river_network, locations) - return _operations.std( - xp, river_network, field, stations_1d, node_weights, edge_weights - ) + return _operations.std(xp, river_network, field, stations_1d, node_weights, edge_weights) @multi_backend(allow_jax_jit=False) def mean(xp, river_network, field, locations, node_weights, edge_weights): stations_1d, _, _ = locations_to_1d(xp, river_network, locations) - return _operations.mean( - xp, river_network, field, stations_1d, node_weights, edge_weights - ) + return _operations.mean(xp, river_network, field, stations_1d, node_weights, edge_weights) @multi_backend(allow_jax_jit=False) def sum(xp, river_network, field, locations, node_weights, edge_weights): stations_1d, _, _ = locations_to_1d(xp, river_network, locations) - return _operations.sum( - xp, river_network, field, stations_1d, node_weights, edge_weights - ) + return _operations.sum(xp, river_network, field, stations_1d, node_weights, edge_weights) @multi_backend(allow_jax_jit=False) def min(xp, river_network, field, locations, node_weights, edge_weights): stations_1d, _, _ = locations_to_1d(xp, river_network, locations) - return _operations.min( - xp, river_network, field, stations_1d, node_weights, edge_weights - ) + return _operations.min(xp, river_network, field, stations_1d, node_weights, edge_weights) @multi_backend(allow_jax_jit=False) def max(xp, river_network, field, locations, node_weights, edge_weights): stations_1d, _, _ = locations_to_1d(xp, river_network, locations) - return _operations.max( - xp, river_network, field, stations_1d, node_weights, edge_weights - ) + return _operations.max(xp, river_network, field, stations_1d, node_weights, edge_weights) @multi_backend() diff --git a/src/earthkit/hydro/catchments/array/_toplevel.py b/src/earthkit/hydro/catchments/array/_toplevel.py index ece00357..26fb5d8f 100644 --- a/src/earthkit/hydro/catchments/array/_toplevel.py +++ b/src/earthkit/hydro/catchments/array/_toplevel.py @@ -36,7 +36,8 @@ def var(river_network, field, locations, node_weights=None, edge_weights=None): - :math:`\bar{x}_j` is the weighted average at node :math:`j`, - :math:`\mathrm{Var}(x)_j` is the weighted variance at node :math:`j`. - Accumulation proceeds in topological order from the sources to the sinks. This formulation computes the population variance. + Accumulation proceeds in topological order from the sources to the sinks. + This formulation computes the population variance. Parameters ---------- @@ -102,7 +103,8 @@ def std(river_network, field, locations, node_weights=None, edge_weights=None): - :math:`\mathrm{Var}(x)_j` is the weighted variance at node :math:`j`. - :math:`\mathrm{Std}(x)_j` is the weighted standard deviation at node :math:`j`. - Accumulation proceeds in topological order from the sources to the sinks. This formulation computes the population standard deviation. + Accumulation proceeds in topological order from the sources to the sinks. + This formulation computes the population standard deviation. Parameters ---------- @@ -360,7 +362,8 @@ def find(river_network, locations, overwrite=True, return_type=None): Delineates catchment areas. Given a field indicating one or more start locations (e.g., outlet points or pour points), - this function delineates the catchments upstream of each start location by grouping all cells that flow into these points. + this function delineates the catchments upstream of each start location + by grouping all cells that flow into these points. Parameters ---------- diff --git a/src/earthkit/hydro/data_structures/_network.py b/src/earthkit/hydro/data_structures/_network.py index d474f484..07b10cf0 100644 --- a/src/earthkit/hydro/data_structures/_network.py +++ b/src/earthkit/hydro/data_structures/_network.py @@ -66,7 +66,8 @@ def to_device(self, device=None, array_backend=None): Parameters ---------- device : str, optional - The device to which to transfer. Default is None, which is `'cpu'` for all backends except cupy, which is `'gpu'`. + The device to which to transfer. + Default is None, which is `'cpu'` for all backends except cupy, which is `'gpu'`. array_backend : str, optional The array backend. One of "numpy", "np", "cupy", "cp", "pytorch", "torch", "jax", "jnp", "tensorflow", "tf", "mlx" or "mx". @@ -77,7 +78,6 @@ def to_device(self, device=None, array_backend=None): RiverNetwork The modified RiverNetwork. """ - from earthkit.utils.array.convert import convert # TODO: use xp.asarray @@ -103,14 +103,9 @@ def to_device(self, device=None, array_backend=None): array_backend = self.array_backend if array_backend in ["torch", "cupy", "numpy"]: - self.groups = [ - convert(group, device=device, array_namespace=array_backend) - for group in self.groups - ] + self.groups = [convert(group, device=device, array_namespace=array_backend) for group in self.groups] self.mask = convert(self.mask, device=device, array_namespace=array_backend) - self.data = [ - convert(self.data[0], device=device, array_namespace=array_backend) - ] + self.data = [convert(self.data[0], device=device, array_namespace=array_backend)] elif array_backend == "jax": assert device == "cpu" import jax.numpy as jnp @@ -155,9 +150,7 @@ def set_default_return_type(self, return_type): None """ if return_type not in ["gridded", "masked"]: - raise ValueError( - f'Invalid return_type {return_type}. Valid types are "gridded", "masked"' - ) + raise ValueError(f'Invalid return_type {return_type}. Valid types are "gridded", "masked"') self.return_type = return_type def export(self, fpath="river_network.joblib", compression=1): diff --git a/src/earthkit/hydro/distance/_toplevel.py b/src/earthkit/hydro/distance/_toplevel.py index 51020b43..31af856e 100644 --- a/src/earthkit/hydro/distance/_toplevel.py +++ b/src/earthkit/hydro/distance/_toplevel.py @@ -31,7 +31,8 @@ def min( where: - :math:`w_{ij}` is the edge distance (e.g., downstream distance), - - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, which can include upstream and/or downstream nodes depending on passed arguments. + - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, + which can include upstream and/or downstream nodes depending on passed arguments. - :math:`d_j` is the total distance at node :math:`j`. Unreachable nodes are given a distance of :math:`\infty`. @@ -93,7 +94,8 @@ def max( where: - :math:`w_{ij}` is the edge distance (e.g., downstream distance), - - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, which can include upstream and/or downstream nodes depending on passed arguments. + - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, + which can include upstream and/or downstream nodes depending on passed arguments. - :math:`d_j` is the total distance at node :math:`j`. Unreachable nodes are given a distance of :math:`-\infty`. @@ -152,11 +154,13 @@ def to_source( where: - :math:`w_{ij}` is the edge distance (e.g., downstream distance), - - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, which can include upstream and/or downstream nodes depending on passed arguments. + - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, + which can include upstream and/or downstream nodes depending on passed arguments. - :math:`\bigoplus` is the aggregation function (max for longest path or min for shortest path). - :math:`d_j` is the total distance at node :math:`j`. - Unreachable nodes are given a distance of :math:`-\infty` if :math:`\bigoplus` is a maximum, and :math:`\infty` if :math:`\bigoplus` is a minimum. + Unreachable nodes are given a distance of :math:`-\infty` if :math:`\bigoplus` is a maximum, + and :math:`\infty` if :math:`\bigoplus` is a minimum. Parameters ---------- @@ -208,11 +212,13 @@ def to_sink( where: - :math:`w_{ij}` is the edge distance (e.g., downstream distance), - - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, which can include upstream and/or downstream nodes depending on passed arguments. + - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, + which can include upstream and/or downstream nodes depending on passed arguments. - :math:`\bigoplus` is the aggregation function (max for longest path or min for shortest path). - :math:`d_j` is the total distance at node :math:`j`. - Unreachable nodes are given a distance of :math:`-\infty` if :math:`\bigoplus` is a maximum, and :math:`\infty` if :math:`\bigoplus` is a minimum. + Unreachable nodes are given a distance of :math:`-\infty` if :math:`\bigoplus` is a maximum, + and :math:`\infty` if :math:`\bigoplus` is a minimum. Parameters ---------- diff --git a/src/earthkit/hydro/distance/array/_toplevel.py b/src/earthkit/hydro/distance/array/_toplevel.py index 5928c3bb..93755e28 100644 --- a/src/earthkit/hydro/distance/array/_toplevel.py +++ b/src/earthkit/hydro/distance/array/_toplevel.py @@ -28,7 +28,8 @@ def min( where: - :math:`w_{ij}` is the edge distance (e.g., downstream distance), - - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, which can include upstream and/or downstream nodes depending on passed arguments. + - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, + which can include upstream and/or downstream nodes depending on passed arguments. - :math:`d_j` is the total distance at node :math:`j`. Unreachable nodes are given a distance of :math:`\infty`. @@ -91,7 +92,8 @@ def max( where: - :math:`w_{ij}` is the edge distance (e.g., downstream distance), - - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, which can include upstream and/or downstream nodes depending on passed arguments. + - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, + which can include upstream and/or downstream nodes depending on passed arguments. - :math:`d_j` is the total distance at node :math:`j`. Unreachable nodes are given a distance of :math:`-\infty`. @@ -151,11 +153,13 @@ def to_source( where: - :math:`w_{ij}` is the edge distance (e.g., downstream distance), - - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, which can include upstream and/or downstream nodes depending on passed arguments. + - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, + which can include upstream and/or downstream nodes depending on passed arguments. - :math:`\bigoplus` is the aggregation function (max for longest path or min for shortest path). - :math:`d_j` is the total distance at node :math:`j`. - Unreachable nodes are given a distance of :math:`-\infty` if :math:`\bigoplus` is a maximum, and :math:`\infty` if :math:`\bigoplus` is a minimum. + Unreachable nodes are given a distance of :math:`-\infty` if :math:`\bigoplus` is a maximum, + and :math:`\infty` if :math:`\bigoplus` is a minimum. Parameters ---------- @@ -202,11 +206,13 @@ def to_sink(river_network, field=None, path="shortest", return_type=None): where: - :math:`w_{ij}` is the edge distance (e.g., downstream distance), - - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, which can include upstream and/or downstream nodes depending on passed arguments. + - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, + which can include upstream and/or downstream nodes depending on passed arguments. - :math:`\bigoplus` is the aggregation function (max for longest path or min for shortest path). - :math:`d_j` is the total distance at node :math:`j`. - Unreachable nodes are given a distance of :math:`-\infty` if :math:`\bigoplus` is a maximum, and :math:`\infty` if :math:`\bigoplus` is a minimum. + Unreachable nodes are given a distance of :math:`-\infty` if :math:`\bigoplus` is a maximum, + and :math:`\infty` if :math:`\bigoplus` is a minimum. Parameters ---------- diff --git a/src/earthkit/hydro/downstream/_toplevel.py b/src/earthkit/hydro/downstream/_toplevel.py index a6a78f0c..bc6e488b 100644 --- a/src/earthkit/hydro/downstream/_toplevel.py +++ b/src/earthkit/hydro/downstream/_toplevel.py @@ -44,7 +44,8 @@ def var( - :math:`\bar{x}_j` is the weighted average at node :math:`j`, - :math:`\mathrm{Var}(x)_j` is the weighted variance at node :math:`j`. - Accumulation proceeds in inverse topological order from the sinks to the sources. This formulation computes the population variance. + Accumulation proceeds in inverse topological order from the sinks to the sources. + This formulation computes the population variance. Parameters ---------- @@ -116,7 +117,8 @@ def std( - :math:`\mathrm{Var}(x)_j` is the weighted variance at node :math:`j`. - :math:`\mathrm{Std}(x)_j` is the weighted standard deviation at node :math:`j`. - Accumulation proceeds in inverse topological order from the sinks to the sources. This formulation computes the population standard deviation. + Accumulation proceeds in inverse topological order from the sinks to the sources. + This formulation computes the population standard deviation. Parameters ---------- diff --git a/src/earthkit/hydro/downstream/array/_operations.py b/src/earthkit/hydro/downstream/array/_operations.py index 1c8f91f5..a03f9492 100644 --- a/src/earthkit/hydro/downstream/array/_operations.py +++ b/src/earthkit/hydro/downstream/array/_operations.py @@ -26,9 +26,7 @@ def var(xp, river_network, field, node_weights, edge_weights, return_type): return_type = river_network.return_type if return_type is None else return_type if return_type not in ["gridded", "masked"]: raise ValueError("return_type must be either 'gridded' or 'masked'.") - decorated_calculate_downstream_metric = mask(return_type == "gridded")( - calculate_downstream_metric - ) + decorated_calculate_downstream_metric = mask(return_type == "gridded")(calculate_downstream_metric) return decorated_calculate_downstream_metric( xp, river_network, @@ -44,9 +42,7 @@ def std(xp, river_network, field, node_weights, edge_weights, return_type): return_type = river_network.return_type if return_type is None else return_type if return_type not in ["gridded", "masked"]: raise ValueError("return_type must be either 'gridded' or 'masked'.") - decorated_calculate_downstream_metric = mask(return_type == "gridded")( - calculate_downstream_metric - ) + decorated_calculate_downstream_metric = mask(return_type == "gridded")(calculate_downstream_metric) return decorated_calculate_downstream_metric( xp, river_network, @@ -62,9 +58,7 @@ def mean(xp, river_network, field, node_weights, edge_weights, return_type): return_type = river_network.return_type if return_type is None else return_type if return_type not in ["gridded", "masked"]: raise ValueError("return_type must be either 'gridded' or 'masked'.") - decorated_calculate_downstream_metric = mask(return_type == "gridded")( - calculate_downstream_metric - ) + decorated_calculate_downstream_metric = mask(return_type == "gridded")(calculate_downstream_metric) return decorated_calculate_downstream_metric( xp, river_network, @@ -80,9 +74,7 @@ def sum(xp, river_network, field, node_weights, edge_weights, return_type): return_type = river_network.return_type if return_type is None else return_type if return_type not in ["gridded", "masked"]: raise ValueError("return_type must be either 'gridded' or 'masked'.") - decorated_calculate_downstream_metric = mask(return_type == "gridded")( - calculate_downstream_metric - ) + decorated_calculate_downstream_metric = mask(return_type == "gridded")(calculate_downstream_metric) return decorated_calculate_downstream_metric( xp, river_network, @@ -98,9 +90,7 @@ def min(xp, river_network, field, node_weights, edge_weights, return_type): return_type = river_network.return_type if return_type is None else return_type if return_type not in ["gridded", "masked"]: raise ValueError("return_type must be either 'gridded' or 'masked'.") - decorated_calculate_downstream_metric = mask(return_type == "gridded")( - calculate_downstream_metric - ) + decorated_calculate_downstream_metric = mask(return_type == "gridded")(calculate_downstream_metric) return decorated_calculate_downstream_metric( xp, river_network, @@ -116,9 +106,7 @@ def max(xp, river_network, field, node_weights, edge_weights, return_type): return_type = river_network.return_type if return_type is None else return_type if return_type not in ["gridded", "masked"]: raise ValueError("return_type must be either 'gridded' or 'masked'.") - decorated_calculate_downstream_metric = mask(return_type == "gridded")( - calculate_downstream_metric - ) + decorated_calculate_downstream_metric = mask(return_type == "gridded")(calculate_downstream_metric) return decorated_calculate_downstream_metric( xp, river_network, diff --git a/src/earthkit/hydro/downstream/array/_toplevel.py b/src/earthkit/hydro/downstream/array/_toplevel.py index f526d4e7..50f2fd48 100644 --- a/src/earthkit/hydro/downstream/array/_toplevel.py +++ b/src/earthkit/hydro/downstream/array/_toplevel.py @@ -35,7 +35,8 @@ def var(river_network, field, node_weights=None, edge_weights=None, return_type= - :math:`\bar{x}_j` is the weighted average at node :math:`j`, - :math:`\mathrm{Var}(x)_j` is the weighted variance at node :math:`j`. - Accumulation proceeds in inverse topological order from the sinks to the sources. This formulation computes the population variance. + Accumulation proceeds in inverse topological order from the sinks to the sources. + This formulation computes the population variance. Parameters ---------- @@ -101,7 +102,8 @@ def std(river_network, field, node_weights=None, edge_weights=None, return_type= - :math:`\mathrm{Var}(x)_j` is the weighted variance at node :math:`j`. - :math:`\mathrm{Std}(x)_j` is the weighted standard deviation at node :math:`j`. - Accumulation proceeds in inverse topological order from the sinks to the sources. This formulation computes the population standard deviation. + Accumulation proceeds in inverse topological order from the sinks to the sources. + This formulation computes the population standard deviation. Parameters ---------- diff --git a/src/earthkit/hydro/length/_toplevel.py b/src/earthkit/hydro/length/_toplevel.py index e0b11977..aba738e6 100644 --- a/src/earthkit/hydro/length/_toplevel.py +++ b/src/earthkit/hydro/length/_toplevel.py @@ -31,7 +31,8 @@ def min( where: - :math:`w_i` is the node length (e.g., pixel length), - - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, which can include upstream and/or downstream nodes depending on passed arguments. + - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, + which can include upstream and/or downstream nodes depending on passed arguments. - :math:`l_j` is the total length at node :math:`j`. Unreachable nodes are given a length of :math:`\infty`. @@ -93,7 +94,8 @@ def max( where: - :math:`w_i` is the node length (e.g., pixel length), - - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, which can include upstream and/or downstream nodes depending on passed arguments. + - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, + which can include upstream and/or downstream nodes depending on passed arguments. - :math:`l_j` is the total length at node :math:`j`. Unreachable nodes are given a length of :math:`-\infty`. @@ -152,11 +154,13 @@ def to_source( where: - :math:`w_i` is the node length (e.g., pixel length), - - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, which can include upstream and/or downstream nodes depending on passed arguments. + - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, + which can include upstream and/or downstream nodes depending on passed arguments. - :math:`\bigoplus` is the aggregation function (max for longest path or min for shortest path). - :math:`l_j` is the total length at node :math:`j`. - Unreachable nodes are given a distance of :math:`-\infty` if :math:`\bigoplus` is a maximum, and :math:`\infty` if :math:`\bigoplus` is a minimum. + Unreachable nodes are given a distance of :math:`-\infty` if :math:`\bigoplus` is a maximum, + and :math:`\infty` if :math:`\bigoplus` is a minimum. Parameters ---------- @@ -208,11 +212,13 @@ def to_sink( where: - :math:`w_i` is the node length (e.g., pixel length), - - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, which can include upstream and/or downstream nodes depending on passed arguments. + - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, + which can include upstream and/or downstream nodes depending on passed arguments. - :math:`\bigoplus` is the aggregation function (max for longest path or min for shortest path). - :math:`l_j` is the total length at node :math:`j`. - Unreachable nodes are given a distance of :math:`-\infty` if :math:`\bigoplus` is a maximum, and :math:`\infty` if :math:`\bigoplus` is a minimum. + Unreachable nodes are given a distance of :math:`-\infty` if :math:`\bigoplus` is a maximum, + and :math:`\infty` if :math:`\bigoplus` is a minimum. Parameters ---------- diff --git a/src/earthkit/hydro/length/array/_toplevel.py b/src/earthkit/hydro/length/array/_toplevel.py index 395237fe..6363c220 100644 --- a/src/earthkit/hydro/length/array/_toplevel.py +++ b/src/earthkit/hydro/length/array/_toplevel.py @@ -28,7 +28,8 @@ def min( where: - :math:`w_i` is the node length (e.g., pixel length), - - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, which can include upstream and/or downstream nodes depending on passed arguments. + - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, + which can include upstream and/or downstream nodes depending on passed arguments. - :math:`l_j` is the total length at node :math:`j`. Unreachable nodes are given a length of :math:`\infty`. @@ -91,7 +92,8 @@ def max( where: - :math:`w_i` is the node length (e.g., pixel length), - - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, which can include upstream and/or downstream nodes depending on passed arguments. + - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, + which can include upstream and/or downstream nodes depending on passed arguments. - :math:`l_j` is the total length at node :math:`j`. Unreachable nodes are given a length of :math:`-\infty`. @@ -151,11 +153,13 @@ def to_source( where: - :math:`w_i` is the node length (e.g., pixel length), - - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, which can include upstream and/or downstream nodes depending on passed arguments. + - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, + which can include upstream and/or downstream nodes depending on passed arguments. - :math:`\bigoplus` is the aggregation function (max for longest path or min for shortest path). - :math:`l_j` is the total length at node :math:`j`. - Unreachable nodes are given a distance of :math:`-\infty` if :math:`\bigoplus` is a maximum, and :math:`\infty` if :math:`\bigoplus` is a minimum. + Unreachable nodes are given a distance of :math:`-\infty` if :math:`\bigoplus` is a maximum, + and :math:`\infty` if :math:`\bigoplus` is a minimum. Parameters ---------- @@ -207,11 +211,13 @@ def to_sink( where: - :math:`w_i` is the node length (e.g., pixel length), - - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, which can include upstream and/or downstream nodes depending on passed arguments. + - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, + which can include upstream and/or downstream nodes depending on passed arguments. - :math:`\bigoplus` is the aggregation function (max for longest path or min for shortest path). - :math:`l_j` is the total length at node :math:`j`. - Unreachable nodes are given a distance of :math:`-\infty` if :math:`\bigoplus` is a maximum, and :math:`\infty` if :math:`\bigoplus` is a minimum. + Unreachable nodes are given a distance of :math:`-\infty` if :math:`\bigoplus` is a maximum, + and :math:`\infty` if :math:`\bigoplus` is a minimum. Parameters ---------- diff --git a/src/earthkit/hydro/move/_toplevel.py b/src/earthkit/hydro/move/_toplevel.py index 45b817f3..be2b932a 100644 --- a/src/earthkit/hydro/move/_toplevel.py +++ b/src/earthkit/hydro/move/_toplevel.py @@ -59,11 +59,10 @@ def upstream( Returns ------- xarray object - Array of values after movement up the river network for every river network node or gridcell, depending on `return_type`. + Array of values after movement up the river network for every river network node or gridcell, + depending on `return_type`. """ - return array.upstream( - river_network, field, node_weights, edge_weights, metric, return_type - ) + return array.upstream(river_network, field, node_weights, edge_weights, metric, return_type) @xarray @@ -123,8 +122,7 @@ def downstream( Returns ------- xarray object - Array of values after movement down the river network for every river network node or gridcell, depending on `return_type`. + Array of values after movement down the river network for every river network node or gridcell, + depending on `return_type`. """ - return array.downstream( - river_network, field, node_weights, edge_weights, metric, return_type - ) + return array.downstream(river_network, field, node_weights, edge_weights, metric, return_type) diff --git a/src/earthkit/hydro/move/array/_operations.py b/src/earthkit/hydro/move/array/_operations.py index 29491550..bcd82b7a 100644 --- a/src/earthkit/hydro/move/array/_operations.py +++ b/src/earthkit/hydro/move/array/_operations.py @@ -12,9 +12,7 @@ def upstream(xp, river_network, field, node_weights, edge_weights, metric, retur @multi_backend(jax_static_args=["xp", "river_network", "return_type", "metric"]) -def downstream( - xp, river_network, field, node_weights, edge_weights, metric, return_type -): +def downstream(xp, river_network, field, node_weights, edge_weights, metric, return_type): return_type = river_network.return_type if return_type is None else return_type if return_type not in ["gridded", "masked"]: raise ValueError("return_type must be either 'gridded' or 'masked'.") diff --git a/src/earthkit/hydro/move/array/_toplevel.py b/src/earthkit/hydro/move/array/_toplevel.py index beea663e..53a676c0 100644 --- a/src/earthkit/hydro/move/array/_toplevel.py +++ b/src/earthkit/hydro/move/array/_toplevel.py @@ -54,7 +54,8 @@ def upstream( Returns ------- array-like - Array of values after movement up the river network for every river network node or gridcell, depending on `return_type`. + Array of values after movement up the river network for every river network node or gridcell, + depending on `return_type`. """ return _operations.upstream( river_network=river_network, @@ -119,7 +120,8 @@ def downstream( Returns ------- array-like - Array of values after movement down the river network for every river network node or gridcell, depending on `return_type`. + Array of values after movement down the river network for every river network node or gridcell, + depending on `return_type`. """ return _operations.downstream( river_network=river_network, diff --git a/src/earthkit/hydro/river_network/_cache.py b/src/earthkit/hydro/river_network/_cache.py index 93f7e7c7..36a0318a 100644 --- a/src/earthkit/hydro/river_network/_cache.py +++ b/src/earthkit/hydro/river_network/_cache.py @@ -4,8 +4,8 @@ from hashlib import sha256 import joblib - from earthkit.hydro._version import __version__ as ekh_version + from earthkit.hydro.data_structures._network import RiverNetwork # read in only up to second decimal point diff --git a/src/earthkit/hydro/river_network/_river_network.py b/src/earthkit/hydro/river_network/_river_network.py index 175ca828..262bc553 100644 --- a/src/earthkit/hydro/river_network/_river_network.py +++ b/src/earthkit/hydro/river_network/_river_network.py @@ -2,6 +2,7 @@ from urllib.request import urlopen import joblib +from earthkit.hydro._version import __version__ as ekh_version from earthkit.hydro._readers import ( find_main_var, @@ -12,7 +13,6 @@ ) from earthkit.hydro._utils.coords import get_core_grid_dims from earthkit.hydro._utils.readers import from_file -from earthkit.hydro._version import __version__ as ekh_version from earthkit.hydro.data_structures._network import RiverNetwork from ._cache import cache @@ -21,11 +21,7 @@ # if dev version, try add +1 to major version # i.e. 0.1.dev90+gfdf4e33.d20250107 -> 1 # i.e. 0.1.0 -> 0 -ekh_version = ( - int(ekh_version.split(".")[0]) + 1 - if "dev" in ekh_version - else int(ekh_version.split(".")[0]) -) +ekh_version = int(ekh_version.split(".")[0]) + 1 if "dev" in ekh_version else int(ekh_version.split(".")[0]) @cache @@ -57,7 +53,8 @@ def create( use_cache : bool, optional Whether to cache the loaded/created river network for quicker reloading. Default is True. cache_dir : str, optional - Where to store the cached river networks. Default is None, which uses `tempfile.mkdtemp(suffix="_earthkit_hydro")`. + Where to store the cached river networks. + Default is None, which uses `tempfile.mkdtemp(suffix="_earthkit_hydro")`. cache_fname : str, optional A string template for the cache filename convention. cache_compression : int, optional @@ -75,10 +72,7 @@ def create( with urlopen(path) as response: river_network_storage = joblib.load(BytesIO(response.read())) else: - raise ValueError( - "Unsupported source for river network format" - f"{river_network_format}: {source}." - ) + raise ValueError(f"Unsupported source for river network format{river_network_format}: {source}.") elif river_network_format == "cama": ekd = import_earthkit_or_prompt_install(river_network_format, source) data = ekd.from_source(source, path).to_xarray(mask_and_scale=False) @@ -89,25 +83,17 @@ def create( coord1: data[coord1].values, coord2: data[coord2].values, } - elif ( - river_network_format == "pcr_d8" - or river_network_format == "esri_d8" - or river_network_format == "merit_d8" - ): + elif river_network_format == "pcr_d8" or river_network_format == "esri_d8" or river_network_format == "merit_d8": if path.endswith(".map"): data = from_file(path, mask=False) - river_network_storage = from_d8( - data, river_network_format=river_network_format - ) + river_network_storage = from_d8(data, river_network_format=river_network_format) # coords not available else: ekd = import_earthkit_or_prompt_install(river_network_format, source) data = ekd.from_source(source, path).to_xarray(mask_and_scale=False) coord1, coord2 = get_core_grid_dims(data) var_name = find_main_var(data) - river_network_storage = from_d8( - data[var_name].values, river_network_format=river_network_format - ) + river_network_storage = from_d8(data[var_name].values, river_network_format=river_network_format) river_network_storage.coords = { coord1: data[coord1].values, coord2: data[coord2].values, @@ -188,13 +174,27 @@ def load( References ---------- - .. [1] Choulga, Margarita; Moschini, Francesca; Mazzetti, Cinzia; Grimaldi, Stefania; Disperati, Juliana; Beck, Hylke; Salamon, Peter; Prudhomme, Christel (2023): LISFLOOD static and parameter maps for Europe. European Commission, Joint Research Centre (JRC) [Dataset] PID: http://data.europa.eu/89h/f572c443-7466-4adf-87aa-c0847a169f23 - .. [2] Choulga, Margarita; Moschini, Francesca; Mazzetti, Cinzia; Disperati, Juliana; Grimaldi, Stefania; Beck, Hylke; Salamon, Peter; Prudhomme, Christel (2023): LISFLOOD static and parameter maps for GloFAS. European Commission, Joint Research Centre (JRC) [Dataset] PID: http://data.europa.eu/89h/68050d73-9c06-499c-a441-dc5053cb0c86 - .. [3] Yamazaki, Dai; Ikeshima, Daiki; Sosa, Jeison; Bates, Paul D.; Allen, George H.; Pavelsky, Tamlin M. (2019): MERIT Hydro: A high-resolution global hydrography map based on latest topography datasets. Water Resources Research, vol.55, pp.5053-5073, 2019, DOI: 10.1029/2019WR024873 - .. [4] Lehner, Bernhard; Verdin, Kristine; Jarvis, Andy (2008): New global hydrography derived from spaceborne elevation data. Eos, Transactions, 89(10): 93-94. Data available at https://www.hydrosheds.org. - .. [5] Wortmann, Michel; Slater, Louise; Hawker, Laurence; Liu, Yinxue; Neal, Jeffrey; Zhang, Boen; Schwenk, Jon; Allen, George H.; Ashworth, Philip; Boothroyd, Richard; Cloke, Hannah; Delorme, Pauline; Gebrechorkos, Solomon H.; Griffith, Helen; Leyland, Julian; McLelland, Stuart; Nicholas, Andrew P.; Sambrook-Smith, Gregory; Vahidi, Elham; Parsons, Daniel; Darby, Stephen E. (2025). Global River Topology (GRIT): A bifurcating river hydrography. Water Resources Research, 61(5), DOI: 10.1029/2024WR038308 + .. [1] Choulga, Margarita; Moschini, Francesca; Mazzetti, Cinzia; Grimaldi, Stefania; Disperati, Juliana; + Beck, Hylke; Salamon, Peter; Prudhomme, Christel (2023): LISFLOOD static and parameter maps for Europe. + European Commission, Joint Research Centre (JRC) [Dataset] + PID: http://data.europa.eu/89h/f572c443-7466-4adf-87aa-c0847a169f23 + .. [2] Choulga, Margarita; Moschini, Francesca; Mazzetti, Cinzia; Disperati, Juliana; Grimaldi, Stefania; + Beck, Hylke; Salamon, Peter; Prudhomme, Christel (2023): LISFLOOD static and parameter maps for GloFAS. + European Commission, Joint Research Centre (JRC) [Dataset] + PID: http://data.europa.eu/89h/68050d73-9c06-499c-a441-dc5053cb0c86 + .. [3] Yamazaki, Dai; Ikeshima, Daiki; Sosa, Jeison; Bates, Paul D.; Allen, George H.; Pavelsky, Tamlin M. + (2019): MERIT Hydro: A high-resolution global hydrography map based on latest topography datasets. + Water Resources Research, vol.55, pp.5053-5073, 2019, DOI: 10.1029/2019WR024873 + .. [4] Lehner, Bernhard; Verdin, Kristine; Jarvis, Andy (2008): New global hydrography derived from + spaceborne elevation data. Eos, Transactions, 89(10): 93-94. + Data available at https://www.hydrosheds.org. + .. [5] Wortmann, Michel; Slater, Louise; Hawker, Laurence; Liu, Yinxue; Neal, Jeffrey; Zhang, Boen; + Schwenk, Jon; Allen, George H.; Ashworth, Philip; Boothroyd, Richard; Cloke, Hannah; Delorme, + Pauline; Gebrechorkos, Solomon H.; Griffith, Helen; Leyland, Julian; McLelland, Stuart; Nicholas, + Andrew P.; Sambrook-Smith, Gregory; Vahidi, Elham; Parsons, Daniel; Darby, Stephen E. (2025). + Global River Topology (GRIT): A bifurcating river hydrography. + Water Resources Research, 61(5), DOI: 10.1029/2024WR038308 """ - try: uri = data_source.format( ekh_version=ekh_version, @@ -224,7 +224,6 @@ def available( data_source : str, optional Base URI to read available networks from. """ - with urlopen(data_source) as response: html = response.read() diff --git a/src/earthkit/hydro/streamorder/array/_operations.py b/src/earthkit/hydro/streamorder/array/_operations.py index 05c315e2..1cd8a1ba 100644 --- a/src/earthkit/hydro/streamorder/array/_operations.py +++ b/src/earthkit/hydro/streamorder/array/_operations.py @@ -16,9 +16,7 @@ def _ufunc_strahler( maxes = xp.scatter_max(maxes, did, up_maxes) maxes_did = xp.gather(maxes, did) counts_uid = xp.gather(counts, uid) - counts = xp.scatter_assign( - counts, did, (old_maxes == maxes_did).astype(int) * counts_uid - ) + counts = xp.scatter_assign(counts, did, (old_maxes == maxes_did).astype(int) * counts_uid) counts = xp.scatter_add(counts, did, (up_maxes == maxes_did).astype(int)) counts_did = xp.gather(counts, did) maxes = xp.scatter_assign(maxes, did, maxes_did + (counts_did > 1).astype(int)) @@ -63,9 +61,7 @@ def operation( def strahler(xp, river_network, return_type): field = xp.zeros(river_network.n_nodes, dtype=float) - field = xp.scatter_assign( - field, river_network.sources, xp.ones(river_network.sources.shape, dtype=float) - ) + field = xp.scatter_assign(field, river_network.sources, xp.ones(river_network.sources.shape, dtype=float)) counts = xp.zeros(river_network.n_nodes, dtype=float) decorated_func = mask(return_type == "gridded")(flow_strahler) @@ -75,9 +71,7 @@ def strahler(xp, river_network, return_type): @multi_backend(jax_static_args=["xp", "river_network", "return_type"]) def shreve(xp, river_network, return_type): field = xp.zeros(river_network.n_nodes, dtype=float) - field = xp.scatter_assign( - field, river_network.sources, xp.ones(river_network.sources.shape, dtype=float) - ) + field = xp.scatter_assign(field, river_network.sources, xp.ones(river_network.sources.shape, dtype=float)) return upstream_sum( river_network=river_network, diff --git a/src/earthkit/hydro/subnetwork/_toplevel.py b/src/earthkit/hydro/subnetwork/_toplevel.py index f73b853a..be563a7b 100644 --- a/src/earthkit/hydro/subnetwork/_toplevel.py +++ b/src/earthkit/hydro/subnetwork/_toplevel.py @@ -35,9 +35,7 @@ def from_mask(river_network: RiverNetwork, node_mask=None, edge_mask=None, copy= if node_mask is not None: if node_mask.shape[-2:] == river_network.shape: - node_mask = mask_last2_dims( - np, node_mask, river_network.mask, node_mask.shape - ) + node_mask = mask_last2_dims(np, node_mask, river_network.mask, node_mask.shape) node_relabel = np.empty(river_network.n_nodes, dtype=int) node_relabel[node_mask] = np.arange(node_mask.sum()) @@ -48,9 +46,7 @@ def from_mask(river_network: RiverNetwork, node_mask=None, edge_mask=None, copy= node_mask[storage.sorted_data[0]] & node_mask[storage.sorted_data[1]] ) elif edge_mask is None: - valid_edges = ( - node_mask[storage.sorted_data[0]] & node_mask[storage.sorted_data[1]] - ) + valid_edges = node_mask[storage.sorted_data[0]] & node_mask[storage.sorted_data[1]] else: valid_edges = edge_mask[storage.sorted_data[2]] @@ -88,7 +84,6 @@ def crop(river_network: RiverNetwork, copy=True): RiverNetwork The river network object created from the given data. """ - if river_network.array_backend != "numpy" or copy is not True: raise NotImplementedError @@ -101,9 +96,7 @@ def crop(river_network: RiverNetwork, copy=True): storage.shape = (int(row_max - row_min + 1), int(col_max - col_min + 1)) - storage.mask = np.ravel_multi_index( - (rows - row_min, cols - col_min), dims=storage.shape - ) + storage.mask = np.ravel_multi_index((rows - row_min, cols - col_min), dims=storage.shape) for i, key in enumerate(storage.coords.keys()): if i == 0: diff --git a/src/earthkit/hydro/upstream/_toplevel.py b/src/earthkit/hydro/upstream/_toplevel.py index ce4aa0f4..fb67b1a5 100644 --- a/src/earthkit/hydro/upstream/_toplevel.py +++ b/src/earthkit/hydro/upstream/_toplevel.py @@ -44,7 +44,8 @@ def var( - :math:`\bar{x}_j` is the weighted average at node :math:`j`, - :math:`\mathrm{Var}(x)_j` is the weighted variance at node :math:`j`. - Accumulation proceeds in topological order from the sources to the sinks. This formulation computes the population variance. + Accumulation proceeds in topological order from the sources to the sinks. + This formulation computes the population variance. Parameters ---------- @@ -116,7 +117,8 @@ def std( - :math:`\mathrm{Var}(x)_j` is the weighted variance at node :math:`j`. - :math:`\mathrm{Std}(x)_j` is the weighted standard deviation at node :math:`j`. - Accumulation proceeds in topological order from the sources to the sinks. This formulation computes the population standard deviation. + Accumulation proceeds in topological order from the sources to the sinks. + This formulation computes the population standard deviation. Parameters ---------- diff --git a/src/earthkit/hydro/upstream/array/_operations.py b/src/earthkit/hydro/upstream/array/_operations.py index 7f563f6a..df1d2760 100644 --- a/src/earthkit/hydro/upstream/array/_operations.py +++ b/src/earthkit/hydro/upstream/array/_operations.py @@ -26,9 +26,7 @@ def var(xp, river_network, field, node_weights, edge_weights, return_type): return_type = river_network.return_type if return_type is None else return_type if return_type not in ["gridded", "masked"]: raise ValueError("return_type must be either 'gridded' or 'masked'.") - decorated_calculate_upstream_metric = mask(return_type == "gridded")( - calculate_upstream_metric - ) + decorated_calculate_upstream_metric = mask(return_type == "gridded")(calculate_upstream_metric) return decorated_calculate_upstream_metric( xp, river_network, @@ -51,9 +49,7 @@ def std( return_type = river_network.return_type if return_type is None else return_type if return_type not in ["gridded", "masked"]: raise ValueError("return_type must be either 'gridded' or 'masked'.") - decorated_calculate_upstream_metric = mask(return_type == "gridded")( - calculate_upstream_metric - ) + decorated_calculate_upstream_metric = mask(return_type == "gridded")(calculate_upstream_metric) return decorated_calculate_upstream_metric( xp, river_network, @@ -69,9 +65,7 @@ def mean(xp, river_network, field, node_weights, edge_weights, return_type): return_type = river_network.return_type if return_type is None else return_type if return_type not in ["gridded", "masked"]: raise ValueError("return_type must be either 'gridded' or 'masked'.") - decorated_calculate_upstream_metric = mask(return_type == "gridded")( - calculate_upstream_metric - ) + decorated_calculate_upstream_metric = mask(return_type == "gridded")(calculate_upstream_metric) return decorated_calculate_upstream_metric( xp, river_network, @@ -87,9 +81,7 @@ def sum(xp, river_network, field, node_weights, edge_weights, return_type): return_type = river_network.return_type if return_type is None else return_type if return_type not in ["gridded", "masked"]: raise ValueError("return_type must be either 'gridded' or 'masked'.") - decorated_calculate_upstream_metric = mask(return_type == "gridded")( - calculate_upstream_metric - ) + decorated_calculate_upstream_metric = mask(return_type == "gridded")(calculate_upstream_metric) return decorated_calculate_upstream_metric( xp, river_network, @@ -105,9 +97,7 @@ def min(xp, river_network, field, node_weights, edge_weights, return_type): return_type = river_network.return_type if return_type is None else return_type if return_type not in ["gridded", "masked"]: raise ValueError("return_type must be either 'gridded' or 'masked'.") - decorated_calculate_upstream_metric = mask(return_type == "gridded")( - calculate_upstream_metric - ) + decorated_calculate_upstream_metric = mask(return_type == "gridded")(calculate_upstream_metric) return decorated_calculate_upstream_metric( xp, river_network, @@ -123,9 +113,7 @@ def max(xp, river_network, field, node_weights, edge_weights, return_type): return_type = river_network.return_type if return_type is None else return_type if return_type not in ["gridded", "masked"]: raise ValueError("return_type must be either 'gridded' or 'masked'.") - decorated_calculate_upstream_metric = mask(return_type == "gridded")( - calculate_upstream_metric - ) + decorated_calculate_upstream_metric = mask(return_type == "gridded")(calculate_upstream_metric) return decorated_calculate_upstream_metric( xp, river_network, diff --git a/src/earthkit/hydro/upstream/array/_toplevel.py b/src/earthkit/hydro/upstream/array/_toplevel.py index 48273e1f..4e73a0cb 100644 --- a/src/earthkit/hydro/upstream/array/_toplevel.py +++ b/src/earthkit/hydro/upstream/array/_toplevel.py @@ -35,7 +35,8 @@ def var(river_network, field, node_weights=None, edge_weights=None, return_type= - :math:`\bar{x}_j` is the weighted average at node :math:`j`, - :math:`\mathrm{Var}(x)_j` is the weighted variance at node :math:`j`. - Accumulation proceeds in topological order from the sources to the sinks. This formulation computes the population variance. + Accumulation proceeds in topological order from the sources to the sinks. + This formulation computes the population variance. Parameters ---------- @@ -101,7 +102,8 @@ def std(river_network, field, node_weights=None, edge_weights=None, return_type= - :math:`\mathrm{Var}(x)_j` is the weighted variance at node :math:`j`. - :math:`\mathrm{Std}(x)_j` is the weighted standard deviation at node :math:`j`. - Accumulation proceeds in topological order from the sources to the sinks. This formulation computes the population standard deviation. + Accumulation proceeds in topological order from the sources to the sinks. + This formulation computes the population standard deviation. Parameters ---------- diff --git a/tests/_test_inputs/accumulation.py b/tests/_test_inputs/accumulation.py index 9add82a6..e8fe2dba 100644 --- a/tests/_test_inputs/accumulation.py +++ b/tests/_test_inputs/accumulation.py @@ -3,50 +3,30 @@ # RIVER NETWORK ONE # 1a: unit field input -input_field_1a = np.array( - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int -) +input_field_1a = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int) mv_1a = np.iinfo(np.int64).max -upstream_metric_sum_1a = np.array( - [1, 1, 1, 1, 1, 2, 2, 3, 2, 1, 3, 3, 9, 3, 1, 1, 20, 3, 2, 1], dtype=int -) +upstream_metric_sum_1a = np.array([1, 1, 1, 1, 1, 2, 2, 3, 2, 1, 3, 3, 9, 3, 1, 1, 20, 3, 2, 1], dtype=int) -upstream_metric_max_1a = np.array( - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int -) +upstream_metric_max_1a = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int) -upstream_metric_min_1a = np.array( - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int -) +upstream_metric_min_1a = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int) -upstream_metric_mean_1a = np.array( - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=float -) +upstream_metric_mean_1a = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=float) -upstream_metric_product_1a = np.array( - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int -) +upstream_metric_product_1a = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int) # 1b: non-missing integer field input -input_field_1b = np.array( - [1, 2, 3, -1, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, -1], dtype=int -) +input_field_1b = np.array([1, 2, 3, -1, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, -1], dtype=int) mv_1b = np.iinfo(np.int64).max -upstream_metric_sum_1b = np.array( - [1, 2, 3, -1, 5, 7, 9, 10, 14, 10, 8, 11, 46, 19, 5, 6, 94, 16, 8, -1], dtype=int -) +upstream_metric_sum_1b = np.array([1, 2, 3, -1, 5, 7, 9, 10, 14, 10, 8, 11, 46, 19, 5, 6, 94, 16, 8, -1], dtype=int) -upstream_metric_max_1b = np.array( - [1, 2, 3, -1, 5, 6, 7, 8, 9, 10, 6, 7, 10, 10, 5, 6, 10, 9, 9, -1], dtype=int -) +upstream_metric_max_1b = np.array([1, 2, 3, -1, 5, 6, 7, 8, 9, 10, 6, 7, 10, 10, 5, 6, 10, 9, 9, -1], dtype=int) -upstream_metric_min_1b = np.array( - [1, 2, 3, -1, 5, 1, 2, -1, 5, 10, 1, 2, -1, 4, 5, 6, -1, -1, -1, -1], dtype=int -) +upstream_metric_min_1b = np.array([1, 2, 3, -1, 5, 1, 2, -1, 5, 10, 1, 2, -1, 4, 5, 6, -1, -1, -1, -1], dtype=int) upstream_metric_mean_1b = np.array( [ @@ -498,23 +478,15 @@ # 1g: missing integer field input with mv=-1 -input_field_1g = np.array( - [1, 2, 3, -1, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, -1], dtype=int -) +input_field_1g = np.array([1, 2, 3, -1, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, -1], dtype=int) mv_1g = -1 -upstream_metric_sum_1g = np.array( - [1, 2, 3, -1, 5, 7, 9, -1, 14, 10, 8, 11, -1, 19, 5, 6, -1, -1, -1, -1], dtype=int -) +upstream_metric_sum_1g = np.array([1, 2, 3, -1, 5, 7, 9, -1, 14, 10, 8, 11, -1, 19, 5, 6, -1, -1, -1, -1], dtype=int) -upstream_metric_max_1g = np.array( - [1, 2, 3, -1, 5, 6, 7, -1, 9, 10, 6, 7, -1, 10, 5, 6, -1, -1, -1, -1], dtype=int -) +upstream_metric_max_1g = np.array([1, 2, 3, -1, 5, 6, 7, -1, 9, 10, 6, 7, -1, 10, 5, 6, -1, -1, -1, -1], dtype=int) -upstream_metric_min_1g = np.array( - [1, 2, 3, -1, 5, 1, 2, -1, 5, 10, 1, 2, -1, 4, 5, 6, -1, -1, -1, -1], dtype=int -) +upstream_metric_min_1g = np.array([1, 2, 3, -1, 5, 1, 2, -1, 5, 10, 1, 2, -1, 4, 5, 6, -1, -1, -1, -1], dtype=int) upstream_metric_mean_1g = np.array( [ @@ -553,44 +525,26 @@ mv_2a = np.iinfo(np.int64).max -upstream_metric_sum_2a = np.array( - [2, 1, 2, 1, 1, 2, 7, 3, 1, 1, 10, 6, 1, 13, 1, 2], dtype=int -) +upstream_metric_sum_2a = np.array([2, 1, 2, 1, 1, 2, 7, 3, 1, 1, 10, 6, 1, 13, 1, 2], dtype=int) -upstream_metric_max_2a = np.array( - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int -) +upstream_metric_max_2a = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int) -upstream_metric_min_2a = np.array( - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int -) +upstream_metric_min_2a = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int) -upstream_metric_mean_2a = np.array( - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=float -) +upstream_metric_mean_2a = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=float) -upstream_metric_product_2a = np.array( - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int -) +upstream_metric_product_2a = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int) # 2b: non-missing integer field input -input_field_2b = np.array( - [1, 2, -1, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1, 14, 15, 16], dtype=int -) +input_field_2b = np.array([1, 2, -1, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1, 14, 15, 16], dtype=int) mv_2b = np.iinfo(np.int64).max -upstream_metric_sum_2b = np.array( - [0, 2, 1, 4, 5, 11, 59, 9, 9, 10, 81, 52, -1, 114, 15, 31], dtype=int -) +upstream_metric_sum_2b = np.array([0, 2, 1, 4, 5, 11, 59, 9, 9, 10, 81, 52, -1, 114, 15, 31], dtype=int) -upstream_metric_max_2b = np.array( - [1, 2, 2, 4, 5, 6, 16, 8, 9, 10, 16, 16, -1, 16, 15, 16], dtype=int -) +upstream_metric_max_2b = np.array([1, 2, 2, 4, 5, 6, 16, 8, 9, 10, 16, 16, -1, 16, 15, 16], dtype=int) -upstream_metric_min_2b = np.array( - [-1, 2, -1, 4, 5, 5, -1, -1, 9, 10, -1, -1, -1, -1, 15, 15], dtype=int -) +upstream_metric_min_2b = np.array([-1, 2, -1, 4, 5, 5, -1, -1, 9, 10, -1, -1, -1, -1, 15, 15], dtype=int) upstream_metric_mean_2b = np.array( [ @@ -646,12 +600,8 @@ # 2g: missing integer field input with mv=-1 -input_field_2g = np.array( - [1, 2, -1, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1, 14, 15, 16], dtype=int -) +input_field_2g = np.array([1, 2, -1, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1, 14, 15, 16], dtype=int) mv_2g = -1 -upstream_metric_sum_2g = np.array( - [-1, 2, -1, 4, 5, 11, -1, -1, 9, 10, -1, -1, -1, -1, 15, 31], dtype=int -) +upstream_metric_sum_2g = np.array([-1, 2, -1, 4, 5, 11, -1, -1, 9, 10, -1, -1, -1, -1, 15, 31], dtype=int) diff --git a/tests/_test_inputs/catchment.py b/tests/_test_inputs/catchment.py index 26e749f4..55a3f9e2 100644 --- a/tests/_test_inputs/catchment.py +++ b/tests/_test_inputs/catchment.py @@ -1,13 +1,13 @@ import numpy as np # catchment_query_field_1 = np.array( -# [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, 1, np.nan, 5, 4, 2, 3, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan], dtype="int" +# [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, 1, np.nan, 5, 4, 2, 3, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan], dtype="int" # noqa: E501 # ) catchment_query_field_1 = [8, 12, 13, 11, 10] # catchment_query_field_2 = np.array( -# [4, np.nan, 1, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, 3, np.nan, 2, np.nan, np.nan], dtype="int" +# [4, np.nan, 1, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, 3, np.nan, 2, np.nan, np.nan], dtype="int" # noqa: E501 # ) catchment_query_field_2 = [2, 13, 11, 0] diff --git a/tests/_test_inputs/movement.py b/tests/_test_inputs/movement.py index b808fd07..cf60a319 100644 --- a/tests/_test_inputs/movement.py +++ b/tests/_test_inputs/movement.py @@ -1,13 +1,9 @@ import numpy as np -upstream_1 = np.array( - [0, 0, 0, 0, 0, 1, 2, 7, 5, 0, 6, 7, 31, 25, 0, 0, 70, 19, 20, 0], dtype=float -) +upstream_1 = np.array([0, 0, 0, 0, 0, 1, 2, 7, 5, 0, 6, 7, 31, 25, 0, 0, 70, 19, 20, 0], dtype=float) -upstream_2 = np.array( - [13, 0, 2, 0, 0, 5, 12, 3, 0, 0, 13, 24, 0, 30, 0, 15], dtype=float -) +upstream_2 = np.array([13, 0, 2, 0, 0, 5, 12, 3, 0, 0, 13, 24, 0, 30, 0, 15], dtype=float) downstream_1 = np.array( @@ -16,6 +12,4 @@ ) -downstream_2 = np.array( - [0, 3, 8, 0, 6, 11, 11, 12, 14, 14, 14, 7, 1, 0, 16, 12], dtype=float -) +downstream_2 = np.array([0, 3, 8, 0, 6, 11, 11, 12, 14, 14, 14, 7, 1, 0, 16, 12], dtype=float) diff --git a/tests/catchments/array/test_find.py b/tests/catchments/array/test_find.py index 57abef8b..771f0891 100644 --- a/tests/catchments/array/test_find.py +++ b/tests/catchments/array/test_find.py @@ -21,12 +21,8 @@ def test_find_catchments_2d(river_network, query_field, find_catchments): # field = np.zeros(river_network.mask.shape, dtype="int") # field[river_network.mask] = query_field - network_find_catchments = ekh.catchments.array.find( - river_network, locations=query_field - ) + network_find_catchments = ekh.catchments.array.find(river_network, locations=query_field) print(find_catchments) print(network_find_catchments) - np.testing.assert_array_equal( - network_find_catchments.flat[river_network.mask], find_catchments - ) + np.testing.assert_array_equal(network_find_catchments.flat[river_network.mask], find_catchments) # np.testing.assert_array_equal(network_find_catchments[~river_network.mask], 0) diff --git a/tests/distance/array/test_max.py b/tests/distance/array/test_max.py index 9496b17d..6b6b32c8 100644 --- a/tests/distance/array/test_max.py +++ b/tests/distance/array/test_max.py @@ -28,9 +28,7 @@ ], indirect=["river_network"], ) -def test_distance_max( - river_network, stations_list, upstream, downstream, weights, result -): +def test_distance_max(river_network, stations_list, upstream, downstream, weights, result): dist = ekh.distance.array.max( river_network, stations_list, diff --git a/tests/distance/array/test_min.py b/tests/distance/array/test_min.py index 74f3ef31..720677f1 100644 --- a/tests/distance/array/test_min.py +++ b/tests/distance/array/test_min.py @@ -36,9 +36,7 @@ ], indirect=["river_network"], ) -def test_distance_min( - river_network, stations_list, upstream, downstream, weights, result -): +def test_distance_min(river_network, stations_list, upstream, downstream, weights, result): dist = ekh.distance.array.min( river_network, stations_list, diff --git a/tests/length/array/test_max.py b/tests/length/array/test_max.py index 41e9860f..30c7e6ba 100644 --- a/tests/length/array/test_max.py +++ b/tests/length/array/test_max.py @@ -28,9 +28,7 @@ ], indirect=["river_network"], ) -def test_length_max( - river_network, stations_list, upstream, downstream, weights, result -): +def test_length_max(river_network, stations_list, upstream, downstream, weights, result): dist = ekh.length.array.max( river_network, stations_list, diff --git a/tests/length/array/test_min.py b/tests/length/array/test_min.py index a491a762..00c98dcc 100644 --- a/tests/length/array/test_min.py +++ b/tests/length/array/test_min.py @@ -36,9 +36,7 @@ ], indirect=["river_network"], ) -def test_length_min( - river_network, stations_list, upstream, downstream, weights, result -): +def test_length_min(river_network, stations_list, upstream, downstream, weights, result): dist = ekh.length.array.min( river_network, stations_list, diff --git a/tests/upstream/array/test_max.py b/tests/upstream/array/test_max.py index bcc7c374..7bec4517 100644 --- a/tests/upstream/array/test_max.py +++ b/tests/upstream/array/test_max.py @@ -25,9 +25,7 @@ indirect=["river_network"], ) def test_calculate_upstream_metric_max(river_network, input_field, flow_downstream, mv): - output_field = ekh.upstream.array.max( - river_network, input_field, node_weights=None, return_type="masked" - ) + output_field = ekh.upstream.array.max(river_network, input_field, node_weights=None, return_type="masked") print(output_field) print(flow_downstream) assert output_field.dtype == flow_downstream.dtype diff --git a/tests/upstream/array/test_mean.py b/tests/upstream/array/test_mean.py index f9c78694..e7a8c9ce 100644 --- a/tests/upstream/array/test_mean.py +++ b/tests/upstream/array/test_mean.py @@ -25,12 +25,8 @@ ], indirect=["river_network"], ) -def test_calculate_upstream_metric_mean( - river_network, input_field, flow_downstream, mv -): - output_field = ekh.upstream.array.mean( - river_network, input_field, node_weights=None, return_type="masked" - ) +def test_calculate_upstream_metric_mean(river_network, input_field, flow_downstream, mv): + output_field = ekh.upstream.array.mean(river_network, input_field, node_weights=None, return_type="masked") assert output_field.dtype == flow_downstream.dtype np.testing.assert_allclose(output_field, flow_downstream) diff --git a/tests/upstream/array/test_min.py b/tests/upstream/array/test_min.py index dda15b41..ba1d7542 100644 --- a/tests/upstream/array/test_min.py +++ b/tests/upstream/array/test_min.py @@ -25,9 +25,7 @@ indirect=["river_network"], ) def test_calculate_upstream_metric_min(river_network, input_field, flow_downstream, mv): - output_field = ekh.upstream.array.min( - river_network, input_field, node_weights=None, return_type="masked" - ) + output_field = ekh.upstream.array.min(river_network, input_field, node_weights=None, return_type="masked") print(output_field) print(flow_downstream) assert output_field.dtype == flow_downstream.dtype diff --git a/tests/upstream/array/test_sum.py b/tests/upstream/array/test_sum.py index 125fda0f..e456f4a6 100644 --- a/tests/upstream/array/test_sum.py +++ b/tests/upstream/array/test_sum.py @@ -26,9 +26,7 @@ indirect=["river_network"], ) @pytest.mark.parametrize("array_backend", ["numpy", "torch", "jax"]) -def test_upstream_metric_sum( - river_network, input_field, flow_downstream, mv, array_backend -): +def test_upstream_metric_sum(river_network, input_field, flow_downstream, mv, array_backend): river_network = river_network.to_device("cpu", array_backend) xp = ekh._backends.find.get_array_backend(array_backend) output_field = ekh.upstream.array.sum( @@ -46,9 +44,7 @@ def test_upstream_metric_sum( flow_downstream = convert_to_2d(river_network, flow_downstream, 0) print(mv, input_field.dtype) print(input_field, flow_downstream) - output_field = ekh.upstream.array.sum( - river_network, xp.asarray(input_field), node_weights=None - ) + output_field = ekh.upstream.array.sum(river_network, xp.asarray(input_field), node_weights=None) output_field = np.asarray(output_field).flatten() flow_downstream = np.asarray(xp.asarray(flow_downstream)) print(output_field)