-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Refactor into CoupledModels -> DriverCoupling -> Mappings * More refactoring * More work * Remove dataclasses for driver coupling classes, use pydantic basemodel instead * Cleanup mapping types * Add arbitrary_types_allowed for pydantic and geodataframes * Fix ribasim type hints * Avoid index kwarg type error * Export DriverCoupling. Move mappings one level up for ease of import in tests * Get test_primod running again * Update cases * Fixes for pre-processing tests * DriverCoupling name changes * Tests running locally * Create SvatUserDemandMapping class Try to improve logic a bit, probably WIP... * Please mypy * Address review comments * lower dir names. Fixes #228
- Loading branch information
Showing
36 changed files
with
1,722 additions
and
1,684 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,21 @@ | ||
from primod.driver_coupling import ( | ||
MetaModDriverCoupling, | ||
RibaMetaDriverCoupling, | ||
RibaModActiveDriverCoupling, | ||
RibaModPassiveDriverCoupling, | ||
) | ||
from primod.metamod import MetaMod | ||
from primod.ribametamod import RibaMetaMod | ||
from primod.ribamod import RibaMod | ||
|
||
__all__ = ["MetaMod", "RibaMod", "RibaMetaMod"] | ||
__all__ = ( | ||
"MetaMod", | ||
"RibaMod", | ||
"RibaMetaMod", | ||
"MetaModDriverCoupling", | ||
"RibaMetaDriverCoupling", | ||
"RibaModActiveDriverCoupling", | ||
"RibaModPassiveDriverCoupling", | ||
) | ||
|
||
__version__ = "2024.2.1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import abc | ||
from collections.abc import Sequence | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
from primod.driver_coupling.driver_coupling_base import DriverCoupling | ||
|
||
|
||
class CoupledModel(abc.ABC): | ||
coupling_list: Sequence[DriverCoupling] | ||
|
||
@abc.abstractmethod | ||
def write(self, directory: str | Path, *args: Any, **kwargs: Any) -> None: | ||
pass | ||
|
||
@abc.abstractmethod | ||
def write_toml(self, directory: str | Path, *args: Any, **kwargs: Any) -> None: | ||
pass | ||
|
||
@staticmethod | ||
def _merge_coupling_dicts(dicts: list[dict[str, Any]]) -> dict[str, Any]: | ||
coupling_dict: dict[str, dict[str, Any] | Any] = {} | ||
for top_dict in dicts: | ||
for top_key, top_value in top_dict.items(): | ||
if isinstance(top_value, dict): | ||
if top_key not in coupling_dict: | ||
coupling_dict[top_key] = {} | ||
for key, filename in top_value.items(): | ||
coupling_dict[top_key][key] = filename | ||
else: | ||
coupling_dict[top_key] = top_value | ||
return coupling_dict | ||
|
||
def write_exchanges(self, directory: str | Path) -> dict[str, Any]: | ||
""" | ||
Write exchanges and return their filenames for the coupler | ||
configuration file. | ||
""" | ||
directory = Path(directory) | ||
exchange_dir = Path(directory) / "exchanges" | ||
exchange_dir.mkdir(exist_ok=True, parents=True) | ||
|
||
coupling_dicts = [] | ||
for coupling in self.coupling_list: | ||
coupling_dict = coupling.write_exchanges( | ||
directory=exchange_dir, coupled_model=self | ||
) | ||
coupling_dicts.append(coupling_dict) | ||
|
||
# FUTURE: if we support multiple MF6 models, group them by name before | ||
# merging, and return a list of coupling_dicts. | ||
merged_coupling_dict = self._merge_coupling_dicts(coupling_dicts) | ||
return merged_coupling_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from primod.driver_coupling.metamod import MetaModDriverCoupling | ||
from primod.driver_coupling.ribameta import RibaMetaDriverCoupling | ||
from primod.driver_coupling.ribamod import ( | ||
RibaModActiveDriverCoupling, | ||
RibaModPassiveDriverCoupling, | ||
) | ||
|
||
__all__ = ( | ||
"MetaModDriverCoupling", | ||
"RibaMetaDriverCoupling", | ||
"RibaModActiveDriverCoupling", | ||
"RibaModPassiveDriverCoupling", | ||
) |
22 changes: 22 additions & 0 deletions
22
pre-processing/primod/driver_coupling/driver_coupling_base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import abc | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
from pydantic import BaseModel | ||
|
||
|
||
class DriverCoupling(BaseModel, abc.ABC): | ||
""" | ||
Abstract base class for driver couplings. | ||
""" | ||
|
||
# Config required for e.g. geodataframes | ||
model_config = {"arbitrary_types_allowed": True} | ||
|
||
@abc.abstractmethod | ||
def derive_mapping(self, *args: Any, **kwargs: Any) -> Any: | ||
pass | ||
|
||
@abc.abstractmethod | ||
def write_exchanges(self, directory: Path, coupled_model: Any) -> dict[str, Any]: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
from imod.mf6 import GroundwaterFlowModel | ||
from imod.msw import GridData, MetaSwapModel, Sprinkling | ||
|
||
from primod.driver_coupling.driver_coupling_base import DriverCoupling | ||
from primod.mapping.node_svat_mapping import NodeSvatMapping | ||
from primod.mapping.rch_svat_mapping import RechargeSvatMapping | ||
from primod.mapping.wel_svat_mapping import WellSvatMapping | ||
|
||
|
||
class MetaModDriverCoupling(DriverCoupling): | ||
""" | ||
Attributes | ||
---------- | ||
mf6_model : str | ||
The model of the driver. | ||
mf6_recharge_package: str | ||
Key of Modflow 6 recharge package to which MetaSWAP is coupled. | ||
mf6_wel_package: str or None | ||
Optional key of Modflow 6 well package to which MetaSWAP sprinkling is | ||
coupled. | ||
""" | ||
|
||
mf6_model: str | ||
mf6_recharge_package: str | ||
mf6_wel_package: str | None = None | ||
|
||
def _check_sprinkling( | ||
self, msw_model: MetaSwapModel, gwf_model: GroundwaterFlowModel | ||
) -> bool: | ||
sprinkling_key = msw_model._get_pkg_key(Sprinkling, optional_package=True) | ||
sprinkling_in_msw = sprinkling_key is not None | ||
sprinkling_in_mf6 = self.mf6_wel_package in gwf_model.keys() | ||
|
||
value = False | ||
match (sprinkling_in_msw, sprinkling_in_mf6): | ||
case (True, False): | ||
raise ValueError( | ||
f"No package named {self.mf6_wel_package} found in Modflow 6 model, " | ||
"but Sprinkling package found in MetaSWAP. " | ||
"iMOD Coupler requires a Well Package " | ||
"to couple wells." | ||
) | ||
case (False, True): | ||
raise ValueError( | ||
f"Modflow 6 Well package {self.mf6_wel_package} specified for sprinkling, " | ||
"but no Sprinkling package found in MetaSWAP model." | ||
) | ||
case (True, True): | ||
value = True | ||
case (False, False): | ||
value = False | ||
|
||
return value | ||
|
||
def derive_mapping( | ||
self, msw_model: MetaSwapModel, gwf_model: GroundwaterFlowModel | ||
) -> tuple[NodeSvatMapping, RechargeSvatMapping, WellSvatMapping | None]: | ||
if self.mf6_recharge_package not in gwf_model.keys(): | ||
raise ValueError( | ||
f"No package named {self.mf6_recharge_package} detected in Modflow 6 model. " | ||
"iMOD_coupler requires a Recharge package." | ||
) | ||
|
||
grid_data_key = [ | ||
pkgname for pkgname, pkg in msw_model.items() if isinstance(pkg, GridData) | ||
][0] | ||
|
||
dis = gwf_model[gwf_model._get_pkgkey("dis")] | ||
|
||
index, svat = msw_model[grid_data_key].generate_index_array() | ||
grid_mapping = NodeSvatMapping(svat=svat, modflow_dis=dis, index=index) | ||
|
||
recharge = gwf_model[self.mf6_recharge_package] | ||
|
||
rch_mapping = RechargeSvatMapping(svat, recharge, index=index) | ||
|
||
if self._check_sprinkling(msw_model=msw_model, gwf_model=gwf_model): | ||
well = gwf_model[self.mf6_wel_package] | ||
well_mapping = WellSvatMapping(svat, well, index=index) | ||
return grid_mapping, rch_mapping, well_mapping | ||
else: | ||
return grid_mapping, rch_mapping, None | ||
|
||
def write_exchanges(self, directory: Path, coupled_model: Any) -> dict[str, Any]: | ||
mf6_simulation = coupled_model.mf6_simulation | ||
gwf_model = mf6_simulation[self.mf6_model] | ||
msw_model = coupled_model.msw_model | ||
|
||
grid_mapping, rch_mapping, well_mapping = self.derive_mapping( | ||
msw_model=msw_model, | ||
gwf_model=gwf_model, | ||
) | ||
|
||
coupling_dict: dict[str, Any] = {} | ||
coupling_dict["mf6_model"] = self.mf6_model | ||
|
||
coupling_dict["mf6_msw_node_map"] = grid_mapping.write(directory) | ||
coupling_dict["mf6_msw_recharge_pkg"] = self.mf6_recharge_package | ||
coupling_dict["mf6_msw_recharge_map"] = rch_mapping.write(directory) | ||
coupling_dict["enable_sprinkling"] = False | ||
|
||
if well_mapping is not None: | ||
coupling_dict["enable_sprinkling"] = True | ||
coupling_dict["mf6_msw_well_pkg"] = self.mf6_wel_package | ||
coupling_dict["mf6_msw_sprinkling_map"] = well_mapping.write(directory) | ||
|
||
return coupling_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import copy | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
import geopandas as gpd | ||
import imod | ||
import numpy as np | ||
import ribasim | ||
from imod.msw import GridData, MetaSwapModel, Sprinkling | ||
|
||
from primod.driver_coupling.driver_coupling_base import DriverCoupling | ||
from primod.driver_coupling.util import ( | ||
_nullify_ribasim_exchange_input, | ||
_validate_node_ids, | ||
) | ||
from primod.mapping.svat_basin_mapping import SvatBasinMapping | ||
from primod.mapping.svat_user_demand_mapping import SvatUserDemandMapping | ||
|
||
|
||
class RibaMetaDriverCoupling(DriverCoupling): | ||
"""A dataclass representing one coupling scenario for the RibaMod driver. | ||
Attributes | ||
---------- | ||
basin_definition: gpd.GeoDataFrame | ||
GeoDataFrame of basin polygons | ||
user_demand_definition: gpd.GeoDataFrame | ||
GeoDataFrame of user demand polygons | ||
""" | ||
|
||
ribasim_basin_definition: gpd.GeoDataFrame | ||
ribasim_user_demand_definition: gpd.GeoDataFrame | None = None | ||
|
||
def _check_sprinkling(self, msw_model: MetaSwapModel) -> bool: | ||
sprinkling_key = msw_model._get_pkg_key(Sprinkling, optional_package=True) | ||
sprinkling_in_msw = sprinkling_key is not None | ||
sprinkling_in_ribasim = self.ribasim_user_demand_definition is not None | ||
|
||
if sprinkling_in_ribasim: | ||
if sprinkling_in_msw: | ||
return True | ||
else: | ||
raise ValueError( | ||
"Ribasim UserDemand definition provided, " | ||
"but no Sprinkling package found in MetaSWAP model." | ||
) | ||
else: | ||
return False | ||
|
||
def derive_mapping( | ||
self, | ||
ribasim_model: ribasim.Model, | ||
msw_model: MetaSwapModel, | ||
) -> tuple[SvatBasinMapping, SvatUserDemandMapping | None]: | ||
grid_data_key = [ | ||
pkgname for pkgname, pkg in msw_model.items() if isinstance(pkg, GridData) | ||
][0] | ||
|
||
index, svat = msw_model[grid_data_key].generate_index_array() | ||
basin_ids = _validate_node_ids( | ||
ribasim_model.basin.node.df, self.ribasim_basin_definition | ||
) | ||
gridded_basin = imod.prepare.rasterize( | ||
self.ribasim_basin_definition, | ||
like=svat, | ||
column="node_id", | ||
) | ||
svat_basin_mapping = SvatBasinMapping( | ||
name="msw_ponding", | ||
gridded_basin=gridded_basin, | ||
basin_ids=basin_ids, | ||
svat=svat, | ||
index=index, | ||
) | ||
|
||
if self._check_sprinkling(msw_model=msw_model): | ||
user_demand_ids = _validate_node_ids( | ||
ribasim_model.user_demand.node.df, self.ribasim_user_demand_definition | ||
) | ||
gridded_user_demand = imod.prepare.rasterize( | ||
self.ribasim_basin_definition, | ||
like=svat, | ||
column="node_id", | ||
) | ||
# sprinkling surface water for subsection of svats determined in 'sprinkling' | ||
swspr_grid_data = copy.deepcopy(msw_model[grid_data_key]) | ||
nsu = swspr_grid_data.dataset["area"].sizes["subunit"] | ||
swsprmax = msw_model["sprinkling"] | ||
swspr_grid_data.dataset["area"].values = np.tile( | ||
swsprmax["max_abstraction_surfacewater_m3_d"].values, | ||
(nsu, 1, 1), | ||
) | ||
index_swspr, svat_swspr = swspr_grid_data.generate_index_array() | ||
svat_user_demand_mapping = SvatUserDemandMapping( | ||
name="msw_sw_sprinkling", | ||
gridded_user_demand=gridded_user_demand, | ||
user_demand_ids=user_demand_ids, | ||
svat=svat_swspr, | ||
index=index_swspr, | ||
) | ||
return svat_basin_mapping, svat_user_demand_mapping | ||
else: | ||
return svat_basin_mapping, None | ||
|
||
def write_exchanges(self, directory: Path, coupled_model: Any) -> dict[str, Any]: | ||
ribasim_model = coupled_model.ribasim_model | ||
msw_model = coupled_model.msw_model | ||
|
||
svat_basin_mapping, svat_user_demand_mapping = self.derive_mapping( | ||
ribasim_model=ribasim_model, | ||
msw_model=msw_model, | ||
) | ||
|
||
coupling_dict: dict[str, Any] = {} | ||
coupling_dict["rib_msw_ponding_map_surface_water"] = svat_basin_mapping.write( | ||
directory=directory | ||
) | ||
|
||
# Set Ribasim runoff input to Null for coupled basins | ||
basin_ids = _validate_node_ids( | ||
ribasim_model.basin.node.df, self.ribasim_basin_definition | ||
) | ||
coupled_basin_indices = svat_basin_mapping.dataframe["basin_index"] | ||
coupled_basin_node_ids = basin_ids[coupled_basin_indices] | ||
_nullify_ribasim_exchange_input( | ||
ribasim_component=ribasim_model.basin, | ||
coupled_node_ids=coupled_basin_node_ids, | ||
columns=["runoff"], | ||
) | ||
|
||
# Now deal with sprinkling if set | ||
if svat_user_demand_mapping is not None: | ||
user_demand_ids = _validate_node_ids( | ||
ribasim_model.user_demand.node.df, self.ribasim_user_demand_definition | ||
) | ||
coupling_dict["rib_msw_sprinkling_map_surface_water"] = ( | ||
svat_user_demand_mapping.write(directory=directory) | ||
) | ||
coupled_user_demand_indices = svat_user_demand_mapping.dataframe[ | ||
"user_demand_index" | ||
] | ||
coupled_user_demand_node_ids = user_demand_ids[coupled_user_demand_indices] | ||
_nullify_ribasim_exchange_input( | ||
ribasim_component=ribasim_model.user_demand, | ||
coupled_node_ids=coupled_user_demand_node_ids, | ||
columns=["demand"], | ||
) | ||
return coupling_dict |
Oops, something went wrong.