Skip to content

Commit

Permalink
Control .to_yaml and .to_dict methods
Browse files Browse the repository at this point in the history
  • Loading branch information
jmccreight committed Oct 10, 2023
1 parent b58484d commit 554899c
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 103 deletions.
8 changes: 8 additions & 0 deletions autotest/test_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,11 @@ def test_setitem_setattr(domain):
# The value for options must be a dictionary
with pytest.raises(ValueError):
ctl.options = None


def test_yaml_roundtrip(domain, tmp_path):
ctl = Control.load_prms(domain["control_file"], warn_unused_options=False)
yml_file = tmp_path / "control.yaml"
ctl.to_yaml(yml_file)
ctl_2 = Control.from_yaml(yml_file)
np.testing.assert_equal(ctl.to_dict(), ctl_2.to_dict())
9 changes: 5 additions & 4 deletions autotest/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
}


invoke_style = ("prms", "model_dict", "model_dict_from_yml")
invoke_style = ("prms", "model_dict", "model_dict_from_yaml")


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -106,9 +106,9 @@ def model_args(domain, control, discretization, request):
"parameters": None,
}

elif invoke_style == "model_dict_from_yml":
yml_file = domain["dir"] / "nhm_model.yml"
model_dict = Model.model_dict_from_yml(yml_file)
elif invoke_style == "model_dict_from_yaml":
yaml_file = domain["dir"] / "nhm_model.yml"
model_dict = Model.model_dict_from_yaml(yaml_file)

args = {
"process_list_or_model_dict": model_dict,
Expand Down Expand Up @@ -143,6 +143,7 @@ def test_model(domain, model_args, tmp_path):
control = model_args["control"]

control.options["input_dir"] = input_dir
control.options["netcdf_output_dir"] = tmp_path / "output"

model = Model(**model_args)

Expand Down
15 changes: 3 additions & 12 deletions examples/00_processes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
}
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -545,22 +544,14 @@
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:pws2] *",
"language": "python",
"name": "conda-env-pws2-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
"name": "ipython"
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"nbconvert_exporter": "python"
}
},
"nbformat": 4,
Expand Down
87 changes: 35 additions & 52 deletions examples/01_multi-process_models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -351,15 +351,8 @@
"source": [
"run_dir = pl.Path(nb_output_dir / \"nhm_yaml\")\n",
"run_dir.mkdir(exist_ok=True)\n",
"control_dict = control.options | {\n",
" \"start_time\": str(control.start_time),\n",
" \"end_time\": str(control.end_time),\n",
" \"time_step\": str(control.time_step)[0:2],\n",
" \"time_step_units\": str(control.time_step)[3:4],\n",
" \"netcdf_output_dir\": run_dir,\n",
"}\n",
"\n",
"pprint(control_dict, sort_dicts=False)"
"control_yaml_file = run_dir / \"control.yml\"\n",
"control.to_yaml(control_yaml_file)"
]
},
{
Expand Down Expand Up @@ -389,11 +382,7 @@
" elif isinstance(val, pl.Path):\n",
" the_dict[key] = str(val)\n",
"\n",
" return the_dict\n",
"\n",
"\n",
"control_dict = dict_pl_to_str(control_dict)\n",
"pprint(control_dict, sort_dicts=False)"
" return the_dict"
]
},
{
Expand All @@ -411,7 +400,6 @@
"metadata": {},
"outputs": [],
"source": [
"control_yaml_file = run_dir / \"control.yml\"\n",
"model_dict = {\n",
" \"control\": control_yaml_file.resolve(),\n",
" \"dis_hru\": domain_dir / \"parameters_dis_hru.nc\",\n",
Expand Down Expand Up @@ -486,11 +474,8 @@
"outputs": [],
"source": [
"model_dict_yaml_file = run_dir / \"model_dict.yml\"\n",
"# the control yaml file was given above and is in the model_dict\n",
"dump_dict = {control_yaml_file: control_dict, model_dict_yaml_file: model_dict}\n",
"for key, val in dump_dict.items():\n",
" with open(key, \"w\") as file:\n",
" documents = yaml.dump(val, file)"
"with open(model_dict_yaml_file, \"w\") as file:\n",
" _ = yaml.dump(model_dict, file)"
]
},
{
Expand Down Expand Up @@ -536,7 +521,7 @@
"metadata": {},
"outputs": [],
"source": [
"model_yml = pws.Model.from_yml(model_dict_yaml_file)\n",
"model_yml = pws.Model.from_yaml(model_dict_yaml_file)\n",
"model_yml"
]
},
Expand Down Expand Up @@ -698,8 +683,8 @@
"metadata": {},
"outputs": [],
"source": [
"control_dict_copy = deepcopy(control_dict)\n",
"model_dict_copy = deepcopy(model_dict)"
"run_dir = pl.Path(nb_output_dir / \"yml_less_output\").resolve()\n",
"run_dir.mkdir(exist_ok=True)"
]
},
{
Expand All @@ -709,20 +694,20 @@
"metadata": {},
"outputs": [],
"source": [
"run_dir = pl.Path(nb_output_dir / \"yml_less_output\").resolve()\n",
"run_dir.mkdir(exist_ok=True)\n",
"\n",
"control_dict_copy[\"netcdf_output_dir\"] = str(run_dir.resolve())\n",
"control_yaml_file = run_dir / \"control.yml\"\n",
"control_dict_copy[\"netcdf_output_var_names\"] = [\n",
"control_cp = deepcopy(control)\n",
"control_cp.options[\"netcdf_output_dir\"] = str(run_dir.resolve())\n",
"control_cp.options[\"netcdf_output_var_names\"] = [\n",
" var\n",
" for ll in [\n",
" pws.PRMSGroundwater.get_variables(),\n",
" pws.PRMSChannel.get_variables(),\n",
" ]\n",
" for var in ll\n",
"]\n",
"pprint(control_dict_copy, sort_dicts=False)"
"pprint(control_cp.to_dict(), sort_dicts=False)\n",
"\n",
"control_yaml_file = run_dir / \"control.yml\"\n",
"control_cp.to_yaml(control_yaml_file)"
]
},
{
Expand All @@ -740,6 +725,7 @@
"metadata": {},
"outputs": [],
"source": [
"model_dict_copy = deepcopy(model_dict)\n",
"model_dict_copy[\"control\"] = str(control_yaml_file)\n",
"model_dict_yaml_file = run_dir / \"model_dict.yml\""
]
Expand All @@ -759,13 +745,8 @@
"metadata": {},
"outputs": [],
"source": [
"dump_dict = {\n",
" control_yaml_file: control_dict_copy,\n",
" model_dict_yaml_file: model_dict_copy,\n",
"}\n",
"for key, val in dump_dict.items():\n",
" with open(key, \"w\") as file:\n",
" documents = yaml.dump(val, file)"
"with open(model_dict_yaml_file, \"w\") as file:\n",
" _ = yaml.dump(model_dict_copy, file)"
]
},
{
Expand All @@ -783,7 +764,7 @@
"metadata": {},
"outputs": [],
"source": [
"submodel = pws.Model.from_yml(model_dict_yaml_file)\n",
"submodel = pws.Model.from_yaml(model_dict_yaml_file)\n",
"submodel"
]
},
Expand Down Expand Up @@ -915,9 +896,9 @@
},
"outputs": [],
"source": [
"yml_output_dir = pl.Path(control_dict[\"netcdf_output_dir\"])\n",
"yaml_output_dir = pl.Path(control.options[\"netcdf_output_dir\"])\n",
"for ii in submodel_file_inputs:\n",
" input_file = yml_output_dir / f\"{ii}.nc\"\n",
" input_file = yaml_output_dir / f\"{ii}.nc\"\n",
" assert input_file.exists()\n",
" print(input_file)"
]
Expand Down Expand Up @@ -946,9 +927,12 @@
"run_dir.mkdir(exist_ok=True)\n",
"\n",
"# key that inputs exist from previous full-model run\n",
"control_dict[\"input_dir\"] = str(yml_output_dir.resolve())\n",
"control_dict[\"netcdf_output_dir\"] = str(run_dir.resolve())\n",
"control_yaml_file = run_dir / \"control.yml\""
"control_cp = deepcopy(control)\n",
"control_cp.options[\"input_dir\"] = yaml_output_dir.resolve()\n",
"control_cp.options[\"netcdf_output_dir\"] = run_dir.resolve()\n",
"control_yaml_file = run_dir / \"control.yml\"\n",
"control_cp.to_yaml(control_yaml_file)\n",
"pprint(control.to_dict(), sort_dicts=False)"
]
},
{
Expand All @@ -973,7 +957,8 @@
"for kk in list(model_dict.keys()):\n",
" if isinstance(model_dict[kk], dict) and kk not in keep_procs:\n",
" del model_dict[kk]\n",
"pprint(control_dict, sort_dicts=False)\n",
"\n",
"\n",
"pprint(model_dict, sort_dicts=False)"
]
},
Expand All @@ -992,10 +977,8 @@
"metadata": {},
"outputs": [],
"source": [
"dump_dict = {control_yaml_file: control_dict, model_dict_yaml_file: model_dict}\n",
"for key, val in dump_dict.items():\n",
" with open(key, \"w\") as file:\n",
" documents = yaml.dump(val, file)"
"with open(model_dict_yaml_file, \"w\") as file:\n",
" _ = yaml.dump(model_dict, file)"
]
},
{
Expand All @@ -1013,7 +996,7 @@
"metadata": {},
"outputs": [],
"source": [
"submodel = pws.Model.from_yml(model_dict_yaml_file)\n",
"submodel = pws.Model.from_yaml(model_dict_yaml_file)\n",
"submodel"
]
},
Expand Down Expand Up @@ -1154,7 +1137,7 @@
"outputs": [],
"source": [
"var = \"recharge\"\n",
"nhm_ds = xr.open_dataset(yml_output_dir / f\"{var}.nc\")\n",
"nhm_ds = xr.open_dataset(yaml_output_dir / f\"{var}.nc\")\n",
"sub_ds = xr.open_dataset(run_dir / f\"{var}.nc\")"
]
},
Expand Down Expand Up @@ -1187,7 +1170,7 @@
"outputs": [],
"source": [
"for var in submodel_variables:\n",
" nhm_da = xr.open_dataset(yml_output_dir / f\"{var}.nc\")[var]\n",
" nhm_da = xr.open_dataset(yaml_output_dir / f\"{var}.nc\")[var]\n",
" sub_da = xr.open_dataset(run_dir / f\"{var}.nc\")[var]\n",
" xr.testing.assert_equal(nhm_da, sub_da)"
]
Expand All @@ -1200,7 +1183,7 @@
"outputs": [],
"source": [
"# var_name = \"dprst_seep_hru\"\n",
"nhm_da = xr.open_dataset(yml_output_dir / f\"{var_name}.nc\")[var_name]\n",
"nhm_da = xr.open_dataset(yaml_output_dir / f\"{var_name}.nc\")[var_name]\n",
"sub_da = xr.open_dataset(run_dir / f\"{var_name}.nc\")[var_name]\n",
"scat = xr.merge(\n",
" [nhm_da.rename(f\"{var_name}_yaml\"), sub_da.rename(f\"{var_name}_subset\")]\n",
Expand Down
18 changes: 18 additions & 0 deletions examples/02_prms_legacy_models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,24 @@
"}"
]
},
{
"cell_type": "markdown",
"id": "ed82f8d1-8bfc-469e-a968-f86e029c7a5f",
"metadata": {},
"source": [
"We note that the `netcdf_output_var_names` in `control.options` is the combination of `nhruOutVar_names` and `nsegmentOutVar_names` from the PRMS-native `control.test` file. In the next section we'll customize this list of variables names, but here we list what we'll output with our current simulation."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e28f2df1-df17-451f-87ed-5d8d1e9d8b7e",
"metadata": {},
"outputs": [],
"source": [
"control.options[\"netcdf_output_var_names\"]"
]
},
{
"cell_type": "markdown",
"id": "0b46e9ca-e84b-40b3-bdc5-179fd6c85555",
Expand Down
Loading

0 comments on commit 554899c

Please sign in to comment.