diff --git a/windspharm/iris.py b/windspharm/iris.py index 03670c0..0ce84aa 100644 --- a/windspharm/iris.py +++ b/windspharm/iris.py @@ -20,6 +20,7 @@ # THE SOFTWARE. from __future__ import absolute_import +from cf_units import Unit from iris.cube import Cube from iris.util import reverse @@ -95,6 +96,9 @@ def __init__(self, u, v, rsphere=6.3712e6, legfunc='stored'): v = v.copy() u.transpose(apiorder) v.transpose(apiorder) + if not u.units.is_unknown(): + u.convert_units("m/s") + v.convert_units("m/s") # Records the current shape and dimension coordinates of the inputs. self._ishape = u.shape self._coords = u.dim_coords @@ -673,6 +677,7 @@ def gradient(self, chi, truncation=None): if type(chi) is not Cube: raise TypeError('scalar field must be an iris cube') name = chi.name() + chi_units = chi.units lat, lat_dim = _dim_coord_and_dim(chi, 'latitude') lon, lon_dim = _dim_coord_and_dim(chi, 'longitude') if (lat.points[0] < lat.points[1]): @@ -698,6 +703,10 @@ def gradient(self, chi, truncation=None): vchi.transpose(reorder) uchi.long_name = 'zonal_gradient_of_{!s}'.format(name) vchi.long_name = 'meridional_gradient_of_{!s}'.format(name) + if not chi_units.is_unknown(): + result_units = chi_units / Unit("m") + uchi.units = result_units + vchi.units = result_units return uchi, vchi def truncate(self, field, truncation=None): diff --git a/windspharm/tests/test_units.py b/windspharm/tests/test_units.py new file mode 100644 index 0000000..7a80d9a --- /dev/null +++ b/windspharm/tests/test_units.py @@ -0,0 +1,109 @@ +"""Test the unit handling of the iris and xarray interfaces.""" + +import numpy as np +import pytest + +try: + import iris.coords + import iris.cube + + import windspharm.iris +except ImportError: + iris = None + +try: + import xarray as xr + + import windspharm.xarray +except ImportError: + xr = None + + +def create_data(package: str, wind_units: str): + """Create data for use by test functions.""" + data = np.zeros((7, 12), dtype="f4") + data[[1, -1], :] = 1 + data[[2, -2], :] = np.sqrt(3) + data[3, :] = 2 + lats = np.arange(-90, 91, 30, dtype="f4") + lons = np.arange(0, 360, 30, dtype="f4") + if package == "iris": + result = iris.cube.Cube( + data, + units=wind_units, + dim_coords_and_dims=( + (iris.coords.DimCoord(lats, "latitude", units="degrees_north"), 0), + ( + iris.coords.DimCoord( + lons, "longitude", units="degrees_east", circular=True + ), + 1, + ), + ), + ) + elif package == "xarray": + result = xr.DataArray( + data, + dims=("latitude", "longitude"), + coords={ + "latitude": (("latitude",), lats, {"units": "degrees_north"}), + "longitude": (("longitude",), lons, {"units": "degrees_easth"}), + }, + attrs={"units": wind_units}, + ) + return result + + +@pytest.mark.skipif("iris is None") +@pytest.mark.parametrize("units", ["knots", "miles per hour"]) +def test_iris_convert_units(units): + """Test that iris will convert speed units to m/s.""" + cube = create_data("iris", units) + vec_wind = windspharm.iris.VectorWind(cube, cube) + assert np.any(vec_wind.u() != cube) + assert np.any(vec_wind.v() != cube) + assert np.all(np.around(vec_wind.u().data[3, :]) == 1) + + +@pytest.mark.skipif("xr is None") +@pytest.mark.parametrize("units", ["knots", "mph"]) +def test_xr_warns_units(units): + """Test that XArray warns for non-m/s wind units.""" + data = create_data("xarray", units) + with pytest.warns(UserWarning): + vec_wind = windspharm.xarray.VectorWind(data, data) + + +@pytest.mark.skipif("xr is None") +@pytest.mark.parametrize("units", ["m/s", "m / s", "m s**-1", "m s^-1", "m s ** -1"]) +@pytest.mark.filterwarnings("error:Winds should have units of m/s") +def test_xr_unit_recognition(units): + """Test that XArray doesn't warn for different spellings of m/s.""" + data = create_data("xarray", units) + vec_wind = windspharm.xarray.VectorWind(data, data) + + +@pytest.mark.parametrize( + "package", + [ + pytest.param("iris", marks=pytest.mark.skipif("iris is None")), + pytest.param("xarray", marks=pytest.mark.skipif("xr is None")), + ], +) +@pytest.mark.parametrize("units", ["K", "mg/kg"]) +def test_gradient_units(package, units): + """Test the units of the gradient. + + They should be the units of the input per meter. + """ + scalar_data = create_data(package, units) + wind_data = create_data(package, "m/s") + vec_wind = getattr(windspharm, package).VectorWind(wind_data, wind_data) + grad_components = vec_wind.gradient(scalar_data) + for component in grad_components: + if package == "iris": + new_units = component.units + else: + new_units = component.attrs["units"] + assert new_units != units + assert new_units == "{:s} / m".format(units) diff --git a/windspharm/xarray.py b/windspharm/xarray.py index b8cb69e..b0c96fa 100644 --- a/windspharm/xarray.py +++ b/windspharm/xarray.py @@ -19,6 +19,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. from __future__ import absolute_import +import warnings try: import xarray as xr @@ -41,7 +42,8 @@ def __init__(self, u, v, rsphere=6.3712e6, legfunc='stored'): Zonal and meridional components of the vector wind respectively. Both components should be `~xarray.DataArray` instances. The components must have the same dimension - coordinates and contain no missing values. + coordinates and contain no missing values. The wind + components should be in units of meters per second. **Optional argument:** @@ -78,6 +80,11 @@ def __init__(self, u, v, rsphere=6.3712e6, legfunc='stored'): if not all([(uc == vc).all() for uc, vc in zip(ucoords, vcoords)]): msg = 'u and v must have the same dimension coordinate values' raise ValueError(msg) + if any( + comp.attrs.get("units", "m/s").replace(" ", "").replace("**", "^") not in ("m/s", "ms^-1") + for comp in (u, v) + ): + warnings.warn("Winds should have units of m/s", UserWarning) # Find the latitude and longitude coordinates and reverse the latitude # dimension if necessary. lat, lat_dim = _find_latitude_coordinate(u) @@ -665,6 +672,12 @@ def gradient(self, chi, truncation=None): if not isinstance(chi, xr.DataArray): raise TypeError('scalar field must be an xarray.DataArray') name = chi.name + try: + chi_units = chi.attrs["units"] + except KeyError: + grad_units = None + else: + grad_units = f"{chi_units:s} / m" lat, lat_dim = _find_latitude_coordinate(chi) lon, lon_dim = _find_longitude_coordinate(chi) if (lat.values[0] < lat.values[1]): @@ -689,6 +702,9 @@ def gradient(self, chi, truncation=None): attrs={'long_name': vchi_name}) uchi = uchi.transpose(*reorder) vchi = vchi.transpose(*reorder) + if grad_units is not None: + uchi.attrs["units"] = grad_units + vchi.attrs["units"] = grad_units return uchi, vchi def truncate(self, field, truncation=None):