diff --git a/README.md b/README.md
index eb22f4e47b..e9a5f69082 100644
--- a/README.md
+++ b/README.md
@@ -120,15 +120,26 @@ gw_convect_dp_ml_norms='/path/to/norms'
To run CAM using the non local gravity wave ML model to replace all parameterisations use the following configuration
```fortran
-use_gw_nlgw=.true.
-gw_nlgw_model_path='/path/to/nlgw-scripted-model.pt'
+use_gw_nlgw_ann=.true.
+use_gw_nlgw_unet=.true.
+gw_nlgw_model_path_ann='/path/to/ann-scripted-model.pt'
+gw_nlgw_model_path_unet='/path/to/unet-scripted-model.pt'
```
-* `use_gw_nlgw` (`logical`)
+* `use_gw_nlgw_ann` (`logical`)
- Whether or not to use the ML scheme for non local gravity waves. Default: `.false.`
+ Whether or not to use the ANN ML scheme for non local gravity waves. Default: `.false.`
-* `gw_nlgw_model_path`
+* `gw_nlgw_model_path_ann`
+
+ Absolute filepath to the non local gravity wave neural net used when `use_gw_nlgw` is set to `.true.` (`.pt`
+ extension).
+
+* `use_gw_nlgw_unet` (`logical`)
+
+ Whether or not to use the UNET ML scheme for non local gravity waves. Default: `.false.`
+
+* `gw_nlgw_model_path_unet`
Absolute filepath to the non local gravity wave neural net used when `use_gw_nlgw` is set to `.true.` (`.pt`
extension).
diff --git a/bld/build-namelist b/bld/build-namelist
index 0ffb144b48..9a8e1b88dc 100755
--- a/bld/build-namelist
+++ b/bld/build-namelist
@@ -3606,7 +3606,8 @@ if (!$simple_phys) {
add_default($nl, 'use_gw_rdg_gamma' , 'val'=>'.false.');
add_default($nl, 'use_gw_front_igw' , 'val'=>'.false.');
add_default($nl, 'use_gw_convect_sh', 'val'=>'.false.');
- add_default($nl, 'use_gw_nlgw' , 'val'=>'.false.');
+ add_default($nl, 'use_gw_nlgw_ann' , 'val'=>'.false.');
+ add_default($nl, 'use_gw_nlgw_unet' , 'val'=>'.false.');
add_default($nl, 'gw_lndscl_sgh');
add_default($nl, 'gw_oro_south_fac');
add_default($nl, 'gw_limit_tau_without_eff');
diff --git a/bld/namelist_files/namelist_definition.xml b/bld/namelist_files/namelist_definition.xml
index 52e8313b21..b783f836d7 100644
--- a/bld/namelist_files/namelist_definition.xml
+++ b/bld/namelist_files/namelist_definition.xml
@@ -1332,10 +1332,17 @@ Whether or not to enable gravity waves produced by shallow convection.
Default: .false.
-
Whether or not to enable gravity waves produced by non-local gravity
-wave ML model.
+wave ANN ML model.
+Default: set by build-namelist.
+
+
+
+Whether or not to enable gravity waves produced by non-local gravity
+wave UNet ML model.
Default: set by build-namelist.
@@ -1428,10 +1435,16 @@ Absolute filepath to the deep convection gravity wave neural net used when
Default: .false.
-
+Absolute filepath to the non local gravity wave traced model (.pt)
+used when `use_gw_nlgw_ann` is set to `.true.`.
+
+
+
Absolute filepath to the non local gravity wave traced model (.pt)
-used when `use_gw_nlgw` is set to `.true.`.
+used when `use_gw_nlgw_unet` is set to `.true.`.
masterprocid, mpi_real8, &
mpi_character, mpi_logical, mpi_integer
use gw_rdg, only: gw_rdg_readnl
+ use gw_nlgw_utils, only: gw_nlgw_model_path_ann, gw_nlgw_model_path_unet
! File containing namelist input.
character(len=*), intent(in) :: nlfile
@@ -248,7 +249,7 @@ subroutine gw_drag_readnl(nlfile)
gw_top_taper, front_gaussian_width, &
gw_convect_dp_ml, gw_convect_dp_ml_compare, &
gw_convect_dp_ml_net_path, gw_convect_dp_ml_norms, &
- gw_nlgw_model_path
+ gw_nlgw_model_path_ann, gw_nlgw_model_path_unet
!----------------------------------------------------------------------
if (use_simple_phys) return
@@ -364,8 +365,11 @@ subroutine gw_drag_readnl(nlfile)
call mpi_bcast(gw_convect_dp_ml_norms, len(gw_convect_dp_ml_norms), mpi_character, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: gw_convect_dp_ml_norms")
- call mpi_bcast(gw_nlgw_model_path, len(gw_nlgw_model_path), mpi_character, mstrid, mpicom, ierr)
- if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: gw_nlgw_model_path")
+ call mpi_bcast(gw_nlgw_model_path_ann, len(gw_nlgw_model_path_ann), mpi_character, mstrid, mpicom, ierr)
+ if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: gw_nlgw_model_path_ann")
+
+ call mpi_bcast(gw_nlgw_model_path_unet, len(gw_nlgw_model_path_unet), mpi_character, mstrid, mpicom, ierr)
+ if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: gw_nlgw_model_path_unet")
! Check if fcrit2 was set.
call shr_assert(fcrit2 /= unset_r8, &
@@ -425,6 +429,7 @@ subroutine gw_init()
use gw_common, only: gw_common_init
use gw_front, only: gaussian_cm_desc
+ use gw_nlgw_utils, only: gw_nlgw_model_path_ann
!---------------------------Local storage-------------------------------
@@ -577,8 +582,8 @@ subroutine gw_init()
call shr_assert(trim(errstring) == "", "gw_common_init: "//errstring// &
errMsg(__FILE__, __LINE__))
- if ( use_gw_nlgw ) then
- call gw_nlgw_dp_init(gw_nlgw_model_path)
+ if ( use_gw_nlgw_ann ) then
+ call gw_nlgw_ann_init(gw_nlgw_model_path_ann)
end if
if ( use_gw_oro ) then
@@ -1298,8 +1303,8 @@ subroutine gw_final()
if ((gw_convect_dp_ml) .or. (gw_convect_dp_ml_compare)) then
call gw_drag_convect_dp_ml_final()
endif
- if ( use_gw_nlgw ) then
- call gw_nlgw_dp_finalize()
+ if ( use_gw_nlgw_ann ) then
+ call gw_nlgw_ann_finalize()
end if
end subroutine gw_final
@@ -1548,8 +1553,8 @@ subroutine gw_tend(state, pbuf, dt, ptend, cam_in, flx_heat)
egwdffi_tot = 0._r8
flx_heat = 0._r8
- if ( use_gw_nlgw ) then
- call gw_nlgw_dp_ml(state1,ptend)
+ if ( use_gw_nlgw_ann ) then
+ call gw_nlgw_ann_infer(state1,ptend)
end if
if (use_gw_convect_dp) then
diff --git a/src/physics/cam/gw_nlgw.F90 b/src/physics/cam/gw_nlgw_ann.F90
similarity index 91%
rename from src/physics/cam/gw_nlgw.F90
rename to src/physics/cam/gw_nlgw_ann.F90
index 647b771e68..fc61d9bb60 100644
--- a/src/physics/cam/gw_nlgw.F90
+++ b/src/physics/cam/gw_nlgw_ann.F90
@@ -1,4 +1,4 @@
-module gw_nlgw
+module gw_nlgw_ann
!
! This module predicts gravity wave forcings via PyTorch NNs trained to include non-local gravity wave effects
@@ -11,18 +11,17 @@ module gw_nlgw
use cam_abortutils, only: endrun
use cam_logfile, only: iulog
use physconst, only: cappa, pi
+use gw_nlgw_utils, only: p0
use interpolate_data, only: lininterp
use ftorch
implicit none
-public :: gw_nlgw_dp_ml, gw_nlgw_dp_init, gw_nlgw_dp_finalize
+public :: gw_nlgw_ann_infer, gw_nlgw_ann_init, gw_nlgw_ann_finalize
private
-integer, parameter :: p0 = 100000 ! 1000 hPa (Pa)
-
type(torch_model) :: nlgw_model ! pytorch model
integer :: ncol ! number of vertical columns
@@ -104,7 +103,9 @@ module gw_nlgw
!==========================================================================
-subroutine gw_nlgw_dp_ml(state_in, ptend)
+subroutine gw_nlgw_ann_infer(state_in, ptend)
+
+ use gw_nlgw_utils, only: flux_to_forcing
! inputs
type(physics_state), intent(in) :: state_in
@@ -170,8 +171,8 @@ subroutine gw_nlgw_dp_ml(state_in, ptend)
call extract_output()
call denormalise_data()
- call flux_to_forcing(uflux, utgw)
- call flux_to_forcing(vflux, vtgw)
+ call flux_to_forcing(uflux, utgw, pmid, ncol)
+ call flux_to_forcing(vflux, vtgw, pmid, ncol)
! update the tendencies
ptend%u(:ncol,:pver) = ptend%u(:ncol,:pver) + utgw(:ncol,:pver)
@@ -200,10 +201,10 @@ subroutine gw_nlgw_dp_ml(state_in, ptend)
deallocate(net_inputs)
deallocate(net_outputs)
-end subroutine gw_nlgw_dp_ml
+end subroutine gw_nlgw_ann_infer
-subroutine gw_nlgw_dp_init(model_path)
+subroutine gw_nlgw_ann_init(model_path)
character(len=*), intent(in) :: model_path ! Filepath to PyTorch Torchscript net
integer :: device_id
@@ -219,17 +220,17 @@ subroutine gw_nlgw_dp_init(model_path)
write(iulog,*)'nlgw model loaded from: ', model_path
endif
-end subroutine gw_nlgw_dp_init
+end subroutine gw_nlgw_ann_init
-subroutine gw_nlgw_dp_finalize()
+subroutine gw_nlgw_ann_finalize()
deallocate(net_inputs)
deallocate(net_outputs)
! free model memory
call torch_delete(nlgw_model)
-end subroutine gw_nlgw_dp_finalize
+end subroutine gw_nlgw_ann_finalize
subroutine read_norms()
@@ -259,6 +260,7 @@ subroutine read_norms()
end subroutine read_norms
subroutine normalise_data()
+ use gw_nlgw_utils, only: cbrt
! lat lon are in radians (convert to degrees first)
lat = lat * 180. / pi
@@ -363,31 +365,4 @@ subroutine denormalise_data()
end subroutine denormalise_data
-elemental function cbrt(a) result(root)
- real(r8), intent(in) :: a
- real(r8), parameter :: one_third = 1._r8/3._r8
- real(r8) :: root
- root = sign(abs(a)**one_third, a)
-end function cbrt
-
-subroutine flux_to_forcing(flux, forcing)
-
- real(r8), intent(in), dimension(:,:) :: flux
- real(r8), intent(out), dimension(:,:) :: forcing ! forcing = -d(u'\omega')/d(p), units = m/s^2
-
- integer :: level, col
-
- ! convert fluxes to tendencies
- ! pressure profile must be in Pascals
-
- do col = 1, ncol
- forcing(col,1) = -1*(flux(col,2) - flux(col,1))/(pmid(col,2) - pmid(col,1))
- do level = 2, pver-1
- forcing(col,level) = (flux(col,level+1) - flux(col,level-1)) / (pmid(col,level)*(log(pmid(col,level+1)) - log(pmid(col,level-1))))
- end do
- forcing(col,pver) = -1*(flux(col,pver) - flux(col,pver-1)) / (pmid(col,pver) - pmid(col,pver-1))
- end do
-
-end subroutine flux_to_forcing
-
-end module gw_nlgw
+end module gw_nlgw_ann
diff --git a/src/physics/cam/gw_nlgw_unet.F90 b/src/physics/cam/gw_nlgw_unet.F90
new file mode 100644
index 0000000000..c5be68ab93
--- /dev/null
+++ b/src/physics/cam/gw_nlgw_unet.F90
@@ -0,0 +1,219 @@
+module gw_nlgw_unet
+
+!
+! This module predicts gravity wave forcings via PyTorch NNs trained to include non-local gravity wave effects
+!
+
+use gw_utils, only: r8, r4
+use ppgrid, only: pver !vertical levels
+use physics_types, only: physics_state, physics_ptend
+use spmd_utils, only: mpicom, mstrid=>masterprocid, masterproc, mpi_real8, iam
+use cam_abortutils, only: endrun
+use cam_logfile, only: iulog
+use cam_history, only: outfld, addfld
+use physconst, only: cappa
+use gw_nlgw_utils, only: lonlat_vars, nlon, nlat
+
+use ftorch
+
+implicit none
+
+public :: gw_nlgw_unet_init, gw_nlgw_unet_infer, gw_nlgw_unet_finalize, gw_nlgw_unet_set_ptend
+
+private
+
+type(torch_model) :: nlgw_model ! pytorch model
+
+real(r4), dimension(:,:,:,:), allocatable, target :: net_inputs
+real(r4), dimension(:,:,:,:), allocatable, target :: net_outputs
+
+! normalisation means and std devs
+real(r8) :: u_mean, v_mean, w_mean, theta_mean
+real(r8) :: u_std, v_std, w_std, theta_std
+
+real(r8) :: uflux_mean, vflux_mean
+real(r8) :: uflux_std, vflux_std
+
+contains
+
+!==========================================================================
+
+
+subroutine gw_nlgw_unet_init(model_path)
+
+ character(len=*), intent(in) :: model_path ! Filepath to PyTorch Torchscript net
+ integer :: device_id
+
+ device_id = 0
+
+ ! Load the convective drag net from TorchScript file
+ call torch_model_load(nlgw_model, model_path, device_type=torch_kCUDA, device_index=device_id)
+ ! read in normalisation weights
+ call read_norms()
+
+ if (masterproc) then
+ write(iulog,*)'nlgw model loaded from: ', model_path
+
+ ! UNet will only run on the master process
+ ! space for u, v, theta and w
+ allocate(net_inputs(1, pver*4, nlat, nlon))
+ ! space for uflux and vflux
+ allocate(net_outputs(1, pver*2, nlat, nlon))
+ endif
+
+
+end subroutine gw_nlgw_unet_init
+
+subroutine gw_nlgw_unet_infer(gathered_lonlat)
+
+ ! global unet data
+ type(lonlat_vars), intent(inout) :: gathered_lonlat
+
+ !---------------------------Local storage-------------------------------
+ type(torch_tensor) :: tensor_in(1), tensor_out(1)
+ integer :: ninputs = 1, noutputs = 1
+ integer, dimension(4) :: layout = [1 , 2, 3, 4]
+
+ integer :: device_id
+
+ device_id = 0
+
+ ! Normalise and construct the input
+ call normalise_data(gathered_lonlat)
+ call construct_input(gathered_lonlat)
+
+ ! send all columns from this process
+ call torch_tensor_from_array(tensor_in(1), net_inputs, layout, torch_kCUDA, device_id)
+ call torch_tensor_from_array(tensor_out(1), net_outputs, layout, torch_kCPU)
+
+ ! Run net forward on data
+ call torch_model_forward(nlgw_model, tensor_in, tensor_out)
+
+ ! Extract and denormalise outputs
+ call extract_output(gathered_lonlat)
+ call denormalise_data(gathered_lonlat)
+
+ ! Clean up the tensors
+ call torch_delete(tensor_in)
+ call torch_delete(tensor_out)
+
+end subroutine gw_nlgw_unet_infer
+
+subroutine gw_nlgw_unet_finalize()
+
+ if (masterproc) then
+ deallocate(net_inputs)
+ deallocate(net_outputs)
+ end if
+ ! free model memory
+ call torch_delete(nlgw_model)
+
+end subroutine gw_nlgw_unet_finalize
+
+subroutine gw_nlgw_unet_set_ptend(phys_state, ptend, utgw, vtgw)
+ ! initiailize and update tendencies
+ use physconst ,only: cpair
+ use physics_types,only: physics_state,physics_ptend,physics_ptend_init
+ use constituents ,only: cnst_get_ind,pcnst
+ use ppgrid ,only: pver,pcols,begchunk,endchunk
+ use cam_history ,only: outfld
+
+ type(physics_state), intent(in) :: phys_state
+ type(physics_ptend), intent(out):: ptend
+ real(r8), dimension(pcols,pver), intent(in) :: utgw, vtgw
+
+ ! local vars
+ integer indw,ncol,lchnk
+ logical lq(pcnst)
+
+ call cnst_get_ind('Q',indw)
+ lq(:) =.false.
+ lq(indw)=.true.
+ call physics_ptend_init(ptend,phys_state%psetcols,'nlgw_unet',lu=.true.,lv=.true.,ls=.true.,lq=lq)
+
+ lchnk=phys_state%lchnk
+ ncol =phys_state%ncol
+ ptend%u(:ncol,:pver) = ptend%u(:ncol,:pver) + utgw(:ncol,:pver)
+ ptend%v(:ncol,:pver) = ptend%v(:ncol,:pver) + vtgw(:ncol,:pver)
+ ! ptend%s(:ncol,:pver) = cb24cnn_Sstep(:ncol,:pver,lchnk)*.35
+ ! ptend%q(:ncol,:pver,indw)= cb24cnn_Qstep(:ncol,:pver,lchnk)*.35
+
+ call outfld('UTGW_NL', utgw(:ncol,:pver), ncol, lchnk)
+ call outfld('VTGW_NL', vtgw(:ncol,:pver), ncol, lchnk)
+end subroutine gw_nlgw_unet_set_ptend
+
+subroutine read_norms()
+
+ ! TODO
+ ! - replace hardcoded means/std devs with netcdf file?
+
+ u_mean = 6.717847278462159_r8
+ v_mean = -0.002744777264668839_r8
+ theta_mean = 0._r8
+ w_mean = 0.0013401482063147452_r8
+
+ u_std = 20.760385183200206_r8
+ v_std = 9.877389116738264_r8
+ theta_std = 1000._r8
+ w_std = 0.11202126259282257_r8
+
+ uflux_mean = -0.0004691528666736032_r8
+ vflux_mean = -0.0002586195082961397_r8
+ uflux_std = 0.032814051953840274_r8
+ vflux_std = 0.03024781201672967_r8
+
+end subroutine read_norms
+
+subroutine normalise_data(gathered_lonlat)
+ use gw_nlgw_utils, only: cbrt
+ type(lonlat_vars), intent(inout) :: gathered_lonlat
+
+ gathered_lonlat%u = (gathered_lonlat%u-u_mean)/(3._r8 * u_std)
+ gathered_lonlat%v = (gathered_lonlat%v-v_mean)/(3._r8 * v_std)
+ gathered_lonlat%theta = (gathered_lonlat%theta-theta_mean)/theta_std
+ gathered_lonlat%w = (gathered_lonlat%w-w_mean)/w_std
+ gathered_lonlat%w = cbrt(gathered_lonlat%w)
+
+end subroutine normalise_data
+
+subroutine construct_input(gathered_lonlat)
+
+ type(lonlat_vars), intent(inout) :: gathered_lonlat
+ integer :: idx_beg, idx_end, i
+
+ idx_end = 0
+
+ idx_beg = idx_end + 1
+ idx_end = idx_end + pver
+ net_inputs(1,idx_beg:idx_end,:,:) = reshape(gathered_lonlat%u, shape=[pver, nlat, nlon], order=[3,2,1])
+ idx_beg = idx_end + 1
+ idx_end = idx_end + pver
+ net_inputs(1,idx_beg:idx_end,:,:) = reshape(gathered_lonlat%v, shape=[pver, nlat, nlon], order=[3,2,1])
+ idx_beg = idx_end + 1
+ idx_end = idx_end + pver
+ net_inputs(1,idx_beg:idx_end,:,:) = reshape(gathered_lonlat%theta, shape=[pver, nlat, nlon], order=[3,2,1])
+ idx_beg = idx_end + 1
+ idx_end = idx_end + pver
+ net_inputs(1,idx_beg:idx_end,:,:) = reshape(gathered_lonlat%w, shape=[pver, nlat, nlon], order=[3,2,1])
+
+end subroutine construct_input
+
+subroutine extract_output(gathered_lonlat)
+
+ type(lonlat_vars), intent(inout) :: gathered_lonlat
+
+ gathered_lonlat%uflux(:,:,:) = reshape(net_outputs(1,:pver,:,:), shape=[nlon, nlat, pver], order=[3,2,1])
+ gathered_lonlat%vflux(:,:,:) = reshape(net_outputs(1,pver+1:,:,:), shape=[nlon, nlat, pver], order=[3,2,1])
+
+end subroutine extract_output
+
+subroutine denormalise_data(gathered_lonlat)
+
+ type(lonlat_vars), intent(inout) :: gathered_lonlat
+
+ gathered_lonlat%uflux = gathered_lonlat%uflux**3._r8 * uflux_std + uflux_mean
+ gathered_lonlat%vflux = gathered_lonlat%vflux**3._r8 * vflux_std + vflux_mean
+
+end subroutine denormalise_data
+
+end module gw_nlgw_unet
diff --git a/src/physics/cam/gw_nlgw_utils.F90 b/src/physics/cam/gw_nlgw_utils.F90
new file mode 100644
index 0000000000..b16b5d0fbf
--- /dev/null
+++ b/src/physics/cam/gw_nlgw_utils.F90
@@ -0,0 +1,88 @@
+module gw_nlgw_utils
+
+use gw_utils, only: r8, r4
+use ppgrid, only: begchunk, endchunk, pcols, pver, pverp
+
+implicit none
+
+public :: cbrt, flux_to_forcing
+public :: phys_vars, lonlat_vars
+integer, parameter, public :: p0 = 100000 ! 1000 hPa (Pa)
+integer, parameter, public :: nlon = 288 ! number of longitude points on lonlat grid
+integer, parameter, public :: nlat = 192 ! number of latitude points on lonlat grid
+character(len=132), public :: gw_nlgw_model_path_ann
+character(len=132), public :: gw_nlgw_model_path_unet
+
+private
+
+! variables on cubed-sphere "phys" grid
+type phys_vars
+!dimension(pver,pcols,begchunk:endchunk)
+real(r8), dimension(:,:,:), allocatable :: &
+ u, &! zonal wind (m/s)
+ v, &! meridional wind (m/s)
+ theta, &! temperature (K)
+ w, &! vertical pressure velocity (Pa/s)
+ pmid ! midpoint pressure (Pa)
+
+real(r8), dimension(:,:,:), allocatable :: &
+ uflux, &! zonal fluxes
+ vflux ! meridional fluxes
+
+real(r8), dimension(:,:,:), allocatable :: &
+ utgw, &! zonal tendencies
+ vtgw ! meridional tendencies
+
+! for debugging only
+! dimension(pcols,begchunk:endchunk)
+real(r8), dimension(:,:), allocatable :: &
+ lat, &
+ lon
+end type
+
+! variables on regular lonlat grid
+type lonlat_vars
+! dimension(lon,lat,pver)
+real(r8), dimension(:,:,:), allocatable :: &
+ u, &! zonal wind (m/s)
+ v, &! meridional wind (m/s)
+ theta, &! temperature (K)
+ w ! vertical pressure velocity (Pa/s)
+
+real(r8), dimension(:,:,:), allocatable :: &
+ uflux, &! zonal fluxes
+ vflux ! meridional fluxes
+end type
+
+contains
+
+elemental function cbrt(a) result(root)
+ real(r8), intent(in) :: a
+ real(r8), parameter :: one_third = 1._r8/3._r8
+ real(r8) :: root
+ root = sign(abs(a)**one_third, a)
+end function cbrt
+
+subroutine flux_to_forcing(flux, forcing, pmid, ncol)
+
+ real(r8), intent(in), dimension(:,:) :: flux ! flux (Pa m/s^2) !TODO check with Aman
+ real(r8), intent(in), dimension(:,:) :: pmid ! midpoint pressure (Pa)
+ real(r8), intent(out), dimension(:,:) :: forcing ! forcing = -d(u'\omega')/d(p), units = m/s^2
+ integer, intent(in) :: ncol
+
+ integer :: level, col
+
+ ! convert fluxes to tendencies
+ ! pressure profile must be in Pascals
+
+ do col = 1, ncol
+ forcing(col,1) = -1*(flux(col,2) - flux(col,1))/(pmid(col,2) - pmid(col,1))
+ do level = 2, pver-1
+ forcing(col,level) = (flux(col,level+1) - flux(col,level-1)) / (pmid(col,level)*(log(pmid(col,level+1)) - log(pmid(col,level-1))))
+ end do
+ forcing(col,pver) = -1*(flux(col,pver) - flux(col,pver-1)) / (pmid(col,pver) - pmid(col,pver-1))
+ end do
+
+end subroutine flux_to_forcing
+
+end module gw_nlgw_utils
diff --git a/src/physics/cam/phys_control.F90 b/src/physics/cam/phys_control.F90
index 7497385583..93263e29a3 100644
--- a/src/physics/cam/phys_control.F90
+++ b/src/physics/cam/phys_control.F90
@@ -98,7 +98,8 @@ module phys_control
logical, public, protected :: use_gw_front_igw = .false. ! Frontogenesis to inertial spectrum.
logical, public, protected :: use_gw_convect_dp = .false. ! Deep convection.
logical, public, protected :: use_gw_convect_sh = .false. ! Shallow convection.
-logical, public, protected :: use_gw_nlgw = .false. ! non local GW ML model
+logical, public, protected :: use_gw_nlgw_ann = .false. ! non local GW ML model (ANN - single column)
+logical, public, protected :: use_gw_nlgw_unet = .false. ! non local GW ML model (UNet - global non local)
! FV dycore angular momentum correction
logical, public, protected :: fv_am_correction = .false.
@@ -137,7 +138,7 @@ subroutine phys_ctl_readnl(nlfile)
history_waccmx, history_chemistry, history_carma, history_clubb, history_dust, &
history_cesm_forcing, history_scwaccm_forcing, history_chemspecies_srf, &
do_clubb_sgs, state_debug_checks, use_hetfrz_classnuc, use_gw_oro, use_gw_front, &
- use_gw_front_igw, use_gw_convect_dp, use_gw_convect_sh, use_gw_nlgw, cld_macmic_num_steps, &
+ use_gw_front_igw, use_gw_convect_dp, use_gw_convect_sh, use_gw_nlgw_ann, use_gw_nlgw_unet, cld_macmic_num_steps, &
offline_driver, convproc_do_aer, cam_snapshot_before_num, cam_snapshot_after_num, &
cam_take_snapshot_before, cam_take_snapshot_after, cam_physics_mesh, use_hemco, do_hb_above_clubb
!-----------------------------------------------------------------------------
@@ -157,7 +158,10 @@ subroutine phys_ctl_readnl(nlfile)
end if
! if we are using the nlgw ML model we need to disable all other parameterizations
- if (masterproc .and. use_gw_nlgw==.true.) then
+ if (masterproc .and. (use_gw_nlgw_ann .or. use_gw_nlgw_unet)) then
+ if (use_gw_nlgw_ann .and. use_gw_nlgw_unet) then
+ call endrun(subname // ':: ERROR you can only select UNet or ANN not both.')
+ end if
use_gw_oro = .false.
use_gw_front = .false.
use_gw_front_igw = .false.
@@ -203,7 +207,8 @@ subroutine phys_ctl_readnl(nlfile)
call mpi_bcast(use_gw_front_igw, 1, mpi_logical, masterprocid, mpicom, ierr)
call mpi_bcast(use_gw_convect_dp, 1, mpi_logical, masterprocid, mpicom, ierr)
call mpi_bcast(use_gw_convect_sh, 1, mpi_logical, masterprocid, mpicom, ierr)
- call mpi_bcast(use_gw_nlgw, 1, mpi_logical, masterprocid, mpicom, ierr)
+ call mpi_bcast(use_gw_nlgw_ann, 1, mpi_logical, masterprocid, mpicom, ierr)
+ call mpi_bcast(use_gw_nlgw_unet, 1, mpi_logical, masterprocid, mpicom, ierr)
call mpi_bcast(cld_macmic_num_steps, 1, mpi_integer, masterprocid, mpicom, ierr)
call mpi_bcast(offline_driver, 1, mpi_logical, masterprocid, mpicom, ierr)
call mpi_bcast(convproc_do_aer, 1, mpi_logical, masterprocid, mpicom, ierr)
diff --git a/src/physics/cam_dev/physpkg.F90 b/src/physics/cam_dev/physpkg.F90
index 1c461c9a1c..c5f0e15353 100644
--- a/src/physics/cam_dev/physpkg.F90
+++ b/src/physics/cam_dev/physpkg.F90
@@ -768,6 +768,7 @@ subroutine phys_init( phys_state, phys_tend, pbuf2d, cam_in, cam_out )
use cam_history, only: addfld, register_vector_field, add_default
use cam_budget, only: cam_budget_init
use phys_grid_ctem, only: phys_grid_ctem_init
+ use phys_control, only: use_gw_nlgw_unet
use ccpp_constituent_prop_mod, only: ccpp_const_props_init
@@ -949,6 +950,11 @@ subroutine phys_init( phys_state, phys_tend, pbuf2d, cam_in, cam_out )
end if
+
+ if (use_gw_nlgw_unet) then
+ call addfld('UTGW_NL', (/ 'lev' /), 'A', 'm/s2', 'Nonlinear GW zonal wind tendency')
+ call addfld('VTGW_NL', (/ 'lev' /), 'A', 'm/s2', 'Nonlinear GW meridional wind tendency')
+ end if
! Initialize CAM CCPP constituent properties array
! for use in CCPP-ized physics schemes:
call ccpp_const_props_init()
@@ -1183,6 +1189,12 @@ subroutine phys_run2(phys_state, ztodt, phys_tend, pbuf2d, cam_out, &
use metdata, only: get_met_srf2
#endif
use hemco_interface, only: HCOI_Chunk_Run
+ use nlgw_remap_mod, only: nlgw_regrid_init, nlgw_latlon_gather, nlgw_latlon_scatter, nlgw_regrid_final
+ use gw_nlgw_unet, only: gw_nlgw_unet_init, gw_nlgw_unet_infer, gw_nlgw_unet_finalize, gw_nlgw_unet_set_ptend
+ use gw_nlgw_utils, only: phys_vars, lonlat_vars, flux_to_forcing, gw_nlgw_model_path_unet
+ use phys_control, only: use_gw_nlgw_unet
+ use time_manager, only: get_nstep
+ use check_energy, only: check_energy_chng
!
! Input arguments
!
@@ -1202,10 +1214,19 @@ subroutine phys_run2(phys_state, ztodt, phys_tend, pbuf2d, cam_out, &
!
integer :: c ! chunk index
integer :: ncol ! number of columns
+ integer :: nstep ! current timestep number
type(physics_buffer_desc),pointer, dimension(:) :: phys_buffer_chunk
+
+ ! for ML UNet model
+ type(physics_ptend) :: ptend ! parameterization tendencies for nlgw
+ real(r8) :: zero(pcols) ! array of zeros
+ type(phys_vars), target :: phys
+ type(lonlat_vars), target :: lonlat, gathered_lonlat
+ real(r8), dimension(pcols,pver) :: temp_uflux, temp_vflux, temp_utgw, temp_vtgw
!
! If exit condition just return
!
+ nstep = get_nstep()
if(single_column.and.scm_crm_mode) then
call diag_deallocate()
@@ -1244,6 +1265,40 @@ subroutine phys_run2(phys_state, ztodt, phys_tend, pbuf2d, cam_out, &
call t_startf ('ac_physics')
call t_adj_detailf(+1)
+ call nlgw_regrid_init(phys, lonlat, gathered_lonlat)
+
+ if (use_gw_nlgw_unet) then
+ ! gather data from all procs, all chunks into a global lonlat grid
+ call nlgw_latlon_gather(phys_state, phys, lonlat, gathered_lonlat)
+
+ if (masterproc) then
+ call gw_nlgw_unet_init(gw_nlgw_model_path_unet)
+ ! run UNet model on globally gathered lonlat grid to compute fluxes
+ call gw_nlgw_unet_infer(gathered_lonlat)
+ call gw_nlgw_unet_finalize()
+ endif
+
+ ! scatter back to all procs, into chunks and regrid back to phys grid
+ call nlgw_latlon_scatter(phys, lonlat, gathered_lonlat)
+
+ do c = begchunk,endchunk
+ ncol = phys_state(c)%ncol
+ temp_uflux(:ncol,:pver) = transpose(phys%uflux(:pver,:ncol,c))
+ temp_vflux(:ncol,:pver) = transpose(phys%vflux(:pver,:ncol,c))
+
+ ! compute tendencies from fluxes
+ call flux_to_forcing(temp_uflux, temp_utgw, phys_state(c)%pmid, ncol)
+ call flux_to_forcing(temp_vflux, temp_vtgw, phys_state(c)%pmid, ncol)
+ ! update ptend
+ call gw_nlgw_unet_set_ptend(phys_state(c), ptend, temp_utgw, temp_vtgw)
+ ! update state and check energy change
+ call physics_update(phys_state(c), ptend, ztodt, phys_tend(c))
+ call check_energy_chng(phys_state(c), phys_tend(c), "nlgw_unet", nstep, ztodt, zero, zero, zero, zero)
+ end do
+
+ call nlgw_regrid_final(phys, lonlat, gathered_lonlat)
+ endif
+
!$OMP PARALLEL DO PRIVATE (C, NCOL, phys_buffer_chunk)
do c=begchunk,endchunk
diff --git a/src/utils/cam_grid_support.F90 b/src/utils/cam_grid_support.F90
index d86c829e77..56d5c100f7 100644
--- a/src/utils/cam_grid_support.F90
+++ b/src/utils/cam_grid_support.F90
@@ -316,6 +316,7 @@ end subroutine print_attr_spec
public :: cam_grid_compute_patch
! Functions for dealing with grid areas
public :: cam_grid_get_areawt
+ public :: get_cam_grid_index
interface cam_grid_attribute_register
module procedure add_cam_grid_attribute_0d_int
@@ -352,7 +353,6 @@ end subroutine print_attr_spec
module procedure cam_grid_write_dist_array_3d_real
end interface
- ! Private interfaces
interface get_cam_grid_index
module procedure get_cam_grid_index_char ! For lookup by name
module procedure get_cam_grid_index_int ! For lookup by ID
diff --git a/src/utils/esmf_check_error_mod.F90 b/src/utils/esmf_check_error_mod.F90
new file mode 100644
index 0000000000..eb9db18a90
--- /dev/null
+++ b/src/utils/esmf_check_error_mod.F90
@@ -0,0 +1,33 @@
+!------------------------------------------------------------------------------
+! ESMF error handler
+!------------------------------------------------------------------------------
+module esmf_check_error_mod
+ use shr_kind_mod, only: cl=>SHR_KIND_CL
+ use spmd_utils, only: masterproc
+ use cam_logfile, only: iulog
+ use cam_abortutils, only: endrun
+ use ESMF, only: ESMF_SUCCESS
+
+ implicit none
+
+ private
+ public :: check_esmf_error
+
+contains
+
+ subroutine check_esmf_error( rc, errmsg )
+
+ integer, intent(in) :: rc
+ character(len=*), intent(in) :: errmsg
+
+ character(len=cl) :: errstr
+
+ if (rc /= ESMF_SUCCESS) then
+ write(errstr,'(a,i6)') 'esmf_zonal_mod::'//trim(errmsg)//' -- ESMF ERROR code: ',rc
+ if (masterproc) write(iulog,*) trim(errstr)
+ call endrun(trim(errstr))
+ end if
+
+ end subroutine check_esmf_error
+
+end module esmf_check_error_mod
diff --git a/src/utils/esmf_lonlat2phys_mod.F90 b/src/utils/esmf_lonlat2phys_mod.F90
new file mode 100644
index 0000000000..e66ed27635
--- /dev/null
+++ b/src/utils/esmf_lonlat2phys_mod.F90
@@ -0,0 +1,160 @@
+!------------------------------------------------------------------------------
+! Provides methods for mapping from regular longitude / latitude grid
+! to physics grid to via ESMF regridding capabilities
+!------------------------------------------------------------------------------
+module esmf_lonlat2phys_mod
+ use shr_kind_mod, only: r8 => shr_kind_r8
+ use cam_logfile, only: iulog
+ use cam_abortutils, only: endrun
+ use spmd_utils, only: masterproc
+ use ppgrid, only: pver
+
+ use ESMF, only: ESMF_RouteHandle, ESMF_Field, ESMF_ArraySpec, ESMF_ArraySpecSet
+ use ESMF, only: ESMF_FieldCreate, ESMF_FieldRegridStore
+ use ESMF, only: ESMF_FieldGet, ESMF_FieldRegrid
+ use ESMF, only: ESMF_KIND_I4, ESMF_KIND_R8, ESMF_TYPEKIND_R8
+ use ESMF, only: ESMF_REGRIDMETHOD_BILINEAR, ESMF_POLEMETHOD_ALLAVG, ESMF_EXTRAPMETHOD_NEAREST_IDAVG
+ use ESMF, only: ESMF_TERMORDER_SRCSEQ, ESMF_MESHLOC_ELEMENT, ESMF_STAGGERLOC_CENTER
+ use ESMF, only: ESMF_FieldDestroy, ESMF_RouteHandleDestroy
+ use esmf_check_error_mod, only: check_esmf_error
+
+ implicit none
+
+ private
+
+ public :: esmf_lonlat2phys_init
+ public :: esmf_lonlat2phys_regrid
+ public :: esmf_lonlat2phys_destroy
+ public :: fields_bundle_t
+ public :: n_flx_flds
+
+ type(ESMF_RouteHandle) :: rh_lonlat2phys_3d
+
+ type(ESMF_Field) :: physfld_3d
+ type(ESMF_Field) :: lonlatfld_3d
+
+ interface esmf_lonlat2phys_regrid
+ module procedure esmf_lonlat2phys_regrid_3d
+ end interface esmf_lonlat2phys_regrid
+
+ type :: fields_bundle_t
+ real(r8), pointer :: fld(:,:,:) => null()
+ end type fields_bundle_t
+
+ integer, parameter :: n_flx_flds = 2 ! 3D uflux and vflux
+
+contains
+
+ !------------------------------------------------------------------------------
+ !------------------------------------------------------------------------------
+ subroutine esmf_lonlat2phys_init()
+ use esmf_phys_mesh_mod, only: physics_grid_mesh
+ use esmf_lonlat_grid_mod, only: lonlat_grid
+
+ type(ESMF_ArraySpec) :: arrayspec
+ integer(ESMF_KIND_I4), pointer :: factorIndexList(:,:)
+ real(ESMF_KIND_R8), pointer :: factorList(:)
+ integer :: smm_srctermproc, smm_pipelinedep, rc
+
+ character(len=*), parameter :: subname = 'esmf_lonlat2phys_init: '
+
+ smm_srctermproc = 0
+ smm_pipelinedep = 16
+
+ ! create ESMF fields
+
+ ! 3D phys fld
+ call ESMF_ArraySpecSet(arrayspec, 3, ESMF_TYPEKIND_R8, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_ArraySpecSet 3D phys fld ERROR')
+
+ physfld_3d = ESMF_FieldCreate(physics_grid_mesh, arrayspec, &
+ gridToFieldMap=(/3/), meshloc=ESMF_MESHLOC_ELEMENT, &
+ ungriddedLBound=(/1,1/), ungriddedUBound=(/pver,n_flx_flds/), rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldCreate 3D phys fld ERROR')
+
+ ! 3D lon lat grid
+ call ESMF_ArraySpecSet(arrayspec, 4, ESMF_TYPEKIND_R8, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_ArraySpecSet 3D lonlat fld ERROR')
+
+ lonlatfld_3d = ESMF_FieldCreate( lonlat_grid, arrayspec, staggerloc=ESMF_STAGGERLOC_CENTER, &
+ ungriddedLBound=(/1,1/), ungriddedUBound=(/pver,n_flx_flds/), rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldCreate 3D lonlat fld ERROR')
+
+ call ESMF_FieldRegridStore(srcField=lonlatfld_3d, dstField=physfld_3d, &
+ regridMethod=ESMF_REGRIDMETHOD_BILINEAR, &
+ polemethod=ESMF_POLEMETHOD_ALLAVG, &
+ extrapMethod=ESMF_EXTRAPMETHOD_NEAREST_IDAVG, &
+ routeHandle=rh_lonlat2phys_3d, factorIndexList=factorIndexList, &
+ factorList=factorList, srcTermProcessing=smm_srctermproc, &
+ pipelineDepth=smm_pipelinedep, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldRegridStore 3D routehandle ERROR')
+
+ end subroutine esmf_lonlat2phys_init
+
+ !------------------------------------------------------------------------------
+ !------------------------------------------------------------------------------
+ subroutine esmf_lonlat2phys_regrid_3d(lonlatflds, physflds)
+ use esmf_lonlat_grid_mod, only: lon_beg,lon_end,lat_beg,lat_end
+ use ppgrid, only: pcols, pver, begchunk, endchunk
+ use phys_grid, only: get_ncols_p
+
+ type(fields_bundle_t), intent(in) :: lonlatflds(n_flx_flds)
+ type(fields_bundle_t), intent(inout) :: physflds(n_flx_flds)
+
+ integer :: i, ichnk, ncol, ifld, ilev, icol, rc
+ real(ESMF_KIND_R8), pointer :: physptr(:,:,:)
+ real(ESMF_KIND_R8), pointer :: lonlatptr(:,:,:,:)
+
+ character(len=*), parameter :: subname = 'esmf_lonlat2phys_regrid_3d: '
+
+ ! set values of lonlatfld_3d ESMF field
+ call ESMF_FieldGet(lonlatfld_3d, localDe=0, farrayPtr=lonlatptr, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldGet lonlatptr')
+
+ do ifld = 1,n_flx_flds
+ lonlatptr(lon_beg:lon_end,lat_beg:lat_end,1:pver,ifld) = lonlatflds(ifld)%fld(lon_beg:lon_end,lat_beg:lat_end,1:pver)
+ end do
+
+ ! regrid
+ call ESMF_FieldRegrid(lonlatfld_3d, physfld_3d, rh_lonlat2phys_3d, &
+ termorderflag=ESMF_TERMORDER_SRCSEQ, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldRegrid lonlatfld_3d->physfld_3d')
+
+ ! get values from physfld_3d ESMF field
+ call ESMF_FieldGet(physfld_3d, localDe=0, farrayPtr=physptr, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldGet physptr')
+
+ i = 0
+ do ichnk = begchunk, endchunk
+ ncol = get_ncols_p(ichnk)
+ do icol = 1,ncol
+ i = i+1
+ do ifld = 1,n_flx_flds
+ do ilev = 1,pver
+ physflds(ifld)%fld(ilev,icol,ichnk) = physptr(ilev,ifld,i)
+ end do
+ end do
+ end do
+ end do
+
+ end subroutine esmf_lonlat2phys_regrid_3d
+
+ !------------------------------------------------------------------------------
+ !------------------------------------------------------------------------------
+ subroutine esmf_lonlat2phys_destroy()
+
+ integer :: rc
+ character(len=*), parameter :: subname = 'esmf_lonlat2phys_destroy: '
+
+ call ESMF_RouteHandleDestroy(rh_lonlat2phys_3d, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldDestroy rh_lonlat2phys_3d')
+
+ call ESMF_FieldDestroy(lonlatfld_3d, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldDestroy lonlatfld_3d')
+
+ call ESMF_FieldDestroy(physfld_3d, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldDestroy physfld_3d')
+
+ end subroutine esmf_lonlat2phys_destroy
+
+end module esmf_lonlat2phys_mod
diff --git a/src/utils/esmf_lonlat_grid_mod.F90 b/src/utils/esmf_lonlat_grid_mod.F90
new file mode 100644
index 0000000000..394322ddc6
--- /dev/null
+++ b/src/utils/esmf_lonlat_grid_mod.F90
@@ -0,0 +1,339 @@
+!-------------------------------------------------------------------------------
+! Encapsulates an ESMF regular longitude / latitude grid
+!-------------------------------------------------------------------------------
+module esmf_lonlat_grid_mod
+ use shr_kind_mod, only: r8 => shr_kind_r8
+ use spmd_utils, only: masterproc, mpicom
+ use cam_logfile, only: iulog
+ use cam_abortutils, only: endrun
+
+ use ESMF, only: ESMF_Grid, ESMF_GridCreate1PeriDim, ESMF_GridAddCoord
+ use ESMF, only: ESMF_GridGetCoord, ESMF_GridDestroy
+ use ESMF, only: ESMF_KIND_R8, ESMF_INDEX_GLOBAL, ESMF_STAGGERLOC_CENTER
+ use esmf_check_error_mod, only: check_esmf_error
+
+ implicit none
+
+ public
+
+ type(ESMF_Grid), protected :: lonlat_grid
+
+ integer, protected :: nlon = 0
+ integer, protected :: nlat = 0
+
+ integer, protected :: lon_beg = -1
+ integer, protected :: lon_end = -1
+ integer, protected :: lat_beg = -1
+ integer, protected :: lat_end = -1
+
+ real(r8), allocatable, protected :: glats(:)
+ real(r8), allocatable, protected :: glons(:)
+
+ integer, protected :: zonal_comm ! zonal direction MPI communicator
+
+contains
+
+ subroutine esmf_lonlat_grid_init(nlats_in, nlons_in)
+ use phys_grid, only: get_grid_dims
+ use mpi, only: mpi_comm_size, mpi_comm_rank, MPI_PROC_NULL, MPI_INTEGER
+
+ integer, intent(in) :: nlats_in
+ integer, intent(in) :: nlons_in
+
+ real(r8) :: delx, dely
+
+ integer :: npes, ierr, mytid, irank, mytidi, mytidj
+ integer :: i,j, n
+ integer :: ntasks_lon, ntasks_lat
+ integer :: lons_per_task, lons_overflow, lats_per_task, lats_overflow
+ integer :: task_cnt
+ integer :: mynlats, mynlons
+
+ integer, allocatable :: mytidi_send(:)
+ integer, allocatable :: mytidj_send(:)
+ integer, allocatable :: mytidi_recv(:)
+ integer, allocatable :: mytidj_recv(:)
+
+ integer, allocatable :: nlons_send(:)
+ integer, allocatable :: nlats_send(:)
+ integer, allocatable :: nlons_recv(:)
+ integer, allocatable :: nlats_recv(:)
+
+ integer, allocatable :: nlons_task(:)
+ integer, allocatable :: nlats_task(:)
+
+ integer, parameter :: minlats_per_pe = 2
+ integer, parameter :: minlons_per_pe = 2
+
+ integer, allocatable :: petmap(:,:,:)
+ integer :: petcnt, astat
+
+ integer :: lbnd_lat, ubnd_lat, lbnd_lon, ubnd_lon
+ integer :: lbnd(1), ubnd(1)
+ real(ESMF_KIND_R8), pointer :: coordX(:), coordY(:)
+
+ character(len=*), parameter :: subname = 'esmf_lonlat_grid_init: '
+
+ ! create reg lon lat grid
+
+ nlat = nlats_in
+ dely = 180._r8/nlat
+
+ nlon = nlons_in
+ delx = 360._r8/nlon
+
+ allocate(glons(nlon), stat=astat)
+ if (astat/=0) then
+ call endrun(subname//'not able to allocate glons array')
+ end if
+ allocate(glats(nlat), stat=astat)
+ if (astat/=0) then
+ call endrun(subname//'not able to allocate glats array')
+ end if
+
+ glons(1) = 0._r8
+ glats(1) = -90._r8 + 0.5_r8 * dely
+
+ do i = 2,nlon
+ glons(i) = glons(i-1) + delx
+ end do
+ do i = 2,nlat
+ glats(i) = glats(i-1) + dely
+ end do
+
+ ! decompose the grid across mpi tasks ...
+
+ call mpi_comm_size(mpicom, npes, ierr)
+ call mpi_comm_rank(mpicom, mytid, ierr)
+
+ decomp_loop: do ntasks_lon = 1,nlon
+ ntasks_lat = npes/ntasks_lon
+ if ( (minlats_per_pe*ntasks_latmytid) exit jloop
+ end do
+ enddo jloop
+ endif
+
+ mynlats = lat_end-lat_beg+1
+ mynlons = lon_end-lon_beg+1
+
+ if (mynlats shr_kind_r8
+ use cam_logfile, only: iulog
+ use cam_abortutils, only: endrun
+ use spmd_utils, only: masterproc
+ use ppgrid, only: pver
+
+ use ESMF, only: ESMF_RouteHandle, ESMF_Field, ESMF_ArraySpec, ESMF_ArraySpecSet
+ use ESMF, only: ESMF_FieldCreate, ESMF_FieldRegridStore
+ use ESMF, only: ESMF_FieldGet, ESMF_FieldRegrid
+ use ESMF, only: ESMF_KIND_I4, ESMF_KIND_R8, ESMF_TYPEKIND_R8
+ use ESMF, only: ESMF_REGRIDMETHOD_BILINEAR, ESMF_POLEMETHOD_ALLAVG, ESMF_EXTRAPMETHOD_NEAREST_IDAVG
+ use ESMF, only: ESMF_TERMORDER_SRCSEQ, ESMF_MESHLOC_ELEMENT, ESMF_STAGGERLOC_CENTER
+ use ESMF, only: ESMF_FieldDestroy, ESMF_RouteHandleDestroy
+ use esmf_check_error_mod, only: check_esmf_error
+
+ implicit none
+
+ private
+
+ public :: esmf_phys2lonlat_init
+ public :: esmf_phys2lonlat_regrid
+ public :: esmf_phys2lonlat_destroy
+ public :: fields_bundle_t
+ public :: nflds
+
+ type(ESMF_RouteHandle) :: rh_phys2lonlat_3d
+ type(ESMF_RouteHandle) :: rh_phys2lonlat_2d
+
+ type(ESMF_Field) :: physfld_3d
+ type(ESMF_Field) :: lonlatfld_3d
+
+ type(ESMF_Field) :: physfld_2d
+ type(ESMF_Field) :: lonlatfld_2d
+
+ interface esmf_phys2lonlat_regrid
+ module procedure esmf_phys2lonlat_regrid_2d
+ module procedure esmf_phys2lonlat_regrid_3d
+ end interface esmf_phys2lonlat_regrid
+
+ type :: fields_bundle_t
+ real(r8), pointer :: fld(:,:,:) => null()
+ end type fields_bundle_t
+
+ integer, parameter :: nflds = 4
+
+contains
+
+ !------------------------------------------------------------------------------
+ !------------------------------------------------------------------------------
+ subroutine esmf_phys2lonlat_init()
+ use esmf_phys_mesh_mod, only: physics_grid_mesh
+ use esmf_lonlat_grid_mod, only: lonlat_grid
+
+ type(ESMF_ArraySpec) :: arrayspec
+ integer(ESMF_KIND_I4), pointer :: factorIndexList(:,:)
+ real(ESMF_KIND_R8), pointer :: factorList(:)
+ integer :: smm_srctermproc, smm_pipelinedep, rc
+
+ character(len=*), parameter :: subname = 'esmf_phys2lonlat_init: '
+
+ smm_srctermproc = 0
+ smm_pipelinedep = 16
+
+ ! create ESMF fields
+
+ ! 3D phys fld
+ call ESMF_ArraySpecSet(arrayspec, 3, ESMF_TYPEKIND_R8, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_ArraySpecSet 3D phys fld ERROR')
+
+ physfld_3d = ESMF_FieldCreate(physics_grid_mesh, arrayspec, &
+ gridToFieldMap=(/3/), meshloc=ESMF_MESHLOC_ELEMENT, &
+ ungriddedLBound=(/1,1/), ungriddedUBound=(/pver,nflds/), rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldCreate 3D phys fld ERROR')
+
+ ! 3D lon lat grid
+ call ESMF_ArraySpecSet(arrayspec, 4, ESMF_TYPEKIND_R8, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_ArraySpecSet 3D lonlat fld ERROR')
+
+ lonlatfld_3d = ESMF_FieldCreate( lonlat_grid, arrayspec, staggerloc=ESMF_STAGGERLOC_CENTER, &
+ ungriddedLBound=(/1,1/), ungriddedUBound=(/pver,nflds/), rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldCreate 3D lonlat fld ERROR')
+
+ ! 2D phys fld
+ call ESMF_ArraySpecSet(arrayspec, 1, ESMF_TYPEKIND_R8, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_ArraySpecSet 2D phys fld ERROR')
+
+ physfld_2d = ESMF_FieldCreate(physics_grid_mesh, arrayspec, &
+ meshloc=ESMF_MESHLOC_ELEMENT, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldCreate 2D phys fld ERROR')
+
+ ! 2D lon/lat grid
+ call ESMF_ArraySpecSet(arrayspec, 2, ESMF_TYPEKIND_R8, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_ArraySpecSet 2D lonlat fld ERROR')
+
+ lonlatfld_2d = ESMF_FieldCreate( lonlat_grid, arrayspec, staggerloc=ESMF_STAGGERLOC_CENTER, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldCreate 2D lonlat fld ERROR')
+
+ call ESMF_FieldRegridStore(srcField=physfld_3d, dstField=lonlatfld_3d, &
+ regridMethod=ESMF_REGRIDMETHOD_BILINEAR, &
+ polemethod=ESMF_POLEMETHOD_ALLAVG, &
+ extrapMethod=ESMF_EXTRAPMETHOD_NEAREST_IDAVG, &
+ routeHandle=rh_phys2lonlat_3d, factorIndexList=factorIndexList, &
+ factorList=factorList, srcTermProcessing=smm_srctermproc, &
+ pipelineDepth=smm_pipelinedep, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldRegridStore 3D routehandle ERROR')
+
+ call ESMF_FieldRegridStore(srcField=physfld_2d, dstField=lonlatfld_2d, &
+ regridMethod=ESMF_REGRIDMETHOD_BILINEAR, &
+ polemethod=ESMF_POLEMETHOD_ALLAVG, &
+ extrapMethod=ESMF_EXTRAPMETHOD_NEAREST_IDAVG, &
+ routeHandle=rh_phys2lonlat_2d, factorIndexList=factorIndexList, &
+ factorList=factorList, srcTermProcessing=smm_srctermproc, &
+ pipelineDepth=smm_pipelinedep, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldRegridStore 3D routehandle ERROR')
+
+ end subroutine esmf_phys2lonlat_init
+
+ !------------------------------------------------------------------------------
+ !------------------------------------------------------------------------------
+ subroutine esmf_phys2lonlat_regrid_3d(physflds, lonlatflds)
+ use esmf_lonlat_grid_mod, only: lon_beg,lon_end,lat_beg,lat_end
+ use ppgrid, only: pcols, pver, begchunk, endchunk
+ use phys_grid, only: get_ncols_p
+
+ type(fields_bundle_t), intent(in) :: physflds(nflds)
+ type(fields_bundle_t), intent(inout) :: lonlatflds(nflds)
+
+ integer :: i, ichnk, ncol, ifld, ilev, icol, rc
+ real(ESMF_KIND_R8), pointer :: physptr(:,:,:)
+ real(ESMF_KIND_R8), pointer :: lonlatptr(:,:,:,:)
+
+ character(len=*), parameter :: subname = 'esmf_phys2lonlat_regrid_3d: '
+
+ call ESMF_FieldGet(physfld_3d, localDe=0, farrayPtr=physptr, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldGet physptr')
+
+ i = 0
+ do ichnk = begchunk, endchunk
+ ncol = get_ncols_p(ichnk)
+ do icol = 1,ncol
+ i = i+1
+ do ifld = 1,nflds
+ do ilev = 1,pver
+ physptr(ilev,ifld,i) = physflds(ifld)%fld(ilev,icol,ichnk)
+ end do
+ end do
+ end do
+ end do
+
+ call ESMF_FieldRegrid(physfld_3d, lonlatfld_3d, rh_phys2lonlat_3d, &
+ termorderflag=ESMF_TERMORDER_SRCSEQ, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldRegrid physfld_3d->lonlatfld_3d')
+
+ call ESMF_FieldGet(lonlatfld_3d, localDe=0, farrayPtr=lonlatptr, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldGet lonlatptr')
+
+ do ifld = 1,nflds
+ lonlatflds(ifld)%fld(lon_beg:lon_end,lat_beg:lat_end,1:pver) = lonlatptr(lon_beg:lon_end,lat_beg:lat_end,1:pver,ifld)
+ end do
+
+ end subroutine esmf_phys2lonlat_regrid_3d
+
+ !------------------------------------------------------------------------------
+ !------------------------------------------------------------------------------
+ subroutine esmf_phys2lonlat_regrid_2d(physarr, lonlatarr)
+ use esmf_lonlat_grid_mod, only: lon_beg,lon_end,lat_beg,lat_end
+ use ppgrid, only: pcols, pver, begchunk, endchunk
+ use phys_grid, only: get_ncols_p
+
+ real(r8),intent(in) :: physarr(pcols,begchunk:endchunk)
+ real(r8),intent(out) :: lonlatarr(lon_beg:lon_end,lat_beg:lat_end)
+
+ integer :: i, ichnk, ncol, icol, rc
+ real(ESMF_KIND_R8), pointer :: physptr(:)
+ real(ESMF_KIND_R8), pointer :: lonlatptr(:,:)
+
+ character(len=*), parameter :: subname = 'esmf_phys2lonlat_regrid_2d: '
+
+ call ESMF_FieldGet(physfld_2d, localDe=0, farrayPtr=physptr, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldGet physptr')
+
+ i = 0
+ do ichnk = begchunk, endchunk
+ ncol = get_ncols_p(ichnk)
+ do icol = 1,ncol
+ i = i+1
+ physptr(i) = physarr(icol,ichnk)
+ end do
+ end do
+
+ call ESMF_FieldRegrid(physfld_2d, lonlatfld_2d, rh_phys2lonlat_2d, &
+ termorderflag=ESMF_TERMORDER_SRCSEQ, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldRegrid physfld_3d->lonlatfld_3d')
+
+ call ESMF_FieldGet(lonlatfld_2d, localDe=0, farrayPtr=lonlatptr, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldGet lonlatptr')
+
+ lonlatarr(lon_beg:lon_end,lat_beg:lat_end) = lonlatptr(lon_beg:lon_end,lat_beg:lat_end)
+
+ end subroutine esmf_phys2lonlat_regrid_2d
+
+ !------------------------------------------------------------------------------
+ !------------------------------------------------------------------------------
+ subroutine esmf_phys2lonlat_destroy()
+
+ integer :: rc
+ character(len=*), parameter :: subname = 'esmf_phys2lonlat_destroy: '
+
+ call ESMF_RouteHandleDestroy(rh_phys2lonlat_3d, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldDestroy rh_phys2lonlat_3d')
+
+ call ESMF_RouteHandleDestroy(rh_phys2lonlat_2d, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldDestroy rh_phys2lonlat_2d')
+
+ call ESMF_FieldDestroy(lonlatfld_3d, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldDestroy lonlatfld_3d')
+
+ call ESMF_FieldDestroy(lonlatfld_2d, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldDestroy lonlatfld_2d')
+
+ call ESMF_FieldDestroy(physfld_3d, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldDestroy physfld_3d')
+
+ call ESMF_FieldDestroy(physfld_2d, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_FieldDestroy physfld_2d')
+
+ end subroutine esmf_phys2lonlat_destroy
+
+end module esmf_phys2lonlat_mod
diff --git a/src/utils/esmf_phys_mesh_mod.F90 b/src/utils/esmf_phys_mesh_mod.F90
new file mode 100644
index 0000000000..4235452dc5
--- /dev/null
+++ b/src/utils/esmf_phys_mesh_mod.F90
@@ -0,0 +1,202 @@
+!-------------------------------------------------------------------------------
+! Encapsulates the CAM physics grid mesh
+!-------------------------------------------------------------------------------
+module esmf_phys_mesh_mod
+ use shr_kind_mod, only: r8 => shr_kind_r8, cs=>shr_kind_cs, cl=>shr_kind_cl
+ use spmd_utils, only: masterproc
+ use cam_logfile, only: iulog
+ use cam_abortutils, only: endrun
+ use ESMF, only: ESMF_DistGrid, ESMF_DistGridCreate, ESMF_MeshCreate
+ use ESMF, only: ESMF_FILEFORMAT_ESMFMESH,ESMF_MeshGet,ESMF_Mesh, ESMF_SUCCESS
+ use ESMF, only: ESMF_MeshDestroy, ESMF_DistGridDestroy
+ use esmf_check_error_mod, only: check_esmf_error
+
+ implicit none
+
+ private
+
+ public :: esmf_phys_mesh_init
+ public :: esmf_phys_mesh_destroy
+ public :: physics_grid_mesh
+
+ ! phys_mesh: Local copy of physics grid
+ type(ESMF_Mesh), protected :: physics_grid_mesh
+
+ ! dist_grid_2d: DistGrid for 2D fields
+ type(ESMF_DistGrid) :: dist_grid_2d
+
+contains
+
+ !-----------------------------------------------------------------------------
+ !-----------------------------------------------------------------------------
+ subroutine esmf_phys_mesh_init()
+ use phys_control, only: phys_getopts
+ use phys_grid, only: get_ncols_p, get_gcol_p, get_rlon_all_p, get_rlat_all_p
+ use ppgrid, only: pcols, begchunk, endchunk
+ use shr_const_mod,only: shr_const_pi
+
+ ! Local variables
+ integer :: ncols
+ integer :: chnk, col, dindex
+ integer, allocatable :: decomp(:)
+ character(len=cl) :: grid_file
+ integer :: spatialDim
+ integer :: numOwnedElements
+ real(r8), pointer :: ownedElemCoords(:)
+ real(r8), pointer :: lat(:), latMesh(:)
+ real(r8), pointer :: lon(:), lonMesh(:)
+ real(r8) :: lats(pcols) ! array of chunk latitudes
+ real(r8) :: lons(pcols) ! array of chunk longitude
+ character(len=cs) :: tempc1,tempc2
+ character(len=300) :: errstr
+
+ integer :: i, c, n, total_cols, rc
+
+ real(r8), parameter :: abstol = 1.e-3_r8
+ real(r8), parameter :: radtodeg = 180.0_r8/shr_const_pi
+ character(len=*), parameter :: subname = 'esmf_phys_mesh_init: '
+
+ ! Find the physics grid file
+ call phys_getopts(physics_grid_out=grid_file)
+
+ ! Compute the local decomp
+ total_cols = 0
+ do chnk = begchunk, endchunk
+ total_cols = total_cols + get_ncols_p(chnk)
+ end do
+ allocate(decomp(total_cols), stat=rc)
+ if (rc/=0) then
+ call endrun(subname//'not able to allocate decomp')
+ end if
+
+ dindex = 0
+ do chnk = begchunk, endchunk
+ ncols = get_ncols_p(chnk)
+ do col = 1, ncols
+ dindex = dindex + 1
+ decomp(dindex) = get_gcol_p(chnk, col)
+ end do
+ end do
+
+ ! Create a DistGrid based on the physics decomp
+ dist_grid_2d = ESMF_DistGridCreate(arbSeqIndexList=decomp, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_DistGridCreate')
+
+ ! Create an ESMF_mesh for the physics decomposition
+ physics_grid_mesh = ESMF_MeshCreate(trim(grid_file), ESMF_FILEFORMAT_ESMFMESH, &
+ elementDistgrid=dist_grid_2d, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_MeshCreate')
+
+ ! Check that the mesh coordinates are consistent with the model physics column coordinates
+
+ ! obtain mesh lats and lons
+ call ESMF_MeshGet(physics_grid_mesh, spatialDim=spatialDim, numOwnedElements=numOwnedElements, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_MeshGet')
+
+ if (numOwnedElements /= total_cols) then
+ write(tempc1,'(i10)') numOwnedElements
+ write(tempc2,'(i10)') total_cols
+ call endrun(subname//"ERROR numOwnedElements "// &
+ trim(tempc1) //" not equal to local size "// trim(tempc2))
+ end if
+
+ allocate(ownedElemCoords(spatialDim*numOwnedElements), stat=rc)
+ if (rc/=0) then
+ call endrun(subname//'not able to allocate ownedElemCoords')
+ end if
+
+ allocate(lonMesh(total_cols), stat=rc)
+ if (rc/=0) then
+ call endrun(subname//'not able to allocate lonMesh')
+ end if
+
+ allocate(latMesh(total_cols), stat=rc)
+ if (rc/=0) then
+ call endrun(subname//'not able to allocate latMesh')
+ end if
+
+ call ESMF_MeshGet(physics_grid_mesh, ownedElemCoords=ownedElemCoords)
+ call check_esmf_error(rc, subname//'ESMF_MeshGet')
+
+ do n = 1,total_cols
+ lonMesh(n) = ownedElemCoords(2*n-1)
+ latMesh(n) = ownedElemCoords(2*n)
+ end do
+
+ ! obtain internally generated cam lats and lons
+ allocate(lon(total_cols), stat=rc);
+ if (rc/=0) then
+ call endrun(subname//'not able to allocate lon')
+ end if
+
+ lon(:) = 0._r8
+
+ allocate(lat(total_cols), stat=rc);
+ if (rc/=0) then
+ call endrun(subname//'not able to allocate lat')
+ end if
+
+ lat(:) = 0._r8
+
+ n=0
+ do c = begchunk, endchunk
+ ncols = get_ncols_p(c)
+ ! latitudes and longitudes returned in radians
+ call get_rlat_all_p(c, ncols, lats)
+ call get_rlon_all_p(c, ncols, lons)
+ do i=1,ncols
+ n = n+1
+ lat(n) = lats(i)*radtodeg
+ lon(n) = lons(i)*radtodeg
+ end do
+ end do
+
+ errstr = ''
+ ! error check differences between internally generated lons and those read in
+ do n = 1,total_cols
+ if (abs(lonMesh(n) - lon(n)) > abstol) then
+ if ( (abs(lonMesh(n)-lon(n)) > 360._r8+abstol) .or. (abs(lonMesh(n)-lon(n)) < 360._r8-abstol) ) then
+ write(errstr,100) n,lon(n),lonMesh(n), abs(lonMesh(n)-lon(n))
+ write(iulog,*) trim(errstr)
+ endif
+ end if
+ if (abs(latMesh(n) - lat(n)) > abstol) then
+ ! poles in the 4x5 SCRIP file seem to be off by 1 degree
+ if (.not.( (abs(lat(n))>88.0_r8) .and. (abs(latMesh(n))>88.0_r8) )) then
+ write(errstr,101) n,lat(n),latMesh(n), abs(latMesh(n)-lat(n))
+ write(iulog,*) trim(errstr)
+ endif
+ end if
+ end do
+
+ if ( len_trim(errstr) > 0 ) then
+ call endrun(subname//'physics mesh coords do not match model coords')
+ end if
+
+ ! deallocate memory
+ deallocate(ownedElemCoords)
+ deallocate(lon, lonMesh)
+ deallocate(lat, latMesh)
+ deallocate(decomp)
+
+100 format('esmf_phys_mesh_init: coord mismatch... n, lon(n), lonmesh(n), diff_lon = ',i6,2(f21.13,3x),d21.5)
+101 format('esmf_phys_mesh_init: coord mismatch... n, lat(n), latmesh(n), diff_lat = ',i6,2(f21.13,3x),d21.5)
+
+ end subroutine esmf_phys_mesh_init
+
+ !-----------------------------------------------------------------------------
+ !-----------------------------------------------------------------------------
+ subroutine esmf_phys_mesh_destroy()
+
+ integer :: rc
+ character(len=*), parameter :: subname = 'esmf_phys_mesh_destroy: '
+
+ call ESMF_MeshDestroy(physics_grid_mesh, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_MeshDestroy phys_mesh')
+
+ call ESMF_DistGridDestroy(dist_grid_2d, rc=rc)
+ call check_esmf_error(rc, subname//'ESMF_DistGridDestroy dist_grid_2d')
+
+ end subroutine esmf_phys_mesh_destroy
+
+end module esmf_phys_mesh_mod
diff --git a/src/utils/esmf_zonal_mean_mod.F90 b/src/utils/esmf_zonal_mean_mod.F90
new file mode 100644
index 0000000000..822c7ae639
--- /dev/null
+++ b/src/utils/esmf_zonal_mean_mod.F90
@@ -0,0 +1,155 @@
+!------------------------------------------------------------------------------
+! Provides methods for calculating zonal means on the ESMF regular longitude
+! / latitude grid
+!------------------------------------------------------------------------------
+module esmf_zonal_mean_mod
+ use shr_kind_mod, only: r8 => shr_kind_r8
+ use cam_logfile, only: iulog
+ use cam_abortutils, only: endrun
+ use spmd_utils, only: masterproc
+ use cam_history_support, only : fillvalue
+
+ implicit none
+
+ private
+
+ public :: esmf_zonal_mean_calc
+ public :: esmf_zonal_mean_masked
+ public :: esmf_zonal_mean_wsums
+
+ interface esmf_zonal_mean_calc
+ module procedure esmf_zonal_mean_calc_2d
+ module procedure esmf_zonal_mean_calc_3d
+ end interface esmf_zonal_mean_calc
+
+contains
+
+ !------------------------------------------------------------------------------
+ ! Calculates zonal means of 3D fields. The wght option can be used to mask out
+ ! regions such as mountains.
+ !------------------------------------------------------------------------------
+ subroutine esmf_zonal_mean_calc_3d(lonlatarr, zmarr, wght)
+ use ppgrid, only: pver
+ use esmf_lonlat_grid_mod, only: lon_beg,lon_end,lat_beg,lat_end, nlon
+ use esmf_lonlat_grid_mod, only: zonal_comm
+ use shr_reprosum_mod,only: shr_reprosum_calc
+
+ real(r8), intent(in) :: lonlatarr(lon_beg:lon_end,lat_beg:lat_end,pver)
+ real(r8), intent(out) :: zmarr(lat_beg:lat_end,pver)
+
+ real(r8), optional, intent(in) :: wght(lon_beg:lon_end,lat_beg:lat_end,pver)
+
+ real(r8) :: tmparr(lon_beg:lon_end,pver)
+ real(r8) :: gsum(pver)
+
+ real(r8) :: wsums(lat_beg:lat_end,pver)
+
+ integer :: numlons, ilat, ilev
+
+ numlons = lon_end-lon_beg+1
+
+ ! zonal mean
+ if (present(wght)) then
+
+ wsums = esmf_zonal_mean_wsums(wght)
+ call esmf_zonal_mean_masked(lonlatarr, wght, wsums, zmarr)
+
+ else
+
+ do ilat = lat_beg, lat_end
+ call shr_reprosum_calc(lonlatarr(lon_beg:lon_end,ilat,:), gsum, numlons, numlons, pver, gbl_count=nlon, commid=zonal_comm)
+ zmarr(ilat,:) = gsum(:)/nlon
+ end do
+
+ end if
+
+ end subroutine esmf_zonal_mean_calc_3d
+
+ !------------------------------------------------------------------------------
+ ! Computes zonal mean for 2D lon / lat fields.
+ !------------------------------------------------------------------------------
+ subroutine esmf_zonal_mean_calc_2d(lonlatarr, zmarr)
+ use esmf_lonlat_grid_mod, only: lon_beg,lon_end,lat_beg,lat_end, nlon
+ use esmf_lonlat_grid_mod, only: zonal_comm
+ use shr_reprosum_mod,only: shr_reprosum_calc
+
+ real(r8), intent(in) :: lonlatarr(lon_beg:lon_end,lat_beg:lat_end)
+ real(r8), intent(out) :: zmarr(lat_beg:lat_end)
+
+ real(r8) :: gsum(lat_beg:lat_end)
+
+ integer :: numlons, numlats
+
+ numlons = lon_end-lon_beg+1
+ numlats = lat_end-lat_beg+1
+
+ ! zonal mean
+
+ call shr_reprosum_calc(lonlatarr, gsum, numlons, numlons, numlats, gbl_count=nlon, commid=zonal_comm)
+ zmarr(:) = gsum(:)/nlon
+
+ end subroutine esmf_zonal_mean_calc_2d
+
+ !------------------------------------------------------------------------------
+ ! Computes longitude sums of grid cell weights.
+ !------------------------------------------------------------------------------
+ function esmf_zonal_mean_wsums(wght) result(wsums)
+ use esmf_lonlat_grid_mod, only: lon_beg,lon_end,lat_beg,lat_end, nlon
+ use esmf_lonlat_grid_mod, only: zonal_comm
+ use shr_reprosum_mod,only: shr_reprosum_calc
+ use ppgrid, only: pver
+
+ real(r8), intent(in) :: wght(lon_beg:lon_end,lat_beg:lat_end,pver)
+
+ real(r8) :: wsums(lat_beg:lat_end,pver)
+ integer :: numlons, ilat
+
+ numlons = lon_end-lon_beg+1
+
+ do ilat = lat_beg, lat_end
+
+ call shr_reprosum_calc(wght(lon_beg:lon_end,ilat,:), wsums(ilat,1:pver), &
+ numlons, numlons, pver, gbl_count=nlon, commid=zonal_comm)
+
+ end do
+
+ end function esmf_zonal_mean_wsums
+
+ !------------------------------------------------------------------------------
+ ! Masks out regions (e.g. mountains) from zonal mean calculation.
+ !------------------------------------------------------------------------------
+ subroutine esmf_zonal_mean_masked(lonlatarr, wght, wsums, zmarr)
+ use esmf_lonlat_grid_mod, only: lon_beg,lon_end,lat_beg,lat_end, nlon
+ use esmf_lonlat_grid_mod, only: zonal_comm
+ use shr_reprosum_mod,only: shr_reprosum_calc
+ use ppgrid, only: pver
+
+ real(r8), intent(in) :: lonlatarr(lon_beg:lon_end,lat_beg:lat_end,pver)
+ real(r8), intent(in) :: wght(lon_beg:lon_end,lat_beg:lat_end,pver) ! grid cell weights
+ real(r8), intent(in) :: wsums(lat_beg:lat_end,pver) ! pre-computed sums of grid cell weights
+ real(r8), intent(out) :: zmarr(lat_beg:lat_end,pver) ! zonal means
+
+ real(r8) :: tmparr(lon_beg:lon_end,pver)
+ integer :: numlons, ilat, ilev
+ real(r8) :: gsum(pver)
+
+ numlons = lon_end-lon_beg+1
+
+ do ilat = lat_beg, lat_end
+
+ tmparr(lon_beg:lon_end,:) = wght(lon_beg:lon_end,ilat,:)*lonlatarr(lon_beg:lon_end,ilat,:)
+ call shr_reprosum_calc(tmparr, gsum, numlons, numlons, pver, gbl_count=nlon, commid=zonal_comm)
+
+ do ilev = 1,pver
+ if (wsums(ilat,ilev)>0._r8) then
+ zmarr(ilat,ilev) = gsum(ilev)/wsums(ilat,ilev)
+ else
+ zmarr(ilat,ilev) = fillvalue
+ end if
+ end do
+
+ end do
+
+ end subroutine esmf_zonal_mean_masked
+
+end module esmf_zonal_mean_mod
diff --git a/src/utils/remap.F90 b/src/utils/remap.F90
new file mode 100644
index 0000000000..cb432cfb05
--- /dev/null
+++ b/src/utils/remap.F90
@@ -0,0 +1,483 @@
+!-----------------------------------------------------------------------------
+! utilities to gather and distribute columns and remap them to/from
+! cubed-sphere grid to a lat-lon grid.
+!-----------------------------------------------------------------------------
+module nlgw_remap_mod
+ use shr_kind_mod, only: r8 => shr_kind_r8, cx => SHR_KIND_CX
+ use ppgrid, only: begchunk, endchunk, pcols, pver, pverp
+ use physics_types, only: physics_state
+ use phys_grid, only: get_ncols_p
+ use spmd_utils, only: masterproc, npes
+ use ref_pres, only: pref_mid
+ use esmf_lonlat_grid_mod, only: beglon=>lon_beg, endlon=>lon_end, beglat=>lat_beg, endlat=>lat_end
+ use cam_history, only: addfld, outfld, horiz_only
+ use cam_history_support, only : fillvalue
+ use perf_mod, only: t_startf, t_stopf
+ use cam_logfile, only: iulog
+ use cam_abortutils, only: endrun
+ use gw_nlgw_utils, only: phys_vars, lonlat_vars
+
+ implicit none
+
+ private
+
+ public :: nlgw_regrid_init
+ public :: nlgw_latlon_gather
+ public :: nlgw_latlon_scatter
+ public :: nlgw_regrid_final
+
+ ! private (for book-keeping in MPI calls)
+ integer, allocatable :: recvcnts(:), displs(:)
+ integer, allocatable :: beglats(:), beglons(:)
+ integer, allocatable :: endlats(:), endlons(:)
+
+ logical, parameter, public :: debug = .true.
+
+
+contains
+
+ !-----------------------------------------------------------------------------
+ ! Initialize arrays and grids for regridding/MPI calls
+ !-----------------------------------------------------------------------------
+ subroutine nlgw_regrid_init(phys, lonlat, gathered_lonlat)
+ use cam_grid_support, only: horiz_coord_t, horiz_coord_create, iMap, cam_grid_register, get_cam_grid_index
+ use esmf_lonlat_grid_mod, only: glats, glons
+ use esmf_lonlat_grid_mod, only: esmf_lonlat_grid_init
+ use esmf_phys_mesh_mod, only: esmf_phys_mesh_init
+ use esmf_phys2lonlat_mod, only: esmf_phys2lonlat_init
+ use esmf_lonlat2phys_mod, only: esmf_lonlat2phys_init
+ use gw_nlgw_utils, only: nlon, nlat
+
+ integer, parameter :: reg_decomp = 332
+
+ integer(iMap), pointer :: grid_map(:,:)
+
+ integer(iMap), pointer :: coord_map(:) => null()
+ type(horiz_coord_t), pointer :: lon_coord
+ type(horiz_coord_t), pointer :: lat_coord
+ integer :: i, j, ind, astat
+
+ type(phys_vars), intent(inout), target :: phys
+ type(lonlat_vars), intent(inout), target :: lonlat
+ type(lonlat_vars), intent(inout), target :: gathered_lonlat
+
+ character(len=*), parameter :: subname = 'ctem_diags_reg: '
+
+ ! initialize grids and mapping
+ call esmf_lonlat_grid_init(nlat, nlon)
+ call esmf_phys_mesh_init()
+ call esmf_phys2lonlat_init()
+ call esmf_lonlat2phys_init()
+
+ ! for the lon-lat grid
+ allocate(grid_map(4, ((endlon - beglon + 1) * (endlat - beglat + 1))), stat=astat)
+ if (astat/=0) then
+ call endrun(subname//'not able to allocate grid_map array')
+ end if
+
+ ind = 0
+ do i = beglat, endlat
+ do j = beglon, endlon
+ ind = ind + 1
+ grid_map(1, ind) = j
+ grid_map(2, ind) = i
+ grid_map(3, ind) = j
+ grid_map(4, ind) = i
+ end do
+ end do
+
+ allocate(coord_map(endlat - beglat + 1), stat=astat)
+ if (astat/=0) then
+ call endrun(subname//'not able to allocate coord_map array')
+ end if
+
+ if (beglon==1) then
+ coord_map = (/ (i, i = beglat, endlat) /)
+ else
+ coord_map = 0
+ end if
+ lat_coord => horiz_coord_create('reglat', '', nlat, 'latitude', 'degrees_north', beglat, endlat, &
+ glats(beglat:endlat), map=coord_map)
+
+ nullify(coord_map)
+
+ allocate(coord_map(endlon - beglon + 1), stat=astat)
+ if (astat/=0) then
+ call endrun(subname//'not able to allocate coord_map array')
+ end if
+
+ if (beglat==1) then
+ coord_map = (/ (i, i = beglon, endlon) /)
+ else
+ coord_map = 0
+ end if
+
+ lon_coord => horiz_coord_create('reglon', '', nlon, 'longitude', 'degrees_east', beglon, endlon, &
+ glons(beglon:endlon), map=coord_map)
+
+ nullify(coord_map)
+
+ ! only create grid if it doesn't exist
+ if (get_cam_grid_index('ctem_lonlat') == -1) then
+ call cam_grid_register('ctem_lonlat', reg_decomp, lat_coord, lon_coord, grid_map, unstruct=.false.)
+ end if
+
+ nullify(grid_map)
+
+ allocate(recvcnts(npes))
+ allocate(displs(npes))
+ allocate(beglats(npes))
+ allocate(beglons(npes))
+ allocate(endlats(npes))
+ allocate(endlons(npes))
+
+ allocate(phys%u(pver,pcols,begchunk:endchunk))
+ allocate(phys%v(pver,pcols,begchunk:endchunk))
+ allocate(phys%theta(pver,pcols,begchunk:endchunk))
+ allocate(phys%w(pver,pcols,begchunk:endchunk))
+ allocate(phys%uflux(pver,pcols,begchunk:endchunk))
+ allocate(phys%vflux(pver,pcols,begchunk:endchunk))
+ allocate(phys%utgw(pver,pcols,begchunk:endchunk))
+ allocate(phys%vtgw(pver,pcols,begchunk:endchunk))
+
+
+ allocate(lonlat%u(beglon:endlon,beglat:endlat,pver))
+ allocate(lonlat%v(beglon:endlon,beglat:endlat,pver))
+ allocate(lonlat%w(beglon:endlon,beglat:endlat,pver))
+ allocate(lonlat%theta(beglon:endlon,beglat:endlat,pver))
+ allocate(lonlat%uflux(beglon:endlon,beglat:endlat,pver))
+ allocate(lonlat%vflux(beglon:endlon,beglat:endlat,pver))
+
+ if (debug) then
+ allocate(phys%pmid(pver,pcols,begchunk:endchunk))
+ allocate(phys%lon(pcols,begchunk:endchunk))
+ allocate(phys%lat(pcols,begchunk:endchunk))
+ end if
+
+ ! gathered grids only exist on masterproc
+ if (masterproc) then
+ allocate(gathered_lonlat%u(nlon, nlat, pver))
+ allocate(gathered_lonlat%v(nlon, nlat, pver))
+ allocate(gathered_lonlat%w(nlon, nlat, pver))
+ allocate(gathered_lonlat%theta(nlon, nlat, pver))
+ allocate(gathered_lonlat%uflux(nlon, nlat, pver))
+ allocate(gathered_lonlat%vflux(nlon, nlat, pver))
+ end if
+ end subroutine nlgw_regrid_init
+
+
+ !-----------------------------------------------------------------------------
+ ! This routine takes variables on the regular lat/lon grid:
+ ! * uses MPI_Scatter to broadcast them from masterproc back to all ranks
+ ! * interpolates them back onto the cubed-sphere grid
+ ! * finally re-chunks the data so it can be used elsewhere
+ !-----------------------------------------------------------------------------
+ subroutine nlgw_latlon_scatter(phys, lonlat, gathered_lonlat)
+ use gw_nlgw_utils, only: nlon, nlat
+ use esmf_lonlat2phys_mod, only: fields_bundle_t, n_flx_flds, esmf_lonlat2phys_regrid
+ use mpishorthand
+
+ type(phys_vars), intent(inout), target :: phys
+ type(lonlat_vars), intent(inout), target :: lonlat
+ type(lonlat_vars), intent(inout), target :: gathered_lonlat
+
+ real(r8), allocatable :: flat_array(:)
+
+ integer :: lchnk, ncol, i, sendcnt, disp_sum
+
+ type(fields_bundle_t) :: phys_flx_flds(n_flx_flds)
+ type(fields_bundle_t) :: lonlat_flx_flds(n_flx_flds)
+
+ call t_startf('nlgw_scatter')
+
+ call t_startf('nlgw_mpiscatter')
+ ! this subsection gathers all variables onto a single process
+
+ sendcnt = (endlon - beglon + 1) * (endlat - beglat + 1) * pver
+
+ ! mpi gather book-keeping
+ call mpigather(sendcnt, 1, mpiint, recvcnts, 1, mpiint, 0, mpicom)
+ call mpigather(beglat, 1, mpiint, beglats, 1, mpiint, 0, mpicom)
+ call mpigather(beglon, 1, mpiint, beglons, 1, mpiint, 0, mpicom)
+ call mpigather(endlat, 1, mpiint, endlats, 1, mpiint, 0, mpicom)
+ call mpigather(endlon, 1, mpiint, endlons, 1, mpiint, 0, mpicom)
+
+ if (masterproc) then
+ disp_sum = 0
+ do i = 1, npes
+ displs(i) = disp_sum
+ disp_sum = disp_sum + recvcnts(i)
+ end do
+ end if
+ allocate(flat_array(nlon * nlat * pver))
+
+ ! unlike in gather case all ranks needs the displs and recvcnts
+ call mpibcast(displs, npes, mpiint, 0, mpicom)
+ call mpibcast(recvcnts, npes, mpiint, 0, mpicom)
+
+ call scatter_3d(gathered_lonlat%uflux, sendcnt, flat_array, lonlat%uflux(beglon:endlon, beglat:endlat, 1:pver))
+ call scatter_3d(gathered_lonlat%vflux, sendcnt, flat_array, lonlat%vflux(beglon:endlon, beglat:endlat, 1:pver))
+
+ deallocate(flat_array)
+
+ call t_stopf('nlgw_mpiscatter')
+
+ call t_startf('nlgw_latlon_gather')
+ ! this subsection does regridding
+
+ phys_flx_flds(1)%fld => phys%uflux
+ phys_flx_flds(2)%fld => phys%vflux
+
+ lonlat_flx_flds(1)%fld => lonlat%uflux
+ lonlat_flx_flds(2)%fld => lonlat%vflux
+
+ ! actual call to regrid to lon/lat grid
+ call esmf_lonlat2phys_regrid(lonlat_flx_flds, phys_flx_flds)
+
+ call t_stopf('nlgw_latlon_gather')
+
+ call t_stopf('nlgw_scatter')
+
+ end subroutine nlgw_latlon_scatter
+
+
+ !-----------------------------------------------------------------------------
+ ! This routine takes variables on the irregular cubed-sphere grid:
+ ! * gathers all the chunks into a single data structure on each rank
+ ! * interpolates cubed-sphere variable to a regular lonlat grid
+ ! * uses MPI_Gather to collect regridded data from all ranks to the masterproc
+ !-----------------------------------------------------------------------------
+ subroutine nlgw_latlon_gather(phys_state, phys, lonlat, gathered_lonlat)
+ use gw_nlgw_utils, only: nlon, nlat
+ use esmf_phys2lonlat_mod, only: fields_bundle_t, nflds, esmf_phys2lonlat_regrid
+ use gw_nlgw_utils, only: p0
+ use physconst, only: cappa
+ use mpishorthand
+
+ type(physics_state), intent(in) :: phys_state(begchunk:endchunk)
+
+ type(phys_vars), intent(inout), target :: phys
+ type(lonlat_vars), intent(inout), target :: lonlat
+ type(lonlat_vars), intent(inout), target :: gathered_lonlat
+
+ real(r8), allocatable :: flat_array(:)
+
+ integer :: lchnk, ncol, i, sendcnt, disp_sum
+
+ type(fields_bundle_t) :: physflds(nflds)
+ type(fields_bundle_t) :: lonlatflds(nflds)
+
+ call t_startf('nlgw_gather')
+
+
+ call t_startf('nlgw_unchunk')
+
+ do lchnk = begchunk,endchunk
+ ncol = phys_state(lchnk)%ncol
+ do i = 1,ncol
+ ! wind components
+ phys%u(:,i,lchnk) = phys_state(lchnk)%u(i,:)
+ phys%v(:,i,lchnk) = phys_state(lchnk)%v(i,:)
+ phys%w(:,i,lchnk) = phys_state(lchnk)%omega(i,:)
+ phys%theta(:,i,lchnk) = phys_state(lchnk)%t(i,:) * (p0 / phys_state(lchnk)%pmid(i,:)) ** cappa
+
+ ! for debugging only
+ if (debug) then
+ phys%pmid(:,i,lchnk) = phys_state(lchnk)%pmid(i,:)
+ phys%lat(i,lchnk) = phys_state(lchnk)%lat(i)
+ phys%lon(i,lchnk) = phys_state(lchnk)%lon(i)
+ end if
+
+ end do
+ end do
+
+ call t_stopf('nlgw_unchunk')
+
+ call t_startf('nlgw_latlon_gather')
+
+ physflds(1)%fld => phys%u
+ physflds(2)%fld => phys%v
+ physflds(3)%fld => phys%w
+ physflds(4)%fld => phys%theta
+
+ lonlatflds(1)%fld => lonlat%u
+ lonlatflds(2)%fld => lonlat%v
+ lonlatflds(3)%fld => lonlat%w
+ lonlatflds(4)%fld => lonlat%theta
+
+ ! actual call to regrid to lon/lat grid
+ call esmf_phys2lonlat_regrid(physflds, lonlatflds)
+
+ call t_stopf('nlgw_latlon_gather')
+
+ call t_startf('nlgw_mpigather')
+ ! this subsection gathers all variables onto a single process
+
+ sendcnt = (endlon - beglon + 1) * (endlat - beglat + 1) * pver
+
+ ! mpi gather book-keeping
+ call mpigather(sendcnt, 1, mpiint, recvcnts, 1, mpiint, 0, mpicom)
+ call mpigather(beglat, 1, mpiint, beglats, 1, mpiint, 0, mpicom)
+ call mpigather(beglon, 1, mpiint, beglons, 1, mpiint, 0, mpicom)
+ call mpigather(endlat, 1, mpiint, endlats, 1, mpiint, 0, mpicom)
+ call mpigather(endlon, 1, mpiint, endlons, 1, mpiint, 0, mpicom)
+
+ if (masterproc) then
+ disp_sum = 0
+ do i = 1, npes
+ displs(i) = disp_sum
+ disp_sum = disp_sum + recvcnts(i)
+ end do
+ end if
+ allocate(flat_array(nlon * nlat * pver))
+
+ call gather_3d(lonlat%u(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, gathered_lonlat%u)
+ call gather_3d(lonlat%v(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, gathered_lonlat%v)
+ call gather_3d(lonlat%w(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, gathered_lonlat%w)
+ call gather_3d(lonlat%theta(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, gathered_lonlat%theta)
+
+ deallocate(flat_array)
+
+ call t_stopf('nlgw_mpigather')
+
+ call t_stopf('nlgw_gather')
+
+ end subroutine nlgw_latlon_gather
+
+ !-----------------------------------------------------------------------------
+ ! Utility function for gathering 2D data into a single array
+ !-----------------------------------------------------------------------------
+ subroutine gather_2d(local_array, sendcnt, flat_array, grid_out)
+ use mpishorthand
+ real(r8), intent(in) :: local_array(:,:) ! Local 2D array section
+ integer, intent(in) :: sendcnt
+ real(r8), intent(inout) :: flat_array(:) ! Flattened array for gathering
+ real(r8), allocatable, intent(inout) :: grid_out(:,:) ! Full gathered grid
+
+ integer :: i, lonsize, latsize
+
+ ! gather variables onto master proc into a flat array (can't do 2D/3D mpigather)
+ call mpigatherv(local_array, sendcnt, mpir8, flat_array, recvcnts, displs, mpir8, 0, mpicom)
+
+ if (masterproc) then
+ do i = 1, npes
+ lonsize = endlons(i) - beglons(i) + 1
+ latsize = endlats(i) - beglats(i) + 1
+ ! reshape each ranks flattended data and populate each block into a single lonlat grid
+ grid_out(beglons(i):endlons(i), beglats(i):endlats(i)) = &
+ reshape(flat_array(displs(i)+1:displs(i)+sendcnt), (/ lonsize, latsize /))
+ end do
+ end if
+ end subroutine gather_2d
+
+ !-----------------------------------------------------------------------------
+ ! Utility function for gathering 3D data into a single array
+ !-----------------------------------------------------------------------------
+ subroutine gather_3d(local_array, sendcnt, flat_array, grid_out)
+ use mpishorthand
+ real(r8), intent(in) :: local_array(:,:,:) ! Local 3D array section
+ integer, intent(in) :: sendcnt
+ real(r8), intent(inout) :: flat_array(:) ! Flattened array for gathering
+ real(r8), allocatable, intent(inout) :: grid_out(:,:,:) ! Full gathered grid
+
+ integer :: i, lonsize, latsize
+
+ ! gather variables onto master proc into a flat array (can't do 2D/3D mpigather)
+ call mpigatherv(local_array, sendcnt, mpir8, flat_array, recvcnts, displs, mpir8, 0, mpicom)
+
+ if (masterproc) then
+ do i = 1, npes
+ lonsize = endlons(i) - beglons(i) + 1
+ latsize = endlats(i) - beglats(i) + 1
+ ! reshape each ranks flattended data and populate each block into a single lonlat grid
+ grid_out(beglons(i):endlons(i), beglats(i):endlats(i), 1:pver) = &
+ reshape(flat_array(displs(i)+1:displs(i)+sendcnt), (/ lonsize, latsize, pver /))
+ end do
+ end if
+ end subroutine gather_3d
+
+ !-----------------------------------------------------------------------------
+ ! Utility function for scattering 3D data into lonlat arrays
+ !-----------------------------------------------------------------------------
+ subroutine scatter_3d(grid_in, sendcnt, flat_array, lonlat_out)
+ use mpishorthand
+ real(r8), allocatable, intent(in) :: grid_in(:,:,:) ! Local 3D array section
+ integer, intent(in) :: sendcnt
+ real(r8), intent(inout) :: flat_array(:) ! temporary storage in flat array
+ real(r8), target, intent(inout) :: lonlat_out(:,:,:) ! Full scattered grid
+
+ integer :: i, lonsize, latsize
+
+ if (masterproc) then
+ do i = 1, npes
+ lonsize = endlons(i) - beglons(i) + 1
+ latsize = endlats(i) - beglats(i) + 1
+ flat_array(displs(i)+1:displs(i)+sendcnt) = &
+ reshape(grid_in(beglons(i):endlons(i), beglats(i):endlats(i), 1:pver), (/lonsize* latsize * pver/))
+ end do
+ end if
+
+ ! scatter variables from flat_array back to all processes
+ call mpiscatterv(flat_array, recvcnts, displs, mpir8, lonlat_out, sendcnt, mpir8, 0, mpicom)
+
+ end subroutine scatter_3d
+
+ !-----------------------------------------------------------------------------
+ ! Tidy up (free allocated memory)
+ !-----------------------------------------------------------------------------
+ subroutine nlgw_regrid_final(phys, lonlat, gathered_lonlat)
+ use esmf_phys2lonlat_mod, only: esmf_phys2lonlat_destroy
+ use esmf_lonlat2phys_mod, only: esmf_lonlat2phys_destroy
+ use esmf_lonlat_grid_mod, only: esmf_lonlat_grid_destroy
+ use esmf_phys_mesh_mod, only: esmf_phys_mesh_destroy
+
+ type(phys_vars), intent(inout), target :: phys
+ type(lonlat_vars), intent(inout), target :: lonlat
+ type(lonlat_vars), intent(inout), target :: gathered_lonlat
+
+ call esmf_phys2lonlat_destroy()
+ call esmf_lonlat2phys_destroy()
+ call esmf_lonlat_grid_destroy()
+ call esmf_phys_mesh_destroy()
+
+ ! TODO double check ALL deallocates here
+ if (masterproc) then
+ deallocate(gathered_lonlat%u)
+ deallocate(gathered_lonlat%v)
+ deallocate(gathered_lonlat%w)
+ deallocate(gathered_lonlat%theta)
+ deallocate(gathered_lonlat%uflux)
+ deallocate(gathered_lonlat%vflux)
+ end if
+
+ deallocate(phys%u)
+ deallocate(phys%v)
+ deallocate(phys%theta)
+ deallocate(phys%w)
+ deallocate(phys%uflux)
+ deallocate(phys%vflux)
+ deallocate(phys%utgw)
+ deallocate(phys%vtgw)
+
+ deallocate(lonlat%u)
+ deallocate(lonlat%v)
+ deallocate(lonlat%w)
+ deallocate(lonlat%theta)
+ deallocate(lonlat%uflux)
+ deallocate(lonlat%vflux)
+
+ if (debug) then
+ deallocate(phys%pmid)
+ deallocate(phys%lon)
+ deallocate(phys%lat)
+ end if
+
+ deallocate(recvcnts)
+ deallocate(displs)
+ deallocate(beglats)
+ deallocate(beglons)
+ deallocate(endlats)
+ deallocate(endlons)
+ end subroutine nlgw_regrid_final
+
+end module nlgw_remap_mod