Skip to content

Commit 0a521e8

Browse files
committed
add type hints to fabm_evaluate.py
1 parent c464d84 commit 0a521e8

File tree

1 file changed

+38
-39
lines changed

1 file changed

+38
-39
lines changed

src/pyfabm/utils/fabm_evaluate.py

+38-39
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212

1313
import sys
1414
import os
15+
from typing import Union, MutableMapping, Mapping, Iterable, cast
1516

16-
import numpy
17-
import netCDF4 # type: ignore
17+
import numpy as np
18+
import netCDF4
1819
import yaml
1920

2021
try:
@@ -25,20 +26,24 @@
2526

2627

2728
def evaluate(
28-
yaml_path,
29-
sources=(),
30-
location={},
31-
assignments={},
32-
verbose=True,
33-
ignore_missing=False,
34-
surface=True,
35-
bottom=True,
29+
yaml_path: str,
30+
sources: Iterable[str] = (),
31+
location: Mapping[str, int] = {},
32+
assignments: Mapping[str, float] = {},
33+
verbose: bool = True,
34+
ignore_missing: bool = False,
35+
surface: bool = True,
36+
bottom: bool = True,
3637
):
3738
# Create model object from YAML file.
3839
model = pyfabm.Model(yaml_path)
3940

40-
allvariables = list(model.state_variables) + list(model.dependencies)
41-
name2variable = {}
41+
allvariables: pyfabm.NamedObjectList[
42+
Union[pyfabm.StateVariable, pyfabm.Dependency]
43+
] = pyfabm.NamedObjectList(model.state_variables, model.dependencies)
44+
name2variable: MutableMapping[
45+
str, Union[pyfabm.StateVariable, pyfabm.Dependency]
46+
] = {}
4247
for variable in allvariables:
4348
name2variable[variable.name] = variable
4449
if hasattr(variable, "output_name"):
@@ -47,9 +52,11 @@ def evaluate(
4752
[(name.lower(), variable) for (name, variable) in name2variable.items()]
4853
)
4954

50-
def set_state(**dim2index):
55+
def set_state(**dim2index: int):
5156
missing = set(allvariables)
52-
variable2source = {}
57+
variable2source: MutableMapping[
58+
Union[pyfabm.StateVariable, pyfabm.Dependency], str
59+
] = {}
5360

5461
def set_variable(variable, value, source):
5562
missing.discard(variable)
@@ -65,7 +72,7 @@ def set_variable(variable, value, source):
6572
for path in sources:
6673
if path.endswith("yaml"):
6774
with open(path) as f:
68-
data = yaml.load(f)
75+
data = yaml.safe_load(f)
6976
for name, value in data.items():
7077
variable = name2variable.get(name)
7178
if variable is None:
@@ -105,22 +112,16 @@ def set_variable(variable, value, source):
105112
variable = name2variable[name]
106113
missing.discard(variable)
107114
variable2source[variable] = "command line"
108-
variable.value = float(value)
115+
variable.value = cast(np.ndarray, float(value))
109116

110117
if verbose:
111118
print()
112119
print("State:")
113-
for variable in sorted(model.state_variables, key=lambda x: x.name.lower()):
114-
print(
115-
f" {variable.name}: {variable.value}"
116-
f" [{variable2source.get(variable)}]"
117-
)
120+
for sv in sorted(model.state_variables, key=lambda x: x.name.lower()):
121+
print(f" {sv.name}: {sv.value}" f" [{variable2source.get(sv)}]")
118122
print("Environment:")
119-
for variable in sorted(model.dependencies, key=lambda x: x.name.lower()):
120-
print(
121-
f" {variable.name}: {variable.value}"
122-
f" [{variable2source.get(variable)}]"
123-
)
123+
for d in sorted(model.dependencies, key=lambda x: x.name.lower()):
124+
print(f" {d.name}: {d.value} [{variable2source.get(d)}]")
124125

125126
if missing:
126127
print("The following variables are still missing:")
@@ -137,10 +138,10 @@ def set_variable(variable, value, source):
137138
sys.exit(1)
138139

139140
print("State variables with largest value:")
140-
for variable in sorted(
141-
model.state_variables, key=lambda x: abs(x.value), reverse=True
141+
for sv in sorted(
142+
model.state_variables, key=lambda x: abs(float(x.value)), reverse=True
142143
)[:3]:
143-
print(f" {variable.name}: {variable.value} {variable.units}")
144+
print(f" {sv.name}: {sv.value} {sv.units}")
144145

145146
# Get model rates
146147
rates = model.getRates(surface=surface, bottom=bottom)
@@ -150,22 +151,20 @@ def set_variable(variable, value, source):
150151

151152
if verbose:
152153
print("Diagnostics:")
153-
for variable in sorted(
154-
model.diagnostic_variables, key=lambda x: x.name.lower()
155-
):
156-
if variable.output:
157-
print(f" {variable.name}: {variable.value} {variable.units}")
154+
for dv in sorted(model.diagnostic_variables, key=lambda x: x.name.lower()):
155+
if dv.output:
156+
print(f" {dv.name}: {dv.value} {dv.units}")
158157

159158
# Check whether rates of change are valid numbers
160-
valids = numpy.isfinite(rates)
159+
valids = np.isfinite(rates)
161160
if not valids.all():
162161
print("The following state variables have an invalid rate of change:")
163-
for variable, rate, valid in zip(model.state_variables, rates, valids):
162+
for sv, rate, valid in zip(model.state_variables, rates, valids):
164163
if not valid:
165-
print(f" {variable.name}: {rate}")
164+
print(f" {sv.name}: {rate}")
166165

167166
eps = 1e-30
168-
relative_rates = numpy.array(
167+
relative_rates = np.array(
169168
[
170169
rate / (variable.value + eps)
171170
for variable, rate in zip(model.state_variables, rates)
@@ -189,7 +188,7 @@ def set_variable(variable, value, source):
189188
)[:3]:
190189
print(f" {variable.name}: {86400 * relative_rate} d-1")
191190

192-
i = relative_rates.argmin()
191+
i = int(relative_rates.argmin())
193192
print(
194193
f"Minimum time step = {-1.0 / relative_rates[i]:%.3f} s due to decrease"
195194
f" in {model.state_variables[i].name}"

0 commit comments

Comments
 (0)