diff --git a/runs/icon_merit_global.py b/runs/icon_merit_global.py index b7b0610..1d9b13d 100644 --- a/runs/icon_merit_global.py +++ b/runs/icon_merit_global.py @@ -211,8 +211,8 @@ def parallel_wrapper(grid, params, reader, writer): # autoreload() from pycsam.inputs.icon_regional_run import params -# from dask.distributed import Client -# import dask +from dask.distributed import Client +import dask # dask.config.set(scheduler='synchronous') @@ -243,18 +243,16 @@ def parallel_wrapper(grid, params, reader, writer): # NetCDF-4 reader does not work well with multithreading # Use only 1 thread per worker! (At least on my laptop) - # client = Client(threads_per_worker=1, n_workers=8) + client = Client(threads_per_worker=1, n_workers=2) - # lazy_results = [] + lazy_results = [] - for c_idx in range(n_cells)[15455:]: - # # for c_idx in range(n_cells)[180:190]: - # for c_idx in range(n_cells)[2046:2060]: - pw_run(c_idx) - # lazy_result = dask.delayed(pw_run)(c_idx) - # lazy_results.append(lazy_result) + for c_idx in range(n_cells): + # pw_run(c_idx) + lazy_result = dask.delayed(pw_run)(c_idx) + lazy_results.append(lazy_result) - # results = dask.compute(*lazy_results) + results = dask.compute(*lazy_results) - # for item in results: - # writer.duplicate(item.c_idx, item) + for item in results: + writer.duplicate(item.c_idx, item) diff --git a/src/io.py b/src/io.py index 14e5545..42853f5 100644 --- a/src/io.py +++ b/src/io.py @@ -6,7 +6,9 @@ import numpy as np import h5py import os + from datetime import datetime +from scipy import interpolate from ..src import utils @@ -171,6 +173,7 @@ def __init__(self, cell, params, verbose=False, is_parallel=False): self.merit_cg = params.merit_cg self.split_EW = False self.span = False + self.interp_lons = [] if not is_parallel: self.get_topo(cell) @@ -334,19 +337,22 @@ def __load_topo(self, cell, fns, dirs, lon_cnt, lat_cnt, lat_idx_rng, lon_idx_r ### Handles the case where a cell spans four topographic datasets cnt_lat = 0 cnt_lon = 0 - lat_low_old = np.ones((len(fns))) * np.inf - lat_high_old = np.ones((len(fns))) * np.inf - lon_low_old = np.ones((len(fns))) * np.inf - lon_high_old = np.ones((len(fns))) * np.inf - lat_nc_change, lon_nc_change = False, False for cnt, fn in enumerate(fns): - # try: - # test.isopen() - # except: + ############################################ + # + # Open data file + # + ############################################ test = nc.Dataset(dirs[cnt] + fn, "r") self.opened_dfs.append(test) + ############################################ + # + # Load lat data + # + ############################################ + lat = test["lat"] lat_min_idx = np.argmin(np.abs((lat - np.sign(lat) * 1e-4) - self.lat_verts.min())) lat_max_idx = np.argmin(np.abs((lat + np.sign(lat) * 1e-4) - self.lat_verts.max())) @@ -354,83 +360,37 @@ def __load_topo(self, cell, fns, dirs, lon_cnt, lat_cnt, lat_idx_rng, lon_idx_r lat_high = np.max((lat_min_idx, lat_max_idx)) lat_low = np.min((lat_min_idx, lat_max_idx)) - # lon = test["lon"] - # lon_min_idx = np.argmin(np.abs((lon - np.sign(lon) * 1e-4) - (self.lon_verts.min()))) - # lon_max_idx = np.argmin(np.abs((lon + np.sign(lon) * 1e-4) - (self.lon_verts.max()))) - - # lon_high = np.max((lon_min_idx, lon_max_idx)) - # lon_low = np.min((lon_min_idx, lon_max_idx)) - - ### Only add lat and lon elements if there are changes to the low and high indices identified: - # if (lon_low not in lon_low_old) and (lon_high not in lon_high_old): - # lon_nc_change = True - - # if (lat_low not in lat_low_old) and (lat_high not in lat_high_old): - # lat_nc_change = True - - ############################################ lat = test["lat"] - lon = test["lon"] - l_lat_bound, r_lat_bound = ( - self.fn_lat[lat_idx_rng[n_row]], - self.fn_lat[lat_idx_rng[n_row] + 1], - ) - - l_lon_bound, r_lon_bound = ( - self.fn_lon[lon_idx_rng[n_col]], - self.fn_lon[lon_idx_rng[n_col] + 1], - ) - - lon_rng = r_lon_bound - l_lon_bound - - lon_in_file = self.lon_verts[( (self.lon_verts - l_lon_bound) > 0 ) & ( (self.lon_verts - l_lon_bound) <= lon_rng )] + ############################################ + # + # Load lon data + # + ############################################ - if len(lon_in_file) == 0: - lon_high = np.argmin(np.abs(lon - r_lon_bound)) - lon_low = np.argmin(np.abs(lon - l_lon_bound)) + # in the case where fns contains both MERIT and REMA dataset, then for the n_row = 0, we do... + if any("REMA" in fn for fn in fns) and any("MERIT" in fn for fn in fns) and (not populate): + if (n_row == 0): + # run MERIT and REMA interpolation + new_lon = self.__do_interp_lon_1D(dirs, fns, cnt_lon, lon_cnt, n_col, lon_idx_rng) + self.interp_lons.append(new_lon) - else: - if not self.split_EW: - if lon_in_file.max() == self.lon_verts.max(): - lon_high = np.argmin(np.abs(lon - lon_in_file.max())) - else: - lon_high = np.argmin(np.abs(lon - r_lon_bound)) - - if lon_in_file.min() == self.lon_verts.min(): - lon_low = np.argmin(np.abs(lon - lon_in_file.min())) - else: - lon_low = np.argmin(np.abs(lon - l_lon_bound)) - - else: - if lon_in_file.max() == min(np.where(self.lon_verts < 0.0, self.lon_verts + 360.0, self.lon_verts)): - lon_high = np.argmin(np.abs(lon - r_lon_bound)) - lon_low = np.argmin(np.abs(lon - lon_in_file.min())) - else: - lon_high = np.argmin(np.abs(lon - r_lon_bound)) - - if lon_in_file.min() == (max(self.lon_verts[self.lon_verts < 0.0] + 360.0) - 360.0): - lon_high = np.argmin(np.abs(lon - lon_in_file.max())) - lon_low = np.argmin(np.abs(lon - l_lon_bound)) - else: - lon_low = np.argmin(np.abs(lon - l_lon_bound)) - # if r_lon_bound > lon_in_file.max(): - # lon_high = np.argmin(np.abs(lon - lon_in_file.max())) - - # if lon_in_file.min() > l_lon_bound: - # lon_low = np.argmin(np.abs(lon - lon_in_file.min())) + # flag stating that we have MERIT+REMA mix + self.span = True + lon = test["lon"] - lon_low_old[cnt] = lon_low - lon_high_old[cnt] = lon_high - lat_low_old[cnt] = lat_low - lat_high_old[cnt] = lat_high + lon_low, lon_high = self.__get_lon_idxs(lon, lon_idx_rng, n_col) + if not populate: if n_row == 0: # if (cnt_lon < (lon_cnt + 1)) and lon_nc_change: - nc_lon += lon_high - lon_low + if not self.span: + nc_lon += lon_high - lon_low + else: + nc_lon += len(new_lon) cnt_lon += 1 if n_col == 0: @@ -445,16 +405,22 @@ def __load_topo(self, cell, fns, dirs, lon_cnt, lat_cnt, lat_idx_rng, lon_idx_r else: topo = test["Elevation"][lat_low:lat_high, lon_low:lon_high] + + curr_lon = lon[lon_low:lon_high].tolist() + if n_col == 0: - cell.lat += lat[lat_low:lat_high].tolist() - if n_row == 0: + curr_lat = lat[lat_low:lat_high].tolist() + cell.lat += curr_lat + if not self.span: + if n_row == 0: + cell.lon += curr_lon + else: # interpolate topo data to new lon grid + new_lon = self.interp_lons[n_col] + topo = self.__interp_topo_2D(topo, curr_lat, curr_lon, new_lon) - if "MERIT" in fns and "REMA" in fns: - self.span = True - # new_lon = + if n_row == 0: + cell.lon += new_lon.tolist() - else: - cell.lon += lon[lon_low:lon_high].tolist() # # current dataset at n_row = 0 is a MERIT dataset # if "MERIT" in fn: @@ -464,20 +430,12 @@ def __load_topo(self, cell, fns, dirs, lon_cnt, lat_cnt, lat_idx_rng, lon_idx_r # if n_row > 0: # if ("REMA" in fn) and (self.prev_merit): - - lon_sz = lon_high - lon_low + if not self.span: + lon_sz = lon_high - lon_low + else: + lon_sz = len(self.interp_lons[n_col]) lat_sz = lat_high - lat_low - - # if lon_nc_change and cnt > 0: - # n_col += 1 - - # # if n_col == (lon_cnt + 1): - # # n_col = 0 - # if lat_nc_change and cnt > 0: - # n_row += 1 - # lat_sz_old = np.copy(lat_sz) - cell.topo[ lat_sz_old : lat_sz_old + lat_sz, lon_sz_old : lon_sz_old + lon_sz, @@ -493,9 +451,6 @@ def __load_topo(self, cell, fns, dirs, lon_cnt, lat_cnt, lat_idx_rng, lon_idx_r n_row += 1 lat_sz_old = np.copy(lat_sz) - lon_nc_change = False - lat_nc_change = False - test.close() if not populate: @@ -519,6 +474,79 @@ def __load_topo(self, cell, fns, dirs, lon_cnt, lat_cnt, lat_idx_rng, lon_idx_r cell.topo, (iint, iint), (iint, iint) ).mean(axis=(-1, -2))[::-1, :] + def __do_interp_lon_1D(self, dirs, fns, cnt_lon, lon_cnt, n_col, lon_idx_rng): + # Note: MERIT is always on n_row = 0 and REMA on n_row = 1 + + merit_path = dirs[cnt_lon] + fns[cnt_lon] + merit_dat = nc.Dataset(merit_path, "r") + merit_lon = merit_dat["lon"] + + rema_path = dirs[cnt_lon + lon_cnt + 1] + fns[cnt_lon + lon_cnt + 1] + rema_dat = nc.Dataset(rema_path, "r") + rema_lon = rema_dat["lon"] + + merit_lon_low, merit_lon_high = self.__get_lon_idxs(merit_lon, lon_idx_rng, n_col) + rema_lon_low, rema_lon_high = self.__get_lon_idxs(rema_lon, lon_idx_rng, n_col) + + merit_lon = merit_lon[merit_lon_low:merit_lon_high].tolist() + rema_lon = rema_lon[rema_lon_low:rema_lon_high].tolist() + + new_max = min(max(merit_lon), max(rema_lon)) + new_min = max(min(merit_lon), min(rema_lon)) + # we always use the number of data points in the merit lon grid: + new_sz = min(len(merit_lon),len(rema_lon)) + + new_lon = np.linspace(new_min, new_max, new_sz) + + return new_lon + + + @staticmethod + def __interp_topo_2D(topo, curr_lat, curr_lon, new_lon): + interp = interpolate.RegularGridInterpolator((curr_lat, curr_lon), topo) + XX, YY = np.meshgrid(new_lon, curr_lat) + return interp((YY, XX)) + + def __get_lon_idxs(self, lon, lon_idx_rng, n_col, ): + l_lon_bound, r_lon_bound = ( + self.fn_lon[lon_idx_rng[n_col]], + self.fn_lon[lon_idx_rng[n_col] + 1], + ) + + lon_rng = r_lon_bound - l_lon_bound + + lon_in_file = self.lon_verts[( (self.lon_verts - l_lon_bound) > 0 ) & ( (self.lon_verts - l_lon_bound) <= lon_rng )] + + if len(lon_in_file) == 0: + lon_high = np.argmin(np.abs(lon - r_lon_bound)) + lon_low = np.argmin(np.abs(lon - l_lon_bound)) + + else: + if not self.split_EW: + if lon_in_file.max() == self.lon_verts.max(): + lon_high = np.argmin(np.abs(lon - lon_in_file.max())) + else: + lon_high = np.argmin(np.abs(lon - r_lon_bound)) + + if lon_in_file.min() == self.lon_verts.min(): + lon_low = np.argmin(np.abs(lon - lon_in_file.min())) + else: + lon_low = np.argmin(np.abs(lon - l_lon_bound)) + + else: + if lon_in_file.max() == min(np.where(self.lon_verts < 0.0, self.lon_verts + 360.0, self.lon_verts)): + lon_high = np.argmin(np.abs(lon - r_lon_bound)) + lon_low = np.argmin(np.abs(lon - lon_in_file.min())) + else: + lon_high = np.argmin(np.abs(lon - r_lon_bound)) + + if lon_in_file.min() == (max(self.lon_verts[self.lon_verts < 0.0] + 360.0) - 360.0): + lon_high = np.argmin(np.abs(lon - lon_in_file.max())) + lon_low = np.argmin(np.abs(lon - l_lon_bound)) + else: + lon_low = np.argmin(np.abs(lon - l_lon_bound)) + + return lon_low, lon_high def close_all(self): for df in self.opened_dfs: