Skip to content

Commit

Permalink
i/o routine now works for all cells on the ICON R2B4 grid
Browse files Browse the repository at this point in the history
  • Loading branch information
ray-chew committed Jun 11, 2024
1 parent fe661f6 commit e03aefd
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 109 deletions.
24 changes: 11 additions & 13 deletions runs/icon_merit_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ def parallel_wrapper(grid, params, reader, writer):
# autoreload()
from pycsam.inputs.icon_regional_run import params

# from dask.distributed import Client
# import dask
from dask.distributed import Client
import dask

# dask.config.set(scheduler='synchronous')

Expand Down Expand Up @@ -243,18 +243,16 @@ def parallel_wrapper(grid, params, reader, writer):

# NetCDF-4 reader does not work well with multithreading
# Use only 1 thread per worker! (At least on my laptop)
# client = Client(threads_per_worker=1, n_workers=8)
client = Client(threads_per_worker=1, n_workers=2)

# lazy_results = []
lazy_results = []

for c_idx in range(n_cells)[15455:]:
# # for c_idx in range(n_cells)[180:190]:
# for c_idx in range(n_cells)[2046:2060]:
pw_run(c_idx)
# lazy_result = dask.delayed(pw_run)(c_idx)
# lazy_results.append(lazy_result)
for c_idx in range(n_cells):
# pw_run(c_idx)
lazy_result = dask.delayed(pw_run)(c_idx)
lazy_results.append(lazy_result)

# results = dask.compute(*lazy_results)
results = dask.compute(*lazy_results)

# for item in results:
# writer.duplicate(item.c_idx, item)
for item in results:
writer.duplicate(item.c_idx, item)
220 changes: 124 additions & 96 deletions src/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import numpy as np
import h5py
import os

from datetime import datetime
from scipy import interpolate

from ..src import utils

Expand Down Expand Up @@ -171,6 +173,7 @@ def __init__(self, cell, params, verbose=False, is_parallel=False):
self.merit_cg = params.merit_cg
self.split_EW = False
self.span = False
self.interp_lons = []

if not is_parallel:
self.get_topo(cell)
Expand Down Expand Up @@ -334,103 +337,60 @@ def __load_topo(self, cell, fns, dirs, lon_cnt, lat_cnt, lat_idx_rng, lon_idx_r
### Handles the case where a cell spans four topographic datasets
cnt_lat = 0
cnt_lon = 0
lat_low_old = np.ones((len(fns))) * np.inf
lat_high_old = np.ones((len(fns))) * np.inf
lon_low_old = np.ones((len(fns))) * np.inf
lon_high_old = np.ones((len(fns))) * np.inf
lat_nc_change, lon_nc_change = False, False

for cnt, fn in enumerate(fns):
# try:
# test.isopen()
# except:
############################################
#
# Open data file
#
############################################
test = nc.Dataset(dirs[cnt] + fn, "r")
self.opened_dfs.append(test)

############################################
#
# Load lat data
#
############################################

lat = test["lat"]
lat_min_idx = np.argmin(np.abs((lat - np.sign(lat) * 1e-4) - self.lat_verts.min()))
lat_max_idx = np.argmin(np.abs((lat + np.sign(lat) * 1e-4) - self.lat_verts.max()))

lat_high = np.max((lat_min_idx, lat_max_idx))
lat_low = np.min((lat_min_idx, lat_max_idx))

# lon = test["lon"]
# lon_min_idx = np.argmin(np.abs((lon - np.sign(lon) * 1e-4) - (self.lon_verts.min())))
# lon_max_idx = np.argmin(np.abs((lon + np.sign(lon) * 1e-4) - (self.lon_verts.max())))

# lon_high = np.max((lon_min_idx, lon_max_idx))
# lon_low = np.min((lon_min_idx, lon_max_idx))

### Only add lat and lon elements if there are changes to the low and high indices identified:
# if (lon_low not in lon_low_old) and (lon_high not in lon_high_old):
# lon_nc_change = True

# if (lat_low not in lat_low_old) and (lat_high not in lat_high_old):
# lat_nc_change = True

############################################
lat = test["lat"]
lon = test["lon"]

l_lat_bound, r_lat_bound = (
self.fn_lat[lat_idx_rng[n_row]],
self.fn_lat[lat_idx_rng[n_row] + 1],
)

l_lon_bound, r_lon_bound = (
self.fn_lon[lon_idx_rng[n_col]],
self.fn_lon[lon_idx_rng[n_col] + 1],
)

lon_rng = r_lon_bound - l_lon_bound

lon_in_file = self.lon_verts[( (self.lon_verts - l_lon_bound) > 0 ) & ( (self.lon_verts - l_lon_bound) <= lon_rng )]
############################################
#
# Load lon data
#
############################################

if len(lon_in_file) == 0:
lon_high = np.argmin(np.abs(lon - r_lon_bound))
lon_low = np.argmin(np.abs(lon - l_lon_bound))
# in the case where fns contains both MERIT and REMA dataset, then for the n_row = 0, we do...
if any("REMA" in fn for fn in fns) and any("MERIT" in fn for fn in fns) and (not populate):
if (n_row == 0):
# run MERIT and REMA interpolation
new_lon = self.__do_interp_lon_1D(dirs, fns, cnt_lon, lon_cnt, n_col, lon_idx_rng)
self.interp_lons.append(new_lon)

else:
if not self.split_EW:
if lon_in_file.max() == self.lon_verts.max():
lon_high = np.argmin(np.abs(lon - lon_in_file.max()))
else:
lon_high = np.argmin(np.abs(lon - r_lon_bound))

if lon_in_file.min() == self.lon_verts.min():
lon_low = np.argmin(np.abs(lon - lon_in_file.min()))
else:
lon_low = np.argmin(np.abs(lon - l_lon_bound))

else:
if lon_in_file.max() == min(np.where(self.lon_verts < 0.0, self.lon_verts + 360.0, self.lon_verts)):
lon_high = np.argmin(np.abs(lon - r_lon_bound))
lon_low = np.argmin(np.abs(lon - lon_in_file.min()))
else:
lon_high = np.argmin(np.abs(lon - r_lon_bound))

if lon_in_file.min() == (max(self.lon_verts[self.lon_verts < 0.0] + 360.0) - 360.0):
lon_high = np.argmin(np.abs(lon - lon_in_file.max()))
lon_low = np.argmin(np.abs(lon - l_lon_bound))
else:
lon_low = np.argmin(np.abs(lon - l_lon_bound))
# if r_lon_bound > lon_in_file.max():
# lon_high = np.argmin(np.abs(lon - lon_in_file.max()))

# if lon_in_file.min() > l_lon_bound:
# lon_low = np.argmin(np.abs(lon - lon_in_file.min()))
# flag stating that we have MERIT+REMA mix
self.span = True

lon = test["lon"]

lon_low_old[cnt] = lon_low
lon_high_old[cnt] = lon_high
lat_low_old[cnt] = lat_low
lat_high_old[cnt] = lat_high
lon_low, lon_high = self.__get_lon_idxs(lon, lon_idx_rng, n_col)


if not populate:
if n_row == 0:

# if (cnt_lon < (lon_cnt + 1)) and lon_nc_change:
nc_lon += lon_high - lon_low
if not self.span:
nc_lon += lon_high - lon_low
else:
nc_lon += len(new_lon)
cnt_lon += 1

if n_col == 0:
Expand All @@ -445,16 +405,22 @@ def __load_topo(self, cell, fns, dirs, lon_cnt, lat_cnt, lat_idx_rng, lon_idx_r

else:
topo = test["Elevation"][lat_low:lat_high, lon_low:lon_high]

curr_lon = lon[lon_low:lon_high].tolist()

if n_col == 0:
cell.lat += lat[lat_low:lat_high].tolist()
if n_row == 0:
curr_lat = lat[lat_low:lat_high].tolist()
cell.lat += curr_lat
if not self.span:
if n_row == 0:
cell.lon += curr_lon
else: # interpolate topo data to new lon grid
new_lon = self.interp_lons[n_col]
topo = self.__interp_topo_2D(topo, curr_lat, curr_lon, new_lon)

if "MERIT" in fns and "REMA" in fns:
self.span = True
# new_lon =
if n_row == 0:
cell.lon += new_lon.tolist()

else:
cell.lon += lon[lon_low:lon_high].tolist()

# # current dataset at n_row = 0 is a MERIT dataset
# if "MERIT" in fn:
Expand All @@ -464,20 +430,12 @@ def __load_topo(self, cell, fns, dirs, lon_cnt, lat_cnt, lat_idx_rng, lon_idx_r
# if n_row > 0:
# if ("REMA" in fn) and (self.prev_merit):


lon_sz = lon_high - lon_low
if not self.span:
lon_sz = lon_high - lon_low
else:
lon_sz = len(self.interp_lons[n_col])
lat_sz = lat_high - lat_low


# if lon_nc_change and cnt > 0:
# n_col += 1

# # if n_col == (lon_cnt + 1):
# # n_col = 0
# if lat_nc_change and cnt > 0:
# n_row += 1
# lat_sz_old = np.copy(lat_sz)

cell.topo[
lat_sz_old : lat_sz_old + lat_sz,
lon_sz_old : lon_sz_old + lon_sz,
Expand All @@ -493,9 +451,6 @@ def __load_topo(self, cell, fns, dirs, lon_cnt, lat_cnt, lat_idx_rng, lon_idx_r
n_row += 1
lat_sz_old = np.copy(lat_sz)

lon_nc_change = False
lat_nc_change = False

test.close()

if not populate:
Expand All @@ -519,6 +474,79 @@ def __load_topo(self, cell, fns, dirs, lon_cnt, lat_cnt, lat_idx_rng, lon_idx_r
cell.topo, (iint, iint), (iint, iint)
).mean(axis=(-1, -2))[::-1, :]

def __do_interp_lon_1D(self, dirs, fns, cnt_lon, lon_cnt, n_col, lon_idx_rng):
# Note: MERIT is always on n_row = 0 and REMA on n_row = 1

merit_path = dirs[cnt_lon] + fns[cnt_lon]
merit_dat = nc.Dataset(merit_path, "r")
merit_lon = merit_dat["lon"]

rema_path = dirs[cnt_lon + lon_cnt + 1] + fns[cnt_lon + lon_cnt + 1]
rema_dat = nc.Dataset(rema_path, "r")
rema_lon = rema_dat["lon"]

merit_lon_low, merit_lon_high = self.__get_lon_idxs(merit_lon, lon_idx_rng, n_col)
rema_lon_low, rema_lon_high = self.__get_lon_idxs(rema_lon, lon_idx_rng, n_col)

merit_lon = merit_lon[merit_lon_low:merit_lon_high].tolist()
rema_lon = rema_lon[rema_lon_low:rema_lon_high].tolist()

new_max = min(max(merit_lon), max(rema_lon))
new_min = max(min(merit_lon), min(rema_lon))
# we always use the number of data points in the merit lon grid:
new_sz = min(len(merit_lon),len(rema_lon))

new_lon = np.linspace(new_min, new_max, new_sz)

return new_lon


@staticmethod
def __interp_topo_2D(topo, curr_lat, curr_lon, new_lon):
interp = interpolate.RegularGridInterpolator((curr_lat, curr_lon), topo)
XX, YY = np.meshgrid(new_lon, curr_lat)
return interp((YY, XX))

def __get_lon_idxs(self, lon, lon_idx_rng, n_col, ):
l_lon_bound, r_lon_bound = (
self.fn_lon[lon_idx_rng[n_col]],
self.fn_lon[lon_idx_rng[n_col] + 1],
)

lon_rng = r_lon_bound - l_lon_bound

lon_in_file = self.lon_verts[( (self.lon_verts - l_lon_bound) > 0 ) & ( (self.lon_verts - l_lon_bound) <= lon_rng )]

if len(lon_in_file) == 0:
lon_high = np.argmin(np.abs(lon - r_lon_bound))
lon_low = np.argmin(np.abs(lon - l_lon_bound))

else:
if not self.split_EW:
if lon_in_file.max() == self.lon_verts.max():
lon_high = np.argmin(np.abs(lon - lon_in_file.max()))
else:
lon_high = np.argmin(np.abs(lon - r_lon_bound))

if lon_in_file.min() == self.lon_verts.min():
lon_low = np.argmin(np.abs(lon - lon_in_file.min()))
else:
lon_low = np.argmin(np.abs(lon - l_lon_bound))

else:
if lon_in_file.max() == min(np.where(self.lon_verts < 0.0, self.lon_verts + 360.0, self.lon_verts)):
lon_high = np.argmin(np.abs(lon - r_lon_bound))
lon_low = np.argmin(np.abs(lon - lon_in_file.min()))
else:
lon_high = np.argmin(np.abs(lon - r_lon_bound))

if lon_in_file.min() == (max(self.lon_verts[self.lon_verts < 0.0] + 360.0) - 360.0):
lon_high = np.argmin(np.abs(lon - lon_in_file.max()))
lon_low = np.argmin(np.abs(lon - l_lon_bound))
else:
lon_low = np.argmin(np.abs(lon - l_lon_bound))

return lon_low, lon_high

def close_all(self):
for df in self.opened_dfs:
Expand Down

0 comments on commit e03aefd

Please sign in to comment.