@@ -1599,3 +1599,142 @@ def test_dond_hypothesis_nd_grid(
1599
1599
expected_signal += weights [i ] * grid .astype (float )
1600
1600
1601
1601
np .testing .assert_allclose (xr_ds ["signal" ].values , expected_signal )
1602
+
1603
+
1604
+ @given (data = hst .data ())
1605
+ @settings (
1606
+ max_examples = 10 ,
1607
+ suppress_health_check = (HealthCheck .function_scoped_fixture ,),
1608
+ deadline = None ,
1609
+ )
1610
+ def test_measurement_hypothesis_nd_grid_with_inferred_param (
1611
+ data : hst .DataObject , experiment : Experiment , caplog : LogCaptureFixture
1612
+ ) -> None :
1613
+ """
1614
+ Randomized ND sweep using Measurement context manager with an inferred parameter:
1615
+ - Draw N in [2, 4]
1616
+ - For each dimension i, draw number of points n_i in [1, 5]
1617
+ - Sweep each ManualParameter over a linspace of length n_i
1618
+ - Choose m in [1, N-1] and a subset of m swept parameters for an inferred coord
1619
+ - Register an inferred parameter depending on that subset and add its values
1620
+ - Measure a deterministic function of the setpoints
1621
+ - Assert xarray dims, coords (including inferred), and data match expectation
1622
+ """
1623
+ # number of dimensions and points per dimension
1624
+ n_dims = data .draw (hst .integers (min_value = 2 , max_value = 4 ), label = "n_dims" )
1625
+ points_per_dim = [
1626
+ data .draw (hst .integers (min_value = 1 , max_value = 5 ), label = f"n_points_dim_{ i } " )
1627
+ for i in range (n_dims )
1628
+ ]
1629
+
1630
+ # build setpoint arrays and names
1631
+ sp_names = [f"x{ i } " for i in range (n_dims )]
1632
+ sp_values : list [np .ndarray ] = [
1633
+ np .linspace (0.0 , float (npts - 1 ), npts ) for npts in points_per_dim
1634
+ ]
1635
+
1636
+ # choose subset for inferred parameter (strict subset)
1637
+ m = data .draw (hst .integers (min_value = 1 , max_value = n_dims - 1 ), label = "m" )
1638
+ inf_indices = sorted (
1639
+ data .draw (
1640
+ hst .lists (
1641
+ hst .integers (min_value = 0 , max_value = n_dims - 1 ),
1642
+ min_size = m ,
1643
+ max_size = m ,
1644
+ unique = True ,
1645
+ ),
1646
+ label = "inf_indices" ,
1647
+ )
1648
+ )
1649
+ inf_sp_names = [sp_names [i ] for i in inf_indices ]
1650
+
1651
+ # weights for measured signal
1652
+ weights = [(i + 1 ) for i in range (n_dims )]
1653
+
1654
+ # Setup measurement with shapes so xarray direct path is used
1655
+ meas = Measurement (exp = experiment , name = "nd_grid_with_inferred" )
1656
+ # register setpoints
1657
+ for name in sp_names :
1658
+ meas .register_custom_parameter (name , paramtype = "numeric" )
1659
+ # register inferred parameter (from subset of setpoints)
1660
+ meas .register_custom_parameter (
1661
+ "inf" , basis = tuple (inf_sp_names ), paramtype = "numeric"
1662
+ )
1663
+ # register measured parameter depending on all setpoints
1664
+ meas .register_custom_parameter (
1665
+ "signal" , setpoints = tuple (sp_names ), paramtype = "numeric"
1666
+ )
1667
+ meas .set_shapes ({"signal" : tuple (points_per_dim )})
1668
+
1669
+ # run measurement over full grid
1670
+ with meas .run () as datasaver :
1671
+ # iterate over grid indices
1672
+ for idx in np .ndindex (* points_per_dim ):
1673
+ # collect setpoint values for this point
1674
+ sp_items : list [tuple [str , float ]] = [
1675
+ (sp_names [k ], float (sp_values [k ][idx [k ]])) for k in range (n_dims )
1676
+ ]
1677
+ # measured signal: weighted sum of all setpoints
1678
+ signal_val = float (
1679
+ sum (weights [k ] * float (sp_values [k ][idx [k ]]) for k in range (n_dims ))
1680
+ )
1681
+ # inferred value: sum over selected subset of setpoints
1682
+ inf_val = float (sum (float (sp_values [k ][idx [k ]]) for k in inf_indices ))
1683
+ results : list [tuple [str , float ]] = [
1684
+ * sp_items ,
1685
+ ("inf" , inf_val ),
1686
+ ("signal" , signal_val ),
1687
+ ]
1688
+ datasaver .add_result (* results )
1689
+
1690
+ ds = datasaver .dataset
1691
+
1692
+ # export to xarray and ensure direct path used
1693
+ caplog .clear ()
1694
+ with caplog .at_level (logging .INFO ):
1695
+ xr_ds = ds .to_xarray_dataset ()
1696
+
1697
+ assert any (
1698
+ "Exporting signal to xarray using direct method" in record .message
1699
+ for record in caplog .records
1700
+ )
1701
+
1702
+ # Expected sizes per coordinate (all setpoints)
1703
+ expected_sizes = {name : len (vals ) for name , vals in zip (sp_names , sp_values )}
1704
+ assert xr_ds .sizes == expected_sizes
1705
+
1706
+ # Check setpoint coords contents and order
1707
+ for name , vals in zip (sp_names , sp_values ):
1708
+ assert name in xr_ds .coords
1709
+ np .testing .assert_allclose (xr_ds .coords [name ].values , vals )
1710
+
1711
+ # Measured data dims and values
1712
+ assert "signal" in xr_ds .data_vars
1713
+ assert xr_ds ["signal" ].dims == tuple (sp_names )
1714
+
1715
+ grids_all = np .meshgrid (* sp_values , indexing = "ij" )
1716
+ expected_signal = np .zeros (tuple (points_per_dim ), dtype = float )
1717
+ for i , grid in enumerate (grids_all ):
1718
+ expected_signal += weights [i ] * grid .astype (float )
1719
+ np .testing .assert_allclose (xr_ds ["signal" ].values , expected_signal )
1720
+
1721
+ # Inferred coord should be present with dims equal to the subset order
1722
+ assert "inf" in xr_ds .coords
1723
+ expected_inf_dims = tuple (inf_sp_names )
1724
+ assert xr_ds .coords ["inf" ].dims == expected_inf_dims
1725
+
1726
+ # Build expected inferred grid based only on the subset dims
1727
+ subset_values = [sp_values [i ] for i in inf_indices ]
1728
+ grids_subset = np .meshgrid (* subset_values , indexing = "ij" ) if subset_values else []
1729
+ expected_inf = np .zeros (tuple (points_per_dim [i ] for i in inf_indices ), dtype = float )
1730
+ for grid in grids_subset :
1731
+ expected_inf += grid .astype (float )
1732
+ np .testing .assert_allclose (xr_ds .coords ["inf" ].values , expected_inf )
1733
+
1734
+ # The indexes of the inferred coord must correspond to the axes it depends on
1735
+ # i.e., keys should match the inferred-from setpoint names, and each index equal
1736
+ # to the dataset's index for that dimension
1737
+ inf_indexes = xr_ds .coords ["inf" ].indexes
1738
+ assert set (inf_indexes .keys ()) == set (inf_sp_names )
1739
+ for dim in inf_sp_names :
1740
+ assert inf_indexes [dim ].equals (xr_ds .indexes [dim ])
0 commit comments