diff --git a/src/physics/cam/gw_drag.F90 b/src/physics/cam/gw_drag.F90 index ca36732005..252e054e01 100644 --- a/src/physics/cam/gw_drag.F90 +++ b/src/physics/cam/gw_drag.F90 @@ -42,6 +42,7 @@ module gw_drag use gw_common, only: GWBand use gw_convect, only: BeresSourceDesc use gw_front, only: CMSourceDesc + use gw_ml, only: gw_drag_convect_dp_ml_init, gw_drag_convect_dp_ml ! Typical module header implicit none @@ -990,7 +991,7 @@ subroutine gw_init() ! Set up neccessary attributes if using ML scheme for convective drag if ((gw_convect_dp_ml == 'on') .or. (gw_convect_dp_ml == 'bothon')) then ! Load the convective drag net from TorchScript file - call torch_model_load(gw_convect_dp_nn, gw_convect_dp_ml_net) + call gw_drag_convect_dp_ml_init(gw_convect_dp_ml_net, gw_convect_dp_ml_norms) endif if (use_gw_convect_sh) then diff --git a/src/physics/cam/gw_ml.F90 b/src/physics/cam/gw_ml.F90 index bb2204116a..9ec295308a 100644 --- a/src/physics/cam/gw_ml.F90 +++ b/src/physics/cam/gw_ml.F90 @@ -7,6 +7,8 @@ module gw_ml use gw_utils, only: r8 use ppgrid, only: pver +use spmd_utils, only: mpicom, mstrid=>masterprocid, masterproc, mpi_real8 +use cam_abortutils, only: endrun use ftorch @@ -14,22 +16,49 @@ module gw_ml private save -public :: gw_drag_convect_dp_ml +public :: gw_drag_convect_dp_ml, gw_drag_convect_dp_ml_init + +! Neural Net as read in by FTorch +type(torch_model) :: convect_net + +! Means for normalisation +real(r8) :: utgw_mean(pver), vtgw_mean(pver) +real(r8) :: u_mean(pver), v_mean(pver) +real(r8) :: t_mean(pver) +real(r8) :: dse_mean(pver) +real(r8) :: nm_mean(pver) +real(r8) :: netdt_mean(pver) +real(r8) :: zm_mean(pver) +real(r8) :: rhoi_mean(pver+1) +real(r8) :: ps_mean +real(r8) :: lat_mean +real(r8) :: lon_mean +! Standard deviations for normalisation +real(r8) :: utgw_std(pver), vtgw_std(pver) +real(r8) :: u_std(pver), v_std(pver) +real(r8) :: t_std(pver) +real(r8) :: dse_std(pver) +real(r8) :: nm_std(pver) +real(r8) :: netdt_std(pver) +real(r8) :: zm_std(pver) +real(r8) :: rhoi_std(pver+1) +real(r8) :: ps_std +real(r8) :: lat_std +real(r8) :: lon_std contains !========================================================================== -subroutine gw_drag_convect_dp_ml(convect_net, & - ncol, dt, & +subroutine gw_drag_convect_dp_ml(ncol, dt, & u, v, t, dse, nm, netdt, zm, rhoi, ps, lat, lon, & utgw, vtgw) ! Take data from CAM, normalise and concatenate before passing it to the Torch neural ! net to calculate u and v tendencies. - ! Neural Net as read in by FTorch - type(torch_model) :: convect_net + + ! Column dimension. integer, intent(in) :: ncol @@ -46,7 +75,7 @@ subroutine gw_drag_convect_dp_ml(convect_net, & ! Midpoint and interface Brunt-Vaisalla frequencies. real(r8), intent(in) :: nm(ncol,pver) ! Heating rate due to convection. - real(r8), intent(in) :: netdt(:,:) + real(r8), intent(in) :: netdt(ncol,pver) ! Midpoint geopotential altitudes. real(r8), intent(in) :: zm(ncol,pver) ! Interface densities. @@ -110,4 +139,219 @@ subroutine gw_drag_convect_dp_ml(convect_net, & end subroutine gw_drag_convect_dp_ml + +subroutine gw_drag_convect_dp_ml_init(neural_net_path, norms_path) + + character(len=132), intent(in) :: neural_net_path ! Filepath to PyTorch Torchscript net + character(len=132), intent(in) :: norms_path ! Filepath to NetCDF normalisation weights + + ! Load the convective drag net from TorchScript file + call torch_model_load(convect_net, neural_net_path) + ! read in normalisation weights + call read_norms(norms_path) + +end subroutine gw_drag_convect_dp_ml_init + + +subroutine read_norms(norms_path) + + use netcdf + use error_messages, only: handle_ncerr + + character(len=132), intent(in) :: norms_path ! Filepath to NetCDF normalisation weights + + integer :: ncid, varid, retva, ierr + character(len=*), parameter :: sub = 'gw_ml/F90 read_norms: ' + + ! Load weights from file in master process then broadcast + if (masterproc) then + ! Open the NetCDF file + call handle_ncerr( nf90_open(trim(norms_path), NF90_NOWRITE, ncid), & + "Error opening NetCDF norms file in gw_ml.F90") + + ! We do not need to read in dimensions here as we assume inputs match the grid. + + ! Read in variables (means and deviations). + call handle_ncerr( nf90_inq_varid(ncid, 'U_mean', varid), & + "Error getting U_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, u_mean), & + "Error getting U_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_inq_varid(ncid, 'U_std', varid), & + "Error getting U_std varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, u_std), & + "Error getting U_std varid from NetCDF Norms file in gw_ml.F90") + + call handle_ncerr( nf90_inq_varid(ncid, 'V_mean', varid), & + "Error getting V_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, v_mean), & + "Error getting V_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_inq_varid(ncid, 'V_std', varid), & + "Error getting V_std varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, v_std), & + "Error getting V_std varid from NetCDF Norms file in gw_ml.F90") + + call handle_ncerr( nf90_inq_varid(ncid, 'T_mean', varid), & + "Error getting T_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, t_mean), & + "Error getting t_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_inq_varid(ncid, 'T_std', varid), & + "Error getting T_std varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, t_std), & + "Error getting T_std varid from NetCDF Norms file in gw_ml.F90") + + call handle_ncerr( nf90_inq_varid(ncid, 'DSE_mean', varid), & + "Error getting DSE_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, dse_mean), & + "Error getting U_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_inq_varid(ncid, 'DSE_std', varid), & + "Error getting DSE_std varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, dse_std), & + "Error getting DSE_std varid from NetCDF Norms file in gw_ml.F90") + + call handle_ncerr( nf90_inq_varid(ncid, 'NMBV_mean', varid), & + "Error getting NMBV_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, nm_mean), & + "Error getting NMBV_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_inq_varid(ncid, 'NMBV_std', varid), & + "Error getting NMBV_std varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, nm_std), & + "Error getting NMBV_std varid from NetCDF Norms file in gw_ml.F90") + + call handle_ncerr( nf90_inq_varid(ncid, 'NETDT_mean', varid), & + "Error getting NETDT_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, netdt_mean), & + "Error getting NETDT_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_inq_varid(ncid, 'NETDT_std', varid), & + "Error getting NETDT_std varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, netdt_std), & + "Error getting NETDT_std varid from NetCDF Norms file in gw_ml.F90") + + call handle_ncerr( nf90_inq_varid(ncid, 'Z3_mean', varid), & + "Error getting Z3_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, zm_mean), & + "Error getting Z3_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_inq_varid(ncid, 'Z3_std', varid), & + "Error getting Z3_std varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, zm_std), & + "Error getting Z3_std varid from NetCDF Norms file in gw_ml.F90") + + call handle_ncerr( nf90_inq_varid(ncid, 'RHOI_mean', varid), & + "Error getting RHOI_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, rhoi_mean), & + "Error getting RHOI_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_inq_varid(ncid, 'RHOI_std', varid), & + "Error getting RHOI_std varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, rhoi_std), & + "Error getting RHOI_std varid from NetCDF Norms file in gw_ml.F90") + + call handle_ncerr( nf90_inq_varid(ncid, 'PS_mean', varid), & + "Error getting PS_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, ps_mean), & + "Error getting PS_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_inq_varid(ncid, 'PS_std', varid), & + "Error getting PS_std varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, ps_std), & + "Error getting PS_std varid from NetCDF Norms file in gw_ml.F90") + + call handle_ncerr( nf90_inq_varid(ncid, 'lat_mean', varid), & + "Error getting lat_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, lat_mean), & + "Error getting lat_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_inq_varid(ncid, 'lat_std', varid), & + "Error getting lat_std varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, lat_std), & + "Error getting lat_std varid from NetCDF Norms file in gw_ml.F90") + + call handle_ncerr( nf90_inq_varid(ncid, 'lon_mean', varid), & + "Error getting lon_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, lon_mean), & + "Error getting lon_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_inq_varid(ncid, 'lon_std', varid), & + "Error getting lon_std varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, lon_std), & + "Error getting lon_std varid from NetCDF Norms file in gw_ml.F90") + + call handle_ncerr( nf90_inq_varid(ncid, 'UTGWSPEC_mean', varid), & + "Error getting UTGWSPEC_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, utgw_mean), & + "Error getting UTGWSPEC_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_inq_varid(ncid, 'UTGWSPEC_std', varid), & + "Error getting UTGWSPEC_std varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, utgw_std), & + "Error getting UTGWSPEC_std varid from NetCDF Norms file in gw_ml.F90") + + call handle_ncerr( nf90_inq_varid(ncid, 'VTGWSPEC_mean', varid), & + "Error getting VTGWSPEC_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, vtgw_mean), & + "Error getting VTGWSPEC_mean varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_inq_varid(ncid, 'VTGWSPEC_std', varid), & + "Error getting VTGWSPEC_std varid from NetCDF Norms file in gw_ml.F90") + call handle_ncerr( nf90_get_var(ncid, varid, vtgw_std), & + "Error getting VTGWSPEC_std varid from NetCDF Norms file in gw_ml.F90") + + endif + + ! Broadcast normalisation variables to other processes + call mpi_bcast(utgw_mean, pver, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: utgw_mean from gw_ml.F90") + call mpi_bcast(utgw_std, pver, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: utgw_std from gw_ml.F90") + + call mpi_bcast(vtgw_mean, pver, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: vtgw_mean from gw_ml.F90") + call mpi_bcast(vtgw_std, pver, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: vtgw_std from gw_ml.F90") + + call mpi_bcast(u_mean, pver, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: u_mean from gw_ml.F90") + call mpi_bcast(u_std, pver, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: u_std from gw_ml.F90") + + call mpi_bcast(v_mean, pver, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: v_mean from gw_ml.F90") + call mpi_bcast(v_std, pver, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: v_std from gw_ml.F90") + + call mpi_bcast(t_mean, pver, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: t_mean from gw_ml.F90") + call mpi_bcast(t_std, pver, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: t_std from gw_ml.F90") + + call mpi_bcast(dse_mean, pver, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: dse_mean from gw_ml.F90") + call mpi_bcast(dse_std, pver, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: dse_std from gw_ml.F90") + + call mpi_bcast(nm_mean, pver, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: nm_mean from gw_ml.F90") + call mpi_bcast(nm_std, pver, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: nm_std from gw_ml.F90") + + call mpi_bcast(zm_mean, pver, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: zm_mean from gw_ml.F90") + call mpi_bcast(zm_std, pver, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: zm_std from gw_ml.F90") + + call mpi_bcast(rhoi_mean, pver+1, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: rhoi_mean from gw_ml.F90") + call mpi_bcast(rhoi_std, pver+1, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: rhoi_std from gw_ml.F90") + + call mpi_bcast(ps_mean, pver, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: ps_mean from gw_ml.F90") + call mpi_bcast(ps_std, pver, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: ps_std from gw_ml.F90") + + call mpi_bcast(lat_mean, 1, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: lat_mean from gw_ml.F90") + call mpi_bcast(lat_std, 1, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: lat_std from gw_ml.F90") + + call mpi_bcast(lon_mean, 1, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: lon_mean from gw_ml.F90") + call mpi_bcast(lon_std, 1, mpi_real8, mstrid, mpicom, ierr) + if (ierr /= 0) call endrun(sub//": FATAL: mpi_bcast: lon_std from gw_ml.F90") + +end subroutine read_norms + end module gw_ml