diff --git a/CHANGES.rst b/CHANGES.rst index 15bac8b0..653d4af5 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,9 +1,13 @@ What's new ========== -0.7.1 (unreleased) +0.8.0 (unreleased) ------------------ +New features +~~~~~~~~~~~ +- Expose ESMF capability to use ``pole_kind`` to specify monopolar or bipolar grid types, useful for regridding tripolar ocean grids. By `Benjamin Cash `_. + Bug fixes ~~~~~~~~~ - Fix ``Mesh.from_polygons`` to support ``shapely`` 2.0. By `Pascal Bourgault `_. diff --git a/xesmf/backend.py b/xesmf/backend.py index c497f9c0..a21cbce3 100644 --- a/xesmf/backend.py +++ b/xesmf/backend.py @@ -58,7 +58,7 @@ def warn_lat_range(lat): class Grid(ESMF.Grid): @classmethod - def from_xarray(cls, lon, lat, periodic=False, mask=None): + def from_xarray(cls, lon, lat, periodic=False, mask=None, pole_kind=None): """ Create an ESMF.Grid object, for constructing ESMF.Field and ESMF.Regrid. @@ -83,6 +83,17 @@ def from_xarray(cls, lon, lat, periodic=False, mask=None): Shape should be ``(Nlon, Nlat)`` for rectilinear grid, or ``(Nx, Ny)`` for general quadrilateral grid. + pole_kind : [int, int] or None + Two item list which specifies the type of connection which occurs at the pole. + The first value specifies the connection that occurs at the minimum end of the + pole dimension. The second value specifies the connection that occurs at the + maximum end of the pole dimension. Options are 0 (no connections at pole), + 1 (monopole, this edge is connected to itself. Given that the edge is n elements long, + then element i is connected to element i+n/2), and 2 (bipole, this edge is connected + to itself. Given that the edge is n elements long, element i is connected to element n-i-1. + If None, defaults to [1,1] for monopole connections. See :attr:`ESMF.api.constants.PoleKind`. + Requires ESMF >= 8.0.1 + Returns ------- grid : ESMF.Grid object @@ -111,13 +122,21 @@ def from_xarray(cls, lon, lat, periodic=False, mask=None): # they will be set to default values (CENTER and SPH_DEG). # However, they actually need to be set explicitly, # otherwise grid._coord_sys and grid._staggerloc will still be None. - grid = cls( - np.array(lon.shape), + kwds = dict( staggerloc=staggerloc, coord_sys=ESMF.CoordSys.SPH_DEG, num_peri_dims=num_peri_dims, + pole_kind=pole_kind, ) + # `pole_kind` option supported since 8.0.1 + if ESMF.__version__ < '8.0.1': + if pole_kind is not None: + raise ValueError('The `pole_kind` option requires esmpy >= 8.0.1') + kwds.pop('pole_kind') + + grid = cls(np.array(lon.shape), **kwds) + # The grid object points to the underlying Fortran arrays in ESMF. # To modify lat/lon coordinates, need to get pointers to them lon_pointer = grid.get_coords(coord_dim=0, staggerloc=staggerloc) diff --git a/xesmf/frontend.py b/xesmf/frontend.py index 70e601dd..2e1caac2 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -127,11 +127,16 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None): else: mask = None + if 'pole_kind' in ds: + pole_kind = np.asarray(ds['pole_kind']) + else: + pole_kind = None + # tranpose the arrays so they become Fortran-ordered if mask is not None: - grid = Grid.from_xarray(lon.T, lat.T, periodic=periodic, mask=mask.T) + grid = Grid.from_xarray(lon.T, lat.T, periodic=periodic, mask=mask.T, pole_kind=pole_kind) else: - grid = Grid.from_xarray(lon.T, lat.T, periodic=periodic, mask=None) + grid = Grid.from_xarray(lon.T, lat.T, periodic=periodic, mask=None, pole_kind=pole_kind) if need_bounds: lon_b, lat_b = _get_lon_lat_bounds(ds) diff --git a/xesmf/tests/test_frontend.py b/xesmf/tests/test_frontend.py index 4aee9507..73304e0e 100644 --- a/xesmf/tests/test_frontend.py +++ b/xesmf/tests/test_frontend.py @@ -854,3 +854,49 @@ def test_spatial_averager_mask(): savg = xe.SpatialAverager(dsm, [poly], geom_dim_name='my_geom') out = savg(dsm.abc) assert_allclose(out, 2, rtol=1e-3) + + +def test_regrid_polekind(): + + # Open tripole SST file + ds_in = xr.open_dataset('mom6_tripole_SST.nc') + + # Open input grid specification + ds_ingrid = xr.open_dataset('grid_spec.nc') + ds_sst_grid = ds_ingrid.rename({'geolat': 'lat', 'geolon': 'lon'}) + ds_sst_grid['mask'] = ds_ingrid['wet'] + + # Get MOM6 mask + ds_ingrid['mask'] = ds_ingrid['wet'] + + # Open output grid specification + ds_outgrid = xr.open_dataset('C384_gaussian_grid.nc') + + # Get C384 land-sea mask + ds_outgrid['mask'] = 1 - ds_outgrid['land'].where(ds_outgrid['land'] < 2.0).squeeze() + + # Create regridder without specifying pole kind + base_regrid = xe.Regridder(ds_sst_grid, ds_outgrid, 'bilinear', periodic=True) + base_result = base_regrid(ds_in['SST']) + + # Add monopole grid information. 1 denotes monopole, 2 bipole + ds_sst_grid['pole_kind'] = np.array([1, 1]) + ds_outgrid['pole_kind'] = np.array([1, 1]) + + monopole_regrid = xe.Regridder(ds_sst_grid, ds_outgrid, 'bilinear', periodic=True) + monopole_result = monopole_regrid(ds_in['SST']) + + # Check behavior unchanged + assert monopole_result.equals(base_result) + + # Add bipole grid information + ds_sst_grid['pole_kind'] = np.array([1, 2], np.int32) + bipole_regrid = xe.Regridder(ds_sst_grid, ds_outgrid, 'bilinear', periodic=True) + bipole_result = bipole_regrid(ds_in['SST']) + + # Confirm results have changed + assert not bipole_result.equals(monopole_result) + + # Confirm results match saved values + verif_in = xr.open_dataset('verify_bipole_regrid_SST.nc')['SST'] + assert bipole_result.equals(verif_in)