Skip to content

Commit

Permalink
intermediate commit for global run script
Browse files Browse the repository at this point in the history
technically, this should work now, although I need to find a way to parallelise the embarrassing loop and possibly move the writing routine out. I will also need to implement the skipping of ocean grid cells. Finally, the south pole looks pretty okay.
  • Loading branch information
ray-chew committed May 14, 2024
1 parent 5ac0be0 commit 4d1fa53
Show file tree
Hide file tree
Showing 2 changed files with 254 additions and 0 deletions.
32 changes: 32 additions & 0 deletions inputs/icon_global_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import numpy as np
from src import var

params = var.params()

params.output_path = "/home/ray/git-projects/spec_appx/outputs/"
params.output_fn = "icon_merit_reg"
params.fn_grid = "../data/icon_compact.nc"
params.fn_topo = "../data/topo_compact.nc"

### South Pole
params.lat_extent = None
params.lon_extent = None

params.tri_set = [13, 104, 105, 106]

# Setup the Fourier parameters and object.
params.nhi = 24
params.nhj = 48

params.n_modes = 50

params.U, params.V = 10.0, 0.0

params.rect = True

params.debug = False
params.dfft_first_guess = True
params.refine = False
params.verbose = False

params.plot = True
222 changes: 222 additions & 0 deletions runs/icon_merit_global.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# %%
import sys

# set system path to find local modules
sys.path.append("..")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from src import io, var, utils, fourier, physics
from wrappers import interface
from vis import plotter, cart_plot

from IPython import get_ipython

ipython = get_ipython()

if ipython is not None:
ipython.run_line_magic("load_ext", "autoreload")
else:
print(ipython)

def autoreload():
if ipython is not None:
ipython.run_line_magic("autoreload", "2")

from sys import exit

if __name__ != "__main__":
exit(0)
# %%
autoreload()
from inputs.icon_regional_run import params

if params.self_test():
params.print()

print(params.path_compact_topo)

grid = var.grid()

# read grid
reader = io.ncdata(padding=params.padding, padding_tol=(60 - params.padding))

# writer object
writer = io.nc_writer(params)

reader.read_dat(params.path_compact_grid, grid)

clat_rad = np.copy(grid.clat)
clon_rad = np.copy(grid.clon)

grid.apply_f(utils.rad2deg)

n_cells = grid.clat.size

for c_idx in range(n_cells)[:1]:
c_idx = 90
print(c_idx)

topo = var.topo_cell()
lat_verts = grid.clat_vertices[c_idx]
lon_verts = grid.clon_vertices[c_idx]

lat_extent = [lat_verts.min() - 1.0,lat_verts.min() - 1.0,lat_verts.max() + 1.0]
lon_extent = [lon_verts.min() - 1.0,lon_verts.min() - 1.0,lon_verts.max() + 1.0]
# we only keep the topography that is inside this lat-lon extent.
lat_verts = np.array(lat_extent)
lon_verts = np.array(lon_extent)

params.lat_extent = lat_extent
params.lon_extent = lon_extent

# read topography
if not params.enable_merit:
reader.read_dat(params.fn_topo, topo)
reader.read_topo(topo, topo, lon_verts, lat_verts)
else:
reader.read_merit_topo(topo, params)
topo.topo[np.where(topo.topo < -500.0)] = -500.0

topo.gen_mgrids()
# %%
clon = np.array([grid.clon[c_idx]])
clat = np.array([grid.clat[c_idx]])
clon_vertices = np.array([grid.clon_vertices[c_idx]])
clat_vertices = np.array([grid.clat_vertices[c_idx]])

ncells = 1
nv = clon_vertices[0].size
# -- create the triangles
clon_vertices = np.where(clon_vertices < -180.0, clon_vertices + 360.0, clon_vertices)
clon_vertices = np.where(clon_vertices > 180.0, clon_vertices - 360.0, clon_vertices)

triangles = np.zeros((ncells, nv, 2), np.float32)

for i in range(0, ncells, 1):
triangles[i, :, 0] = np.array(clon_vertices[i, :])
triangles[i, :, 1] = np.array(clat_vertices[i, :])

print("--> triangles done")

cart_plot.lat_lon_icon(topo, triangles, ncells=ncells, clon=clon, clat=clat)


# %%
tri_idx = 0
# initialise cell object
cell = var.topo_cell()

simplex_lon = triangles[tri_idx, :, 0]
simplex_lat = triangles[tri_idx, :, 1]

utils.get_lat_lon_segments(
simplex_lat, simplex_lon, cell, topo, rect=params.rect
)

topo_orig = np.copy(cell.topo)

if params.dfft_first_guess:
nhi = len(cell.lon)
nhj = len(cell.lat)

first_guess = interface.get_pmf(nhi, nhj, params.U, params.V)
fobj_tri = fourier.f_trans(nhi, nhj)

#######################################################
# do fourier...

if not params.dfft_first_guess:
freqs, uw_pmf_freqs, dat_2D_fg0 = first_guess.sappx(cell, lmbda=0.0)

#######################################################
# do fourier using DFFT

if params.dfft_first_guess:
ampls, uw_pmf_freqs, dat_2D_fg0, kls = first_guess.dfft(cell)
freqs = np.copy(ampls)

print("uw_pmf_freqs_sum:", uw_pmf_freqs.sum())

fq_cpy = np.copy(freqs)

indices = []
max_ampls = []

for ii in range(params.n_modes):
max_idx = np.unravel_index(fq_cpy.argmax(), fq_cpy.shape)
indices.append(max_idx)
max_ampls.append(fq_cpy[max_idx])
max_val = fq_cpy[max_idx]
fq_cpy[max_idx] = 0.0

utils.get_lat_lon_segments(
simplex_lat, simplex_lon, cell, topo, rect=False
)

k_idxs = [pair[1] for pair in indices]
l_idxs = [pair[0] for pair in indices]

second_guess = interface.get_pmf(nhi, nhj, params.U, params.V)

if params.dfft_first_guess:
second_guess.fobj.set_kls(
k_idxs, l_idxs, recompute_nhij=True, components="real"
)
else:
second_guess.fobj.set_kls(k_idxs, l_idxs, recompute_nhij=False)

freqs, uw, dat_2D_sg0 = second_guess.sappx(cell, lmbda=1e-1, updt_analysis=True)

cell.topo = topo_orig

writer.output(tri_idx, clat_rad[tri_idx], clon_rad[tri_idx], cell.analysis)

cell.uw = uw

if params.plot:
fs = (15, 9.0)
v_extent = [dat_2D_sg0.min(), dat_2D_sg0.max()]

fig, axs = plt.subplots(2, 2, figsize=fs)

fig_obj = plotter.fig_obj(
fig, second_guess.fobj.nhar_i, second_guess.fobj.nhar_j
)
axs[0, 0] = fig_obj.phys_panel(
axs[0, 0],
dat_2D_sg0,
title="T%i: Reconstruction" % tri_idx,
xlabel="longitude [km]",
ylabel="latitude [km]",
extent=[cell.lon.min(), cell.lon.max(), cell.lat.min(), cell.lat.max()],
v_extent=v_extent,
)

axs[0, 1] = fig_obj.phys_panel(
axs[0, 1],
cell.topo * cell.mask,
title="T%i: Reconstruction" % tri_idx,
xlabel="longitude [km]",
ylabel="latitude [km]",
extent=[cell.lon.min(), cell.lon.max(), cell.lat.min(), cell.lat.max()],
v_extent=v_extent,
)

if params.dfft_first_guess:
axs[1, 0] = fig_obj.fft_freq_panel(
axs[1, 0], freqs, kls[0], kls[1], typ="real"
)
axs[1, 1] = fig_obj.fft_freq_panel(
axs[1, 1], uw, kls[0], kls[1], title="PMF spectrum", typ="real"
)
else:
axs[1, 0] = fig_obj.freq_panel(axs[1, 0], freqs)
axs[1, 1] = fig_obj.freq_panel(axs[1, 1], uw, title="PMF spectrum")

plt.tight_layout()
plt.savefig("%sT%i.pdf" % (params.path_output, tri_idx))
plt.show()
# %%

0 comments on commit 4d1fa53

Please sign in to comment.