Skip to content

Commit

Permalink
introduce common comparison utils for in-memory and netcdf output fil…
Browse files Browse the repository at this point in the history
…es, and a probeable/customiztble version of numpy.testing.assert_allclose, refactor test_prms_groundwater.py to use these.
  • Loading branch information
jmccreight committed Sep 25, 2023
1 parent 7f31e68 commit 00401f6
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 30 deletions.
58 changes: 28 additions & 30 deletions autotest/test_prms_groundwater.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
import pytest

from pywatershed import Control, Parameters, PRMSGroundwater
from pywatershed.base.adapter import adapter_factory
from pywatershed.hydrology.prms_groundwater import has_prmsgroundwater_f
from pywatershed.parameters import PrmsParameters
from pywatershed.utils.netcdf_utils import NetCdfCompare

from utils_compare import assert_allclose, compare_in_memory, compare_netcdfs

# compare in memory (faster) or full output files
compare_output_files = False
rtol = atol = 1.0e-13

calc_methods = ("numpy", "numba", "fortran")
params = ("params_sep", "params_one")
Expand Down Expand Up @@ -44,7 +50,6 @@ def test_compare_prms(

tmp_path = pl.Path(tmp_path)

# load csv files into dataframes
output_dir = domain["prms_output_dir"]
input_variables = {}
for key in PRMSGroundwater.get_inputs():
Expand All @@ -59,43 +64,36 @@ def test_compare_prms(
budget_type="error",
calc_method=calc_method,
)
nc_parent = tmp_path / domain["domain_name"]
gw.initialize_netcdf(nc_parent)

output_compare = {}
vars_compare = PRMSGroundwater.get_variables()
for key in PRMSGroundwater.get_variables():
if key not in vars_compare:
continue
base_nc_path = output_dir / f"{key}.nc"
compare_nc_path = tmp_path / domain["domain_name"] / f"{key}.nc"
output_compare[key] = (base_nc_path, compare_nc_path)

print(f"base_nc_path: {base_nc_path}")
print(f"compare_nc_path: {compare_nc_path}")
if compare_output_files:
nc_parent = tmp_path / domain["domain_name"]
gw.initialize_netcdf(nc_parent)
else:
answers = {}
for var in PRMSGroundwater.get_variables():
var_pth = output_dir / f"{var}.nc"
answers[var] = adapter_factory(
var_pth, variable_name=var, control=control
)

for istep in range(control.n_times):
control.advance()

gw.advance()

gw.calculate(float(istep))

gw.output()

if not compare_output_files:
compare_in_memory(gw, answers, atol=atol, rtol=rtol)

gw.finalize()

assert_error = False
for key, (base, compare) in output_compare.items():
success, diff = NetCdfCompare(base, compare).compare()
if not success:
print(
f"comparison for {key} failed: "
+ f"maximum error {diff[key][0]} "
+ f"(maximum allowed error {diff[key][1]}) "
+ f"in column {diff[key][2]}"
)
assert_error = True
assert not assert_error, "comparison failed"
if compare_output_files:
compare_netcdfs(
PRMSGroundwater.get_variables(),
tmp_path / domain["domain_name"],
output_dir,
atol=atol,
rtol=rtol,
)

return
120 changes: 120 additions & 0 deletions autotest/utils_compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import pathlib as pl

import numpy as np
import pywatershed as pws
import xarray as xr


def assert_allclose(
actual: np.ndarray,
desired: np.ndarray,
rtol: float = 1.0e-15,
atol: float = 1.0e-15,
equal_nan: bool = True,
strict: bool = False,
also_check_w_np: bool = True,
error_message: str = "Comparison unsuccessful (default message)",
):
"""Reinvent np.testing.assert_allclose to get useful diagnostincs in debug
Args:
actual: Array obtained.
desired: Array desired.
rtol: Relative tolerance.
atol: Absolute tolerance.
equal_nan: If True, NaNs will compare equal.
strict: If True, raise an ``AssertionError`` when either the shape or
the data type of the arguments does not match. The special
handling of scalars mentioned in the Notes section is disabled.
also_check_w_np: first check using np.testing.assert_allclose using
the same options.
"""

if also_check_w_np:
np.testing.assert_allclose(
actual,
desired,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
# strict=strict, # to add to newer versions of numpy
)

if strict:
assert actual.shape == desired.shape
assert isinstance(actual, type(desired))
assert isinstance(desired, type(actual))

if equal_nan:
actual_nan = np.where(np.isnan(actual), True, False)
desired_nan = np.where(np.isnan(desired), True, False)
assert (actual_nan == desired_nan).all()

abs_diff = abs(actual - desired)
rel_abs_diff = abs_diff / desired

abs_close = abs_diff < atol
rel_close = rel_abs_diff < rtol
rel_close = np.where(np.isnan(rel_close), False, rel_close)

close = abs_close | rel_close

assert close.all()


def compare_in_memory(
process: pws.base.Process,
answers: dict[pws.base.adapter.AdapterNetcdf],
rtol: float = 1.0e-15,
atol: float = 1.0e-15,
equal_nan: bool = True,
strict: bool = False,
also_check_w_np: bool = True,
error_message: str = None,
):
# TODO: docstring
for var in process.get_variables():
answers[var].advance()
assert_allclose(
process[var],
answers[var].current.data,
atol=atol,
rtol=rtol,
equal_nan=equal_nan,
strict=strict,
also_check_w_np=also_check_w_np,
error_message=error_message,
)


def compare_netcdfs(
var_list: list,
results_dir: pl.Path,
answers_dir: pl.Path,
rtol: float = 1.0e-15,
atol: float = 1.0e-15,
equal_nan: bool = True,
strict: bool = False,
also_check_w_np: bool = True,
error_message: str = None,
):
# TODO: docstring
# TODO: improve error message
# TODO: collect failures in a try and report at end
for var in var_list:
answer = xr.open_dataarray(answers_dir / f"{var}.nc")
result = xr.open_dataarray(results_dir / f"{var}.nc")

if error_message is None:
error_message = f"Comparison of variable '{var}' was unsuccessful"

assert_allclose(
actual=result.values,
desired=answer.values,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
strict=strict,
also_check_w_np=also_check_w_np,
error_message=error_message,
)

0 comments on commit 00401f6

Please sign in to comment.