diff --git a/README.md b/README.md index eb22f4e47b..e9a5f69082 100644 --- a/README.md +++ b/README.md @@ -120,15 +120,26 @@ gw_convect_dp_ml_norms='/path/to/norms' To run CAM using the non local gravity wave ML model to replace all parameterisations use the following configuration ```fortran -use_gw_nlgw=.true. -gw_nlgw_model_path='/path/to/nlgw-scripted-model.pt' +use_gw_nlgw_ann=.true. +use_gw_nlgw_unet=.true. +gw_nlgw_model_path_ann='/path/to/ann-scripted-model.pt' +gw_nlgw_model_path_unet='/path/to/unet-scripted-model.pt' ``` -* `use_gw_nlgw` (`logical`) +* `use_gw_nlgw_ann` (`logical`) - Whether or not to use the ML scheme for non local gravity waves. Default: `.false.` + Whether or not to use the ANN ML scheme for non local gravity waves. Default: `.false.` -* `gw_nlgw_model_path` +* `gw_nlgw_model_path_ann` + + Absolute filepath to the non local gravity wave neural net used when `use_gw_nlgw` is set to `.true.` (`.pt` + extension). + +* `use_gw_nlgw_unet` (`logical`) + + Whether or not to use the UNET ML scheme for non local gravity waves. Default: `.false.` + +* `gw_nlgw_model_path_unet` Absolute filepath to the non local gravity wave neural net used when `use_gw_nlgw` is set to `.true.` (`.pt` extension). diff --git a/bld/build-namelist b/bld/build-namelist index 0ffb144b48..9a8e1b88dc 100755 --- a/bld/build-namelist +++ b/bld/build-namelist @@ -3606,7 +3606,8 @@ if (!$simple_phys) { add_default($nl, 'use_gw_rdg_gamma' , 'val'=>'.false.'); add_default($nl, 'use_gw_front_igw' , 'val'=>'.false.'); add_default($nl, 'use_gw_convect_sh', 'val'=>'.false.'); - add_default($nl, 'use_gw_nlgw' , 'val'=>'.false.'); + add_default($nl, 'use_gw_nlgw_ann' , 'val'=>'.false.'); + add_default($nl, 'use_gw_nlgw_unet' , 'val'=>'.false.'); add_default($nl, 'gw_lndscl_sgh'); add_default($nl, 'gw_oro_south_fac'); add_default($nl, 'gw_limit_tau_without_eff'); diff --git a/bld/namelist_files/namelist_definition.xml b/bld/namelist_files/namelist_definition.xml index 52e8313b21..b783f836d7 100644 --- a/bld/namelist_files/namelist_definition.xml +++ b/bld/namelist_files/namelist_definition.xml @@ -1332,10 +1332,17 @@ Whether or not to enable gravity waves produced by shallow convection. Default: .false. - Whether or not to enable gravity waves produced by non-local gravity -wave ML model. +wave ANN ML model. +Default: set by build-namelist. + + + +Whether or not to enable gravity waves produced by non-local gravity +wave UNet ML model. Default: set by build-namelist. @@ -1428,10 +1435,16 @@ Absolute filepath to the deep convection gravity wave neural net used when Default: .false. - +Absolute filepath to the non local gravity wave traced model (.pt) +used when `use_gw_nlgw_ann` is set to `.true.`. + + + Absolute filepath to the non local gravity wave traced model (.pt) -used when `use_gw_nlgw` is set to `.true.`. +used when `use_gw_nlgw_unet` is set to `.true.`. masterprocid, mpi_real8, & mpi_character, mpi_logical, mpi_integer use gw_rdg, only: gw_rdg_readnl + use gw_nlgw_utils, only: gw_nlgw_model_path_ann, gw_nlgw_model_path_unet ! File containing namelist input. character(len=*), intent(in) :: nlfile @@ -248,7 +249,7 @@ subroutine gw_drag_readnl(nlfile) gw_top_taper, front_gaussian_width, & gw_convect_dp_ml, gw_convect_dp_ml_compare, & gw_convect_dp_ml_net_path, gw_convect_dp_ml_norms, & - gw_nlgw_model_path + gw_nlgw_model_path_ann, gw_nlgw_model_path_unet !---------------------------------------------------------------------- if (use_simple_phys) return @@ -364,8 +365,11 @@ subroutine gw_drag_readnl(nlfile) call mpi_bcast(gw_convect_dp_ml_norms, len(gw_convect_dp_ml_norms), mpi_character, mstrid, mpicom, ierr) if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: gw_convect_dp_ml_norms") - call mpi_bcast(gw_nlgw_model_path, len(gw_nlgw_model_path), mpi_character, mstrid, mpicom, ierr) - if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: gw_nlgw_model_path") + call mpi_bcast(gw_nlgw_model_path_ann, len(gw_nlgw_model_path_ann), mpi_character, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: gw_nlgw_model_path_ann") + + call mpi_bcast(gw_nlgw_model_path_unet, len(gw_nlgw_model_path_unet), mpi_character, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: gw_nlgw_model_path_unet") ! Check if fcrit2 was set. call shr_assert(fcrit2 /= unset_r8, & @@ -425,6 +429,7 @@ subroutine gw_init() use gw_common, only: gw_common_init use gw_front, only: gaussian_cm_desc + use gw_nlgw_utils, only: gw_nlgw_model_path_ann !---------------------------Local storage------------------------------- @@ -577,8 +582,8 @@ subroutine gw_init() call shr_assert(trim(errstring) == "", "gw_common_init: "//errstring// & errMsg(__FILE__, __LINE__)) - if ( use_gw_nlgw ) then - call gw_nlgw_dp_init(gw_nlgw_model_path) + if ( use_gw_nlgw_ann ) then + call gw_nlgw_ann_init(gw_nlgw_model_path_ann) end if if ( use_gw_oro ) then @@ -1298,8 +1303,8 @@ subroutine gw_final() if ((gw_convect_dp_ml) .or. (gw_convect_dp_ml_compare)) then call gw_drag_convect_dp_ml_final() endif - if ( use_gw_nlgw ) then - call gw_nlgw_dp_finalize() + if ( use_gw_nlgw_ann ) then + call gw_nlgw_ann_finalize() end if end subroutine gw_final @@ -1548,8 +1553,8 @@ subroutine gw_tend(state, pbuf, dt, ptend, cam_in, flx_heat) egwdffi_tot = 0._r8 flx_heat = 0._r8 - if ( use_gw_nlgw ) then - call gw_nlgw_dp_ml(state1,ptend) + if ( use_gw_nlgw_ann ) then + call gw_nlgw_ann_infer(state1,ptend) end if if (use_gw_convect_dp) then diff --git a/src/physics/cam/gw_nlgw.F90 b/src/physics/cam/gw_nlgw_ann.F90 similarity index 91% rename from src/physics/cam/gw_nlgw.F90 rename to src/physics/cam/gw_nlgw_ann.F90 index 647b771e68..fc61d9bb60 100644 --- a/src/physics/cam/gw_nlgw.F90 +++ b/src/physics/cam/gw_nlgw_ann.F90 @@ -1,4 +1,4 @@ -module gw_nlgw +module gw_nlgw_ann ! ! This module predicts gravity wave forcings via PyTorch NNs trained to include non-local gravity wave effects @@ -11,18 +11,17 @@ module gw_nlgw use cam_abortutils, only: endrun use cam_logfile, only: iulog use physconst, only: cappa, pi +use gw_nlgw_utils, only: p0 use interpolate_data, only: lininterp use ftorch implicit none -public :: gw_nlgw_dp_ml, gw_nlgw_dp_init, gw_nlgw_dp_finalize +public :: gw_nlgw_ann_infer, gw_nlgw_ann_init, gw_nlgw_ann_finalize private -integer, parameter :: p0 = 100000 ! 1000 hPa (Pa) - type(torch_model) :: nlgw_model ! pytorch model integer :: ncol ! number of vertical columns @@ -104,7 +103,9 @@ module gw_nlgw !========================================================================== -subroutine gw_nlgw_dp_ml(state_in, ptend) +subroutine gw_nlgw_ann_infer(state_in, ptend) + + use gw_nlgw_utils, only: flux_to_forcing ! inputs type(physics_state), intent(in) :: state_in @@ -170,8 +171,8 @@ subroutine gw_nlgw_dp_ml(state_in, ptend) call extract_output() call denormalise_data() - call flux_to_forcing(uflux, utgw) - call flux_to_forcing(vflux, vtgw) + call flux_to_forcing(uflux, utgw, pmid, ncol) + call flux_to_forcing(vflux, vtgw, pmid, ncol) ! update the tendencies ptend%u(:ncol,:pver) = ptend%u(:ncol,:pver) + utgw(:ncol,:pver) @@ -200,10 +201,10 @@ subroutine gw_nlgw_dp_ml(state_in, ptend) deallocate(net_inputs) deallocate(net_outputs) -end subroutine gw_nlgw_dp_ml +end subroutine gw_nlgw_ann_infer -subroutine gw_nlgw_dp_init(model_path) +subroutine gw_nlgw_ann_init(model_path) character(len=*), intent(in) :: model_path ! Filepath to PyTorch Torchscript net integer :: device_id @@ -219,17 +220,17 @@ subroutine gw_nlgw_dp_init(model_path) write(iulog,*)'nlgw model loaded from: ', model_path endif -end subroutine gw_nlgw_dp_init +end subroutine gw_nlgw_ann_init -subroutine gw_nlgw_dp_finalize() +subroutine gw_nlgw_ann_finalize() deallocate(net_inputs) deallocate(net_outputs) ! free model memory call torch_delete(nlgw_model) -end subroutine gw_nlgw_dp_finalize +end subroutine gw_nlgw_ann_finalize subroutine read_norms() @@ -259,6 +260,7 @@ subroutine read_norms() end subroutine read_norms subroutine normalise_data() + use gw_nlgw_utils, only: cbrt ! lat lon are in radians (convert to degrees first) lat = lat * 180. / pi @@ -363,31 +365,4 @@ subroutine denormalise_data() end subroutine denormalise_data -elemental function cbrt(a) result(root) - real(r8), intent(in) :: a - real(r8), parameter :: one_third = 1._r8/3._r8 - real(r8) :: root - root = sign(abs(a)**one_third, a) -end function cbrt - -subroutine flux_to_forcing(flux, forcing) - - real(r8), intent(in), dimension(:,:) :: flux - real(r8), intent(out), dimension(:,:) :: forcing ! forcing = -d(u'\omega')/d(p), units = m/s^2 - - integer :: level, col - - ! convert fluxes to tendencies - ! pressure profile must be in Pascals - - do col = 1, ncol - forcing(col,1) = -1*(flux(col,2) - flux(col,1))/(pmid(col,2) - pmid(col,1)) - do level = 2, pver-1 - forcing(col,level) = (flux(col,level+1) - flux(col,level-1)) / (pmid(col,level)*(log(pmid(col,level+1)) - log(pmid(col,level-1)))) - end do - forcing(col,pver) = -1*(flux(col,pver) - flux(col,pver-1)) / (pmid(col,pver) - pmid(col,pver-1)) - end do - -end subroutine flux_to_forcing - -end module gw_nlgw +end module gw_nlgw_ann diff --git a/src/physics/cam/gw_nlgw_unet.F90 b/src/physics/cam/gw_nlgw_unet.F90 new file mode 100644 index 0000000000..c5be68ab93 --- /dev/null +++ b/src/physics/cam/gw_nlgw_unet.F90 @@ -0,0 +1,219 @@ +module gw_nlgw_unet + +! +! This module predicts gravity wave forcings via PyTorch NNs trained to include non-local gravity wave effects +! + +use gw_utils, only: r8, r4 +use ppgrid, only: pver !vertical levels +use physics_types, only: physics_state, physics_ptend +use spmd_utils, only: mpicom, mstrid=>masterprocid, masterproc, mpi_real8, iam +use cam_abortutils, only: endrun +use cam_logfile, only: iulog +use cam_history, only: outfld, addfld +use physconst, only: cappa +use gw_nlgw_utils, only: lonlat_vars, nlon, nlat + +use ftorch + +implicit none + +public :: gw_nlgw_unet_init, gw_nlgw_unet_infer, gw_nlgw_unet_finalize, gw_nlgw_unet_set_ptend + +private + +type(torch_model) :: nlgw_model ! pytorch model + +real(r4), dimension(:,:,:,:), allocatable, target :: net_inputs +real(r4), dimension(:,:,:,:), allocatable, target :: net_outputs + +! normalisation means and std devs +real(r8) :: u_mean, v_mean, w_mean, theta_mean +real(r8) :: u_std, v_std, w_std, theta_std + +real(r8) :: uflux_mean, vflux_mean +real(r8) :: uflux_std, vflux_std + +contains + +!========================================================================== + + +subroutine gw_nlgw_unet_init(model_path) + + character(len=*), intent(in) :: model_path ! Filepath to PyTorch Torchscript net + integer :: device_id + + device_id = 0 + + ! Load the convective drag net from TorchScript file + call torch_model_load(nlgw_model, model_path, device_type=torch_kCUDA, device_index=device_id) + ! read in normalisation weights + call read_norms() + + if (masterproc) then + write(iulog,*)'nlgw model loaded from: ', model_path + + ! UNet will only run on the master process + ! space for u, v, theta and w + allocate(net_inputs(1, pver*4, nlat, nlon)) + ! space for uflux and vflux + allocate(net_outputs(1, pver*2, nlat, nlon)) + endif + + +end subroutine gw_nlgw_unet_init + +subroutine gw_nlgw_unet_infer(gathered_lonlat) + + ! global unet data + type(lonlat_vars), intent(inout) :: gathered_lonlat + + !---------------------------Local storage------------------------------- + type(torch_tensor) :: tensor_in(1), tensor_out(1) + integer :: ninputs = 1, noutputs = 1 + integer, dimension(4) :: layout = [1 , 2, 3, 4] + + integer :: device_id + + device_id = 0 + + ! Normalise and construct the input + call normalise_data(gathered_lonlat) + call construct_input(gathered_lonlat) + + ! send all columns from this process + call torch_tensor_from_array(tensor_in(1), net_inputs, layout, torch_kCUDA, device_id) + call torch_tensor_from_array(tensor_out(1), net_outputs, layout, torch_kCPU) + + ! Run net forward on data + call torch_model_forward(nlgw_model, tensor_in, tensor_out) + + ! Extract and denormalise outputs + call extract_output(gathered_lonlat) + call denormalise_data(gathered_lonlat) + + ! Clean up the tensors + call torch_delete(tensor_in) + call torch_delete(tensor_out) + +end subroutine gw_nlgw_unet_infer + +subroutine gw_nlgw_unet_finalize() + + if (masterproc) then + deallocate(net_inputs) + deallocate(net_outputs) + end if + ! free model memory + call torch_delete(nlgw_model) + +end subroutine gw_nlgw_unet_finalize + +subroutine gw_nlgw_unet_set_ptend(phys_state, ptend, utgw, vtgw) + ! initiailize and update tendencies + use physconst ,only: cpair + use physics_types,only: physics_state,physics_ptend,physics_ptend_init + use constituents ,only: cnst_get_ind,pcnst + use ppgrid ,only: pver,pcols,begchunk,endchunk + use cam_history ,only: outfld + + type(physics_state), intent(in) :: phys_state + type(physics_ptend), intent(out):: ptend + real(r8), dimension(pcols,pver), intent(in) :: utgw, vtgw + + ! local vars + integer indw,ncol,lchnk + logical lq(pcnst) + + call cnst_get_ind('Q',indw) + lq(:) =.false. + lq(indw)=.true. + call physics_ptend_init(ptend,phys_state%psetcols,'nlgw_unet',lu=.true.,lv=.true.,ls=.true.,lq=lq) + + lchnk=phys_state%lchnk + ncol =phys_state%ncol + ptend%u(:ncol,:pver) = ptend%u(:ncol,:pver) + utgw(:ncol,:pver) + ptend%v(:ncol,:pver) = ptend%v(:ncol,:pver) + vtgw(:ncol,:pver) + ! ptend%s(:ncol,:pver) = cb24cnn_Sstep(:ncol,:pver,lchnk)*.35 + ! ptend%q(:ncol,:pver,indw)= cb24cnn_Qstep(:ncol,:pver,lchnk)*.35 + + call outfld('UTGW_NL', utgw(:ncol,:pver), ncol, lchnk) + call outfld('VTGW_NL', vtgw(:ncol,:pver), ncol, lchnk) +end subroutine gw_nlgw_unet_set_ptend + +subroutine read_norms() + + ! TODO + ! - replace hardcoded means/std devs with netcdf file? + + u_mean = 6.717847278462159_r8 + v_mean = -0.002744777264668839_r8 + theta_mean = 0._r8 + w_mean = 0.0013401482063147452_r8 + + u_std = 20.760385183200206_r8 + v_std = 9.877389116738264_r8 + theta_std = 1000._r8 + w_std = 0.11202126259282257_r8 + + uflux_mean = -0.0004691528666736032_r8 + vflux_mean = -0.0002586195082961397_r8 + uflux_std = 0.032814051953840274_r8 + vflux_std = 0.03024781201672967_r8 + +end subroutine read_norms + +subroutine normalise_data(gathered_lonlat) + use gw_nlgw_utils, only: cbrt + type(lonlat_vars), intent(inout) :: gathered_lonlat + + gathered_lonlat%u = (gathered_lonlat%u-u_mean)/(3._r8 * u_std) + gathered_lonlat%v = (gathered_lonlat%v-v_mean)/(3._r8 * v_std) + gathered_lonlat%theta = (gathered_lonlat%theta-theta_mean)/theta_std + gathered_lonlat%w = (gathered_lonlat%w-w_mean)/w_std + gathered_lonlat%w = cbrt(gathered_lonlat%w) + +end subroutine normalise_data + +subroutine construct_input(gathered_lonlat) + + type(lonlat_vars), intent(inout) :: gathered_lonlat + integer :: idx_beg, idx_end, i + + idx_end = 0 + + idx_beg = idx_end + 1 + idx_end = idx_end + pver + net_inputs(1,idx_beg:idx_end,:,:) = reshape(gathered_lonlat%u, shape=[pver, nlat, nlon], order=[3,2,1]) + idx_beg = idx_end + 1 + idx_end = idx_end + pver + net_inputs(1,idx_beg:idx_end,:,:) = reshape(gathered_lonlat%v, shape=[pver, nlat, nlon], order=[3,2,1]) + idx_beg = idx_end + 1 + idx_end = idx_end + pver + net_inputs(1,idx_beg:idx_end,:,:) = reshape(gathered_lonlat%theta, shape=[pver, nlat, nlon], order=[3,2,1]) + idx_beg = idx_end + 1 + idx_end = idx_end + pver + net_inputs(1,idx_beg:idx_end,:,:) = reshape(gathered_lonlat%w, shape=[pver, nlat, nlon], order=[3,2,1]) + +end subroutine construct_input + +subroutine extract_output(gathered_lonlat) + + type(lonlat_vars), intent(inout) :: gathered_lonlat + + gathered_lonlat%uflux(:,:,:) = reshape(net_outputs(1,:pver,:,:), shape=[nlon, nlat, pver], order=[3,2,1]) + gathered_lonlat%vflux(:,:,:) = reshape(net_outputs(1,pver+1:,:,:), shape=[nlon, nlat, pver], order=[3,2,1]) + +end subroutine extract_output + +subroutine denormalise_data(gathered_lonlat) + + type(lonlat_vars), intent(inout) :: gathered_lonlat + + gathered_lonlat%uflux = gathered_lonlat%uflux**3._r8 * uflux_std + uflux_mean + gathered_lonlat%vflux = gathered_lonlat%vflux**3._r8 * vflux_std + vflux_mean + +end subroutine denormalise_data + +end module gw_nlgw_unet diff --git a/src/physics/cam/gw_nlgw_utils.F90 b/src/physics/cam/gw_nlgw_utils.F90 new file mode 100644 index 0000000000..b16b5d0fbf --- /dev/null +++ b/src/physics/cam/gw_nlgw_utils.F90 @@ -0,0 +1,88 @@ +module gw_nlgw_utils + +use gw_utils, only: r8, r4 +use ppgrid, only: begchunk, endchunk, pcols, pver, pverp + +implicit none + +public :: cbrt, flux_to_forcing +public :: phys_vars, lonlat_vars +integer, parameter, public :: p0 = 100000 ! 1000 hPa (Pa) +integer, parameter, public :: nlon = 288 ! number of longitude points on lonlat grid +integer, parameter, public :: nlat = 192 ! number of latitude points on lonlat grid +character(len=132), public :: gw_nlgw_model_path_ann +character(len=132), public :: gw_nlgw_model_path_unet + +private + +! variables on cubed-sphere "phys" grid +type phys_vars +!dimension(pver,pcols,begchunk:endchunk) +real(r8), dimension(:,:,:), allocatable :: & + u, &! zonal wind (m/s) + v, &! meridional wind (m/s) + theta, &! temperature (K) + w, &! vertical pressure velocity (Pa/s) + pmid ! midpoint pressure (Pa) + +real(r8), dimension(:,:,:), allocatable :: & + uflux, &! zonal fluxes + vflux ! meridional fluxes + +real(r8), dimension(:,:,:), allocatable :: & + utgw, &! zonal tendencies + vtgw ! meridional tendencies + +! for debugging only +! dimension(pcols,begchunk:endchunk) +real(r8), dimension(:,:), allocatable :: & + lat, & + lon +end type + +! variables on regular lonlat grid +type lonlat_vars +! dimension(lon,lat,pver) +real(r8), dimension(:,:,:), allocatable :: & + u, &! zonal wind (m/s) + v, &! meridional wind (m/s) + theta, &! temperature (K) + w ! vertical pressure velocity (Pa/s) + +real(r8), dimension(:,:,:), allocatable :: & + uflux, &! zonal fluxes + vflux ! meridional fluxes +end type + +contains + +elemental function cbrt(a) result(root) + real(r8), intent(in) :: a + real(r8), parameter :: one_third = 1._r8/3._r8 + real(r8) :: root + root = sign(abs(a)**one_third, a) +end function cbrt + +subroutine flux_to_forcing(flux, forcing, pmid, ncol) + + real(r8), intent(in), dimension(:,:) :: flux ! flux (Pa m/s^2) !TODO check with Aman + real(r8), intent(in), dimension(:,:) :: pmid ! midpoint pressure (Pa) + real(r8), intent(out), dimension(:,:) :: forcing ! forcing = -d(u'\omega')/d(p), units = m/s^2 + integer, intent(in) :: ncol + + integer :: level, col + + ! convert fluxes to tendencies + ! pressure profile must be in Pascals + + do col = 1, ncol + forcing(col,1) = -1*(flux(col,2) - flux(col,1))/(pmid(col,2) - pmid(col,1)) + do level = 2, pver-1 + forcing(col,level) = (flux(col,level+1) - flux(col,level-1)) / (pmid(col,level)*(log(pmid(col,level+1)) - log(pmid(col,level-1)))) + end do + forcing(col,pver) = -1*(flux(col,pver) - flux(col,pver-1)) / (pmid(col,pver) - pmid(col,pver-1)) + end do + +end subroutine flux_to_forcing + +end module gw_nlgw_utils diff --git a/src/physics/cam/phys_control.F90 b/src/physics/cam/phys_control.F90 index 7497385583..93263e29a3 100644 --- a/src/physics/cam/phys_control.F90 +++ b/src/physics/cam/phys_control.F90 @@ -98,7 +98,8 @@ module phys_control logical, public, protected :: use_gw_front_igw = .false. ! Frontogenesis to inertial spectrum. logical, public, protected :: use_gw_convect_dp = .false. ! Deep convection. logical, public, protected :: use_gw_convect_sh = .false. ! Shallow convection. -logical, public, protected :: use_gw_nlgw = .false. ! non local GW ML model +logical, public, protected :: use_gw_nlgw_ann = .false. ! non local GW ML model (ANN - single column) +logical, public, protected :: use_gw_nlgw_unet = .false. ! non local GW ML model (UNet - global non local) ! FV dycore angular momentum correction logical, public, protected :: fv_am_correction = .false. @@ -137,7 +138,7 @@ subroutine phys_ctl_readnl(nlfile) history_waccmx, history_chemistry, history_carma, history_clubb, history_dust, & history_cesm_forcing, history_scwaccm_forcing, history_chemspecies_srf, & do_clubb_sgs, state_debug_checks, use_hetfrz_classnuc, use_gw_oro, use_gw_front, & - use_gw_front_igw, use_gw_convect_dp, use_gw_convect_sh, use_gw_nlgw, cld_macmic_num_steps, & + use_gw_front_igw, use_gw_convect_dp, use_gw_convect_sh, use_gw_nlgw_ann, use_gw_nlgw_unet, cld_macmic_num_steps, & offline_driver, convproc_do_aer, cam_snapshot_before_num, cam_snapshot_after_num, & cam_take_snapshot_before, cam_take_snapshot_after, cam_physics_mesh, use_hemco, do_hb_above_clubb !----------------------------------------------------------------------------- @@ -157,7 +158,10 @@ subroutine phys_ctl_readnl(nlfile) end if ! if we are using the nlgw ML model we need to disable all other parameterizations - if (masterproc .and. use_gw_nlgw==.true.) then + if (masterproc .and. (use_gw_nlgw_ann .or. use_gw_nlgw_unet)) then + if (use_gw_nlgw_ann .and. use_gw_nlgw_unet) then + call endrun(subname // ':: ERROR you can only select UNet or ANN not both.') + end if use_gw_oro = .false. use_gw_front = .false. use_gw_front_igw = .false. @@ -203,7 +207,8 @@ subroutine phys_ctl_readnl(nlfile) call mpi_bcast(use_gw_front_igw, 1, mpi_logical, masterprocid, mpicom, ierr) call mpi_bcast(use_gw_convect_dp, 1, mpi_logical, masterprocid, mpicom, ierr) call mpi_bcast(use_gw_convect_sh, 1, mpi_logical, masterprocid, mpicom, ierr) - call mpi_bcast(use_gw_nlgw, 1, mpi_logical, masterprocid, mpicom, ierr) + call mpi_bcast(use_gw_nlgw_ann, 1, mpi_logical, masterprocid, mpicom, ierr) + call mpi_bcast(use_gw_nlgw_unet, 1, mpi_logical, masterprocid, mpicom, ierr) call mpi_bcast(cld_macmic_num_steps, 1, mpi_integer, masterprocid, mpicom, ierr) call mpi_bcast(offline_driver, 1, mpi_logical, masterprocid, mpicom, ierr) call mpi_bcast(convproc_do_aer, 1, mpi_logical, masterprocid, mpicom, ierr) diff --git a/src/physics/cam_dev/physpkg.F90 b/src/physics/cam_dev/physpkg.F90 index 1c461c9a1c..c5f0e15353 100644 --- a/src/physics/cam_dev/physpkg.F90 +++ b/src/physics/cam_dev/physpkg.F90 @@ -768,6 +768,7 @@ subroutine phys_init( phys_state, phys_tend, pbuf2d, cam_in, cam_out ) use cam_history, only: addfld, register_vector_field, add_default use cam_budget, only: cam_budget_init use phys_grid_ctem, only: phys_grid_ctem_init + use phys_control, only: use_gw_nlgw_unet use ccpp_constituent_prop_mod, only: ccpp_const_props_init @@ -949,6 +950,11 @@ subroutine phys_init( phys_state, phys_tend, pbuf2d, cam_in, cam_out ) end if + + if (use_gw_nlgw_unet) then + call addfld('UTGW_NL', (/ 'lev' /), 'A', 'm/s2', 'Nonlinear GW zonal wind tendency') + call addfld('VTGW_NL', (/ 'lev' /), 'A', 'm/s2', 'Nonlinear GW meridional wind tendency') + end if ! Initialize CAM CCPP constituent properties array ! for use in CCPP-ized physics schemes: call ccpp_const_props_init() @@ -1183,6 +1189,12 @@ subroutine phys_run2(phys_state, ztodt, phys_tend, pbuf2d, cam_out, & use metdata, only: get_met_srf2 #endif use hemco_interface, only: HCOI_Chunk_Run + use nlgw_remap_mod, only: nlgw_regrid_init, nlgw_latlon_gather, nlgw_latlon_scatter, nlgw_regrid_final + use gw_nlgw_unet, only: gw_nlgw_unet_init, gw_nlgw_unet_infer, gw_nlgw_unet_finalize, gw_nlgw_unet_set_ptend + use gw_nlgw_utils, only: phys_vars, lonlat_vars, flux_to_forcing, gw_nlgw_model_path_unet + use phys_control, only: use_gw_nlgw_unet + use time_manager, only: get_nstep + use check_energy, only: check_energy_chng ! ! Input arguments ! @@ -1202,10 +1214,19 @@ subroutine phys_run2(phys_state, ztodt, phys_tend, pbuf2d, cam_out, & ! integer :: c ! chunk index integer :: ncol ! number of columns + integer :: nstep ! current timestep number type(physics_buffer_desc),pointer, dimension(:) :: phys_buffer_chunk + + ! for ML UNet model + type(physics_ptend) :: ptend ! parameterization tendencies for nlgw + real(r8) :: zero(pcols) ! array of zeros + type(phys_vars), target :: phys + type(lonlat_vars), target :: lonlat, gathered_lonlat + real(r8), dimension(pcols,pver) :: temp_uflux, temp_vflux, temp_utgw, temp_vtgw ! ! If exit condition just return ! + nstep = get_nstep() if(single_column.and.scm_crm_mode) then call diag_deallocate() @@ -1244,6 +1265,40 @@ subroutine phys_run2(phys_state, ztodt, phys_tend, pbuf2d, cam_out, & call t_startf ('ac_physics') call t_adj_detailf(+1) + call nlgw_regrid_init(phys, lonlat, gathered_lonlat) + + if (use_gw_nlgw_unet) then + ! gather data from all procs, all chunks into a global lonlat grid + call nlgw_latlon_gather(phys_state, phys, lonlat, gathered_lonlat) + + if (masterproc) then + call gw_nlgw_unet_init(gw_nlgw_model_path_unet) + ! run UNet model on globally gathered lonlat grid to compute fluxes + call gw_nlgw_unet_infer(gathered_lonlat) + call gw_nlgw_unet_finalize() + endif + + ! scatter back to all procs, into chunks and regrid back to phys grid + call nlgw_latlon_scatter(phys, lonlat, gathered_lonlat) + + do c = begchunk,endchunk + ncol = phys_state(c)%ncol + temp_uflux(:ncol,:pver) = transpose(phys%uflux(:pver,:ncol,c)) + temp_vflux(:ncol,:pver) = transpose(phys%vflux(:pver,:ncol,c)) + + ! compute tendencies from fluxes + call flux_to_forcing(temp_uflux, temp_utgw, phys_state(c)%pmid, ncol) + call flux_to_forcing(temp_vflux, temp_vtgw, phys_state(c)%pmid, ncol) + ! update ptend + call gw_nlgw_unet_set_ptend(phys_state(c), ptend, temp_utgw, temp_vtgw) + ! update state and check energy change + call physics_update(phys_state(c), ptend, ztodt, phys_tend(c)) + call check_energy_chng(phys_state(c), phys_tend(c), "nlgw_unet", nstep, ztodt, zero, zero, zero, zero) + end do + + call nlgw_regrid_final(phys, lonlat, gathered_lonlat) + endif + !$OMP PARALLEL DO PRIVATE (C, NCOL, phys_buffer_chunk) do c=begchunk,endchunk diff --git a/src/utils/cam_grid_support.F90 b/src/utils/cam_grid_support.F90 index d86c829e77..56d5c100f7 100644 --- a/src/utils/cam_grid_support.F90 +++ b/src/utils/cam_grid_support.F90 @@ -316,6 +316,7 @@ end subroutine print_attr_spec public :: cam_grid_compute_patch ! Functions for dealing with grid areas public :: cam_grid_get_areawt + public :: get_cam_grid_index interface cam_grid_attribute_register module procedure add_cam_grid_attribute_0d_int @@ -352,7 +353,6 @@ end subroutine print_attr_spec module procedure cam_grid_write_dist_array_3d_real end interface - ! Private interfaces interface get_cam_grid_index module procedure get_cam_grid_index_char ! For lookup by name module procedure get_cam_grid_index_int ! For lookup by ID diff --git a/src/utils/esmf_check_error_mod.F90 b/src/utils/esmf_check_error_mod.F90 new file mode 100644 index 0000000000..eb9db18a90 --- /dev/null +++ b/src/utils/esmf_check_error_mod.F90 @@ -0,0 +1,33 @@ +!------------------------------------------------------------------------------ +! ESMF error handler +!------------------------------------------------------------------------------ +module esmf_check_error_mod + use shr_kind_mod, only: cl=>SHR_KIND_CL + use spmd_utils, only: masterproc + use cam_logfile, only: iulog + use cam_abortutils, only: endrun + use ESMF, only: ESMF_SUCCESS + + implicit none + + private + public :: check_esmf_error + +contains + + subroutine check_esmf_error( rc, errmsg ) + + integer, intent(in) :: rc + character(len=*), intent(in) :: errmsg + + character(len=cl) :: errstr + + if (rc /= ESMF_SUCCESS) then + write(errstr,'(a,i6)') 'esmf_zonal_mod::'//trim(errmsg)//' -- ESMF ERROR code: ',rc + if (masterproc) write(iulog,*) trim(errstr) + call endrun(trim(errstr)) + end if + + end subroutine check_esmf_error + +end module esmf_check_error_mod diff --git a/src/utils/esmf_lonlat2phys_mod.F90 b/src/utils/esmf_lonlat2phys_mod.F90 new file mode 100644 index 0000000000..e66ed27635 --- /dev/null +++ b/src/utils/esmf_lonlat2phys_mod.F90 @@ -0,0 +1,160 @@ +!------------------------------------------------------------------------------ +! Provides methods for mapping from regular longitude / latitude grid +! to physics grid to via ESMF regridding capabilities +!------------------------------------------------------------------------------ +module esmf_lonlat2phys_mod + use shr_kind_mod, only: r8 => shr_kind_r8 + use cam_logfile, only: iulog + use cam_abortutils, only: endrun + use spmd_utils, only: masterproc + use ppgrid, only: pver + + use ESMF, only: ESMF_RouteHandle, ESMF_Field, ESMF_ArraySpec, ESMF_ArraySpecSet + use ESMF, only: ESMF_FieldCreate, ESMF_FieldRegridStore + use ESMF, only: ESMF_FieldGet, ESMF_FieldRegrid + use ESMF, only: ESMF_KIND_I4, ESMF_KIND_R8, ESMF_TYPEKIND_R8 + use ESMF, only: ESMF_REGRIDMETHOD_BILINEAR, ESMF_POLEMETHOD_ALLAVG, ESMF_EXTRAPMETHOD_NEAREST_IDAVG + use ESMF, only: ESMF_TERMORDER_SRCSEQ, ESMF_MESHLOC_ELEMENT, ESMF_STAGGERLOC_CENTER + use ESMF, only: ESMF_FieldDestroy, ESMF_RouteHandleDestroy + use esmf_check_error_mod, only: check_esmf_error + + implicit none + + private + + public :: esmf_lonlat2phys_init + public :: esmf_lonlat2phys_regrid + public :: esmf_lonlat2phys_destroy + public :: fields_bundle_t + public :: n_flx_flds + + type(ESMF_RouteHandle) :: rh_lonlat2phys_3d + + type(ESMF_Field) :: physfld_3d + type(ESMF_Field) :: lonlatfld_3d + + interface esmf_lonlat2phys_regrid + module procedure esmf_lonlat2phys_regrid_3d + end interface esmf_lonlat2phys_regrid + + type :: fields_bundle_t + real(r8), pointer :: fld(:,:,:) => null() + end type fields_bundle_t + + integer, parameter :: n_flx_flds = 2 ! 3D uflux and vflux + +contains + + !------------------------------------------------------------------------------ + !------------------------------------------------------------------------------ + subroutine esmf_lonlat2phys_init() + use esmf_phys_mesh_mod, only: physics_grid_mesh + use esmf_lonlat_grid_mod, only: lonlat_grid + + type(ESMF_ArraySpec) :: arrayspec + integer(ESMF_KIND_I4), pointer :: factorIndexList(:,:) + real(ESMF_KIND_R8), pointer :: factorList(:) + integer :: smm_srctermproc, smm_pipelinedep, rc + + character(len=*), parameter :: subname = 'esmf_lonlat2phys_init: ' + + smm_srctermproc = 0 + smm_pipelinedep = 16 + + ! create ESMF fields + + ! 3D phys fld + call ESMF_ArraySpecSet(arrayspec, 3, ESMF_TYPEKIND_R8, rc=rc) + call check_esmf_error(rc, subname//'ESMF_ArraySpecSet 3D phys fld ERROR') + + physfld_3d = ESMF_FieldCreate(physics_grid_mesh, arrayspec, & + gridToFieldMap=(/3/), meshloc=ESMF_MESHLOC_ELEMENT, & + ungriddedLBound=(/1,1/), ungriddedUBound=(/pver,n_flx_flds/), rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldCreate 3D phys fld ERROR') + + ! 3D lon lat grid + call ESMF_ArraySpecSet(arrayspec, 4, ESMF_TYPEKIND_R8, rc=rc) + call check_esmf_error(rc, subname//'ESMF_ArraySpecSet 3D lonlat fld ERROR') + + lonlatfld_3d = ESMF_FieldCreate( lonlat_grid, arrayspec, staggerloc=ESMF_STAGGERLOC_CENTER, & + ungriddedLBound=(/1,1/), ungriddedUBound=(/pver,n_flx_flds/), rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldCreate 3D lonlat fld ERROR') + + call ESMF_FieldRegridStore(srcField=lonlatfld_3d, dstField=physfld_3d, & + regridMethod=ESMF_REGRIDMETHOD_BILINEAR, & + polemethod=ESMF_POLEMETHOD_ALLAVG, & + extrapMethod=ESMF_EXTRAPMETHOD_NEAREST_IDAVG, & + routeHandle=rh_lonlat2phys_3d, factorIndexList=factorIndexList, & + factorList=factorList, srcTermProcessing=smm_srctermproc, & + pipelineDepth=smm_pipelinedep, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldRegridStore 3D routehandle ERROR') + + end subroutine esmf_lonlat2phys_init + + !------------------------------------------------------------------------------ + !------------------------------------------------------------------------------ + subroutine esmf_lonlat2phys_regrid_3d(lonlatflds, physflds) + use esmf_lonlat_grid_mod, only: lon_beg,lon_end,lat_beg,lat_end + use ppgrid, only: pcols, pver, begchunk, endchunk + use phys_grid, only: get_ncols_p + + type(fields_bundle_t), intent(in) :: lonlatflds(n_flx_flds) + type(fields_bundle_t), intent(inout) :: physflds(n_flx_flds) + + integer :: i, ichnk, ncol, ifld, ilev, icol, rc + real(ESMF_KIND_R8), pointer :: physptr(:,:,:) + real(ESMF_KIND_R8), pointer :: lonlatptr(:,:,:,:) + + character(len=*), parameter :: subname = 'esmf_lonlat2phys_regrid_3d: ' + + ! set values of lonlatfld_3d ESMF field + call ESMF_FieldGet(lonlatfld_3d, localDe=0, farrayPtr=lonlatptr, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldGet lonlatptr') + + do ifld = 1,n_flx_flds + lonlatptr(lon_beg:lon_end,lat_beg:lat_end,1:pver,ifld) = lonlatflds(ifld)%fld(lon_beg:lon_end,lat_beg:lat_end,1:pver) + end do + + ! regrid + call ESMF_FieldRegrid(lonlatfld_3d, physfld_3d, rh_lonlat2phys_3d, & + termorderflag=ESMF_TERMORDER_SRCSEQ, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldRegrid lonlatfld_3d->physfld_3d') + + ! get values from physfld_3d ESMF field + call ESMF_FieldGet(physfld_3d, localDe=0, farrayPtr=physptr, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldGet physptr') + + i = 0 + do ichnk = begchunk, endchunk + ncol = get_ncols_p(ichnk) + do icol = 1,ncol + i = i+1 + do ifld = 1,n_flx_flds + do ilev = 1,pver + physflds(ifld)%fld(ilev,icol,ichnk) = physptr(ilev,ifld,i) + end do + end do + end do + end do + + end subroutine esmf_lonlat2phys_regrid_3d + + !------------------------------------------------------------------------------ + !------------------------------------------------------------------------------ + subroutine esmf_lonlat2phys_destroy() + + integer :: rc + character(len=*), parameter :: subname = 'esmf_lonlat2phys_destroy: ' + + call ESMF_RouteHandleDestroy(rh_lonlat2phys_3d, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldDestroy rh_lonlat2phys_3d') + + call ESMF_FieldDestroy(lonlatfld_3d, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldDestroy lonlatfld_3d') + + call ESMF_FieldDestroy(physfld_3d, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldDestroy physfld_3d') + + end subroutine esmf_lonlat2phys_destroy + +end module esmf_lonlat2phys_mod diff --git a/src/utils/esmf_lonlat_grid_mod.F90 b/src/utils/esmf_lonlat_grid_mod.F90 new file mode 100644 index 0000000000..394322ddc6 --- /dev/null +++ b/src/utils/esmf_lonlat_grid_mod.F90 @@ -0,0 +1,339 @@ +!------------------------------------------------------------------------------- +! Encapsulates an ESMF regular longitude / latitude grid +!------------------------------------------------------------------------------- +module esmf_lonlat_grid_mod + use shr_kind_mod, only: r8 => shr_kind_r8 + use spmd_utils, only: masterproc, mpicom + use cam_logfile, only: iulog + use cam_abortutils, only: endrun + + use ESMF, only: ESMF_Grid, ESMF_GridCreate1PeriDim, ESMF_GridAddCoord + use ESMF, only: ESMF_GridGetCoord, ESMF_GridDestroy + use ESMF, only: ESMF_KIND_R8, ESMF_INDEX_GLOBAL, ESMF_STAGGERLOC_CENTER + use esmf_check_error_mod, only: check_esmf_error + + implicit none + + public + + type(ESMF_Grid), protected :: lonlat_grid + + integer, protected :: nlon = 0 + integer, protected :: nlat = 0 + + integer, protected :: lon_beg = -1 + integer, protected :: lon_end = -1 + integer, protected :: lat_beg = -1 + integer, protected :: lat_end = -1 + + real(r8), allocatable, protected :: glats(:) + real(r8), allocatable, protected :: glons(:) + + integer, protected :: zonal_comm ! zonal direction MPI communicator + +contains + + subroutine esmf_lonlat_grid_init(nlats_in, nlons_in) + use phys_grid, only: get_grid_dims + use mpi, only: mpi_comm_size, mpi_comm_rank, MPI_PROC_NULL, MPI_INTEGER + + integer, intent(in) :: nlats_in + integer, intent(in) :: nlons_in + + real(r8) :: delx, dely + + integer :: npes, ierr, mytid, irank, mytidi, mytidj + integer :: i,j, n + integer :: ntasks_lon, ntasks_lat + integer :: lons_per_task, lons_overflow, lats_per_task, lats_overflow + integer :: task_cnt + integer :: mynlats, mynlons + + integer, allocatable :: mytidi_send(:) + integer, allocatable :: mytidj_send(:) + integer, allocatable :: mytidi_recv(:) + integer, allocatable :: mytidj_recv(:) + + integer, allocatable :: nlons_send(:) + integer, allocatable :: nlats_send(:) + integer, allocatable :: nlons_recv(:) + integer, allocatable :: nlats_recv(:) + + integer, allocatable :: nlons_task(:) + integer, allocatable :: nlats_task(:) + + integer, parameter :: minlats_per_pe = 2 + integer, parameter :: minlons_per_pe = 2 + + integer, allocatable :: petmap(:,:,:) + integer :: petcnt, astat + + integer :: lbnd_lat, ubnd_lat, lbnd_lon, ubnd_lon + integer :: lbnd(1), ubnd(1) + real(ESMF_KIND_R8), pointer :: coordX(:), coordY(:) + + character(len=*), parameter :: subname = 'esmf_lonlat_grid_init: ' + + ! create reg lon lat grid + + nlat = nlats_in + dely = 180._r8/nlat + + nlon = nlons_in + delx = 360._r8/nlon + + allocate(glons(nlon), stat=astat) + if (astat/=0) then + call endrun(subname//'not able to allocate glons array') + end if + allocate(glats(nlat), stat=astat) + if (astat/=0) then + call endrun(subname//'not able to allocate glats array') + end if + + glons(1) = 0._r8 + glats(1) = -90._r8 + 0.5_r8 * dely + + do i = 2,nlon + glons(i) = glons(i-1) + delx + end do + do i = 2,nlat + glats(i) = glats(i-1) + dely + end do + + ! decompose the grid across mpi tasks ... + + call mpi_comm_size(mpicom, npes, ierr) + call mpi_comm_rank(mpicom, mytid, ierr) + + decomp_loop: do ntasks_lon = 1,nlon + ntasks_lat = npes/ntasks_lon + if ( (minlats_per_pe*ntasks_latmytid) exit jloop + end do + enddo jloop + endif + + mynlats = lat_end-lat_beg+1 + mynlons = lon_end-lon_beg+1 + + if (mynlats shr_kind_r8 + use cam_logfile, only: iulog + use cam_abortutils, only: endrun + use spmd_utils, only: masterproc + use ppgrid, only: pver + + use ESMF, only: ESMF_RouteHandle, ESMF_Field, ESMF_ArraySpec, ESMF_ArraySpecSet + use ESMF, only: ESMF_FieldCreate, ESMF_FieldRegridStore + use ESMF, only: ESMF_FieldGet, ESMF_FieldRegrid + use ESMF, only: ESMF_KIND_I4, ESMF_KIND_R8, ESMF_TYPEKIND_R8 + use ESMF, only: ESMF_REGRIDMETHOD_BILINEAR, ESMF_POLEMETHOD_ALLAVG, ESMF_EXTRAPMETHOD_NEAREST_IDAVG + use ESMF, only: ESMF_TERMORDER_SRCSEQ, ESMF_MESHLOC_ELEMENT, ESMF_STAGGERLOC_CENTER + use ESMF, only: ESMF_FieldDestroy, ESMF_RouteHandleDestroy + use esmf_check_error_mod, only: check_esmf_error + + implicit none + + private + + public :: esmf_phys2lonlat_init + public :: esmf_phys2lonlat_regrid + public :: esmf_phys2lonlat_destroy + public :: fields_bundle_t + public :: nflds + + type(ESMF_RouteHandle) :: rh_phys2lonlat_3d + type(ESMF_RouteHandle) :: rh_phys2lonlat_2d + + type(ESMF_Field) :: physfld_3d + type(ESMF_Field) :: lonlatfld_3d + + type(ESMF_Field) :: physfld_2d + type(ESMF_Field) :: lonlatfld_2d + + interface esmf_phys2lonlat_regrid + module procedure esmf_phys2lonlat_regrid_2d + module procedure esmf_phys2lonlat_regrid_3d + end interface esmf_phys2lonlat_regrid + + type :: fields_bundle_t + real(r8), pointer :: fld(:,:,:) => null() + end type fields_bundle_t + + integer, parameter :: nflds = 4 + +contains + + !------------------------------------------------------------------------------ + !------------------------------------------------------------------------------ + subroutine esmf_phys2lonlat_init() + use esmf_phys_mesh_mod, only: physics_grid_mesh + use esmf_lonlat_grid_mod, only: lonlat_grid + + type(ESMF_ArraySpec) :: arrayspec + integer(ESMF_KIND_I4), pointer :: factorIndexList(:,:) + real(ESMF_KIND_R8), pointer :: factorList(:) + integer :: smm_srctermproc, smm_pipelinedep, rc + + character(len=*), parameter :: subname = 'esmf_phys2lonlat_init: ' + + smm_srctermproc = 0 + smm_pipelinedep = 16 + + ! create ESMF fields + + ! 3D phys fld + call ESMF_ArraySpecSet(arrayspec, 3, ESMF_TYPEKIND_R8, rc=rc) + call check_esmf_error(rc, subname//'ESMF_ArraySpecSet 3D phys fld ERROR') + + physfld_3d = ESMF_FieldCreate(physics_grid_mesh, arrayspec, & + gridToFieldMap=(/3/), meshloc=ESMF_MESHLOC_ELEMENT, & + ungriddedLBound=(/1,1/), ungriddedUBound=(/pver,nflds/), rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldCreate 3D phys fld ERROR') + + ! 3D lon lat grid + call ESMF_ArraySpecSet(arrayspec, 4, ESMF_TYPEKIND_R8, rc=rc) + call check_esmf_error(rc, subname//'ESMF_ArraySpecSet 3D lonlat fld ERROR') + + lonlatfld_3d = ESMF_FieldCreate( lonlat_grid, arrayspec, staggerloc=ESMF_STAGGERLOC_CENTER, & + ungriddedLBound=(/1,1/), ungriddedUBound=(/pver,nflds/), rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldCreate 3D lonlat fld ERROR') + + ! 2D phys fld + call ESMF_ArraySpecSet(arrayspec, 1, ESMF_TYPEKIND_R8, rc=rc) + call check_esmf_error(rc, subname//'ESMF_ArraySpecSet 2D phys fld ERROR') + + physfld_2d = ESMF_FieldCreate(physics_grid_mesh, arrayspec, & + meshloc=ESMF_MESHLOC_ELEMENT, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldCreate 2D phys fld ERROR') + + ! 2D lon/lat grid + call ESMF_ArraySpecSet(arrayspec, 2, ESMF_TYPEKIND_R8, rc=rc) + call check_esmf_error(rc, subname//'ESMF_ArraySpecSet 2D lonlat fld ERROR') + + lonlatfld_2d = ESMF_FieldCreate( lonlat_grid, arrayspec, staggerloc=ESMF_STAGGERLOC_CENTER, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldCreate 2D lonlat fld ERROR') + + call ESMF_FieldRegridStore(srcField=physfld_3d, dstField=lonlatfld_3d, & + regridMethod=ESMF_REGRIDMETHOD_BILINEAR, & + polemethod=ESMF_POLEMETHOD_ALLAVG, & + extrapMethod=ESMF_EXTRAPMETHOD_NEAREST_IDAVG, & + routeHandle=rh_phys2lonlat_3d, factorIndexList=factorIndexList, & + factorList=factorList, srcTermProcessing=smm_srctermproc, & + pipelineDepth=smm_pipelinedep, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldRegridStore 3D routehandle ERROR') + + call ESMF_FieldRegridStore(srcField=physfld_2d, dstField=lonlatfld_2d, & + regridMethod=ESMF_REGRIDMETHOD_BILINEAR, & + polemethod=ESMF_POLEMETHOD_ALLAVG, & + extrapMethod=ESMF_EXTRAPMETHOD_NEAREST_IDAVG, & + routeHandle=rh_phys2lonlat_2d, factorIndexList=factorIndexList, & + factorList=factorList, srcTermProcessing=smm_srctermproc, & + pipelineDepth=smm_pipelinedep, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldRegridStore 3D routehandle ERROR') + + end subroutine esmf_phys2lonlat_init + + !------------------------------------------------------------------------------ + !------------------------------------------------------------------------------ + subroutine esmf_phys2lonlat_regrid_3d(physflds, lonlatflds) + use esmf_lonlat_grid_mod, only: lon_beg,lon_end,lat_beg,lat_end + use ppgrid, only: pcols, pver, begchunk, endchunk + use phys_grid, only: get_ncols_p + + type(fields_bundle_t), intent(in) :: physflds(nflds) + type(fields_bundle_t), intent(inout) :: lonlatflds(nflds) + + integer :: i, ichnk, ncol, ifld, ilev, icol, rc + real(ESMF_KIND_R8), pointer :: physptr(:,:,:) + real(ESMF_KIND_R8), pointer :: lonlatptr(:,:,:,:) + + character(len=*), parameter :: subname = 'esmf_phys2lonlat_regrid_3d: ' + + call ESMF_FieldGet(physfld_3d, localDe=0, farrayPtr=physptr, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldGet physptr') + + i = 0 + do ichnk = begchunk, endchunk + ncol = get_ncols_p(ichnk) + do icol = 1,ncol + i = i+1 + do ifld = 1,nflds + do ilev = 1,pver + physptr(ilev,ifld,i) = physflds(ifld)%fld(ilev,icol,ichnk) + end do + end do + end do + end do + + call ESMF_FieldRegrid(physfld_3d, lonlatfld_3d, rh_phys2lonlat_3d, & + termorderflag=ESMF_TERMORDER_SRCSEQ, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldRegrid physfld_3d->lonlatfld_3d') + + call ESMF_FieldGet(lonlatfld_3d, localDe=0, farrayPtr=lonlatptr, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldGet lonlatptr') + + do ifld = 1,nflds + lonlatflds(ifld)%fld(lon_beg:lon_end,lat_beg:lat_end,1:pver) = lonlatptr(lon_beg:lon_end,lat_beg:lat_end,1:pver,ifld) + end do + + end subroutine esmf_phys2lonlat_regrid_3d + + !------------------------------------------------------------------------------ + !------------------------------------------------------------------------------ + subroutine esmf_phys2lonlat_regrid_2d(physarr, lonlatarr) + use esmf_lonlat_grid_mod, only: lon_beg,lon_end,lat_beg,lat_end + use ppgrid, only: pcols, pver, begchunk, endchunk + use phys_grid, only: get_ncols_p + + real(r8),intent(in) :: physarr(pcols,begchunk:endchunk) + real(r8),intent(out) :: lonlatarr(lon_beg:lon_end,lat_beg:lat_end) + + integer :: i, ichnk, ncol, icol, rc + real(ESMF_KIND_R8), pointer :: physptr(:) + real(ESMF_KIND_R8), pointer :: lonlatptr(:,:) + + character(len=*), parameter :: subname = 'esmf_phys2lonlat_regrid_2d: ' + + call ESMF_FieldGet(physfld_2d, localDe=0, farrayPtr=physptr, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldGet physptr') + + i = 0 + do ichnk = begchunk, endchunk + ncol = get_ncols_p(ichnk) + do icol = 1,ncol + i = i+1 + physptr(i) = physarr(icol,ichnk) + end do + end do + + call ESMF_FieldRegrid(physfld_2d, lonlatfld_2d, rh_phys2lonlat_2d, & + termorderflag=ESMF_TERMORDER_SRCSEQ, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldRegrid physfld_3d->lonlatfld_3d') + + call ESMF_FieldGet(lonlatfld_2d, localDe=0, farrayPtr=lonlatptr, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldGet lonlatptr') + + lonlatarr(lon_beg:lon_end,lat_beg:lat_end) = lonlatptr(lon_beg:lon_end,lat_beg:lat_end) + + end subroutine esmf_phys2lonlat_regrid_2d + + !------------------------------------------------------------------------------ + !------------------------------------------------------------------------------ + subroutine esmf_phys2lonlat_destroy() + + integer :: rc + character(len=*), parameter :: subname = 'esmf_phys2lonlat_destroy: ' + + call ESMF_RouteHandleDestroy(rh_phys2lonlat_3d, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldDestroy rh_phys2lonlat_3d') + + call ESMF_RouteHandleDestroy(rh_phys2lonlat_2d, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldDestroy rh_phys2lonlat_2d') + + call ESMF_FieldDestroy(lonlatfld_3d, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldDestroy lonlatfld_3d') + + call ESMF_FieldDestroy(lonlatfld_2d, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldDestroy lonlatfld_2d') + + call ESMF_FieldDestroy(physfld_3d, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldDestroy physfld_3d') + + call ESMF_FieldDestroy(physfld_2d, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldDestroy physfld_2d') + + end subroutine esmf_phys2lonlat_destroy + +end module esmf_phys2lonlat_mod diff --git a/src/utils/esmf_phys_mesh_mod.F90 b/src/utils/esmf_phys_mesh_mod.F90 new file mode 100644 index 0000000000..4235452dc5 --- /dev/null +++ b/src/utils/esmf_phys_mesh_mod.F90 @@ -0,0 +1,202 @@ +!------------------------------------------------------------------------------- +! Encapsulates the CAM physics grid mesh +!------------------------------------------------------------------------------- +module esmf_phys_mesh_mod + use shr_kind_mod, only: r8 => shr_kind_r8, cs=>shr_kind_cs, cl=>shr_kind_cl + use spmd_utils, only: masterproc + use cam_logfile, only: iulog + use cam_abortutils, only: endrun + use ESMF, only: ESMF_DistGrid, ESMF_DistGridCreate, ESMF_MeshCreate + use ESMF, only: ESMF_FILEFORMAT_ESMFMESH,ESMF_MeshGet,ESMF_Mesh, ESMF_SUCCESS + use ESMF, only: ESMF_MeshDestroy, ESMF_DistGridDestroy + use esmf_check_error_mod, only: check_esmf_error + + implicit none + + private + + public :: esmf_phys_mesh_init + public :: esmf_phys_mesh_destroy + public :: physics_grid_mesh + + ! phys_mesh: Local copy of physics grid + type(ESMF_Mesh), protected :: physics_grid_mesh + + ! dist_grid_2d: DistGrid for 2D fields + type(ESMF_DistGrid) :: dist_grid_2d + +contains + + !----------------------------------------------------------------------------- + !----------------------------------------------------------------------------- + subroutine esmf_phys_mesh_init() + use phys_control, only: phys_getopts + use phys_grid, only: get_ncols_p, get_gcol_p, get_rlon_all_p, get_rlat_all_p + use ppgrid, only: pcols, begchunk, endchunk + use shr_const_mod,only: shr_const_pi + + ! Local variables + integer :: ncols + integer :: chnk, col, dindex + integer, allocatable :: decomp(:) + character(len=cl) :: grid_file + integer :: spatialDim + integer :: numOwnedElements + real(r8), pointer :: ownedElemCoords(:) + real(r8), pointer :: lat(:), latMesh(:) + real(r8), pointer :: lon(:), lonMesh(:) + real(r8) :: lats(pcols) ! array of chunk latitudes + real(r8) :: lons(pcols) ! array of chunk longitude + character(len=cs) :: tempc1,tempc2 + character(len=300) :: errstr + + integer :: i, c, n, total_cols, rc + + real(r8), parameter :: abstol = 1.e-3_r8 + real(r8), parameter :: radtodeg = 180.0_r8/shr_const_pi + character(len=*), parameter :: subname = 'esmf_phys_mesh_init: ' + + ! Find the physics grid file + call phys_getopts(physics_grid_out=grid_file) + + ! Compute the local decomp + total_cols = 0 + do chnk = begchunk, endchunk + total_cols = total_cols + get_ncols_p(chnk) + end do + allocate(decomp(total_cols), stat=rc) + if (rc/=0) then + call endrun(subname//'not able to allocate decomp') + end if + + dindex = 0 + do chnk = begchunk, endchunk + ncols = get_ncols_p(chnk) + do col = 1, ncols + dindex = dindex + 1 + decomp(dindex) = get_gcol_p(chnk, col) + end do + end do + + ! Create a DistGrid based on the physics decomp + dist_grid_2d = ESMF_DistGridCreate(arbSeqIndexList=decomp, rc=rc) + call check_esmf_error(rc, subname//'ESMF_DistGridCreate') + + ! Create an ESMF_mesh for the physics decomposition + physics_grid_mesh = ESMF_MeshCreate(trim(grid_file), ESMF_FILEFORMAT_ESMFMESH, & + elementDistgrid=dist_grid_2d, rc=rc) + call check_esmf_error(rc, subname//'ESMF_MeshCreate') + + ! Check that the mesh coordinates are consistent with the model physics column coordinates + + ! obtain mesh lats and lons + call ESMF_MeshGet(physics_grid_mesh, spatialDim=spatialDim, numOwnedElements=numOwnedElements, rc=rc) + call check_esmf_error(rc, subname//'ESMF_MeshGet') + + if (numOwnedElements /= total_cols) then + write(tempc1,'(i10)') numOwnedElements + write(tempc2,'(i10)') total_cols + call endrun(subname//"ERROR numOwnedElements "// & + trim(tempc1) //" not equal to local size "// trim(tempc2)) + end if + + allocate(ownedElemCoords(spatialDim*numOwnedElements), stat=rc) + if (rc/=0) then + call endrun(subname//'not able to allocate ownedElemCoords') + end if + + allocate(lonMesh(total_cols), stat=rc) + if (rc/=0) then + call endrun(subname//'not able to allocate lonMesh') + end if + + allocate(latMesh(total_cols), stat=rc) + if (rc/=0) then + call endrun(subname//'not able to allocate latMesh') + end if + + call ESMF_MeshGet(physics_grid_mesh, ownedElemCoords=ownedElemCoords) + call check_esmf_error(rc, subname//'ESMF_MeshGet') + + do n = 1,total_cols + lonMesh(n) = ownedElemCoords(2*n-1) + latMesh(n) = ownedElemCoords(2*n) + end do + + ! obtain internally generated cam lats and lons + allocate(lon(total_cols), stat=rc); + if (rc/=0) then + call endrun(subname//'not able to allocate lon') + end if + + lon(:) = 0._r8 + + allocate(lat(total_cols), stat=rc); + if (rc/=0) then + call endrun(subname//'not able to allocate lat') + end if + + lat(:) = 0._r8 + + n=0 + do c = begchunk, endchunk + ncols = get_ncols_p(c) + ! latitudes and longitudes returned in radians + call get_rlat_all_p(c, ncols, lats) + call get_rlon_all_p(c, ncols, lons) + do i=1,ncols + n = n+1 + lat(n) = lats(i)*radtodeg + lon(n) = lons(i)*radtodeg + end do + end do + + errstr = '' + ! error check differences between internally generated lons and those read in + do n = 1,total_cols + if (abs(lonMesh(n) - lon(n)) > abstol) then + if ( (abs(lonMesh(n)-lon(n)) > 360._r8+abstol) .or. (abs(lonMesh(n)-lon(n)) < 360._r8-abstol) ) then + write(errstr,100) n,lon(n),lonMesh(n), abs(lonMesh(n)-lon(n)) + write(iulog,*) trim(errstr) + endif + end if + if (abs(latMesh(n) - lat(n)) > abstol) then + ! poles in the 4x5 SCRIP file seem to be off by 1 degree + if (.not.( (abs(lat(n))>88.0_r8) .and. (abs(latMesh(n))>88.0_r8) )) then + write(errstr,101) n,lat(n),latMesh(n), abs(latMesh(n)-lat(n)) + write(iulog,*) trim(errstr) + endif + end if + end do + + if ( len_trim(errstr) > 0 ) then + call endrun(subname//'physics mesh coords do not match model coords') + end if + + ! deallocate memory + deallocate(ownedElemCoords) + deallocate(lon, lonMesh) + deallocate(lat, latMesh) + deallocate(decomp) + +100 format('esmf_phys_mesh_init: coord mismatch... n, lon(n), lonmesh(n), diff_lon = ',i6,2(f21.13,3x),d21.5) +101 format('esmf_phys_mesh_init: coord mismatch... n, lat(n), latmesh(n), diff_lat = ',i6,2(f21.13,3x),d21.5) + + end subroutine esmf_phys_mesh_init + + !----------------------------------------------------------------------------- + !----------------------------------------------------------------------------- + subroutine esmf_phys_mesh_destroy() + + integer :: rc + character(len=*), parameter :: subname = 'esmf_phys_mesh_destroy: ' + + call ESMF_MeshDestroy(physics_grid_mesh, rc=rc) + call check_esmf_error(rc, subname//'ESMF_MeshDestroy phys_mesh') + + call ESMF_DistGridDestroy(dist_grid_2d, rc=rc) + call check_esmf_error(rc, subname//'ESMF_DistGridDestroy dist_grid_2d') + + end subroutine esmf_phys_mesh_destroy + +end module esmf_phys_mesh_mod diff --git a/src/utils/esmf_zonal_mean_mod.F90 b/src/utils/esmf_zonal_mean_mod.F90 new file mode 100644 index 0000000000..822c7ae639 --- /dev/null +++ b/src/utils/esmf_zonal_mean_mod.F90 @@ -0,0 +1,155 @@ +!------------------------------------------------------------------------------ +! Provides methods for calculating zonal means on the ESMF regular longitude +! / latitude grid +!------------------------------------------------------------------------------ +module esmf_zonal_mean_mod + use shr_kind_mod, only: r8 => shr_kind_r8 + use cam_logfile, only: iulog + use cam_abortutils, only: endrun + use spmd_utils, only: masterproc + use cam_history_support, only : fillvalue + + implicit none + + private + + public :: esmf_zonal_mean_calc + public :: esmf_zonal_mean_masked + public :: esmf_zonal_mean_wsums + + interface esmf_zonal_mean_calc + module procedure esmf_zonal_mean_calc_2d + module procedure esmf_zonal_mean_calc_3d + end interface esmf_zonal_mean_calc + +contains + + !------------------------------------------------------------------------------ + ! Calculates zonal means of 3D fields. The wght option can be used to mask out + ! regions such as mountains. + !------------------------------------------------------------------------------ + subroutine esmf_zonal_mean_calc_3d(lonlatarr, zmarr, wght) + use ppgrid, only: pver + use esmf_lonlat_grid_mod, only: lon_beg,lon_end,lat_beg,lat_end, nlon + use esmf_lonlat_grid_mod, only: zonal_comm + use shr_reprosum_mod,only: shr_reprosum_calc + + real(r8), intent(in) :: lonlatarr(lon_beg:lon_end,lat_beg:lat_end,pver) + real(r8), intent(out) :: zmarr(lat_beg:lat_end,pver) + + real(r8), optional, intent(in) :: wght(lon_beg:lon_end,lat_beg:lat_end,pver) + + real(r8) :: tmparr(lon_beg:lon_end,pver) + real(r8) :: gsum(pver) + + real(r8) :: wsums(lat_beg:lat_end,pver) + + integer :: numlons, ilat, ilev + + numlons = lon_end-lon_beg+1 + + ! zonal mean + if (present(wght)) then + + wsums = esmf_zonal_mean_wsums(wght) + call esmf_zonal_mean_masked(lonlatarr, wght, wsums, zmarr) + + else + + do ilat = lat_beg, lat_end + call shr_reprosum_calc(lonlatarr(lon_beg:lon_end,ilat,:), gsum, numlons, numlons, pver, gbl_count=nlon, commid=zonal_comm) + zmarr(ilat,:) = gsum(:)/nlon + end do + + end if + + end subroutine esmf_zonal_mean_calc_3d + + !------------------------------------------------------------------------------ + ! Computes zonal mean for 2D lon / lat fields. + !------------------------------------------------------------------------------ + subroutine esmf_zonal_mean_calc_2d(lonlatarr, zmarr) + use esmf_lonlat_grid_mod, only: lon_beg,lon_end,lat_beg,lat_end, nlon + use esmf_lonlat_grid_mod, only: zonal_comm + use shr_reprosum_mod,only: shr_reprosum_calc + + real(r8), intent(in) :: lonlatarr(lon_beg:lon_end,lat_beg:lat_end) + real(r8), intent(out) :: zmarr(lat_beg:lat_end) + + real(r8) :: gsum(lat_beg:lat_end) + + integer :: numlons, numlats + + numlons = lon_end-lon_beg+1 + numlats = lat_end-lat_beg+1 + + ! zonal mean + + call shr_reprosum_calc(lonlatarr, gsum, numlons, numlons, numlats, gbl_count=nlon, commid=zonal_comm) + zmarr(:) = gsum(:)/nlon + + end subroutine esmf_zonal_mean_calc_2d + + !------------------------------------------------------------------------------ + ! Computes longitude sums of grid cell weights. + !------------------------------------------------------------------------------ + function esmf_zonal_mean_wsums(wght) result(wsums) + use esmf_lonlat_grid_mod, only: lon_beg,lon_end,lat_beg,lat_end, nlon + use esmf_lonlat_grid_mod, only: zonal_comm + use shr_reprosum_mod,only: shr_reprosum_calc + use ppgrid, only: pver + + real(r8), intent(in) :: wght(lon_beg:lon_end,lat_beg:lat_end,pver) + + real(r8) :: wsums(lat_beg:lat_end,pver) + integer :: numlons, ilat + + numlons = lon_end-lon_beg+1 + + do ilat = lat_beg, lat_end + + call shr_reprosum_calc(wght(lon_beg:lon_end,ilat,:), wsums(ilat,1:pver), & + numlons, numlons, pver, gbl_count=nlon, commid=zonal_comm) + + end do + + end function esmf_zonal_mean_wsums + + !------------------------------------------------------------------------------ + ! Masks out regions (e.g. mountains) from zonal mean calculation. + !------------------------------------------------------------------------------ + subroutine esmf_zonal_mean_masked(lonlatarr, wght, wsums, zmarr) + use esmf_lonlat_grid_mod, only: lon_beg,lon_end,lat_beg,lat_end, nlon + use esmf_lonlat_grid_mod, only: zonal_comm + use shr_reprosum_mod,only: shr_reprosum_calc + use ppgrid, only: pver + + real(r8), intent(in) :: lonlatarr(lon_beg:lon_end,lat_beg:lat_end,pver) + real(r8), intent(in) :: wght(lon_beg:lon_end,lat_beg:lat_end,pver) ! grid cell weights + real(r8), intent(in) :: wsums(lat_beg:lat_end,pver) ! pre-computed sums of grid cell weights + real(r8), intent(out) :: zmarr(lat_beg:lat_end,pver) ! zonal means + + real(r8) :: tmparr(lon_beg:lon_end,pver) + integer :: numlons, ilat, ilev + real(r8) :: gsum(pver) + + numlons = lon_end-lon_beg+1 + + do ilat = lat_beg, lat_end + + tmparr(lon_beg:lon_end,:) = wght(lon_beg:lon_end,ilat,:)*lonlatarr(lon_beg:lon_end,ilat,:) + call shr_reprosum_calc(tmparr, gsum, numlons, numlons, pver, gbl_count=nlon, commid=zonal_comm) + + do ilev = 1,pver + if (wsums(ilat,ilev)>0._r8) then + zmarr(ilat,ilev) = gsum(ilev)/wsums(ilat,ilev) + else + zmarr(ilat,ilev) = fillvalue + end if + end do + + end do + + end subroutine esmf_zonal_mean_masked + +end module esmf_zonal_mean_mod diff --git a/src/utils/remap.F90 b/src/utils/remap.F90 new file mode 100644 index 0000000000..cb432cfb05 --- /dev/null +++ b/src/utils/remap.F90 @@ -0,0 +1,483 @@ +!----------------------------------------------------------------------------- +! utilities to gather and distribute columns and remap them to/from +! cubed-sphere grid to a lat-lon grid. +!----------------------------------------------------------------------------- +module nlgw_remap_mod + use shr_kind_mod, only: r8 => shr_kind_r8, cx => SHR_KIND_CX + use ppgrid, only: begchunk, endchunk, pcols, pver, pverp + use physics_types, only: physics_state + use phys_grid, only: get_ncols_p + use spmd_utils, only: masterproc, npes + use ref_pres, only: pref_mid + use esmf_lonlat_grid_mod, only: beglon=>lon_beg, endlon=>lon_end, beglat=>lat_beg, endlat=>lat_end + use cam_history, only: addfld, outfld, horiz_only + use cam_history_support, only : fillvalue + use perf_mod, only: t_startf, t_stopf + use cam_logfile, only: iulog + use cam_abortutils, only: endrun + use gw_nlgw_utils, only: phys_vars, lonlat_vars + + implicit none + + private + + public :: nlgw_regrid_init + public :: nlgw_latlon_gather + public :: nlgw_latlon_scatter + public :: nlgw_regrid_final + + ! private (for book-keeping in MPI calls) + integer, allocatable :: recvcnts(:), displs(:) + integer, allocatable :: beglats(:), beglons(:) + integer, allocatable :: endlats(:), endlons(:) + + logical, parameter, public :: debug = .true. + + +contains + + !----------------------------------------------------------------------------- + ! Initialize arrays and grids for regridding/MPI calls + !----------------------------------------------------------------------------- + subroutine nlgw_regrid_init(phys, lonlat, gathered_lonlat) + use cam_grid_support, only: horiz_coord_t, horiz_coord_create, iMap, cam_grid_register, get_cam_grid_index + use esmf_lonlat_grid_mod, only: glats, glons + use esmf_lonlat_grid_mod, only: esmf_lonlat_grid_init + use esmf_phys_mesh_mod, only: esmf_phys_mesh_init + use esmf_phys2lonlat_mod, only: esmf_phys2lonlat_init + use esmf_lonlat2phys_mod, only: esmf_lonlat2phys_init + use gw_nlgw_utils, only: nlon, nlat + + integer, parameter :: reg_decomp = 332 + + integer(iMap), pointer :: grid_map(:,:) + + integer(iMap), pointer :: coord_map(:) => null() + type(horiz_coord_t), pointer :: lon_coord + type(horiz_coord_t), pointer :: lat_coord + integer :: i, j, ind, astat + + type(phys_vars), intent(inout), target :: phys + type(lonlat_vars), intent(inout), target :: lonlat + type(lonlat_vars), intent(inout), target :: gathered_lonlat + + character(len=*), parameter :: subname = 'ctem_diags_reg: ' + + ! initialize grids and mapping + call esmf_lonlat_grid_init(nlat, nlon) + call esmf_phys_mesh_init() + call esmf_phys2lonlat_init() + call esmf_lonlat2phys_init() + + ! for the lon-lat grid + allocate(grid_map(4, ((endlon - beglon + 1) * (endlat - beglat + 1))), stat=astat) + if (astat/=0) then + call endrun(subname//'not able to allocate grid_map array') + end if + + ind = 0 + do i = beglat, endlat + do j = beglon, endlon + ind = ind + 1 + grid_map(1, ind) = j + grid_map(2, ind) = i + grid_map(3, ind) = j + grid_map(4, ind) = i + end do + end do + + allocate(coord_map(endlat - beglat + 1), stat=astat) + if (astat/=0) then + call endrun(subname//'not able to allocate coord_map array') + end if + + if (beglon==1) then + coord_map = (/ (i, i = beglat, endlat) /) + else + coord_map = 0 + end if + lat_coord => horiz_coord_create('reglat', '', nlat, 'latitude', 'degrees_north', beglat, endlat, & + glats(beglat:endlat), map=coord_map) + + nullify(coord_map) + + allocate(coord_map(endlon - beglon + 1), stat=astat) + if (astat/=0) then + call endrun(subname//'not able to allocate coord_map array') + end if + + if (beglat==1) then + coord_map = (/ (i, i = beglon, endlon) /) + else + coord_map = 0 + end if + + lon_coord => horiz_coord_create('reglon', '', nlon, 'longitude', 'degrees_east', beglon, endlon, & + glons(beglon:endlon), map=coord_map) + + nullify(coord_map) + + ! only create grid if it doesn't exist + if (get_cam_grid_index('ctem_lonlat') == -1) then + call cam_grid_register('ctem_lonlat', reg_decomp, lat_coord, lon_coord, grid_map, unstruct=.false.) + end if + + nullify(grid_map) + + allocate(recvcnts(npes)) + allocate(displs(npes)) + allocate(beglats(npes)) + allocate(beglons(npes)) + allocate(endlats(npes)) + allocate(endlons(npes)) + + allocate(phys%u(pver,pcols,begchunk:endchunk)) + allocate(phys%v(pver,pcols,begchunk:endchunk)) + allocate(phys%theta(pver,pcols,begchunk:endchunk)) + allocate(phys%w(pver,pcols,begchunk:endchunk)) + allocate(phys%uflux(pver,pcols,begchunk:endchunk)) + allocate(phys%vflux(pver,pcols,begchunk:endchunk)) + allocate(phys%utgw(pver,pcols,begchunk:endchunk)) + allocate(phys%vtgw(pver,pcols,begchunk:endchunk)) + + + allocate(lonlat%u(beglon:endlon,beglat:endlat,pver)) + allocate(lonlat%v(beglon:endlon,beglat:endlat,pver)) + allocate(lonlat%w(beglon:endlon,beglat:endlat,pver)) + allocate(lonlat%theta(beglon:endlon,beglat:endlat,pver)) + allocate(lonlat%uflux(beglon:endlon,beglat:endlat,pver)) + allocate(lonlat%vflux(beglon:endlon,beglat:endlat,pver)) + + if (debug) then + allocate(phys%pmid(pver,pcols,begchunk:endchunk)) + allocate(phys%lon(pcols,begchunk:endchunk)) + allocate(phys%lat(pcols,begchunk:endchunk)) + end if + + ! gathered grids only exist on masterproc + if (masterproc) then + allocate(gathered_lonlat%u(nlon, nlat, pver)) + allocate(gathered_lonlat%v(nlon, nlat, pver)) + allocate(gathered_lonlat%w(nlon, nlat, pver)) + allocate(gathered_lonlat%theta(nlon, nlat, pver)) + allocate(gathered_lonlat%uflux(nlon, nlat, pver)) + allocate(gathered_lonlat%vflux(nlon, nlat, pver)) + end if + end subroutine nlgw_regrid_init + + + !----------------------------------------------------------------------------- + ! This routine takes variables on the regular lat/lon grid: + ! * uses MPI_Scatter to broadcast them from masterproc back to all ranks + ! * interpolates them back onto the cubed-sphere grid + ! * finally re-chunks the data so it can be used elsewhere + !----------------------------------------------------------------------------- + subroutine nlgw_latlon_scatter(phys, lonlat, gathered_lonlat) + use gw_nlgw_utils, only: nlon, nlat + use esmf_lonlat2phys_mod, only: fields_bundle_t, n_flx_flds, esmf_lonlat2phys_regrid + use mpishorthand + + type(phys_vars), intent(inout), target :: phys + type(lonlat_vars), intent(inout), target :: lonlat + type(lonlat_vars), intent(inout), target :: gathered_lonlat + + real(r8), allocatable :: flat_array(:) + + integer :: lchnk, ncol, i, sendcnt, disp_sum + + type(fields_bundle_t) :: phys_flx_flds(n_flx_flds) + type(fields_bundle_t) :: lonlat_flx_flds(n_flx_flds) + + call t_startf('nlgw_scatter') + + call t_startf('nlgw_mpiscatter') + ! this subsection gathers all variables onto a single process + + sendcnt = (endlon - beglon + 1) * (endlat - beglat + 1) * pver + + ! mpi gather book-keeping + call mpigather(sendcnt, 1, mpiint, recvcnts, 1, mpiint, 0, mpicom) + call mpigather(beglat, 1, mpiint, beglats, 1, mpiint, 0, mpicom) + call mpigather(beglon, 1, mpiint, beglons, 1, mpiint, 0, mpicom) + call mpigather(endlat, 1, mpiint, endlats, 1, mpiint, 0, mpicom) + call mpigather(endlon, 1, mpiint, endlons, 1, mpiint, 0, mpicom) + + if (masterproc) then + disp_sum = 0 + do i = 1, npes + displs(i) = disp_sum + disp_sum = disp_sum + recvcnts(i) + end do + end if + allocate(flat_array(nlon * nlat * pver)) + + ! unlike in gather case all ranks needs the displs and recvcnts + call mpibcast(displs, npes, mpiint, 0, mpicom) + call mpibcast(recvcnts, npes, mpiint, 0, mpicom) + + call scatter_3d(gathered_lonlat%uflux, sendcnt, flat_array, lonlat%uflux(beglon:endlon, beglat:endlat, 1:pver)) + call scatter_3d(gathered_lonlat%vflux, sendcnt, flat_array, lonlat%vflux(beglon:endlon, beglat:endlat, 1:pver)) + + deallocate(flat_array) + + call t_stopf('nlgw_mpiscatter') + + call t_startf('nlgw_latlon_gather') + ! this subsection does regridding + + phys_flx_flds(1)%fld => phys%uflux + phys_flx_flds(2)%fld => phys%vflux + + lonlat_flx_flds(1)%fld => lonlat%uflux + lonlat_flx_flds(2)%fld => lonlat%vflux + + ! actual call to regrid to lon/lat grid + call esmf_lonlat2phys_regrid(lonlat_flx_flds, phys_flx_flds) + + call t_stopf('nlgw_latlon_gather') + + call t_stopf('nlgw_scatter') + + end subroutine nlgw_latlon_scatter + + + !----------------------------------------------------------------------------- + ! This routine takes variables on the irregular cubed-sphere grid: + ! * gathers all the chunks into a single data structure on each rank + ! * interpolates cubed-sphere variable to a regular lonlat grid + ! * uses MPI_Gather to collect regridded data from all ranks to the masterproc + !----------------------------------------------------------------------------- + subroutine nlgw_latlon_gather(phys_state, phys, lonlat, gathered_lonlat) + use gw_nlgw_utils, only: nlon, nlat + use esmf_phys2lonlat_mod, only: fields_bundle_t, nflds, esmf_phys2lonlat_regrid + use gw_nlgw_utils, only: p0 + use physconst, only: cappa + use mpishorthand + + type(physics_state), intent(in) :: phys_state(begchunk:endchunk) + + type(phys_vars), intent(inout), target :: phys + type(lonlat_vars), intent(inout), target :: lonlat + type(lonlat_vars), intent(inout), target :: gathered_lonlat + + real(r8), allocatable :: flat_array(:) + + integer :: lchnk, ncol, i, sendcnt, disp_sum + + type(fields_bundle_t) :: physflds(nflds) + type(fields_bundle_t) :: lonlatflds(nflds) + + call t_startf('nlgw_gather') + + + call t_startf('nlgw_unchunk') + + do lchnk = begchunk,endchunk + ncol = phys_state(lchnk)%ncol + do i = 1,ncol + ! wind components + phys%u(:,i,lchnk) = phys_state(lchnk)%u(i,:) + phys%v(:,i,lchnk) = phys_state(lchnk)%v(i,:) + phys%w(:,i,lchnk) = phys_state(lchnk)%omega(i,:) + phys%theta(:,i,lchnk) = phys_state(lchnk)%t(i,:) * (p0 / phys_state(lchnk)%pmid(i,:)) ** cappa + + ! for debugging only + if (debug) then + phys%pmid(:,i,lchnk) = phys_state(lchnk)%pmid(i,:) + phys%lat(i,lchnk) = phys_state(lchnk)%lat(i) + phys%lon(i,lchnk) = phys_state(lchnk)%lon(i) + end if + + end do + end do + + call t_stopf('nlgw_unchunk') + + call t_startf('nlgw_latlon_gather') + + physflds(1)%fld => phys%u + physflds(2)%fld => phys%v + physflds(3)%fld => phys%w + physflds(4)%fld => phys%theta + + lonlatflds(1)%fld => lonlat%u + lonlatflds(2)%fld => lonlat%v + lonlatflds(3)%fld => lonlat%w + lonlatflds(4)%fld => lonlat%theta + + ! actual call to regrid to lon/lat grid + call esmf_phys2lonlat_regrid(physflds, lonlatflds) + + call t_stopf('nlgw_latlon_gather') + + call t_startf('nlgw_mpigather') + ! this subsection gathers all variables onto a single process + + sendcnt = (endlon - beglon + 1) * (endlat - beglat + 1) * pver + + ! mpi gather book-keeping + call mpigather(sendcnt, 1, mpiint, recvcnts, 1, mpiint, 0, mpicom) + call mpigather(beglat, 1, mpiint, beglats, 1, mpiint, 0, mpicom) + call mpigather(beglon, 1, mpiint, beglons, 1, mpiint, 0, mpicom) + call mpigather(endlat, 1, mpiint, endlats, 1, mpiint, 0, mpicom) + call mpigather(endlon, 1, mpiint, endlons, 1, mpiint, 0, mpicom) + + if (masterproc) then + disp_sum = 0 + do i = 1, npes + displs(i) = disp_sum + disp_sum = disp_sum + recvcnts(i) + end do + end if + allocate(flat_array(nlon * nlat * pver)) + + call gather_3d(lonlat%u(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, gathered_lonlat%u) + call gather_3d(lonlat%v(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, gathered_lonlat%v) + call gather_3d(lonlat%w(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, gathered_lonlat%w) + call gather_3d(lonlat%theta(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, gathered_lonlat%theta) + + deallocate(flat_array) + + call t_stopf('nlgw_mpigather') + + call t_stopf('nlgw_gather') + + end subroutine nlgw_latlon_gather + + !----------------------------------------------------------------------------- + ! Utility function for gathering 2D data into a single array + !----------------------------------------------------------------------------- + subroutine gather_2d(local_array, sendcnt, flat_array, grid_out) + use mpishorthand + real(r8), intent(in) :: local_array(:,:) ! Local 2D array section + integer, intent(in) :: sendcnt + real(r8), intent(inout) :: flat_array(:) ! Flattened array for gathering + real(r8), allocatable, intent(inout) :: grid_out(:,:) ! Full gathered grid + + integer :: i, lonsize, latsize + + ! gather variables onto master proc into a flat array (can't do 2D/3D mpigather) + call mpigatherv(local_array, sendcnt, mpir8, flat_array, recvcnts, displs, mpir8, 0, mpicom) + + if (masterproc) then + do i = 1, npes + lonsize = endlons(i) - beglons(i) + 1 + latsize = endlats(i) - beglats(i) + 1 + ! reshape each ranks flattended data and populate each block into a single lonlat grid + grid_out(beglons(i):endlons(i), beglats(i):endlats(i)) = & + reshape(flat_array(displs(i)+1:displs(i)+sendcnt), (/ lonsize, latsize /)) + end do + end if + end subroutine gather_2d + + !----------------------------------------------------------------------------- + ! Utility function for gathering 3D data into a single array + !----------------------------------------------------------------------------- + subroutine gather_3d(local_array, sendcnt, flat_array, grid_out) + use mpishorthand + real(r8), intent(in) :: local_array(:,:,:) ! Local 3D array section + integer, intent(in) :: sendcnt + real(r8), intent(inout) :: flat_array(:) ! Flattened array for gathering + real(r8), allocatable, intent(inout) :: grid_out(:,:,:) ! Full gathered grid + + integer :: i, lonsize, latsize + + ! gather variables onto master proc into a flat array (can't do 2D/3D mpigather) + call mpigatherv(local_array, sendcnt, mpir8, flat_array, recvcnts, displs, mpir8, 0, mpicom) + + if (masterproc) then + do i = 1, npes + lonsize = endlons(i) - beglons(i) + 1 + latsize = endlats(i) - beglats(i) + 1 + ! reshape each ranks flattended data and populate each block into a single lonlat grid + grid_out(beglons(i):endlons(i), beglats(i):endlats(i), 1:pver) = & + reshape(flat_array(displs(i)+1:displs(i)+sendcnt), (/ lonsize, latsize, pver /)) + end do + end if + end subroutine gather_3d + + !----------------------------------------------------------------------------- + ! Utility function for scattering 3D data into lonlat arrays + !----------------------------------------------------------------------------- + subroutine scatter_3d(grid_in, sendcnt, flat_array, lonlat_out) + use mpishorthand + real(r8), allocatable, intent(in) :: grid_in(:,:,:) ! Local 3D array section + integer, intent(in) :: sendcnt + real(r8), intent(inout) :: flat_array(:) ! temporary storage in flat array + real(r8), target, intent(inout) :: lonlat_out(:,:,:) ! Full scattered grid + + integer :: i, lonsize, latsize + + if (masterproc) then + do i = 1, npes + lonsize = endlons(i) - beglons(i) + 1 + latsize = endlats(i) - beglats(i) + 1 + flat_array(displs(i)+1:displs(i)+sendcnt) = & + reshape(grid_in(beglons(i):endlons(i), beglats(i):endlats(i), 1:pver), (/lonsize* latsize * pver/)) + end do + end if + + ! scatter variables from flat_array back to all processes + call mpiscatterv(flat_array, recvcnts, displs, mpir8, lonlat_out, sendcnt, mpir8, 0, mpicom) + + end subroutine scatter_3d + + !----------------------------------------------------------------------------- + ! Tidy up (free allocated memory) + !----------------------------------------------------------------------------- + subroutine nlgw_regrid_final(phys, lonlat, gathered_lonlat) + use esmf_phys2lonlat_mod, only: esmf_phys2lonlat_destroy + use esmf_lonlat2phys_mod, only: esmf_lonlat2phys_destroy + use esmf_lonlat_grid_mod, only: esmf_lonlat_grid_destroy + use esmf_phys_mesh_mod, only: esmf_phys_mesh_destroy + + type(phys_vars), intent(inout), target :: phys + type(lonlat_vars), intent(inout), target :: lonlat + type(lonlat_vars), intent(inout), target :: gathered_lonlat + + call esmf_phys2lonlat_destroy() + call esmf_lonlat2phys_destroy() + call esmf_lonlat_grid_destroy() + call esmf_phys_mesh_destroy() + + ! TODO double check ALL deallocates here + if (masterproc) then + deallocate(gathered_lonlat%u) + deallocate(gathered_lonlat%v) + deallocate(gathered_lonlat%w) + deallocate(gathered_lonlat%theta) + deallocate(gathered_lonlat%uflux) + deallocate(gathered_lonlat%vflux) + end if + + deallocate(phys%u) + deallocate(phys%v) + deallocate(phys%theta) + deallocate(phys%w) + deallocate(phys%uflux) + deallocate(phys%vflux) + deallocate(phys%utgw) + deallocate(phys%vtgw) + + deallocate(lonlat%u) + deallocate(lonlat%v) + deallocate(lonlat%w) + deallocate(lonlat%theta) + deallocate(lonlat%uflux) + deallocate(lonlat%vflux) + + if (debug) then + deallocate(phys%pmid) + deallocate(phys%lon) + deallocate(phys%lat) + end if + + deallocate(recvcnts) + deallocate(displs) + deallocate(beglats) + deallocate(beglons) + deallocate(endlats) + deallocate(endlons) + end subroutine nlgw_regrid_final + +end module nlgw_remap_mod