Skip to content

Commit f1dc9af

Browse files
authored
Merge pull request #7502 from jenshnielsen/jenshnielsen/fix_cache_with_inferred_params_2
Write NaN values to cache when data is missing
2 parents 8ea9634 + 98a07cc commit f1dc9af

File tree

3 files changed

+78
-10
lines changed

3 files changed

+78
-10
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed a bug where omitting data for one or more variables could result in an inconsistent dataset cache. Missing data is now filled with appropriate empty values (0, "" or NaN depending on the data type)

src/qcodes/dataset/data_set_cache.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def _merge_data(
325325
shape: tuple[int, ...] | None,
326326
single_tree_write_status: int | None,
327327
meas_parameter: str,
328-
) -> tuple[dict[str, npt.NDArray], int | None]:
328+
) -> tuple[dict[str, npt.NDArray], int]:
329329
subtree_merged_data = {}
330330
subtree_parameters = existing_data.keys()
331331

@@ -335,20 +335,19 @@ def _merge_data(
335335
"The following keys were unexpected: "
336336
f"{set(new_data.keys() - existing_data.keys())}"
337337
)
338-
339-
new_write_status: int | None
340-
single_param_merged_data, new_write_status = _merge_data_single_param(
338+
single_param_merged_data, data_written = _merge_data_single_param(
341339
existing_data.get(meas_parameter),
342340
new_data.get(meas_parameter),
343341
shape,
344342
single_tree_write_status,
345343
)
344+
new_write_status = data_written if data_written is not None else 0
346345
if single_param_merged_data is not None:
347346
subtree_merged_data[meas_parameter] = single_param_merged_data
348347

349348
for subtree_param in subtree_parameters:
350349
if subtree_param != meas_parameter:
351-
single_param_merged_data, new_write_status = _merge_data_single_param(
350+
single_param_merged_data, data_written = _merge_data_single_param(
352351
existing_data.get(subtree_param),
353352
new_data.get(subtree_param),
354353
shape,
@@ -357,6 +356,9 @@ def _merge_data(
357356
if single_param_merged_data is not None:
358357
subtree_merged_data[subtree_param] = single_param_merged_data
359358

359+
if data_written is not None and data_written > new_write_status:
360+
new_write_status = data_written
361+
360362
return subtree_merged_data, new_write_status
361363

362364

@@ -373,22 +375,34 @@ def _merge_data_single_param(
373375
(merged_data, new_write_status) = _insert_into_data_dict(
374376
existing_values, new_values, single_tree_write_status, shape=shape
375377
)
376-
elif new_values is not None:
378+
elif new_values is not None or shape is not None:
377379
(merged_data, new_write_status) = _create_new_data_dict(new_values, shape)
378380
elif existing_values is not None:
379381
merged_data = existing_values
380382
new_write_status = single_tree_write_status
383+
elif shape is None and new_values is None:
384+
merged_data = existing_values
385+
new_write_status = single_tree_write_status
381386
else:
382387
merged_data = None
383388
new_write_status = None
384389
return merged_data, new_write_status
385390

386391

387392
def _create_new_data_dict(
388-
new_values: npt.NDArray, shape: tuple[int, ...] | None
389-
) -> tuple[npt.NDArray, int]:
390-
if shape is None:
393+
new_values: npt.NDArray | None, shape: tuple[int, ...] | None
394+
) -> tuple[npt.NDArray, int | None]:
395+
if shape is None and new_values is None:
396+
raise RuntimeError("Cannot create new data dict without new values")
397+
elif shape is None:
398+
assert new_values is not None
391399
return new_values, new_values.size
400+
elif new_values is None:
401+
# we don't know the datatype so use float which can hold NaN
402+
# since that is the most common?
403+
data = np.zeros(shape)
404+
data[:] = np.nan
405+
return data, None
392406
elif new_values.size > 0:
393407
n_values = new_values.size
394408
data = np.zeros(shape, dtype=new_values.dtype)

tests/dataset/test_dataset_in_memory.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,29 @@
33
import re
44
import shutil
55
from pathlib import Path
6+
from typing import TYPE_CHECKING
67

78
import hypothesis.strategies as hst
89
import numpy as np
910
import pytest
1011
import xarray as xr
12+
from deepdiff import DeepDiff # type: ignore[import-untyped]
1113
from hypothesis import HealthCheck, given, settings
1214
from numpy.testing import assert_almost_equal
1315

1416
import qcodes
15-
from qcodes.dataset import load_by_id, load_by_run_spec
17+
from qcodes.dataset import Measurement, load_by_id, load_by_run_spec
1618
from qcodes.dataset.data_set_in_memory import DataSetInMem, load_from_file
1719
from qcodes.dataset.data_set_protocol import DataSetType
1820
from qcodes.dataset.descriptions.dependencies import InterDependencies_
1921
from qcodes.dataset.descriptions.param_spec import ParamSpecBase
2022
from qcodes.dataset.sqlite.connection import AtomicConnection, atomic_transaction
23+
from qcodes.parameters import ManualParameter, Parameter
2124
from qcodes.station import Station
2225

26+
if TYPE_CHECKING:
27+
from qcodes.dataset.experiment_container import Experiment
28+
2329

2430
def test_dataset_in_memory_reload_from_db(
2531
meas_with_registered_param, DMM, DAC, tmp_path
@@ -676,3 +682,50 @@ def test_load_from_db_dataset_moved(
676682
not in new_xr_ds.attrs
677683
)
678684
assert new_xr_ds.attrs["metadata_added_after_set_new_netcdf_location"] == 6969
685+
686+
687+
@pytest.mark.parametrize("include_inferred_data", [True, False])
688+
def test_dataset_in_mem_with_inferred_parameters(
689+
experiment: "Experiment", include_inferred_data: bool
690+
) -> None:
691+
inferred1 = ManualParameter("inferred1", initial_value=0.0)
692+
inferred2 = ManualParameter("inferred2", initial_value=0.0)
693+
control1 = ManualParameter("control1", initial_value=0.0)
694+
control2 = ManualParameter("control2", initial_value=0.0)
695+
dependent = Parameter("dependent", get_cmd=lambda: control1(), set_cmd=False)
696+
meas = Measurement(exp=experiment, name="via Measurement")
697+
698+
meas.register_parameter(control1)
699+
meas.register_parameter(control2)
700+
meas.register_parameter(inferred1, basis=(control1, control2))
701+
meas.register_parameter(inferred2, basis=(control1, control2))
702+
meas.register_parameter(dependent, setpoints=(control1, control2))
703+
meas.set_shapes({dependent.register_name: (11, 11)})
704+
with meas.run() as datasaver:
705+
for i in range(11):
706+
for j in range(11):
707+
control1(float(i))
708+
control2(float(j))
709+
if include_inferred_data:
710+
datasaver.add_result(
711+
(inferred1, inferred1()),
712+
(inferred2, inferred2()),
713+
(control1, control1()),
714+
(control2, control2()),
715+
(dependent, dependent()),
716+
)
717+
else:
718+
datasaver.add_result(
719+
(control1, control1()),
720+
(control2, control2()),
721+
(dependent, dependent()),
722+
)
723+
ds = datasaver.dataset
724+
725+
param_data = ds.get_parameter_data()
726+
cache_data = ds.cache.data()
727+
728+
assert set(param_data.keys()) == set(cache_data.keys())
729+
assert set(param_data["dependent"].keys()) == set(cache_data["dependent"].keys())
730+
731+
assert DeepDiff(param_data, cache_data, ignore_nan_inequality=True) == {}

0 commit comments

Comments
 (0)