Skip to content

Commit 40d92cb

Browse files
committed
Wip merge via dataset rather than array
1 parent 01c1ff5 commit 40d92cb

File tree

1 file changed

+24
-30
lines changed

1 file changed

+24
-30
lines changed

src/qcodes/dataset/exporters/export_to_xarray.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import warnings
55
from importlib.metadata import version
66
from math import prod
7-
from typing import TYPE_CHECKING, Literal, cast
7+
from typing import TYPE_CHECKING, Literal
88

99
from packaging import version as pversion
1010

@@ -66,15 +66,13 @@ def _load_to_xarray_dataarray_dict_no_metadata(
6666
datadict: Mapping[str, Mapping[str, npt.NDArray]],
6767
*,
6868
use_multi_index: Literal["auto", "always", "never"] = "auto",
69-
) -> dict[str, xr.DataArray]:
70-
import xarray as xr
71-
69+
) -> dict[str, xr.Dataset]:
7270
if use_multi_index not in ("auto", "always", "never"):
7371
raise ValueError(
7472
f"Invalid value for use_multi_index. Expected one of 'auto', 'always', 'never' but got {use_multi_index}"
7573
)
7674

77-
data_xrdarray_dict: dict[str, xr.DataArray] = {}
75+
data_xrdarray_dict: dict[str, xr.Dataset] = {}
7876

7977
for name, subdict in datadict.items():
8078
shape_is_consistent = (
@@ -96,11 +94,9 @@ def _load_to_xarray_dataarray_dict_no_metadata(
9694
)
9795

9896
if index is None:
99-
xrdarray: xr.DataArray = (
100-
_data_to_dataframe(subdict, index=index)
101-
.to_xarray()
102-
.get(name, xr.DataArray())
103-
)
97+
xrdarray: xr.Dataset = _data_to_dataframe(
98+
subdict, index=index
99+
).to_xarray()
104100
data_xrdarray_dict[name] = xrdarray
105101
elif index_is_unique:
106102
df = _data_to_dataframe(subdict, index)
@@ -109,9 +105,7 @@ def _load_to_xarray_dataarray_dict_no_metadata(
109105
)
110106
else:
111107
df = _data_to_dataframe(subdict, index)
112-
xrdata_temp = df.reset_index().to_xarray()
113-
for _name in subdict:
114-
data_xrdarray_dict[_name] = xrdata_temp[_name]
108+
data_xrdarray_dict[name] = df.reset_index().to_xarray()
115109

116110
return data_xrdarray_dict
117111

@@ -122,7 +116,7 @@ def _xarray_data_array_from_pandas_multi_index(
122116
name: str,
123117
df: pd.DataFrame,
124118
index: pd.Index | pd.MultiIndex,
125-
) -> xr.DataArray:
119+
) -> xr.Dataset:
126120
import pandas as pd
127121
import xarray as xr
128122

@@ -148,16 +142,16 @@ def _xarray_data_array_from_pandas_multi_index(
148142
)
149143

150144
coords = xr.Coordinates.from_pandas_multiindex(df.index, "multi_index")
151-
xrdarray = xr.DataArray(df[name], coords=coords)
145+
xrdarray = xr.DataArray(df[name], coords=coords).to_dataset(name=name)
152146
else:
153-
xrdarray = df.to_xarray().get(name, xr.DataArray())
147+
xrdarray = df.to_xarray()
154148

155149
return xrdarray
156150

157151

158152
def _xarray_data_array_direct(
159153
dataset: DataSetProtocol, name: str, subdict: Mapping[str, npt.NDArray]
160-
) -> xr.DataArray:
154+
) -> xr.Dataset:
161155
import xarray as xr
162156

163157
meas_paramspec = dataset.description.interdeps.graph.nodes[name]["value"]
@@ -213,20 +207,20 @@ def _xarray_data_array_direct(
213207
data_vars.update(extra_data_vars)
214208

215209
ds = xr.Dataset(data_vars, coords=coords)
216-
da = ds[name]
217-
if len(extra_data_vars) > 0:
218-
# stash extra data vars to be added at dataset assembly time
219-
# mapping: var_name -> (dims_tuple, numpy array)
220-
da.attrs["_qcodes_extra_data_vars"] = extra_data_vars
221-
return da
210+
# da = ds[name]
211+
# if len(extra_data_vars) > 0:
212+
# # stash extra data vars to be added at dataset assembly time
213+
# # mapping: var_name -> (dims_tuple, numpy array)
214+
# da.attrs["_qcodes_extra_data_vars"] = extra_data_vars
215+
return ds
222216

223217

224218
def load_to_xarray_dataarray_dict(
225219
dataset: DataSetProtocol,
226220
datadict: Mapping[str, Mapping[str, npt.NDArray]],
227221
*,
228222
use_multi_index: Literal["auto", "always", "never"] = "auto",
229-
) -> dict[str, xr.DataArray]:
223+
) -> dict[str, xr.Dataset]:
230224
dataarrays = _load_to_xarray_dataarray_dict_no_metadata(
231225
dataset, datadict, use_multi_index=use_multi_index
232226
)
@@ -282,14 +276,14 @@ def load_to_xarray_dataset(
282276

283277
# Casting Hashable for the key type until python/mypy#1114
284278
# and python/typing#445 are resolved.
285-
xrdataset = xr.Dataset(cast("dict[Hashable, xr.DataArray]", data_xrdarray_dict))
279+
xrdataset = xr.merge(data_xrdarray_dict.values(), compat="equals", join="outer")
286280

287281
# add any stashed extra data variables created during direct export
288-
for _, dataarray in data_xrdarray_dict.items():
289-
extras = dataarray.attrs.pop("_qcodes_extra_data_vars", None)
290-
if isinstance(extras, dict):
291-
for var_name, (dims, values) in extras.items():
292-
xrdataset[var_name] = (dims, values)
282+
# for _, dataarray in data_xrdarray_dict.items():
283+
# extras = dataarray.attrs.pop("_qcodes_extra_data_vars", None)
284+
# if isinstance(extras, dict):
285+
# for var_name, (dims, values) in extras.items():
286+
# xrdataset[var_name] = (dims, values)
293287

294288
_add_param_spec_to_xarray_coords(dataset, xrdataset)
295289
_add_param_spec_to_xarray_data_vars(dataset, xrdataset)

0 commit comments

Comments
 (0)