Skip to content

Commit bb913ee

Browse files
committed
Add test that coords are not duplicated
1 parent 0856fba commit bb913ee

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

tests/dataset/test_dataset_export.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1973,6 +1973,63 @@ def test_measurement_2d_with_inferred_setpoint(
19731973
assert inf_idx["y"].equals(xr_ds.indexes["y"])
19741974

19751975

1976+
def test_measurement_2d_with_inferred_setpoint_from_setpoint(
1977+
experiment: Experiment, caplog: LogCaptureFixture
1978+
) -> None:
1979+
"""
1980+
This is not a good idea but a user can do this
1981+
"""
1982+
# Grid sizes
1983+
nx, ny = 3, 4
1984+
x_vals = np.linspace(0.0, 2.0, nx)
1985+
y_vals = np.linspace(10.0, 13.0, ny)
1986+
1987+
meas = Measurement(exp=experiment, name="2d_with_inferred_setpoint")
1988+
# Register setpoint x
1989+
meas.register_custom_parameter("x", paramtype="numeric")
1990+
1991+
# Register y as setpoint inferred from basis
1992+
meas.register_custom_parameter("y", basis=("x"), paramtype="numeric")
1993+
# Register measured parameter depending on (x, y)
1994+
meas.register_custom_parameter("signal", setpoints=("x", "y"), paramtype="numeric")
1995+
meas.set_shapes({"signal": (nx, ny)})
1996+
1997+
with meas.run() as datasaver:
1998+
for ix in range(nx):
1999+
for iy in range(ny):
2000+
x = float(x_vals[ix])
2001+
y = float(y_vals[iy])
2002+
signal = x + 3.0 * y # deterministic function
2003+
datasaver.add_result(
2004+
("x", x),
2005+
("y", y),
2006+
("signal", signal),
2007+
)
2008+
2009+
ds = datasaver.dataset
2010+
2011+
caplog.clear()
2012+
with caplog.at_level(logging.INFO):
2013+
xr_ds = ds.to_xarray_dataset()
2014+
2015+
assert any(
2016+
"Exporting signal to xarray using direct method" in record.message
2017+
for record in caplog.records
2018+
)
2019+
2020+
# Sizes and coords
2021+
assert xr_ds.sizes == {"x": nx, "y": ny}
2022+
np.testing.assert_allclose(xr_ds.coords["x"].values, x_vals)
2023+
np.testing.assert_allclose(xr_ds.coords["y"].values, y_vals)
2024+
2025+
assert len(xr_ds.coords) == 2
2026+
2027+
# Signal dims and values
2028+
assert xr_ds["signal"].dims == ("x", "y")
2029+
expected_signal = x_vals[:, None] + 3.0 * y_vals[None, :]
2030+
np.testing.assert_allclose(xr_ds["signal"].values, expected_signal)
2031+
2032+
19762033
def test_measurement_2d_top_level_inferred_is_data_var(
19772034
experiment: Experiment, caplog: LogCaptureFixture
19782035
) -> None:

0 commit comments

Comments
 (0)