Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions windspharm/iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# THE SOFTWARE.
from __future__ import absolute_import

from cf_units import Unit
from iris.cube import Cube
from iris.util import reverse

Expand Down Expand Up @@ -95,6 +96,9 @@ def __init__(self, u, v, rsphere=6.3712e6, legfunc='stored'):
v = v.copy()
u.transpose(apiorder)
v.transpose(apiorder)
if not u.units.is_unknown():
u.convert_units("m/s")
v.convert_units("m/s")
# Records the current shape and dimension coordinates of the inputs.
self._ishape = u.shape
self._coords = u.dim_coords
Expand Down Expand Up @@ -673,6 +677,7 @@ def gradient(self, chi, truncation=None):
if type(chi) is not Cube:
raise TypeError('scalar field must be an iris cube')
name = chi.name()
chi_units = chi.units
lat, lat_dim = _dim_coord_and_dim(chi, 'latitude')
lon, lon_dim = _dim_coord_and_dim(chi, 'longitude')
if (lat.points[0] < lat.points[1]):
Expand All @@ -698,6 +703,10 @@ def gradient(self, chi, truncation=None):
vchi.transpose(reorder)
uchi.long_name = 'zonal_gradient_of_{!s}'.format(name)
vchi.long_name = 'meridional_gradient_of_{!s}'.format(name)
if not chi_units.is_unknown():
result_units = chi_units / Unit("m")
uchi.units = result_units
vchi.units = result_units
return uchi, vchi

def truncate(self, field, truncation=None):
Expand Down
109 changes: 109 additions & 0 deletions windspharm/tests/test_units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Test the unit handling of the iris and xarray interfaces."""

import numpy as np
import pytest

try:
import iris.coords
import iris.cube

import windspharm.iris
except ImportError:
iris = None

try:
import xarray as xr

import windspharm.xarray
except ImportError:
xr = None


def create_data(package: str, wind_units: str):
"""Create data for use by test functions."""
data = np.zeros((7, 12), dtype="f4")
data[[1, -1], :] = 1
data[[2, -2], :] = np.sqrt(3)
data[3, :] = 2
lats = np.arange(-90, 91, 30, dtype="f4")
lons = np.arange(0, 360, 30, dtype="f4")
if package == "iris":
result = iris.cube.Cube(
data,
units=wind_units,
dim_coords_and_dims=(
(iris.coords.DimCoord(lats, "latitude", units="degrees_north"), 0),
(
iris.coords.DimCoord(
lons, "longitude", units="degrees_east", circular=True
),
1,
),
),
)
elif package == "xarray":
result = xr.DataArray(
data,
dims=("latitude", "longitude"),
coords={
"latitude": (("latitude",), lats, {"units": "degrees_north"}),
"longitude": (("longitude",), lons, {"units": "degrees_easth"}),
},
attrs={"units": wind_units},
)
return result


@pytest.mark.skipif("iris is None")
@pytest.mark.parametrize("units", ["knots", "miles per hour"])
def test_iris_convert_units(units):
"""Test that iris will convert speed units to m/s."""
cube = create_data("iris", units)
vec_wind = windspharm.iris.VectorWind(cube, cube)
assert np.any(vec_wind.u() != cube)
assert np.any(vec_wind.v() != cube)
assert np.all(np.around(vec_wind.u().data[3, :]) == 1)


@pytest.mark.skipif("xr is None")
@pytest.mark.parametrize("units", ["knots", "mph"])
def test_xr_warns_units(units):
"""Test that XArray warns for non-m/s wind units."""
data = create_data("xarray", units)
with pytest.warns(UserWarning):
vec_wind = windspharm.xarray.VectorWind(data, data)


@pytest.mark.skipif("xr is None")
@pytest.mark.parametrize("units", ["m/s", "m / s", "m s**-1", "m s^-1", "m s ** -1"])
@pytest.mark.filterwarnings("error:Winds should have units of m/s")
def test_xr_unit_recognition(units):
"""Test that XArray doesn't warn for different spellings of m/s."""
data = create_data("xarray", units)
vec_wind = windspharm.xarray.VectorWind(data, data)


@pytest.mark.parametrize(
"package",
[
pytest.param("iris", marks=pytest.mark.skipif("iris is None")),
pytest.param("xarray", marks=pytest.mark.skipif("xr is None")),
],
)
@pytest.mark.parametrize("units", ["K", "mg/kg"])
def test_gradient_units(package, units):
"""Test the units of the gradient.

They should be the units of the input per meter.
"""
scalar_data = create_data(package, units)
wind_data = create_data(package, "m/s")
vec_wind = getattr(windspharm, package).VectorWind(wind_data, wind_data)
grad_components = vec_wind.gradient(scalar_data)
for component in grad_components:
if package == "iris":
new_units = component.units
else:
new_units = component.attrs["units"]
assert new_units != units
assert new_units == "{:s} / m".format(units)
18 changes: 17 additions & 1 deletion windspharm/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
from __future__ import absolute_import
import warnings

try:
import xarray as xr
Expand All @@ -41,7 +42,8 @@ def __init__(self, u, v, rsphere=6.3712e6, legfunc='stored'):
Zonal and meridional components of the vector wind
respectively. Both components should be `~xarray.DataArray`
instances. The components must have the same dimension
coordinates and contain no missing values.
coordinates and contain no missing values. The wind
components should be in units of meters per second.

**Optional argument:**

Expand Down Expand Up @@ -78,6 +80,11 @@ def __init__(self, u, v, rsphere=6.3712e6, legfunc='stored'):
if not all([(uc == vc).all() for uc, vc in zip(ucoords, vcoords)]):
msg = 'u and v must have the same dimension coordinate values'
raise ValueError(msg)
if any(
comp.attrs.get("units", "m/s").replace(" ", "").replace("**", "^") not in ("m/s", "ms^-1")
for comp in (u, v)
):
warnings.warn("Winds should have units of m/s", UserWarning)
# Find the latitude and longitude coordinates and reverse the latitude
# dimension if necessary.
lat, lat_dim = _find_latitude_coordinate(u)
Expand Down Expand Up @@ -665,6 +672,12 @@ def gradient(self, chi, truncation=None):
if not isinstance(chi, xr.DataArray):
raise TypeError('scalar field must be an xarray.DataArray')
name = chi.name
try:
chi_units = chi.attrs["units"]
except KeyError:
grad_units = None
else:
grad_units = f"{chi_units:s} / m"
lat, lat_dim = _find_latitude_coordinate(chi)
lon, lon_dim = _find_longitude_coordinate(chi)
if (lat.values[0] < lat.values[1]):
Expand All @@ -689,6 +702,9 @@ def gradient(self, chi, truncation=None):
attrs={'long_name': vchi_name})
uchi = uchi.transpose(*reorder)
vchi = vchi.transpose(*reorder)
if grad_units is not None:
uchi.attrs["units"] = grad_units
vchi.attrs["units"] = grad_units
return uchi, vchi

def truncate(self, field, truncation=None):
Expand Down
Loading