From da5b8f1810142d3f66ee078e5c9dfe016195526b Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Tue, 30 Jan 2024 18:44:35 +0100 Subject: [PATCH] Issue #708 Refactor dataset initialization (#722) Fix https://github.com/Deltares/imod-python/issues/708 The changeset is incomplete, as it is not rolled out for all packages yet. Therefore tests will fail, hence the draft status. This to make the review process more focused, reviewers are asked to provide feedback on the approach. Changes: - Variable names are inferred based on what is provided to locals(). This makes it possible to avoid having to manually assign variables in the __init__ of each specific package. - I had to add one function to work around this issue: https://github.com/Deltares/xugrid/issues/179 - Grids are merged with an exact join, this to avoid that xarray joins variable coordinates in an unwanted way when dataarrays are inconsistent, which causes issues like https://github.com/Deltares/imod-python/issues/674 - Refactored the riv unit tests to use pytest cases, so that tests (without the xr.merge) are run for both unstructured and structured data. Update: - Variable names are not inferred with locals() anymore, but instead with a ``pkg_init`` decorator. --- docs/api/changelog.rst | 3 + imod/mf6/adv.py | 4 +- imod/mf6/boundary_condition.py | 15 +- imod/mf6/buy.py | 22 +- imod/mf6/chd.py | 23 +- imod/mf6/cnc.py | 14 +- imod/mf6/dis.py | 10 +- imod/mf6/disv.py | 10 +- imod/mf6/drn.py | 26 +-- imod/mf6/dsp.py | 20 +- imod/mf6/evt.py | 33 ++- imod/mf6/ghb.py | 25 +-- imod/mf6/gwfgwf.py | 22 +- imod/mf6/gwfgwt.py | 9 +- imod/mf6/hfb.py | 4 +- imod/mf6/ic.py | 4 +- imod/mf6/ims.py | 59 ++--- imod/mf6/ist.py | 42 ++-- imod/mf6/lak.py | 210 +++++++++--------- imod/mf6/mf6_hfb_adapter.py | 14 +- imod/mf6/mf6_wel_adapter.py | 16 +- imod/mf6/mst.py | 24 +- imod/mf6/npf.py | 53 +++-- imod/mf6/oc.py | 17 +- imod/mf6/package.py | 21 +- imod/mf6/pkgbase.py | 18 +- imod/mf6/rch.py | 23 +- imod/mf6/regridding_utils.py | 19 +- imod/mf6/riv.py | 28 +-- imod/mf6/src.py | 14 +- imod/mf6/ssm.py | 22 +- imod/mf6/sto.py | 28 ++- imod/mf6/timedis.py | 10 +- imod/mf6/uzf.py | 78 +++---- imod/mf6/wel.py | 97 ++++---- imod/tests/test_mf6/test_exchangebase.py | 8 +- imod/tests/test_mf6/test_mf6_lak.py | 3 +- imod/tests/test_mf6/test_mf6_riv.py | 99 ++++++--- imod/tests/test_mf6/test_mf6_sto.py | 4 +- .../test_mf6/test_mf6_transport_model.py | 7 +- imod/typing/grid.py | 110 +++++++++ 41 files changed, 733 insertions(+), 535 deletions(-) diff --git a/docs/api/changelog.rst b/docs/api/changelog.rst index 4ff8f3970..b7e6787e4 100644 --- a/docs/api/changelog.rst +++ b/docs/api/changelog.rst @@ -14,6 +14,9 @@ Fixed - iMOD Python now supports versions of pandas >= 2 - Fixed bugs with clipping :class:`imod.mf6.HorizontalFlowBarrier` for structured grids +- Packages and boundary conditions in the ``imod.mf6`` module will now throw an + error upon initialization if coordinate labels are inconsistent amongst + variables - Improved performance for merging structured multimodel Modflow 6 output - Bug where :function:`imod.formats.idf.open_subdomains` did not properly support custom patterns diff --git a/imod/mf6/adv.py b/imod/mf6/adv.py index e054bfebf..62b49c73b 100644 --- a/imod/mf6/adv.py +++ b/imod/mf6/adv.py @@ -18,8 +18,8 @@ class Advection(Package): _template = Package._initialize_template(_pkg_id) def __init__(self, scheme): - super().__init__() - self.dataset["scheme"] = scheme + dict_dataset = {"scheme": scheme} + super().__init__(dict_dataset) def render(self, directory, pkgname, globaltimes, binary): scheme = self.dataset["scheme"].item() diff --git a/imod/mf6/boundary_condition.py b/imod/mf6/boundary_condition.py index bbfbd62bf..7fe1c7c40 100644 --- a/imod/mf6/boundary_condition.py +++ b/imod/mf6/boundary_condition.py @@ -7,9 +7,13 @@ import xarray as xr import xugrid as xu -from imod.mf6.auxiliary_variables import get_variable_names +from imod.mf6.auxiliary_variables import ( + expand_transient_auxiliary_variables, + get_variable_names, +) from imod.mf6.package import Package from imod.mf6.write_context import WriteContext +from imod.typing.grid import GridDataArray def _dis_recarr(arrdict, layer, notnull): @@ -64,6 +68,15 @@ class BoundaryCondition(Package, abc.ABC): not the array input which is used in :class:`Package`. """ + def __init__(self, allargs: dict[str, GridDataArray | float | int | bool | str]): + super().__init__(allargs) + if "concentration" in allargs.keys() and allargs["concentration"] is None: + # Remove vars inplace + del self.dataset["concentration"] + del self.dataset["concentration_boundary_type"] + else: + expand_transient_auxiliary_variables(self) + def _max_active_n(self): """ Determine the maximum active number of cells that are active diff --git a/imod/mf6/buy.py b/imod/mf6/buy.py index 63dd90e4d..79d42e494 100644 --- a/imod/mf6/buy.py +++ b/imod/mf6/buy.py @@ -110,17 +110,17 @@ def __init__( densityfile: str = None, validate: bool = True, ): - super().__init__(locals()) - self.dataset["reference_density"] = reference_density - # Assign a shared index: this also forces equal lenghts - self.dataset["density_concentration_slope"] = assign_index( - density_concentration_slope - ) - self.dataset["reference_concentration"] = assign_index(reference_concentration) - self.dataset["modelname"] = assign_index(modelname) - self.dataset["species"] = assign_index(species) - self.dataset["hhformulation_rhs"] = hhformulation_rhs - self.dataset["densityfile"] = densityfile + dict_dataset = { + "reference_density": reference_density, + # Assign a shared index: this also forces equal lenghts + "density_concentration_slope": assign_index(density_concentration_slope), + "reference_concentration": assign_index(reference_concentration), + "modelname": assign_index(modelname), + "species": assign_index(species), + "hhformulation_rhs": hhformulation_rhs, + "densityfile": densityfile, + } + super().__init__(dict_dataset) self.dependencies = [] self._validate_init_schemata(validate) diff --git a/imod/mf6/chd.py b/imod/mf6/chd.py index ad39a54fb..45986320e 100644 --- a/imod/mf6/chd.py +++ b/imod/mf6/chd.py @@ -1,6 +1,5 @@ import numpy as np -from imod.mf6.auxiliary_variables import expand_transient_auxiliary_variables from imod.mf6.boundary_condition import BoundaryCondition from imod.mf6.regridding_utils import RegridderType from imod.mf6.validation import BOUNDARY_DIMS_SCHEMA, CONC_DIMS_SCHEMA @@ -126,17 +125,17 @@ def __init__( validate: bool = True, repeat_stress=None, ): - super().__init__(locals()) - self.dataset["head"] = head - if concentration is not None: - self.dataset["concentration"] = concentration - self.dataset["concentration_boundary_type"] = concentration_boundary_type - expand_transient_auxiliary_variables(self) - self.dataset["print_input"] = print_input - self.dataset["print_flows"] = print_flows - self.dataset["save_flows"] = save_flows - self.dataset["observations"] = observations - self.dataset["repeat_stress"] = repeat_stress + dict_dataset = { + "head": head, + "concentration": concentration, + "concentration_boundary_type": concentration_boundary_type, + "print_input": print_input, + "print_flows": print_flows, + "save_flows": save_flows, + "observations": observations, + "repeat_stress": repeat_stress, + } + super().__init__(dict_dataset) self._validate_init_schemata(validate) def _validate(self, schemata, **kwargs): diff --git a/imod/mf6/cnc.py b/imod/mf6/cnc.py index 09cbe3936..284ff9612 100644 --- a/imod/mf6/cnc.py +++ b/imod/mf6/cnc.py @@ -74,10 +74,12 @@ def __init__( observations=None, validate: bool = True, ): - super().__init__(locals()) - self.dataset["concentration"] = concentration - self.dataset["print_input"] = print_input - self.dataset["print_flows"] = print_flows - self.dataset["save_flows"] = save_flows - self.dataset["observations"] = observations + dict_dataset = { + "concentration": concentration, + "print_input": print_input, + "print_flows": print_flows, + "save_flows": save_flows, + "observations": observations, + } + super().__init__(dict_dataset) self._validate_init_schemata(validate) diff --git a/imod/mf6/dis.py b/imod/mf6/dis.py index 43bd6ecce..e636f99b3 100644 --- a/imod/mf6/dis.py +++ b/imod/mf6/dis.py @@ -91,10 +91,12 @@ class StructuredDiscretization(Package): _skip_mask_arrays = ["bottom", "idomain"] def __init__(self, top, bottom, idomain, validate: bool = True): - super(__class__, self).__init__(locals()) - self.dataset["idomain"] = idomain - self.dataset["top"] = top - self.dataset["bottom"] = bottom + dict_dataset = { + "idomain": idomain, + "top": top, + "bottom": bottom, + } + super().__init__(dict_dataset) self._validate_init_schemata(validate) def _delrc(self, dx): diff --git a/imod/mf6/disv.py b/imod/mf6/disv.py index 656d2d981..4196384c7 100644 --- a/imod/mf6/disv.py +++ b/imod/mf6/disv.py @@ -72,10 +72,12 @@ class VerticesDiscretization(Package): _skip_mask_arrays = ["bottom", "idomain"] def __init__(self, top, bottom, idomain, validate: bool = True): - super().__init__(locals()) - self.dataset["idomain"] = idomain - self.dataset["top"] = top - self.dataset["bottom"] = bottom + dict_dataset = { + "idomain": idomain, + "top": top, + "bottom": bottom, + } + super().__init__(dict_dataset) self._validate_init_schemata(validate) def render(self, directory, pkgname, binary): diff --git a/imod/mf6/drn.py b/imod/mf6/drn.py index 6e904e104..ec04afb17 100644 --- a/imod/mf6/drn.py +++ b/imod/mf6/drn.py @@ -1,6 +1,5 @@ import numpy as np -from imod.mf6.auxiliary_variables import expand_transient_auxiliary_variables from imod.mf6.boundary_condition import BoundaryCondition from imod.mf6.regridding_utils import RegridderType from imod.mf6.validation import BOUNDARY_DIMS_SCHEMA @@ -115,18 +114,19 @@ def __init__( validate: bool = True, repeat_stress=None, ): - super().__init__(locals()) - self.dataset["elevation"] = elevation - self.dataset["conductance"] = conductance - if concentration is not None: - self.dataset["concentration"] = concentration - self.dataset["concentration_boundary_type"] = concentration_boundary_type - expand_transient_auxiliary_variables(self) - self.dataset["print_input"] = print_input - self.dataset["print_flows"] = print_flows - self.dataset["save_flows"] = save_flows - self.dataset["observations"] = observations - self.dataset["repeat_stress"] = repeat_stress + dict_dataset = { + "elevation": elevation, + "conductance": conductance, + "concentration": concentration, + "concentration_boundary_type": concentration_boundary_type, + "print_input": print_input, + "print_flows": print_flows, + "save_flows": save_flows, + "observations": observations, + "repeat_stress": repeat_stress, + } + super().__init__(dict_dataset) + self._validate_init_schemata(validate) def _validate(self, schemata, **kwargs): diff --git a/imod/mf6/dsp.py b/imod/mf6/dsp.py index a9f640b2d..589e1d1ec 100644 --- a/imod/mf6/dsp.py +++ b/imod/mf6/dsp.py @@ -145,13 +145,15 @@ def __init__( xt3d_rhs=False, validate: bool = True, ): - super().__init__(locals()) - self.dataset["xt3d_off"] = xt3d_off - self.dataset["xt3d_rhs"] = xt3d_rhs - self.dataset["diffusion_coefficient"] = diffusion_coefficient - self.dataset["longitudinal_horizontal"] = longitudinal_horizontal - self.dataset["transversal_horizontal1"] = transversal_horizontal1 - self.dataset["longitudinal_vertical"] = longitudinal_vertical - self.dataset["transversal_horizontal2"] = transversal_horizontal2 - self.dataset["transversal_vertical"] = transversal_vertical + dict_dataset = { + "xt3d_off": xt3d_off, + "xt3d_rhs": xt3d_rhs, + "diffusion_coefficient": diffusion_coefficient, + "longitudinal_horizontal": longitudinal_horizontal, + "transversal_horizontal1": transversal_horizontal1, + "longitudinal_vertical": longitudinal_vertical, + "transversal_horizontal2": transversal_horizontal2, + "transversal_vertical": transversal_vertical, + } + super().__init__(dict_dataset) self._validate_init_schemata(validate) diff --git a/imod/mf6/evt.py b/imod/mf6/evt.py index dc92025e5..341dc08f3 100644 --- a/imod/mf6/evt.py +++ b/imod/mf6/evt.py @@ -2,7 +2,6 @@ import numpy as np -from imod.mf6.auxiliary_variables import expand_transient_auxiliary_variables from imod.mf6.boundary_condition import BoundaryCondition from imod.mf6.regridding_utils import RegridderType from imod.mf6.validation import BOUNDARY_DIMS_SCHEMA @@ -193,27 +192,27 @@ def __init__( validate: bool = True, repeat_stress=None, ): - super().__init__(locals()) - self.dataset["surface"] = surface - self.dataset["rate"] = rate - self.dataset["depth"] = depth if ("segment" in proportion_rate.dims) ^ ("segment" in proportion_depth.dims): raise ValueError( "Segment must be provided for both proportion_rate and" " proportion_depth, or for none at all." ) - self.dataset["proportion_rate"] = proportion_rate - self.dataset["proportion_depth"] = proportion_depth - if concentration is not None: - self.dataset["concentration"] = concentration - self.dataset["concentration_boundary_type"] = concentration_boundary_type - expand_transient_auxiliary_variables(self) - self.dataset["fixed_cell"] = fixed_cell - self.dataset["print_input"] = print_input - self.dataset["print_flows"] = print_flows - self.dataset["save_flows"] = save_flows - self.dataset["observations"] = observations - self.dataset["repeat_stress"] = repeat_stress + dict_dataset = { + "surface": surface, + "rate": rate, + "depth": depth, + "proportion_rate": proportion_rate, + "proportion_depth": proportion_depth, + "concentration": concentration, + "concentration_boundary_type": concentration_boundary_type, + "fixed_cell": fixed_cell, + "print_input": print_input, + "print_flows": print_flows, + "save_flows": save_flows, + "observations": observations, + "repeat_stress": repeat_stress, + } + super().__init__(dict_dataset) self._validate_init_schemata(validate) def _validate(self, schemata, **kwargs): diff --git a/imod/mf6/ghb.py b/imod/mf6/ghb.py index 7f6a0caff..05fa57a5a 100644 --- a/imod/mf6/ghb.py +++ b/imod/mf6/ghb.py @@ -1,6 +1,5 @@ import numpy as np -from imod.mf6.auxiliary_variables import expand_transient_auxiliary_variables from imod.mf6.boundary_condition import BoundaryCondition from imod.mf6.regridding_utils import RegridderType from imod.mf6.validation import BOUNDARY_DIMS_SCHEMA, CONC_DIMS_SCHEMA @@ -133,18 +132,18 @@ def __init__( validate: bool = True, repeat_stress=None, ): - super().__init__(locals()) - self.dataset["head"] = head - self.dataset["conductance"] = conductance - if concentration is not None: - self.dataset["concentration"] = concentration - self.dataset["concentration_boundary_type"] = concentration_boundary_type - expand_transient_auxiliary_variables(self) - self.dataset["print_input"] = print_input - self.dataset["print_flows"] = print_flows - self.dataset["save_flows"] = save_flows - self.dataset["observations"] = observations - self.dataset["repeat_stress"] = repeat_stress + dict_dataset = { + "head": head, + "conductance": conductance, + "concentration": concentration, + "concentration_boundary_type": concentration_boundary_type, + "print_input": print_input, + "print_flows": print_flows, + "save_flows": save_flows, + "observations": observations, + "repeat_stress": repeat_stress, + } + super().__init__(dict_dataset) self._validate_init_schemata(validate) def _validate(self, schemata, **kwargs): diff --git a/imod/mf6/gwfgwf.py b/imod/mf6/gwfgwf.py index a1d6df261..add2555c1 100644 --- a/imod/mf6/gwfgwf.py +++ b/imod/mf6/gwfgwf.py @@ -34,16 +34,18 @@ def __init__( angldegx: Optional[xr.DataArray] = None, cdist: Optional[xr.DataArray] = None, ): - super().__init__(locals()) - self.dataset["cell_id1"] = cell_id1 - self.dataset["cell_id2"] = cell_id2 - self.dataset["layer"] = layer - self.dataset["model_name_1"] = model_id1 - self.dataset["model_name_2"] = model_id2 - self.dataset["ihc"] = xr.DataArray(np.ones_like(cl1, dtype=int)) - self.dataset["cl1"] = cl1 - self.dataset["cl2"] = cl2 - self.dataset["hwva"] = hwva + dict_dataset = { + "cell_id1": cell_id1, + "cell_id2": cell_id2, + "layer": layer, + "model_name_1": model_id1, + "model_name_2": model_id2, + "ihc": xr.DataArray(np.ones_like(cl1, dtype=int)), + "cl1": cl1, + "cl2": cl2, + "hwva": hwva, + } + super().__init__(dict_dataset) auxiliary_variables = [var for var in [angldegx, cdist] if var is not None] if auxiliary_variables: diff --git a/imod/mf6/gwfgwt.py b/imod/mf6/gwfgwt.py index 35f8a0e75..62d6c0a48 100644 --- a/imod/mf6/gwfgwt.py +++ b/imod/mf6/gwfgwt.py @@ -14,9 +14,12 @@ class GWFGWT(ExchangeBase): _template = Package._initialize_template(_pkg_id) def __init__(self, model_id1: str, model_id2: str): - super().__init__(locals()) - self.dataset["model_name_1"] = model_id1 - self.dataset["model_name_2"] = model_id2 + dict_dataset = { + "model_name_1": model_id1, + "model_name_2": model_id2, + } + + super().__init__(dict_dataset) def clip_box( self, diff --git a/imod/mf6/hfb.py b/imod/mf6/hfb.py index 1f8cdbfa8..6595ee1fe 100644 --- a/imod/mf6/hfb.py +++ b/imod/mf6/hfb.py @@ -280,8 +280,8 @@ def __init__( geometry: gpd.GeoDataFrame, print_input: bool = False, ) -> None: - super().__init__(locals()) - self.dataset["print_input"] = print_input + dict_dataset = {"print_input": print_input} + super().__init__(dict_dataset) self.line_data = geometry diff --git a/imod/mf6/ic.py b/imod/mf6/ic.py index 5c0cf521f..88e996fdd 100644 --- a/imod/mf6/ic.py +++ b/imod/mf6/ic.py @@ -67,7 +67,6 @@ class InitialConditions(Package): } def __init__(self, start=None, head=None, validate: bool = True): - super().__init__(locals()) if start is None: start = head warnings.warn( @@ -80,7 +79,8 @@ def __init__(self, start=None, head=None, validate: bool = True): if head is not None: raise ValueError("start and head arguments cannot both be defined") - self.dataset["start"] = start + dict_dataset = {"start": start} + super().__init__(dict_dataset) self._validate_init_schemata(validate) def render(self, directory, pkgname, globaltimes, binary): diff --git a/imod/mf6/ims.py b/imod/mf6/ims.py index d5ed53b37..012d43e86 100644 --- a/imod/mf6/ims.py +++ b/imod/mf6/ims.py @@ -393,38 +393,39 @@ def __init__( no_ptc=False, validate: bool = True, ): - super().__init__() - self.dataset = xr.Dataset() + dict_dataset = { + "outer_dvclose": outer_dvclose, + "outer_maximum": outer_maximum, + "under_relaxation": under_relaxation, + "under_relaxation_theta": under_relaxation_theta, + "under_relaxation_kappa": under_relaxation_kappa, + "under_relaxation_gamma": under_relaxation_gamma, + "under_relaxation_momentum": under_relaxation_momentum, + "backtracking_number": backtracking_number, + "backtracking_tolerance": backtracking_tolerance, + "backtracking_reduction_factor": backtracking_reduction_factor, + "backtracking_residual_limit": backtracking_residual_limit, + "inner_maximum": inner_maximum, + "inner_dvclose": inner_dvclose, + "inner_rclose": inner_rclose, + "rclose_option": rclose_option, + "linear_acceleration": linear_acceleration, + "relaxation_factor": relaxation_factor, + "preconditioner_levels": preconditioner_levels, + "preconditioner_drop_tolerance": preconditioner_drop_tolerance, + "number_orthogonalizations": number_orthogonalizations, + "scaling_method": scaling_method, + "reordering_method": reordering_method, + "print_option": print_option, + "csv_output": csv_output, + "no_ptc": no_ptc, + } # Make sure the modelnames are set as a variable rather than dimension: if isinstance(modelnames, xr.DataArray): - self.dataset["modelnames"] = modelnames + dict_dataset["modelnames"] = modelnames else: - self.dataset["modelnames"] = ("model", modelnames) - self.dataset["outer_dvclose"] = outer_dvclose - self.dataset["outer_maximum"] = outer_maximum - self.dataset["under_relaxation"] = under_relaxation - self.dataset["under_relaxation_theta"] = under_relaxation_theta - self.dataset["under_relaxation_kappa"] = under_relaxation_kappa - self.dataset["under_relaxation_gamma"] = under_relaxation_gamma - self.dataset["under_relaxation_momentum"] = under_relaxation_momentum - self.dataset["backtracking_number"] = backtracking_number - self.dataset["backtracking_tolerance"] = backtracking_tolerance - self.dataset["backtracking_reduction_factor"] = backtracking_reduction_factor - self.dataset["backtracking_residual_limit"] = backtracking_residual_limit - self.dataset["inner_maximum"] = inner_maximum - self.dataset["inner_dvclose"] = inner_dvclose - self.dataset["inner_rclose"] = inner_rclose - self.dataset["rclose_option"] = rclose_option - self.dataset["linear_acceleration"] = linear_acceleration - self.dataset["relaxation_factor"] = relaxation_factor - self.dataset["preconditioner_levels"] = preconditioner_levels - self.dataset["preconditioner_drop_tolerance"] = preconditioner_drop_tolerance - self.dataset["number_orthogonalizations"] = number_orthogonalizations - self.dataset["scaling_method"] = scaling_method - self.dataset["reordering_method"] = reordering_method - self.dataset["print_option"] = print_option - self.dataset["csv_output"] = csv_output - self.dataset["no_ptc"] = no_ptc + dict_dataset["modelnames"] = ("model", modelnames) + super().__init__(dict_dataset) self._validate_init_schemata(validate) diff --git a/imod/mf6/ist.py b/imod/mf6/ist.py index 0d153af80..0344262a1 100644 --- a/imod/mf6/ist.py +++ b/imod/mf6/ist.py @@ -227,25 +227,25 @@ def __init__( "provided.", ) - super().__init__(locals()) - self.dataset["initial_immobile_concentration"] = initial_immobile_concentration - self.dataset[ - "mobile_immobile_mass_transfer_rate" - ] = mobile_immobile_mass_transfer_rate - self.dataset["immobile_porosity"] = immobile_porosity - self.dataset["decay"] = decay - self.dataset["decay_sorbed"] = decay_sorbed - self.dataset["bulk_density"] = bulk_density - self.dataset["distribution_coefficient"] = distribution_coefficient - self.dataset["save_flows"] = save_flows - self.dataset["budgetfile"] = budgetbinfile - self.dataset["budgetcsvfile"] = budgetcsvfile - self.dataset["sorption"] = sorption - self.dataset["first_order_decay"] = first_order_decay - self.dataset["zero_order_decay"] = zero_order_decay - self.dataset["cimfile"] = cimfile - self.dataset["columns "] = columns - self.dataset["width"] = width - self.dataset["digits"] = digits - self.dataset["format"] = format + dict_dataset = { + "initial_immobile_concentration": initial_immobile_concentration, + "mobile_immobile_mass_transfer_rate": mobile_immobile_mass_transfer_rate, + "immobile_porosity": immobile_porosity, + "decay": decay, + "decay_sorbed": decay_sorbed, + "bulk_density": bulk_density, + "distribution_coefficient": distribution_coefficient, + "save_flows": save_flows, + "budgetfile": budgetbinfile, + "budgetcsvfile": budgetcsvfile, + "sorption": sorption, + "first_order_decay": first_order_decay, + "zero_order_decay": zero_order_decay, + "cimfile": cimfile, + "columns ": columns, + "width": width, + "digits": digits, + "format": format, + } + super().__init__(dict_dataset) self._validate_init_schemata(validate) diff --git a/imod/mf6/lak.py b/imod/mf6/lak.py index dc3eb11ae..32df0c0c0 100644 --- a/imod/mf6/lak.py +++ b/imod/mf6/lak.py @@ -31,8 +31,8 @@ class LakeApi_Base(PackageBase): Base class for lake and outlet object. """ - def __init__(self): - super().__init__() + def __init__(self, allargs): + super().__init__(allargs) class LakeData(LakeApi_Base): @@ -98,25 +98,20 @@ def __init__( auxiliary=None, lake_table=None, ): - super().__init__() - self.dataset["starting_stage"] = starting_stage - self.dataset["boundname"] = boundname - self.dataset["connection_type"] = (connection_type.dims, connection_type.values) - self.dataset["bed_leak"] = (bed_leak.dims, bed_leak.values) - self.dataset["top_elevation"] = (top_elevation.dims, top_elevation.values) - self.dataset["bottom_elevation"] = (bot_elevation.dims, bot_elevation.values) - self.dataset["connection_length"] = ( - connection_length.dims, - connection_length.values, - ) - self.dataset["connection_width"] = ( - connection_width.dims, - connection_width.values, - ) - self.dataset["lake_table"] = lake_table + dict_dataset = { + "starting_stage": starting_stage, + "boundname": boundname, + "connection_type": (connection_type.dims, connection_type.values), + "bed_leak": (bed_leak.dims, bed_leak.values), + "top_elevation": (top_elevation.dims, top_elevation.values), + "bottom_elevation": (bot_elevation.dims, bot_elevation.values), + "connection_length": (connection_length.dims, connection_length.values), + "connection_width": (connection_width.dims, connection_width.values), + "lake_table": lake_table, + } + super().__init__(dict_dataset) # timeseries data - times = [] timeseries_dict = { "status": status, @@ -153,15 +148,8 @@ class OutletBase(LakeApi_Base): timeseries_names = ["rate", "invert", "rough", "width", "slope"] - def __init__(self, lakein: str, lakeout: str): - super().__init__() - self.dataset = xr.Dataset() - self.dataset["lakein"] = lakein - self.dataset["lakeout"] = lakeout - self.dataset["invert"] = None - self.dataset["width"] = None - self.dataset["roughness"] = None - self.dataset["slope"] = None + def __init__(self, allargs): + super().__init__(allargs) class OutletManning(OutletBase): @@ -181,11 +169,15 @@ def __init__( roughness, slope, ): - super().__init__(lakein, lakeout) - self.dataset["invert"] = invert - self.dataset["width"] = width - self.dataset["roughness"] = roughness - self.dataset["slope"] = slope + dict_dataset = { + "lakein": lakein, + "lakeout": lakeout, + "invert": invert, + "width": width, + "roughness": roughness, + "slope": slope, + } + super().__init__(dict_dataset) class OutletWeir(OutletBase): @@ -196,9 +188,15 @@ class OutletWeir(OutletBase): _couttype = "weir" def __init__(self, lakein: str, lakeout: str, invert, width): - super().__init__(lakein, lakeout) - self.dataset["invert"] = invert - self.dataset["width"] = width + dict_dataset = { + "lakein": lakein, + "lakeout": lakeout, + "invert": invert, + "width": width, + "roughness": None, + "slope": None, + } + super().__init__(dict_dataset) class OutletSpecified(OutletBase): @@ -209,8 +207,16 @@ class OutletSpecified(OutletBase): _couttype = "specified" def __init__(self, lakein: str, lakeout: str, rate): - super().__init__(lakein, lakeout) - self.dataset["rate"] = rate + dict_dataset = { + "lakein": lakein, + "lakeout": lakeout, + "invert": None, + "width": None, + "roughness": None, + "slope": None, + "rate": rate, + } + super().__init__(dict_dataset) def create_connection_data(lakes): @@ -323,9 +329,9 @@ def create_outlet_data(outlets, name_to_number): def concatenate_timeseries(list_of_lakes_or_outlets, timeseries_name): """ - In this function we create a dataarray with a given time coorridnate axis. We add all - the timeseries of lakes or outlets with the given name. We also create a dimension to - specify the lake or outlet number. + In this function we create a dataarray with a given time coordinate axis. We + add all the timeseries of lakes or outlets with the given name. We also + create a dimension to specify the lake or outlet number. """ if list_of_lakes_or_outlets is None: return None @@ -726,68 +732,56 @@ def __init__( length_conversion=None, validate=True, ): - super().__init__(locals()) - self.dataset["lake_boundname"] = lake_boundname - self.dataset["lake_number"] = lake_number - self.dataset["lake_starting_stage"] = lake_starting_stage - - nr_indices = int(self.dataset["lake_number"].data.max()) - if outlet_lakein is not None: - nroutlets = len(outlet_lakein.data) - nr_indices = max(nr_indices, nroutlets) - - self.dataset = self.dataset.assign_coords(index=range(1, nr_indices + 1, 1)) - - self.dataset["connection_lake_number"] = connection_lake_number - self.dataset["connection_cell_id"] = connection_cell_id - self.dataset["connection_type"] = connection_type - self.dataset["connection_bed_leak"] = connection_bed_leak - self.dataset["connection_bottom_elevation"] = connection_bottom_elevation - self.dataset["connection_top_elevation"] = connection_top_elevation - self.dataset["connection_width"] = connection_width - self.dataset["connection_length"] = connection_length - - self.dataset["outlet_lakein"] = outlet_lakein - self.dataset["outlet_lakeout"] = outlet_lakeout - self.dataset["outlet_couttype"] = outlet_couttype - self.dataset["outlet_invert"] = outlet_invert - self.dataset["outlet_roughness"] = outlet_roughness - self.dataset["outlet_width"] = outlet_width - self.dataset["outlet_slope"] = outlet_slope - - self.dataset["print_input"] = print_input - self.dataset["print_stage"] = print_stage - self.dataset["print_flows"] = print_flows - self.dataset["save_flows"] = save_flows - - self.dataset["stagefile"] = stagefile - self.dataset["budgetfile"] = budgetfile - self.dataset["budgetcsvfile"] = budgetcsvfile - self.dataset["package_convergence_filename"] = package_convergence_filename - self.dataset["ts6_filename"] = ts6_filename - self.dataset["time_conversion"] = time_conversion - self.dataset["length_conversion"] = length_conversion - - self.dataset["ts_status"] = ts_status if ts_status is not None: - self.dataset["ts_status"] = self._convert_to_string_dataarray( - self.dataset["ts_status"] - ) - self.dataset["ts_stage"] = ts_stage - self.dataset["ts_rainfall"] = ts_rainfall - self.dataset["ts_evaporation"] = ts_evaporation - self.dataset["ts_runoff"] = ts_runoff - self.dataset["ts_inflow"] = ts_inflow - self.dataset["ts_withdrawal"] = ts_withdrawal - self.dataset["ts_auxiliary"] = ts_auxiliary - - self.dataset["ts_rate"] = ts_rate - self.dataset["ts_invert"] = ts_invert - self.dataset["ts_rough"] = ts_rough - self.dataset["ts_width"] = ts_width - self.dataset["ts_slope"] = ts_slope - - self.dataset["lake_tables"] = lake_tables + ts_status = self._convert_to_string_dataarray(xr.DataArray(ts_status)) + + dict_dataset = { + "lake_boundname": lake_boundname, + "lake_number": lake_number, + "lake_starting_stage": lake_starting_stage, + "connection_lake_number": connection_lake_number, + "connection_cell_id": connection_cell_id, + "connection_type": connection_type, + "connection_bed_leak": connection_bed_leak, + "connection_bottom_elevation": connection_bottom_elevation, + "connection_top_elevation": connection_top_elevation, + "connection_width": connection_width, + "connection_length": connection_length, + "outlet_lakein": outlet_lakein, + "outlet_lakeout": outlet_lakeout, + "outlet_couttype": outlet_couttype, + "outlet_invert": outlet_invert, + "outlet_roughness": outlet_roughness, + "outlet_width": outlet_width, + "outlet_slope": outlet_slope, + "print_input": print_input, + "print_stage": print_stage, + "print_flows": print_flows, + "save_flows": save_flows, + "stagefile": stagefile, + "budgetfile": budgetfile, + "budgetcsvfile": budgetcsvfile, + "package_convergence_filename": package_convergence_filename, + "ts6_filename": ts6_filename, + "time_conversion": time_conversion, + "length_conversion": length_conversion, + "ts_status": ts_status, + "ts_stage": ts_stage, + "ts_rainfall": ts_rainfall, + "ts_evaporation": ts_evaporation, + "ts_runoff": ts_runoff, + "ts_inflow": ts_inflow, + "ts_withdrawal": ts_withdrawal, + "ts_auxiliary": ts_auxiliary, + "ts_rate": ts_rate, + "ts_invert": ts_invert, + "ts_rough": ts_rough, + "ts_width": ts_width, + "ts_slope": ts_slope, + "lake_tables": lake_tables, + } + + super().__init__(dict_dataset) self._validate_init_schemata(validate) @@ -840,6 +834,18 @@ def from_lakes_and_outlets( shortname = ts_name[3:] package_content[ts_name] = concatenate_timeseries(outlets, shortname) + # Align timeseries variables + ts_vars = Lake._period_data_lakes + Lake._period_data_outlets + assigned_ts_vars = [ + ts_var for ts_var in ts_vars if package_content[ts_var] is not None + ] + aligned_timeseries = xr.align( + *(package_content[ts_var] for ts_var in assigned_ts_vars), join="outer" + ) + package_content.update( + {ts_var: ts for ts_var, ts in zip(assigned_ts_vars, aligned_timeseries)} + ) + if outlets is not None: outlet_data = create_outlet_data(outlets, name_to_number) package_content.update(outlet_data) diff --git a/imod/mf6/mf6_hfb_adapter.py b/imod/mf6/mf6_hfb_adapter.py index a01f4e7b3..dad7b2db9 100644 --- a/imod/mf6/mf6_hfb_adapter.py +++ b/imod/mf6/mf6_hfb_adapter.py @@ -115,12 +115,14 @@ def __init__( print_input: Union[bool, xr.DataArray] = False, validate: Union[bool, xr.DataArray] = True, ): - super().__init__(locals()) - self.dataset["cell_id1"] = cell_id1 - self.dataset["cell_id2"] = cell_id2 - self.dataset["layer"] = layer - self.dataset["hydraulic_characteristic"] = hydraulic_characteristic - self.dataset["print_input"] = print_input + dict_dataset = { + "cell_id1": cell_id1, + "cell_id2": cell_id2, + "layer": layer, + "hydraulic_characteristic": hydraulic_characteristic, + "print_input": print_input, + } + super().__init__(dict_dataset) def _get_bin_ds(self): bin_ds = self.dataset[ diff --git a/imod/mf6/mf6_wel_adapter.py b/imod/mf6/mf6_wel_adapter.py index 89dc5fd54..d08a68519 100644 --- a/imod/mf6/mf6_wel_adapter.py +++ b/imod/mf6/mf6_wel_adapter.py @@ -14,7 +14,6 @@ import numpy as np -from imod.mf6.auxiliary_variables import expand_transient_auxiliary_variables from imod.mf6.boundary_condition import BoundaryCondition from imod.schemata import DTypeSchema @@ -55,14 +54,13 @@ def __init__( concentration_boundary_type="aux", validate: bool = True, ): - super().__init__() - self.dataset["cellid"] = cellid - self.dataset["rate"] = rate - - if concentration is not None: - self.dataset["concentration"] = concentration - self.dataset["concentration_boundary_type"] = concentration_boundary_type - expand_transient_auxiliary_variables(self) + dict_dataset = { + "cellid": cellid, + "rate": rate, + "concentration": concentration, + "concentration_boundary_type": concentration_boundary_type, + } + super().__init__(dict_dataset) self._validate_init_schemata(validate) def _ds_to_arrdict(self, ds): diff --git a/imod/mf6/mst.py b/imod/mf6/mst.py index 5d6eb5ac5..f46a233d6 100644 --- a/imod/mf6/mst.py +++ b/imod/mf6/mst.py @@ -113,15 +113,17 @@ def __init__( raise ValueError( "zero_order_decay and first_order_decay may not both be True" ) - super().__init__(locals()) - self.dataset["porosity"] = porosity - self.dataset["decay"] = decay - self.dataset["decay_sorbed"] = decay_sorbed - self.dataset["bulk_density"] = bulk_density - self.dataset["distcoef"] = distcoef - self.dataset["sp2"] = sp2 - self.dataset["save_flows"] = save_flows - self.dataset["sorption"] = sorption - self.dataset["zero_order_decay"] = zero_order_decay - self.dataset["first_order_decay"] = first_order_decay + dict_dataset = { + "porosity": porosity, + "decay": decay, + "decay_sorbed": decay_sorbed, + "bulk_density": bulk_density, + "distcoef": distcoef, + "sp2": sp2, + "save_flows": save_flows, + "sorption": sorption, + "zero_order_decay": zero_order_decay, + "first_order_decay": first_order_decay, + } + super().__init__(dict_dataset) self._validate_init_schemata(validate) diff --git a/imod/mf6/npf.py b/imod/mf6/npf.py index 4952d0818..86cbb5fba 100644 --- a/imod/mf6/npf.py +++ b/imod/mf6/npf.py @@ -354,7 +354,6 @@ def __init__( rhs_option=False, validate: bool = True, ): - super().__init__(locals()) # check rewetting if not rewet and any( [rewet_layer, rewet_factor, rewet_iterations, rewet_method] @@ -363,38 +362,38 @@ def __init__( "rewet_layer, rewet_factor, rewet_iterations, and rewet_method should" " all be left at a default value of None if rewet is False." ) - self.dataset["icelltype"] = icelltype - self.dataset["k"] = k - self.dataset["rewet"] = rewet - self.dataset["rewet_layer"] = rewet_layer - self.dataset["rewet_factor"] = rewet_factor - self.dataset["rewet_iterations"] = rewet_iterations - self.dataset["rewet_method"] = rewet_method - self.dataset["k22"] = k22 - self.dataset["k33"] = k33 - self.dataset["angle1"] = angle1 - self.dataset["angle2"] = angle2 - self.dataset["angle3"] = angle3 if cell_averaging is not None: warnings.warn( "Use of `cell_averaging` is deprecated, please use `alternative_cell_averaging` instead", DeprecationWarning, ) - self.dataset["alternative_cell_averaging"] = cell_averaging - else: - self.dataset["alternative_cell_averaging"] = alternative_cell_averaging + alternative_cell_averaging = cell_averaging - self.dataset["save_flows"] = save_flows - self.dataset[ - "starting_head_as_confined_thickness" - ] = starting_head_as_confined_thickness - self.dataset["variable_vertical_conductance"] = variable_vertical_conductance - self.dataset["dewatered"] = dewatered - self.dataset["perched"] = perched - self.dataset["save_specific_discharge"] = save_specific_discharge - self.dataset["save_saturation"] = save_saturation - self.dataset["xt3d_option"] = xt3d_option - self.dataset["rhs_option"] = rhs_option + dict_dataset = { + "icelltype": icelltype, + "k": k, + "rewet": rewet, + "rewet_layer": rewet_layer, + "rewet_factor": rewet_factor, + "rewet_iterations": rewet_iterations, + "rewet_method": rewet_method, + "k22": k22, + "k33": k33, + "angle1": angle1, + "angle2": angle2, + "angle3": angle3, + "alternative_cell_averaging": alternative_cell_averaging, + "save_flows": save_flows, + "starting_head_as_confined_thickness": starting_head_as_confined_thickness, + "variable_vertical_conductance": variable_vertical_conductance, + "dewatered": dewatered, + "perched": perched, + "save_specific_discharge": save_specific_discharge, + "save_saturation": save_saturation, + "xt3d_option": xt3d_option, + "rhs_option": rhs_option, + } + super().__init__(dict_dataset) self._validate_init_schemata(validate) def get_xt3d_option(self) -> bool: diff --git a/imod/mf6/oc.py b/imod/mf6/oc.py index 6047cd64c..8acc70253 100644 --- a/imod/mf6/oc.py +++ b/imod/mf6/oc.py @@ -91,8 +91,6 @@ def __init__( concentration_file=None, validate: bool = True, ): - super().__init__() - save_concentration = ( None if is_dataarray_none(save_concentration) else save_concentration ) @@ -102,12 +100,15 @@ def __init__( if save_head is not None and save_concentration is not None: raise ValueError("save_head and save_concentration cannot both be defined.") - self.dataset["save_head"] = save_head - self.dataset["save_concentration"] = save_concentration - self.dataset["save_budget"] = save_budget - self.dataset["head_file"] = head_file - self.dataset["budget_file"] = budget_file - self.dataset["concentration_file"] = concentration_file + dict_dataset = { + "save_head": save_head, + "save_concentration": save_concentration, + "save_budget": save_budget, + "head_file": head_file, + "budget_file": budget_file, + "concentration_file": concentration_file, + } + super().__init__(dict_dataset) self._validate_init_schemata(validate) def _get_ocsetting(self, setting): diff --git a/imod/mf6/package.py b/imod/mf6/package.py index ed0ddcd6d..a47d0bb41 100644 --- a/imod/mf6/package.py +++ b/imod/mf6/package.py @@ -21,6 +21,7 @@ from imod.mf6.regridding_utils import ( RegridderInstancesCollection, RegridderType, + assign_coord_if_present, get_non_grid_data, ) from imod.mf6.utilities.schemata import filter_schemata_dict @@ -53,7 +54,7 @@ class Package(PackageBase, IPackage, abc.ABC): _write_schemata: Dict[str, List[SchemaType] | Tuple[SchemaType]] = {} _keyword_map: Dict[str, str] = {} - def __init__(self, allargs=None): + def __init__(self, allargs: dict[str, GridDataArray | float | int | bool | str]): super().__init__(allargs) def isel(self): @@ -700,19 +701,15 @@ def regrid_like( regridder_function, target_grid, ) - - new_package = self.__class__(**new_package_data) - - # set dx and dy if present in target_grid - if "dx" in target_grid.coords: - new_package.dataset = new_package.dataset.assign_coords( - {"dx": target_grid.coords["dx"].values[()]} + # set dx and dy if present in target_grid + new_package_data[varname] = assign_coord_if_present( + "dx", target_grid, new_package_data[varname] ) - if "dy" in target_grid.coords: - new_package.dataset = new_package.dataset.assign_coords( - {"dy": target_grid.coords["dy"].values[()]} + new_package_data[varname] = assign_coord_if_present( + "dy", target_grid, new_package_data[varname] ) - return new_package + + return self.__class__(**new_package_data) def skip_masking_dataarray(self, array_name: str) -> bool: if hasattr(self, "_skip_mask_arrays"): diff --git a/imod/mf6/pkgbase.py b/imod/mf6/pkgbase.py index 7c9c69a36..2c3a89a42 100644 --- a/imod/mf6/pkgbase.py +++ b/imod/mf6/pkgbase.py @@ -8,6 +8,7 @@ import imod from imod.mf6.interfaces.ipackagebase import IPackageBase +from imod.typing.grid import GridDataArray, GridDataset, merge_with_dictionary TRANSPORT_PACKAGES = ("adv", "dsp", "ssm", "mst", "ist", "src") EXCHANGE_PACKAGES = ("gwfgwf", "gwfgwt") @@ -25,20 +26,19 @@ class PackageBase(IPackageBase, abc.ABC): def __new__(cls, *_, **__): return super(PackageBase, cls).__new__(cls) - def __init__(self, allargs=None): - if allargs is not None: - for arg in allargs.values(): - if isinstance(arg, xu.UgridDataArray): - self.__dataset = xu.UgridDataset(grids=arg.ugrid.grid) - return - self.__dataset = xr.Dataset() + def __init__( + self, variables_to_merge: dict[str, GridDataArray | float | int | bool | str] + ): + # Merge variables, perform exact join to verify if coordinates values + # are consistent amongst variables. + self.__dataset = merge_with_dictionary(variables_to_merge, join="exact") @property - def dataset(self) -> xr.Dataset: + def dataset(self) -> GridDataset: return self.__dataset @dataset.setter - def dataset(self, value: xr.Dataset) -> None: + def dataset(self, value: GridDataset) -> None: self.__dataset = value def __getitem__(self, key): diff --git a/imod/mf6/rch.py b/imod/mf6/rch.py index 2fcca90c0..5ac98fa00 100644 --- a/imod/mf6/rch.py +++ b/imod/mf6/rch.py @@ -1,6 +1,5 @@ import numpy as np -from imod.mf6.auxiliary_variables import expand_transient_auxiliary_variables from imod.mf6.boundary_condition import BoundaryCondition from imod.mf6.regridding_utils import RegridderType from imod.mf6.validation import BOUNDARY_DIMS_SCHEMA, CONC_DIMS_SCHEMA @@ -120,17 +119,17 @@ def __init__( validate: bool = True, repeat_stress=None, ): - super().__init__(locals()) - self.dataset["rate"] = rate - if concentration is not None: - self.dataset["concentration"] = concentration - self.dataset["concentration_boundary_type"] = concentration_boundary_type - expand_transient_auxiliary_variables(self) - self.dataset["print_input"] = print_input - self.dataset["print_flows"] = print_flows - self.dataset["save_flows"] = save_flows - self.dataset["observations"] = observations - self.dataset["repeat_stress"] = repeat_stress + dict_dataset = { + "rate": rate, + "concentration": concentration, + "concentration_boundary_type": concentration_boundary_type, + "print_input": print_input, + "print_flows": print_flows, + "save_flows": save_flows, + "observations": observations, + "repeat_stress": repeat_stress, + } + super().__init__(dict_dataset) self._validate_init_schemata(validate) def _validate(self, schemata, **kwargs): diff --git a/imod/mf6/regridding_utils.py b/imod/mf6/regridding_utils.py index 9cb6aa0b2..83a9fc98b 100644 --- a/imod/mf6/regridding_utils.py +++ b/imod/mf6/regridding_utils.py @@ -1,11 +1,13 @@ import abc from enum import Enum -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import xarray as xr import xugrid as xu from xugrid.regrid.regridder import BaseRegridder +from imod.typing.grid import GridDataArray + class RegridderType(Enum): """ @@ -127,3 +129,18 @@ def get_non_grid_data(package, grid_names: List[str]) -> Dict[str, any]: else: result[name] = package.dataset[name].values[()] return result + + +def assign_coord_if_present( + coordname: str, target_grid: GridDataArray, maybe_has_coords_attr: Any +): + """ + If ``maybe_has_coords`` has a ``coords`` attribute and if coordname in + target_grid, copy coord. + """ + if coordname in target_grid.coords: + if coordname in target_grid.coords and hasattr(maybe_has_coords_attr, "coords"): + maybe_has_coords_attr = maybe_has_coords_attr.assign_coords( + {coordname: target_grid.coords[coordname].values[()]} + ) + return maybe_has_coords_attr diff --git a/imod/mf6/riv.py b/imod/mf6/riv.py index 50ca419e2..d8f3bf904 100644 --- a/imod/mf6/riv.py +++ b/imod/mf6/riv.py @@ -1,6 +1,5 @@ import numpy as np -from imod.mf6.auxiliary_variables import expand_transient_auxiliary_variables from imod.mf6.boundary_condition import BoundaryCondition from imod.mf6.regridding_utils import RegridderType from imod.mf6.validation import BOUNDARY_DIMS_SCHEMA, CONC_DIMS_SCHEMA @@ -144,19 +143,20 @@ def __init__( validate: bool = True, repeat_stress=None, ): - super().__init__(locals()) - self.dataset["stage"] = stage - self.dataset["conductance"] = conductance - self.dataset["bottom_elevation"] = bottom_elevation - if concentration is not None: - self.dataset["concentration"] = concentration - self.dataset["concentration_boundary_type"] = concentration_boundary_type - expand_transient_auxiliary_variables(self) - self.dataset["print_input"] = print_input - self.dataset["print_flows"] = print_flows - self.dataset["save_flows"] = save_flows - self.dataset["observations"] = observations - self.dataset["repeat_stress"] = repeat_stress + dict_dataset = { + "stage": stage, + "conductance": conductance, + "bottom_elevation": bottom_elevation, + "concentration": concentration, + "concentration_boundary_type": concentration_boundary_type, + "print_input": print_input, + "print_flows": print_flows, + "save_flows": save_flows, + "observations": observations, + "repeat_stress": repeat_stress, + } + super().__init__(dict_dataset) + self._validate_init_schemata(validate) def _validate(self, schemata, **kwargs): diff --git a/imod/mf6/src.py b/imod/mf6/src.py index d395147b5..195de09ab 100644 --- a/imod/mf6/src.py +++ b/imod/mf6/src.py @@ -76,10 +76,12 @@ def __init__( observations=None, validate: bool = True, ): - super().__init__() - self.dataset["rate"] = rate - self.dataset["print_input"] = print_input - self.dataset["print_flows"] = print_flows - self.dataset["save_flows"] = save_flows - self.dataset["observations"] = observations + dict_dataset = { + "rate": rate, + "print_input": print_input, + "print_flows": print_flows, + "save_flows": save_flows, + "observations": observations, + } + super().__init__(dict_dataset) self._validate_init_schemata(validate) diff --git a/imod/mf6/ssm.py b/imod/mf6/ssm.py index 57acf51b1..f717a6bb0 100644 --- a/imod/mf6/ssm.py +++ b/imod/mf6/ssm.py @@ -47,18 +47,16 @@ def __init__( save_flows: bool = False, validate: bool = True, ): - super().__init__() - # By sharing the index, this will raise an error if lengths do not - # match. - self.dataset["package_names"] = with_index_dim(package_names) - self.dataset["concentration_boundary_type"] = with_index_dim( - concentration_boundary_type - ) - self.dataset["auxiliary_variable_name"] = with_index_dim( - auxiliary_variable_name - ) - self.dataset["print_flows"] = print_flows - self.dataset["save_flows"] = save_flows + dict_dataset = { + # By sharing the index, this will raise an error if lengths do not + # match. + "package_names": with_index_dim(package_names), + "concentration_boundary_type": with_index_dim(concentration_boundary_type), + "auxiliary_variable_name": with_index_dim(auxiliary_variable_name), + "print_flows": print_flows, + "save_flows": save_flows, + } + super().__init__(dict_dataset) self._validate_init_schemata(validate) def render(self, directory, pkgname, globaltimes, binary): diff --git a/imod/mf6/sto.py b/imod/mf6/sto.py index 875f3b0cb..85179bfe6 100644 --- a/imod/mf6/sto.py +++ b/imod/mf6/sto.py @@ -172,12 +172,14 @@ def __init__( save_flows: bool = False, validate: bool = True, ): - super().__init__(locals()) - self.dataset["specific_storage"] = specific_storage - self.dataset["specific_yield"] = specific_yield - self.dataset["convertible"] = convertible - self.dataset["transient"] = transient - self.dataset["save_flows"] = save_flows + dict_dataset = { + "specific_storage": specific_storage, + "specific_yield": specific_yield, + "convertible": convertible, + "transient": transient, + "save_flows": save_flows, + } + super().__init__(dict_dataset) self._validate_init_schemata(validate) def render(self, directory, pkgname, globaltimes, binary): @@ -304,12 +306,14 @@ def __init__( save_flows: bool = False, validate: bool = True, ): - super().__init__(locals()) - self.dataset["storage_coefficient"] = storage_coefficient - self.dataset["specific_yield"] = specific_yield - self.dataset["convertible"] = convertible - self.dataset["transient"] = transient - self.dataset["save_flows"] = save_flows + dict_dataset = { + "storage_coefficient": storage_coefficient, + "specific_yield": specific_yield, + "convertible": convertible, + "transient": transient, + "save_flows": save_flows, + } + super().__init__(dict_dataset) self._validate_init_schemata(validate) def render(self, directory, pkgname, globaltimes, binary): diff --git a/imod/mf6/timedis.py b/imod/mf6/timedis.py index 914280695..582081f47 100644 --- a/imod/mf6/timedis.py +++ b/imod/mf6/timedis.py @@ -61,10 +61,12 @@ def __init__( timestep_multiplier=1.0, validate: bool = True, ): - super().__init__() - self.dataset["timestep_duration"] = timestep_duration - self.dataset["n_timesteps"] = n_timesteps - self.dataset["timestep_multiplier"] = timestep_multiplier + dict_dataset = { + "timestep_duration": timestep_duration, + "n_timesteps": n_timesteps, + "timestep_multiplier": timestep_multiplier, + } + super().__init__(dict_dataset) self._validate_init_schemata(validate) def render(self): diff --git a/imod/mf6/uzf.py b/imod/mf6/uzf.py index d9e472012..47f032181 100644 --- a/imod/mf6/uzf.py +++ b/imod/mf6/uzf.py @@ -229,16 +229,45 @@ def __init__( timeseries=None, validate: bool = True, ): - super().__init__(locals()) - # Package data - self.dataset["surface_depression_depth"] = surface_depression_depth - self.dataset["kv_sat"] = kv_sat - self.dataset["theta_res"] = theta_res - self.dataset["theta_sat"] = theta_sat - self.dataset["theta_init"] = theta_init - self.dataset["epsilon"] = epsilon - - # Stress period data + landflag = self._determine_landflag(kv_sat) + iuzno = self._create_uzf_numbers(landflag) + ivertcon = self._determine_vertical_connection(iuzno) + + dict_dataset = { + # Package data + "surface_depression_depth": surface_depression_depth, + "kv_sat": kv_sat, + "theta_res": theta_res, + "theta_sat": theta_sat, + "theta_init": theta_init, + "epsilon": epsilon, + # Stress period data + "infiltration_rate": infiltration_rate, + "et_pot": et_pot, + "extinction_depth": extinction_depth, + "extinction_theta": extinction_theta, + "air_entry_potential": air_entry_potential, + "root_potential": root_potential, + "root_activity": root_activity, + # Dimensions + "ntrailwaves": ntrailwaves, + "nwavesets": nwavesets, + # Options + "groundwater_ET_function": groundwater_ET_function, + "simulate_gwseep": simulate_groundwater_seepage, + "print_input": print_input, + "print_flows": print_flows, + "save_flows": save_flows, + "observations": observations, + "water_mover": water_mover, + "timeseries": timeseries, + # Additonal indices for Packagedata + "landflag": landflag, + "iuzno": iuzno, + "ivertcon": ivertcon, + } + super().__init__(dict_dataset) + self.dataset["iuzno"].name = "uzf_number" self._check_options( groundwater_ET_function, et_pot, @@ -248,35 +277,6 @@ def __init__( root_potential, root_activity, ) - - self.dataset["infiltration_rate"] = infiltration_rate - self.dataset["et_pot"] = et_pot - self.dataset["extinction_depth"] = extinction_depth - self.dataset["extinction_theta"] = extinction_theta - self.dataset["air_entry_potential"] = air_entry_potential - self.dataset["root_potential"] = root_potential - self.dataset["root_activity"] = root_activity - - # Dimensions - self.dataset["ntrailwaves"] = ntrailwaves - self.dataset["nwavesets"] = nwavesets - - # Options - self.dataset["groundwater_ET_function"] = groundwater_ET_function - self.dataset["simulate_gwseep"] = simulate_groundwater_seepage - self.dataset["print_input"] = print_input - self.dataset["print_flows"] = print_flows - self.dataset["save_flows"] = save_flows - self.dataset["observations"] = observations - self.dataset["water_mover"] = water_mover - self.dataset["timeseries"] = timeseries - - # Additonal indices for Packagedata - self.dataset["landflag"] = self._determine_landflag(kv_sat) - self.dataset["iuzno"] = self._create_uzf_numbers(self["landflag"]) - self.dataset["iuzno"].name = "uzf_number" - self.dataset["ivertcon"] = self._determine_vertical_connection(self["iuzno"]) - self._validate_init_schemata(validate) def fill_stress_perioddata(self): diff --git a/imod/mf6/wel.py b/imod/mf6/wel.py index 5d640b435..7897cdf4b 100644 --- a/imod/mf6/wel.py +++ b/imod/mf6/wel.py @@ -11,7 +11,6 @@ import xugrid as xu import imod -from imod.mf6.auxiliary_variables import expand_transient_auxiliary_variables from imod.mf6.boundary_condition import ( BoundaryCondition, DisStructuredBoundaryCondition, @@ -169,27 +168,27 @@ def __init__( validate: bool = True, repeat_stress: Optional[xr.DataArray] = None, ): - super().__init__() - self.dataset["screen_top"] = _assign_dims(screen_top) - self.dataset["screen_bottom"] = _assign_dims(screen_bottom) - self.dataset["y"] = _assign_dims(y) - self.dataset["x"] = _assign_dims(x) - self.dataset["rate"] = _assign_dims(rate) if id is None: - id = np.arange(self.dataset["x"].size).astype(str) - self.dataset["id"] = _assign_dims(id) - self.dataset["minimum_k"] = minimum_k - self.dataset["minimum_thickness"] = minimum_thickness - - self.dataset["print_input"] = print_input - self.dataset["print_flows"] = print_flows - self.dataset["save_flows"] = save_flows - self.dataset["observations"] = observations - self.dataset["repeat_stress"] = repeat_stress - if concentration is not None: - self.dataset["concentration"] = concentration - self.dataset["concentration_boundary_type"] = concentration_boundary_type - + id = np.arange(len(x)).astype(str) + + dict_dataset = { + "screen_top": _assign_dims(screen_top), + "screen_bottom": _assign_dims(screen_bottom), + "y": _assign_dims(y), + "x": _assign_dims(x), + "rate": _assign_dims(rate), + "id": _assign_dims(id), + "minimum_k": minimum_k, + "minimum_thickness": minimum_thickness, + "print_input": print_input, + "print_flows": print_flows, + "save_flows": save_flows, + "observations": observations, + "repeat_stress": repeat_stress, + "concentration": concentration, + "concentration_boundary_type": concentration_boundary_type, + } + super().__init__(dict_dataset) self._validate_init_schemata(validate) @classmethod @@ -664,22 +663,20 @@ def __init__( validate: bool = True, repeat_stress=None, ): - super().__init__() - self.dataset["layer"] = _assign_dims(layer) - self.dataset["row"] = _assign_dims(row) - self.dataset["column"] = _assign_dims(column) - self.dataset["rate"] = _assign_dims(rate) - self.dataset["print_input"] = print_input - self.dataset["print_flows"] = print_flows - self.dataset["save_flows"] = save_flows - self.dataset["observations"] = observations - self.dataset["repeat_stress"] = repeat_stress - - if concentration is not None: - self.dataset["concentration"] = concentration - self.dataset["concentration_boundary_type"] = concentration_boundary_type - expand_transient_auxiliary_variables(self) - + dict_dataset = { + "layer": _assign_dims(layer), + "row": _assign_dims(row), + "column": _assign_dims(column), + "rate": _assign_dims(rate), + "print_input": print_input, + "print_flows": print_flows, + "save_flows": save_flows, + "observations": observations, + "repeat_stress": repeat_stress, + "concentration": concentration, + "concentration_boundary_type": concentration_boundary_type, + } + super().__init__(dict_dataset) self._validate_init_schemata(validate) warnings.warn( @@ -823,20 +820,18 @@ def __init__( observations=None, validate: bool = True, ): - super().__init__() - self.dataset["layer"] = _assign_dims(layer) - self.dataset["cell2d"] = _assign_dims(cell2d) - self.dataset["rate"] = _assign_dims(rate) - self.dataset["print_input"] = print_input - self.dataset["print_flows"] = print_flows - self.dataset["save_flows"] = save_flows - self.dataset["observations"] = observations - - if concentration is not None: - self.dataset["concentration"] = concentration - self.dataset["concentration_boundary_type"] = concentration_boundary_type - expand_transient_auxiliary_variables(self) - + dict_dataset = { + "layer": _assign_dims(layer), + "cell2d": _assign_dims(cell2d), + "rate": _assign_dims(rate), + "print_input": print_input, + "print_flows": print_flows, + "save_flows": save_flows, + "observations": observations, + "concentration": concentration, + "concentration_boundary_type": concentration_boundary_type, + } + super().__init__(dict_dataset) self._validate_init_schemata(validate) warnings.warn( diff --git a/imod/tests/test_mf6/test_exchangebase.py b/imod/tests/test_mf6/test_exchangebase.py index a9410c78f..6659d8cb0 100644 --- a/imod/tests/test_mf6/test_exchangebase.py +++ b/imod/tests/test_mf6/test_exchangebase.py @@ -9,11 +9,13 @@ class DummyExchange(ExchangeBase): _pkg_id = "gwfgwt" def __init__(self, model_id1: str = None, model_id2: str = None): - super().__init__() + variables_to_merge = {} if model_id1: - self.dataset["model_name_1"] = model_id1 + variables_to_merge["model_name_1"] = model_id1 if model_id2: - self.dataset["model_name_2"] = model_id2 + variables_to_merge["model_name_2"] = model_id2 + + super().__init__(variables_to_merge) def test_package_name_construct_name(): diff --git a/imod/tests/test_mf6/test_mf6_lak.py b/imod/tests/test_mf6/test_mf6_lak.py index 8ca5375f5..14bddcd68 100644 --- a/imod/tests/test_mf6/test_mf6_lak.py +++ b/imod/tests/test_mf6/test_mf6_lak.py @@ -395,7 +395,8 @@ def test_lake_rendering_transient_all_timeseries(basic_dis, tmp_path): ) outlet1 = OutletManning("Naardermeer", "IJsselmeer", invert, 2, 3, 4) outlet2 = OutletSpecified("IJsselmeer", "Naardermeer", rate) - outlet3 = OutletWeir("IJsselmeer", "Naardermeer", invert, numeric) + invert_outlet3, width_outlet3 = xr.align(invert, numeric, join="inner") + outlet3 = OutletWeir("IJsselmeer", "Naardermeer", invert_outlet3, width_outlet3) lake_with_status = mf_lake.create_lake_data_structured( is_lake1, diff --git a/imod/tests/test_mf6/test_mf6_riv.py b/imod/tests/test_mf6/test_mf6_riv.py index ccbab31be..9e224ec05 100644 --- a/imod/tests/test_mf6/test_mf6_riv.py +++ b/imod/tests/test_mf6/test_mf6_riv.py @@ -6,6 +6,8 @@ import numpy as np import pytest import xarray as xr +import xugrid as xu +from pytest_cases import parametrize_with_cases import imod from imod.mf6.write_context import WriteContext @@ -49,8 +51,29 @@ def riv_dict(make_da): return dict(stage=da, conductance=da, bottom_elevation=bottom) -def test_render(riv_dict): - river = imod.mf6.River(**riv_dict) +def make_dict_unstructured(d): + return {key: xu.UgridDataArray.from_structured(value) for key, value in d.items()} + + +class RivCases: + def case_structured(self, riv_dict): + return riv_dict + + def case_unstructured(self, riv_dict): + return make_dict_unstructured(riv_dict) + + +class RivDisCases: + def case_structured(self, riv_dict, dis_dict): + return riv_dict, dis_dict + + def case_unstructured(self, riv_dict, dis_dict): + return make_dict_unstructured(riv_dict), make_dict_unstructured(dis_dict) + + +@parametrize_with_cases("riv_data", cases=RivCases) +def test_render(riv_data): + river = imod.mf6.River(**riv_data) directory = pathlib.Path("mymodel") globaltimes = [np.datetime64("2000-01-01")] actual = river.render(directory, "river", globaltimes, True) @@ -71,21 +94,23 @@ def test_render(riv_dict): assert actual == expected -def test_wrong_dtype(riv_dict): - riv_dict["stage"] = riv_dict["stage"].astype(int) +@parametrize_with_cases("riv_data", cases=RivCases) +def test_wrong_dtype(riv_data): + riv_data["stage"] = riv_data["stage"].astype(int) with pytest.raises(ValidationError): - imod.mf6.River(**riv_dict) + imod.mf6.River(**riv_data) -def test_all_nan(riv_dict, dis_dict): +@parametrize_with_cases("riv_data,dis_data", cases=RivDisCases) +def test_all_nan(riv_data, dis_data): # Use where to set everything to np.nan for var in ["stage", "conductance", "bottom_elevation"]: - riv_dict[var] = riv_dict[var].where(False) + riv_data[var] = riv_data[var].where(False) - river = imod.mf6.River(**riv_dict) + river = imod.mf6.River(**riv_data) - errors = river._validate(river._write_schemata, **dis_dict) + errors = river._validate(river._write_schemata, **dis_data) assert len(errors) == 1 @@ -93,20 +118,22 @@ def test_all_nan(riv_dict, dis_dict): assert var == "stage" -def test_inconsistent_nan(riv_dict, dis_dict): - riv_dict["stage"][:, 1, 2] = np.nan - river = imod.mf6.River(**riv_dict) +@parametrize_with_cases("riv_data,dis_data", cases=RivDisCases) +def test_inconsistent_nan(riv_data, dis_data): + riv_data["stage"][..., 2] = np.nan + river = imod.mf6.River(**riv_data) - errors = river._validate(river._write_schemata, **dis_dict) + errors = river._validate(river._write_schemata, **dis_data) assert len(errors) == 1 -def test_check_layer(riv_dict): +@parametrize_with_cases("riv_data", cases=RivCases) +def test_check_layer(riv_data): """ Test for error thrown if variable has no layer coord """ - riv_dict["stage"] = riv_dict["stage"].sel(layer=2, drop=True) + riv_data["stage"] = riv_data["stage"].sel(layer=2, drop=True) message = textwrap.dedent( """ @@ -118,7 +145,7 @@ def test_check_layer(riv_dict): ValidationError, match=re.escape(message), ): - imod.mf6.River(**riv_dict) + imod.mf6.River(**riv_data) def test_check_dimsize_zero(): @@ -152,69 +179,73 @@ def test_check_dimsize_zero(): imod.mf6.River(stage=da, conductance=da, bottom_elevation=da - 1.0) -def test_check_zero_conductance(riv_dict, dis_dict): +@parametrize_with_cases("riv_data,dis_data", cases=RivDisCases) +def test_check_zero_conductance(riv_data, dis_data): """ Test for zero conductance """ - riv_dict["conductance"] = riv_dict["conductance"] * 0.0 + riv_data["conductance"] = riv_data["conductance"] * 0.0 - river = imod.mf6.River(**riv_dict) + river = imod.mf6.River(**riv_data) - errors = river._validate(river._write_schemata, **dis_dict) + errors = river._validate(river._write_schemata, **dis_data) assert len(errors) == 1 for var, var_errors in errors.items(): assert var == "conductance" -def test_check_bottom_above_stage(riv_dict, dis_dict): +@parametrize_with_cases("riv_data,dis_data", cases=RivDisCases) +def test_check_bottom_above_stage(riv_data, dis_data): """ Check that river bottom is not above stage. """ - riv_dict["bottom_elevation"] = riv_dict["bottom_elevation"] + 10.0 + riv_data["bottom_elevation"] = riv_data["bottom_elevation"] + 10.0 - river = imod.mf6.River(**riv_dict) + river = imod.mf6.River(**riv_data) - errors = river._validate(river._write_schemata, **dis_dict) + errors = river._validate(river._write_schemata, **dis_data) assert len(errors) == 1 for var, var_errors in errors.items(): assert var == "stage" -def test_check_riv_bottom_above_dis_bottom(riv_dict, dis_dict): +@parametrize_with_cases("riv_data,dis_data", cases=RivDisCases) +def test_check_riv_bottom_above_dis_bottom(riv_data, dis_data): """ Check that river bottom not above dis bottom. """ - river = imod.mf6.River(**riv_dict) + river = imod.mf6.River(**riv_data) - river._validate(river._write_schemata, **dis_dict) + river._validate(river._write_schemata, **dis_data) - dis_dict["bottom"] += 2.0 + dis_data["bottom"] += 2.0 - errors = river._validate(river._write_schemata, **dis_dict) + errors = river._validate(river._write_schemata, **dis_data) assert len(errors) == 1 for var, var_errors in errors.items(): assert var == "bottom_elevation" -def test_check_boundary_outside_active_domain(riv_dict, dis_dict): +@parametrize_with_cases("riv_data,dis_data", cases=RivDisCases) +def test_check_boundary_outside_active_domain(riv_data, dis_data): """ Check that river not outside idomain """ - river = imod.mf6.River(**riv_dict) + river = imod.mf6.River(**riv_data) - errors = river._validate(river._write_schemata, **dis_dict) + errors = river._validate(river._write_schemata, **dis_data) assert len(errors) == 0 - dis_dict["idomain"][0, 0, 0] = 0 + dis_data["idomain"][..., 0] = 0 - errors = river._validate(river._write_schemata, **dis_dict) + errors = river._validate(river._write_schemata, **dis_data) assert len(errors) == 1 diff --git a/imod/tests/test_mf6/test_mf6_sto.py b/imod/tests/test_mf6/test_mf6_sto.py index b78c0e09c..185f67482 100644 --- a/imod/tests/test_mf6/test_mf6_sto.py +++ b/imod/tests/test_mf6/test_mf6_sto.py @@ -46,9 +46,9 @@ def convertible(idomain): @pytest.fixture(scope="function") def dis(idomain): - top = idomain.sel(layer=1) + top = idomain.sel(layer=1, drop=True) bottom = idomain - xr.DataArray( - data=[1.0, 2.0], dims=("layer",), coords={"layer": [2, 3]} + data=[0.0, 1.0, 2.0], dims=("layer",), coords={"layer": [1, 2, 3]} ) return imod.mf6.StructuredDiscretization( diff --git a/imod/tests/test_mf6/test_mf6_transport_model.py b/imod/tests/test_mf6/test_mf6_transport_model.py index ceff4999b..35c0d911d 100644 --- a/imod/tests/test_mf6/test_mf6_transport_model.py +++ b/imod/tests/test_mf6/test_mf6_transport_model.py @@ -2,6 +2,7 @@ import numpy as np import pytest +import xarray as xr import imod from imod.mf6.adv import AdvectionCentral @@ -45,7 +46,10 @@ def test_transport_model_rendering(): def test_assign_flow_discretization(basic_dis, concentration_fc): # define a grid - idomain, _, bottom = basic_dis + _, _, bottom = basic_dis + idomain = xr.ones_like( + concentration_fc.isel(species=0, time=0, drop=True), dtype=int + ) like = idomain.sel(layer=1).astype(np.float32) concentration = concentration_fc.sel(layer=1) @@ -54,6 +58,7 @@ def test_assign_flow_discretization(basic_dis, concentration_fc): gwf_model["dis"] = imod.mf6.StructuredDiscretization( top=200.0, bottom=bottom, idomain=idomain ) + gwf_model["riv-1"] = imod.mf6.River( stage=like, conductance=like, diff --git a/imod/typing/grid.py b/imod/typing/grid.py index 4e23f102c..79d262638 100644 --- a/imod/typing/grid.py +++ b/imod/typing/grid.py @@ -1,3 +1,5 @@ +import pickle +import textwrap from typing import Callable, Sequence import numpy as np @@ -93,6 +95,50 @@ def _type_dispatch_functions_on_grid_sequence( ) +# Typedispatching doesn't work based on types of dict elements, therefore resort +# to manual type testing +def _type_dispatch_functions_on_dict( + dict_of_objects: dict[str, GridDataArray | float | bool | int], + unstructured_func: Callable, + structured_func: Callable, + *args, + **kwargs, +): + """ + Typedispatch function on grid and scalar variables provided in dictionary. + Types do not need to be homogeneous as scalars and grids can be mixed. No + mixing of structured and unstructured grids is allowed. Also allows running + function on dictionary with purely scalars, in which case it will call to + the xarray function. + """ + + error_msg = textwrap.dedent( + """ + Received both structured grid (xr.DataArray) and xu.UgridDataArray. This + means structured grids as well as unstructured grids were provided. + """ + ) + + if dict_of_objects is None: + return xr.Dataset() + + types = [type(arg) for arg in dict_of_objects.values()] + has_unstructured = xu.UgridDataArray in types + # Test structured if xr.DataArray and spatial. + has_structured_grid = any( + [ + isinstance(arg, xr.DataArray) and is_spatial_2D(arg) + for arg in dict_of_objects.values() + ] + ) + if has_structured_grid and has_unstructured: + raise TypeError(error_msg) + if has_unstructured: + return unstructured_func([dict_of_objects], *args, **kwargs) + + return structured_func([dict_of_objects], *args, **kwargs) + + def merge( objects: Sequence[GridDataArray | GridDataset], *args, **kwargs ) -> GridDataset: @@ -117,6 +163,65 @@ def concat( ) +def merge_unstructured_dataset(variables_to_merge: list[dict], *args, **kwargs): + """ + Work around xugrid issue https://github.com/Deltares/xugrid/issues/179 + + Expects only one dictionary in list. List is used to have same API as + xr.merge(). + + Merges unstructured grids first, then manually assigns scalar variables. + """ + if len(variables_to_merge) > 1: + raise ValueError( + f"Only one dict of variables expected, got {len(variables_to_merge)}" + ) + + variables_to_merge_dict = variables_to_merge[0] + + if not isinstance(variables_to_merge_dict, dict): + raise TypeError(f"Expected dict, got {type(variables_to_merge_dict)}") + + # Separate variables into list of grids and dict of scalar variables + grids_ls = [] + scalar_dict = {} + for name, variable in variables_to_merge_dict.items(): + if isinstance(variable, xu.UgridDataArray): + grids_ls.append(variable.rename(name)) + else: + scalar_dict[name] = variable + + # Merge grids + dataset = xu.merge(grids_ls, *args, **kwargs) + + # Temporarily work around this xugrid issue, until fixed: + # https://github.com/Deltares/xugrid/issues/206 + grid_hashes = [hash(pickle.dumps(grid)) for grid in dataset.ugrid.grids] + unique_grid_hashes = np.unique(grid_hashes) + if unique_grid_hashes.size > 1: + raise ValueError( + "Multiple grids provided, please provide data on one unique grid" + ) + else: + # Possibly won't work anymore if this ever gets implemented: + # https://github.com/Deltares/xugrid/issues/195 + dataset._grids = [dataset.grids[0]] + + # Assign scalar variables manually + for name, variable in scalar_dict.items(): + dataset[name] = variable + + return dataset + + +def merge_with_dictionary( + variables_to_merge: dict[str, GridDataArray | float | bool | int], *args, **kwargs +): + return _type_dispatch_functions_on_dict( + variables_to_merge, merge_unstructured_dataset, xr.merge, *args, **kwargs + ) + + @typedispatch def bounding_polygon(active: xr.DataArray): """Return bounding polygon of active cells""" @@ -156,3 +261,8 @@ def is_spatial_2D(array: xu.UgridDataArray) -> bool: has_spatial_coords = face_dim in coords has_spatial_dims = face_dim in dims return has_spatial_dims & has_spatial_coords + + +@typedispatch +def is_spatial_2D(_: object) -> bool: + return False