Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ML coupling #14

Open
wants to merge 22 commits into
base: datawave_ml
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
fd6e669
Add variable for the convect_dp_gw net filepath to be read from namel…
jatkinson1000 Jun 17, 2024
bed301e
WIP Adapt gw_drag.F90 to import ftorch and read in a net from file. T…
jatkinson1000 Jun 17, 2024
7968fa9
Move reading of NN to initialisation routine for gw_drag.
jatkinson1000 Jun 24, 2024
b3ac3e9
Add a gw_final() subroutine to destroy NN instance.
jatkinson1000 Jun 24, 2024
16b53ac
Remove spurious test code.
jatkinson1000 Aug 5, 2024
dfd451d
Update source to use latest version of FTorch following API Changes.
jatkinson1000 Jul 30, 2024
0d334db
WIP add gw_ml file with starting point for a routine to run the ML pa…
jatkinson1000 Aug 5, 2024
01950bc
Add normalisation file to namelist variables, read, and distribute.
jatkinson1000 Aug 16, 2024
603c78a
Create gw_ml initialisation routine to read in net and normalisation …
jatkinson1000 Aug 19, 2024
72a50a9
Add finalisation routine to gw_ml to destroy the net.
jatkinson1000 Aug 19, 2024
a9f95ee
Add normalisation routine for inputs to net.
jatkinson1000 Aug 19, 2024
19fa91f
Add denormalisation routine for the gw_convect_dp net.
jatkinson1000 Aug 19, 2024
7bf43e1
BugFix: Remove redundant net variable
jatkinson1000 Aug 19, 2024
c01db7e
Couple gw_drag to the ML scheme
jatkinson1000 Aug 19, 2024
921a189
Bugfix: Update ML switches to be logicals when calling gw_drag init a…
jatkinson1000 Aug 19, 2024
aaecc4a
[DEBUG: Run net with ones as the input to compare to same operation u…
jatkinson1000 Sep 9, 2024
5990616
Revert changes to debug commit using ones as net input.
jatkinson1000 Sep 9, 2024
83cb8c9
Add output for U and V tendencies from GW from Net and Beres schemes.
jatkinson1000 Sep 9, 2024
03bd7f5
Set tendencies not predicted by the net to 0.0 for the ml scheme rath…
jatkinson1000 Oct 15, 2024
bf2a8bb
Add iulog to gw_ml to write outputs to diagnostics in init scheme.
jatkinson1000 Oct 15, 2024
218242f
Add code to produce output required by the testcase.
jatkinson1000 Sep 30, 2024
a2d881c
Update rhoi output to be ncol rather than pcol.
jatkinson1000 Oct 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bld/namelist_files/namelist_defaults_cam.xml
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,7 @@
<gw_convect_dp_ml >.false.</gw_convect_dp_ml>
<gw_convect_dp_ml_compare >.false.</gw_convect_dp_ml_compare>
<gw_convect_dp_ml_net_path >NONE </gw_convect_dp_ml_net_path>
<gw_convect_dp_ml_norms >NONE </gw_convect_dp_ml_norms>

<!-- setting for gravity waves from shallow convection. -->
<effgw_beres_sh>0.03D0</effgw_beres_sh>
Expand Down
6 changes: 6 additions & 0 deletions bld/namelist_files/namelist_definition.xml
Original file line number Diff line number Diff line change
Expand Up @@ -1405,6 +1405,12 @@ Absolute filepath to the deep convection gravity wave neural net used when
`gw_convect_dp_ml` is set to `.true.`.
</entry>

<entry id="gw_convect_dp_ml_norms" type="char*132" input_pathname="abs" category="gw_drag"
group="gw_drag_nl" valid_values="" >
Absolute filepath to the deep convection gravity wave normalisation weights (NetCDF)
used when `gw_convect_dp_ml` is set to `.true.`.
</entry>

<entry id="gw_convect_dp_ml_compare" type="logical" category="gw_drag"
group="gw_drag_nl" valid_values="" >
Whether or not to run a piggybacking comparison of the ML deep convection gravity
Expand Down
106 changes: 90 additions & 16 deletions src/physics/cam/gw_drag.F90
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ module gw_drag
use cam_logfile, only: iulog
use cam_abortutils, only: endrun

use ftorch

use ref_pres, only: do_molec_diff, nbot_molec, press_lim_idx
use physconst, only: cpair

Expand All @@ -40,6 +42,8 @@ module gw_drag
use gw_common, only: GWBand
use gw_convect, only: BeresSourceDesc
use gw_front, only: CMSourceDesc
use gw_ml, only: gw_drag_convect_dp_ml_init, gw_drag_convect_dp_ml_final, &
gw_drag_convect_dp_ml

! Typical module header
implicit none
Expand All @@ -52,6 +56,7 @@ module gw_drag
public :: gw_drag_readnl ! Read namelist
public :: gw_init ! Initialization
public :: gw_tend ! interface to actual parameterization
public :: gw_final ! Finalization

!
! PRIVATE: Rest of the data and interfaces are private to this module
Expand Down Expand Up @@ -195,6 +200,8 @@ module gw_drag
! Switch for using ML GW parameterisation for deep convection source
logical :: gw_convect_dp_ml = .false.
logical :: gw_convect_dp_ml_compare = .false.
character(len=132) :: gw_convect_dp_ml_net_path
character(len=132) :: gw_convect_dp_ml_norms

!==========================================================================
contains
Expand Down Expand Up @@ -237,7 +244,8 @@ subroutine gw_drag_readnl(nlfile)
gw_oro_south_fac, gw_limit_tau_without_eff, &
gw_lndscl_sgh, gw_prndl, gw_apply_tndmax, gw_qbo_hdepth_scaling, &
gw_top_taper, front_gaussian_width, &
gw_convect_dp_ml, gw_convect_dp_ml_compare
gw_convect_dp_ml, gw_convect_dp_ml_compare, &
gw_convect_dp_ml_net_path, gw_convect_dp_ml_norms
!----------------------------------------------------------------------

if (use_simple_phys) return
Expand Down Expand Up @@ -347,6 +355,12 @@ subroutine gw_drag_readnl(nlfile)
call mpi_bcast(gw_convect_dp_ml_compare, 1, mpi_logical, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: gw_convect_dp_ml_compare")

call mpi_bcast(gw_convect_dp_ml_net_path, len(gw_convect_dp_ml_net_path), mpi_character, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: gw_convect_dp_ml_net_path")

call mpi_bcast(gw_convect_dp_ml_norms, len(gw_convect_dp_ml_norms), mpi_character, mstrid, mpicom, ierr)
if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: gw_convect_dp_ml_norms")

! Check if fcrit2 was set.
call shr_assert(fcrit2 /= unset_r8, &
"gw_drag_readnl: fcrit2 must be set via the namelist."// &
Expand Down Expand Up @@ -974,6 +988,27 @@ subroutine gw_init()

end if

! Set up neccessary attributes if using ML scheme for convective drag
if ((gw_convect_dp_ml) .or. (gw_convect_dp_ml_compare)) then
! Load the convective drag net from TorchScript file
call gw_drag_convect_dp_ml_init(gw_convect_dp_ml_net_path, gw_convect_dp_ml_norms)

! Register fields with the output buffer
call addfld ('UTGW_NN ', (/ 'lev' /), 'A', 'm/s2', &
'U tendency due to convective gravity wave drag from NN.')
call addfld ('VTGW_NN ', (/ 'lev' /), 'A', 'm/s2', &
'V tendency due to convective gravity wave drag from NN.')
call register_vector_field('UTGW_NN', 'VTGW_NN')

! Register fields with the output buffer
call addfld ('UTGW_BERES ', (/ 'lev' /), 'A', 'm/s2', &
'U tendency due to convective gravity wave drag from BERES.')
call addfld ('VTGW_BERES ', (/ 'lev' /), 'A', 'm/s2', &
'V tendency due to convective gravity wave drag from NN.')
call register_vector_field('UTGW_BERES', 'VTGW_BERES')

endif

if (use_gw_convect_sh) then

ttend_sh_idx = pbuf_get_index('TTEND_SH')
Expand Down Expand Up @@ -1028,6 +1063,15 @@ subroutine gw_init()
call add_default('EKGW', 1, ' ')
end if

call addfld ('RHOI', (/ 'ilev' /), 'A', 'kg/m3', &
'density at interfaces')

call addfld ('DSE', (/ 'lev' /), 'I', 'J/kg', &
'dry static energy')

call addfld ('NMBV', (/ 'lev' /), 'I', 'J/kg', &
'Brunt Vaisala Frequency')

call addfld ('UTGW_TOTAL', (/ 'lev' /), 'A','m/s2', &
'Total U tendency due to gravity wave drag')
call addfld ('VTGW_TOTAL', (/ 'lev' /), 'A','m/s2', &
Expand Down Expand Up @@ -1236,6 +1280,15 @@ end subroutine handle_pio_error

!==========================================================================

subroutine gw_final()
! Destroy neccessary attributes if using ML scheme for convective drag
if ((gw_convect_dp_ml) .or. (gw_convect_dp_ml_compare)) then
call gw_drag_convect_dp_ml_final()
endif
end subroutine gw_final

!==========================================================================

subroutine gw_tend(state, pbuf, dt, ptend, cam_in, flx_heat)
!-----------------------------------------------------------------------
! Interface for multiple gravity wave drag parameterization.
Expand All @@ -1255,6 +1308,7 @@ subroutine gw_tend(state, pbuf, dt, ptend, cam_in, flx_heat)
use gw_oro, only: gw_oro_src
use gw_front, only: gw_cm_src
use gw_convect, only: gw_beres_src
use gw_ml, only: gw_drag_convect_dp_ml

!------------------------------Arguments--------------------------------
type(physics_state), intent(in) :: state ! physics state structure
Expand Down Expand Up @@ -1408,6 +1462,9 @@ subroutine gw_tend(state, pbuf, dt, ptend, cam_in, flx_heat)
real(r8) :: piln(state%ncol,pver+1)
real(r8) :: zm(state%ncol,pver)
real(r8) :: zi(state%ncol,pver+1)
real(r8) :: ps(state%ncol)
real(r8) :: lat(state%ncol)
real(r8) :: lon(state%ncol)
!------------------------------------------------------------------------

! Make local copy of input state.
Expand All @@ -1429,6 +1486,9 @@ subroutine gw_tend(state, pbuf, dt, ptend, cam_in, flx_heat)
piln = state1%lnpint(:ncol,:)
zm = state1%zm(:ncol,:)
zi = state1%zi(:ncol,:)
ps = state1%ps(:ncol)
lat = state1%lat(:ncol)
lon = state1%lon(:ncol)

lq = .true.
call physics_ptend_init(ptend, state1%psetcols, "Gravity wave drag", &
Expand Down Expand Up @@ -1498,6 +1558,9 @@ subroutine gw_tend(state, pbuf, dt, ptend, cam_in, flx_heat)
u, v, ttend_dp(:ncol,:), zm, src_level, tend_level, tau, &
ubm, ubi, xv, yv, c, hdepth, maxq0)

! TODO: If we are running with the ML scheme save tau to a temp variable so we
! Can reset to it instead of having it updated in the physics scheme.

if ((.not. gw_convect_dp_ml) .or. (gw_convect_dp_ml_compare)) then
! Solve for the drag profile with Beres source spectrum.
call gw_drag_prof(ncol, band_mid, p, src_level, tend_level, dt, &
Expand All @@ -1507,6 +1570,12 @@ subroutine gw_tend(state, pbuf, dt, ptend, cam_in, flx_heat)
ttgw_temp, qtgw_temp, egwdffi, gwut, dttdf, dttke, &
lapply_effgw_in=gw_apply_tndmax)

if (gw_convect_dp_ml_compare) then
! write fields out for comparison
call outfld('UTGW_BERES', utgw_temp, ncol, lchnk)
call outfld('VTGW_BERES', vtgw_temp, ncol, lchnk)
end if

if (.not. gw_convect_dp_ml) then
! Save the results to apply to ptend for simulation updates
qtgw = qtgw_temp
Expand All @@ -1521,37 +1590,38 @@ subroutine gw_tend(state, pbuf, dt, ptend, cam_in, flx_heat)
write(iulog,*) "Using the ML scheme for convective gravity waves."
end if

! Solve for the drag profile with Beres source spectrum.
! Placeholder to be replaced with the ML scheme
call gw_drag_prof(ncol, band_mid, p, src_level, tend_level, dt, &
t, vramp, &
piln, rhoi, nm, ni, ubm, ubi, xv, yv, &
effgw, c, kvtt, q, dse, tau, utgw_temp, vtgw_temp, &
ttgw_temp, qtgw_temp, egwdffi, gwut, dttdf, dttke, &
lapply_effgw_in=gw_apply_tndmax)
call gw_drag_convect_dp_ml(ncol, dt, &
u, v, t, dse, nm, ttend_dp(:ncol,:), zm, rhoi, ps, &
lat, lon, &
utgw_temp, vtgw_temp)

! write fields out for comparison
call outfld('UTGW_NN', utgw_temp, ncol, lchnk)
call outfld('VTGW_NN', vtgw_temp, ncol, lchnk)

if (gw_convect_dp_ml) then
! Save the results to apply to ptend for simulation updates
! TODO: Check how to handle tendencies not output by ML scheme
qtgw = qtgw_temp ! in the ml scheme there is no qtgw so use qtgw = 0.0
ttgw = ttgw_temp ! in the ml scheme there is no ttgw so use ttgw = 0.0
! Save the results to apply to ptend for simulation updates.
! Some tendencies are not supplied by the NN, set these to 0.0 as minor.
ttgw = 0.0
qtgw = 0.0
egwdffi = 0.0
gwut = 0.0
dttdf = 0.0
dttke = 0.0
utgw = utgw_temp
vtgw = vtgw_temp
! in the ml scheme there is not egwdffi set, so use egwdffi = 0.0
end if
end if

! Project stress into directional components.
taucd = calc_taucd(ncol, band_mid%ngwv, tend_level, tau, c, xv, yv, ubi)

! add the diffusion coefficients
! TODO: Check how to handle egwdffi not output by ML scheme
do k = 1, pver+1
egwdffi_tot(:,k) = egwdffi_tot(:,k) + egwdffi(:,k)
end do

! Store constituents tendencies
! TODO: Check how to handle qtgw not output by ML scheme
do m=1, pcnst
do k = 1, pver
ptend%q(:ncol,k,m) = ptend%q(:ncol,k,m) + qtgw(:,k,m)
Expand Down Expand Up @@ -2035,6 +2105,10 @@ subroutine gw_tend(state, pbuf, dt, ptend, cam_in, flx_heat)
call outfld('CLDLIQTGW', ptend%q(:,:,ixcldliq), pcols, lchnk)
call outfld('CLDICETGW', ptend%q(:,:,ixcldice), pcols, lchnk)

call outfld('RHOI', rhoi, ncol, lchnk)
call outfld('DSE', dse, ncol, lchnk)
call outfld('NMBV', nm, ncol, lchnk)

! Destroy objects.
call p%finalize()

Expand Down
Loading