@@ -162,21 +162,52 @@ def _xarray_data_array_direct(
162
162
import xarray as xr
163
163
164
164
meas_paramspec = dataset .description .interdeps .graph .nodes [name ]["value" ]
165
- _ , deps , _ = dataset .description .interdeps .all_parameters_in_tree_by_group (
165
+ _ , deps , inferred = dataset .description .interdeps .all_parameters_in_tree_by_group (
166
166
meas_paramspec
167
167
)
168
- dep_axis = {}
168
+ # Build coordinate axes from direct dependencies preserving their order
169
+ dep_axis : dict [str , npt .NDArray ] = {}
169
170
for axis , dep in enumerate (deps ):
170
171
dep_array = subdict [dep .name ]
171
172
dep_axis [dep .name ] = dep_array [
172
173
tuple (slice (None ) if i == axis else 0 for i in range (dep_array .ndim ))
173
174
]
174
175
175
- da = xr .Dataset (
176
+ extra_coords : dict [str , tuple [tuple [str , ...] | tuple [str ], npt .NDArray ]] = {}
177
+ for inf in inferred :
178
+ # skip parameters already used as primary coordinate axes
179
+ if inf .name in dep_axis :
180
+ continue
181
+ # add only if data for this parameter is available
182
+ if inf .name not in subdict :
183
+ continue
184
+
185
+ inf_related = dataset .description .interdeps .find_all_parameters_in_tree (inf )
186
+
187
+ related_deps = inf_related .intersection (set (deps ))
188
+ related_top_level = inf_related .intersection ({meas_paramspec })
189
+
190
+ if len (related_top_level ) > 0 :
191
+ raise NotImplementedError (
192
+ "Adding inferred coords related to top level param is not yet supported"
193
+ )
194
+
195
+ inf_data = subdict [inf .name ][
196
+ tuple (slice (None ) if dep in related_deps else 0 for dep in deps )
197
+ ]
198
+ inf_coords = [dep .name for dep in deps if dep in related_deps ]
199
+
200
+ extra_coords [inf .name ] = (tuple (inf_coords ), inf_data )
201
+
202
+ # Compose coordinates dict including dependency axes and extra inferred coords
203
+ coords : dict [str , tuple [tuple [str , ...] | tuple [str ], npt .NDArray ] | npt .NDArray ]
204
+ coords = {** dep_axis , ** extra_coords }
205
+
206
+ ds = xr .Dataset (
176
207
{name : (tuple (dep_axis .keys ()), subdict [name ])},
177
- coords = dep_axis ,
178
- )[ name ]
179
- return da
208
+ coords = coords ,
209
+ )
210
+ return ds [ name ]
180
211
181
212
182
213
def load_to_xarray_dataarray_dict (
0 commit comments