From ef2d8aecd7a16561612de22d160db15d4f977c3d Mon Sep 17 00:00:00 2001 From: raychew Date: Mon, 20 May 2024 19:45:26 +0200 Subject: [PATCH] updated global ICON script to latest machinery checked; the dfft and lsff results have been reproduced --- runs/icon_merit_global.py | 254 +++++++++++++++++++++++--------------- 1 file changed, 157 insertions(+), 97 deletions(-) diff --git a/runs/icon_merit_global.py b/runs/icon_merit_global.py index a276ff3..c99f9b9 100644 --- a/runs/icon_merit_global.py +++ b/runs/icon_merit_global.py @@ -3,8 +3,9 @@ import pandas as pd import matplotlib.pyplot as plt + from pycsam.src import io, var, utils, fourier -from pycsam.wrappers import interface +from pycsam.wrappers import interface, diagnostics from pycsam.vis import plotter, cart_plot from IPython import get_ipython @@ -105,21 +106,32 @@ def autoreload(): if params.plot: cart_plot.lat_lon_icon(topo, triangles, ncells=ncells, clon=clon, clat=clat) -# %% - - print(topo.topo.shape) - # %% tri_idx = 0 # initialise cell object cell = var.topo_cell() + tri = var.obj() - simplex_lon = triangles[tri_idx, :, 0] - simplex_lat = triangles[tri_idx, :, 1] + nhi = params.nhi + nhj = params.nhj - utils.get_lat_lon_segments( - simplex_lat, simplex_lon, cell, topo, rect=params.rect - ) + fa = interface.first_appx(nhi, nhj, params, topo) + sa = interface.second_appx(nhi, nhj, params, topo, tri) + + dplot = diagnostics.diag_plotter(params, nhi, nhj) + + # simplex_lon = triangles[tri_idx, :, 0] + # simplex_lat = triangles[tri_idx, :, 1] + + tri.tri_lon_verts = triangles[:, :, 0] + tri.tri_lat_verts = triangles[:, :, 1] + + # utils.get_lat_lon_segments( + # simplex_lat, simplex_lon, cell, topo, rect=params.rect + # ) + + simplex_lat = tri.tri_lat_verts[tri_idx] + simplex_lon = tri.tri_lon_verts[tri_idx] if utils.is_land(cell, simplex_lat, simplex_lon, topo): writer.output(c_idx, clat_rad[tri_idx], clon_rad[tri_idx], 0) @@ -127,115 +139,163 @@ def autoreload(): else: is_land = 1 - topo_orig = np.copy(cell.topo) + # topo_orig = np.copy(cell.topo) if params.dfft_first_guess: - nhi = len(cell.lon) - nhj = len(cell.lat) + # do tapering + if params.taper_fa: + interface.taper_quad(params, simplex_lat, simplex_lon, cell, topo) + else: + utils.get_lat_lon_segments( + simplex_lat, simplex_lon, cell, topo, rect=params.rect + ) + + dfft_run = interface.get_pmf(nhi, nhj, params.U, params.V) + ampls_fa, uw_fa, dat_2D_fa, kls_fa = dfft_run.dfft(cell) + + cell_fa = cell + + nhi = len(cell_fa.lon) + nhj = len(cell_fa.lat) + + sa.nhi = nhi + sa.nhj = nhj else: - nhi = params.nhi - nhj = params.nhj + cell_fa, ampls_fa, uw_fa, dat_2D_fa = fa.do(simplex_lat, simplex_lon) - first_guess = interface.get_pmf(nhi, nhj, params.U, params.V) - fobj_tri = fourier.f_trans(nhi, nhj) - ####################################################### - # do fourier... + sols = (cell_fa, ampls_fa, uw_fa, dat_2D_fa) - if not params.dfft_first_guess: - freqs, uw_pmf_freqs, dat_2D_fg0 = first_guess.sappx(cell, params.lmbda_fa) + v_extent = [dat_2D_fa.min(), dat_2D_fa.max()] - ####################################################### - # do fourier using DFFT + if params.dfft_first_guess: + dplot.show( + tri_idx, sols, kls=kls_fa, v_extent=v_extent, dfft_plot=True, + output_fig=False + ) + else: + dplot.show(tri_idx, sols, v_extent=v_extent, output_fig=False) + if params.recompute_rhs: + sols, sols_rc = sa.do(tri_idx, ampls_fa) + else: + sols = sa.do(tri_idx, ampls_fa) + + cell, ampls_sa, uw_sa, dat_2D_sa = sols + v_extent = [dat_2D_sa.min(), dat_2D_sa.max()] + if params.dfft_first_guess: - ampls, uw_pmf_freqs, dat_2D_fg0, kls = first_guess.dfft(cell) - freqs = np.copy(ampls) + dplot.show( + tri_idx, sols, kls=kls_fa, v_extent=v_extent, dfft_plot=True, + output_fig=False + ) + else: + dplot.show(tri_idx, sols, v_extent=v_extent, output_fig=False) - print("uw_pmf_freqs_sum:", uw_pmf_freqs.sum()) - fq_cpy = np.copy(freqs) - fq_cpy[ - np.isnan(fq_cpy) - ] = 0.0 # necessary. Otherwise, popping with fq_cpy.max() gives the np.nan entries first. - indices = [] - max_ampls = [] - for ii in range(params.n_modes): - max_idx = np.unravel_index(fq_cpy.argmax(), fq_cpy.shape) - indices.append(max_idx) - max_ampls.append(fq_cpy[max_idx]) - max_val = fq_cpy[max_idx] - fq_cpy[max_idx] = 0.0 - utils.get_lat_lon_segments( - simplex_lat, simplex_lon, cell, topo, rect=False - ) + # first_guess = interface.get_pmf(nhi, nhj, params.U, params.V) + # fobj_tri = fourier.f_trans(nhi, nhj) - k_idxs = [pair[1] for pair in indices] - l_idxs = [pair[0] for pair in indices] + # ####################################################### + # # do fourier... - second_guess = interface.get_pmf(nhi, nhj, params.U, params.V) + # if not params.dfft_first_guess: + # freqs, uw_pmf_freqs, dat_2D_fg0 = first_guess.sappx(cell, params.lmbda_fa) - if params.dfft_first_guess: - second_guess.fobj.set_kls( - k_idxs, l_idxs, recompute_nhij=True, components="real" - ) - else: - second_guess.fobj.set_kls(k_idxs, l_idxs, recompute_nhij=False) + # ####################################################### + # # do fourier using DFFT - freqs, uw, dat_2D_sg0 = second_guess.sappx(cell, lmbda=params.lmbda_sa, updt_analysis=True) + # if params.dfft_first_guess: + # ampls, uw_pmf_freqs, dat_2D_fg0, kls = first_guess.dfft(cell) + # freqs = np.copy(ampls) - cell.topo = topo_orig + # print("uw_pmf_freqs_sum:", uw_pmf_freqs.sum()) - writer.output(c_idx, clat_rad[tri_idx], clon_rad[tri_idx], is_land, cell.analysis) - - cell.uw = uw + # fq_cpy = np.copy(freqs) + # fq_cpy[ + # np.isnan(fq_cpy) + # ] = 0.0 # necessary. Otherwise, popping with fq_cpy.max() gives the np.nan entries first. - if params.plot: - fs = (15, 9.0) - v_extent = [dat_2D_sg0.min(), dat_2D_sg0.max()] - - fig, axs = plt.subplots(2, 2, figsize=fs) - - fig_obj = plotter.fig_obj( - fig, second_guess.fobj.nhar_i, second_guess.fobj.nhar_j - ) - axs[0, 0] = fig_obj.phys_panel( - axs[0, 0], - dat_2D_sg0, - title="T%i: Reconstruction" % tri_idx, - xlabel="longitude [km]", - ylabel="latitude [km]", - extent=[cell.lon.min(), cell.lon.max(), cell.lat.min(), cell.lat.max()], - v_extent=v_extent, - ) - - axs[0, 1] = fig_obj.phys_panel( - axs[0, 1], - cell.topo * cell.mask, - title="T%i: Reconstruction" % tri_idx, - xlabel="longitude [km]", - ylabel="latitude [km]", - extent=[cell.lon.min(), cell.lon.max(), cell.lat.min(), cell.lat.max()], - v_extent=v_extent, - ) - - if params.dfft_first_guess: - axs[1, 0] = fig_obj.fft_freq_panel( - axs[1, 0], freqs, kls[0], kls[1], typ="real" - ) - axs[1, 1] = fig_obj.fft_freq_panel( - axs[1, 1], uw, kls[0], kls[1], title="PMF spectrum", typ="real" - ) - else: - axs[1, 0] = fig_obj.freq_panel(axs[1, 0], freqs) - axs[1, 1] = fig_obj.freq_panel(axs[1, 1], uw, title="PMF spectrum") + # indices = [] + # max_ampls = [] + + # for ii in range(params.n_modes): + # max_idx = np.unravel_index(fq_cpy.argmax(), fq_cpy.shape) + # indices.append(max_idx) + # max_ampls.append(fq_cpy[max_idx]) + # max_val = fq_cpy[max_idx] + # fq_cpy[max_idx] = 0.0 + + # utils.get_lat_lon_segments( + # simplex_lat, simplex_lon, cell, topo, rect=False + # ) + + # k_idxs = [pair[1] for pair in indices] + # l_idxs = [pair[0] for pair in indices] + + # second_guess = interface.get_pmf(nhi, nhj, params.U, params.V) - plt.tight_layout() - plt.savefig("%sT%i.pdf" % (params.path_output, tri_idx)) - plt.show() + # if params.dfft_first_guess: + # second_guess.fobj.set_kls( + # k_idxs, l_idxs, recompute_nhij=True, components="real" + # ) + # else: + # second_guess.fobj.set_kls(k_idxs, l_idxs, recompute_nhij=False) + + # freqs, uw, dat_2D_sg0 = second_guess.sappx(cell, lmbda=params.lmbda_sa, updt_analysis=True) + + # cell.topo = topo_orig + + # writer.output(c_idx, clat_rad[tri_idx], clon_rad[tri_idx], is_land, cell.analysis) + + # cell.uw = uw + + # if params.plot: + # fs = (15, 9.0) + # v_extent = [dat_2D_sg0.min(), dat_2D_sg0.max()] + + # fig, axs = plt.subplots(2, 2, figsize=fs) + + # fig_obj = plotter.fig_obj( + # fig, second_guess.fobj.nhar_i, second_guess.fobj.nhar_j + # ) + # axs[0, 0] = fig_obj.phys_panel( + # axs[0, 0], + # dat_2D_sg0, + # title="T%i: Reconstruction" % tri_idx, + # xlabel="longitude [km]", + # ylabel="latitude [km]", + # extent=[cell.lon.min(), cell.lon.max(), cell.lat.min(), cell.lat.max()], + # v_extent=v_extent, + # ) + + # axs[0, 1] = fig_obj.phys_panel( + # axs[0, 1], + # cell.topo * cell.mask, + # title="T%i: Reconstruction" % tri_idx, + # xlabel="longitude [km]", + # ylabel="latitude [km]", + # extent=[cell.lon.min(), cell.lon.max(), cell.lat.min(), cell.lat.max()], + # v_extent=v_extent, + # ) + + # if params.dfft_first_guess: + # axs[1, 0] = fig_obj.fft_freq_panel( + # axs[1, 0], freqs, kls[0], kls[1], typ="real" + # ) + # axs[1, 1] = fig_obj.fft_freq_panel( + # axs[1, 1], uw, kls[0], kls[1], title="PMF spectrum", typ="real" + # ) + # else: + # axs[1, 0] = fig_obj.freq_panel(axs[1, 0], freqs) + # axs[1, 1] = fig_obj.freq_panel(axs[1, 1], uw, title="PMF spectrum") + + # plt.tight_layout() + # plt.savefig("%sT%i.pdf" % (params.path_output, tri_idx)) + # plt.show() # %% \ No newline at end of file