Skip to content

Commit 38abf50

Browse files
committed
Add test where loop is over infeered parameter
1 parent 3f398cd commit 38abf50

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

tests/dataset/test_dataset_export.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1738,3 +1738,80 @@ def test_measurement_hypothesis_nd_grid_with_inferred_param(
17381738
assert set(inf_indexes.keys()) == set(inf_sp_names)
17391739
for dim in inf_sp_names:
17401740
assert inf_indexes[dim].equals(xr_ds.indexes[dim])
1741+
1742+
1743+
def test_measurement_2d_with_inferred_setpoint(
1744+
experiment: Experiment, caplog: LogCaptureFixture
1745+
) -> None:
1746+
"""
1747+
Sweep two parameters (x, y) where y is inferred from one or more basis parameters.
1748+
Verify that xarray export uses direct method, signal dims match, and basis
1749+
parameters appear as inferred coordinates with indexes corresponding to y.
1750+
"""
1751+
# Grid sizes
1752+
nx, ny = 3, 4
1753+
x_vals = np.linspace(0.0, 2.0, nx)
1754+
# Define basis parameters for y and compute y from these
1755+
y_b0_vals = np.linspace(10.0, 13.0, ny)
1756+
y_b1_vals = np.linspace(-1.0, 2.0, ny)
1757+
# y is inferred from (y_b0, y_b1)
1758+
y_vals = y_b0_vals + 2.0 * y_b1_vals
1759+
1760+
meas = Measurement(exp=experiment, name="2d_with_inferred_setpoint")
1761+
# Register setpoint x
1762+
meas.register_custom_parameter("x", paramtype="numeric")
1763+
# Register basis params for y
1764+
meas.register_custom_parameter("y_b0", paramtype="numeric")
1765+
meas.register_custom_parameter("y_b1", paramtype="numeric")
1766+
# Register y as setpoint inferred from basis
1767+
meas.register_custom_parameter("y", basis=("y_b0", "y_b1"), paramtype="numeric")
1768+
# Register measured parameter depending on (x, y)
1769+
meas.register_custom_parameter("signal", setpoints=("x", "y"), paramtype="numeric")
1770+
meas.set_shapes({"signal": (nx, ny)})
1771+
1772+
with meas.run() as datasaver:
1773+
for ix in range(nx):
1774+
for iy in range(ny):
1775+
x = float(x_vals[ix])
1776+
y_b0 = float(y_b0_vals[iy])
1777+
y_b1 = float(y_b1_vals[iy])
1778+
y = float(y_vals[iy])
1779+
signal = x + 3.0 * y # deterministic function
1780+
datasaver.add_result(
1781+
("x", x),
1782+
("y_b0", y_b0),
1783+
("y_b1", y_b1),
1784+
("y", y),
1785+
("signal", signal),
1786+
)
1787+
1788+
ds = datasaver.dataset
1789+
1790+
caplog.clear()
1791+
with caplog.at_level(logging.INFO):
1792+
xr_ds = ds.to_xarray_dataset()
1793+
1794+
assert any(
1795+
"Exporting signal to xarray using direct method" in record.message
1796+
for record in caplog.records
1797+
)
1798+
1799+
# Sizes and coords
1800+
assert xr_ds.sizes == {"x": nx, "y": ny}
1801+
np.testing.assert_allclose(xr_ds.coords["x"].values, x_vals)
1802+
np.testing.assert_allclose(xr_ds.coords["y"].values, y_vals)
1803+
1804+
# Signal dims and values
1805+
assert xr_ds["signal"].dims == ("x", "y")
1806+
expected_signal = x_vals[:, None] + 3.0 * y_vals[None, :]
1807+
np.testing.assert_allclose(xr_ds["signal"].values, expected_signal)
1808+
1809+
# Inferred coords for y_b0 and y_b1 exist with dims only along y
1810+
for name, vals in ("y_b0", y_b0_vals), ("y_b1", y_b1_vals):
1811+
assert name in xr_ds.coords
1812+
assert xr_ds.coords[name].dims == ("y",)
1813+
np.testing.assert_allclose(xr_ds.coords[name].values, vals)
1814+
# Indexes of inferred coords should correspond to the y axis index
1815+
inf_idx = xr_ds.coords[name].indexes
1816+
assert set(inf_idx.keys()) == {"y"}
1817+
assert inf_idx["y"].equals(xr_ds.indexes["y"])

0 commit comments

Comments
 (0)