Skip to content

Commit effe676

Browse files
authored
pyfabm: add NamedObjectList.index, improve annotations and mask access (#95)
* pyfabm: make index method of NamedObjectList take str * suppress exception chaining * improve type annotation * improve mask access
1 parent 1509670 commit effe676

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

src/pyfabm/__init__.py

+23-12
Original file line numberDiff line numberDiff line change
@@ -886,19 +886,27 @@ def __init__(self, *data: Iterable[T]):
886886
def __len__(self) -> int:
887887
return len(self._data)
888888

889-
def __getitem__(self, key: Union[str, int]) -> T:
889+
def __getitem__(self, key: Union[int, str]) -> T:
890890
if isinstance(key, str):
891891
return self.find(key)
892892
return self._data[key]
893893

894-
def __contains__(self, key: Union[str, int]) -> bool:
894+
def __contains__(self, key: Union[T, str]) -> bool:
895895
if isinstance(key, str):
896896
try:
897897
self.find(key)
898898
return True
899899
except KeyError:
900900
return False
901-
return super().__contains__(key)
901+
return key in self._data
902+
903+
def index(self, key: Union[T, str], *args) -> int:
904+
if isinstance(key, str):
905+
try:
906+
key = self.find(key)
907+
except KeyError:
908+
raise ValueError from None
909+
return self._data.index(key, *args)
902910

903911
def __repr__(self) -> str:
904912
return repr(self._data)
@@ -1079,10 +1087,15 @@ def link_mask(self, *masks: np.ndarray):
10791087
self._mask = masks
10801088
self.fabm.set_mask(self.pmodel, *self._mask)
10811089

1082-
def _get_mask(self) -> Union[np.ndarray, Sequence[np.ndarray]]:
1083-
return self._mask[0] if len(self._mask) == 1 else self._mask
1084-
1085-
def _set_mask(self, values: Union[npt.ArrayLike, Sequence[npt.ArrayLike]]):
1090+
@property
1091+
def mask(self) -> Union[np.ndarray, Sequence[np.ndarray], None]:
1092+
mask = self._mask
1093+
if mask is not None and len(mask) == 1:
1094+
mask = mask[0]
1095+
return mask
1096+
1097+
@mask.setter
1098+
def mask(self, values: Union[npt.ArrayLike, Sequence[npt.ArrayLike]]):
10861099
if self.fabm.mask_type == 1:
10871100
values = (values,)
10881101
if len(values) != self.fabm.mask_type:
@@ -1096,8 +1109,6 @@ def _set_mask(self, values: Union[npt.ArrayLike, Sequence[npt.ArrayLike]]):
10961109
if value is not mask:
10971110
mask[...] = value
10981111

1099-
mask = property(_get_mask, _set_mask)
1100-
11011112
def link_bottom_index(self, indices: np.ndarray):
11021113
if not self.fabm.variable_bottom_index:
11031114
raise FABMException(
@@ -1401,7 +1412,7 @@ def _update_configuration(self, settings: Optional[Tuple] = None):
14011412
+ self.horizontal_dependencies
14021413
+ self.scalar_dependencies
14031414
)
1404-
self.variables = (
1415+
self.variables: NamedObjectList[VariableFromPointer] = (
14051416
self.state_variables + self.diagnostic_variables + self.dependencies
14061417
)
14071418

@@ -1414,7 +1425,7 @@ def _update_configuration(self, settings: Optional[Tuple] = None):
14141425

14151426
self.itime = -1.0
14161427

1417-
def getRates(self, t: float = None, surface: bool = True, bottom: bool = True):
1428+
def getRates(self, t: Optional[float] = None, surface: bool = True, bottom: bool = True):
14181429
"""Returns the local rate of change in state variables,
14191430
given the current state and environment.
14201431
"""
@@ -1451,7 +1462,7 @@ def getRates(self, t: float = None, surface: bool = True, bottom: bool = True):
14511462

14521463
def get_sources(
14531464
self,
1454-
t: float = None,
1465+
t: Optional[float] = None,
14551466
out: Optional[Tuple[np.ndarray, np.ndarray, np.ndarray]] = None,
14561467
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
14571468
if t is None:

0 commit comments

Comments
 (0)