Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
97ed194
wip: add remapping subroutine
TomMelt Jun 16, 2025
e4e1296
chore: remove unused code
TomMelt Jun 27, 2025
dc33839
wip: regridding debug
TomMelt Jul 30, 2025
84e4f07
wip: use 192x288 grid instead
TomMelt Aug 20, 2025
28e0530
wip: try to fix corner issue in regrid
TomMelt Aug 20, 2025
d7de81c
chore: remove unnecessary vars
TomMelt Sep 2, 2025
8ed58f0
wip: gathering ps onto masterproc
TomMelt Sep 2, 2025
e55c585
wip: 2d field gather works, need to do 3d now
TomMelt Sep 3, 2025
a7e7de0
untested code for lonlat2phys regrid
fvitt Sep 3, 2025
27fa6a5
wip: gather 3d fields working now
TomMelt Sep 4, 2025
25f9b50
feat: make grid arrays available to module
TomMelt Sep 10, 2025
a8cb8d1
Revert "wip: try to fix corner issue in regrid"
TomMelt Sep 10, 2025
461566f
Merge branch 'nonlocal-gws-global' into lonlat2phys_regrid
TomMelt Sep 10, 2025
29e4950
Merge pull request #33 from fvitt/lonlat2phys_regrid
TomMelt Sep 10, 2025
b39fdde
feat: change polemethod back to ESMF_POLEMETHOD_ALLAVG
TomMelt Sep 10, 2025
3f1f8df
feat: regridding works both ways!
TomMelt Sep 11, 2025
e460ad6
chore: restructure gw_nlgw ready for gw_nlgw_unet
TomMelt Sep 12, 2025
298fdf9
chore: tidy up remap.F90
TomMelt Sep 12, 2025
d593e05
feat: (wip) add UNet model to CAM
TomMelt Oct 29, 2025
276baac
feat: compute tendencies and tidy up
TomMelt Nov 3, 2025
91430cb
wip: output utgw and vtgw
TomMelt Nov 3, 2025
9598a81
Fix output to initialise tends earlier
ma595 Nov 3, 2025
83c0e72
Merge pull request #35 from DataWaveProject/nonlocal-gws-global-output
TomMelt Nov 12, 2025
4de751d
feat: update tendencies using Wills approach
TomMelt Nov 12, 2025
917de26
chore: change name to nlgw_unet when init ptend
TomMelt Nov 13, 2025
9b9753b
chore: remove hardcoded path for unet and add paths to module
TomMelt Nov 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,26 @@ gw_convect_dp_ml_norms='/path/to/norms'

To run CAM using the non local gravity wave ML model to replace all parameterisations use the following configuration
```fortran
use_gw_nlgw=.true.
gw_nlgw_model_path='/path/to/nlgw-scripted-model.pt'
use_gw_nlgw_ann=.true.
use_gw_nlgw_unet=.true.
gw_nlgw_model_path_ann='/path/to/ann-scripted-model.pt'
gw_nlgw_model_path_unet='/path/to/unet-scripted-model.pt'
```

* `use_gw_nlgw` (`logical`)
* `use_gw_nlgw_ann` (`logical`)

Whether or not to use the ML scheme for non local gravity waves. Default: `.false.`
Whether or not to use the ANN ML scheme for non local gravity waves. Default: `.false.`

* `gw_nlgw_model_path`
* `gw_nlgw_model_path_ann`

Absolute filepath to the non local gravity wave neural net used when `use_gw_nlgw` is set to `.true.` (`.pt`
extension).

* `use_gw_nlgw_unet` (`logical`)

Whether or not to use the UNET ML scheme for non local gravity waves. Default: `.false.`

* `gw_nlgw_model_path_unet`

Absolute filepath to the non local gravity wave neural net used when `use_gw_nlgw` is set to `.true.` (`.pt`
extension).
Expand Down
3 changes: 2 additions & 1 deletion bld/build-namelist
Original file line number Diff line number Diff line change
Expand Up @@ -3606,7 +3606,8 @@ if (!$simple_phys) {
add_default($nl, 'use_gw_rdg_gamma' , 'val'=>'.false.');
add_default($nl, 'use_gw_front_igw' , 'val'=>'.false.');
add_default($nl, 'use_gw_convect_sh', 'val'=>'.false.');
add_default($nl, 'use_gw_nlgw' , 'val'=>'.false.');
add_default($nl, 'use_gw_nlgw_ann' , 'val'=>'.false.');
add_default($nl, 'use_gw_nlgw_unet' , 'val'=>'.false.');
add_default($nl, 'gw_lndscl_sgh');
add_default($nl, 'gw_oro_south_fac');
add_default($nl, 'gw_limit_tau_without_eff');
Expand Down
21 changes: 17 additions & 4 deletions bld/namelist_files/namelist_definition.xml
Original file line number Diff line number Diff line change
Expand Up @@ -1332,10 +1332,17 @@ Whether or not to enable gravity waves produced by shallow convection.
Default: .false.
</entry>

<entry id="use_gw_nlgw" type="logical" category="gw_drag"
<entry id="use_gw_nlgw_ann" type="logical" category="gw_drag"
group="phys_ctl_nl" valid_values="" >
Whether or not to enable gravity waves produced by non-local gravity
wave ML model.
wave ANN ML model.
Default: set by build-namelist.
</entry>

<entry id="use_gw_nlgw_unet" type="logical" category="gw_drag"
group="phys_ctl_nl" valid_values="" >
Whether or not to enable gravity waves produced by non-local gravity
wave UNet ML model.
Default: set by build-namelist.
</entry>

Expand Down Expand Up @@ -1428,10 +1435,16 @@ Absolute filepath to the deep convection gravity wave neural net used when
Default: .false.
</entry>

<entry id="gw_nlgw_model_path" type="char*132" input_pathname="abs" category="gw_drag"
<entry id="gw_nlgw_model_path_ann" type="char*132" input_pathname="abs" category="gw_drag"
group="gw_drag_nl" valid_values="" >
Absolute filepath to the non local gravity wave traced model (.pt)
used when `use_gw_nlgw_ann` is set to `.true.`.
</entry>

<entry id="gw_nlgw_model_path_unet" type="char*132" input_pathname="abs" category="gw_drag"
group="gw_drag_nl" valid_values="" >
Absolute filepath to the non local gravity wave traced model (.pt)
used when `use_gw_nlgw` is set to `.true.`.
used when `use_gw_nlgw_unet` is set to `.true.`.
</entry>

<entry id="effgw_cm" type="real" category="gw_drag"
Expand Down
29 changes: 17 additions & 12 deletions src/physics/cam/gw_drag.F90
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@ module gw_drag
! These are the actual switches for different gravity wave sources.
use phys_control, only: use_gw_oro, use_gw_front, use_gw_front_igw, &
use_gw_convect_dp, use_gw_convect_sh, &
use_simple_phys, use_gw_nlgw
use_simple_phys, &
use_gw_nlgw_ann, use_gw_nlgw_unet

use gw_common, only: GWBand
use gw_convect, only: BeresSourceDesc
use gw_front, only: CMSourceDesc
use gw_ml, only: gw_drag_convect_dp_ml_init, gw_drag_convect_dp_ml_final, &
gw_drag_convect_dp_ml
use gw_nlgw, only: gw_nlgw_dp_ml, gw_nlgw_dp_init, gw_nlgw_dp_finalize
use gw_nlgw_ann, only: gw_nlgw_ann_infer, gw_nlgw_ann_init, gw_nlgw_ann_finalize

! Typical module header
implicit none
Expand Down Expand Up @@ -203,7 +204,6 @@ module gw_drag
logical :: gw_convect_dp_ml_compare = .false.
character(len=132) :: gw_convect_dp_ml_net_path
character(len=132) :: gw_convect_dp_ml_norms
character(len=132) :: gw_nlgw_model_path

!==========================================================================
contains
Expand All @@ -216,6 +216,7 @@ subroutine gw_drag_readnl(nlfile)
use spmd_utils, only: mpicom, mstrid=>masterprocid, mpi_real8, &
mpi_character, mpi_logical, mpi_integer
use gw_rdg, only: gw_rdg_readnl
use gw_nlgw_utils, only: gw_nlgw_model_path_ann, gw_nlgw_model_path_unet

! File containing namelist input.
character(len=*), intent(in) :: nlfile
Expand Down Expand Up @@ -248,7 +249,7 @@ subroutine gw_drag_readnl(nlfile)
gw_top_taper, front_gaussian_width, &
gw_convect_dp_ml, gw_convect_dp_ml_compare, &
gw_convect_dp_ml_net_path, gw_convect_dp_ml_norms, &
gw_nlgw_model_path
gw_nlgw_model_path_ann, gw_nlgw_model_path_unet
!----------------------------------------------------------------------

if (use_simple_phys) return
Expand Down Expand Up @@ -364,8 +365,11 @@ subroutine gw_drag_readnl(nlfile)
call mpi_bcast(gw_convect_dp_ml_norms, len(gw_convect_dp_ml_norms), mpi_character, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: gw_convect_dp_ml_norms")

call mpi_bcast(gw_nlgw_model_path, len(gw_nlgw_model_path), mpi_character, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: gw_nlgw_model_path")
call mpi_bcast(gw_nlgw_model_path_ann, len(gw_nlgw_model_path_ann), mpi_character, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: gw_nlgw_model_path_ann")

call mpi_bcast(gw_nlgw_model_path_unet, len(gw_nlgw_model_path_unet), mpi_character, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: gw_nlgw_model_path_unet")

! Check if fcrit2 was set.
call shr_assert(fcrit2 /= unset_r8, &
Expand Down Expand Up @@ -425,6 +429,7 @@ subroutine gw_init()

use gw_common, only: gw_common_init
use gw_front, only: gaussian_cm_desc
use gw_nlgw_utils, only: gw_nlgw_model_path_ann

!---------------------------Local storage-------------------------------

Expand Down Expand Up @@ -577,8 +582,8 @@ subroutine gw_init()
call shr_assert(trim(errstring) == "", "gw_common_init: "//errstring// &
errMsg(__FILE__, __LINE__))

if ( use_gw_nlgw ) then
call gw_nlgw_dp_init(gw_nlgw_model_path)
if ( use_gw_nlgw_ann ) then
call gw_nlgw_ann_init(gw_nlgw_model_path_ann)
end if

if ( use_gw_oro ) then
Expand Down Expand Up @@ -1298,8 +1303,8 @@ subroutine gw_final()
if ((gw_convect_dp_ml) .or. (gw_convect_dp_ml_compare)) then
call gw_drag_convect_dp_ml_final()
endif
if ( use_gw_nlgw ) then
call gw_nlgw_dp_finalize()
if ( use_gw_nlgw_ann ) then
call gw_nlgw_ann_finalize()
end if
end subroutine gw_final

Expand Down Expand Up @@ -1548,8 +1553,8 @@ subroutine gw_tend(state, pbuf, dt, ptend, cam_in, flx_heat)
egwdffi_tot = 0._r8
flx_heat = 0._r8

if ( use_gw_nlgw ) then
call gw_nlgw_dp_ml(state1,ptend)
if ( use_gw_nlgw_ann ) then
call gw_nlgw_ann_infer(state1,ptend)
end if

if (use_gw_convect_dp) then
Expand Down
55 changes: 15 additions & 40 deletions src/physics/cam/gw_nlgw.F90 → src/physics/cam/gw_nlgw_ann.F90
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module gw_nlgw
module gw_nlgw_ann

!
! This module predicts gravity wave forcings via PyTorch NNs trained to include non-local gravity wave effects
Expand All @@ -11,18 +11,17 @@ module gw_nlgw
use cam_abortutils, only: endrun
use cam_logfile, only: iulog
use physconst, only: cappa, pi
use gw_nlgw_utils, only: p0
use interpolate_data, only: lininterp

use ftorch

implicit none

public :: gw_nlgw_dp_ml, gw_nlgw_dp_init, gw_nlgw_dp_finalize
public :: gw_nlgw_ann_infer, gw_nlgw_ann_init, gw_nlgw_ann_finalize

private

integer, parameter :: p0 = 100000 ! 1000 hPa (Pa)

type(torch_model) :: nlgw_model ! pytorch model

integer :: ncol ! number of vertical columns
Expand Down Expand Up @@ -104,7 +103,9 @@ module gw_nlgw

!==========================================================================

subroutine gw_nlgw_dp_ml(state_in, ptend)
subroutine gw_nlgw_ann_infer(state_in, ptend)

use gw_nlgw_utils, only: flux_to_forcing

! inputs
type(physics_state), intent(in) :: state_in
Expand Down Expand Up @@ -170,8 +171,8 @@ subroutine gw_nlgw_dp_ml(state_in, ptend)
call extract_output()
call denormalise_data()

call flux_to_forcing(uflux, utgw)
call flux_to_forcing(vflux, vtgw)
call flux_to_forcing(uflux, utgw, pmid, ncol)
call flux_to_forcing(vflux, vtgw, pmid, ncol)

! update the tendencies
ptend%u(:ncol,:pver) = ptend%u(:ncol,:pver) + utgw(:ncol,:pver)
Expand Down Expand Up @@ -200,10 +201,10 @@ subroutine gw_nlgw_dp_ml(state_in, ptend)
deallocate(net_inputs)
deallocate(net_outputs)

end subroutine gw_nlgw_dp_ml
end subroutine gw_nlgw_ann_infer


subroutine gw_nlgw_dp_init(model_path)
subroutine gw_nlgw_ann_init(model_path)

character(len=*), intent(in) :: model_path ! Filepath to PyTorch Torchscript net
integer :: device_id
Expand All @@ -219,17 +220,17 @@ subroutine gw_nlgw_dp_init(model_path)
write(iulog,*)'nlgw model loaded from: ', model_path
endif

end subroutine gw_nlgw_dp_init
end subroutine gw_nlgw_ann_init


subroutine gw_nlgw_dp_finalize()
subroutine gw_nlgw_ann_finalize()

deallocate(net_inputs)
deallocate(net_outputs)
! free model memory
call torch_delete(nlgw_model)

end subroutine gw_nlgw_dp_finalize
end subroutine gw_nlgw_ann_finalize


subroutine read_norms()
Expand Down Expand Up @@ -259,6 +260,7 @@ subroutine read_norms()
end subroutine read_norms

subroutine normalise_data()
use gw_nlgw_utils, only: cbrt

! lat lon are in radians (convert to degrees first)
lat = lat * 180. / pi
Expand Down Expand Up @@ -363,31 +365,4 @@ subroutine denormalise_data()

end subroutine denormalise_data

elemental function cbrt(a) result(root)
real(r8), intent(in) :: a
real(r8), parameter :: one_third = 1._r8/3._r8
real(r8) :: root
root = sign(abs(a)**one_third, a)
end function cbrt

subroutine flux_to_forcing(flux, forcing)

real(r8), intent(in), dimension(:,:) :: flux
real(r8), intent(out), dimension(:,:) :: forcing ! forcing = -d(u'\omega')/d(p), units = m/s^2

integer :: level, col

! convert fluxes to tendencies
! pressure profile must be in Pascals

do col = 1, ncol
forcing(col,1) = -1*(flux(col,2) - flux(col,1))/(pmid(col,2) - pmid(col,1))
do level = 2, pver-1
forcing(col,level) = (flux(col,level+1) - flux(col,level-1)) / (pmid(col,level)*(log(pmid(col,level+1)) - log(pmid(col,level-1))))
end do
forcing(col,pver) = -1*(flux(col,pver) - flux(col,pver-1)) / (pmid(col,pver) - pmid(col,pver-1))
end do

end subroutine flux_to_forcing

end module gw_nlgw
end module gw_nlgw_ann
Loading