Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
adf8c64
add optional dependency group with gribjump dependencies
andreas-grafberger Oct 16, 2025
93c606a
chore: min ekh version
Oisin-M Oct 16, 2025
8385bae
feat: add opt_1d_index to outputted station mapping df
Oisin-M Oct 16, 2025
8943050
feat: first basic implementation using gribjump for timeseries extrac…
Oisin-M Oct 16, 2025
fc04787
fix: update returned xr metadata
Oisin-M Oct 16, 2025
3b7f8d2
fix: add station as dimension
Oisin-M Oct 16, 2025
b481a3e
refactor: same handling of ekd config for extract_timeseries and comp…
Oisin-M Nov 3, 2025
f25c4aa
feat: update notebooks with new config parsing
Oisin-M Nov 3, 2025
c2b6081
refactor: share load_da between extract_timeseries and compute_hydros…
Oisin-M Nov 3, 2025
b1dbb9b
fix: remove duplicate ekd.from_source call in load_da
andreas-grafberger Nov 4, 2025
d5f561a
opt: pass ranges instead of indices to gribjump source
andreas-grafberger Nov 5, 2025
0ccd02e
update earthkit-data minimum version to 0.18.0
andreas-grafberger Nov 13, 2025
37fd63d
test(extract_timeseries): add basic unit tests for extractor (non-gri…
andreas-grafberger Nov 13, 2025
8a6973c
refactor(extractor): simplify control flow and improve readability
andreas-grafberger Nov 13, 2025
f1c6953
feat(extractor): add type hints to some functions
andreas-grafberger Nov 13, 2025
7db064f
chore: add disclaimer to gribjumplib in pyproject.toml
andreas-grafberger Nov 17, 2025
d5dda4d
chore: add installation instructions for experimental gribjump extras
andreas-grafberger Nov 17, 2025
83c861b
chore: minor cosmetic changes like comments
andreas-grafberger Nov 19, 2025
11470b4
chore: bump minimum earthkit-data version to 0.18.2
andreas-grafberger Nov 19, 2025
d413ccb
Change ProgessBar import from dask
andreas-grafberger Nov 19, 2025
4d171bb
remove gribjumplib as a dependency
andreas-grafberger Nov 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions hat/compute_hydrostats/stat_calc.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
import earthkit.data as ekd
from earthkit.hydro._readers import find_main_var
from hat.core import load_da
import numpy as np
import xarray as xr
from hat.compute_hydrostats import stats


def load_da(ds_config):
ds = ekd.from_source(*ds_config["source"]).to_xarray()
var_name = find_main_var(ds, 2)
da = ds[var_name]
return da


def find_valid_subset(sim_da, obs_da, sim_coords, obs_coords, new_coords):
sim_station_colname = sim_coords.get("s", "station")
obs_station_colname = obs_coords.get("s", "station")
Expand All @@ -35,9 +27,9 @@ def find_valid_subset(sim_da, obs_da, sim_coords, obs_coords, new_coords):

def stat_calc(config):
sim_config = config["sim"]
sim_da = load_da(config["sim"])
sim_da, _ = load_da(sim_config, 2)
obs_config = config["obs"]
obs_da = load_da(obs_config)
obs_da, _ = load_da(obs_config, 2)
new_coords = config["output"]["coords"]
sim_da, obs_da = find_valid_subset(sim_da, obs_da, sim_config["coords"], obs_config["coords"], new_coords)
stat_dict = {}
Expand Down
10 changes: 10 additions & 0 deletions hat/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import earthkit.data as ekd
from earthkit.hydro._readers import find_main_var


def load_da(ds_config, n_dims):
src_name = list(ds_config["source"].keys())[0]
ds = ekd.from_source(src_name, **ds_config["source"][src_name]).to_xarray(**ds_config.get("to_xarray_options", {}))
var_name = find_main_var(ds, n_dims)
da = ds[var_name]
return da, var_name
70 changes: 38 additions & 32 deletions hat/extract_timeseries/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,17 @@
import pandas as pd
import xarray as xr
import numpy as np
import earthkit.data as ekd
from earthkit.hydro._readers import find_main_var
from hat.core import load_da

from hat import _LOGGER as logger


def process_grid_inputs(grid_config):
src_name = list(grid_config["source"].keys())[0]
logger.info(f"Processing grid inputs from source: {src_name}")
logger.debug(f"Grid config: {grid_config['source'][src_name]}")
ds = ekd.from_source(src_name, **grid_config["source"][src_name]).to_xarray(
**grid_config.get("to_xarray_options", {})
)
var_name = find_main_var(ds, 3)
da = ds[var_name]
da, var_name = load_da(grid_config, 3)
logger.info(f"Xarray created from source:\n{da}\n")
gridx_colname = grid_config.get("coord_x", "lat")
gridy_colname = grid_config.get("coord_y", "lon")
coord_config = grid_config.get("coords", {})
gridx_colname = coord_config.get("x", "lat")
gridy_colname = coord_config.get("y", "lon")
da = da.sortby([gridx_colname, gridy_colname])
shape = da[gridx_colname].shape[0], da[gridy_colname].shape[0]
return da, var_name, gridx_colname, gridy_colname, shape
Expand Down Expand Up @@ -61,7 +54,7 @@ def create_mask_from_coords(coords_config, df, gridx, gridy, shape):
return mask, duplication_indexes


def process_inputs(station_config, grid_config):
def parse_stations(station_config):
logger.debug(f"Reading station file, {station_config}")
df = pd.read_csv(station_config["file"])
filters = station_config.get("filter")
Expand All @@ -72,23 +65,39 @@ def process_inputs(station_config, grid_config):

index_config = station_config.get("index", None)
coords_config = station_config.get("coords", None)
index_1d_config = station_config.get("index_1d", None)
return index_config, coords_config, index_1d_config, station_names, df


def process_inputs(station_config, grid_config):
index_config, coords_config, index_1d_config, station_names, df = parse_stations(station_config)

# TODO: better malformed config handling
if index_config is not None and coords_config is not None:
raise ValueError("Use either index or coords, not both.")

da, da_varname, gridx_colname, gridy_colname, shape = process_grid_inputs(grid_config)

if index_config is not None:
mask, duplication_indexes = create_mask_from_index(index_config, df, shape)
elif coords_config is not None:
mask, duplication_indexes = create_mask_from_coords(
coords_config, df, da[gridx_colname].values, da[gridy_colname].values, shape
)
if list(grid_config["source"].keys())[0] == "gribjump":
assert index_1d_config is not None
unique_indices, duplication_indexes = np.unique(df[index_1d_config].values, return_inverse=True)
grid_config["source"]["gribjump"]["indices"] = unique_indices
Copy link
Collaborator

@colonesej colonesej Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we force and always use the option fetch_coords_from_fdb=True in .from_source("gribjump",...) ?

Copy link
Collaborator Author

@Oisin-M Oisin-M Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ranges=[(1234, 2345)], will extract all 1D indices from 1234 up until (and not including) 2345, it's not a 2D index. Since we just really want a few indices here I don't think it's suitable for this application? Happy to be corrected if I misunderstood something

I need to read up on the fetch_coords_from_fdb things, but might be useful indeed!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, sorry. I misunderstood the ranges meaning, deleted that part of the comment. Cool, I think having the coordinates in the netcdf output would be nice.

masked_da, da_varname = load_da(grid_config, 2)
else:
# default to index approach
mask, duplication_indexes = create_mask_from_index(index_config, df, shape)
da, da_varname, gridx_colname, gridy_colname, shape = process_grid_inputs(grid_config)

if index_config is not None:
mask, duplication_indexes = create_mask_from_index(index_config, df, shape)
elif coords_config is not None:
mask, duplication_indexes = create_mask_from_coords(
coords_config, df, da[gridx_colname].values, da[gridy_colname].values, shape
)
else:
# default to index approach
mask, duplication_indexes = create_mask_from_index(index_config, df, shape)

return da, da_varname, gridx_colname, gridy_colname, mask, station_names, duplication_indexes
logger.info("Extracting timeseries at selected stations")
masked_da = apply_mask(da, mask, gridx_colname, gridy_colname)

return da_varname, station_names, duplication_indexes, masked_da


def mask_array_np(arr, mask):
Expand All @@ -101,12 +110,12 @@ def apply_mask(da, mask, coordx, coordy):
da,
mask,
input_core_dims=[(coordx, coordy), (coordx, coordy)],
output_core_dims=[["station"]],
output_core_dims=[["index"]],
output_dtypes=[da.dtype],
exclude_dims={coordx, coordy},
dask="parallelized",
dask_gufunc_kwargs={
"output_sizes": {"station": int(mask.sum())},
"output_sizes": {"index": int(mask.sum())},
"allow_rechunk": True,
},
)
Expand All @@ -115,13 +124,10 @@ def apply_mask(da, mask, coordx, coordy):


def extractor(config):
da, da_varname, gridx_colname, gridy_colname, mask, station_names, duplication_indexes = process_inputs(
config["station"], config["grid"]
)
logger.info("Extracting timeseries at selected stations")
masked_da = apply_mask(da, mask, gridx_colname, gridy_colname)
da_varname, station_names, duplication_indexes, masked_da = process_inputs(config["station"], config["grid"])
ds = xr.Dataset({da_varname: masked_da})
ds = ds.isel(station=duplication_indexes)
ds = ds.isel(index=duplication_indexes)
ds = ds.rename({"index": "station"})
ds["station"] = station_names
if config.get("output", None) is not None:
logger.info(f"Saving output to {config['output']['file']}")
Expand Down
4 changes: 3 additions & 1 deletion hat/station_mapping/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,15 @@ def apply_blacklist(blacklist_config, metric_grid, grid_area_coords1, grid_area_
return metric_grid, grid_area_coords1, grid_area_coords2


def outputs_to_df(df, indx, indy, cindx, cindy, errors, grid_area_coords1, grid_area_coords2, filename):
def outputs_to_df(df, indx, indy, cindx, cindy, errors, grid_area_coords1, grid_area_coords2, shape, filename):
df["opt_x_index"] = indx
df["opt_y_index"] = indy
df["near_x_index"] = cindx
df["near_y_index"] = cindy
df["opt_error"] = errors
df["opt_x_coord"] = grid_area_coords1[indx, 0]
df["opt_y_coord"] = grid_area_coords2[0, indy]
df["opt_1d_index"] = indy + shape[1] * indx
if filename is not None:
df.to_csv(filename, index=False)
return df
Expand Down Expand Up @@ -109,6 +110,7 @@ def mapper(config):
*mapping_outputs,
grid_area_coords1,
grid_area_coords2,
shape=grid_area_coords1.shape,
filename=config["output"]["file"] if config.get("output", None) is not None else None,
)
generate_summary_plots(df, config.get("plot", None))
Expand Down
6 changes: 3 additions & 3 deletions notebooks/workflow/hydrostats_computation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
"source": [
"config = {\n",
" \"sim\": {\n",
" \"source\": [\"file\", \"extracted_timeseries.nc\"],\n",
" \"source\": {\"file\": \"extracted_timeseries.nc\"},\n",
" \"coords\": {\n",
" \"s\": \"station\",\n",
" \"t\": \"time\"\n",
" }\n",
" },\n",
" \"obs\": {\n",
" \"source\": [\"file\", \"observations.nc\"],\n",
" \"source\": {\"file\": \"observations.nc\"},\n",
" \"coords\": {\n",
" \"s\": \"station\",\n",
" \"t\": \"time\"\n",
Expand All @@ -49,7 +49,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "hat",
"language": "python",
"name": "python3"
},
Expand Down
2 changes: 1 addition & 1 deletion notebooks/workflow/timeseries_extraction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
" \"name\": \"station_id\"\n",
" },\n",
" \"grid\": {\n",
" \"source\": [\"file\", \"./sim.nc\"],\n",
" \"source\": {\"file\": \"./sim.nc\"},\n",
" \"coords\": {\n",
" \"x\": \"lat\",\n",
" \"y\": \"lon\",\n",
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ dependencies = [
"ipyleaflet",
"ipywidgets",
"earthkit-data>=0.13.8",
"earthkit-hydro",
"earthkit-hydro>=1.0.0",
"earthkit-meteo",
"cfgrib", # check if necessary
"netCDF4", # check if necessary
Expand All @@ -68,6 +68,10 @@ dependencies = [
"ruff",
"pre-commit"
]
gribjump = [
"earthkit-data[gribjump]",
"gribjumplib==0.10.3.dev20250908"
]

[project.scripts]
hat-extract-timeseries = "hat.cli:extractor_cli"
Expand Down
Loading