Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 191 additions & 0 deletions argopy/tests/test_utils_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import pytest
import logging
import xarray as xr
import numpy as np
from pathlib import Path

from argopy import tutorial
from argopy.errors import InvalidDatasetStructure
from argopy.utils import (
fill_variables_not_in_all_datasets,
drop_variables_not_in_all_datasets,
merge_param_with_param_adjusted,
filter_param_by_data_mode,
split_data_mode,
)

log = logging.getLogger("argopy.tests.utils.transform")


def get_da(name, N_PROF, N_LEVELS):
return xr.DataArray(
np.random.rand(N_PROF, N_LEVELS),
coords={"N_PROF": np.arange(0, N_PROF), "N_LEVELS": np.arange(0, N_LEVELS)},
name=name,
)


def get_ds(names, N_PROF=2, N_LEVELS=1):
d = {}
for name in names:
d.update({name: get_da(name, N_PROF, N_LEVELS)})
return xr.Dataset(d)


def test_drop_variables_not_in_all_datasets():
# Create a list of dummy datasets:
ds1 = get_ds(["PRES", "TEMP", "PSAL"], 3, 6)
ds2 = get_ds(["PRES", "TEMP", "PSAL", "DOXY"], 3, 6)
# Drop:
ds_list = drop_variables_not_in_all_datasets([ds1, ds2])
# Assert:
assert len(ds_list) == 2
assert "DOXY" not in ds_list[1]
for key in ["PRES", "TEMP", "PSAL"]:
assert key in ds_list[0]
assert key in ds_list[1]


def test_fill_variables_not_in_all_datasets():
# Create a list of dummy datasets:
ds1 = get_ds(["PRES", "TEMP", "PSAL"], 3, 6)
ds2 = get_ds(["PRES", "TEMP", "PSAL", "DOXY"], 3, 6)
# Fill:
ds_list = fill_variables_not_in_all_datasets([ds1, ds2], concat_dim="N_PROF")
# Assert:
assert len(ds_list) == 2
for key in ["PRES", "TEMP", "PSAL", "DOXY"]:
assert key in ds_list[0]
assert key in ds_list[1]


class Test_merge_param_with_param_adjusted:

def _create_ds(self):
# Create a list of dataset that should be mergeable
# and cover all parameter presence combination possible
ds_list = []
ds = get_ds(["PRES", "DATA_MODE", "TEMP_ADJUSTED"], 2, 1)
ds = ds.stack({"N_POINTS": ["N_PROF", "N_LEVELS"]})
ds_list.append(ds)

ds = get_ds(
["PRES", "DATA_MODE", "TEMP", "TEMP_ADJUSTED", "TEMP_ADJUSTED_QC"], 2, 1
)
ds = ds.stack({"N_POINTS": ["N_PROF", "N_LEVELS"]})
ds_list.append(ds)

ds = get_ds(
[
"PRES",
"DATA_MODE",
"TEMP",
"TEMP_ADJUSTED",
"TEMP_ADJUSTED_QC",
"TEMP_ADJUSTED_ERROR",
],
2,
1,
)
ds = ds.stack({"N_POINTS": ["N_PROF", "N_LEVELS"]})
ds_list.append(ds)

# Now fill data mode with:
# R only, A only, D only,
# R & A, R & D, A & D
ds_list_final = []
for dd in ds_list:
for ii in range(0, 6):
if ii == 0:
dd["DATA_MODE"].values = ["R", "R"]
elif ii == 1:
dd["DATA_MODE"].values = ["A", "A"]
elif ii == 2:
dd["DATA_MODE"].values = ["D", "D"]
elif ii == 3:
dd["DATA_MODE"].values = ["R", "A"]
elif ii == 4:
dd["DATA_MODE"].values = ["R", "D"]
elif ii == 5:
dd["DATA_MODE"].values = ["A", "D"]
ds_list_final.append(dd.copy())

return ds_list_final

def test_ds_structure_errors(self):
ds = get_ds(["PRES", "TEMP", "PSAL"], 3, 6)
with pytest.raises(InvalidDatasetStructure):
merge_param_with_param_adjusted(ds, "TEMP", errors="raise")
assert ds == merge_param_with_param_adjusted(ds, "TEMP", errors="ignore")

ds = get_ds(["PRES", "TEMP", "TEMP_ADJUSTED"], 3, 6)
with pytest.raises(InvalidDatasetStructure):
merge_param_with_param_adjusted(ds, "TEMP", errors="raise")

ds = get_ds(["PRES", "TEMP", "TEMP_ADJUSTED"], 3, 6)
ds = ds.stack({"N_POINTS": ["N_PROF", "N_LEVELS"]})
with pytest.raises(InvalidDatasetStructure):
merge_param_with_param_adjusted(ds, "TEMP", errors="raise")
assert ds == merge_param_with_param_adjusted(ds, "TEMP", errors="ignore")

def test_for_parameters(self):
ds_list = self._create_ds()
for ds in ds_list:
ds_merged = merge_param_with_param_adjusted(ds, "TEMP", errors="raise")
assert "TEMP_ADJUSTED" not in ds_merged
assert "TEMP_ADJUSTED_ERROR" not in ds_merged
assert "TEMP_ADJUSTED_QC" not in ds_merged


class Test_filter_param_by_data_mode:

def test_ds_structure_errors(self):
ds = get_ds(["PRES", "TEMP", "PSAL"], 3, 6)
with pytest.raises(InvalidDatasetStructure):
filter_param_by_data_mode(ds, "TEMP", errors="raise")
assert ds == filter_param_by_data_mode(ds, "TEMP", errors="ignore")

def test_single_value_filter(self):
ds = get_ds(["PRES", "TEMP", "TEMP_ADJUSTED", "DATA_MODE"])
ds = ds.stack({"N_POINTS": ["N_PROF", "N_LEVELS"]})

ds["DATA_MODE"].values = ["R", "R"]
ds = filter_param_by_data_mode(ds, "TEMP", dm=["R"], mask=False)
assert len(ds["N_POINTS"]) == 2

ds = filter_param_by_data_mode(ds, "TEMP", dm=["A"], mask=False)
assert len(ds["N_POINTS"]) == 0

def test_multiple_value_filter(self):
ds = get_ds(["PRES", "TEMP", "TEMP_ADJUSTED", "DATA_MODE"])
ds = ds.stack({"N_POINTS": ["N_PROF", "N_LEVELS"]})

ds["DATA_MODE"].values = ["R", "R"]
dsf = filter_param_by_data_mode(ds.copy(), "TEMP", dm=["R", "A"], mask=False)
assert len(dsf["N_POINTS"]) == 2
assert np.unique(ds["DATA_MODE"]) == "R"

ds["DATA_MODE"].values = ["R", "A"]
dsf = filter_param_by_data_mode(ds.copy(), "TEMP", dm=["A", "D"], mask=False)
assert len(dsf["N_POINTS"]) == 1
assert dsf["DATA_MODE"].values == ["A"]

ds["DATA_MODE"].values = ["R", "R"]
dsf = filter_param_by_data_mode(ds.copy(), "TEMP", dm=["D", "A"], mask=False)
assert len(dsf["N_POINTS"]) == 0


def test_split_data_mode():
ds = get_ds(["PRES", "TEMP", "DOXY" "STATION_PARAMETERS", "PARAMETER_DATA_MODE"])
ds = ds.stack({"N_POINTS": ["N_PROF", "N_LEVELS"]})
with pytest.raises(InvalidDatasetStructure):
split_data_mode(ds)

host = tutorial.open_dataset("gdac")[0]
ds = xr.open_dataset(Path(host).joinpath("dac/coriolis/3902131/3902131_Sprof.nc"), engine='argo')
dss = split_data_mode(ds)
assert "PARAMETER_DATA_MODE" not in dss
assert "TEMP_DATA_MODE" in dss
assert "PRES_DATA_MODE" in dss
assert "PSAL_DATA_MODE" in dss
assert "DOXY_DATA_MODE" in dss
52 changes: 25 additions & 27 deletions argopy/utils/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import numpy as np
import xarray as xr
import pandas as pd
import logging
from typing import List, Union

Expand Down Expand Up @@ -121,14 +120,12 @@ def fillvalue(da):
for res in ds_collection:
[vlist.append(v) for v in list(res.variables) if concat_dim in res[v].dims]
vlist = np.unique(vlist)
log.debug("variables: %s" % vlist)

# List all possible coordinates:
clist = []
for res in ds_collection:
[clist.append(c) for c in list(res.coords) if concat_dim in res[c].dims]
clist = np.unique(clist)
log.debug("coordinates: %s" % clist)

# Get the first occurrence of each variable, to be used as a template for attributes and dtype
meta = {}
Expand All @@ -140,7 +137,6 @@ def fillvalue(da):
"dtype": ds[v].dtype,
"fill_value": fillvalue(ds[v]),
}
[log.debug(meta[m]) for m in meta.keys()]

# Add missing variables to dataset
datasets = [ds.copy() for ds in ds_collection]
Expand Down Expand Up @@ -251,23 +247,25 @@ def merge_param_with_param_adjusted(
),
)
)
copy_adj = np.count_nonzero(ii_measured_adj) > 0

# Copy param_adjusted values onto param indexes where data_mode is in 'a' or 'd':
ds["%s" % param].loc[dict(N_POINTS=ii_measured_adj)] = ds[
"%s_ADJUSTED" % param
].loc[dict(N_POINTS=ii_measured_adj)]
ds = ds.drop_vars(["%s_ADJUSTED" % param])

if "%s_ADJUSTED_QC" % param in ds and "%s_ADJUSTED_QC" % param in ds:
ds["%s_QC" % param].loc[dict(N_POINTS=ii_measured_adj)] = ds[
"%s_ADJUSTED_QC" % param
if copy_adj:
ds["%s" % param].loc[dict(N_POINTS=ii_measured_adj)] = ds[
"%s_ADJUSTED" % param
].loc[dict(N_POINTS=ii_measured_adj)]
ds = ds.drop_vars(["%s_ADJUSTED" % param])
if "%s_ADJUSTED_QC" % param in ds:
if copy_adj:
ds["%s_QC" % param].loc[dict(N_POINTS=ii_measured_adj)] = ds[
"%s_ADJUSTED_QC" % param
].loc[dict(N_POINTS=ii_measured_adj)]
ds = ds.drop_vars(["%s_ADJUSTED_QC" % param])

if "%s_ERROR" % param in ds and "%s_ADJUSTED_ERROR" % param in ds:
ds["%s_ERROR" % param].loc[dict(N_POINTS=ii_measured_adj)] = ds[
"%s_ADJUSTED_ERROR" % param
].loc[dict(N_POINTS=ii_measured_adj)]
if copy_adj:
ds["%s_ERROR" % param].loc[dict(N_POINTS=ii_measured_adj)] = ds[
"%s_ADJUSTED_ERROR" % param
].loc[dict(N_POINTS=ii_measured_adj)]
ds = ds.drop_vars(["%s_ADJUSTED_ERROR" % param])

if core_ds:
Expand Down Expand Up @@ -380,34 +378,34 @@ def split_data_mode(ds: xr.Dataset) -> xr.Dataset:

def read_data_mode_for(ds: xr.Dataset, param: str) -> xr.DataArray:
"""Return data mode of a given parameter"""
da_masked = ds['PARAMETER_DATA_MODE'].where(ds['STATION_PARAMETERS'] == u64(param))
da_masked = ds["PARAMETER_DATA_MODE"].where(
ds["STATION_PARAMETERS"] == u64(param), ""
)

def _dropna(x):
# x('N_PARAM') is reduced to the first non nan value, a scalar, no dimension
y = pd.Series(x).dropna().tolist()
if len(y) == 0:
return ""
else:
return y[0]
_dropna = lambda x: next(
(item for item in x if item != ""), ""
) # noqa: E731

kwargs = dict(
dask="parallelized",
input_core_dims=[["N_PARAM"]], # Function takes N_PARAM as input
output_core_dims=[[]], # Function reduces to a scalar (no dimension)
vectorize=True # Apply function element-wise along the other dimensions
vectorize=True, # Apply function element-wise along the other dimensions
)

dm = xr.apply_ufunc(_dropna, da_masked, **kwargs)
dm = dm.rename("%s_DATA_MODE" % param)
dm.attrs = ds['PARAMETER_DATA_MODE'].attrs
dm.attrs = ds["PARAMETER_DATA_MODE"].attrs
return dm

for param in params:
name = "%s_DATA_MODE" % param.replace("_PARAMETER", "").replace(
"PARAMETER_", ""
)
if name == "_DATA_MODE":
log.error("This dataset has an error in 'STATION_PARAMETERS': it contains an empty string")
log.error(
"This dataset has an error in 'STATION_PARAMETERS': it contains an empty string"
)
else:
ds[name] = read_data_mode_for(ds, param)

Expand Down
Loading