@@ -173,6 +173,7 @@ def _xarray_data_array_direct(
173
173
]
174
174
175
175
extra_coords : dict [str , tuple [tuple [str , ...], npt .NDArray ]] = {}
176
+ extra_data_vars : dict [str , tuple [tuple [str , ...], npt .NDArray ]] = {}
176
177
for inf in inferred :
177
178
# skip parameters already used as primary coordinate axes
178
179
if inf .name in dep_axis :
@@ -187,26 +188,37 @@ def _xarray_data_array_direct(
187
188
related_top_level = inf_related .intersection ({meas_paramspec })
188
189
189
190
if len (related_top_level ) > 0 :
190
- raise NotImplementedError (
191
- "Adding inferred coords related to top level param is not yet supported"
192
- )
193
-
194
- inf_data = subdict [inf .name ][
195
- tuple (slice (None ) if dep in related_deps else 0 for dep in deps )
196
- ]
197
- inf_coords = [dep .name for dep in deps if dep in related_deps ]
191
+ # If inferred param is related to the top-level measurement parameter,
192
+ # add it as a data variable with the full dependency dimensions
193
+ inf_data_full = subdict [inf .name ]
194
+ inf_dims_full = tuple (dep_axis .keys ())
195
+ extra_data_vars [inf .name ] = (inf_dims_full , inf_data_full )
196
+ else :
197
+ # Otherwise, add as a coordinate along the related dependency axes only
198
+ inf_data = subdict [inf .name ][
199
+ tuple (slice (None ) if dep in related_deps else 0 for dep in deps )
200
+ ]
201
+ inf_coords = [dep .name for dep in deps if dep in related_deps ]
198
202
199
- extra_coords [inf .name ] = (tuple (inf_coords ), inf_data )
203
+ extra_coords [inf .name ] = (tuple (inf_coords ), inf_data )
200
204
201
205
# Compose coordinates dict including dependency axes and extra inferred coords
202
206
coords : dict [str , tuple [tuple [str , ...], npt .NDArray ] | npt .NDArray ]
203
207
coords = {** dep_axis , ** extra_coords }
204
208
205
- ds = xr .Dataset (
206
- {name : (tuple (dep_axis .keys ()), subdict [name ])},
207
- coords = coords ,
208
- )
209
- return ds [name ]
209
+ # Compose data variables dict including measured var and any inferred data vars
210
+ data_vars : dict [str , tuple [tuple [str , ...], npt .NDArray ]] = {
211
+ name : (tuple (dep_axis .keys ()), subdict [name ])
212
+ }
213
+ data_vars .update (extra_data_vars )
214
+
215
+ ds = xr .Dataset (data_vars , coords = coords )
216
+ da = ds [name ]
217
+ if len (extra_data_vars ) > 0 :
218
+ # stash extra data vars to be added at dataset assembly time
219
+ # mapping: var_name -> (dims_tuple, numpy array)
220
+ da .attrs ["_qcodes_extra_data_vars" ] = extra_data_vars
221
+ return da
210
222
211
223
212
224
def load_to_xarray_dataarray_dict (
@@ -272,6 +284,13 @@ def load_to_xarray_dataset(
272
284
# and python/typing#445 are resolved.
273
285
xrdataset = xr .Dataset (cast ("dict[Hashable, xr.DataArray]" , data_xrdarray_dict ))
274
286
287
+ # add any stashed extra data variables created during direct export
288
+ for _ , dataarray in data_xrdarray_dict .items ():
289
+ extras = dataarray .attrs .pop ("_qcodes_extra_data_vars" , None )
290
+ if isinstance (extras , dict ):
291
+ for var_name , (dims , values ) in extras .items ():
292
+ xrdataset [var_name ] = (dims , values )
293
+
275
294
_add_param_spec_to_xarray_coords (dataset , xrdataset )
276
295
_add_param_spec_to_xarray_data_vars (dataset , xrdataset )
277
296
_add_metadata_to_xarray (dataset , xrdataset )
0 commit comments