Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@ repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.15.4
hooks:
- id: ruff
args: [--fix, --show-fixes]
- id: ruff-check
exclude: '(dev/.*|.*_)\.py$'
args:
- --line-length=120
- --fix
- --exit-non-zero-on-fix
- --preview
- id: ruff-format
- repo: https://github.com/executablebooks/mdformat
rev: 1.0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,7 @@
"metadata": {},
"outputs": [],
"source": [
"reduced_data = ekt.spatial.reduce(\n",
" era5_xr, nuts_data, mask_dim=\"FID\", extra_reduce_dims=\"valid_time\", all_touched=True\n",
")\n",
"reduced_data = ekt.spatial.reduce(era5_xr, nuts_data, mask_dim=\"FID\", extra_reduce_dims=\"valid_time\", all_touched=True)\n",
"reduced_data"
]
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,7 @@
"era5_monthly_mean.t2m.isel(**isel_kwargs).plot(label=\"Monthly mean\", ax=ax, color=\"black\")\n",
"upper_m = era5_monthly_mean.t2m + era5_monthly_std.t2m\n",
"lower_m = era5_monthly_mean.t2m - era5_monthly_std.t2m\n",
"upper_m.isel(**isel_kwargs).plot(\n",
" ax=ax, label=\"Monthly standard deviation spread\", linestyle=\"--\", color=\"black\"\n",
")\n",
"upper_m.isel(**isel_kwargs).plot(ax=ax, label=\"Monthly standard deviation spread\", linestyle=\"--\", color=\"black\")\n",
"lower_m.isel(**isel_kwargs).plot(ax=ax, linestyle=\"--\", color=\"black\")\n",
"\n",
"\n",
Expand Down
4 changes: 1 addition & 3 deletions docs/notebooks/temporal/03-seas5-daily-statistics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,7 @@
"\n",
"for itime in range(3):\n",
" for number in range(25):\n",
" t_data = seas_daily_median_by_step.t2m.isel(\n",
" **isel_kwargs, number=number, forecast_reference_time=itime\n",
" )\n",
" t_data = seas_daily_median_by_step.t2m.isel(**isel_kwargs, number=number, forecast_reference_time=itime)\n",
" if number == 0:\n",
" extra_kwargs = {\"label\": f\"FC ref time: {str(t_data.forecast_reference_time.values)[:10]}\"}\n",
" else:\n",
Expand Down
12 changes: 3 additions & 9 deletions docs/notebooks/temporal/05-accumulation-to-rate-examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,7 @@
"color = \"tab:blue\"\n",
"ax1.set_xlabel(\"Time\")\n",
"ax1.set_ylabel(f\"Accumulation ({ds_era5['tp'].units})\", color=color)\n",
"ax1.bar(\n",
" ds_era5[\"valid_time\"], ds_era5[\"tp\"].isel(plot_point), color=color, label=\"Accumulation\", width=0.1, lw=0\n",
")\n",
"ax1.bar(ds_era5[\"valid_time\"], ds_era5[\"tp\"].isel(plot_point), color=color, label=\"Accumulation\", width=0.1, lw=0)\n",
"ax1.tick_params(axis=\"y\", labelcolor=color)\n",
"\n",
"ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis\n",
Expand Down Expand Up @@ -319,9 +317,7 @@
"metadata": {},
"outputs": [],
"source": [
"ds_seas5_rate = ekt.temporal.accumulation_to_rate(\n",
" ds_seas5, accumulation_type=\"start_of_forecast\", from_first_step=True\n",
")\n",
"ds_seas5_rate = ekt.temporal.accumulation_to_rate(ds_seas5, accumulation_type=\"start_of_forecast\", from_first_step=True)\n",
"ds_seas5_rate"
]
},
Expand Down Expand Up @@ -409,9 +405,7 @@
"metadata": {},
"outputs": [],
"source": [
"ds_seas5_deaccumulate = ekt.temporal.deaccumulate(\n",
" ds_seas5, accumulation_type=\"start_of_forecast\", from_first_step=True\n",
")\n",
"ds_seas5_deaccumulate = ekt.temporal.deaccumulate(ds_seas5, accumulation_type=\"start_of_forecast\", from_first_step=True)\n",
"ds_seas5_deaccumulate"
]
},
Expand Down
42 changes: 18 additions & 24 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,35 +62,29 @@ exclude = '^tests/legacy-api/'
ignore_missing_imports = true

[tool.ruff]
# Black line length is 88, but black does not format comments.
line-length = 110
select = [
# pyflakes
"F",
# pycodestyle
"E",
"W",
# isort
"I",
# pydocstyle
"D"
]
line-length = 120

[tool.ruff.lint]
ignore = [
# pydocstyle: Missing Docstrings
"D1"
"D1", # pydocstyle: Missing Docstrings
"D107", # pydocstyle: numpy convention
"D203",
"D205",
"D212",
"D213",
"D401",
"D402",
"D413",
"D415",
"D416",
"D417"
]
select = [
# pyflakes
"F",
# pycodestyle
"E",
"W",
# isort
"I",
# pydocstyle
"D"
"F", # pyflakes
"E", # pycodestyle
"W", # pycodestyle warnings
"I", # isort
"D" # pydocstyle
]

[tool.setuptools.packages.find]
Expand Down
18 changes: 5 additions & 13 deletions src/earthkit/transforms/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,7 @@ def latitude_weights(dataarray: xr.DataArray, lat_key: str | None = None):
xp = array_namespace_from_object(lat_array[lat_key])
return xp.cos(xp.radians(lat_array[lat_key]))

raise KeyError(
"Latitude variable name not detected or found in the dataarray. Please provide the correct key."
)
raise KeyError("Latitude variable name not detected or found in the dataarray. Please provide the correct key.")


HOW_METHODS = {
Expand Down Expand Up @@ -495,9 +493,7 @@ def get_dim_key(

# We have not been able to detect, so raise an error
if raise_error:
raise ValueError(
f"Unable to find dimension key for axis '{axis}' in dataarray with dims: {dataarray.dims}."
)
raise ValueError(f"Unable to find dimension key for axis '{axis}' in dataarray with dims: {dataarray.dims}.")

return axis

Expand Down Expand Up @@ -570,8 +566,7 @@ def groupby_time(
grouped_data = dataarray.groupby(f"{time_dim}.{frequency}")
except AttributeError:
raise ValueError(
f"Invalid frequency '{frequency}' - see xarray documentation for "
f"a full list of valid frequencies."
f"Invalid frequency '{frequency}' - see xarray documentation for a full list of valid frequencies."
)

return grouped_data
Expand All @@ -590,8 +585,7 @@ def groupby_bins(
grouped_data = dataarray.groupby_bins(f"{time_dim}.{frequency}", bin_widths)
except AttributeError:
raise ValueError(
f"Invalid frequency '{frequency}' - see xarray documentation for "
f"a full list of valid frequencies."
f"Invalid frequency '{frequency}' - see xarray documentation for a full list of valid frequencies."
)
return grouped_data

Expand Down Expand Up @@ -638,9 +632,7 @@ def _wrapper(_kwarg_types, _convert_types, *args, **kwargs):
_convert_types = {key: _convert_types for key in convert_kwargs}

convert_kwargs = [
k
for k in convert_kwargs
if isinstance(kwargs[k], _ensure_tuple(_convert_types.get(k, ())))
k for k in convert_kwargs if isinstance(kwargs[k], _ensure_tuple(_convert_types.get(k, ())))
]

# Transform args/kwargs
Expand Down
8 changes: 2 additions & 6 deletions src/earthkit/transforms/climatology/_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,9 +777,7 @@ def _anomaly_dataarray(
anomaly_array = groupby_time(dataarray, time_dim=time_dim, **groupby_kwargs) - climatology_da

if relative:
anomaly_array = (
groupby_time(anomaly_array, time_dim=time_dim, **groupby_kwargs) / climatology_da
) * 100.0
anomaly_array = (groupby_time(anomaly_array, time_dim=time_dim, **groupby_kwargs) / climatology_da) * 100.0
name_tag = "relative anomaly"
update_attrs = {"units": "%"}
else:
Expand All @@ -788,9 +786,7 @@ def _anomaly_dataarray(

anomaly_array = resample(anomaly_array, how="mean", **reduce_kwargs, **groupby_kwargs, dim=time_dim)

return _update_anomaly_array(
anomaly_array, dataarray, var_name, name_tag, update_attrs, how_label=how_label
)
return _update_anomaly_array(anomaly_array, dataarray, var_name, name_tag, update_attrs, how_label=how_label)


def _update_anomaly_array(anomaly_array, original_array, var_name, name_tag, update_attrs, how_label=None):
Expand Down
7 changes: 2 additions & 5 deletions src/earthkit/transforms/spatial/_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,9 +475,7 @@ def reduce(
else:
raise TypeError("Return as type not recognised or incompatible with inputs")
else:
return _reduce_dataarray_as_xarray(
dataarray, geodataframe=geodataframe, mask_arrays=_mask_arrays, **kwargs
)
return _reduce_dataarray_as_xarray(dataarray, geodataframe=geodataframe, mask_arrays=_mask_arrays, **kwargs)


def _reduce_dataarray_as_xarray(
Expand Down Expand Up @@ -717,8 +715,7 @@ def _reduce_dataarray_as_pandas(
out_dims = {dim: dataarray[dim].data for dim in _out_dims}
reduce_attrs[f"{out_xr.name}"].update({"dims": out_dims})
reduced_list = [
out_xr.sel(**{mask_dim_name: mask_dim_value}).data
for mask_dim_value in out_xr[mask_dim_name].data
out_xr.sel(**{mask_dim_name: mask_dim_value}).data for mask_dim_value in out_xr[mask_dim_name].data
]
out = out.assign(**{f"{out_xr.name}": reduced_list})

Expand Down
8 changes: 2 additions & 6 deletions src/earthkit/transforms/temporal/_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,9 @@ def standardise_time(
try:
source_times = [time_value.strftime(target_format) for time_value in dataarray[time_dim].data]
except AttributeError:
source_times = [
pd.to_datetime(time_value).strftime(target_format) for time_value in dataarray[time_dim].data
]
source_times = [pd.to_datetime(time_value).strftime(target_format) for time_value in dataarray[time_dim].data]

standardised_times = np.array(
[pd.to_datetime(time_string).to_datetime64() for time_string in source_times]
)
standardised_times = np.array([pd.to_datetime(time_string).to_datetime64() for time_string in source_times])

dataarray = dataarray.assign_coords({time_dim: standardised_times})

Expand Down
9 changes: 2 additions & 7 deletions src/earthkit/transforms/temporal/_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,19 +372,14 @@ def _accumulation_to_rate_dataarray(
if "units" in dataarray.attrs:
output.attrs.update({"units": dataarray.attrs["units"] + rate_units_str})
if "long_name" in dataarray.attrs:
output.attrs["long_name"] = " ".join(
filter(None, [dataarray.attrs["long_name"], rate_label.replace("_", " ")])
)
output.attrs["long_name"] = " ".join(filter(None, [dataarray.attrs["long_name"], rate_label.replace("_", " ")]))
if provenance or "history" in dataarray.attrs:
output.attrs["history"] = "\n".join(
filter(
None,
[
dataarray.attrs.get("history", ""),
(
"Converted from accumulation to rate using "
"earthkit.transforms.temporal.accumulation_to_rate."
),
("Converted from accumulation to rate using earthkit.transforms.temporal.accumulation_to_rate."),
],
)
)
Expand Down
28 changes: 7 additions & 21 deletions tests/legacy-api/test_legacy_10_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def dummy_func2(dataarray, *args, time_dim=None, **kwargs):
# Test case for the decorator when time_shift is provided
def test_time_dim_decorator_time_shift_provided():
# Prepare test data
dataarray = xr.DataArray(
[1, 2, 3], dims=["time"], coords={"time": pd.date_range("2000-01-01", periods=3)}
)
dataarray = xr.DataArray([1, 2, 3], dims=["time"], coords={"time": pd.date_range("2000-01-01", periods=3)})

# Call the decorated function with time_shift provided
result, result_time_dim = time_dim_decorator(dummy_func)(dataarray, time_shift={"days": 1})
Expand All @@ -52,14 +50,10 @@ def test_time_dim_decorator_time_shift_provided():
# Test case for the decorator when both time_dim and time_shift are provided
def test_time_dim_decorator_time_dim_and_time_shift_provided():
# Prepare test data
dataarray = xr.DataArray(
[1, 2, 3], dims=["dummy"], coords={"dummy": pd.date_range("2000-01-01", periods=3)}
)
dataarray = xr.DataArray([1, 2, 3], dims=["dummy"], coords={"dummy": pd.date_range("2000-01-01", periods=3)})

# Call the decorated function with both time_dim and time_shift provided
result, result_time_dim = time_dim_decorator(dummy_func)(
dataarray, time_dim="dummy", time_shift={"days": 1}
)
result, result_time_dim = time_dim_decorator(dummy_func)(dataarray, time_dim="dummy", time_shift={"days": 1})

# Check if the time dimension remains unchanged
assert result_time_dim == "dummy"
Expand All @@ -81,9 +75,7 @@ def test_time_dim_decorator_not_found_error():
# Test case for the decorator when time_shift is provided and remove_partial_periods=True
def test_time_dim_decorator_time_shift_provided_trim_shifted():
# Prepare test data
dataarray = xr.DataArray(
[1, 2, 3], dims=["time"], coords={"time": pd.date_range("2000-01-01", periods=3)}
)
dataarray = xr.DataArray([1, 2, 3], dims=["time"], coords={"time": pd.date_range("2000-01-01", periods=3)})

# Call the decorated function with time_shift provided
result = time_dim_decorator(dummy_func2)(dataarray, time_shift={"days": 1}, remove_partial_periods=True)
Expand Down Expand Up @@ -114,9 +106,7 @@ def test_groupby_kwargs_decorator_provided():
other_kwargs = {"method": "linear", "fill_value": 0}

# Call the decorated function with groupby_kwargs provided
result_groupby_kwargs, result_kwargs = groupby_kwargs_decorator(gb_dummy_func)(
**groupby_kwargs, **other_kwargs
)
result_groupby_kwargs, result_kwargs = groupby_kwargs_decorator(gb_dummy_func)(**groupby_kwargs, **other_kwargs)

assert result_groupby_kwargs == groupby_kwargs
assert result_kwargs == other_kwargs
Expand All @@ -129,9 +119,7 @@ def test_groupby_kwargs_decorator_partial_provided():
other_kwargs = {"method": "linear"}

# Call the decorated function with some groupby_kwargs provided as keyword arguments
result_groupby_kwargs, result_kwargs = groupby_kwargs_decorator(gb_dummy_func)(
**groupby_kwargs, **other_kwargs
)
result_groupby_kwargs, result_kwargs = groupby_kwargs_decorator(gb_dummy_func)(**groupby_kwargs, **other_kwargs)

assert result_groupby_kwargs == groupby_kwargs
assert result_kwargs == other_kwargs
Expand Down Expand Up @@ -209,9 +197,7 @@ def test_get_how_xp_np():
"data_object",
(
xr.DataArray(np.random.rand(10, 10), dims=["x", "y"]),
xr.Dataset(
{"var": (["x", "y"], np.random.rand(10, 10))}, coords={"x": np.arange(10), "y": np.arange(10)}
),
xr.Dataset({"var": (["x", "y"], np.random.rand(10, 10))}, coords={"x": np.arange(10), "y": np.arange(10)}),
np.random.rand(10, 10),
),
)
Expand Down
5 changes: 1 addition & 4 deletions tests/legacy-api/test_legacy_30_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,7 @@ def test_standardise_time_basic():
def test_standardise_time_monthly():
data = get_data().to_xarray()
data_standardised = temporal.standardise_time(data, target_format="%Y-%m-15")
assert all(
pd.to_datetime(time_value).day == 15
for time_value in data_standardised.forecast_reference_time.values
)
assert all(pd.to_datetime(time_value).day == 15 for time_value in data_standardised.forecast_reference_time.values)


@pytest.mark.parametrize(
Expand Down
Loading
Loading