4
4
import warnings
5
5
from importlib .metadata import version
6
6
from math import prod
7
- from typing import TYPE_CHECKING , Literal , cast
7
+ from typing import TYPE_CHECKING , Literal
8
8
9
9
from packaging import version as pversion
10
10
@@ -66,15 +66,13 @@ def _load_to_xarray_dataarray_dict_no_metadata(
66
66
datadict : Mapping [str , Mapping [str , npt .NDArray ]],
67
67
* ,
68
68
use_multi_index : Literal ["auto" , "always" , "never" ] = "auto" ,
69
- ) -> dict [str , xr .DataArray ]:
70
- import xarray as xr
71
-
69
+ ) -> dict [str , xr .Dataset ]:
72
70
if use_multi_index not in ("auto" , "always" , "never" ):
73
71
raise ValueError (
74
72
f"Invalid value for use_multi_index. Expected one of 'auto', 'always', 'never' but got { use_multi_index } "
75
73
)
76
74
77
- data_xrdarray_dict : dict [str , xr .DataArray ] = {}
75
+ data_xrdarray_dict : dict [str , xr .Dataset ] = {}
78
76
79
77
for name , subdict in datadict .items ():
80
78
shape_is_consistent = (
@@ -96,11 +94,9 @@ def _load_to_xarray_dataarray_dict_no_metadata(
96
94
)
97
95
98
96
if index is None :
99
- xrdarray : xr .DataArray = (
100
- _data_to_dataframe (subdict , index = index )
101
- .to_xarray ()
102
- .get (name , xr .DataArray ())
103
- )
97
+ xrdarray : xr .Dataset = _data_to_dataframe (
98
+ subdict , index = index
99
+ ).to_xarray ()
104
100
data_xrdarray_dict [name ] = xrdarray
105
101
elif index_is_unique :
106
102
df = _data_to_dataframe (subdict , index )
@@ -109,9 +105,7 @@ def _load_to_xarray_dataarray_dict_no_metadata(
109
105
)
110
106
else :
111
107
df = _data_to_dataframe (subdict , index )
112
- xrdata_temp = df .reset_index ().to_xarray ()
113
- for _name in subdict :
114
- data_xrdarray_dict [_name ] = xrdata_temp [_name ]
108
+ data_xrdarray_dict [name ] = df .reset_index ().to_xarray ()
115
109
116
110
return data_xrdarray_dict
117
111
@@ -122,7 +116,7 @@ def _xarray_data_array_from_pandas_multi_index(
122
116
name : str ,
123
117
df : pd .DataFrame ,
124
118
index : pd .Index | pd .MultiIndex ,
125
- ) -> xr .DataArray :
119
+ ) -> xr .Dataset :
126
120
import pandas as pd
127
121
import xarray as xr
128
122
@@ -148,16 +142,16 @@ def _xarray_data_array_from_pandas_multi_index(
148
142
)
149
143
150
144
coords = xr .Coordinates .from_pandas_multiindex (df .index , "multi_index" )
151
- xrdarray = xr .DataArray (df [name ], coords = coords )
145
+ xrdarray = xr .DataArray (df [name ], coords = coords ). to_dataset ( name = name )
152
146
else :
153
- xrdarray = df .to_xarray (). get ( name , xr . DataArray ())
147
+ xrdarray = df .to_xarray ()
154
148
155
149
return xrdarray
156
150
157
151
158
152
def _xarray_data_array_direct (
159
153
dataset : DataSetProtocol , name : str , subdict : Mapping [str , npt .NDArray ]
160
- ) -> xr .DataArray :
154
+ ) -> xr .Dataset :
161
155
import xarray as xr
162
156
163
157
meas_paramspec = dataset .description .interdeps .graph .nodes [name ]["value" ]
@@ -213,20 +207,20 @@ def _xarray_data_array_direct(
213
207
data_vars .update (extra_data_vars )
214
208
215
209
ds = xr .Dataset (data_vars , coords = coords )
216
- da = ds [name ]
217
- if len (extra_data_vars ) > 0 :
218
- # stash extra data vars to be added at dataset assembly time
219
- # mapping: var_name -> (dims_tuple, numpy array)
220
- da .attrs ["_qcodes_extra_data_vars" ] = extra_data_vars
221
- return da
210
+ # da = ds[name]
211
+ # if len(extra_data_vars) > 0:
212
+ # # stash extra data vars to be added at dataset assembly time
213
+ # # mapping: var_name -> (dims_tuple, numpy array)
214
+ # da.attrs["_qcodes_extra_data_vars"] = extra_data_vars
215
+ return ds
222
216
223
217
224
218
def load_to_xarray_dataarray_dict (
225
219
dataset : DataSetProtocol ,
226
220
datadict : Mapping [str , Mapping [str , npt .NDArray ]],
227
221
* ,
228
222
use_multi_index : Literal ["auto" , "always" , "never" ] = "auto" ,
229
- ) -> dict [str , xr .DataArray ]:
223
+ ) -> dict [str , xr .Dataset ]:
230
224
dataarrays = _load_to_xarray_dataarray_dict_no_metadata (
231
225
dataset , datadict , use_multi_index = use_multi_index
232
226
)
@@ -282,14 +276,14 @@ def load_to_xarray_dataset(
282
276
283
277
# Casting Hashable for the key type until python/mypy#1114
284
278
# and python/typing#445 are resolved.
285
- xrdataset = xr .Dataset ( cast ( "dict[Hashable, xr.DataArray]" , data_xrdarray_dict ) )
279
+ xrdataset = xr .merge ( data_xrdarray_dict . values (), compat = "equals" , join = "outer" )
286
280
287
281
# add any stashed extra data variables created during direct export
288
- for _ , dataarray in data_xrdarray_dict .items ():
289
- extras = dataarray .attrs .pop ("_qcodes_extra_data_vars" , None )
290
- if isinstance (extras , dict ):
291
- for var_name , (dims , values ) in extras .items ():
292
- xrdataset [var_name ] = (dims , values )
282
+ # for _, dataarray in data_xrdarray_dict.items():
283
+ # extras = dataarray.attrs.pop("_qcodes_extra_data_vars", None)
284
+ # if isinstance(extras, dict):
285
+ # for var_name, (dims, values) in extras.items():
286
+ # xrdataset[var_name] = (dims, values)
293
287
294
288
_add_param_spec_to_xarray_coords (dataset , xrdataset )
295
289
_add_param_spec_to_xarray_data_vars (dataset , xrdataset )
0 commit comments