Skip to content

Commit

Permalink
updated global ICON script to latest machinery
Browse files Browse the repository at this point in the history
checked; the dfft and lsff results have been reproduced
  • Loading branch information
ray-chew committed May 20, 2024
1 parent 1d815da commit ef2d8ae
Showing 1 changed file with 157 additions and 97 deletions.
254 changes: 157 additions & 97 deletions runs/icon_merit_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import pandas as pd
import matplotlib.pyplot as plt


from pycsam.src import io, var, utils, fourier
from pycsam.wrappers import interface
from pycsam.wrappers import interface, diagnostics
from pycsam.vis import plotter, cart_plot

from IPython import get_ipython
Expand Down Expand Up @@ -105,137 +106,196 @@ def autoreload():
if params.plot:
cart_plot.lat_lon_icon(topo, triangles, ncells=ncells, clon=clon, clat=clat)

# %%

print(topo.topo.shape)

# %%
tri_idx = 0
# initialise cell object
cell = var.topo_cell()
tri = var.obj()

simplex_lon = triangles[tri_idx, :, 0]
simplex_lat = triangles[tri_idx, :, 1]
nhi = params.nhi
nhj = params.nhj

utils.get_lat_lon_segments(
simplex_lat, simplex_lon, cell, topo, rect=params.rect
)
fa = interface.first_appx(nhi, nhj, params, topo)
sa = interface.second_appx(nhi, nhj, params, topo, tri)

dplot = diagnostics.diag_plotter(params, nhi, nhj)

# simplex_lon = triangles[tri_idx, :, 0]
# simplex_lat = triangles[tri_idx, :, 1]

tri.tri_lon_verts = triangles[:, :, 0]
tri.tri_lat_verts = triangles[:, :, 1]

# utils.get_lat_lon_segments(
# simplex_lat, simplex_lon, cell, topo, rect=params.rect
# )

simplex_lat = tri.tri_lat_verts[tri_idx]
simplex_lon = tri.tri_lon_verts[tri_idx]

if utils.is_land(cell, simplex_lat, simplex_lon, topo):
writer.output(c_idx, clat_rad[tri_idx], clon_rad[tri_idx], 0)
continue
else:
is_land = 1

topo_orig = np.copy(cell.topo)
# topo_orig = np.copy(cell.topo)

if params.dfft_first_guess:
nhi = len(cell.lon)
nhj = len(cell.lat)
# do tapering
if params.taper_fa:
interface.taper_quad(params, simplex_lat, simplex_lon, cell, topo)
else:
utils.get_lat_lon_segments(
simplex_lat, simplex_lon, cell, topo, rect=params.rect
)

dfft_run = interface.get_pmf(nhi, nhj, params.U, params.V)
ampls_fa, uw_fa, dat_2D_fa, kls_fa = dfft_run.dfft(cell)

cell_fa = cell

nhi = len(cell_fa.lon)
nhj = len(cell_fa.lat)

sa.nhi = nhi
sa.nhj = nhj
else:
nhi = params.nhi
nhj = params.nhj
cell_fa, ampls_fa, uw_fa, dat_2D_fa = fa.do(simplex_lat, simplex_lon)

first_guess = interface.get_pmf(nhi, nhj, params.U, params.V)
fobj_tri = fourier.f_trans(nhi, nhj)

#######################################################
# do fourier...
sols = (cell_fa, ampls_fa, uw_fa, dat_2D_fa)

if not params.dfft_first_guess:
freqs, uw_pmf_freqs, dat_2D_fg0 = first_guess.sappx(cell, params.lmbda_fa)
v_extent = [dat_2D_fa.min(), dat_2D_fa.max()]

#######################################################
# do fourier using DFFT
if params.dfft_first_guess:
dplot.show(
tri_idx, sols, kls=kls_fa, v_extent=v_extent, dfft_plot=True,
output_fig=False
)
else:
dplot.show(tri_idx, sols, v_extent=v_extent, output_fig=False)

if params.recompute_rhs:
sols, sols_rc = sa.do(tri_idx, ampls_fa)
else:
sols = sa.do(tri_idx, ampls_fa)

cell, ampls_sa, uw_sa, dat_2D_sa = sols
v_extent = [dat_2D_sa.min(), dat_2D_sa.max()]

if params.dfft_first_guess:
ampls, uw_pmf_freqs, dat_2D_fg0, kls = first_guess.dfft(cell)
freqs = np.copy(ampls)
dplot.show(
tri_idx, sols, kls=kls_fa, v_extent=v_extent, dfft_plot=True,
output_fig=False
)
else:
dplot.show(tri_idx, sols, v_extent=v_extent, output_fig=False)

print("uw_pmf_freqs_sum:", uw_pmf_freqs.sum())

fq_cpy = np.copy(freqs)
fq_cpy[
np.isnan(fq_cpy)
] = 0.0 # necessary. Otherwise, popping with fq_cpy.max() gives the np.nan entries first.

indices = []
max_ampls = []

for ii in range(params.n_modes):
max_idx = np.unravel_index(fq_cpy.argmax(), fq_cpy.shape)
indices.append(max_idx)
max_ampls.append(fq_cpy[max_idx])
max_val = fq_cpy[max_idx]
fq_cpy[max_idx] = 0.0

utils.get_lat_lon_segments(
simplex_lat, simplex_lon, cell, topo, rect=False
)
# first_guess = interface.get_pmf(nhi, nhj, params.U, params.V)
# fobj_tri = fourier.f_trans(nhi, nhj)

k_idxs = [pair[1] for pair in indices]
l_idxs = [pair[0] for pair in indices]
# #######################################################
# # do fourier...

second_guess = interface.get_pmf(nhi, nhj, params.U, params.V)
# if not params.dfft_first_guess:
# freqs, uw_pmf_freqs, dat_2D_fg0 = first_guess.sappx(cell, params.lmbda_fa)

if params.dfft_first_guess:
second_guess.fobj.set_kls(
k_idxs, l_idxs, recompute_nhij=True, components="real"
)
else:
second_guess.fobj.set_kls(k_idxs, l_idxs, recompute_nhij=False)
# #######################################################
# # do fourier using DFFT

freqs, uw, dat_2D_sg0 = second_guess.sappx(cell, lmbda=params.lmbda_sa, updt_analysis=True)
# if params.dfft_first_guess:
# ampls, uw_pmf_freqs, dat_2D_fg0, kls = first_guess.dfft(cell)
# freqs = np.copy(ampls)

cell.topo = topo_orig
# print("uw_pmf_freqs_sum:", uw_pmf_freqs.sum())

writer.output(c_idx, clat_rad[tri_idx], clon_rad[tri_idx], is_land, cell.analysis)

cell.uw = uw
# fq_cpy = np.copy(freqs)
# fq_cpy[
# np.isnan(fq_cpy)
# ] = 0.0 # necessary. Otherwise, popping with fq_cpy.max() gives the np.nan entries first.

if params.plot:
fs = (15, 9.0)
v_extent = [dat_2D_sg0.min(), dat_2D_sg0.max()]

fig, axs = plt.subplots(2, 2, figsize=fs)

fig_obj = plotter.fig_obj(
fig, second_guess.fobj.nhar_i, second_guess.fobj.nhar_j
)
axs[0, 0] = fig_obj.phys_panel(
axs[0, 0],
dat_2D_sg0,
title="T%i: Reconstruction" % tri_idx,
xlabel="longitude [km]",
ylabel="latitude [km]",
extent=[cell.lon.min(), cell.lon.max(), cell.lat.min(), cell.lat.max()],
v_extent=v_extent,
)

axs[0, 1] = fig_obj.phys_panel(
axs[0, 1],
cell.topo * cell.mask,
title="T%i: Reconstruction" % tri_idx,
xlabel="longitude [km]",
ylabel="latitude [km]",
extent=[cell.lon.min(), cell.lon.max(), cell.lat.min(), cell.lat.max()],
v_extent=v_extent,
)

if params.dfft_first_guess:
axs[1, 0] = fig_obj.fft_freq_panel(
axs[1, 0], freqs, kls[0], kls[1], typ="real"
)
axs[1, 1] = fig_obj.fft_freq_panel(
axs[1, 1], uw, kls[0], kls[1], title="PMF spectrum", typ="real"
)
else:
axs[1, 0] = fig_obj.freq_panel(axs[1, 0], freqs)
axs[1, 1] = fig_obj.freq_panel(axs[1, 1], uw, title="PMF spectrum")
# indices = []
# max_ampls = []

# for ii in range(params.n_modes):
# max_idx = np.unravel_index(fq_cpy.argmax(), fq_cpy.shape)
# indices.append(max_idx)
# max_ampls.append(fq_cpy[max_idx])
# max_val = fq_cpy[max_idx]
# fq_cpy[max_idx] = 0.0

# utils.get_lat_lon_segments(
# simplex_lat, simplex_lon, cell, topo, rect=False
# )

# k_idxs = [pair[1] for pair in indices]
# l_idxs = [pair[0] for pair in indices]

# second_guess = interface.get_pmf(nhi, nhj, params.U, params.V)

plt.tight_layout()
plt.savefig("%sT%i.pdf" % (params.path_output, tri_idx))
plt.show()
# if params.dfft_first_guess:
# second_guess.fobj.set_kls(
# k_idxs, l_idxs, recompute_nhij=True, components="real"
# )
# else:
# second_guess.fobj.set_kls(k_idxs, l_idxs, recompute_nhij=False)

# freqs, uw, dat_2D_sg0 = second_guess.sappx(cell, lmbda=params.lmbda_sa, updt_analysis=True)

# cell.topo = topo_orig

# writer.output(c_idx, clat_rad[tri_idx], clon_rad[tri_idx], is_land, cell.analysis)

# cell.uw = uw

# if params.plot:
# fs = (15, 9.0)
# v_extent = [dat_2D_sg0.min(), dat_2D_sg0.max()]

# fig, axs = plt.subplots(2, 2, figsize=fs)

# fig_obj = plotter.fig_obj(
# fig, second_guess.fobj.nhar_i, second_guess.fobj.nhar_j
# )
# axs[0, 0] = fig_obj.phys_panel(
# axs[0, 0],
# dat_2D_sg0,
# title="T%i: Reconstruction" % tri_idx,
# xlabel="longitude [km]",
# ylabel="latitude [km]",
# extent=[cell.lon.min(), cell.lon.max(), cell.lat.min(), cell.lat.max()],
# v_extent=v_extent,
# )

# axs[0, 1] = fig_obj.phys_panel(
# axs[0, 1],
# cell.topo * cell.mask,
# title="T%i: Reconstruction" % tri_idx,
# xlabel="longitude [km]",
# ylabel="latitude [km]",
# extent=[cell.lon.min(), cell.lon.max(), cell.lat.min(), cell.lat.max()],
# v_extent=v_extent,
# )

# if params.dfft_first_guess:
# axs[1, 0] = fig_obj.fft_freq_panel(
# axs[1, 0], freqs, kls[0], kls[1], typ="real"
# )
# axs[1, 1] = fig_obj.fft_freq_panel(
# axs[1, 1], uw, kls[0], kls[1], title="PMF spectrum", typ="real"
# )
# else:
# axs[1, 0] = fig_obj.freq_panel(axs[1, 0], freqs)
# axs[1, 1] = fig_obj.freq_panel(axs[1, 1], uw, title="PMF spectrum")

# plt.tight_layout()
# plt.savefig("%sT%i.pdf" % (params.path_output, tri_idx))
# plt.show()


# %%

0 comments on commit ef2d8ae

Please sign in to comment.