From 97ed194b4a60013d3324bbc61845a763159a8e9a Mon Sep 17 00:00:00 2001 From: tommelt Date: Mon, 16 Jun 2025 17:35:14 +0100 Subject: [PATCH 01/23] wip: add remapping subroutine --- src/physics/cam/gw_drag.F90 | 5 + src/utils/esmf_check_error_mod.F90 | 33 +++ src/utils/esmf_lonlat_grid_mod.F90 | 338 +++++++++++++++++++++++++++++ src/utils/esmf_phys2lonlat_mod.F90 | 234 ++++++++++++++++++++ src/utils/esmf_phys_mesh_mod.F90 | 202 +++++++++++++++++ src/utils/esmf_zonal_mean_mod.F90 | 155 +++++++++++++ src/utils/remap.F90 | 212 ++++++++++++++++++ 7 files changed, 1179 insertions(+) create mode 100644 src/utils/esmf_check_error_mod.F90 create mode 100644 src/utils/esmf_lonlat_grid_mod.F90 create mode 100644 src/utils/esmf_phys2lonlat_mod.F90 create mode 100644 src/utils/esmf_phys_mesh_mod.F90 create mode 100644 src/utils/esmf_zonal_mean_mod.F90 create mode 100644 src/utils/remap.F90 diff --git a/src/physics/cam/gw_drag.F90 b/src/physics/cam/gw_drag.F90 index 9a3a651060..6f929ad190 100644 --- a/src/physics/cam/gw_drag.F90 +++ b/src/physics/cam/gw_drag.F90 @@ -45,6 +45,7 @@ module gw_drag use gw_ml, only: gw_drag_convect_dp_ml_init, gw_drag_convect_dp_ml_final, & gw_drag_convect_dp_ml use gw_nlgw, only: gw_nlgw_dp_ml, gw_nlgw_dp_init, gw_nlgw_dp_finalize + use nlgw_remap_mod, only: nlgw_regrid_init, nlgw_regrid, nlgw_regrid_init ! Typical module header implicit none @@ -1489,6 +1490,10 @@ subroutine gw_tend(state, pbuf, dt, ptend, cam_in, flx_heat) ! constituents are all treated as wet mmr call set_dry_to_wet(state1) + call nlgw_regrid_init() + call nlgw_regrid(state1) + stop + lchnk = state1%lchnk ncol = state1%ncol diff --git a/src/utils/esmf_check_error_mod.F90 b/src/utils/esmf_check_error_mod.F90 new file mode 100644 index 0000000000..eb9db18a90 --- /dev/null +++ b/src/utils/esmf_check_error_mod.F90 @@ -0,0 +1,33 @@ +!------------------------------------------------------------------------------ +! ESMF error handler +!------------------------------------------------------------------------------ +module esmf_check_error_mod + use shr_kind_mod, only: cl=>SHR_KIND_CL + use spmd_utils, only: masterproc + use cam_logfile, only: iulog + use cam_abortutils, only: endrun + use ESMF, only: ESMF_SUCCESS + + implicit none + + private + public :: check_esmf_error + +contains + + subroutine check_esmf_error( rc, errmsg ) + + integer, intent(in) :: rc + character(len=*), intent(in) :: errmsg + + character(len=cl) :: errstr + + if (rc /= ESMF_SUCCESS) then + write(errstr,'(a,i6)') 'esmf_zonal_mod::'//trim(errmsg)//' -- ESMF ERROR code: ',rc + if (masterproc) write(iulog,*) trim(errstr) + call endrun(trim(errstr)) + end if + + end subroutine check_esmf_error + +end module esmf_check_error_mod diff --git a/src/utils/esmf_lonlat_grid_mod.F90 b/src/utils/esmf_lonlat_grid_mod.F90 new file mode 100644 index 0000000000..23b575ac61 --- /dev/null +++ b/src/utils/esmf_lonlat_grid_mod.F90 @@ -0,0 +1,338 @@ +!------------------------------------------------------------------------------- +! Encapsulates an ESMF regular longitude / latitude grid +!------------------------------------------------------------------------------- +module esmf_lonlat_grid_mod + use shr_kind_mod, only: r8 => shr_kind_r8 + use spmd_utils, only: masterproc, mpicom + use cam_logfile, only: iulog + use cam_abortutils, only: endrun + + use ESMF, only: ESMF_Grid, ESMF_GridCreate1PeriDim, ESMF_GridAddCoord + use ESMF, only: ESMF_GridGetCoord, ESMF_GridDestroy + use ESMF, only: ESMF_KIND_R8, ESMF_INDEX_GLOBAL, ESMF_STAGGERLOC_CENTER + use esmf_check_error_mod, only: check_esmf_error + + implicit none + + public + + type(ESMF_Grid), protected :: lonlat_grid + + integer, protected :: nlon = 0 + integer, protected :: nlat = 0 + + integer, protected :: lon_beg = -1 + integer, protected :: lon_end = -1 + integer, protected :: lat_beg = -1 + integer, protected :: lat_end = -1 + + real(r8), allocatable, protected :: glats(:) + real(r8), allocatable, protected :: glons(:) + + integer, protected :: zonal_comm ! zonal direction MPI communicator + +contains + + subroutine esmf_lonlat_grid_init(nlats_in) + use phys_grid, only: get_grid_dims + use mpi, only: mpi_comm_size, mpi_comm_rank, MPI_PROC_NULL, MPI_INTEGER + + integer, intent(in) :: nlats_in + + real(r8) :: delx, dely + + integer :: npes, ierr, mytid, irank, mytidi, mytidj + integer :: i,j, n + integer :: ntasks_lon, ntasks_lat + integer :: lons_per_task, lons_overflow, lats_per_task, lats_overflow + integer :: task_cnt + integer :: mynlats, mynlons + + integer, allocatable :: mytidi_send(:) + integer, allocatable :: mytidj_send(:) + integer, allocatable :: mytidi_recv(:) + integer, allocatable :: mytidj_recv(:) + + integer, allocatable :: nlons_send(:) + integer, allocatable :: nlats_send(:) + integer, allocatable :: nlons_recv(:) + integer, allocatable :: nlats_recv(:) + + integer, allocatable :: nlons_task(:) + integer, allocatable :: nlats_task(:) + + integer, parameter :: minlats_per_pe = 2 + integer, parameter :: minlons_per_pe = 2 + + integer, allocatable :: petmap(:,:,:) + integer :: petcnt, astat + + integer :: lbnd_lat, ubnd_lat, lbnd_lon, ubnd_lon + integer :: lbnd(1), ubnd(1) + real(ESMF_KIND_R8), pointer :: coordX(:), coordY(:) + + character(len=*), parameter :: subname = 'esmf_lonlat_grid_init: ' + + ! create reg lon lat grid + + nlat = nlats_in + dely = 180._r8/nlat + + nlon = 2*nlat + delx = 360._r8/nlon + + allocate(glons(nlon), stat=astat) + if (astat/=0) then + call endrun(subname//'not able to allocate glons array') + end if + allocate(glats(nlat), stat=astat) + if (astat/=0) then + call endrun(subname//'not able to allocate glats array') + end if + + glons(1) = 0._r8 + glats(1) = -90._r8 + 0.5_r8 * dely + + do i = 2,nlon + glons(i) = glons(i-1) + delx + end do + do i = 2,nlat + glats(i) = glats(i-1) + dely + end do + + ! decompose the grid across mpi tasks ... + + call mpi_comm_size(mpicom, npes, ierr) + call mpi_comm_rank(mpicom, mytid, ierr) + + decomp_loop: do ntasks_lon = 1,nlon + ntasks_lat = npes/ntasks_lon + if ( (minlats_per_pe*ntasks_latmytid) exit jloop + end do + enddo jloop + endif + + mynlats = lat_end-lat_beg+1 + mynlons = lon_end-lon_beg+1 + + if (mynlats shr_kind_r8 + use cam_logfile, only: iulog + use cam_abortutils, only: endrun + use spmd_utils, only: masterproc + use ppgrid, only: pver + + use ESMF, only: ESMF_RouteHandle, ESMF_Field, ESMF_ArraySpec, ESMF_ArraySpecSet + use ESMF, only: ESMF_FieldCreate, ESMF_FieldRegridStore + use ESMF, only: ESMF_FieldGet, ESMF_FieldRegrid + use ESMF, only: ESMF_KIND_I4, ESMF_KIND_R8, ESMF_TYPEKIND_R8 + use ESMF, only: ESMF_REGRIDMETHOD_BILINEAR, ESMF_POLEMETHOD_ALLAVG, ESMF_EXTRAPMETHOD_NEAREST_IDAVG + use ESMF, only: ESMF_TERMORDER_SRCSEQ, ESMF_MESHLOC_ELEMENT, ESMF_STAGGERLOC_CENTER + use ESMF, only: ESMF_FieldDestroy, ESMF_RouteHandleDestroy + use esmf_check_error_mod, only: check_esmf_error + + implicit none + + private + + public :: esmf_phys2lonlat_init + public :: esmf_phys2lonlat_regrid + public :: esmf_phys2lonlat_destroy + public :: fields_bundle_t + public :: nflds + + type(ESMF_RouteHandle) :: rh_phys2lonlat_3d + type(ESMF_RouteHandle) :: rh_phys2lonlat_2d + + type(ESMF_Field) :: physfld_3d + type(ESMF_Field) :: lonlatfld_3d + + type(ESMF_Field) :: physfld_2d + type(ESMF_Field) :: lonlatfld_2d + + interface esmf_phys2lonlat_regrid + module procedure esmf_phys2lonlat_regrid_2d + module procedure esmf_phys2lonlat_regrid_3d + end interface esmf_phys2lonlat_regrid + + type :: fields_bundle_t + real(r8), pointer :: fld(:,:,:) => null() + end type fields_bundle_t + + integer, parameter :: nflds = 5 + +contains + + !------------------------------------------------------------------------------ + !------------------------------------------------------------------------------ + subroutine esmf_phys2lonlat_init() + use esmf_phys_mesh_mod, only: physics_grid_mesh + use esmf_lonlat_grid_mod, only: lonlat_grid + + type(ESMF_ArraySpec) :: arrayspec + integer(ESMF_KIND_I4), pointer :: factorIndexList(:,:) + real(ESMF_KIND_R8), pointer :: factorList(:) + integer :: smm_srctermproc, smm_pipelinedep, rc + + character(len=*), parameter :: subname = 'esmf_phys2lonlat_init: ' + + smm_srctermproc = 0 + smm_pipelinedep = 16 + + ! create ESMF fields + + ! 3D phys fld + call ESMF_ArraySpecSet(arrayspec, 3, ESMF_TYPEKIND_R8, rc=rc) + call check_esmf_error(rc, subname//'ESMF_ArraySpecSet 3D phys fld ERROR') + + physfld_3d = ESMF_FieldCreate(physics_grid_mesh, arrayspec, & + gridToFieldMap=(/3/), meshloc=ESMF_MESHLOC_ELEMENT, & + ungriddedLBound=(/1,1/), ungriddedUBound=(/pver,nflds/), rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldCreate 3D phys fld ERROR') + + ! 3D lon lat grid + call ESMF_ArraySpecSet(arrayspec, 4, ESMF_TYPEKIND_R8, rc=rc) + call check_esmf_error(rc, subname//'ESMF_ArraySpecSet 3D lonlat fld ERROR') + + lonlatfld_3d = ESMF_FieldCreate( lonlat_grid, arrayspec, staggerloc=ESMF_STAGGERLOC_CENTER, & + ungriddedLBound=(/1,1/), ungriddedUBound=(/pver,nflds/), rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldCreate 3D lonlat fld ERROR') + + ! 2D phys fld + call ESMF_ArraySpecSet(arrayspec, 1, ESMF_TYPEKIND_R8, rc=rc) + call check_esmf_error(rc, subname//'ESMF_ArraySpecSet 2D phys fld ERROR') + + physfld_2d = ESMF_FieldCreate(physics_grid_mesh, arrayspec, & + meshloc=ESMF_MESHLOC_ELEMENT, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldCreate 2D phys fld ERROR') + + ! 2D lon/lat grid + call ESMF_ArraySpecSet(arrayspec, 2, ESMF_TYPEKIND_R8, rc=rc) + call check_esmf_error(rc, subname//'ESMF_ArraySpecSet 2D lonlat fld ERROR') + + lonlatfld_2d = ESMF_FieldCreate( lonlat_grid, arrayspec, staggerloc=ESMF_STAGGERLOC_CENTER, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldCreate 2D lonlat fld ERROR') + + call ESMF_FieldRegridStore(srcField=physfld_3d, dstField=lonlatfld_3d, & + regridMethod=ESMF_REGRIDMETHOD_BILINEAR, & + polemethod=ESMF_POLEMETHOD_ALLAVG, & + extrapMethod=ESMF_EXTRAPMETHOD_NEAREST_IDAVG, & + routeHandle=rh_phys2lonlat_3d, factorIndexList=factorIndexList, & + factorList=factorList, srcTermProcessing=smm_srctermproc, & + pipelineDepth=smm_pipelinedep, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldRegridStore 3D routehandle ERROR') + + call ESMF_FieldRegridStore(srcField=physfld_2d, dstField=lonlatfld_2d, & + regridMethod=ESMF_REGRIDMETHOD_BILINEAR, & + polemethod=ESMF_POLEMETHOD_ALLAVG, & + extrapMethod=ESMF_EXTRAPMETHOD_NEAREST_IDAVG, & + routeHandle=rh_phys2lonlat_2d, factorIndexList=factorIndexList, & + factorList=factorList, srcTermProcessing=smm_srctermproc, & + pipelineDepth=smm_pipelinedep, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldRegridStore 3D routehandle ERROR') + + end subroutine esmf_phys2lonlat_init + + !------------------------------------------------------------------------------ + !------------------------------------------------------------------------------ + subroutine esmf_phys2lonlat_regrid_3d(physflds, lonlatflds) + use esmf_lonlat_grid_mod, only: lon_beg,lon_end,lat_beg,lat_end + use ppgrid, only: pcols, pver, begchunk, endchunk + use phys_grid, only: get_ncols_p + + type(fields_bundle_t), intent(in) :: physflds(nflds) + type(fields_bundle_t), intent(inout) :: lonlatflds(nflds) + + integer :: i, ichnk, ncol, ifld, ilev, icol, rc + real(ESMF_KIND_R8), pointer :: physptr(:,:,:) + real(ESMF_KIND_R8), pointer :: lonlatptr(:,:,:,:) + + character(len=*), parameter :: subname = 'esmf_phys2lonlat_regrid_3d: ' + + call ESMF_FieldGet(physfld_3d, localDe=0, farrayPtr=physptr, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldGet physptr') + + i = 0 + do ichnk = begchunk, endchunk + ncol = get_ncols_p(ichnk) + do icol = 1,ncol + i = i+1 + do ifld = 1,nflds + do ilev = 1,pver + physptr(ilev,ifld,i) = physflds(ifld)%fld(ilev,icol,ichnk) + end do + end do + end do + end do + + call ESMF_FieldRegrid(physfld_3d, lonlatfld_3d, rh_phys2lonlat_3d, & + termorderflag=ESMF_TERMORDER_SRCSEQ, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldRegrid physfld_3d->lonlatfld_3d') + + call ESMF_FieldGet(lonlatfld_3d, localDe=0, farrayPtr=lonlatptr, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldGet lonlatptr') + + do ifld = 1,nflds + lonlatflds(ifld)%fld(lon_beg:lon_end,lat_beg:lat_end,1:pver) = lonlatptr(lon_beg:lon_end,lat_beg:lat_end,1:pver,ifld) + end do + + end subroutine esmf_phys2lonlat_regrid_3d + + !------------------------------------------------------------------------------ + !------------------------------------------------------------------------------ + subroutine esmf_phys2lonlat_regrid_2d(physarr, lonlatarr) + use esmf_lonlat_grid_mod, only: lon_beg,lon_end,lat_beg,lat_end + use ppgrid, only: pcols, pver, begchunk, endchunk + use phys_grid, only: get_ncols_p + + real(r8),intent(in) :: physarr(pcols,begchunk:endchunk) + real(r8),intent(out) :: lonlatarr(lon_beg:lon_end,lat_beg:lat_end) + + integer :: i, ichnk, ncol, icol, rc + real(ESMF_KIND_R8), pointer :: physptr(:) + real(ESMF_KIND_R8), pointer :: lonlatptr(:,:) + + character(len=*), parameter :: subname = 'esmf_phys2lonlat_regrid_2d: ' + + call ESMF_FieldGet(physfld_2d, localDe=0, farrayPtr=physptr, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldGet physptr') + + i = 0 + do ichnk = begchunk, endchunk + ncol = get_ncols_p(ichnk) + do icol = 1,ncol + i = i+1 + physptr(i) = physarr(icol,ichnk) + end do + end do + + call ESMF_FieldRegrid(physfld_2d, lonlatfld_2d, rh_phys2lonlat_2d, & + termorderflag=ESMF_TERMORDER_SRCSEQ, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldRegrid physfld_3d->lonlatfld_3d') + + call ESMF_FieldGet(lonlatfld_2d, localDe=0, farrayPtr=lonlatptr, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldGet lonlatptr') + + lonlatarr(lon_beg:lon_end,lat_beg:lat_end) = lonlatptr(lon_beg:lon_end,lat_beg:lat_end) + + end subroutine esmf_phys2lonlat_regrid_2d + + !------------------------------------------------------------------------------ + !------------------------------------------------------------------------------ + subroutine esmf_phys2lonlat_destroy() + + integer :: rc + character(len=*), parameter :: subname = 'esmf_phys2lonlat_destroy: ' + + call ESMF_RouteHandleDestroy(rh_phys2lonlat_3d, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldDestroy rh_phys2lonlat_3d') + + call ESMF_RouteHandleDestroy(rh_phys2lonlat_2d, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldDestroy rh_phys2lonlat_2d') + + call ESMF_FieldDestroy(lonlatfld_3d, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldDestroy lonlatfld_3d') + + call ESMF_FieldDestroy(lonlatfld_2d, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldDestroy lonlatfld_2d') + + call ESMF_FieldDestroy(physfld_3d, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldDestroy physfld_3d') + + call ESMF_FieldDestroy(physfld_2d, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldDestroy physfld_2d') + + end subroutine esmf_phys2lonlat_destroy + +end module esmf_phys2lonlat_mod diff --git a/src/utils/esmf_phys_mesh_mod.F90 b/src/utils/esmf_phys_mesh_mod.F90 new file mode 100644 index 0000000000..4235452dc5 --- /dev/null +++ b/src/utils/esmf_phys_mesh_mod.F90 @@ -0,0 +1,202 @@ +!------------------------------------------------------------------------------- +! Encapsulates the CAM physics grid mesh +!------------------------------------------------------------------------------- +module esmf_phys_mesh_mod + use shr_kind_mod, only: r8 => shr_kind_r8, cs=>shr_kind_cs, cl=>shr_kind_cl + use spmd_utils, only: masterproc + use cam_logfile, only: iulog + use cam_abortutils, only: endrun + use ESMF, only: ESMF_DistGrid, ESMF_DistGridCreate, ESMF_MeshCreate + use ESMF, only: ESMF_FILEFORMAT_ESMFMESH,ESMF_MeshGet,ESMF_Mesh, ESMF_SUCCESS + use ESMF, only: ESMF_MeshDestroy, ESMF_DistGridDestroy + use esmf_check_error_mod, only: check_esmf_error + + implicit none + + private + + public :: esmf_phys_mesh_init + public :: esmf_phys_mesh_destroy + public :: physics_grid_mesh + + ! phys_mesh: Local copy of physics grid + type(ESMF_Mesh), protected :: physics_grid_mesh + + ! dist_grid_2d: DistGrid for 2D fields + type(ESMF_DistGrid) :: dist_grid_2d + +contains + + !----------------------------------------------------------------------------- + !----------------------------------------------------------------------------- + subroutine esmf_phys_mesh_init() + use phys_control, only: phys_getopts + use phys_grid, only: get_ncols_p, get_gcol_p, get_rlon_all_p, get_rlat_all_p + use ppgrid, only: pcols, begchunk, endchunk + use shr_const_mod,only: shr_const_pi + + ! Local variables + integer :: ncols + integer :: chnk, col, dindex + integer, allocatable :: decomp(:) + character(len=cl) :: grid_file + integer :: spatialDim + integer :: numOwnedElements + real(r8), pointer :: ownedElemCoords(:) + real(r8), pointer :: lat(:), latMesh(:) + real(r8), pointer :: lon(:), lonMesh(:) + real(r8) :: lats(pcols) ! array of chunk latitudes + real(r8) :: lons(pcols) ! array of chunk longitude + character(len=cs) :: tempc1,tempc2 + character(len=300) :: errstr + + integer :: i, c, n, total_cols, rc + + real(r8), parameter :: abstol = 1.e-3_r8 + real(r8), parameter :: radtodeg = 180.0_r8/shr_const_pi + character(len=*), parameter :: subname = 'esmf_phys_mesh_init: ' + + ! Find the physics grid file + call phys_getopts(physics_grid_out=grid_file) + + ! Compute the local decomp + total_cols = 0 + do chnk = begchunk, endchunk + total_cols = total_cols + get_ncols_p(chnk) + end do + allocate(decomp(total_cols), stat=rc) + if (rc/=0) then + call endrun(subname//'not able to allocate decomp') + end if + + dindex = 0 + do chnk = begchunk, endchunk + ncols = get_ncols_p(chnk) + do col = 1, ncols + dindex = dindex + 1 + decomp(dindex) = get_gcol_p(chnk, col) + end do + end do + + ! Create a DistGrid based on the physics decomp + dist_grid_2d = ESMF_DistGridCreate(arbSeqIndexList=decomp, rc=rc) + call check_esmf_error(rc, subname//'ESMF_DistGridCreate') + + ! Create an ESMF_mesh for the physics decomposition + physics_grid_mesh = ESMF_MeshCreate(trim(grid_file), ESMF_FILEFORMAT_ESMFMESH, & + elementDistgrid=dist_grid_2d, rc=rc) + call check_esmf_error(rc, subname//'ESMF_MeshCreate') + + ! Check that the mesh coordinates are consistent with the model physics column coordinates + + ! obtain mesh lats and lons + call ESMF_MeshGet(physics_grid_mesh, spatialDim=spatialDim, numOwnedElements=numOwnedElements, rc=rc) + call check_esmf_error(rc, subname//'ESMF_MeshGet') + + if (numOwnedElements /= total_cols) then + write(tempc1,'(i10)') numOwnedElements + write(tempc2,'(i10)') total_cols + call endrun(subname//"ERROR numOwnedElements "// & + trim(tempc1) //" not equal to local size "// trim(tempc2)) + end if + + allocate(ownedElemCoords(spatialDim*numOwnedElements), stat=rc) + if (rc/=0) then + call endrun(subname//'not able to allocate ownedElemCoords') + end if + + allocate(lonMesh(total_cols), stat=rc) + if (rc/=0) then + call endrun(subname//'not able to allocate lonMesh') + end if + + allocate(latMesh(total_cols), stat=rc) + if (rc/=0) then + call endrun(subname//'not able to allocate latMesh') + end if + + call ESMF_MeshGet(physics_grid_mesh, ownedElemCoords=ownedElemCoords) + call check_esmf_error(rc, subname//'ESMF_MeshGet') + + do n = 1,total_cols + lonMesh(n) = ownedElemCoords(2*n-1) + latMesh(n) = ownedElemCoords(2*n) + end do + + ! obtain internally generated cam lats and lons + allocate(lon(total_cols), stat=rc); + if (rc/=0) then + call endrun(subname//'not able to allocate lon') + end if + + lon(:) = 0._r8 + + allocate(lat(total_cols), stat=rc); + if (rc/=0) then + call endrun(subname//'not able to allocate lat') + end if + + lat(:) = 0._r8 + + n=0 + do c = begchunk, endchunk + ncols = get_ncols_p(c) + ! latitudes and longitudes returned in radians + call get_rlat_all_p(c, ncols, lats) + call get_rlon_all_p(c, ncols, lons) + do i=1,ncols + n = n+1 + lat(n) = lats(i)*radtodeg + lon(n) = lons(i)*radtodeg + end do + end do + + errstr = '' + ! error check differences between internally generated lons and those read in + do n = 1,total_cols + if (abs(lonMesh(n) - lon(n)) > abstol) then + if ( (abs(lonMesh(n)-lon(n)) > 360._r8+abstol) .or. (abs(lonMesh(n)-lon(n)) < 360._r8-abstol) ) then + write(errstr,100) n,lon(n),lonMesh(n), abs(lonMesh(n)-lon(n)) + write(iulog,*) trim(errstr) + endif + end if + if (abs(latMesh(n) - lat(n)) > abstol) then + ! poles in the 4x5 SCRIP file seem to be off by 1 degree + if (.not.( (abs(lat(n))>88.0_r8) .and. (abs(latMesh(n))>88.0_r8) )) then + write(errstr,101) n,lat(n),latMesh(n), abs(latMesh(n)-lat(n)) + write(iulog,*) trim(errstr) + endif + end if + end do + + if ( len_trim(errstr) > 0 ) then + call endrun(subname//'physics mesh coords do not match model coords') + end if + + ! deallocate memory + deallocate(ownedElemCoords) + deallocate(lon, lonMesh) + deallocate(lat, latMesh) + deallocate(decomp) + +100 format('esmf_phys_mesh_init: coord mismatch... n, lon(n), lonmesh(n), diff_lon = ',i6,2(f21.13,3x),d21.5) +101 format('esmf_phys_mesh_init: coord mismatch... n, lat(n), latmesh(n), diff_lat = ',i6,2(f21.13,3x),d21.5) + + end subroutine esmf_phys_mesh_init + + !----------------------------------------------------------------------------- + !----------------------------------------------------------------------------- + subroutine esmf_phys_mesh_destroy() + + integer :: rc + character(len=*), parameter :: subname = 'esmf_phys_mesh_destroy: ' + + call ESMF_MeshDestroy(physics_grid_mesh, rc=rc) + call check_esmf_error(rc, subname//'ESMF_MeshDestroy phys_mesh') + + call ESMF_DistGridDestroy(dist_grid_2d, rc=rc) + call check_esmf_error(rc, subname//'ESMF_DistGridDestroy dist_grid_2d') + + end subroutine esmf_phys_mesh_destroy + +end module esmf_phys_mesh_mod diff --git a/src/utils/esmf_zonal_mean_mod.F90 b/src/utils/esmf_zonal_mean_mod.F90 new file mode 100644 index 0000000000..822c7ae639 --- /dev/null +++ b/src/utils/esmf_zonal_mean_mod.F90 @@ -0,0 +1,155 @@ +!------------------------------------------------------------------------------ +! Provides methods for calculating zonal means on the ESMF regular longitude +! / latitude grid +!------------------------------------------------------------------------------ +module esmf_zonal_mean_mod + use shr_kind_mod, only: r8 => shr_kind_r8 + use cam_logfile, only: iulog + use cam_abortutils, only: endrun + use spmd_utils, only: masterproc + use cam_history_support, only : fillvalue + + implicit none + + private + + public :: esmf_zonal_mean_calc + public :: esmf_zonal_mean_masked + public :: esmf_zonal_mean_wsums + + interface esmf_zonal_mean_calc + module procedure esmf_zonal_mean_calc_2d + module procedure esmf_zonal_mean_calc_3d + end interface esmf_zonal_mean_calc + +contains + + !------------------------------------------------------------------------------ + ! Calculates zonal means of 3D fields. The wght option can be used to mask out + ! regions such as mountains. + !------------------------------------------------------------------------------ + subroutine esmf_zonal_mean_calc_3d(lonlatarr, zmarr, wght) + use ppgrid, only: pver + use esmf_lonlat_grid_mod, only: lon_beg,lon_end,lat_beg,lat_end, nlon + use esmf_lonlat_grid_mod, only: zonal_comm + use shr_reprosum_mod,only: shr_reprosum_calc + + real(r8), intent(in) :: lonlatarr(lon_beg:lon_end,lat_beg:lat_end,pver) + real(r8), intent(out) :: zmarr(lat_beg:lat_end,pver) + + real(r8), optional, intent(in) :: wght(lon_beg:lon_end,lat_beg:lat_end,pver) + + real(r8) :: tmparr(lon_beg:lon_end,pver) + real(r8) :: gsum(pver) + + real(r8) :: wsums(lat_beg:lat_end,pver) + + integer :: numlons, ilat, ilev + + numlons = lon_end-lon_beg+1 + + ! zonal mean + if (present(wght)) then + + wsums = esmf_zonal_mean_wsums(wght) + call esmf_zonal_mean_masked(lonlatarr, wght, wsums, zmarr) + + else + + do ilat = lat_beg, lat_end + call shr_reprosum_calc(lonlatarr(lon_beg:lon_end,ilat,:), gsum, numlons, numlons, pver, gbl_count=nlon, commid=zonal_comm) + zmarr(ilat,:) = gsum(:)/nlon + end do + + end if + + end subroutine esmf_zonal_mean_calc_3d + + !------------------------------------------------------------------------------ + ! Computes zonal mean for 2D lon / lat fields. + !------------------------------------------------------------------------------ + subroutine esmf_zonal_mean_calc_2d(lonlatarr, zmarr) + use esmf_lonlat_grid_mod, only: lon_beg,lon_end,lat_beg,lat_end, nlon + use esmf_lonlat_grid_mod, only: zonal_comm + use shr_reprosum_mod,only: shr_reprosum_calc + + real(r8), intent(in) :: lonlatarr(lon_beg:lon_end,lat_beg:lat_end) + real(r8), intent(out) :: zmarr(lat_beg:lat_end) + + real(r8) :: gsum(lat_beg:lat_end) + + integer :: numlons, numlats + + numlons = lon_end-lon_beg+1 + numlats = lat_end-lat_beg+1 + + ! zonal mean + + call shr_reprosum_calc(lonlatarr, gsum, numlons, numlons, numlats, gbl_count=nlon, commid=zonal_comm) + zmarr(:) = gsum(:)/nlon + + end subroutine esmf_zonal_mean_calc_2d + + !------------------------------------------------------------------------------ + ! Computes longitude sums of grid cell weights. + !------------------------------------------------------------------------------ + function esmf_zonal_mean_wsums(wght) result(wsums) + use esmf_lonlat_grid_mod, only: lon_beg,lon_end,lat_beg,lat_end, nlon + use esmf_lonlat_grid_mod, only: zonal_comm + use shr_reprosum_mod,only: shr_reprosum_calc + use ppgrid, only: pver + + real(r8), intent(in) :: wght(lon_beg:lon_end,lat_beg:lat_end,pver) + + real(r8) :: wsums(lat_beg:lat_end,pver) + integer :: numlons, ilat + + numlons = lon_end-lon_beg+1 + + do ilat = lat_beg, lat_end + + call shr_reprosum_calc(wght(lon_beg:lon_end,ilat,:), wsums(ilat,1:pver), & + numlons, numlons, pver, gbl_count=nlon, commid=zonal_comm) + + end do + + end function esmf_zonal_mean_wsums + + !------------------------------------------------------------------------------ + ! Masks out regions (e.g. mountains) from zonal mean calculation. + !------------------------------------------------------------------------------ + subroutine esmf_zonal_mean_masked(lonlatarr, wght, wsums, zmarr) + use esmf_lonlat_grid_mod, only: lon_beg,lon_end,lat_beg,lat_end, nlon + use esmf_lonlat_grid_mod, only: zonal_comm + use shr_reprosum_mod,only: shr_reprosum_calc + use ppgrid, only: pver + + real(r8), intent(in) :: lonlatarr(lon_beg:lon_end,lat_beg:lat_end,pver) + real(r8), intent(in) :: wght(lon_beg:lon_end,lat_beg:lat_end,pver) ! grid cell weights + real(r8), intent(in) :: wsums(lat_beg:lat_end,pver) ! pre-computed sums of grid cell weights + real(r8), intent(out) :: zmarr(lat_beg:lat_end,pver) ! zonal means + + real(r8) :: tmparr(lon_beg:lon_end,pver) + integer :: numlons, ilat, ilev + real(r8) :: gsum(pver) + + numlons = lon_end-lon_beg+1 + + do ilat = lat_beg, lat_end + + tmparr(lon_beg:lon_end,:) = wght(lon_beg:lon_end,ilat,:)*lonlatarr(lon_beg:lon_end,ilat,:) + call shr_reprosum_calc(tmparr, gsum, numlons, numlons, pver, gbl_count=nlon, commid=zonal_comm) + + do ilev = 1,pver + if (wsums(ilat,ilev)>0._r8) then + zmarr(ilat,ilev) = gsum(ilev)/wsums(ilat,ilev) + else + zmarr(ilat,ilev) = fillvalue + end if + end do + + end do + + end subroutine esmf_zonal_mean_masked + +end module esmf_zonal_mean_mod diff --git a/src/utils/remap.F90 b/src/utils/remap.F90 new file mode 100644 index 0000000000..cdc81607af --- /dev/null +++ b/src/utils/remap.F90 @@ -0,0 +1,212 @@ +!----------------------------------------------------------------------------- +! utilities to gather and distribute columns and remap them to/from +! cubed-sphere grid to a lat-lon grid. +!----------------------------------------------------------------------------- +module nlgw_remap_mod + use shr_kind_mod, only: r8 => shr_kind_r8, cx => SHR_KIND_CX + use ppgrid, only: begchunk, endchunk, pcols, pver, pverp + use physics_types, only: physics_state + use phys_grid, only: get_ncols_p + use spmd_utils, only: masterproc + use ref_pres, only: pref_mid + use esmf_lonlat_grid_mod, only: beglon=>lon_beg, endlon=>lon_end, beglat=>lat_beg, endlat=>lat_end + use cam_history, only: addfld, outfld, horiz_only + use cam_history_support, only : fillvalue + use perf_mod, only: t_startf, t_stopf + use cam_logfile, only: iulog + use cam_abortutils, only: endrun + + implicit none + + private + + public :: nlgw_regrid_init + public :: nlgw_regrid + public :: nlgw_regrid_final + +contains + + !----------------------------------------------------------------------------- + !----------------------------------------------------------------------------- + subroutine nlgw_regrid_init() + use cam_grid_support, only: horiz_coord_t, horiz_coord_create, iMap, cam_grid_register + use esmf_lonlat_grid_mod, only: glats, nlat, glons, nlon + use esmf_lonlat_grid_mod, only: esmf_lonlat_grid_init + use esmf_phys_mesh_mod, only: esmf_phys_mesh_init + use esmf_phys2lonlat_mod, only: esmf_phys2lonlat_init + + integer, parameter :: reg_decomp = 332 + + integer(iMap), pointer :: grid_map(:,:) + + integer(iMap), pointer :: coord_map(:) => null() + type(horiz_coord_t), pointer :: lon_coord + type(horiz_coord_t), pointer :: lat_coord + integer :: i, j, ind, astat + + character(len=*), parameter :: subname = 'ctem_diags_reg: ' + + ! initialize grids and mapping + call esmf_lonlat_grid_init(64) + call esmf_phys_mesh_init() + call esmf_phys2lonlat_init() + + ! for the lon-lat grid + allocate(grid_map(4, ((endlon - beglon + 1) * (endlat - beglat + 1))), stat=astat) + if (astat/=0) then + call endrun(subname//'not able to allocate grid_map array') + end if + + ind = 0 + do i = beglat, endlat + do j = beglon, endlon + ind = ind + 1 + grid_map(1, ind) = j + grid_map(2, ind) = i + grid_map(3, ind) = j + grid_map(4, ind) = i + end do + end do + + allocate(coord_map(endlat - beglat + 1), stat=astat) + if (astat/=0) then + call endrun(subname//'not able to allocate coord_map array') + end if + + if (beglon==1) then + coord_map = (/ (i, i = beglat, endlat) /) + else + coord_map = 0 + end if + lat_coord => horiz_coord_create('reglat', '', nlat, 'latitude', 'degrees_north', beglat, endlat, & + glats(beglat:endlat), map=coord_map) + + nullify(coord_map) + + allocate(coord_map(endlon - beglon + 1), stat=astat) + if (astat/=0) then + call endrun(subname//'not able to allocate coord_map array') + end if + + if (beglat==1) then + coord_map = (/ (i, i = beglon, endlon) /) + else + coord_map = 0 + end if + + lon_coord => horiz_coord_create('reglon', '', nlon, 'longitude', 'degrees_east', beglon, endlon, & + glons(beglon:endlon), map=coord_map) + + nullify(coord_map) + + call cam_grid_register('ctem_lonlat', reg_decomp, lat_coord, lon_coord, grid_map, unstruct=.false.) + + nullify(grid_map) + + end subroutine nlgw_regrid_init + + !----------------------------------------------------------------------------- + !----------------------------------------------------------------------------- + subroutine nlgw_regrid(phys_state) + use air_composition, only: mbarv ! g/mole + use shr_const_mod, only: rgas => shr_const_rgas ! J/K/kmole + use shr_const_mod, only: grav => shr_const_g ! m/s2 + use esmf_phys2lonlat_mod, only: esmf_phys2lonlat_regrid + use esmf_zonal_mean_mod, only: esmf_zonal_mean_calc, esmf_zonal_mean_wsums, esmf_zonal_mean_masked + use interpolate_data, only: lininterp + use esmf_phys2lonlat_mod, only: fields_bundle_t, nflds + + type(physics_state), intent(in) :: phys_state(begchunk:endchunk) + + ! arrays on physics grid + real(r8), target :: u_phys(pver,pcols,begchunk:endchunk) + real(r8), target :: v_phys(pver,pcols,begchunk:endchunk) + real(r8), target :: w_phys(pver,pcols,begchunk:endchunk) + real(r8), target :: t_phys(pver,pcols,begchunk:endchunk) + real(r8), target :: pmid_phys(pver,pcols,begchunk:endchunk) + real(r8) :: ps_phys(pcols,begchunk:endchunk) + real(r8) :: lat_phys(pcols,begchunk:endchunk) + real(r8) :: lon_phys(pcols,begchunk:endchunk) + + ! arrays on latlon grid + real(r8), target :: u_lonlat(beglon:endlon,beglat:endlat,pver) + real(r8), target :: v_lonlat(beglon:endlon,beglat:endlat,pver) + real(r8), target :: w_lonlat(beglon:endlon,beglat:endlat,pver) + real(r8), target :: t_lonlat(beglon:endlon,beglat:endlat,pver) + real(r8), target :: pmid_lonlat(beglon:endlon,beglat:endlat,pver) + real(r8) :: ps_lonlat(beglon:endlon,beglat:endlat) + real(r8) :: lat_lonlat(beglon:endlon,beglat:endlat) + real(r8) :: lon_lonlat(beglon:endlon,beglat:endlat) + + integer :: lchnk, ncol, i + + type(fields_bundle_t) :: physflds(nflds) + type(fields_bundle_t) :: lonlatflds(nflds) + + call t_startf('nlgw_gather') + + call t_startf('nlgw_unchunk') + + do lchnk = begchunk,endchunk + ncol = phys_state(lchnk)%ncol + do i = 1,ncol + ! wind components + u_phys(:,i,lchnk) = phys_state(lchnk)%u(i,:) + v_phys(:,i,lchnk) = phys_state(lchnk)%v(i,:) + w_phys(:,i,lchnk) = phys_state(lchnk)%omega(i,:) + t_phys(:,i,lchnk) = phys_state(lchnk)%t(i,:) + pmid_phys(:,i,lchnk) = phys_state(lchnk)%pmid(i,:) + + ps_phys(i,lchnk) = phys_state(lchnk)%ps(i) + lat_phys(i,lchnk) = phys_state(lchnk)%lat(i) + lon_phys(i,lchnk) = phys_state(lchnk)%lon(i) + + end do + end do + + call t_stopf('nlgw_unchunk') + + call t_startf('nlgw_regrid') + + ! regrid to lon/lat grid + + physflds(1)%fld => u_phys + physflds(2)%fld => v_phys + physflds(3)%fld => w_phys + physflds(4)%fld => t_phys + physflds(5)%fld => pmid_phys + + lonlatflds(1)%fld => u_lonlat + lonlatflds(2)%fld => v_lonlat + lonlatflds(3)%fld => w_lonlat + lonlatflds(4)%fld => t_lonlat + lonlatflds(5)%fld => pmid_lonlat + + call esmf_phys2lonlat_regrid(physflds, lonlatflds) + + call esmf_phys2lonlat_regrid(ps_phys, ps_lonlat) + call esmf_phys2lonlat_regrid(lat_phys, lat_lonlat) + call esmf_phys2lonlat_regrid(lon_phys, lon_lonlat) + + call t_stopf('nlgw_regrid') + + call t_stopf('nlgw_gather') + + end subroutine nlgw_regrid + + !----------------------------------------------------------------------------- + !----------------------------------------------------------------------------- + subroutine nlgw_regrid_final() + use esmf_phys2lonlat_mod, only: esmf_phys2lonlat_destroy + use esmf_lonlat_grid_mod, only: esmf_lonlat_grid_destroy + use esmf_phys_mesh_mod, only: esmf_phys_mesh_destroy + + if (.not.ctem_diags_active) return + + call esmf_phys2lonlat_destroy() + call esmf_lonlat_grid_destroy() + call esmf_phys_mesh_destroy() + + end subroutine nlgw_regrid_final + +end module nlgw_remap_mod From e4e12963f5e903e3e0b04dc5f37410bdde049159 Mon Sep 17 00:00:00 2001 From: tommelt Date: Fri, 27 Jun 2025 07:59:50 -0600 Subject: [PATCH 02/23] chore: remove unused code --- src/physics/cam/gw_drag.F90 | 5 ----- src/utils/remap.F90 | 2 -- 2 files changed, 7 deletions(-) diff --git a/src/physics/cam/gw_drag.F90 b/src/physics/cam/gw_drag.F90 index 6f929ad190..9a3a651060 100644 --- a/src/physics/cam/gw_drag.F90 +++ b/src/physics/cam/gw_drag.F90 @@ -45,7 +45,6 @@ module gw_drag use gw_ml, only: gw_drag_convect_dp_ml_init, gw_drag_convect_dp_ml_final, & gw_drag_convect_dp_ml use gw_nlgw, only: gw_nlgw_dp_ml, gw_nlgw_dp_init, gw_nlgw_dp_finalize - use nlgw_remap_mod, only: nlgw_regrid_init, nlgw_regrid, nlgw_regrid_init ! Typical module header implicit none @@ -1490,10 +1489,6 @@ subroutine gw_tend(state, pbuf, dt, ptend, cam_in, flx_heat) ! constituents are all treated as wet mmr call set_dry_to_wet(state1) - call nlgw_regrid_init() - call nlgw_regrid(state1) - stop - lchnk = state1%lchnk ncol = state1%ncol diff --git a/src/utils/remap.F90 b/src/utils/remap.F90 index cdc81607af..28431fdbc9 100644 --- a/src/utils/remap.F90 +++ b/src/utils/remap.F90 @@ -201,8 +201,6 @@ subroutine nlgw_regrid_final() use esmf_lonlat_grid_mod, only: esmf_lonlat_grid_destroy use esmf_phys_mesh_mod, only: esmf_phys_mesh_destroy - if (.not.ctem_diags_active) return - call esmf_phys2lonlat_destroy() call esmf_lonlat_grid_destroy() call esmf_phys_mesh_destroy() From dc338392d6174a1d6e7639b44c8550e32618900d Mon Sep 17 00:00:00 2001 From: tommelt Date: Wed, 30 Jul 2025 06:30:53 -0600 Subject: [PATCH 03/23] wip: regridding debug --- src/physics/cam_dev/physpkg.F90 | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/physics/cam_dev/physpkg.F90 b/src/physics/cam_dev/physpkg.F90 index 1c461c9a1c..516b39046f 100644 --- a/src/physics/cam_dev/physpkg.F90 +++ b/src/physics/cam_dev/physpkg.F90 @@ -869,7 +869,7 @@ subroutine phys_init( phys_state, phys_tend, pbuf2d, cam_in, cam_out ) call co2_init() end if - call gw_init() + ! call gw_init() call rayleigh_friction_init() @@ -1183,6 +1183,7 @@ subroutine phys_run2(phys_state, ztodt, phys_tend, pbuf2d, cam_out, & use metdata, only: get_met_srf2 #endif use hemco_interface, only: HCOI_Chunk_Run + use nlgw_remap_mod, only: nlgw_regrid_init, nlgw_regrid, nlgw_regrid_final ! ! Input arguments ! @@ -1244,6 +1245,10 @@ subroutine phys_run2(phys_state, ztodt, phys_tend, pbuf2d, cam_out, & call t_startf ('ac_physics') call t_adj_detailf(+1) + call nlgw_regrid_init() + call nlgw_regrid(phys_state) + stop + !$OMP PARALLEL DO PRIVATE (C, NCOL, phys_buffer_chunk) do c=begchunk,endchunk From 84e4f07b381d8e0f4651d7cad573b2354884d8ed Mon Sep 17 00:00:00 2001 From: tommelt Date: Wed, 20 Aug 2025 10:20:18 -0600 Subject: [PATCH 04/23] wip: use 192x288 grid instead --- src/utils/esmf_lonlat_grid_mod.F90 | 5 +++-- src/utils/remap.F90 | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/utils/esmf_lonlat_grid_mod.F90 b/src/utils/esmf_lonlat_grid_mod.F90 index 23b575ac61..394322ddc6 100644 --- a/src/utils/esmf_lonlat_grid_mod.F90 +++ b/src/utils/esmf_lonlat_grid_mod.F90 @@ -33,11 +33,12 @@ module esmf_lonlat_grid_mod contains - subroutine esmf_lonlat_grid_init(nlats_in) + subroutine esmf_lonlat_grid_init(nlats_in, nlons_in) use phys_grid, only: get_grid_dims use mpi, only: mpi_comm_size, mpi_comm_rank, MPI_PROC_NULL, MPI_INTEGER integer, intent(in) :: nlats_in + integer, intent(in) :: nlons_in real(r8) :: delx, dely @@ -78,7 +79,7 @@ subroutine esmf_lonlat_grid_init(nlats_in) nlat = nlats_in dely = 180._r8/nlat - nlon = 2*nlat + nlon = nlons_in delx = 360._r8/nlon allocate(glons(nlon), stat=astat) diff --git a/src/utils/remap.F90 b/src/utils/remap.F90 index 28431fdbc9..7769bcb88e 100644 --- a/src/utils/remap.F90 +++ b/src/utils/remap.F90 @@ -47,7 +47,7 @@ subroutine nlgw_regrid_init() character(len=*), parameter :: subname = 'ctem_diags_reg: ' ! initialize grids and mapping - call esmf_lonlat_grid_init(64) + call esmf_lonlat_grid_init(192, 288) call esmf_phys_mesh_init() call esmf_phys2lonlat_init() From 28e0530fd9bbc462445ae24b766293b5f4f89a3b Mon Sep 17 00:00:00 2001 From: tommelt Date: Wed, 20 Aug 2025 10:20:37 -0600 Subject: [PATCH 05/23] wip: try to fix corner issue in regrid --- src/utils/esmf_phys2lonlat_mod.F90 | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/utils/esmf_phys2lonlat_mod.F90 b/src/utils/esmf_phys2lonlat_mod.F90 index 15e46bc79c..a191cf31e5 100644 --- a/src/utils/esmf_phys2lonlat_mod.F90 +++ b/src/utils/esmf_phys2lonlat_mod.F90 @@ -13,7 +13,7 @@ module esmf_phys2lonlat_mod use ESMF, only: ESMF_FieldCreate, ESMF_FieldRegridStore use ESMF, only: ESMF_FieldGet, ESMF_FieldRegrid use ESMF, only: ESMF_KIND_I4, ESMF_KIND_R8, ESMF_TYPEKIND_R8 - use ESMF, only: ESMF_REGRIDMETHOD_BILINEAR, ESMF_POLEMETHOD_ALLAVG, ESMF_EXTRAPMETHOD_NEAREST_IDAVG + use ESMF, only: ESMF_REGRIDMETHOD_BILINEAR, ESMF_POLEMETHOD_NONE, ESMF_EXTRAPMETHOD_NEAREST_IDAVG use ESMF, only: ESMF_TERMORDER_SRCSEQ, ESMF_MESHLOC_ELEMENT, ESMF_STAGGERLOC_CENTER use ESMF, only: ESMF_FieldDestroy, ESMF_RouteHandleDestroy use esmf_check_error_mod, only: check_esmf_error @@ -33,9 +33,11 @@ module esmf_phys2lonlat_mod type(ESMF_Field) :: physfld_3d type(ESMF_Field) :: lonlatfld_3d + type(ESMF_Field) :: lonlatfld_3d_copy type(ESMF_Field) :: physfld_2d type(ESMF_Field) :: lonlatfld_2d + type(ESMF_Field) :: lonlatfld_2d_copy interface esmf_phys2lonlat_regrid module procedure esmf_phys2lonlat_regrid_2d @@ -84,6 +86,7 @@ subroutine esmf_phys2lonlat_init() lonlatfld_3d = ESMF_FieldCreate( lonlat_grid, arrayspec, staggerloc=ESMF_STAGGERLOC_CENTER, & ungriddedLBound=(/1,1/), ungriddedUBound=(/pver,nflds/), rc=rc) call check_esmf_error(rc, subname//'ESMF_FieldCreate 3D lonlat fld ERROR') + lonlatfld_3d_copy = lonlatfld_3d ! 2D phys fld call ESMF_ArraySpecSet(arrayspec, 1, ESMF_TYPEKIND_R8, rc=rc) @@ -99,19 +102,20 @@ subroutine esmf_phys2lonlat_init() lonlatfld_2d = ESMF_FieldCreate( lonlat_grid, arrayspec, staggerloc=ESMF_STAGGERLOC_CENTER, rc=rc) call check_esmf_error(rc, subname//'ESMF_FieldCreate 2D lonlat fld ERROR') + lonlatfld_2d_copy = lonlatfld_2d - call ESMF_FieldRegridStore(srcField=physfld_3d, dstField=lonlatfld_3d, & + call ESMF_FieldRegridStore(srcField=physfld_3d, dstField=lonlatfld_3d_copy, & regridMethod=ESMF_REGRIDMETHOD_BILINEAR, & - polemethod=ESMF_POLEMETHOD_ALLAVG, & + polemethod=ESMF_POLEMETHOD_NONE, & extrapMethod=ESMF_EXTRAPMETHOD_NEAREST_IDAVG, & routeHandle=rh_phys2lonlat_3d, factorIndexList=factorIndexList, & factorList=factorList, srcTermProcessing=smm_srctermproc, & pipelineDepth=smm_pipelinedep, rc=rc) call check_esmf_error(rc, subname//'ESMF_FieldRegridStore 3D routehandle ERROR') - call ESMF_FieldRegridStore(srcField=physfld_2d, dstField=lonlatfld_2d, & + call ESMF_FieldRegridStore(srcField=physfld_2d, dstField=lonlatfld_2d_copy, & regridMethod=ESMF_REGRIDMETHOD_BILINEAR, & - polemethod=ESMF_POLEMETHOD_ALLAVG, & + polemethod=ESMF_POLEMETHOD_NONE, & extrapMethod=ESMF_EXTRAPMETHOD_NEAREST_IDAVG, & routeHandle=rh_phys2lonlat_2d, factorIndexList=factorIndexList, & factorList=factorList, srcTermProcessing=smm_srctermproc, & From d7de81ca3330ddcc51cc38f9fe2f4dcb8b87384c Mon Sep 17 00:00:00 2001 From: tommelt Date: Tue, 2 Sep 2025 09:06:05 -0600 Subject: [PATCH 06/23] chore: remove unnecessary vars --- src/utils/remap.F90 | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/utils/remap.F90 b/src/utils/remap.F90 index 7769bcb88e..b164a326dd 100644 --- a/src/utils/remap.F90 +++ b/src/utils/remap.F90 @@ -125,8 +125,9 @@ subroutine nlgw_regrid(phys_state) real(r8), target :: t_phys(pver,pcols,begchunk:endchunk) real(r8), target :: pmid_phys(pver,pcols,begchunk:endchunk) real(r8) :: ps_phys(pcols,begchunk:endchunk) - real(r8) :: lat_phys(pcols,begchunk:endchunk) - real(r8) :: lon_phys(pcols,begchunk:endchunk) + ! for debugging only + ! real(r8) :: lat_phys(pcols,begchunk:endchunk) + ! real(r8) :: lon_phys(pcols,begchunk:endchunk) ! arrays on latlon grid real(r8), target :: u_lonlat(beglon:endlon,beglat:endlat,pver) @@ -135,8 +136,6 @@ subroutine nlgw_regrid(phys_state) real(r8), target :: t_lonlat(beglon:endlon,beglat:endlat,pver) real(r8), target :: pmid_lonlat(beglon:endlon,beglat:endlat,pver) real(r8) :: ps_lonlat(beglon:endlon,beglat:endlat) - real(r8) :: lat_lonlat(beglon:endlon,beglat:endlat) - real(r8) :: lon_lonlat(beglon:endlon,beglat:endlat) integer :: lchnk, ncol, i @@ -158,8 +157,9 @@ subroutine nlgw_regrid(phys_state) pmid_phys(:,i,lchnk) = phys_state(lchnk)%pmid(i,:) ps_phys(i,lchnk) = phys_state(lchnk)%ps(i) - lat_phys(i,lchnk) = phys_state(lchnk)%lat(i) - lon_phys(i,lchnk) = phys_state(lchnk)%lon(i) + ! for debugging only + ! lat_phys(i,lchnk) = phys_state(lchnk)%lat(i) + ! lon_phys(i,lchnk) = phys_state(lchnk)%lon(i) end do end do @@ -185,8 +185,6 @@ subroutine nlgw_regrid(phys_state) call esmf_phys2lonlat_regrid(physflds, lonlatflds) call esmf_phys2lonlat_regrid(ps_phys, ps_lonlat) - call esmf_phys2lonlat_regrid(lat_phys, lat_lonlat) - call esmf_phys2lonlat_regrid(lon_phys, lon_lonlat) call t_stopf('nlgw_regrid') From 8ed58f0524dfe93a1d6ed0ea01a0fae83a40dbcc Mon Sep 17 00:00:00 2001 From: tommelt Date: Tue, 2 Sep 2025 16:57:27 +0100 Subject: [PATCH 07/23] wip: gathering ps onto masterproc --- src/utils/remap.F90 | 41 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/src/utils/remap.F90 b/src/utils/remap.F90 index b164a326dd..74146bb6c5 100644 --- a/src/utils/remap.F90 +++ b/src/utils/remap.F90 @@ -7,7 +7,7 @@ module nlgw_remap_mod use ppgrid, only: begchunk, endchunk, pcols, pver, pverp use physics_types, only: physics_state use phys_grid, only: get_ncols_p - use spmd_utils, only: masterproc + use spmd_utils, only: masterproc, npes use ref_pres, only: pref_mid use esmf_lonlat_grid_mod, only: beglon=>lon_beg, endlon=>lon_end, beglat=>lat_beg, endlat=>lat_end use cam_history, only: addfld, outfld, horiz_only @@ -47,7 +47,7 @@ subroutine nlgw_regrid_init() character(len=*), parameter :: subname = 'ctem_diags_reg: ' ! initialize grids and mapping - call esmf_lonlat_grid_init(192, 288) + call esmf_lonlat_grid_init(64, 128) call esmf_phys_mesh_init() call esmf_phys2lonlat_init() @@ -111,10 +111,12 @@ subroutine nlgw_regrid(phys_state) use air_composition, only: mbarv ! g/mole use shr_const_mod, only: rgas => shr_const_rgas ! J/K/kmole use shr_const_mod, only: grav => shr_const_g ! m/s2 + use esmf_lonlat_grid_mod, only: nlat, nlon use esmf_phys2lonlat_mod, only: esmf_phys2lonlat_regrid use esmf_zonal_mean_mod, only: esmf_zonal_mean_calc, esmf_zonal_mean_wsums, esmf_zonal_mean_masked use interpolate_data, only: lininterp use esmf_phys2lonlat_mod, only: fields_bundle_t, nflds + use mpishorthand type(physics_state), intent(in) :: phys_state(begchunk:endchunk) @@ -137,7 +139,11 @@ subroutine nlgw_regrid(phys_state) real(r8), target :: pmid_lonlat(beglon:endlon,beglat:endlat,pver) real(r8) :: ps_lonlat(beglon:endlon,beglat:endlat) - integer :: lchnk, ncol, i + real(r8), allocatable :: ps_flat(:) + real(r8), allocatable :: ps_grid(:, :) + + integer :: lchnk, ncol, i, sendcnt, disp_sum + integer, allocatable :: recvcnts(:), displs(:) type(fields_bundle_t) :: physflds(nflds) type(fields_bundle_t) :: lonlatflds(nflds) @@ -188,6 +194,35 @@ subroutine nlgw_regrid(phys_state) call t_stopf('nlgw_regrid') + call t_startf('nlgw_mpigather') + + ! gather ps_lonlat onto master proc here using MPI gather + allocate(recvcnts(npes)) + allocate(displs(npes)) + + sendcnt = (endlon - beglon) * (endlat - beglat) + + call mpigather(sendcnt, 1, mpiint, recvcnts, 1, mpiint, 0, mpicom) + + if (masterproc) then + allocate(ps_flat(nlon*nlat)) + allocate(ps_grid(nlon, nlat)) + disp_sum = 0 + do i = 1, npes + displs(i) = disp_sum + disp_sum = disp_sum + recvcnts(i) + end do + end if + + call mpigatherv(ps_lonlat, sendcnt, mpir8, ps_grid, recvcnts, displs, mpir8, 0, mpicom) + + call t_stopf('nlgw_mpigather') + + if (masterproc) then + deallocate(ps_flat) + deallocate(ps_grid) + end if + call t_stopf('nlgw_gather') end subroutine nlgw_regrid From e55c585fc8cd415334d4d1f0fb724b94f7b00dd2 Mon Sep 17 00:00:00 2001 From: tommelt Date: Wed, 3 Sep 2025 11:06:53 -0600 Subject: [PATCH 08/23] wip: 2d field gather works, need to do 3d now --- src/utils/remap.F90 | 54 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 46 insertions(+), 8 deletions(-) diff --git a/src/utils/remap.F90 b/src/utils/remap.F90 index 74146bb6c5..7c2ae82c8c 100644 --- a/src/utils/remap.F90 +++ b/src/utils/remap.F90 @@ -143,7 +143,12 @@ subroutine nlgw_regrid(phys_state) real(r8), allocatable :: ps_grid(:, :) integer :: lchnk, ncol, i, sendcnt, disp_sum + integer :: lonsize, latsize + integer :: tompver, tompcols + integer, allocatable :: recvcnts(:), displs(:) + integer, allocatable :: beglats(:), beglons(:) + integer, allocatable :: endlats(:), endlons(:) type(fields_bundle_t) :: physflds(nflds) type(fields_bundle_t) :: lonlatflds(nflds) @@ -173,8 +178,7 @@ subroutine nlgw_regrid(phys_state) call t_stopf('nlgw_unchunk') call t_startf('nlgw_regrid') - - ! regrid to lon/lat grid + ! this subsection does regridding physflds(1)%fld => u_phys physflds(2)%fld => v_phys @@ -188,25 +192,38 @@ subroutine nlgw_regrid(phys_state) lonlatflds(4)%fld => t_lonlat lonlatflds(5)%fld => pmid_lonlat + ! actual call to regrid to lon/lat grid call esmf_phys2lonlat_regrid(physflds, lonlatflds) - call esmf_phys2lonlat_regrid(ps_phys, ps_lonlat) + ! TODO + ! convert t to theta before gathering + ! we dont need ps we need phis + call t_stopf('nlgw_regrid') call t_startf('nlgw_mpigather') + ! this subsection gathers all variables onto a single process - ! gather ps_lonlat onto master proc here using MPI gather allocate(recvcnts(npes)) allocate(displs(npes)) + allocate(beglats(npes)) + allocate(beglons(npes)) + allocate(endlats(npes)) + allocate(endlons(npes)) + allocate(ps_flat(nlon * nlat)) + allocate(ps_grid(nlon, nlat)) - sendcnt = (endlon - beglon) * (endlat - beglat) + sendcnt = (endlon - beglon + 1) * (endlat - beglat + 1) + ! mpi gather book-keeping call mpigather(sendcnt, 1, mpiint, recvcnts, 1, mpiint, 0, mpicom) + call mpigather(beglat, 1, mpiint, beglats, 1, mpiint, 0, mpicom) + call mpigather(beglon, 1, mpiint, beglons, 1, mpiint, 0, mpicom) + call mpigather(endlat, 1, mpiint, endlats, 1, mpiint, 0, mpicom) + call mpigather(endlon, 1, mpiint, endlons, 1, mpiint, 0, mpicom) if (masterproc) then - allocate(ps_flat(nlon*nlat)) - allocate(ps_grid(nlon, nlat)) disp_sum = 0 do i = 1, npes displs(i) = disp_sum @@ -214,11 +231,32 @@ subroutine nlgw_regrid(phys_state) end do end if - call mpigatherv(ps_lonlat, sendcnt, mpir8, ps_grid, recvcnts, displs, mpir8, 0, mpicom) + ! gather variables onto master proc into a flat array (can't do 2D/3D mpigather) + call mpigatherv(ps_lonlat(beglon:endlon, beglat:endlat), sendcnt, mpir8, ps_flat, recvcnts, displs, mpir8, 0, mpicom) + + tompver = pver + tompcols = pcols + print *, tompver + print *, tompcols + + if (masterproc) then + do i = 1, npes + lonsize = endlons(i) - beglons(i) + 1 + latsize = endlats(i) - beglats(i) + 1 + ! reshape each ranks flattended data and populate each block into a single lonlat grid + ps_grid(beglons(i):endlons(i), beglats(i):endlats(i)) = & + reshape(ps_flat(displs(i)+1:displs(i)+sendcnt), (/ lonsize, latsize /)) + end do + end if call t_stopf('nlgw_mpigather') + ! TODO + ! convert fluxes to tendencies after regridding back to cubed sphere + ! that way we dont need pmid + if (masterproc) then + ! TODO ALL deallocates here deallocate(ps_flat) deallocate(ps_grid) end if From a7e7de0d0b655294b0510ea89843dafcc71be729 Mon Sep 17 00:00:00 2001 From: Francis Vitt Date: Wed, 3 Sep 2025 15:46:00 -0600 Subject: [PATCH 09/23] untested code for lonlat2phys regrid --- src/utils/esmf_lonlat2phys_mod.F90 | 160 +++++++++++++++++++++++++++++ src/utils/esmf_phys2lonlat_mod.F90 | 12 +-- src/utils/remap.F90 | 8 ++ 3 files changed, 172 insertions(+), 8 deletions(-) create mode 100644 src/utils/esmf_lonlat2phys_mod.F90 diff --git a/src/utils/esmf_lonlat2phys_mod.F90 b/src/utils/esmf_lonlat2phys_mod.F90 new file mode 100644 index 0000000000..63a388d10a --- /dev/null +++ b/src/utils/esmf_lonlat2phys_mod.F90 @@ -0,0 +1,160 @@ +!------------------------------------------------------------------------------ +! Provides methods for mapping from regular longitude / latitude grid +! to physics grid to via ESMF regridding capabilities +!------------------------------------------------------------------------------ +module esmf_lonlat2phys_mod + use shr_kind_mod, only: r8 => shr_kind_r8 + use cam_logfile, only: iulog + use cam_abortutils, only: endrun + use spmd_utils, only: masterproc + use ppgrid, only: pver + + use ESMF, only: ESMF_RouteHandle, ESMF_Field, ESMF_ArraySpec, ESMF_ArraySpecSet + use ESMF, only: ESMF_FieldCreate, ESMF_FieldRegridStore + use ESMF, only: ESMF_FieldGet, ESMF_FieldRegrid + use ESMF, only: ESMF_KIND_I4, ESMF_KIND_R8, ESMF_TYPEKIND_R8 + use ESMF, only: ESMF_REGRIDMETHOD_BILINEAR, ESMF_POLEMETHOD_NONE, ESMF_EXTRAPMETHOD_NEAREST_IDAVG + use ESMF, only: ESMF_TERMORDER_SRCSEQ, ESMF_MESHLOC_ELEMENT, ESMF_STAGGERLOC_CENTER + use ESMF, only: ESMF_FieldDestroy, ESMF_RouteHandleDestroy + use esmf_check_error_mod, only: check_esmf_error + + implicit none + + private + + public :: esmf_lonlat2phys_init + public :: esmf_lonlat2phys_regrid + public :: esmf_lonlat2phys_destroy + public :: fields_bundle_t + public :: n_flx_flds + + type(ESMF_RouteHandle) :: rh_lonlat2phys_3d + + type(ESMF_Field) :: physfld_3d + type(ESMF_Field) :: lonlatfld_3d + + interface esmf_lonlat2phys_regrid + module procedure esmf_lonlat2phys_regrid_3d + end interface esmf_lonlat2phys_regrid + + type :: fields_bundle_t + real(r8), pointer :: fld(:,:,:) => null() + end type fields_bundle_t + + integer, parameter :: n_flx_flds = 2 ! 3D uflux and vflux + +contains + + !------------------------------------------------------------------------------ + !------------------------------------------------------------------------------ + subroutine esmf_lonlat2phys_init() + use esmf_phys_mesh_mod, only: physics_grid_mesh + use esmf_lonlat_grid_mod, only: lonlat_grid + + type(ESMF_ArraySpec) :: arrayspec + integer(ESMF_KIND_I4), pointer :: factorIndexList(:,:) + real(ESMF_KIND_R8), pointer :: factorList(:) + integer :: smm_srctermproc, smm_pipelinedep, rc + + character(len=*), parameter :: subname = 'esmf_lonlat2phys_init: ' + + smm_srctermproc = 0 + smm_pipelinedep = 16 + + ! create ESMF fields + + ! 3D phys fld + call ESMF_ArraySpecSet(arrayspec, 3, ESMF_TYPEKIND_R8, rc=rc) + call check_esmf_error(rc, subname//'ESMF_ArraySpecSet 3D phys fld ERROR') + + physfld_3d = ESMF_FieldCreate(physics_grid_mesh, arrayspec, & + gridToFieldMap=(/3/), meshloc=ESMF_MESHLOC_ELEMENT, & + ungriddedLBound=(/1,1/), ungriddedUBound=(/pver,n_flx_flds/), rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldCreate 3D phys fld ERROR') + + ! 3D lon lat grid + call ESMF_ArraySpecSet(arrayspec, 4, ESMF_TYPEKIND_R8, rc=rc) + call check_esmf_error(rc, subname//'ESMF_ArraySpecSet 3D lonlat fld ERROR') + + lonlatfld_3d = ESMF_FieldCreate( lonlat_grid, arrayspec, staggerloc=ESMF_STAGGERLOC_CENTER, & + ungriddedLBound=(/1,1/), ungriddedUBound=(/pver,n_flx_flds/), rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldCreate 3D lonlat fld ERROR') + + call ESMF_FieldRegridStore(srcField=lonlatfld_3d, dstField=physfld_3d, & + regridMethod=ESMF_REGRIDMETHOD_BILINEAR, & + polemethod=ESMF_POLEMETHOD_NONE, & + extrapMethod=ESMF_EXTRAPMETHOD_NEAREST_IDAVG, & + routeHandle=rh_lonlat2phys_3d, factorIndexList=factorIndexList, & + factorList=factorList, srcTermProcessing=smm_srctermproc, & + pipelineDepth=smm_pipelinedep, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldRegridStore 3D routehandle ERROR') + + end subroutine esmf_lonlat2phys_init + + !------------------------------------------------------------------------------ + !------------------------------------------------------------------------------ + subroutine esmf_lonlat2phys_regrid_3d(lonlatflds, physflds) + use esmf_lonlat_grid_mod, only: lon_beg,lon_end,lat_beg,lat_end + use ppgrid, only: pcols, pver, begchunk, endchunk + use phys_grid, only: get_ncols_p + + type(fields_bundle_t), intent(in) :: lonlatflds(n_flx_flds) + type(fields_bundle_t), intent(inout) :: physflds(n_flx_flds) + + integer :: i, ichnk, ncol, ifld, ilev, icol, rc + real(ESMF_KIND_R8), pointer :: physptr(:,:,:) + real(ESMF_KIND_R8), pointer :: lonlatptr(:,:,:,:) + + character(len=*), parameter :: subname = 'esmf_lonlat2phys_regrid_3d: ' + + ! set values of lonlatfld_3d ESMF field + call ESMF_FieldGet(lonlatfld_3d, localDe=0, farrayPtr=lonlatptr, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldGet lonlatptr') + + do ifld = 1,n_flx_flds + lonlatptr(lon_beg:lon_end,lat_beg:lat_end,1:pver,ifld) = lonlatflds(ifld)%fld(lon_beg:lon_end,lat_beg:lat_end,1:pver) + end do + + ! regrid + call ESMF_FieldRegrid(lonlatfld_3d, physfld_3d, rh_lonlat2phys_3d, & + termorderflag=ESMF_TERMORDER_SRCSEQ, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldRegrid lonlatfld_3d->physfld_3d') + + ! get values from physfld_3d ESMF field + call ESMF_FieldGet(physfld_3d, localDe=0, farrayPtr=physptr, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldGet physptr') + + i = 0 + do ichnk = begchunk, endchunk + ncol = get_ncols_p(ichnk) + do icol = 1,ncol + i = i+1 + do ifld = 1,n_flx_flds + do ilev = 1,pver + physflds(ifld)%fld(ilev,icol,ichnk) = physptr(ilev,ifld,i) + end do + end do + end do + end do + + end subroutine esmf_lonlat2phys_regrid_3d + + !------------------------------------------------------------------------------ + !------------------------------------------------------------------------------ + subroutine esmf_lonlat2phys_destroy() + + integer :: rc + character(len=*), parameter :: subname = 'esmf_lonlat2phys_destroy: ' + + call ESMF_RouteHandleDestroy(rh_lonlat2phys_3d, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldDestroy rh_lonlat2phys_3d') + + call ESMF_FieldDestroy(lonlatfld_3d, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldDestroy lonlatfld_3d') + + call ESMF_FieldDestroy(physfld_3d, rc=rc) + call check_esmf_error(rc, subname//'ESMF_FieldDestroy physfld_3d') + + end subroutine esmf_lonlat2phys_destroy + +end module esmf_lonlat2phys_mod diff --git a/src/utils/esmf_phys2lonlat_mod.F90 b/src/utils/esmf_phys2lonlat_mod.F90 index a191cf31e5..bf18e3b6ef 100644 --- a/src/utils/esmf_phys2lonlat_mod.F90 +++ b/src/utils/esmf_phys2lonlat_mod.F90 @@ -33,11 +33,9 @@ module esmf_phys2lonlat_mod type(ESMF_Field) :: physfld_3d type(ESMF_Field) :: lonlatfld_3d - type(ESMF_Field) :: lonlatfld_3d_copy type(ESMF_Field) :: physfld_2d type(ESMF_Field) :: lonlatfld_2d - type(ESMF_Field) :: lonlatfld_2d_copy interface esmf_phys2lonlat_regrid module procedure esmf_phys2lonlat_regrid_2d @@ -86,7 +84,6 @@ subroutine esmf_phys2lonlat_init() lonlatfld_3d = ESMF_FieldCreate( lonlat_grid, arrayspec, staggerloc=ESMF_STAGGERLOC_CENTER, & ungriddedLBound=(/1,1/), ungriddedUBound=(/pver,nflds/), rc=rc) call check_esmf_error(rc, subname//'ESMF_FieldCreate 3D lonlat fld ERROR') - lonlatfld_3d_copy = lonlatfld_3d ! 2D phys fld call ESMF_ArraySpecSet(arrayspec, 1, ESMF_TYPEKIND_R8, rc=rc) @@ -102,20 +99,19 @@ subroutine esmf_phys2lonlat_init() lonlatfld_2d = ESMF_FieldCreate( lonlat_grid, arrayspec, staggerloc=ESMF_STAGGERLOC_CENTER, rc=rc) call check_esmf_error(rc, subname//'ESMF_FieldCreate 2D lonlat fld ERROR') - lonlatfld_2d_copy = lonlatfld_2d - call ESMF_FieldRegridStore(srcField=physfld_3d, dstField=lonlatfld_3d_copy, & + call ESMF_FieldRegridStore(srcField=physfld_3d, dstField=lonlatfld_3d, & regridMethod=ESMF_REGRIDMETHOD_BILINEAR, & - polemethod=ESMF_POLEMETHOD_NONE, & + polemethod=ESMF_POLEMETHOD_NONE, & extrapMethod=ESMF_EXTRAPMETHOD_NEAREST_IDAVG, & routeHandle=rh_phys2lonlat_3d, factorIndexList=factorIndexList, & factorList=factorList, srcTermProcessing=smm_srctermproc, & pipelineDepth=smm_pipelinedep, rc=rc) call check_esmf_error(rc, subname//'ESMF_FieldRegridStore 3D routehandle ERROR') - call ESMF_FieldRegridStore(srcField=physfld_2d, dstField=lonlatfld_2d_copy, & + call ESMF_FieldRegridStore(srcField=physfld_2d, dstField=lonlatfld_2d, & regridMethod=ESMF_REGRIDMETHOD_BILINEAR, & - polemethod=ESMF_POLEMETHOD_NONE, & + polemethod=ESMF_POLEMETHOD_NONE, & extrapMethod=ESMF_EXTRAPMETHOD_NEAREST_IDAVG, & routeHandle=rh_phys2lonlat_2d, factorIndexList=factorIndexList, & factorList=factorList, srcTermProcessing=smm_srctermproc, & diff --git a/src/utils/remap.F90 b/src/utils/remap.F90 index 7c2ae82c8c..af9ab9529f 100644 --- a/src/utils/remap.F90 +++ b/src/utils/remap.F90 @@ -34,6 +34,7 @@ subroutine nlgw_regrid_init() use esmf_lonlat_grid_mod, only: esmf_lonlat_grid_init use esmf_phys_mesh_mod, only: esmf_phys_mesh_init use esmf_phys2lonlat_mod, only: esmf_phys2lonlat_init + use esmf_lonlat2phys_mod, only: esmf_lonlat2phys_init integer, parameter :: reg_decomp = 332 @@ -50,6 +51,7 @@ subroutine nlgw_regrid_init() call esmf_lonlat_grid_init(64, 128) call esmf_phys_mesh_init() call esmf_phys2lonlat_init() + call esmf_lonlat2phys_init() ! for the lon-lat grid allocate(grid_map(4, ((endlon - beglon + 1) * (endlat - beglat + 1))), stat=astat) @@ -116,6 +118,7 @@ subroutine nlgw_regrid(phys_state) use esmf_zonal_mean_mod, only: esmf_zonal_mean_calc, esmf_zonal_mean_wsums, esmf_zonal_mean_masked use interpolate_data, only: lininterp use esmf_phys2lonlat_mod, only: fields_bundle_t, nflds + use esmf_lonlat2phys_mod, only: esmf_lonlat2phys_regrid, n_flx_flds use mpishorthand type(physics_state), intent(in) :: phys_state(begchunk:endchunk) @@ -153,6 +156,9 @@ subroutine nlgw_regrid(phys_state) type(fields_bundle_t) :: physflds(nflds) type(fields_bundle_t) :: lonlatflds(nflds) + type(fields_bundle_t) :: phys_flx_flds(n_flx_flds) + type(fields_bundle_t) :: lonlat_flx_flds(n_flx_flds) + call t_startf('nlgw_gather') call t_startf('nlgw_unchunk') @@ -269,10 +275,12 @@ end subroutine nlgw_regrid !----------------------------------------------------------------------------- subroutine nlgw_regrid_final() use esmf_phys2lonlat_mod, only: esmf_phys2lonlat_destroy + use esmf_lonlat2phys_mod, only: esmf_lonlat2phys_destroy use esmf_lonlat_grid_mod, only: esmf_lonlat_grid_destroy use esmf_phys_mesh_mod, only: esmf_phys_mesh_destroy call esmf_phys2lonlat_destroy() + call esmf_lonlat2phys_destroy() call esmf_lonlat_grid_destroy() call esmf_phys_mesh_destroy() From 27fa6a5a75dc5842b5466ef9c790e67c750c9cf5 Mon Sep 17 00:00:00 2001 From: tommelt Date: Thu, 4 Sep 2025 07:01:53 -0600 Subject: [PATCH 10/23] wip: gather 3d fields working now --- src/utils/esmf_phys2lonlat_mod.F90 | 2 +- src/utils/remap.F90 | 117 ++++++++++++++++++++++------- 2 files changed, 91 insertions(+), 28 deletions(-) diff --git a/src/utils/esmf_phys2lonlat_mod.F90 b/src/utils/esmf_phys2lonlat_mod.F90 index a191cf31e5..1f340e83e3 100644 --- a/src/utils/esmf_phys2lonlat_mod.F90 +++ b/src/utils/esmf_phys2lonlat_mod.F90 @@ -48,7 +48,7 @@ module esmf_phys2lonlat_mod real(r8), pointer :: fld(:,:,:) => null() end type fields_bundle_t - integer, parameter :: nflds = 5 + integer, parameter :: nflds = 4 contains diff --git a/src/utils/remap.F90 b/src/utils/remap.F90 index 7c2ae82c8c..09162df59b 100644 --- a/src/utils/remap.F90 +++ b/src/utils/remap.F90 @@ -125,8 +125,7 @@ subroutine nlgw_regrid(phys_state) real(r8), target :: v_phys(pver,pcols,begchunk:endchunk) real(r8), target :: w_phys(pver,pcols,begchunk:endchunk) real(r8), target :: t_phys(pver,pcols,begchunk:endchunk) - real(r8), target :: pmid_phys(pver,pcols,begchunk:endchunk) - real(r8) :: ps_phys(pcols,begchunk:endchunk) + real(r8) :: phis_phys(pcols,begchunk:endchunk) ! for debugging only ! real(r8) :: lat_phys(pcols,begchunk:endchunk) ! real(r8) :: lon_phys(pcols,begchunk:endchunk) @@ -136,15 +135,14 @@ subroutine nlgw_regrid(phys_state) real(r8), target :: v_lonlat(beglon:endlon,beglat:endlat,pver) real(r8), target :: w_lonlat(beglon:endlon,beglat:endlat,pver) real(r8), target :: t_lonlat(beglon:endlon,beglat:endlat,pver) - real(r8), target :: pmid_lonlat(beglon:endlon,beglat:endlat,pver) - real(r8) :: ps_lonlat(beglon:endlon,beglat:endlat) + real(r8) :: phis_lonlat(beglon:endlon,beglat:endlat) - real(r8), allocatable :: ps_flat(:) - real(r8), allocatable :: ps_grid(:, :) + real(r8), allocatable :: flat_array(:) + real(r8), dimension(:, :), allocatable :: phis_grid + real(r8), dimension(:,:,:), allocatable :: u_grid, v_grid, w_grid, t_grid integer :: lchnk, ncol, i, sendcnt, disp_sum integer :: lonsize, latsize - integer :: tompver, tompcols integer, allocatable :: recvcnts(:), displs(:) integer, allocatable :: beglats(:), beglons(:) @@ -165,9 +163,8 @@ subroutine nlgw_regrid(phys_state) v_phys(:,i,lchnk) = phys_state(lchnk)%v(i,:) w_phys(:,i,lchnk) = phys_state(lchnk)%omega(i,:) t_phys(:,i,lchnk) = phys_state(lchnk)%t(i,:) - pmid_phys(:,i,lchnk) = phys_state(lchnk)%pmid(i,:) - ps_phys(i,lchnk) = phys_state(lchnk)%ps(i) + phis_phys(i,lchnk) = phys_state(lchnk)%ps(i) ! for debugging only ! lat_phys(i,lchnk) = phys_state(lchnk)%lat(i) ! lon_phys(i,lchnk) = phys_state(lchnk)%lon(i) @@ -184,17 +181,15 @@ subroutine nlgw_regrid(phys_state) physflds(2)%fld => v_phys physflds(3)%fld => w_phys physflds(4)%fld => t_phys - physflds(5)%fld => pmid_phys lonlatflds(1)%fld => u_lonlat lonlatflds(2)%fld => v_lonlat lonlatflds(3)%fld => w_lonlat lonlatflds(4)%fld => t_lonlat - lonlatflds(5)%fld => pmid_lonlat ! actual call to regrid to lon/lat grid call esmf_phys2lonlat_regrid(physflds, lonlatflds) - call esmf_phys2lonlat_regrid(ps_phys, ps_lonlat) + call esmf_phys2lonlat_regrid(phis_phys, phis_lonlat) ! TODO ! convert t to theta before gathering @@ -211,8 +206,6 @@ subroutine nlgw_regrid(phys_state) allocate(beglons(npes)) allocate(endlats(npes)) allocate(endlons(npes)) - allocate(ps_flat(nlon * nlat)) - allocate(ps_grid(nlon, nlat)) sendcnt = (endlon - beglon + 1) * (endlat - beglat + 1) @@ -231,23 +224,30 @@ subroutine nlgw_regrid(phys_state) end do end if - ! gather variables onto master proc into a flat array (can't do 2D/3D mpigather) - call mpigatherv(ps_lonlat(beglon:endlon, beglat:endlat), sendcnt, mpir8, ps_flat, recvcnts, displs, mpir8, 0, mpicom) + allocate(flat_array(nlon * nlat)) + + call gather_2d(phis_lonlat(beglon:endlon, beglat:endlat), sendcnt, flat_array, recvcnts, displs, & + phis_grid, beglons, endlons, beglats, endlats) - tompver = pver - tompcols = pcols - print *, tompver - print *, tompcols + sendcnt = sendcnt * pver if (masterproc) then do i = 1, npes - lonsize = endlons(i) - beglons(i) + 1 - latsize = endlats(i) - beglats(i) + 1 - ! reshape each ranks flattended data and populate each block into a single lonlat grid - ps_grid(beglons(i):endlons(i), beglats(i):endlats(i)) = & - reshape(ps_flat(displs(i)+1:displs(i)+sendcnt), (/ lonsize, latsize /)) + displs(i) = displs(i) * pver + recvcnts(i) = recvcnts(i) * pver end do end if + deallocate(flat_array) + allocate(flat_array(nlon * nlat * pver)) + + call gather_3d(u_lonlat(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, recvcnts, displs, & + u_grid, beglons, endlons, beglats, endlats) + call gather_3d(v_lonlat(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, recvcnts, displs, & + v_grid, beglons, endlons, beglats, endlats) + call gather_3d(w_lonlat(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, recvcnts, displs, & + w_grid, beglons, endlons, beglats, endlats) + call gather_3d(t_lonlat(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, recvcnts, displs, & + t_grid, beglons, endlons, beglats, endlats) call t_stopf('nlgw_mpigather') @@ -257,14 +257,77 @@ subroutine nlgw_regrid(phys_state) if (masterproc) then ! TODO ALL deallocates here - deallocate(ps_flat) - deallocate(ps_grid) + deallocate(u_grid) + deallocate(v_grid) + deallocate(w_grid) + deallocate(t_grid) + deallocate(phis_grid) end if call t_stopf('nlgw_gather') end subroutine nlgw_regrid + !----------------------------------------------------------------------------- + !----------------------------------------------------------------------------- + subroutine gather_2d(local_array, sendcnt, flat_array, recvcnts, displs, grid_out, beglons, endlons, beglats, endlats) + use mpishorthand + use esmf_lonlat_grid_mod, only: nlat, nlon + real(r8), intent(in) :: local_array(:,:) ! Local 2D array section + integer, intent(in) :: sendcnt + real(r8), intent(inout) :: flat_array(:) ! Flattened array for gathering + integer, intent(in) :: recvcnts(:), displs(:) + real(r8), allocatable, intent(out) :: grid_out(:,:) ! Full gathered grid + integer, intent(in) :: beglons(:), endlons(:) + integer, intent(in) :: beglats(:), endlats(:) + + integer :: i, lonsize, latsize + + ! gather variables onto master proc into a flat array (can't do 2D/3D mpigather) + call mpigatherv(local_array, sendcnt, mpir8, flat_array, recvcnts, displs, mpir8, 0, mpicom) + + if (masterproc) then + allocate(grid_out(nlon, nlat)) + do i = 1, npes + lonsize = endlons(i) - beglons(i) + 1 + latsize = endlats(i) - beglats(i) + 1 + ! reshape each ranks flattended data and populate each block into a single lonlat grid + grid_out(beglons(i):endlons(i), beglats(i):endlats(i)) = & + reshape(flat_array(displs(i)+1:displs(i)+sendcnt), (/ lonsize, latsize /)) + end do + end if + end subroutine gather_2d + + !----------------------------------------------------------------------------- + !----------------------------------------------------------------------------- + subroutine gather_3d(local_array, sendcnt, flat_array, recvcnts, displs, grid_out, beglons, endlons, beglats, endlats) + use mpishorthand + use esmf_lonlat_grid_mod, only: nlat, nlon + real(r8), intent(in) :: local_array(:,:,:) ! Local 2D array section + integer, intent(in) :: sendcnt + real(r8), intent(inout) :: flat_array(:) ! Flattened array for gathering + integer, intent(in) :: recvcnts(:), displs(:) + real(r8), allocatable, intent(out) :: grid_out(:,:,:) ! Full gathered grid + integer, intent(in) :: beglons(:), endlons(:) + integer, intent(in) :: beglats(:), endlats(:) + + integer :: i, lonsize, latsize + + ! gather variables onto master proc into a flat array (can't do 2D/3D mpigather) + call mpigatherv(local_array, sendcnt, mpir8, flat_array, recvcnts, displs, mpir8, 0, mpicom) + + if (masterproc) then + allocate(grid_out(nlon, nlat, pver)) + do i = 1, npes + lonsize = endlons(i) - beglons(i) + 1 + latsize = endlats(i) - beglats(i) + 1 + ! reshape each ranks flattended data and populate each block into a single lonlat grid + grid_out(beglons(i):endlons(i), beglats(i):endlats(i), 1:pver) = & + reshape(flat_array(displs(i)+1:displs(i)+sendcnt), (/ lonsize, latsize, pver /)) + end do + end if + end subroutine gather_3d + !----------------------------------------------------------------------------- !----------------------------------------------------------------------------- subroutine nlgw_regrid_final() From 25f9b50f3589154f1ca83c643483e58c95db514d Mon Sep 17 00:00:00 2001 From: tommelt Date: Wed, 10 Sep 2025 05:13:41 -0600 Subject: [PATCH 11/23] feat: make grid arrays available to module --- src/utils/remap.F90 | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/utils/remap.F90 b/src/utils/remap.F90 index 09162df59b..0c9cf94b4a 100644 --- a/src/utils/remap.F90 +++ b/src/utils/remap.F90 @@ -24,6 +24,11 @@ module nlgw_remap_mod public :: nlgw_regrid public :: nlgw_regrid_final + ! these arrays contain the regridded variables of interest for the NN + real(r8), dimension(:, :), allocatable, public :: phis_grid + real(r8), dimension(:,:,:), allocatable, public :: u_grid, v_grid, w_grid, t_grid + real(r8), dimension(:,:,:), allocatable, public :: utgw_grid, vtgw_grid + contains !----------------------------------------------------------------------------- @@ -138,8 +143,6 @@ subroutine nlgw_regrid(phys_state) real(r8) :: phis_lonlat(beglon:endlon,beglat:endlat) real(r8), allocatable :: flat_array(:) - real(r8), dimension(:, :), allocatable :: phis_grid - real(r8), dimension(:,:,:), allocatable :: u_grid, v_grid, w_grid, t_grid integer :: lchnk, ncol, i, sendcnt, disp_sum integer :: lonsize, latsize @@ -255,15 +258,6 @@ subroutine nlgw_regrid(phys_state) ! convert fluxes to tendencies after regridding back to cubed sphere ! that way we dont need pmid - if (masterproc) then - ! TODO ALL deallocates here - deallocate(u_grid) - deallocate(v_grid) - deallocate(w_grid) - deallocate(t_grid) - deallocate(phis_grid) - end if - call t_stopf('nlgw_gather') end subroutine nlgw_regrid @@ -339,6 +333,18 @@ subroutine nlgw_regrid_final() call esmf_lonlat_grid_destroy() call esmf_phys_mesh_destroy() + if (masterproc) then + ! TODO ALL deallocates here + deallocate(phis_grid) + deallocate(u_grid) + deallocate(v_grid) + deallocate(w_grid) + deallocate(t_grid) + deallocate(utgw_grid) + deallocate(vtgw_grid) + end if + + end subroutine nlgw_regrid_final end module nlgw_remap_mod From a8cb8d1e706287017244e5663241b48f5a3db9a5 Mon Sep 17 00:00:00 2001 From: tommelt Date: Wed, 10 Sep 2025 05:14:16 -0600 Subject: [PATCH 12/23] Revert "wip: try to fix corner issue in regrid" This reverts commit 28e0530fd9bbc462445ae24b766293b5f4f89a3b. --- src/utils/esmf_phys2lonlat_mod.F90 | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/utils/esmf_phys2lonlat_mod.F90 b/src/utils/esmf_phys2lonlat_mod.F90 index 1f340e83e3..8f6ef3fdad 100644 --- a/src/utils/esmf_phys2lonlat_mod.F90 +++ b/src/utils/esmf_phys2lonlat_mod.F90 @@ -13,7 +13,7 @@ module esmf_phys2lonlat_mod use ESMF, only: ESMF_FieldCreate, ESMF_FieldRegridStore use ESMF, only: ESMF_FieldGet, ESMF_FieldRegrid use ESMF, only: ESMF_KIND_I4, ESMF_KIND_R8, ESMF_TYPEKIND_R8 - use ESMF, only: ESMF_REGRIDMETHOD_BILINEAR, ESMF_POLEMETHOD_NONE, ESMF_EXTRAPMETHOD_NEAREST_IDAVG + use ESMF, only: ESMF_REGRIDMETHOD_BILINEAR, ESMF_POLEMETHOD_ALLAVG, ESMF_EXTRAPMETHOD_NEAREST_IDAVG use ESMF, only: ESMF_TERMORDER_SRCSEQ, ESMF_MESHLOC_ELEMENT, ESMF_STAGGERLOC_CENTER use ESMF, only: ESMF_FieldDestroy, ESMF_RouteHandleDestroy use esmf_check_error_mod, only: check_esmf_error @@ -33,11 +33,9 @@ module esmf_phys2lonlat_mod type(ESMF_Field) :: physfld_3d type(ESMF_Field) :: lonlatfld_3d - type(ESMF_Field) :: lonlatfld_3d_copy type(ESMF_Field) :: physfld_2d type(ESMF_Field) :: lonlatfld_2d - type(ESMF_Field) :: lonlatfld_2d_copy interface esmf_phys2lonlat_regrid module procedure esmf_phys2lonlat_regrid_2d @@ -86,7 +84,6 @@ subroutine esmf_phys2lonlat_init() lonlatfld_3d = ESMF_FieldCreate( lonlat_grid, arrayspec, staggerloc=ESMF_STAGGERLOC_CENTER, & ungriddedLBound=(/1,1/), ungriddedUBound=(/pver,nflds/), rc=rc) call check_esmf_error(rc, subname//'ESMF_FieldCreate 3D lonlat fld ERROR') - lonlatfld_3d_copy = lonlatfld_3d ! 2D phys fld call ESMF_ArraySpecSet(arrayspec, 1, ESMF_TYPEKIND_R8, rc=rc) @@ -102,20 +99,19 @@ subroutine esmf_phys2lonlat_init() lonlatfld_2d = ESMF_FieldCreate( lonlat_grid, arrayspec, staggerloc=ESMF_STAGGERLOC_CENTER, rc=rc) call check_esmf_error(rc, subname//'ESMF_FieldCreate 2D lonlat fld ERROR') - lonlatfld_2d_copy = lonlatfld_2d - call ESMF_FieldRegridStore(srcField=physfld_3d, dstField=lonlatfld_3d_copy, & + call ESMF_FieldRegridStore(srcField=physfld_3d, dstField=lonlatfld_3d, & regridMethod=ESMF_REGRIDMETHOD_BILINEAR, & - polemethod=ESMF_POLEMETHOD_NONE, & + polemethod=ESMF_POLEMETHOD_ALLAVG, & extrapMethod=ESMF_EXTRAPMETHOD_NEAREST_IDAVG, & routeHandle=rh_phys2lonlat_3d, factorIndexList=factorIndexList, & factorList=factorList, srcTermProcessing=smm_srctermproc, & pipelineDepth=smm_pipelinedep, rc=rc) call check_esmf_error(rc, subname//'ESMF_FieldRegridStore 3D routehandle ERROR') - call ESMF_FieldRegridStore(srcField=physfld_2d, dstField=lonlatfld_2d_copy, & + call ESMF_FieldRegridStore(srcField=physfld_2d, dstField=lonlatfld_2d, & regridMethod=ESMF_REGRIDMETHOD_BILINEAR, & - polemethod=ESMF_POLEMETHOD_NONE, & + polemethod=ESMF_POLEMETHOD_ALLAVG, & extrapMethod=ESMF_EXTRAPMETHOD_NEAREST_IDAVG, & routeHandle=rh_phys2lonlat_2d, factorIndexList=factorIndexList, & factorList=factorList, srcTermProcessing=smm_srctermproc, & From b39fddeeac474e7916f6b177c6d6ccedeb6f958b Mon Sep 17 00:00:00 2001 From: tommelt Date: Wed, 10 Sep 2025 05:27:14 -0600 Subject: [PATCH 13/23] feat: change polemethod back to ESMF_POLEMETHOD_ALLAVG --- src/utils/esmf_lonlat2phys_mod.F90 | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/utils/esmf_lonlat2phys_mod.F90 b/src/utils/esmf_lonlat2phys_mod.F90 index 63a388d10a..e66ed27635 100644 --- a/src/utils/esmf_lonlat2phys_mod.F90 +++ b/src/utils/esmf_lonlat2phys_mod.F90 @@ -13,7 +13,7 @@ module esmf_lonlat2phys_mod use ESMF, only: ESMF_FieldCreate, ESMF_FieldRegridStore use ESMF, only: ESMF_FieldGet, ESMF_FieldRegrid use ESMF, only: ESMF_KIND_I4, ESMF_KIND_R8, ESMF_TYPEKIND_R8 - use ESMF, only: ESMF_REGRIDMETHOD_BILINEAR, ESMF_POLEMETHOD_NONE, ESMF_EXTRAPMETHOD_NEAREST_IDAVG + use ESMF, only: ESMF_REGRIDMETHOD_BILINEAR, ESMF_POLEMETHOD_ALLAVG, ESMF_EXTRAPMETHOD_NEAREST_IDAVG use ESMF, only: ESMF_TERMORDER_SRCSEQ, ESMF_MESHLOC_ELEMENT, ESMF_STAGGERLOC_CENTER use ESMF, only: ESMF_FieldDestroy, ESMF_RouteHandleDestroy use esmf_check_error_mod, only: check_esmf_error @@ -82,7 +82,7 @@ subroutine esmf_lonlat2phys_init() call ESMF_FieldRegridStore(srcField=lonlatfld_3d, dstField=physfld_3d, & regridMethod=ESMF_REGRIDMETHOD_BILINEAR, & - polemethod=ESMF_POLEMETHOD_NONE, & + polemethod=ESMF_POLEMETHOD_ALLAVG, & extrapMethod=ESMF_EXTRAPMETHOD_NEAREST_IDAVG, & routeHandle=rh_lonlat2phys_3d, factorIndexList=factorIndexList, & factorList=factorList, srcTermProcessing=smm_srctermproc, & From 3f1f8df6099811f95ed287675aeb38aef58b7d32 Mon Sep 17 00:00:00 2001 From: tommelt Date: Thu, 11 Sep 2025 09:41:36 -0600 Subject: [PATCH 14/23] feat: regridding works both ways! --- src/physics/cam_dev/physpkg.F90 | 5 +- src/utils/remap.F90 | 225 ++++++++++++++++++++++++-------- 2 files changed, 173 insertions(+), 57 deletions(-) diff --git a/src/physics/cam_dev/physpkg.F90 b/src/physics/cam_dev/physpkg.F90 index 516b39046f..e0689029e1 100644 --- a/src/physics/cam_dev/physpkg.F90 +++ b/src/physics/cam_dev/physpkg.F90 @@ -1183,7 +1183,7 @@ subroutine phys_run2(phys_state, ztodt, phys_tend, pbuf2d, cam_out, & use metdata, only: get_met_srf2 #endif use hemco_interface, only: HCOI_Chunk_Run - use nlgw_remap_mod, only: nlgw_regrid_init, nlgw_regrid, nlgw_regrid_final + use nlgw_remap_mod, only: nlgw_regrid_init, nlgw_latlon_gather, nlgw_latlon_scatter, nlgw_regrid_final ! ! Input arguments ! @@ -1246,7 +1246,8 @@ subroutine phys_run2(phys_state, ztodt, phys_tend, pbuf2d, cam_out, & call t_adj_detailf(+1) call nlgw_regrid_init() - call nlgw_regrid(phys_state) + call nlgw_latlon_gather(phys_state) + call nlgw_latlon_scatter() stop !$OMP PARALLEL DO PRIVATE (C, NCOL, phys_buffer_chunk) diff --git a/src/utils/remap.F90 b/src/utils/remap.F90 index c9b0e03f4c..edfabcc099 100644 --- a/src/utils/remap.F90 +++ b/src/utils/remap.F90 @@ -21,7 +21,8 @@ module nlgw_remap_mod private public :: nlgw_regrid_init - public :: nlgw_regrid + public :: nlgw_latlon_gather + public :: nlgw_latlon_scatter public :: nlgw_regrid_final ! these arrays contain the regridded variables of interest for the NN @@ -29,9 +30,21 @@ module nlgw_remap_mod real(r8), dimension(:,:,:), allocatable, public :: u_grid, v_grid, w_grid, t_grid real(r8), dimension(:,:,:), allocatable, public :: utgw_grid, vtgw_grid + ! arrays on physics grid for use in gw_drag routine + real(r8), dimension(:,:,:), allocatable, target :: utgw_phys + real(r8), dimension(:,:,:), allocatable, target :: vtgw_phys + + + ! private (for book-keeping in MPI calls) + integer, allocatable :: recvcnts(:), displs(:) + integer, allocatable :: beglats(:), beglons(:) + integer, allocatable :: endlats(:), endlons(:) + + contains !----------------------------------------------------------------------------- + ! Initialize arrays and grids for regridding/MPI calls !----------------------------------------------------------------------------- subroutine nlgw_regrid_init() use cam_grid_support, only: horiz_coord_t, horiz_coord_create, iMap, cam_grid_register @@ -53,7 +66,7 @@ subroutine nlgw_regrid_init() character(len=*), parameter :: subname = 'ctem_diags_reg: ' ! initialize grids and mapping - call esmf_lonlat_grid_init(64, 128) + call esmf_lonlat_grid_init(192, 288) call esmf_phys_mesh_init() call esmf_phys2lonlat_init() call esmf_lonlat2phys_init() @@ -110,20 +123,116 @@ subroutine nlgw_regrid_init() nullify(grid_map) + allocate(recvcnts(npes)) + allocate(displs(npes)) + allocate(beglats(npes)) + allocate(beglons(npes)) + allocate(endlats(npes)) + allocate(endlons(npes)) + + ! gathered grids only exist on masterproc + if (masterproc) then + allocate(u_grid(nlon, nlat, pver)) + allocate(v_grid(nlon, nlat, pver)) + allocate(w_grid(nlon, nlat, pver)) + allocate(t_grid(nlon, nlat, pver)) + allocate(phis_grid(nlon, nlat)) + allocate(utgw_grid(nlon, nlat, pver)) + allocate(vtgw_grid(nlon, nlat, pver)) + end if + + allocate(utgw_phys(pver,pcols,begchunk:endchunk)) + allocate(vtgw_phys(pver,pcols,begchunk:endchunk)) + end subroutine nlgw_regrid_init + + !----------------------------------------------------------------------------- + ! This routine takes variables on the regular lat/lon grid: + ! * uses MPI_Scatter to broadcast them from masterproc back to all ranks + ! * interpolates them back onto the cubed-sphere grid + ! * finally re-chunks the data so it can be used elsewhere + !----------------------------------------------------------------------------- + subroutine nlgw_latlon_scatter() + use esmf_lonlat_grid_mod, only: nlat, nlon + use esmf_lonlat2phys_mod, only: fields_bundle_t, n_flx_flds, esmf_lonlat2phys_regrid + use mpishorthand + + ! arrays on latlon grid + real(r8), target :: utgw_lonlat(beglon:endlon,beglat:endlat,pver) + real(r8), target :: vtgw_lonlat(beglon:endlon,beglat:endlat,pver) + + real(r8), allocatable :: flat_array(:) + + integer :: lchnk, ncol, i, sendcnt, disp_sum + + type(fields_bundle_t) :: phys_flx_flds(n_flx_flds) + type(fields_bundle_t) :: lonlat_flx_flds(n_flx_flds) + + call t_startf('nlgw_scatter') + + call t_startf('nlgw_mpiscatter') + ! this subsection gathers all variables onto a single process + + sendcnt = (endlon - beglon + 1) * (endlat - beglat + 1) * pver + + ! mpi gather book-keeping + call mpigather(sendcnt, 1, mpiint, recvcnts, 1, mpiint, 0, mpicom) + call mpigather(beglat, 1, mpiint, beglats, 1, mpiint, 0, mpicom) + call mpigather(beglon, 1, mpiint, beglons, 1, mpiint, 0, mpicom) + call mpigather(endlat, 1, mpiint, endlats, 1, mpiint, 0, mpicom) + call mpigather(endlon, 1, mpiint, endlons, 1, mpiint, 0, mpicom) + + if (masterproc) then + disp_sum = 0 + do i = 1, npes + displs(i) = disp_sum + disp_sum = disp_sum + recvcnts(i) + end do + end if + allocate(flat_array(nlon * nlat * pver)) + + ! unlike in gather case all ranks needs the displs and recvcnts + call mpibcast(displs, npes, mpiint, 0, mpicom) + call mpibcast(recvcnts, npes, mpiint, 0, mpicom) + + utgw_lonlat = 0._r8 + call scatter_3d(u_grid, sendcnt, flat_array, utgw_lonlat(beglon:endlon, beglat:endlat, 1:pver)) + call scatter_3d(vtgw_grid, sendcnt, flat_array, vtgw_lonlat(beglon:endlon, beglat:endlat, 1:pver)) + + deallocate(flat_array) + + call t_stopf('nlgw_mpiscatter') + + call t_startf('nlgw_latlon_gather') + ! this subsection does regridding + + phys_flx_flds(1)%fld => utgw_phys + phys_flx_flds(2)%fld => vtgw_phys + + lonlat_flx_flds(1)%fld => utgw_lonlat + lonlat_flx_flds(2)%fld => vtgw_lonlat + + ! actual call to regrid to lon/lat grid + call esmf_lonlat2phys_regrid(lonlat_flx_flds, phys_flx_flds) + + call t_stopf('nlgw_latlon_gather') + + + call t_stopf('nlgw_scatter') + + end subroutine nlgw_latlon_scatter + + !----------------------------------------------------------------------------- + ! This routine takes variables on the irregular cubed-sphere grid: + ! * gathers all the chunks into a single data structure on each rank + ! * interpolates cubed-sphere variable to a regular lonlat grid + ! * uses MPI_Gather to collect regridded data from all ranks to the masterproc !----------------------------------------------------------------------------- - subroutine nlgw_regrid(phys_state) - use air_composition, only: mbarv ! g/mole - use shr_const_mod, only: rgas => shr_const_rgas ! J/K/kmole - use shr_const_mod, only: grav => shr_const_g ! m/s2 + subroutine nlgw_latlon_gather(phys_state) use esmf_lonlat_grid_mod, only: nlat, nlon - use esmf_phys2lonlat_mod, only: esmf_phys2lonlat_regrid - use esmf_zonal_mean_mod, only: esmf_zonal_mean_calc, esmf_zonal_mean_wsums, esmf_zonal_mean_masked - use interpolate_data, only: lininterp - use esmf_phys2lonlat_mod, only: fields_bundle_t, nflds - use esmf_lonlat2phys_mod, only: esmf_lonlat2phys_regrid, n_flx_flds + use esmf_phys2lonlat_mod, only: fields_bundle_t, nflds, esmf_phys2lonlat_regrid use mpishorthand type(physics_state), intent(in) :: phys_state(begchunk:endchunk) @@ -148,18 +257,10 @@ subroutine nlgw_regrid(phys_state) real(r8), allocatable :: flat_array(:) integer :: lchnk, ncol, i, sendcnt, disp_sum - integer :: lonsize, latsize - - integer, allocatable :: recvcnts(:), displs(:) - integer, allocatable :: beglats(:), beglons(:) - integer, allocatable :: endlats(:), endlons(:) type(fields_bundle_t) :: physflds(nflds) type(fields_bundle_t) :: lonlatflds(nflds) - type(fields_bundle_t) :: phys_flx_flds(n_flx_flds) - type(fields_bundle_t) :: lonlat_flx_flds(n_flx_flds) - call t_startf('nlgw_gather') call t_startf('nlgw_unchunk') @@ -183,7 +284,7 @@ subroutine nlgw_regrid(phys_state) call t_stopf('nlgw_unchunk') - call t_startf('nlgw_regrid') + call t_startf('nlgw_latlon_gather') ! this subsection does regridding physflds(1)%fld => u_phys @@ -204,18 +305,11 @@ subroutine nlgw_regrid(phys_state) ! convert t to theta before gathering ! we dont need ps we need phis - call t_stopf('nlgw_regrid') + call t_stopf('nlgw_latlon_gather') call t_startf('nlgw_mpigather') ! this subsection gathers all variables onto a single process - allocate(recvcnts(npes)) - allocate(displs(npes)) - allocate(beglats(npes)) - allocate(beglons(npes)) - allocate(endlats(npes)) - allocate(endlons(npes)) - sendcnt = (endlon - beglon + 1) * (endlat - beglat + 1) ! mpi gather book-keeping @@ -235,9 +329,7 @@ subroutine nlgw_regrid(phys_state) allocate(flat_array(nlon * nlat)) - call gather_2d(phis_lonlat(beglon:endlon, beglat:endlat), sendcnt, flat_array, recvcnts, displs, & - phis_grid, beglons, endlons, beglats, endlats) - + call gather_2d(phis_lonlat(beglon:endlon, beglat:endlat), sendcnt, flat_array, phis_grid) sendcnt = sendcnt * pver if (masterproc) then @@ -249,14 +341,12 @@ subroutine nlgw_regrid(phys_state) deallocate(flat_array) allocate(flat_array(nlon * nlat * pver)) - call gather_3d(u_lonlat(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, recvcnts, displs, & - u_grid, beglons, endlons, beglats, endlats) - call gather_3d(v_lonlat(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, recvcnts, displs, & - v_grid, beglons, endlons, beglats, endlats) - call gather_3d(w_lonlat(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, recvcnts, displs, & - w_grid, beglons, endlons, beglats, endlats) - call gather_3d(t_lonlat(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, recvcnts, displs, & - t_grid, beglons, endlons, beglats, endlats) + call gather_3d(u_lonlat(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, u_grid) + call gather_3d(v_lonlat(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, v_grid) + call gather_3d(w_lonlat(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, w_grid) + call gather_3d(t_lonlat(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, t_grid) + + deallocate(flat_array) call t_stopf('nlgw_mpigather') @@ -266,20 +356,17 @@ subroutine nlgw_regrid(phys_state) call t_stopf('nlgw_gather') - end subroutine nlgw_regrid + end subroutine nlgw_latlon_gather !----------------------------------------------------------------------------- + ! Utility function for gathering 2D data into a single array !----------------------------------------------------------------------------- - subroutine gather_2d(local_array, sendcnt, flat_array, recvcnts, displs, grid_out, beglons, endlons, beglats, endlats) + subroutine gather_2d(local_array, sendcnt, flat_array, grid_out) use mpishorthand - use esmf_lonlat_grid_mod, only: nlat, nlon real(r8), intent(in) :: local_array(:,:) ! Local 2D array section integer, intent(in) :: sendcnt real(r8), intent(inout) :: flat_array(:) ! Flattened array for gathering - integer, intent(in) :: recvcnts(:), displs(:) - real(r8), allocatable, intent(out) :: grid_out(:,:) ! Full gathered grid - integer, intent(in) :: beglons(:), endlons(:) - integer, intent(in) :: beglats(:), endlats(:) + real(r8), allocatable, intent(inout) :: grid_out(:,:) ! Full gathered grid integer :: i, lonsize, latsize @@ -287,7 +374,6 @@ subroutine gather_2d(local_array, sendcnt, flat_array, recvcnts, displs, grid_ou call mpigatherv(local_array, sendcnt, mpir8, flat_array, recvcnts, displs, mpir8, 0, mpicom) if (masterproc) then - allocate(grid_out(nlon, nlat)) do i = 1, npes lonsize = endlons(i) - beglons(i) + 1 latsize = endlats(i) - beglats(i) + 1 @@ -299,17 +385,14 @@ subroutine gather_2d(local_array, sendcnt, flat_array, recvcnts, displs, grid_ou end subroutine gather_2d !----------------------------------------------------------------------------- + ! Utility function for gathering 3D data into a single array !----------------------------------------------------------------------------- - subroutine gather_3d(local_array, sendcnt, flat_array, recvcnts, displs, grid_out, beglons, endlons, beglats, endlats) + subroutine gather_3d(local_array, sendcnt, flat_array, grid_out) use mpishorthand - use esmf_lonlat_grid_mod, only: nlat, nlon real(r8), intent(in) :: local_array(:,:,:) ! Local 2D array section integer, intent(in) :: sendcnt real(r8), intent(inout) :: flat_array(:) ! Flattened array for gathering - integer, intent(in) :: recvcnts(:), displs(:) - real(r8), allocatable, intent(out) :: grid_out(:,:,:) ! Full gathered grid - integer, intent(in) :: beglons(:), endlons(:) - integer, intent(in) :: beglats(:), endlats(:) + real(r8), allocatable, intent(inout) :: grid_out(:,:,:) ! Full gathered grid integer :: i, lonsize, latsize @@ -317,7 +400,6 @@ subroutine gather_3d(local_array, sendcnt, flat_array, recvcnts, displs, grid_ou call mpigatherv(local_array, sendcnt, mpir8, flat_array, recvcnts, displs, mpir8, 0, mpicom) if (masterproc) then - allocate(grid_out(nlon, nlat, pver)) do i = 1, npes lonsize = endlons(i) - beglons(i) + 1 latsize = endlats(i) - beglats(i) + 1 @@ -329,6 +411,33 @@ subroutine gather_3d(local_array, sendcnt, flat_array, recvcnts, displs, grid_ou end subroutine gather_3d !----------------------------------------------------------------------------- + ! Utility function for scattering 3D data into lonlat arrays + !----------------------------------------------------------------------------- + subroutine scatter_3d(grid_in, sendcnt, flat_array, lonlat_out) + use mpishorthand + real(r8), allocatable, intent(in) :: grid_in(:,:,:) ! Local 2D array section + integer, intent(in) :: sendcnt + real(r8), intent(inout) :: flat_array(:) ! temporary storage in flat array + real(r8), target, intent(inout) :: lonlat_out(:,:,:) ! Full scattered grid + + integer :: i, lonsize, latsize + + if (masterproc) then + do i = 1, npes + lonsize = endlons(i) - beglons(i) + 1 + latsize = endlats(i) - beglats(i) + 1 + flat_array(displs(i)+1:displs(i)+sendcnt) = & + reshape(grid_in(beglons(i):endlons(i), beglats(i):endlats(i), 1:pver), (/lonsize* latsize * pver/)) + end do + end if + + ! scatter variables onto master proc into a flat array (can't do 2D/3D mpiscatter) + call mpiscatterv(flat_array, recvcnts, displs, mpir8, lonlat_out, sendcnt, mpir8, 0, mpicom) + + end subroutine scatter_3d + + !----------------------------------------------------------------------------- + ! Tidy up (free allocated memory) !----------------------------------------------------------------------------- subroutine nlgw_regrid_final() use esmf_phys2lonlat_mod, only: esmf_phys2lonlat_destroy @@ -341,8 +450,8 @@ subroutine nlgw_regrid_final() call esmf_lonlat_grid_destroy() call esmf_phys_mesh_destroy() + ! TODO double check ALL deallocates here if (masterproc) then - ! TODO ALL deallocates here deallocate(phis_grid) deallocate(u_grid) deallocate(v_grid) @@ -352,6 +461,12 @@ subroutine nlgw_regrid_final() deallocate(vtgw_grid) end if + deallocate(recvcnts) + deallocate(displs) + deallocate(beglats) + deallocate(beglons) + deallocate(endlats) + deallocate(endlons) end subroutine nlgw_regrid_final From e460ad615f767a2e2337035abab079c91a56c09b Mon Sep 17 00:00:00 2001 From: tommelt Date: Fri, 12 Sep 2025 03:52:34 -0600 Subject: [PATCH 15/23] chore: restructure gw_nlgw ready for gw_nlgw_unet --- src/physics/cam/gw_drag.F90 | 8 +-- .../cam/{gw_nlgw.F90 => gw_nlgw_ann.F90} | 52 +++++-------------- src/physics/cam/gw_nlgw_utils.F90 | 43 +++++++++++++++ 3 files changed, 61 insertions(+), 42 deletions(-) rename src/physics/cam/{gw_nlgw.F90 => gw_nlgw_ann.F90} (92%) create mode 100644 src/physics/cam/gw_nlgw_utils.F90 diff --git a/src/physics/cam/gw_drag.F90 b/src/physics/cam/gw_drag.F90 index 9a3a651060..2edafbfaf6 100644 --- a/src/physics/cam/gw_drag.F90 +++ b/src/physics/cam/gw_drag.F90 @@ -44,7 +44,7 @@ module gw_drag use gw_front, only: CMSourceDesc use gw_ml, only: gw_drag_convect_dp_ml_init, gw_drag_convect_dp_ml_final, & gw_drag_convect_dp_ml - use gw_nlgw, only: gw_nlgw_dp_ml, gw_nlgw_dp_init, gw_nlgw_dp_finalize + use gw_nlgw_ann, only: gw_nlgw_ann_infer, gw_nlgw_ann_init, gw_nlgw_ann_finalize ! Typical module header implicit none @@ -578,7 +578,7 @@ subroutine gw_init() errMsg(__FILE__, __LINE__)) if ( use_gw_nlgw ) then - call gw_nlgw_dp_init(gw_nlgw_model_path) + call gw_nlgw_ann_init(gw_nlgw_model_path) end if if ( use_gw_oro ) then @@ -1299,7 +1299,7 @@ subroutine gw_final() call gw_drag_convect_dp_ml_final() endif if ( use_gw_nlgw ) then - call gw_nlgw_dp_finalize() + call gw_nlgw_ann_finalize() end if end subroutine gw_final @@ -1549,7 +1549,7 @@ subroutine gw_tend(state, pbuf, dt, ptend, cam_in, flx_heat) flx_heat = 0._r8 if ( use_gw_nlgw ) then - call gw_nlgw_dp_ml(state1,ptend) + call gw_nlgw_ann_infer(state1,ptend) end if if (use_gw_convect_dp) then diff --git a/src/physics/cam/gw_nlgw.F90 b/src/physics/cam/gw_nlgw_ann.F90 similarity index 92% rename from src/physics/cam/gw_nlgw.F90 rename to src/physics/cam/gw_nlgw_ann.F90 index 647b771e68..4d044ce8ed 100644 --- a/src/physics/cam/gw_nlgw.F90 +++ b/src/physics/cam/gw_nlgw_ann.F90 @@ -1,4 +1,4 @@ -module gw_nlgw +module gw_nlgw_ann ! ! This module predicts gravity wave forcings via PyTorch NNs trained to include non-local gravity wave effects @@ -17,7 +17,7 @@ module gw_nlgw implicit none -public :: gw_nlgw_dp_ml, gw_nlgw_dp_init, gw_nlgw_dp_finalize +public :: gw_nlgw_ann_infer, gw_nlgw_ann_init, gw_nlgw_ann_finalize private @@ -104,7 +104,9 @@ module gw_nlgw !========================================================================== -subroutine gw_nlgw_dp_ml(state_in, ptend) +subroutine gw_nlgw_ann_infer(state_in, ptend) + + use gw_nlgw_utils, only: flux_to_forcing ! inputs type(physics_state), intent(in) :: state_in @@ -170,8 +172,8 @@ subroutine gw_nlgw_dp_ml(state_in, ptend) call extract_output() call denormalise_data() - call flux_to_forcing(uflux, utgw) - call flux_to_forcing(vflux, vtgw) + call flux_to_forcing(uflux, utgw, pmid, ncol) + call flux_to_forcing(vflux, vtgw, pmid, ncol) ! update the tendencies ptend%u(:ncol,:pver) = ptend%u(:ncol,:pver) + utgw(:ncol,:pver) @@ -200,10 +202,10 @@ subroutine gw_nlgw_dp_ml(state_in, ptend) deallocate(net_inputs) deallocate(net_outputs) -end subroutine gw_nlgw_dp_ml +end subroutine gw_nlgw_ann_infer -subroutine gw_nlgw_dp_init(model_path) +subroutine gw_nlgw_ann_init(model_path) character(len=*), intent(in) :: model_path ! Filepath to PyTorch Torchscript net integer :: device_id @@ -219,17 +221,17 @@ subroutine gw_nlgw_dp_init(model_path) write(iulog,*)'nlgw model loaded from: ', model_path endif -end subroutine gw_nlgw_dp_init +end subroutine gw_nlgw_ann_init -subroutine gw_nlgw_dp_finalize() +subroutine gw_nlgw_ann_finalize() deallocate(net_inputs) deallocate(net_outputs) ! free model memory call torch_delete(nlgw_model) -end subroutine gw_nlgw_dp_finalize +end subroutine gw_nlgw_ann_finalize subroutine read_norms() @@ -259,6 +261,7 @@ subroutine read_norms() end subroutine read_norms subroutine normalise_data() + use gw_nlgw_utils, only: cbrt ! lat lon are in radians (convert to degrees first) lat = lat * 180. / pi @@ -363,31 +366,4 @@ subroutine denormalise_data() end subroutine denormalise_data -elemental function cbrt(a) result(root) - real(r8), intent(in) :: a - real(r8), parameter :: one_third = 1._r8/3._r8 - real(r8) :: root - root = sign(abs(a)**one_third, a) -end function cbrt - -subroutine flux_to_forcing(flux, forcing) - - real(r8), intent(in), dimension(:,:) :: flux - real(r8), intent(out), dimension(:,:) :: forcing ! forcing = -d(u'\omega')/d(p), units = m/s^2 - - integer :: level, col - - ! convert fluxes to tendencies - ! pressure profile must be in Pascals - - do col = 1, ncol - forcing(col,1) = -1*(flux(col,2) - flux(col,1))/(pmid(col,2) - pmid(col,1)) - do level = 2, pver-1 - forcing(col,level) = (flux(col,level+1) - flux(col,level-1)) / (pmid(col,level)*(log(pmid(col,level+1)) - log(pmid(col,level-1)))) - end do - forcing(col,pver) = -1*(flux(col,pver) - flux(col,pver-1)) / (pmid(col,pver) - pmid(col,pver-1)) - end do - -end subroutine flux_to_forcing - -end module gw_nlgw +end module gw_nlgw_ann diff --git a/src/physics/cam/gw_nlgw_utils.F90 b/src/physics/cam/gw_nlgw_utils.F90 new file mode 100644 index 0000000000..399fab29c5 --- /dev/null +++ b/src/physics/cam/gw_nlgw_utils.F90 @@ -0,0 +1,43 @@ +module gw_nlgw_utils + +use gw_utils, only: r8, r4 +use ppgrid, only: pver !vertical levels + +implicit none + +public :: cbrt, flux_to_forcing + +private + +contains + +elemental function cbrt(a) result(root) + real(r8), intent(in) :: a + real(r8), parameter :: one_third = 1._r8/3._r8 + real(r8) :: root + root = sign(abs(a)**one_third, a) +end function cbrt + +subroutine flux_to_forcing(flux, forcing, pmid, ncol) + + real(r8), intent(in), dimension(:,:) :: flux ! flux (Pa m/s^2) !TODO check with Aman + real(r8), intent(in), dimension(:,:) :: pmid ! midpoint pressure (Pa) + real(r8), intent(out), dimension(:,:) :: forcing ! forcing = -d(u'\omega')/d(p), units = m/s^2 + integer, intent(in) :: ncol + + integer :: level, col + + ! convert fluxes to tendencies + ! pressure profile must be in Pascals + + do col = 1, ncol + forcing(col,1) = -1*(flux(col,2) - flux(col,1))/(pmid(col,2) - pmid(col,1)) + do level = 2, pver-1 + forcing(col,level) = (flux(col,level+1) - flux(col,level-1)) / (pmid(col,level)*(log(pmid(col,level+1)) - log(pmid(col,level-1)))) + end do + forcing(col,pver) = -1*(flux(col,pver) - flux(col,pver-1)) / (pmid(col,pver) - pmid(col,pver-1)) + end do + +end subroutine flux_to_forcing + +end module gw_nlgw_utils From 298fdf9d8baf98e05c0bdcfe2912ea365c294032 Mon Sep 17 00:00:00 2001 From: tommelt Date: Fri, 12 Sep 2025 03:52:51 -0600 Subject: [PATCH 16/23] chore: tidy up remap.F90 --- src/utils/remap.F90 | 36 +++++++----------------------------- 1 file changed, 7 insertions(+), 29 deletions(-) diff --git a/src/utils/remap.F90 b/src/utils/remap.F90 index edfabcc099..3dbf6b34d5 100644 --- a/src/utils/remap.F90 +++ b/src/utils/remap.F90 @@ -26,7 +26,6 @@ module nlgw_remap_mod public :: nlgw_regrid_final ! these arrays contain the regridded variables of interest for the NN - real(r8), dimension(:, :), allocatable, public :: phis_grid real(r8), dimension(:,:,:), allocatable, public :: u_grid, v_grid, w_grid, t_grid real(r8), dimension(:,:,:), allocatable, public :: utgw_grid, vtgw_grid @@ -136,7 +135,6 @@ subroutine nlgw_regrid_init() allocate(v_grid(nlon, nlat, pver)) allocate(w_grid(nlon, nlat, pver)) allocate(t_grid(nlon, nlat, pver)) - allocate(phis_grid(nlon, nlat)) allocate(utgw_grid(nlon, nlat, pver)) allocate(vtgw_grid(nlon, nlat, pver)) end if @@ -242,7 +240,6 @@ subroutine nlgw_latlon_gather(phys_state) real(r8), target :: v_phys(pver,pcols,begchunk:endchunk) real(r8), target :: w_phys(pver,pcols,begchunk:endchunk) real(r8), target :: t_phys(pver,pcols,begchunk:endchunk) - real(r8) :: phis_phys(pcols,begchunk:endchunk) ! for debugging only ! real(r8) :: lat_phys(pcols,begchunk:endchunk) ! real(r8) :: lon_phys(pcols,begchunk:endchunk) @@ -252,7 +249,6 @@ subroutine nlgw_latlon_gather(phys_state) real(r8), target :: v_lonlat(beglon:endlon,beglat:endlat,pver) real(r8), target :: w_lonlat(beglon:endlon,beglat:endlat,pver) real(r8), target :: t_lonlat(beglon:endlon,beglat:endlat,pver) - real(r8) :: phis_lonlat(beglon:endlon,beglat:endlat) real(r8), allocatable :: flat_array(:) @@ -274,7 +270,6 @@ subroutine nlgw_latlon_gather(phys_state) w_phys(:,i,lchnk) = phys_state(lchnk)%omega(i,:) t_phys(:,i,lchnk) = phys_state(lchnk)%t(i,:) - phis_phys(i,lchnk) = phys_state(lchnk)%ps(i) ! for debugging only ! lat_phys(i,lchnk) = phys_state(lchnk)%lat(i) ! lon_phys(i,lchnk) = phys_state(lchnk)%lon(i) @@ -285,7 +280,6 @@ subroutine nlgw_latlon_gather(phys_state) call t_stopf('nlgw_unchunk') call t_startf('nlgw_latlon_gather') - ! this subsection does regridding physflds(1)%fld => u_phys physflds(2)%fld => v_phys @@ -299,18 +293,13 @@ subroutine nlgw_latlon_gather(phys_state) ! actual call to regrid to lon/lat grid call esmf_phys2lonlat_regrid(physflds, lonlatflds) - call esmf_phys2lonlat_regrid(phis_phys, phis_lonlat) - - ! TODO - ! convert t to theta before gathering - ! we dont need ps we need phis call t_stopf('nlgw_latlon_gather') call t_startf('nlgw_mpigather') ! this subsection gathers all variables onto a single process - sendcnt = (endlon - beglon + 1) * (endlat - beglat + 1) + sendcnt = (endlon - beglon + 1) * (endlat - beglat + 1) * pver ! mpi gather book-keeping call mpigather(sendcnt, 1, mpiint, recvcnts, 1, mpiint, 0, mpicom) @@ -326,19 +315,6 @@ subroutine nlgw_latlon_gather(phys_state) disp_sum = disp_sum + recvcnts(i) end do end if - - allocate(flat_array(nlon * nlat)) - - call gather_2d(phis_lonlat(beglon:endlon, beglat:endlat), sendcnt, flat_array, phis_grid) - - sendcnt = sendcnt * pver - if (masterproc) then - do i = 1, npes - displs(i) = displs(i) * pver - recvcnts(i) = recvcnts(i) * pver - end do - end if - deallocate(flat_array) allocate(flat_array(nlon * nlat * pver)) call gather_3d(u_lonlat(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, u_grid) @@ -389,7 +365,7 @@ end subroutine gather_2d !----------------------------------------------------------------------------- subroutine gather_3d(local_array, sendcnt, flat_array, grid_out) use mpishorthand - real(r8), intent(in) :: local_array(:,:,:) ! Local 2D array section + real(r8), intent(in) :: local_array(:,:,:) ! Local 3D array section integer, intent(in) :: sendcnt real(r8), intent(inout) :: flat_array(:) ! Flattened array for gathering real(r8), allocatable, intent(inout) :: grid_out(:,:,:) ! Full gathered grid @@ -415,7 +391,7 @@ end subroutine gather_3d !----------------------------------------------------------------------------- subroutine scatter_3d(grid_in, sendcnt, flat_array, lonlat_out) use mpishorthand - real(r8), allocatable, intent(in) :: grid_in(:,:,:) ! Local 2D array section + real(r8), allocatable, intent(in) :: grid_in(:,:,:) ! Local 3D array section integer, intent(in) :: sendcnt real(r8), intent(inout) :: flat_array(:) ! temporary storage in flat array real(r8), target, intent(inout) :: lonlat_out(:,:,:) ! Full scattered grid @@ -431,7 +407,7 @@ subroutine scatter_3d(grid_in, sendcnt, flat_array, lonlat_out) end do end if - ! scatter variables onto master proc into a flat array (can't do 2D/3D mpiscatter) + ! scatter variables from flat_array back to all processes call mpiscatterv(flat_array, recvcnts, displs, mpir8, lonlat_out, sendcnt, mpir8, 0, mpicom) end subroutine scatter_3d @@ -452,7 +428,6 @@ subroutine nlgw_regrid_final() ! TODO double check ALL deallocates here if (masterproc) then - deallocate(phis_grid) deallocate(u_grid) deallocate(v_grid) deallocate(w_grid) @@ -468,6 +443,9 @@ subroutine nlgw_regrid_final() deallocate(endlats) deallocate(endlons) + deallocate(utgw_phys) + deallocate(vtgw_phys) + end subroutine nlgw_regrid_final end module nlgw_remap_mod From d593e054bc56644510a73adcca31a307b65187c4 Mon Sep 17 00:00:00 2001 From: tommelt Date: Wed, 29 Oct 2025 04:21:52 -0600 Subject: [PATCH 17/23] feat: (wip) add UNet model to CAM - Work in progress - UNet runs once successfully. Generating fluxes from the inputs. - Fluxes need to be double checked - Currently we do not compute tendencies - This will be done in a subsequent commit --- README.md | 21 ++- bld/build-namelist | 3 +- bld/namelist_files/namelist_definition.xml | 21 ++- src/physics/cam/gw_drag.F90 | 22 ++- src/physics/cam/gw_nlgw_ann.F90 | 3 +- src/physics/cam/gw_nlgw_unet.F90 | 195 ++++++++++++++++++++ src/physics/cam/gw_nlgw_utils.F90 | 45 ++++- src/physics/cam/phys_control.F90 | 13 +- src/physics/cam_dev/physpkg.F90 | 42 ++++- src/utils/remap.F90 | 201 ++++++++++++--------- 10 files changed, 450 insertions(+), 116 deletions(-) create mode 100644 src/physics/cam/gw_nlgw_unet.F90 diff --git a/README.md b/README.md index eb22f4e47b..e9a5f69082 100644 --- a/README.md +++ b/README.md @@ -120,15 +120,26 @@ gw_convect_dp_ml_norms='/path/to/norms' To run CAM using the non local gravity wave ML model to replace all parameterisations use the following configuration ```fortran -use_gw_nlgw=.true. -gw_nlgw_model_path='/path/to/nlgw-scripted-model.pt' +use_gw_nlgw_ann=.true. +use_gw_nlgw_unet=.true. +gw_nlgw_model_path_ann='/path/to/ann-scripted-model.pt' +gw_nlgw_model_path_unet='/path/to/unet-scripted-model.pt' ``` -* `use_gw_nlgw` (`logical`) +* `use_gw_nlgw_ann` (`logical`) - Whether or not to use the ML scheme for non local gravity waves. Default: `.false.` + Whether or not to use the ANN ML scheme for non local gravity waves. Default: `.false.` -* `gw_nlgw_model_path` +* `gw_nlgw_model_path_ann` + + Absolute filepath to the non local gravity wave neural net used when `use_gw_nlgw` is set to `.true.` (`.pt` + extension). + +* `use_gw_nlgw_unet` (`logical`) + + Whether or not to use the UNET ML scheme for non local gravity waves. Default: `.false.` + +* `gw_nlgw_model_path_unet` Absolute filepath to the non local gravity wave neural net used when `use_gw_nlgw` is set to `.true.` (`.pt` extension). diff --git a/bld/build-namelist b/bld/build-namelist index 0ffb144b48..9a8e1b88dc 100755 --- a/bld/build-namelist +++ b/bld/build-namelist @@ -3606,7 +3606,8 @@ if (!$simple_phys) { add_default($nl, 'use_gw_rdg_gamma' , 'val'=>'.false.'); add_default($nl, 'use_gw_front_igw' , 'val'=>'.false.'); add_default($nl, 'use_gw_convect_sh', 'val'=>'.false.'); - add_default($nl, 'use_gw_nlgw' , 'val'=>'.false.'); + add_default($nl, 'use_gw_nlgw_ann' , 'val'=>'.false.'); + add_default($nl, 'use_gw_nlgw_unet' , 'val'=>'.false.'); add_default($nl, 'gw_lndscl_sgh'); add_default($nl, 'gw_oro_south_fac'); add_default($nl, 'gw_limit_tau_without_eff'); diff --git a/bld/namelist_files/namelist_definition.xml b/bld/namelist_files/namelist_definition.xml index 52e8313b21..b783f836d7 100644 --- a/bld/namelist_files/namelist_definition.xml +++ b/bld/namelist_files/namelist_definition.xml @@ -1332,10 +1332,17 @@ Whether or not to enable gravity waves produced by shallow convection. Default: .false. - Whether or not to enable gravity waves produced by non-local gravity -wave ML model. +wave ANN ML model. +Default: set by build-namelist. + + + +Whether or not to enable gravity waves produced by non-local gravity +wave UNet ML model. Default: set by build-namelist. @@ -1428,10 +1435,16 @@ Absolute filepath to the deep convection gravity wave neural net used when Default: .false. - +Absolute filepath to the non local gravity wave traced model (.pt) +used when `use_gw_nlgw_ann` is set to `.true.`. + + + Absolute filepath to the non local gravity wave traced model (.pt) -used when `use_gw_nlgw` is set to `.true.`. +used when `use_gw_nlgw_unet` is set to `.true.`. masterprocid, masterproc, mpi_real8, iam +use cam_abortutils, only: endrun +use cam_logfile, only: iulog +use physconst, only: cappa +use gw_nlgw_utils, only: lonlat_vars, nlon, nlat + +use ftorch + +implicit none + +public :: gw_nlgw_unet_init, gw_nlgw_unet_infer, gw_nlgw_unet_finalize + +private + +type(torch_model) :: nlgw_model ! pytorch model + +real(r8), dimension(:,:,:), allocatable :: & + uflux, &! zonal wind flux (Pa) + vflux, &! meridional wind flux (Pa) + utgw, &! zonal wind tendency (m/s^2) + vtgw ! meridional wind tendency (m/s^2) + +real(r8), dimension(:,:,:), allocatable :: & + uflux_grid, &! zonal wind flux (Pa) + vflux_grid ! meridional wind flux (Pa) + +real(r4), dimension(:,:,:,:), allocatable, target :: net_inputs +real(r4), dimension(:,:,:,:), allocatable, target :: net_outputs + +! normalisation means and std devs +real(r8) :: u_mean, v_mean, w_mean, theta_mean +real(r8) :: u_std, v_std, w_std, theta_std + +real(r8) :: uflux_mean, vflux_mean +real(r8) :: uflux_std, vflux_std + +contains + +!========================================================================== + + +subroutine gw_nlgw_unet_init(model_path) + + character(len=*), intent(in) :: model_path ! Filepath to PyTorch Torchscript net + integer :: device_id + + device_id = 0 + + ! Load the convective drag net from TorchScript file + call torch_model_load(nlgw_model, model_path, device_type=torch_kCUDA, device_index=device_id) + ! read in normalisation weights + call read_norms() + + if (masterproc) then + write(iulog,*)'nlgw model loaded from: ', model_path + + ! UNet will only run on the master process + ! space for u, v, theta and w + allocate(net_inputs(1, pver*4, nlat, nlon)) + ! space for uflux and vflux + allocate(net_outputs(1, pver*2, nlat, nlon)) + endif + +end subroutine gw_nlgw_unet_init + +subroutine gw_nlgw_unet_infer(gathered_lonlat) + + ! global unet data + type(lonlat_vars), intent(inout) :: gathered_lonlat + + !---------------------------Local storage------------------------------- + type(torch_tensor) :: tensor_in(1), tensor_out(1) + integer :: ninputs = 1, noutputs = 1 + integer, dimension(4) :: layout = [1 , 2, 3, 4] + + integer :: device_id + + device_id = 0 + + ! Normalise and construct the input + call normalise_data(gathered_lonlat) + call construct_input(gathered_lonlat) + + ! send all columns from this process + call torch_tensor_from_array(tensor_in(1), net_inputs, layout, torch_kCUDA, device_id) + call torch_tensor_from_array(tensor_out(1), net_outputs, layout, torch_kCPU) + + ! Run net forward on data + call torch_model_forward(nlgw_model, tensor_in, tensor_out) + + ! Extract and denormalise outputs + call extract_output(gathered_lonlat) + call denormalise_data(gathered_lonlat) + + ! Clean up the tensors + call torch_delete(tensor_in) + call torch_delete(tensor_out) + +end subroutine gw_nlgw_unet_infer + +subroutine gw_nlgw_unet_finalize() + + if (masterproc) then + deallocate(net_inputs) + deallocate(net_outputs) + end if + ! free model memory + call torch_delete(nlgw_model) + +end subroutine gw_nlgw_unet_finalize + +subroutine read_norms() + + ! TODO + ! - replace hardcoded means/std devs with netcdf file? + + u_mean = 6.717847278462159_r8 + v_mean = -0.002744777264668839_r8 + theta_mean = 0._r8 + w_mean = 0.0013401482063147452_r8 + + u_std = 20.760385183200206_r8 + v_std = 9.877389116738264_r8 + theta_std = 1000._r8 + w_std = 0.11202126259282257_r8 + + uflux_mean = -0.0004691528666736032_r8 + vflux_mean = -0.0002586195082961397_r8 + uflux_std = 0.032814051953840274_r8 + vflux_std = 0.03024781201672967_r8 + +end subroutine read_norms + +subroutine normalise_data(gathered_lonlat) + use gw_nlgw_utils, only: cbrt + type(lonlat_vars), intent(inout) :: gathered_lonlat + + gathered_lonlat%u = (gathered_lonlat%u-u_mean)/(3._r8 * u_std) + gathered_lonlat%v = (gathered_lonlat%v-v_mean)/(3._r8 * v_std) + gathered_lonlat%theta = (gathered_lonlat%theta-theta_mean)/theta_std + gathered_lonlat%w = (gathered_lonlat%w-w_mean)/w_std + gathered_lonlat%w = cbrt(gathered_lonlat%w) + +end subroutine normalise_data + +subroutine construct_input(gathered_lonlat) + + type(lonlat_vars), intent(inout) :: gathered_lonlat + integer :: idx_beg, idx_end, i + + idx_end = 0 + + idx_beg = idx_end + 1 + idx_end = idx_end + pver + net_inputs(1,idx_beg:idx_end,:,:) = reshape(gathered_lonlat%u, shape=[pver, nlat, nlon], order=[3,2,1]) + idx_beg = idx_end + 1 + idx_end = idx_end + pver + net_inputs(1,idx_beg:idx_end,:,:) = reshape(gathered_lonlat%v, shape=[pver, nlat, nlon], order=[3,2,1]) + idx_beg = idx_end + 1 + idx_end = idx_end + pver + net_inputs(1,idx_beg:idx_end,:,:) = reshape(gathered_lonlat%theta, shape=[pver, nlat, nlon], order=[3,2,1]) + idx_beg = idx_end + 1 + idx_end = idx_end + pver + net_inputs(1,idx_beg:idx_end,:,:) = reshape(gathered_lonlat%w, shape=[pver, nlat, nlon], order=[3,2,1]) + +end subroutine construct_input + +subroutine extract_output(gathered_lonlat) + + type(lonlat_vars), intent(inout) :: gathered_lonlat + + gathered_lonlat%uflux(:,:,:) = reshape(net_outputs(1,:pver,:,:), shape=[nlon, nlat, pver], order=[3,2,1]) + gathered_lonlat%vflux(:,:,:) = reshape(net_outputs(1,pver+1:,:,:), shape=[nlon, nlat, pver], order=[3,2,1]) + +end subroutine extract_output + +subroutine denormalise_data(gathered_lonlat) + + type(lonlat_vars), intent(inout) :: gathered_lonlat + + gathered_lonlat%uflux = gathered_lonlat%uflux**3._r8 * uflux_std + uflux_mean + gathered_lonlat%vflux = gathered_lonlat%vflux**3._r8 * vflux_std + vflux_mean + +end subroutine denormalise_data + +end module gw_nlgw_unet diff --git a/src/physics/cam/gw_nlgw_utils.F90 b/src/physics/cam/gw_nlgw_utils.F90 index 399fab29c5..d0cd5e7a10 100644 --- a/src/physics/cam/gw_nlgw_utils.F90 +++ b/src/physics/cam/gw_nlgw_utils.F90 @@ -1,14 +1,57 @@ module gw_nlgw_utils use gw_utils, only: r8, r4 -use ppgrid, only: pver !vertical levels +use ppgrid, only: begchunk, endchunk, pcols, pver, pverp implicit none public :: cbrt, flux_to_forcing +public :: phys_vars, lonlat_vars +integer, parameter, public :: p0 = 100000 ! 1000 hPa (Pa) +integer, parameter, public :: nlon = 288 ! number of longitude points on lonlat grid +integer, parameter, public :: nlat = 192 ! number of latitude points on lonlat grid private +! variables on cubed-sphere "phys" grid +type phys_vars +!dimension(pver,pcols,begchunk:endchunk) +real(r8), dimension(:,:,:), allocatable :: & + u, &! zonal wind (m/s) + v, &! meridional wind (m/s) + theta, &! temperature (K) + w, &! vertical pressure velocity (Pa/s) + pmid ! midpoint pressure (Pa) + +real(r8), dimension(:,:,:), allocatable :: & + uflux, &! zonal fluxes + vflux ! meridional fluxes + +real(r8), dimension(:,:,:), allocatable :: & + utgw, &! zonal tendencies + vtgw ! meridional tendencies + +! for debugging only +! dimension(pcols,begchunk:endchunk) +real(r8), dimension(:,:), allocatable :: & + lat, & + lon +end type + +! variables on regular lonlat grid +type lonlat_vars +! dimension(lon,lat,pver) +real(r8), dimension(:,:,:), allocatable :: & + u, &! zonal wind (m/s) + v, &! meridional wind (m/s) + theta, &! temperature (K) + w ! vertical pressure velocity (Pa/s) + +real(r8), dimension(:,:,:), allocatable :: & + uflux, &! zonal fluxes + vflux ! meridional fluxes +end type + contains elemental function cbrt(a) result(root) diff --git a/src/physics/cam/phys_control.F90 b/src/physics/cam/phys_control.F90 index 7497385583..93263e29a3 100644 --- a/src/physics/cam/phys_control.F90 +++ b/src/physics/cam/phys_control.F90 @@ -98,7 +98,8 @@ module phys_control logical, public, protected :: use_gw_front_igw = .false. ! Frontogenesis to inertial spectrum. logical, public, protected :: use_gw_convect_dp = .false. ! Deep convection. logical, public, protected :: use_gw_convect_sh = .false. ! Shallow convection. -logical, public, protected :: use_gw_nlgw = .false. ! non local GW ML model +logical, public, protected :: use_gw_nlgw_ann = .false. ! non local GW ML model (ANN - single column) +logical, public, protected :: use_gw_nlgw_unet = .false. ! non local GW ML model (UNet - global non local) ! FV dycore angular momentum correction logical, public, protected :: fv_am_correction = .false. @@ -137,7 +138,7 @@ subroutine phys_ctl_readnl(nlfile) history_waccmx, history_chemistry, history_carma, history_clubb, history_dust, & history_cesm_forcing, history_scwaccm_forcing, history_chemspecies_srf, & do_clubb_sgs, state_debug_checks, use_hetfrz_classnuc, use_gw_oro, use_gw_front, & - use_gw_front_igw, use_gw_convect_dp, use_gw_convect_sh, use_gw_nlgw, cld_macmic_num_steps, & + use_gw_front_igw, use_gw_convect_dp, use_gw_convect_sh, use_gw_nlgw_ann, use_gw_nlgw_unet, cld_macmic_num_steps, & offline_driver, convproc_do_aer, cam_snapshot_before_num, cam_snapshot_after_num, & cam_take_snapshot_before, cam_take_snapshot_after, cam_physics_mesh, use_hemco, do_hb_above_clubb !----------------------------------------------------------------------------- @@ -157,7 +158,10 @@ subroutine phys_ctl_readnl(nlfile) end if ! if we are using the nlgw ML model we need to disable all other parameterizations - if (masterproc .and. use_gw_nlgw==.true.) then + if (masterproc .and. (use_gw_nlgw_ann .or. use_gw_nlgw_unet)) then + if (use_gw_nlgw_ann .and. use_gw_nlgw_unet) then + call endrun(subname // ':: ERROR you can only select UNet or ANN not both.') + end if use_gw_oro = .false. use_gw_front = .false. use_gw_front_igw = .false. @@ -203,7 +207,8 @@ subroutine phys_ctl_readnl(nlfile) call mpi_bcast(use_gw_front_igw, 1, mpi_logical, masterprocid, mpicom, ierr) call mpi_bcast(use_gw_convect_dp, 1, mpi_logical, masterprocid, mpicom, ierr) call mpi_bcast(use_gw_convect_sh, 1, mpi_logical, masterprocid, mpicom, ierr) - call mpi_bcast(use_gw_nlgw, 1, mpi_logical, masterprocid, mpicom, ierr) + call mpi_bcast(use_gw_nlgw_ann, 1, mpi_logical, masterprocid, mpicom, ierr) + call mpi_bcast(use_gw_nlgw_unet, 1, mpi_logical, masterprocid, mpicom, ierr) call mpi_bcast(cld_macmic_num_steps, 1, mpi_integer, masterprocid, mpicom, ierr) call mpi_bcast(offline_driver, 1, mpi_logical, masterprocid, mpicom, ierr) call mpi_bcast(convproc_do_aer, 1, mpi_logical, masterprocid, mpicom, ierr) diff --git a/src/physics/cam_dev/physpkg.F90 b/src/physics/cam_dev/physpkg.F90 index e0689029e1..8a2d174d76 100644 --- a/src/physics/cam_dev/physpkg.F90 +++ b/src/physics/cam_dev/physpkg.F90 @@ -1183,7 +1183,9 @@ subroutine phys_run2(phys_state, ztodt, phys_tend, pbuf2d, cam_out, & use metdata, only: get_met_srf2 #endif use hemco_interface, only: HCOI_Chunk_Run - use nlgw_remap_mod, only: nlgw_regrid_init, nlgw_latlon_gather, nlgw_latlon_scatter, nlgw_regrid_final + use nlgw_remap_mod, only: nlgw_regrid_init, nlgw_latlon_gather, nlgw_latlon_scatter, nlgw_regrid_final + use gw_nlgw_unet, only: gw_nlgw_unet_init, gw_nlgw_unet_infer, gw_nlgw_unet_finalize + use gw_nlgw_utils, only: phys_vars, lonlat_vars, flux_to_forcing ! ! Input arguments ! @@ -1204,6 +1206,10 @@ subroutine phys_run2(phys_state, ztodt, phys_tend, pbuf2d, cam_out, & integer :: c ! chunk index integer :: ncol ! number of columns type(physics_buffer_desc),pointer, dimension(:) :: phys_buffer_chunk + ! for ML UNet model + type(phys_vars), target :: phys + type(lonlat_vars), target :: lonlat, gathered_lonlat + real(r8), dimension(pcols,pver) :: temp_uflux, temp_vflux, temp_utgw, temp_vtgw ! ! If exit condition just return ! @@ -1245,9 +1251,37 @@ subroutine phys_run2(phys_state, ztodt, phys_tend, pbuf2d, cam_out, & call t_startf ('ac_physics') call t_adj_detailf(+1) - call nlgw_regrid_init() - call nlgw_latlon_gather(phys_state) - call nlgw_latlon_scatter() + call nlgw_regrid_init(phys, lonlat, gathered_lonlat) + + ! gather data from all procs, all chunks into a global lonlat grid + call nlgw_latlon_gather(phys_state, phys, lonlat, gathered_lonlat) + + if (masterproc) then + call gw_nlgw_unet_init('/glade/u/home/tmeltzer/nonlocal_gwfluxes/era5_training/nlgw_unet_gpu_scripted.pt') + ! run UNet model on globally gathered lonlat grid to compute fluxes + call gw_nlgw_unet_infer(gathered_lonlat) + call gw_nlgw_unet_finalize() + endif + + ! scatter back to all procs, into chunks and regrid back to phys grid + call nlgw_latlon_scatter(phys, lonlat, gathered_lonlat) + + do c = begchunk,endchunk + ncol = phys_state(c)%ncol + temp_uflux(:ncol,:pver) = transpose(phys%uflux(:pver,:ncol,c)) + temp_vflux(:ncol,:pver) = transpose(phys%vflux(:pver,:ncol,c)) + + ! update tendencies + call flux_to_forcing(temp_uflux, temp_utgw, phys_state(c)%pmid, ncol) + call flux_to_forcing(temp_vflux, temp_vtgw, phys_state(c)%pmid, ncol) + ! ptend%u(:ncol,:pver) = ptend%u(:ncol,:pver) + utgw(:ncol,:pver) + ! ptend%v(:ncol,:pver) = ptend%v(:ncol,:pver) + vtgw(:ncol,:pver) + ! call update_enegry(ptend) + end do + + ! TODO check energy conservation after tendency update + ! TODO look at Will Chapman's code + call nlgw_regrid_final(phys, lonlat, gathered_lonlat) stop !$OMP PARALLEL DO PRIVATE (C, NCOL, phys_buffer_chunk) diff --git a/src/utils/remap.F90 b/src/utils/remap.F90 index 3dbf6b34d5..c502fb2b76 100644 --- a/src/utils/remap.F90 +++ b/src/utils/remap.F90 @@ -15,6 +15,7 @@ module nlgw_remap_mod use perf_mod, only: t_startf, t_stopf use cam_logfile, only: iulog use cam_abortutils, only: endrun + use gw_nlgw_utils, only: phys_vars, lonlat_vars implicit none @@ -25,33 +26,27 @@ module nlgw_remap_mod public :: nlgw_latlon_scatter public :: nlgw_regrid_final - ! these arrays contain the regridded variables of interest for the NN - real(r8), dimension(:,:,:), allocatable, public :: u_grid, v_grid, w_grid, t_grid - real(r8), dimension(:,:,:), allocatable, public :: utgw_grid, vtgw_grid - - ! arrays on physics grid for use in gw_drag routine - real(r8), dimension(:,:,:), allocatable, target :: utgw_phys - real(r8), dimension(:,:,:), allocatable, target :: vtgw_phys - - ! private (for book-keeping in MPI calls) integer, allocatable :: recvcnts(:), displs(:) integer, allocatable :: beglats(:), beglons(:) integer, allocatable :: endlats(:), endlons(:) + logical, parameter, public :: debug = .true. + contains !----------------------------------------------------------------------------- ! Initialize arrays and grids for regridding/MPI calls !----------------------------------------------------------------------------- - subroutine nlgw_regrid_init() - use cam_grid_support, only: horiz_coord_t, horiz_coord_create, iMap, cam_grid_register - use esmf_lonlat_grid_mod, only: glats, nlat, glons, nlon + subroutine nlgw_regrid_init(phys, lonlat, gathered_lonlat) + use cam_grid_support, only: horiz_coord_t, horiz_coord_create, iMap, cam_grid_register + use esmf_lonlat_grid_mod, only: glats, glons use esmf_lonlat_grid_mod, only: esmf_lonlat_grid_init - use esmf_phys_mesh_mod, only: esmf_phys_mesh_init + use esmf_phys_mesh_mod, only: esmf_phys_mesh_init use esmf_phys2lonlat_mod, only: esmf_phys2lonlat_init use esmf_lonlat2phys_mod, only: esmf_lonlat2phys_init + use gw_nlgw_utils, only: nlon, nlat integer, parameter :: reg_decomp = 332 @@ -62,10 +57,14 @@ subroutine nlgw_regrid_init() type(horiz_coord_t), pointer :: lat_coord integer :: i, j, ind, astat + type(phys_vars), intent(inout), target :: phys + type(lonlat_vars), intent(inout), target :: lonlat + type(lonlat_vars), intent(inout), target :: gathered_lonlat + character(len=*), parameter :: subname = 'ctem_diags_reg: ' ! initialize grids and mapping - call esmf_lonlat_grid_init(192, 288) + call esmf_lonlat_grid_init(nlat, nlon) call esmf_phys_mesh_init() call esmf_phys2lonlat_init() call esmf_lonlat2phys_init() @@ -129,19 +128,38 @@ subroutine nlgw_regrid_init() allocate(endlats(npes)) allocate(endlons(npes)) + allocate(phys%u(pver,pcols,begchunk:endchunk)) + allocate(phys%v(pver,pcols,begchunk:endchunk)) + allocate(phys%theta(pver,pcols,begchunk:endchunk)) + allocate(phys%w(pver,pcols,begchunk:endchunk)) + allocate(phys%uflux(pver,pcols,begchunk:endchunk)) + allocate(phys%vflux(pver,pcols,begchunk:endchunk)) + allocate(phys%utgw(pver,pcols,begchunk:endchunk)) + allocate(phys%vtgw(pver,pcols,begchunk:endchunk)) + + + allocate(lonlat%u(beglon:endlon,beglat:endlat,pver)) + allocate(lonlat%v(beglon:endlon,beglat:endlat,pver)) + allocate(lonlat%w(beglon:endlon,beglat:endlat,pver)) + allocate(lonlat%theta(beglon:endlon,beglat:endlat,pver)) + allocate(lonlat%uflux(beglon:endlon,beglat:endlat,pver)) + allocate(lonlat%vflux(beglon:endlon,beglat:endlat,pver)) + + if (debug) then + allocate(phys%pmid(pver,pcols,begchunk:endchunk)) + allocate(phys%lon(pcols,begchunk:endchunk)) + allocate(phys%lat(pcols,begchunk:endchunk)) + end if + ! gathered grids only exist on masterproc if (masterproc) then - allocate(u_grid(nlon, nlat, pver)) - allocate(v_grid(nlon, nlat, pver)) - allocate(w_grid(nlon, nlat, pver)) - allocate(t_grid(nlon, nlat, pver)) - allocate(utgw_grid(nlon, nlat, pver)) - allocate(vtgw_grid(nlon, nlat, pver)) + allocate(gathered_lonlat%u(nlon, nlat, pver)) + allocate(gathered_lonlat%v(nlon, nlat, pver)) + allocate(gathered_lonlat%w(nlon, nlat, pver)) + allocate(gathered_lonlat%theta(nlon, nlat, pver)) + allocate(gathered_lonlat%uflux(nlon, nlat, pver)) + allocate(gathered_lonlat%vflux(nlon, nlat, pver)) end if - - allocate(utgw_phys(pver,pcols,begchunk:endchunk)) - allocate(vtgw_phys(pver,pcols,begchunk:endchunk)) - end subroutine nlgw_regrid_init @@ -151,14 +169,14 @@ end subroutine nlgw_regrid_init ! * interpolates them back onto the cubed-sphere grid ! * finally re-chunks the data so it can be used elsewhere !----------------------------------------------------------------------------- - subroutine nlgw_latlon_scatter() - use esmf_lonlat_grid_mod, only: nlat, nlon + subroutine nlgw_latlon_scatter(phys, lonlat, gathered_lonlat) + use gw_nlgw_utils, only: nlon, nlat use esmf_lonlat2phys_mod, only: fields_bundle_t, n_flx_flds, esmf_lonlat2phys_regrid use mpishorthand - ! arrays on latlon grid - real(r8), target :: utgw_lonlat(beglon:endlon,beglat:endlat,pver) - real(r8), target :: vtgw_lonlat(beglon:endlon,beglat:endlat,pver) + type(phys_vars), intent(inout), target :: phys + type(lonlat_vars), intent(inout), target :: lonlat + type(lonlat_vars), intent(inout), target :: gathered_lonlat real(r8), allocatable :: flat_array(:) @@ -194,9 +212,8 @@ subroutine nlgw_latlon_scatter() call mpibcast(displs, npes, mpiint, 0, mpicom) call mpibcast(recvcnts, npes, mpiint, 0, mpicom) - utgw_lonlat = 0._r8 - call scatter_3d(u_grid, sendcnt, flat_array, utgw_lonlat(beglon:endlon, beglat:endlat, 1:pver)) - call scatter_3d(vtgw_grid, sendcnt, flat_array, vtgw_lonlat(beglon:endlon, beglat:endlat, 1:pver)) + call scatter_3d(gathered_lonlat%uflux, sendcnt, flat_array, lonlat%uflux(beglon:endlon, beglat:endlat, 1:pver)) + call scatter_3d(gathered_lonlat%vflux, sendcnt, flat_array, lonlat%vflux(beglon:endlon, beglat:endlat, 1:pver)) deallocate(flat_array) @@ -205,18 +222,17 @@ subroutine nlgw_latlon_scatter() call t_startf('nlgw_latlon_gather') ! this subsection does regridding - phys_flx_flds(1)%fld => utgw_phys - phys_flx_flds(2)%fld => vtgw_phys + phys_flx_flds(1)%fld => phys%uflux + phys_flx_flds(2)%fld => phys%vflux - lonlat_flx_flds(1)%fld => utgw_lonlat - lonlat_flx_flds(2)%fld => vtgw_lonlat + lonlat_flx_flds(1)%fld => lonlat%uflux + lonlat_flx_flds(2)%fld => lonlat%vflux ! actual call to regrid to lon/lat grid call esmf_lonlat2phys_regrid(lonlat_flx_flds, phys_flx_flds) call t_stopf('nlgw_latlon_gather') - call t_stopf('nlgw_scatter') end subroutine nlgw_latlon_scatter @@ -228,27 +244,18 @@ end subroutine nlgw_latlon_scatter ! * interpolates cubed-sphere variable to a regular lonlat grid ! * uses MPI_Gather to collect regridded data from all ranks to the masterproc !----------------------------------------------------------------------------- - subroutine nlgw_latlon_gather(phys_state) - use esmf_lonlat_grid_mod, only: nlat, nlon + subroutine nlgw_latlon_gather(phys_state, phys, lonlat, gathered_lonlat) + use gw_nlgw_utils, only: nlon, nlat use esmf_phys2lonlat_mod, only: fields_bundle_t, nflds, esmf_phys2lonlat_regrid + use gw_nlgw_utils, only: p0 + use physconst, only: cappa use mpishorthand type(physics_state), intent(in) :: phys_state(begchunk:endchunk) - ! arrays on physics grid - real(r8), target :: u_phys(pver,pcols,begchunk:endchunk) - real(r8), target :: v_phys(pver,pcols,begchunk:endchunk) - real(r8), target :: w_phys(pver,pcols,begchunk:endchunk) - real(r8), target :: t_phys(pver,pcols,begchunk:endchunk) - ! for debugging only - ! real(r8) :: lat_phys(pcols,begchunk:endchunk) - ! real(r8) :: lon_phys(pcols,begchunk:endchunk) - - ! arrays on latlon grid - real(r8), target :: u_lonlat(beglon:endlon,beglat:endlat,pver) - real(r8), target :: v_lonlat(beglon:endlon,beglat:endlat,pver) - real(r8), target :: w_lonlat(beglon:endlon,beglat:endlat,pver) - real(r8), target :: t_lonlat(beglon:endlon,beglat:endlat,pver) + type(phys_vars), intent(inout), target :: phys + type(lonlat_vars), intent(inout), target :: lonlat + type(lonlat_vars), intent(inout), target :: gathered_lonlat real(r8), allocatable :: flat_array(:) @@ -259,20 +266,24 @@ subroutine nlgw_latlon_gather(phys_state) call t_startf('nlgw_gather') + call t_startf('nlgw_unchunk') do lchnk = begchunk,endchunk ncol = phys_state(lchnk)%ncol do i = 1,ncol ! wind components - u_phys(:,i,lchnk) = phys_state(lchnk)%u(i,:) - v_phys(:,i,lchnk) = phys_state(lchnk)%v(i,:) - w_phys(:,i,lchnk) = phys_state(lchnk)%omega(i,:) - t_phys(:,i,lchnk) = phys_state(lchnk)%t(i,:) + phys%u(:,i,lchnk) = phys_state(lchnk)%u(i,:) + phys%v(:,i,lchnk) = phys_state(lchnk)%v(i,:) + phys%w(:,i,lchnk) = phys_state(lchnk)%omega(i,:) + phys%theta(:,i,lchnk) = phys_state(lchnk)%t(i,:) * (p0 / phys_state(lchnk)%pmid(i,:)) ** cappa ! for debugging only - ! lat_phys(i,lchnk) = phys_state(lchnk)%lat(i) - ! lon_phys(i,lchnk) = phys_state(lchnk)%lon(i) + if (debug) then + phys%pmid(:,i,lchnk) = phys_state(lchnk)%pmid(i,:) + phys%lat(i,lchnk) = phys_state(lchnk)%lat(i) + phys%lon(i,lchnk) = phys_state(lchnk)%lon(i) + end if end do end do @@ -281,15 +292,15 @@ subroutine nlgw_latlon_gather(phys_state) call t_startf('nlgw_latlon_gather') - physflds(1)%fld => u_phys - physflds(2)%fld => v_phys - physflds(3)%fld => w_phys - physflds(4)%fld => t_phys + physflds(1)%fld => phys%u + physflds(2)%fld => phys%v + physflds(3)%fld => phys%w + physflds(4)%fld => phys%theta - lonlatflds(1)%fld => u_lonlat - lonlatflds(2)%fld => v_lonlat - lonlatflds(3)%fld => w_lonlat - lonlatflds(4)%fld => t_lonlat + lonlatflds(1)%fld => lonlat%u + lonlatflds(2)%fld => lonlat%v + lonlatflds(3)%fld => lonlat%w + lonlatflds(4)%fld => lonlat%theta ! actual call to regrid to lon/lat grid call esmf_phys2lonlat_regrid(physflds, lonlatflds) @@ -317,19 +328,15 @@ subroutine nlgw_latlon_gather(phys_state) end if allocate(flat_array(nlon * nlat * pver)) - call gather_3d(u_lonlat(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, u_grid) - call gather_3d(v_lonlat(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, v_grid) - call gather_3d(w_lonlat(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, w_grid) - call gather_3d(t_lonlat(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, t_grid) + call gather_3d(lonlat%u(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, gathered_lonlat%u) + call gather_3d(lonlat%v(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, gathered_lonlat%v) + call gather_3d(lonlat%w(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, gathered_lonlat%w) + call gather_3d(lonlat%theta(beglon:endlon, beglat:endlat, 1:pver), sendcnt, flat_array, gathered_lonlat%theta) deallocate(flat_array) call t_stopf('nlgw_mpigather') - ! TODO - ! convert fluxes to tendencies after regridding back to cubed sphere - ! that way we dont need pmid - call t_stopf('nlgw_gather') end subroutine nlgw_latlon_gather @@ -415,12 +422,16 @@ end subroutine scatter_3d !----------------------------------------------------------------------------- ! Tidy up (free allocated memory) !----------------------------------------------------------------------------- - subroutine nlgw_regrid_final() + subroutine nlgw_regrid_final(phys, lonlat, gathered_lonlat) use esmf_phys2lonlat_mod, only: esmf_phys2lonlat_destroy use esmf_lonlat2phys_mod, only: esmf_lonlat2phys_destroy use esmf_lonlat_grid_mod, only: esmf_lonlat_grid_destroy use esmf_phys_mesh_mod, only: esmf_phys_mesh_destroy + type(phys_vars), intent(inout), target :: phys + type(lonlat_vars), intent(inout), target :: lonlat + type(lonlat_vars), intent(inout), target :: gathered_lonlat + call esmf_phys2lonlat_destroy() call esmf_lonlat2phys_destroy() call esmf_lonlat_grid_destroy() @@ -428,12 +439,34 @@ subroutine nlgw_regrid_final() ! TODO double check ALL deallocates here if (masterproc) then - deallocate(u_grid) - deallocate(v_grid) - deallocate(w_grid) - deallocate(t_grid) - deallocate(utgw_grid) - deallocate(vtgw_grid) + deallocate(gathered_lonlat%u) + deallocate(gathered_lonlat%v) + deallocate(gathered_lonlat%w) + deallocate(gathered_lonlat%theta) + deallocate(gathered_lonlat%uflux) + deallocate(gathered_lonlat%vflux) + end if + + deallocate(phys%u) + deallocate(phys%v) + deallocate(phys%theta) + deallocate(phys%w) + deallocate(phys%uflux) + deallocate(phys%vflux) + deallocate(phys%utgw) + deallocate(phys%vtgw) + + deallocate(lonlat%u) + deallocate(lonlat%v) + deallocate(lonlat%w) + deallocate(lonlat%theta) + deallocate(lonlat%uflux) + deallocate(lonlat%vflux) + + if (debug) then + deallocate(phys%pmid) + deallocate(phys%lon) + deallocate(phys%lat) end if deallocate(recvcnts) @@ -442,10 +475,6 @@ subroutine nlgw_regrid_final() deallocate(beglons) deallocate(endlats) deallocate(endlons) - - deallocate(utgw_phys) - deallocate(vtgw_phys) - end subroutine nlgw_regrid_final end module nlgw_remap_mod From 276baacd0ac61c1e9d3b28aad37484c85ac66095 Mon Sep 17 00:00:00 2001 From: tommelt Date: Mon, 3 Nov 2025 06:56:46 -0700 Subject: [PATCH 18/23] feat: compute tendencies and tidy up --- src/physics/cam/gw_drag.F90 | 8 +++- src/physics/cam/gw_nlgw_unet.F90 | 30 +++++++++----- src/physics/cam_dev/physpkg.F90 | 71 ++++++++++++++++++-------------- src/utils/cam_grid_support.F90 | 2 +- src/utils/remap.F90 | 7 +++- 5 files changed, 72 insertions(+), 46 deletions(-) diff --git a/src/physics/cam/gw_drag.F90 b/src/physics/cam/gw_drag.F90 index ba70ecfecf..fb3ff463c9 100644 --- a/src/physics/cam/gw_drag.F90 +++ b/src/physics/cam/gw_drag.F90 @@ -37,7 +37,8 @@ module gw_drag ! These are the actual switches for different gravity wave sources. use phys_control, only: use_gw_oro, use_gw_front, use_gw_front_igw, & use_gw_convect_dp, use_gw_convect_sh, & - use_simple_phys, use_gw_nlgw_ann + use_simple_phys, & + use_gw_nlgw_ann, use_gw_nlgw_unet use gw_common, only: GWBand use gw_convect, only: BeresSourceDesc @@ -45,6 +46,7 @@ module gw_drag use gw_ml, only: gw_drag_convect_dp_ml_init, gw_drag_convect_dp_ml_final, & gw_drag_convect_dp_ml use gw_nlgw_ann, only: gw_nlgw_ann_infer, gw_nlgw_ann_init, gw_nlgw_ann_finalize + use gw_nlgw_unet, only: gw_nlgw_unet_update_ptend ! Typical module header implicit none @@ -1556,6 +1558,10 @@ subroutine gw_tend(state, pbuf, dt, ptend, cam_in, flx_heat) call gw_nlgw_ann_infer(state1,ptend) end if + if ( use_gw_nlgw_unet ) then + call gw_nlgw_unet_update_ptend(ptend, lchnk, ncol) + end if + if (use_gw_convect_dp) then !------------------------------------------------------------------ ! Convective gravity waves (Beres scheme, deep). diff --git a/src/physics/cam/gw_nlgw_unet.F90 b/src/physics/cam/gw_nlgw_unet.F90 index 9c4f562404..db6de681aa 100644 --- a/src/physics/cam/gw_nlgw_unet.F90 +++ b/src/physics/cam/gw_nlgw_unet.F90 @@ -17,22 +17,16 @@ module gw_nlgw_unet implicit none -public :: gw_nlgw_unet_init, gw_nlgw_unet_infer, gw_nlgw_unet_finalize +public :: gw_nlgw_unet_init, gw_nlgw_unet_infer, gw_nlgw_unet_finalize, gw_nlgw_unet_update_ptend + +real(r8), dimension(:,:,:), allocatable, public :: & + utgw_allchunk, &! zonal wind tendency (m/s^2) + vtgw_allchunk ! meridional wind tendency (m/s^2) private type(torch_model) :: nlgw_model ! pytorch model -real(r8), dimension(:,:,:), allocatable :: & - uflux, &! zonal wind flux (Pa) - vflux, &! meridional wind flux (Pa) - utgw, &! zonal wind tendency (m/s^2) - vtgw ! meridional wind tendency (m/s^2) - -real(r8), dimension(:,:,:), allocatable :: & - uflux_grid, &! zonal wind flux (Pa) - vflux_grid ! meridional wind flux (Pa) - real(r4), dimension(:,:,:,:), allocatable, target :: net_inputs real(r4), dimension(:,:,:,:), allocatable, target :: net_outputs @@ -118,6 +112,20 @@ subroutine gw_nlgw_unet_finalize() end subroutine gw_nlgw_unet_finalize +subroutine gw_nlgw_unet_update_ptend(ptend, lchnk, ncol) + + use gw_nlgw_utils, only: flux_to_forcing + + ! inputs + type(physics_ptend), intent(inout) :: ptend + integer, intent(in) :: lchnk, ncol + + ! update the tendencies + ptend%u(:ncol,:pver) = ptend%u(:ncol,:pver) + utgw_allchunk(:ncol,:pver, lchnk) + ptend%v(:ncol,:pver) = ptend%v(:ncol,:pver) + vtgw_allchunk(:ncol,:pver, lchnk) + +end subroutine gw_nlgw_unet_update_ptend + subroutine read_norms() ! TODO diff --git a/src/physics/cam_dev/physpkg.F90 b/src/physics/cam_dev/physpkg.F90 index 8a2d174d76..3611ca8cca 100644 --- a/src/physics/cam_dev/physpkg.F90 +++ b/src/physics/cam_dev/physpkg.F90 @@ -869,7 +869,7 @@ subroutine phys_init( phys_state, phys_tend, pbuf2d, cam_in, cam_out ) call co2_init() end if - ! call gw_init() + call gw_init() call rayleigh_friction_init() @@ -1184,8 +1184,9 @@ subroutine phys_run2(phys_state, ztodt, phys_tend, pbuf2d, cam_out, & #endif use hemco_interface, only: HCOI_Chunk_Run use nlgw_remap_mod, only: nlgw_regrid_init, nlgw_latlon_gather, nlgw_latlon_scatter, nlgw_regrid_final - use gw_nlgw_unet, only: gw_nlgw_unet_init, gw_nlgw_unet_infer, gw_nlgw_unet_finalize + use gw_nlgw_unet, only: gw_nlgw_unet_init, gw_nlgw_unet_infer, gw_nlgw_unet_finalize, utgw_allchunk, vtgw_allchunk use gw_nlgw_utils, only: phys_vars, lonlat_vars, flux_to_forcing + use phys_control, only: use_gw_nlgw_unet ! ! Input arguments ! @@ -1253,37 +1254,42 @@ subroutine phys_run2(phys_state, ztodt, phys_tend, pbuf2d, cam_out, & call nlgw_regrid_init(phys, lonlat, gathered_lonlat) - ! gather data from all procs, all chunks into a global lonlat grid - call nlgw_latlon_gather(phys_state, phys, lonlat, gathered_lonlat) - - if (masterproc) then - call gw_nlgw_unet_init('/glade/u/home/tmeltzer/nonlocal_gwfluxes/era5_training/nlgw_unet_gpu_scripted.pt') - ! run UNet model on globally gathered lonlat grid to compute fluxes - call gw_nlgw_unet_infer(gathered_lonlat) - call gw_nlgw_unet_finalize() + if (use_gw_nlgw_unet) then + ! gather data from all procs, all chunks into a global lonlat grid + call nlgw_latlon_gather(phys_state, phys, lonlat, gathered_lonlat) + + if (masterproc) then + call gw_nlgw_unet_init('/glade/u/home/tmeltzer/nonlocal_gwfluxes/era5_training/nlgw_unet_gpu_scripted.pt') + ! run UNet model on globally gathered lonlat grid to compute fluxes + call gw_nlgw_unet_infer(gathered_lonlat) + call gw_nlgw_unet_finalize() + endif + + ! scatter back to all procs, into chunks and regrid back to phys grid + call nlgw_latlon_scatter(phys, lonlat, gathered_lonlat) + + allocate(utgw_allchunk(pcols, pver, begchunk:endchunk)) + allocate(vtgw_allchunk(pcols, pver, begchunk:endchunk)) + + do c = begchunk,endchunk + ncol = phys_state(c)%ncol + temp_uflux(:ncol,:pver) = transpose(phys%uflux(:pver,:ncol,c)) + temp_vflux(:ncol,:pver) = transpose(phys%vflux(:pver,:ncol,c)) + + ! compute tendencies from fluxes + call flux_to_forcing(temp_uflux, temp_utgw, phys_state(c)%pmid, ncol) + call flux_to_forcing(temp_vflux, temp_vtgw, phys_state(c)%pmid, ncol) + ! store tendencies in unet module so they can be updated in gw_tend + utgw_allchunk(:ncol,:pver,c) = temp_utgw(:ncol,:pver) + vtgw_allchunk(:ncol,:pver,c) = temp_vtgw(:ncol,:pver) + end do + + ! TODO check energy conservation after tendency update + ! TODO look at Will Chapman's code + ! TODO Ideally update here but for now we do it in tphysac + call nlgw_regrid_final(phys, lonlat, gathered_lonlat) endif - ! scatter back to all procs, into chunks and regrid back to phys grid - call nlgw_latlon_scatter(phys, lonlat, gathered_lonlat) - - do c = begchunk,endchunk - ncol = phys_state(c)%ncol - temp_uflux(:ncol,:pver) = transpose(phys%uflux(:pver,:ncol,c)) - temp_vflux(:ncol,:pver) = transpose(phys%vflux(:pver,:ncol,c)) - - ! update tendencies - call flux_to_forcing(temp_uflux, temp_utgw, phys_state(c)%pmid, ncol) - call flux_to_forcing(temp_vflux, temp_vtgw, phys_state(c)%pmid, ncol) - ! ptend%u(:ncol,:pver) = ptend%u(:ncol,:pver) + utgw(:ncol,:pver) - ! ptend%v(:ncol,:pver) = ptend%v(:ncol,:pver) + vtgw(:ncol,:pver) - ! call update_enegry(ptend) - end do - - ! TODO check energy conservation after tendency update - ! TODO look at Will Chapman's code - call nlgw_regrid_final(phys, lonlat, gathered_lonlat) - stop - !$OMP PARALLEL DO PRIVATE (C, NCOL, phys_buffer_chunk) do c=begchunk,endchunk @@ -1301,6 +1307,9 @@ subroutine phys_run2(phys_state, ztodt, phys_tend, pbuf2d, cam_out, & phys_state(c), phys_tend(c), phys_buffer_chunk) end do ! Chunk loop + deallocate(utgw_allchunk) + deallocate(vtgw_allchunk) + call t_adj_detailf(-1) call t_stopf('ac_physics') diff --git a/src/utils/cam_grid_support.F90 b/src/utils/cam_grid_support.F90 index d86c829e77..56d5c100f7 100644 --- a/src/utils/cam_grid_support.F90 +++ b/src/utils/cam_grid_support.F90 @@ -316,6 +316,7 @@ end subroutine print_attr_spec public :: cam_grid_compute_patch ! Functions for dealing with grid areas public :: cam_grid_get_areawt + public :: get_cam_grid_index interface cam_grid_attribute_register module procedure add_cam_grid_attribute_0d_int @@ -352,7 +353,6 @@ end subroutine print_attr_spec module procedure cam_grid_write_dist_array_3d_real end interface - ! Private interfaces interface get_cam_grid_index module procedure get_cam_grid_index_char ! For lookup by name module procedure get_cam_grid_index_int ! For lookup by ID diff --git a/src/utils/remap.F90 b/src/utils/remap.F90 index c502fb2b76..cb432cfb05 100644 --- a/src/utils/remap.F90 +++ b/src/utils/remap.F90 @@ -40,7 +40,7 @@ module nlgw_remap_mod ! Initialize arrays and grids for regridding/MPI calls !----------------------------------------------------------------------------- subroutine nlgw_regrid_init(phys, lonlat, gathered_lonlat) - use cam_grid_support, only: horiz_coord_t, horiz_coord_create, iMap, cam_grid_register + use cam_grid_support, only: horiz_coord_t, horiz_coord_create, iMap, cam_grid_register, get_cam_grid_index use esmf_lonlat_grid_mod, only: glats, glons use esmf_lonlat_grid_mod, only: esmf_lonlat_grid_init use esmf_phys_mesh_mod, only: esmf_phys_mesh_init @@ -117,7 +117,10 @@ subroutine nlgw_regrid_init(phys, lonlat, gathered_lonlat) nullify(coord_map) - call cam_grid_register('ctem_lonlat', reg_decomp, lat_coord, lon_coord, grid_map, unstruct=.false.) + ! only create grid if it doesn't exist + if (get_cam_grid_index('ctem_lonlat') == -1) then + call cam_grid_register('ctem_lonlat', reg_decomp, lat_coord, lon_coord, grid_map, unstruct=.false.) + end if nullify(grid_map) From 91430cbf7acfe269fe04e9bc243f6886d9e7bc05 Mon Sep 17 00:00:00 2001 From: tommelt Date: Mon, 3 Nov 2025 06:57:06 -0700 Subject: [PATCH 19/23] wip: output utgw and vtgw --- src/physics/cam/gw_nlgw_unet.F90 | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/physics/cam/gw_nlgw_unet.F90 b/src/physics/cam/gw_nlgw_unet.F90 index db6de681aa..20972338df 100644 --- a/src/physics/cam/gw_nlgw_unet.F90 +++ b/src/physics/cam/gw_nlgw_unet.F90 @@ -10,6 +10,7 @@ module gw_nlgw_unet use spmd_utils, only: mpicom, mstrid=>masterprocid, masterproc, mpi_real8, iam use cam_abortutils, only: endrun use cam_logfile, only: iulog +use cam_history, only: outfld, addfld use physconst, only: cappa use gw_nlgw_utils, only: lonlat_vars, nlon, nlat @@ -64,6 +65,10 @@ subroutine gw_nlgw_unet_init(model_path) allocate(net_outputs(1, pver*2, nlat, nlon)) endif + + call addfld('UTGW_NL', (/ 'lev' /), 'A', 'm/s2', 'Nonlinear GW zonal wind tendency') + call addfld('VTGW_NL', (/ 'lev' /), 'A', 'm/s2', 'Nonlinear GW meridional wind tendency') + end subroutine gw_nlgw_unet_init subroutine gw_nlgw_unet_infer(gathered_lonlat) @@ -124,6 +129,9 @@ subroutine gw_nlgw_unet_update_ptend(ptend, lchnk, ncol) ptend%u(:ncol,:pver) = ptend%u(:ncol,:pver) + utgw_allchunk(:ncol,:pver, lchnk) ptend%v(:ncol,:pver) = ptend%v(:ncol,:pver) + vtgw_allchunk(:ncol,:pver, lchnk) + call outfld('UTGW_NL', utgw_allchunk(:ncol,:pver, lchnk), ncol, lchnk) + call outfld('VTGW_NL', vtgw_allchunk(:ncol,:pver, lchnk), ncol, lchnk) + end subroutine gw_nlgw_unet_update_ptend subroutine read_norms() From 9598a818112a7505e34b7ef02d1081cabc5a8734 Mon Sep 17 00:00:00 2001 From: Matt Archer Date: Mon, 3 Nov 2025 15:28:14 -0700 Subject: [PATCH 20/23] Fix output to initialise tends earlier --- src/physics/cam/gw_nlgw_unet.F90 | 3 --- src/physics/cam_dev/physpkg.F90 | 6 ++++++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/physics/cam/gw_nlgw_unet.F90 b/src/physics/cam/gw_nlgw_unet.F90 index 20972338df..534e2d5716 100644 --- a/src/physics/cam/gw_nlgw_unet.F90 +++ b/src/physics/cam/gw_nlgw_unet.F90 @@ -66,9 +66,6 @@ subroutine gw_nlgw_unet_init(model_path) endif - call addfld('UTGW_NL', (/ 'lev' /), 'A', 'm/s2', 'Nonlinear GW zonal wind tendency') - call addfld('VTGW_NL', (/ 'lev' /), 'A', 'm/s2', 'Nonlinear GW meridional wind tendency') - end subroutine gw_nlgw_unet_init subroutine gw_nlgw_unet_infer(gathered_lonlat) diff --git a/src/physics/cam_dev/physpkg.F90 b/src/physics/cam_dev/physpkg.F90 index 3611ca8cca..ab998fcbc0 100644 --- a/src/physics/cam_dev/physpkg.F90 +++ b/src/physics/cam_dev/physpkg.F90 @@ -768,6 +768,7 @@ subroutine phys_init( phys_state, phys_tend, pbuf2d, cam_in, cam_out ) use cam_history, only: addfld, register_vector_field, add_default use cam_budget, only: cam_budget_init use phys_grid_ctem, only: phys_grid_ctem_init + use phys_control, only: use_gw_nlgw_unet use ccpp_constituent_prop_mod, only: ccpp_const_props_init @@ -949,6 +950,11 @@ subroutine phys_init( phys_state, phys_tend, pbuf2d, cam_in, cam_out ) end if + + if (use_gw_nlgw_unet) then + call addfld('UTGW_NL', (/ 'lev' /), 'A', 'm/s2', 'Nonlinear GW zonal wind tendency') + call addfld('VTGW_NL', (/ 'lev' /), 'A', 'm/s2', 'Nonlinear GW meridional wind tendency') + end if ! Initialize CAM CCPP constituent properties array ! for use in CCPP-ized physics schemes: call ccpp_const_props_init() From 4de751dba6071aaa60844ad9bfd357685cbe1c95 Mon Sep 17 00:00:00 2001 From: tommelt Date: Wed, 12 Nov 2025 08:56:21 -0700 Subject: [PATCH 21/23] feat: update tendencies using Wills approach --- src/physics/cam/gw_drag.F90 | 5 --- src/physics/cam/gw_nlgw_unet.F90 | 53 +++++++++++++++++++------------- src/physics/cam_dev/physpkg.F90 | 30 +++++++++--------- 3 files changed, 47 insertions(+), 41 deletions(-) diff --git a/src/physics/cam/gw_drag.F90 b/src/physics/cam/gw_drag.F90 index fb3ff463c9..3a8d6fbf9b 100644 --- a/src/physics/cam/gw_drag.F90 +++ b/src/physics/cam/gw_drag.F90 @@ -46,7 +46,6 @@ module gw_drag use gw_ml, only: gw_drag_convect_dp_ml_init, gw_drag_convect_dp_ml_final, & gw_drag_convect_dp_ml use gw_nlgw_ann, only: gw_nlgw_ann_infer, gw_nlgw_ann_init, gw_nlgw_ann_finalize - use gw_nlgw_unet, only: gw_nlgw_unet_update_ptend ! Typical module header implicit none @@ -1558,10 +1557,6 @@ subroutine gw_tend(state, pbuf, dt, ptend, cam_in, flx_heat) call gw_nlgw_ann_infer(state1,ptend) end if - if ( use_gw_nlgw_unet ) then - call gw_nlgw_unet_update_ptend(ptend, lchnk, ncol) - end if - if (use_gw_convect_dp) then !------------------------------------------------------------------ ! Convective gravity waves (Beres scheme, deep). diff --git a/src/physics/cam/gw_nlgw_unet.F90 b/src/physics/cam/gw_nlgw_unet.F90 index 534e2d5716..bd82aba602 100644 --- a/src/physics/cam/gw_nlgw_unet.F90 +++ b/src/physics/cam/gw_nlgw_unet.F90 @@ -18,11 +18,7 @@ module gw_nlgw_unet implicit none -public :: gw_nlgw_unet_init, gw_nlgw_unet_infer, gw_nlgw_unet_finalize, gw_nlgw_unet_update_ptend - -real(r8), dimension(:,:,:), allocatable, public :: & - utgw_allchunk, &! zonal wind tendency (m/s^2) - vtgw_allchunk ! meridional wind tendency (m/s^2) +public :: gw_nlgw_unet_init, gw_nlgw_unet_infer, gw_nlgw_unet_finalize, gw_nlgw_unet_set_ptend private @@ -114,22 +110,37 @@ subroutine gw_nlgw_unet_finalize() end subroutine gw_nlgw_unet_finalize -subroutine gw_nlgw_unet_update_ptend(ptend, lchnk, ncol) - - use gw_nlgw_utils, only: flux_to_forcing - - ! inputs - type(physics_ptend), intent(inout) :: ptend - integer, intent(in) :: lchnk, ncol - - ! update the tendencies - ptend%u(:ncol,:pver) = ptend%u(:ncol,:pver) + utgw_allchunk(:ncol,:pver, lchnk) - ptend%v(:ncol,:pver) = ptend%v(:ncol,:pver) + vtgw_allchunk(:ncol,:pver, lchnk) - - call outfld('UTGW_NL', utgw_allchunk(:ncol,:pver, lchnk), ncol, lchnk) - call outfld('VTGW_NL', vtgw_allchunk(:ncol,:pver, lchnk), ncol, lchnk) - -end subroutine gw_nlgw_unet_update_ptend +subroutine gw_nlgw_unet_set_ptend(phys_state, ptend, utgw, vtgw) + ! initiailize and update tendencies + use physconst ,only: cpair + use physics_types,only: physics_state,physics_ptend,physics_ptend_init + use constituents ,only: cnst_get_ind,pcnst + use ppgrid ,only: pver,pcols,begchunk,endchunk + use cam_history ,only: outfld + + type(physics_state), intent(in) :: phys_state + type(physics_ptend), intent(out):: ptend + real(r8), dimension(pcols,pver), intent(in) :: utgw, vtgw + + ! local vars + integer indw,ncol,lchnk + logical lq(pcnst) + + call cnst_get_ind('Q',indw) + lq(:) =.false. + lq(indw)=.true. + call physics_ptend_init(ptend,phys_state%psetcols,'cb24cnn',lu=.true.,lv=.true.,ls=.true.,lq=lq) + + lchnk=phys_state%lchnk + ncol =phys_state%ncol + ptend%u(:ncol,:pver) = ptend%u(:ncol,:pver) + utgw(:ncol,:pver) + ptend%v(:ncol,:pver) = ptend%v(:ncol,:pver) + vtgw(:ncol,:pver) + ! ptend%s(:ncol,:pver) = cb24cnn_Sstep(:ncol,:pver,lchnk)*.35 + ! ptend%q(:ncol,:pver,indw)= cb24cnn_Qstep(:ncol,:pver,lchnk)*.35 + + call outfld('UTGW_NL', utgw(:ncol,:pver), ncol, lchnk) + call outfld('VTGW_NL', vtgw(:ncol,:pver), ncol, lchnk) +end subroutine gw_nlgw_unet_set_ptend subroutine read_norms() diff --git a/src/physics/cam_dev/physpkg.F90 b/src/physics/cam_dev/physpkg.F90 index ab998fcbc0..74f108ed93 100644 --- a/src/physics/cam_dev/physpkg.F90 +++ b/src/physics/cam_dev/physpkg.F90 @@ -1190,9 +1190,11 @@ subroutine phys_run2(phys_state, ztodt, phys_tend, pbuf2d, cam_out, & #endif use hemco_interface, only: HCOI_Chunk_Run use nlgw_remap_mod, only: nlgw_regrid_init, nlgw_latlon_gather, nlgw_latlon_scatter, nlgw_regrid_final - use gw_nlgw_unet, only: gw_nlgw_unet_init, gw_nlgw_unet_infer, gw_nlgw_unet_finalize, utgw_allchunk, vtgw_allchunk + use gw_nlgw_unet, only: gw_nlgw_unet_init, gw_nlgw_unet_infer, gw_nlgw_unet_finalize, gw_nlgw_unet_set_ptend use gw_nlgw_utils, only: phys_vars, lonlat_vars, flux_to_forcing use phys_control, only: use_gw_nlgw_unet + use time_manager, only: get_nstep + use check_energy, only: check_energy_chng ! ! Input arguments ! @@ -1212,14 +1214,19 @@ subroutine phys_run2(phys_state, ztodt, phys_tend, pbuf2d, cam_out, & ! integer :: c ! chunk index integer :: ncol ! number of columns + integer :: nstep ! current timestep number type(physics_buffer_desc),pointer, dimension(:) :: phys_buffer_chunk + ! for ML UNet model - type(phys_vars), target :: phys - type(lonlat_vars), target :: lonlat, gathered_lonlat + type(physics_ptend) :: ptend ! parameterization tendencies for nlgw + real(r8) :: zero(pcols) ! array of zeros + type(phys_vars), target :: phys + type(lonlat_vars), target :: lonlat, gathered_lonlat real(r8), dimension(pcols,pver) :: temp_uflux, temp_vflux, temp_utgw, temp_vtgw ! ! If exit condition just return ! + nstep = get_nstep() if(single_column.and.scm_crm_mode) then call diag_deallocate() @@ -1274,9 +1281,6 @@ subroutine phys_run2(phys_state, ztodt, phys_tend, pbuf2d, cam_out, & ! scatter back to all procs, into chunks and regrid back to phys grid call nlgw_latlon_scatter(phys, lonlat, gathered_lonlat) - allocate(utgw_allchunk(pcols, pver, begchunk:endchunk)) - allocate(vtgw_allchunk(pcols, pver, begchunk:endchunk)) - do c = begchunk,endchunk ncol = phys_state(c)%ncol temp_uflux(:ncol,:pver) = transpose(phys%uflux(:pver,:ncol,c)) @@ -1285,14 +1289,13 @@ subroutine phys_run2(phys_state, ztodt, phys_tend, pbuf2d, cam_out, & ! compute tendencies from fluxes call flux_to_forcing(temp_uflux, temp_utgw, phys_state(c)%pmid, ncol) call flux_to_forcing(temp_vflux, temp_vtgw, phys_state(c)%pmid, ncol) - ! store tendencies in unet module so they can be updated in gw_tend - utgw_allchunk(:ncol,:pver,c) = temp_utgw(:ncol,:pver) - vtgw_allchunk(:ncol,:pver,c) = temp_vtgw(:ncol,:pver) + ! update ptend + call gw_nlgw_unet_set_ptend(phys_state(c), ptend, temp_utgw, temp_vtgw) + ! update state and check energy change + call physics_update(phys_state(c), ptend, ztodt, phys_tend(c)) + call check_energy_chng(phys_state(c), phys_tend(c), "nlgw_unet", nstep, ztodt, zero, zero, zero, zero) end do - ! TODO check energy conservation after tendency update - ! TODO look at Will Chapman's code - ! TODO Ideally update here but for now we do it in tphysac call nlgw_regrid_final(phys, lonlat, gathered_lonlat) endif @@ -1313,9 +1316,6 @@ subroutine phys_run2(phys_state, ztodt, phys_tend, pbuf2d, cam_out, & phys_state(c), phys_tend(c), phys_buffer_chunk) end do ! Chunk loop - deallocate(utgw_allchunk) - deallocate(vtgw_allchunk) - call t_adj_detailf(-1) call t_stopf('ac_physics') From 917de2623cdf53a949cc5b801c48a2646829b919 Mon Sep 17 00:00:00 2001 From: tommelt Date: Thu, 13 Nov 2025 08:18:53 -0700 Subject: [PATCH 22/23] chore: change name to nlgw_unet when init ptend --- src/physics/cam/gw_nlgw_unet.F90 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/physics/cam/gw_nlgw_unet.F90 b/src/physics/cam/gw_nlgw_unet.F90 index bd82aba602..c5be68ab93 100644 --- a/src/physics/cam/gw_nlgw_unet.F90 +++ b/src/physics/cam/gw_nlgw_unet.F90 @@ -129,7 +129,7 @@ subroutine gw_nlgw_unet_set_ptend(phys_state, ptend, utgw, vtgw) call cnst_get_ind('Q',indw) lq(:) =.false. lq(indw)=.true. - call physics_ptend_init(ptend,phys_state%psetcols,'cb24cnn',lu=.true.,lv=.true.,ls=.true.,lq=lq) + call physics_ptend_init(ptend,phys_state%psetcols,'nlgw_unet',lu=.true.,lv=.true.,ls=.true.,lq=lq) lchnk=phys_state%lchnk ncol =phys_state%ncol From 9b9753b92ea49d5da6c7bac854bb869f361f3e80 Mon Sep 17 00:00:00 2001 From: tommelt Date: Thu, 13 Nov 2025 09:05:59 -0700 Subject: [PATCH 23/23] chore: remove hardcoded path for unet and add paths to module --- src/physics/cam/gw_drag.F90 | 4 ++-- src/physics/cam/gw_nlgw_utils.F90 | 2 ++ src/physics/cam_dev/physpkg.F90 | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/physics/cam/gw_drag.F90 b/src/physics/cam/gw_drag.F90 index 3a8d6fbf9b..3e7f8bacb1 100644 --- a/src/physics/cam/gw_drag.F90 +++ b/src/physics/cam/gw_drag.F90 @@ -204,8 +204,6 @@ module gw_drag logical :: gw_convect_dp_ml_compare = .false. character(len=132) :: gw_convect_dp_ml_net_path character(len=132) :: gw_convect_dp_ml_norms - character(len=132) :: gw_nlgw_model_path_ann - character(len=132) :: gw_nlgw_model_path_unet !========================================================================== contains @@ -218,6 +216,7 @@ subroutine gw_drag_readnl(nlfile) use spmd_utils, only: mpicom, mstrid=>masterprocid, mpi_real8, & mpi_character, mpi_logical, mpi_integer use gw_rdg, only: gw_rdg_readnl + use gw_nlgw_utils, only: gw_nlgw_model_path_ann, gw_nlgw_model_path_unet ! File containing namelist input. character(len=*), intent(in) :: nlfile @@ -430,6 +429,7 @@ subroutine gw_init() use gw_common, only: gw_common_init use gw_front, only: gaussian_cm_desc + use gw_nlgw_utils, only: gw_nlgw_model_path_ann !---------------------------Local storage------------------------------- diff --git a/src/physics/cam/gw_nlgw_utils.F90 b/src/physics/cam/gw_nlgw_utils.F90 index d0cd5e7a10..b16b5d0fbf 100644 --- a/src/physics/cam/gw_nlgw_utils.F90 +++ b/src/physics/cam/gw_nlgw_utils.F90 @@ -10,6 +10,8 @@ module gw_nlgw_utils integer, parameter, public :: p0 = 100000 ! 1000 hPa (Pa) integer, parameter, public :: nlon = 288 ! number of longitude points on lonlat grid integer, parameter, public :: nlat = 192 ! number of latitude points on lonlat grid +character(len=132), public :: gw_nlgw_model_path_ann +character(len=132), public :: gw_nlgw_model_path_unet private diff --git a/src/physics/cam_dev/physpkg.F90 b/src/physics/cam_dev/physpkg.F90 index 74f108ed93..c5f0e15353 100644 --- a/src/physics/cam_dev/physpkg.F90 +++ b/src/physics/cam_dev/physpkg.F90 @@ -1191,7 +1191,7 @@ subroutine phys_run2(phys_state, ztodt, phys_tend, pbuf2d, cam_out, & use hemco_interface, only: HCOI_Chunk_Run use nlgw_remap_mod, only: nlgw_regrid_init, nlgw_latlon_gather, nlgw_latlon_scatter, nlgw_regrid_final use gw_nlgw_unet, only: gw_nlgw_unet_init, gw_nlgw_unet_infer, gw_nlgw_unet_finalize, gw_nlgw_unet_set_ptend - use gw_nlgw_utils, only: phys_vars, lonlat_vars, flux_to_forcing + use gw_nlgw_utils, only: phys_vars, lonlat_vars, flux_to_forcing, gw_nlgw_model_path_unet use phys_control, only: use_gw_nlgw_unet use time_manager, only: get_nstep use check_energy, only: check_energy_chng @@ -1272,7 +1272,7 @@ subroutine phys_run2(phys_state, ztodt, phys_tend, pbuf2d, cam_out, & call nlgw_latlon_gather(phys_state, phys, lonlat, gathered_lonlat) if (masterproc) then - call gw_nlgw_unet_init('/glade/u/home/tmeltzer/nonlocal_gwfluxes/era5_training/nlgw_unet_gpu_scripted.pt') + call gw_nlgw_unet_init(gw_nlgw_model_path_unet) ! run UNet model on globally gathered lonlat grid to compute fluxes call gw_nlgw_unet_infer(gathered_lonlat) call gw_nlgw_unet_finalize()