Skip to content

Commit f4c3431

Browse files
tbody-cfshassec
andauthored
Input_file_handling_robustness (#107)
* Improve flexibility of input file handling * Split input dict processing from yaml read * Repair tests * Update cfspopcon/input_file_handling.py * Update cfspopcon/input_file_handling.py Co-authored-by: Christoph Hasse <[email protected]> --------- Co-authored-by: Christoph Hasse <[email protected]>
1 parent 520f458 commit f4c3431

21 files changed

+181
-58
lines changed

.gitignore

+4-1
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,7 @@ radas_dir/*
141141
popcon_algorithms.yaml
142142

143143
# Have an untracked folder for rough working
144-
untracked/
144+
untracked/
145+
# Have a cases folder for personal cases which shouldn't be added
146+
# to the index
147+
cases/*

cfspopcon/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from . import file_io, formulas, named_options, shaping_and_selection
99
from .algorithm_class import Algorithm, CompositeAlgorithm
1010
from .formulas.atomic_data import AtomicData
11-
from .input_file_handling import read_case
11+
from .input_file_handling import process_input_dictionary, read_case
1212
from .plotting import read_plot_style
1313
from .unit_handling import (
1414
convert_to_default_units,
@@ -23,6 +23,7 @@
2323
"named_options",
2424
"magnitude_in_default_units",
2525
"convert_to_default_units",
26+
"process_input_dictionary",
2627
"set_default_units",
2728
"convert_units",
2829
"read_case",

cfspopcon/algorithm_class.py

+10
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,16 @@ def function_wrapper(func: GenericFunctionType) -> GenericFunctionType:
174174

175175
return function_wrapper
176176

177+
@classmethod
178+
def empty(cls) -> Algorithm:
179+
"""Makes a 'do nothing' algorithm, in case you don't want to use the algorithm functionality."""
180+
181+
def do_nothing() -> dict[str, Any]:
182+
result_dict: dict[str, Any] = {}
183+
return result_dict
184+
185+
return cls(do_nothing, return_keys=[], name="empty", skip_registration=True)
186+
177187
def validate_inputs(
178188
self, configuration: Union[dict, xr.Dataset], quiet: bool = False, raise_error_on_missing_inputs: bool = False
179189
) -> bool:

cfspopcon/file_io.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,13 @@
2424
]
2525

2626

27-
def sanitize_variable(val: xr.DataArray, key: str) -> Union[xr.DataArray, str]:
28-
"""Strip units and Enum values from a variable so that it can be stored in a NetCDF file."""
27+
def sanitize_variable(val: xr.DataArray, key: str, coord: bool = False) -> Union[xr.DataArray, str]:
28+
"""Strip units and Enum values from a variable so that it can be stored in a NetCDF file.
29+
30+
If you set coord=True and you pass in a scalar val, val is wrapped in a length-1 array to
31+
circumvent an xarray issue regarding single-value coordinates.
32+
See https://github.com/pydata/xarray/issues/1709.
33+
"""
2934
try:
3035
val = convert_to_default_units(val, key).pint.dequantify()
3136
except KeyError:
@@ -34,9 +39,12 @@ def sanitize_variable(val: xr.DataArray, key: str) -> Union[xr.DataArray, str]:
3439
if val.dtype == object:
3540
try:
3641
if val.size == 1:
37-
val = val.item().name
42+
if not coord:
43+
val = val.item().name
44+
else:
45+
val = xr.DataArray([val.item().name])
3846
else:
39-
val = xr.DataArray([v.name for v in val.values])
47+
val = xr.DataArray([v.name for v in val.values], dims=val.dims)
4048
except AttributeError:
4149
warnings.warn(f"Cannot handle {key}. Dropping variable.", stacklevel=3)
4250
return "UNHANDLED"
@@ -62,7 +70,7 @@ def write_dataset_to_netcdf(
6270

6371
for key in serialized_dataset.coords: # type:ignore[assignment]
6472
assert isinstance(key, str)
65-
serialized_dataset[key] = sanitize_variable(dataset[key], key)
73+
serialized_dataset[key] = sanitize_variable(dataset[key], key, coord=True)
6674

6775
serialized_dataset.to_netcdf(filepath, engine=netcdf_writer)
6876

cfspopcon/input_file_handling.py

+45-20
Original file line numberDiff line numberDiff line change
@@ -22,43 +22,68 @@ def read_case(
2222
2323
kwargs can be an arbitrary dictionary of key-value pairs that overwrite the config values.
2424
"""
25-
if kwargs is None:
26-
kwargs = dict()
27-
if Path(case).exists():
28-
case = Path(case)
29-
if case.is_dir():
30-
input_file = case / "input.yaml"
31-
else:
32-
input_file = case
33-
else:
25+
case = Path(case)
26+
27+
if not case.exists():
3428
raise FileNotFoundError(f"Could not find {case}.")
3529

30+
if case.is_dir():
31+
case_dir = case
32+
input_file = case_dir / "input.yaml"
33+
else:
34+
case_dir = case.parent
35+
input_file = case
36+
3637
with open(input_file) as file:
37-
repr_d: dict[str, Any] = yaml.load(file, Loader=yaml.FullLoader)
38+
repr_d = yaml.load(file, Loader=yaml.FullLoader)
3839

39-
repr_d.update(kwargs)
40+
if kwargs is not None:
41+
repr_d.update(kwargs)
4042

41-
algorithms = repr_d.pop("algorithms")
42-
algorithm_list = [Algorithm.get_algorithm(algorithm) for algorithm in algorithms]
43+
return process_input_dictionary(repr_d, case_dir)
44+
45+
46+
def process_input_dictionary(
47+
repr_d: dict[str, Any], case_dir: Path
48+
) -> tuple[dict[str, Any], Union[CompositeAlgorithm, Algorithm], dict[str, Any], dict[str, Path]]:
49+
"""Convert an input dictionary into an processed dictionary, a CompositeAlgorithm and dictionaries defining points and plots.
50+
51+
Several processing steps are applied, including;
52+
* The `algorithms` entry is converted into a `cfspopcon.CompositeAlgorithm`. This basically gives the list of operations that we want to perform on the input data.
53+
* The `points` entry is stored in a separate dictionary. This gives a set of key-value pairs of 'optimal' points (for instance, giving the point with the maximum fusion power gain).
54+
* The `grids` entry is converted into an `xr.DataArray` storing a `np.linspace` or `np.logspace` of values which we scan over. We usually scan over `average_electron_density` and `average_electron_temp`, but there's nothing preventing you from scanning over other numerical input variables or having more than 2 dimensions which you scan over (n.b. this can get expensive!).
55+
* Each input variable is checked to see if its name matches one of the enumerators in `cfspopcon.named_options`. These are used to store switch values, such as `cfspopcon.named_options.ReactionType.DT` which indicates that we're interested in the DT fusion reaction.
56+
* Each input variable is converted into its default units. Default units are retrieved via the `cfspopcon.unit_handling.default_unit` function. This will set, for instance, the `average_electron_temp` values to have units of `keV`.
57+
58+
Args:
59+
repr_d: Dictionary to process
60+
case_dir: Relative paths specified in repr_d are interpreted as relative to this directory
61+
"""
62+
algorithms = repr_d.pop("algorithms", dict())
63+
algorithm_list: list[Union[Algorithm, CompositeAlgorithm]] = [Algorithm.get_algorithm(algorithm) for algorithm in algorithms]
4364

44-
# why doesn't mypy deduce the below without hint?
45-
algorithm: Union[Algorithm, CompositeAlgorithm] = CompositeAlgorithm(algorithm_list) if len(algorithm_list) > 1 else algorithm_list[0]
65+
if len(algorithm_list) > 1:
66+
algorithm = CompositeAlgorithm(algorithm_list)
67+
elif len(algorithm_list) == 1:
68+
algorithm = algorithm_list[0] # type:ignore[assignment]
69+
elif len(algorithm_list) == 0:
70+
algorithm = Algorithm.empty() # type:ignore[assignment]
4671

4772
points = repr_d.pop("points", dict())
4873
plots = repr_d.pop("plots", dict())
4974

5075
process_grid_values(repr_d)
5176
process_named_options(repr_d)
52-
process_paths(repr_d, input_file)
53-
process_paths(plots, input_file)
77+
process_paths(repr_d, case_dir)
78+
process_paths(plots, case_dir)
5479
process_units(repr_d)
5580

5681
return repr_d, algorithm, points, plots
5782

5883

5984
def process_grid_values(repr_d: dict[str, Any]): # type:ignore[no-untyped-def]
6085
"""Process the grid of values to run POPCON over."""
61-
grid_values = repr_d.pop("grid")
86+
grid_values = repr_d.pop("grid", dict())
6287
for key, grid_spec in grid_values.items():
6388
grid_spacing = grid_spec.get("spacing", "linear")
6489

@@ -81,15 +106,15 @@ def process_named_options(repr_d: dict[str, Any]): # type:ignore[no-untyped-def
81106
repr_d[key] = convert_named_options(key=key, val=val)
82107

83108

84-
def process_paths(repr_d: dict[str, Any], input_file: Path): # type:ignore[no-untyped-def]
109+
def process_paths(repr_d: dict[str, Any], case_dir: Path): # type:ignore[no-untyped-def]
85110
"""Process path tags, up to a maximum of one tag per input variable.
86111
87112
Allowed tags are:
88113
* CASE_DIR: the folder that the input.yaml file is located in
89114
* WORKING_DIR: the current working directory that the script is being run from
90115
"""
91116
path_mappings = dict(
92-
CASE_DIR=input_file.parent,
117+
CASE_DIR=case_dir,
93118
WORKING_DIR=Path("."),
94119
)
95120
if repr_d is None:

docs/conf.py

+10
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@
4141
r"https://doi.org/10.13182/FST43-67",
4242
r"https://www.tandfonline.com/doi/full/10.13182/FST43-67",
4343
r"https://www-internal.psfc.mit.edu/research/alcator/data/fst_cmod.pdf",
44+
# these links in the time_independent_inductances_and_fluxes notebook are on private servers that are sometimes down
45+
r"https://fire.pppl.gov/iaea06_ftp7_5_matsukawa.pdf",
46+
r"https://escholarship.org/content/qt78k0v04v/qt78k0v04v_noSplash_c44c701847deffab65024dd9ceff9c59.pdf?t=p15pc5",
47+
r"https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=585f5eb3f62f3bd76f3d667c1df357562f54c084",
48+
r"https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=585f5eb3f62f3bd76f3d667c1df357562f54c084",
49+
r"https://fire.pppl.gov/Snowmass_BP/FIRE.pdf",
50+
r"https://www.ipp.mpg.de/16208/einfuehrung",
51+
r"https://www.ipp.mpg.de/16701/jet",
52+
r"https://iopscience.iop.org/article/10.1088/1009-0630/13/1/01",
53+
r"https://www-internal.psfc.mit.edu/research/alcator/data/fst_cmod.pdf",
4454
]
4555
linkcheck_retries = 5
4656
linkcheck_timeout = 120

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ classifiers = [
1919

2020
[tool.poetry.scripts]
2121
popcon = 'cfspopcon.cli:run_popcon_cli'
22+
cfspopcon = 'cfspopcon.cli:run_popcon_cli'
2223
popcon_algorithms = 'cfspopcon.cli:write_algorithms_yaml'
2324

2425
[tool.poetry.dependencies]
-12 Bytes
Binary file not shown.

tests/regression_results/PRD.json

+1-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"Tungsten"
2222
],
2323
"dims": [
24-
"dim_0"
24+
"dim_species"
2525
]
2626
}
2727
},
@@ -1333,7 +1333,6 @@
13331333
}
13341334
},
13351335
"dims": {
1336-
"dim_0": 5,
13371336
"dim_rho": 50,
13381337
"dim_species": 5
13391338
}
-187 Bytes
Binary file not shown.

tests/test_algorithms_class.py

+10
Original file line numberDiff line numberDiff line change
@@ -270,3 +270,13 @@ def test_get_algorithm():
270270
for key in Algorithm.algorithms():
271271
alg = Algorithm.get_algorithm(key)
272272
assert alg._name in [f"run_{key}", key, "<lambda>"]
273+
274+
275+
def test_blank_algorithm():
276+
test_ds = xr.Dataset(data_vars=dict(a=xr.DataArray([1, 2, 3])))
277+
278+
algorithm = Algorithm.empty()
279+
280+
updated_ds = algorithm.update_dataset(test_ds)
281+
282+
xr.testing.assert_allclose(test_ds, updated_ds)

tests/test_confinement_switch.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from cfspopcon.formulas.energy_confinement.switch_confinement_scaling_on_threshold import switch_to_L_mode_confinement_below_threshold
2-
from cfspopcon.formulas.energy_confinement.solve_for_input_power import solve_energy_confinement_scaling_for_input_power
3-
from cfspopcon.formulas.energy_confinement.read_energy_confinement_scalings import read_confinement_scalings
4-
from cfspopcon.unit_handling import ureg, magnitude_in_units
51
import numpy as np
62

3+
from cfspopcon.formulas.energy_confinement.read_energy_confinement_scalings import read_confinement_scalings
4+
from cfspopcon.formulas.energy_confinement.solve_for_input_power import solve_energy_confinement_scaling_for_input_power
5+
from cfspopcon.formulas.energy_confinement.switch_confinement_scaling_on_threshold import switch_to_L_mode_confinement_below_threshold
6+
from cfspopcon.unit_handling import magnitude_in_units, ureg
7+
78

89
def test_switch_to_L_mode_confinement_below_threshold():
910
kwargs = dict(

tests/test_docs.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
import pytest
21
import subprocess
32
import warnings
43

4+
import pytest
5+
56
pytest.importorskip("sphinx")
67
from importlib.resources import files
78

89

910
@pytest.mark.docs
1011
def test_docs():
11-
"Test the Sphinx documentation."
12+
"""Test the Sphinx documentation."""
1213
popcon_directory = files("cfspopcon")
1314

1415
doctest_output = subprocess.run(

tests/test_for_anonymous_algorithms.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from cfspopcon.algorithm_class import Algorithm
2-
import cfspopcon
31
from importlib import import_module
42

3+
import cfspopcon
4+
from cfspopcon.algorithm_class import Algorithm
5+
56

67
def import_all_submodules(importable, module, prefix):
78
for module in module.__all__:

tests/test_helpers.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1+
import numpy as np
12
import pytest
23
import xarray as xr
3-
import numpy as np
44

55
from cfspopcon import named_options
6-
from cfspopcon.helpers import (
7-
convert_named_options,
8-
)
9-
from cfspopcon.named_options import AtomicSpecies
106
from cfspopcon.formulas.impurities.impurity_array_helpers import (
117
extend_impurity_concentration_array,
128
make_impurity_concentration_array,
139
make_impurity_concentration_array_from_kwargs,
1410
)
11+
from cfspopcon.helpers import (
12+
convert_named_options,
13+
)
14+
from cfspopcon.named_options import AtomicSpecies
1515

1616

1717
def test_convert_named_options():

tests/test_infra/test_line_selection.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import numpy as np
22
import xarray as xr
3-
from cfspopcon.unit_handling import ureg, Quantity, magnitude, convert_units
43

5-
from cfspopcon.shaping_and_selection.line_selection import interpolate_onto_line, find_coords_of_contour
4+
from cfspopcon.shaping_and_selection.line_selection import find_coords_of_contour, interpolate_onto_line
5+
from cfspopcon.unit_handling import Quantity, convert_units, magnitude, ureg
66

77

88
def test_extract_values_along_contour():

tests/test_input_file_handling.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from pathlib import Path
2+
3+
import pytest
4+
import yaml
5+
6+
from cfspopcon.algorithm_class import Algorithm, CompositeAlgorithm
7+
from cfspopcon.input_file_handling import read_case, process_input_dictionary
8+
9+
10+
@pytest.fixture
11+
def test_dict():
12+
return dict(Q=1.0)
13+
14+
15+
@pytest.fixture
16+
def case_dir():
17+
return Path(".").absolute()
18+
19+
20+
def test_blank_dictionary(test_dict, case_dir):
21+
process_input_dictionary(test_dict, case_dir)
22+
23+
24+
def test_blank_file(test_dict, tmp_path):
25+
with open(tmp_path / "input.yaml", "w") as file:
26+
yaml.dump(test_dict, file)
27+
28+
read_case(tmp_path)
29+
30+
31+
def test_blank_file_with_another_suffix(test_dict, tmp_path):
32+
with open(tmp_path / "another.filetype", "w") as file:
33+
yaml.dump(test_dict, file)
34+
35+
read_case(tmp_path / "another.filetype")
36+
37+
38+
def test_algorithm_read_single_from_input_file(case_dir):
39+
test_dict = dict(algorithms=["read_atomic_data"])
40+
41+
repr_d, algorithm, points, plots = process_input_dictionary(test_dict, case_dir)
42+
43+
assert isinstance(algorithm, Algorithm)
44+
45+
46+
def test_algorithm_read_multiple_from_input_file(case_dir):
47+
test_dict = dict(algorithms=["read_atomic_data", "set_up_impurity_concentration_array"])
48+
49+
repr_d, algorithm, points, plots = process_input_dictionary(test_dict, case_dir)
50+
51+
assert isinstance(algorithm, CompositeAlgorithm)
52+
53+
54+
def test_read_example_input_file():
55+
example_case = Path(__file__).parents[1] / "example_cases" / "SPARC_PRD" / "input.yaml"
56+
57+
read_case(example_case)

0 commit comments

Comments
 (0)