Skip to content

Commit 01c1ff5

Browse files
committed
Insert derived parameters in data
1 parent 6dd19f3 commit 01c1ff5

File tree

2 files changed

+89
-14
lines changed

2 files changed

+89
-14
lines changed

src/qcodes/dataset/exporters/export_to_xarray.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def _xarray_data_array_direct(
173173
]
174174

175175
extra_coords: dict[str, tuple[tuple[str, ...], npt.NDArray]] = {}
176+
extra_data_vars: dict[str, tuple[tuple[str, ...], npt.NDArray]] = {}
176177
for inf in inferred:
177178
# skip parameters already used as primary coordinate axes
178179
if inf.name in dep_axis:
@@ -187,26 +188,37 @@ def _xarray_data_array_direct(
187188
related_top_level = inf_related.intersection({meas_paramspec})
188189

189190
if len(related_top_level) > 0:
190-
raise NotImplementedError(
191-
"Adding inferred coords related to top level param is not yet supported"
192-
)
193-
194-
inf_data = subdict[inf.name][
195-
tuple(slice(None) if dep in related_deps else 0 for dep in deps)
196-
]
197-
inf_coords = [dep.name for dep in deps if dep in related_deps]
191+
# If inferred param is related to the top-level measurement parameter,
192+
# add it as a data variable with the full dependency dimensions
193+
inf_data_full = subdict[inf.name]
194+
inf_dims_full = tuple(dep_axis.keys())
195+
extra_data_vars[inf.name] = (inf_dims_full, inf_data_full)
196+
else:
197+
# Otherwise, add as a coordinate along the related dependency axes only
198+
inf_data = subdict[inf.name][
199+
tuple(slice(None) if dep in related_deps else 0 for dep in deps)
200+
]
201+
inf_coords = [dep.name for dep in deps if dep in related_deps]
198202

199-
extra_coords[inf.name] = (tuple(inf_coords), inf_data)
203+
extra_coords[inf.name] = (tuple(inf_coords), inf_data)
200204

201205
# Compose coordinates dict including dependency axes and extra inferred coords
202206
coords: dict[str, tuple[tuple[str, ...], npt.NDArray] | npt.NDArray]
203207
coords = {**dep_axis, **extra_coords}
204208

205-
ds = xr.Dataset(
206-
{name: (tuple(dep_axis.keys()), subdict[name])},
207-
coords=coords,
208-
)
209-
return ds[name]
209+
# Compose data variables dict including measured var and any inferred data vars
210+
data_vars: dict[str, tuple[tuple[str, ...], npt.NDArray]] = {
211+
name: (tuple(dep_axis.keys()), subdict[name])
212+
}
213+
data_vars.update(extra_data_vars)
214+
215+
ds = xr.Dataset(data_vars, coords=coords)
216+
da = ds[name]
217+
if len(extra_data_vars) > 0:
218+
# stash extra data vars to be added at dataset assembly time
219+
# mapping: var_name -> (dims_tuple, numpy array)
220+
da.attrs["_qcodes_extra_data_vars"] = extra_data_vars
221+
return da
210222

211223

212224
def load_to_xarray_dataarray_dict(
@@ -272,6 +284,13 @@ def load_to_xarray_dataset(
272284
# and python/typing#445 are resolved.
273285
xrdataset = xr.Dataset(cast("dict[Hashable, xr.DataArray]", data_xrdarray_dict))
274286

287+
# add any stashed extra data variables created during direct export
288+
for _, dataarray in data_xrdarray_dict.items():
289+
extras = dataarray.attrs.pop("_qcodes_extra_data_vars", None)
290+
if isinstance(extras, dict):
291+
for var_name, (dims, values) in extras.items():
292+
xrdataset[var_name] = (dims, values)
293+
275294
_add_param_spec_to_xarray_coords(dataset, xrdataset)
276295
_add_param_spec_to_xarray_data_vars(dataset, xrdataset)
277296
_add_metadata_to_xarray(dataset, xrdataset)

tests/dataset/test_dataset_export.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1971,3 +1971,59 @@ def test_measurement_2d_with_inferred_setpoint(
19711971
inf_idx = xr_ds.coords[name].indexes
19721972
assert set(inf_idx.keys()) == {"y"}
19731973
assert inf_idx["y"].equals(xr_ds.indexes["y"])
1974+
1975+
1976+
def test_measurement_2d_top_level_inferred_is_data_var(
1977+
experiment: Experiment, caplog: LogCaptureFixture
1978+
) -> None:
1979+
"""
1980+
If an inferred parameter is related to the top-level measured parameter,
1981+
it must be exported as a data variable (not a coordinate) with the full
1982+
dependency dimensions.
1983+
"""
1984+
nx, ny = 2, 3
1985+
x_vals = np.linspace(0.0, 1.0, nx)
1986+
y_vals = np.linspace(10.0, 12.0, ny)
1987+
1988+
# Define a measured signal and an inferred param both defined on (x, y)
1989+
# The inferred param is related to the measured top-level param in the graph
1990+
meas = Measurement(exp=experiment, name="2d_top_level_inferred")
1991+
meas.register_custom_parameter("x", paramtype="numeric")
1992+
meas.register_custom_parameter("y", paramtype="numeric")
1993+
# Register measured top-level
1994+
meas.register_custom_parameter("signal", setpoints=("x", "y"), paramtype="numeric")
1995+
# Register inferred related to top-level (basis includes the measured top-level)
1996+
meas.register_custom_parameter("derived", basis=("signal",), paramtype="numeric")
1997+
meas.set_shapes({"signal": (nx, ny)})
1998+
1999+
with meas.run() as datasaver:
2000+
for ix in range(nx):
2001+
for iy in range(ny):
2002+
x = float(x_vals[ix])
2003+
y = float(y_vals[iy])
2004+
signal = x + y
2005+
derived = 2.0 * signal # inferred from top-level
2006+
datasaver.add_result(
2007+
("x", x), ("y", y), ("signal", signal), ("derived", derived)
2008+
)
2009+
2010+
ds = datasaver.dataset
2011+
caplog.clear()
2012+
with caplog.at_level(logging.INFO):
2013+
xr_ds = ds.to_xarray_dataset()
2014+
2015+
# Direct path log should be present
2016+
assert any(
2017+
"Exporting signal to xarray using direct method" in record.message
2018+
for record in caplog.records
2019+
)
2020+
2021+
# The derived param should be a data variable with dims (x, y), not a coord
2022+
assert "derived" in xr_ds.data_vars
2023+
assert "derived" not in xr_ds.coords
2024+
assert xr_ds["derived"].dims == ("x", "y")
2025+
2026+
expected_signal = x_vals[:, None] + y_vals[None, :]
2027+
expected_derived = 2.0 * expected_signal
2028+
np.testing.assert_allclose(xr_ds["signal"].values, expected_signal)
2029+
np.testing.assert_allclose(xr_ds["derived"].values, expected_derived)

0 commit comments

Comments
 (0)