Skip to content

Commit b8cde31

Browse files
committed
Add test for inference
1 parent 74a79c1 commit b8cde31

File tree

1 file changed

+139
-0
lines changed

1 file changed

+139
-0
lines changed

tests/dataset/test_dataset_export.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1599,3 +1599,142 @@ def test_dond_hypothesis_nd_grid(
15991599
expected_signal += weights[i] * grid.astype(float)
16001600

16011601
np.testing.assert_allclose(xr_ds["signal"].values, expected_signal)
1602+
1603+
1604+
@given(data=hst.data())
1605+
@settings(
1606+
max_examples=10,
1607+
suppress_health_check=(HealthCheck.function_scoped_fixture,),
1608+
deadline=None,
1609+
)
1610+
def test_measurement_hypothesis_nd_grid_with_inferred_param(
1611+
data: hst.DataObject, experiment: Experiment, caplog: LogCaptureFixture
1612+
) -> None:
1613+
"""
1614+
Randomized ND sweep using Measurement context manager with an inferred parameter:
1615+
- Draw N in [2, 4]
1616+
- For each dimension i, draw number of points n_i in [1, 5]
1617+
- Sweep each ManualParameter over a linspace of length n_i
1618+
- Choose m in [1, N-1] and a subset of m swept parameters for an inferred coord
1619+
- Register an inferred parameter depending on that subset and add its values
1620+
- Measure a deterministic function of the setpoints
1621+
- Assert xarray dims, coords (including inferred), and data match expectation
1622+
"""
1623+
# number of dimensions and points per dimension
1624+
n_dims = data.draw(hst.integers(min_value=2, max_value=4), label="n_dims")
1625+
points_per_dim = [
1626+
data.draw(hst.integers(min_value=1, max_value=5), label=f"n_points_dim_{i}")
1627+
for i in range(n_dims)
1628+
]
1629+
1630+
# build setpoint arrays and names
1631+
sp_names = [f"x{i}" for i in range(n_dims)]
1632+
sp_values: list[np.ndarray] = [
1633+
np.linspace(0.0, float(npts - 1), npts) for npts in points_per_dim
1634+
]
1635+
1636+
# choose subset for inferred parameter (strict subset)
1637+
m = data.draw(hst.integers(min_value=1, max_value=n_dims - 1), label="m")
1638+
inf_indices = sorted(
1639+
data.draw(
1640+
hst.lists(
1641+
hst.integers(min_value=0, max_value=n_dims - 1),
1642+
min_size=m,
1643+
max_size=m,
1644+
unique=True,
1645+
),
1646+
label="inf_indices",
1647+
)
1648+
)
1649+
inf_sp_names = [sp_names[i] for i in inf_indices]
1650+
1651+
# weights for measured signal
1652+
weights = [(i + 1) for i in range(n_dims)]
1653+
1654+
# Setup measurement with shapes so xarray direct path is used
1655+
meas = Measurement(exp=experiment, name="nd_grid_with_inferred")
1656+
# register setpoints
1657+
for name in sp_names:
1658+
meas.register_custom_parameter(name, paramtype="numeric")
1659+
# register inferred parameter (from subset of setpoints)
1660+
meas.register_custom_parameter(
1661+
"inf", basis=tuple(inf_sp_names), paramtype="numeric"
1662+
)
1663+
# register measured parameter depending on all setpoints
1664+
meas.register_custom_parameter(
1665+
"signal", setpoints=tuple(sp_names), paramtype="numeric"
1666+
)
1667+
meas.set_shapes({"signal": tuple(points_per_dim)})
1668+
1669+
# run measurement over full grid
1670+
with meas.run() as datasaver:
1671+
# iterate over grid indices
1672+
for idx in np.ndindex(*points_per_dim):
1673+
# collect setpoint values for this point
1674+
sp_items: list[tuple[str, float]] = [
1675+
(sp_names[k], float(sp_values[k][idx[k]])) for k in range(n_dims)
1676+
]
1677+
# measured signal: weighted sum of all setpoints
1678+
signal_val = float(
1679+
sum(weights[k] * float(sp_values[k][idx[k]]) for k in range(n_dims))
1680+
)
1681+
# inferred value: sum over selected subset of setpoints
1682+
inf_val = float(sum(float(sp_values[k][idx[k]]) for k in inf_indices))
1683+
results: list[tuple[str, float]] = [
1684+
*sp_items,
1685+
("inf", inf_val),
1686+
("signal", signal_val),
1687+
]
1688+
datasaver.add_result(*results)
1689+
1690+
ds = datasaver.dataset
1691+
1692+
# export to xarray and ensure direct path used
1693+
caplog.clear()
1694+
with caplog.at_level(logging.INFO):
1695+
xr_ds = ds.to_xarray_dataset()
1696+
1697+
assert any(
1698+
"Exporting signal to xarray using direct method" in record.message
1699+
for record in caplog.records
1700+
)
1701+
1702+
# Expected sizes per coordinate (all setpoints)
1703+
expected_sizes = {name: len(vals) for name, vals in zip(sp_names, sp_values)}
1704+
assert xr_ds.sizes == expected_sizes
1705+
1706+
# Check setpoint coords contents and order
1707+
for name, vals in zip(sp_names, sp_values):
1708+
assert name in xr_ds.coords
1709+
np.testing.assert_allclose(xr_ds.coords[name].values, vals)
1710+
1711+
# Measured data dims and values
1712+
assert "signal" in xr_ds.data_vars
1713+
assert xr_ds["signal"].dims == tuple(sp_names)
1714+
1715+
grids_all = np.meshgrid(*sp_values, indexing="ij")
1716+
expected_signal = np.zeros(tuple(points_per_dim), dtype=float)
1717+
for i, grid in enumerate(grids_all):
1718+
expected_signal += weights[i] * grid.astype(float)
1719+
np.testing.assert_allclose(xr_ds["signal"].values, expected_signal)
1720+
1721+
# Inferred coord should be present with dims equal to the subset order
1722+
assert "inf" in xr_ds.coords
1723+
expected_inf_dims = tuple(inf_sp_names)
1724+
assert xr_ds.coords["inf"].dims == expected_inf_dims
1725+
1726+
# Build expected inferred grid based only on the subset dims
1727+
subset_values = [sp_values[i] for i in inf_indices]
1728+
grids_subset = np.meshgrid(*subset_values, indexing="ij") if subset_values else []
1729+
expected_inf = np.zeros(tuple(points_per_dim[i] for i in inf_indices), dtype=float)
1730+
for grid in grids_subset:
1731+
expected_inf += grid.astype(float)
1732+
np.testing.assert_allclose(xr_ds.coords["inf"].values, expected_inf)
1733+
1734+
# The indexes of the inferred coord must correspond to the axes it depends on
1735+
# i.e., keys should match the inferred-from setpoint names, and each index equal
1736+
# to the dataset's index for that dimension
1737+
inf_indexes = xr_ds.coords["inf"].indexes
1738+
assert set(inf_indexes.keys()) == set(inf_sp_names)
1739+
for dim in inf_sp_names:
1740+
assert inf_indexes[dim].equals(xr_ds.indexes[dim])

0 commit comments

Comments
 (0)