Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions windspharm/iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# THE SOFTWARE.
from __future__ import absolute_import

from cf_units import Unit
from iris.cube import Cube
from iris.util import reverse

Expand Down Expand Up @@ -95,6 +96,9 @@ def __init__(self, u, v, rsphere=6.3712e6, legfunc='stored'):
v = v.copy()
u.transpose(apiorder)
v.transpose(apiorder)
if not u.units.is_unknown():
u.convert_units("m/s")
v.convert_units("m/s")
# Records the current shape and dimension coordinates of the inputs.
self._ishape = u.shape
self._coords = u.dim_coords
Expand Down Expand Up @@ -673,6 +677,7 @@ def gradient(self, chi, truncation=None):
if type(chi) is not Cube:
raise TypeError('scalar field must be an iris cube')
name = chi.name()
chi_units = chi.units
lat, lat_dim = _dim_coord_and_dim(chi, 'latitude')
lon, lon_dim = _dim_coord_and_dim(chi, 'longitude')
if (lat.points[0] < lat.points[1]):
Expand All @@ -698,6 +703,10 @@ def gradient(self, chi, truncation=None):
vchi.transpose(reorder)
uchi.long_name = 'zonal_gradient_of_{!s}'.format(name)
vchi.long_name = 'meridional_gradient_of_{!s}'.format(name)
if chi_units != "unknown":
result_units = chi_units / Unit("m")
uchi.units = result_units
vchi.units = result_units
return uchi, vchi

def truncate(self, field, truncation=None):
Expand Down
109 changes: 109 additions & 0 deletions windspharm/tests/test_units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Test the unit handling of the iris and xarray interfaces."""

import numpy as np
import pytest

try:
import iris.coords
import iris.cube

import windspharm.iris
except ImportError:
iris = None

try:
import xarray as xr

import windspharm.xarray
except ImportError:
xr = None


def create_data(package: str, wind_units: str):
"""Create data for use by test functions."""
data = np.zeros((7, 12), dtype="f4")
data[[1, -1], :] = 1
data[[2, -2], :] = np.sqrt(3)
data[3, :] = 2
lats = np.arange(-90, 91, 30, dtype="f4")
lons = np.arange(0, 360, 30, dtype="f4")
if package == "iris":
result = iris.cube.Cube(
data,
units=wind_units,
dim_coords_and_dims=(
(iris.coords.DimCoord(lats, "latitude", units="degrees_north"), 0),
(
iris.coords.DimCoord(
lons, "longitude", units="degrees_east", circular=True
),
1,
),
),
)
elif package == "xarray":
result = xr.DataArray(
data,
dims=("latitude", "longitude"),
coords={
"latitude": (("latitude",), lats, {"units": "degrees_north"}),
"longitude": (("longitude",), lons, {"units": "degrees_easth"}),
},
attrs={"units": wind_units},
)
return result


@pytest.mark.skipif("iris is None")
@pytest.mark.parametrize("units", ["knots", "miles per hour"])
def test_iris_convert_units(units):
"""Test that iris will convert speed units to m/s."""
cube = create_data("iris", units)
vec_wind = windspharm.iris.VectorWind(cube, cube)
assert np.any(vec_wind.u() != cube)
assert np.any(vec_wind.v() != cube)
assert np.all(np.around(vec_wind.u().data[3, :]) == 1)


@pytest.mark.skipif("xr is None")
@pytest.mark.parametrize("units", ["knots", "mph"])
def test_xr_warns_units(units):
"""Test that XArray warns for non-m/s wind units."""
data = create_data("xarray", units)
with pytest.warns(UserWarning):
vec_wind = windspharm.xarray.VectorWind(data, data)


@pytest.mark.skipif("xr is None")
@pytest.mark.parametrize("units", ["m/s", "m / s", "m s**-1", "m s^-1", "m s ** -1"])
@pytest.mark.filterwarnings("error:Winds should have units of m/s")
def test_xr_unit_recognition(units):
"""Test that XArray doesn't warn for different spellings of m/s."""
data = create_data("xarray", units)
vec_wind = windspharm.xarray.VectorWind(data, data)


@pytest.mark.parametrize(
"package",
[
pytest.param("iris", marks=pytest.mark.skipif("iris is None")),
pytest.param("xarray", marks=pytest.mark.skipif("xr is None")),
],
)
@pytest.mark.parametrize("units", ["K", "mg/kg"])
def test_gradient_units(package, units):
"""Test the units of the gradient.

They should be the units of the input per meter.
"""
scalar_data = create_data(package, units)
wind_data = create_data(package, "m/s")
vec_wind = getattr(windspharm, package).VectorWind(wind_data, wind_data)
grad_components = vec_wind.gradient(scalar_data)
for component in grad_components:
if package == "iris":
new_units = component.units
else:
new_units = component.attrs["units"]
assert new_units != units
assert new_units == "{:s} / m".format(units)
18 changes: 17 additions & 1 deletion windspharm/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
from __future__ import absolute_import
import warnings

try:
import xarray as xr
Expand All @@ -41,7 +42,8 @@ def __init__(self, u, v, rsphere=6.3712e6, legfunc='stored'):
Zonal and meridional components of the vector wind
respectively. Both components should be `~xarray.DataArray`
instances. The components must have the same dimension
coordinates and contain no missing values.
coordinates and contain no missing values. The wind
components should be in units of meters per second.

**Optional argument:**

Expand Down Expand Up @@ -78,6 +80,11 @@ def __init__(self, u, v, rsphere=6.3712e6, legfunc='stored'):
if not all([(uc == vc).all() for uc, vc in zip(ucoords, vcoords)]):
msg = 'u and v must have the same dimension coordinate values'
raise ValueError(msg)
if any(
comp.attrs.get("units", "m/s").replace(" ", "").replace("**", "^") not in ("m/s", "ms^-1")
for comp in (u, v)
):
warnings.warn("Winds should have units of m/s", UserWarning)
# Find the latitude and longitude coordinates and reverse the latitude
# dimension if necessary.
lat, lat_dim = _find_latitude_coordinate(u)
Expand Down Expand Up @@ -665,6 +672,12 @@ def gradient(self, chi, truncation=None):
if not isinstance(chi, xr.DataArray):
raise TypeError('scalar field must be an xarray.DataArray')
name = chi.name
try:
chi_units = chi.attrs["units"]
except KeyError:
grad_units = None
else:
grad_units = f"{chi_units:s} / m"
lat, lat_dim = _find_latitude_coordinate(chi)
lon, lon_dim = _find_longitude_coordinate(chi)
if (lat.values[0] < lat.values[1]):
Expand All @@ -689,6 +702,9 @@ def gradient(self, chi, truncation=None):
attrs={'long_name': vchi_name})
uchi = uchi.transpose(*reorder)
vchi = vchi.transpose(*reorder)
if grad_units is not None:
uchi.attrs["units"] = grad_units
vchi.attrs["units"] = grad_units
return uchi, vchi

def truncate(self, field, truncation=None):
Expand Down