Skip to content

Commit

Permalink
intermediate commit for robust MERIT I/O across E-W split
Browse files Browse the repository at this point in the history
requires cleaning up, and some corner cases still exist...
  • Loading branch information
ray-chew committed May 27, 2024
1 parent fe6cdf8 commit dbf7915
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 26 deletions.
46 changes: 38 additions & 8 deletions runs/icon_merit_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,32 @@ def do_cell(c_idx,
print(c_idx)

topo = var.topo_cell()

lat_verts = grid.clat_vertices[c_idx]
lon_verts = grid.clon_vertices[c_idx]

lat_extent = [lat_verts.min() - 0.0,lat_verts.min() - 0.0,lat_verts.max() + 0.0]
lon_extent = [lon_verts.min() - 0.0,lon_verts.min() - 0.0,lon_verts.max() + 0.0]
# if ( (lon_verts.max() - lon_verts.min()) > 180.0 ):
# lon_verts[np.argmin(lon_verts)] += 360.0

# clon = utils.rescale(grid.clon[c_idx], rng=[lon_verts.min(),lon_verts.max()])
# clat = utils.rescale(grid.clat[c_idx], rng=[lat_verts.min(),lat_verts.max()])

# check = utils.gen_triangle(lon_verts, lat_verts)

# print("is center in triangle:", check.vec_get_mask((clon, clat)))

# lat_expand = 0.0
# lat_extent = [lat_verts.min() - lat_expand,lat_verts.min() - lat_expand,lat_verts.max() + lat_expand]

# lon_expand = 0.0
# lon_extent = [lon_verts.min() - lon_expand,lon_verts.min() - lon_expand,lon_verts.max() + lon_expand]

lat_extent = lat_verts
lon_extent = lon_verts
# we only keep the topography that is inside this lat-lon extent.

lat_extent, lon_extent = utils.handle_latlon_expansion(lat_extent, lon_extent)

lat_verts = np.array(lat_extent)
lon_verts = np.array(lon_extent)

Expand All @@ -60,14 +80,24 @@ def do_cell(c_idx,

clon = np.array([grid.clon[c_idx]])
clat = np.array([grid.clat[c_idx]])
clon_vertices = np.array([grid.clon_vertices[c_idx]])
clat_vertices = np.array([grid.clat_vertices[c_idx]])
# clon = np.array([clon])
# clat = np.array([clat])
# clon_vertices = np.array([grid.clon_vertices[c_idx]])
# clat_vertices = np.array([grid.clat_vertices[c_idx]])
clon_vertices = np.array([lon_verts])
clat_vertices = np.array([lat_verts])


ncells = 1
nv = clon_vertices[0].size
# -- create the triangles
clon_vertices = np.where(clon_vertices < -180.0, clon_vertices + 360.0, clon_vertices)
clon_vertices = np.where(clon_vertices > 180.0, clon_vertices - 360.0, clon_vertices)
# clon_vertices = np.where(clon_vertices < -180.0, clon_vertices + 360.0, clon_vertices)
# clon_vertices = np.where(clon_vertices > 180.0, clon_vertices - 360.0, clon_vertices)

# if ( (clon_vertices.max() - clon_vertices.min()) > 180.0 ):
if reader.split_EW:
clon_vertices[clon_vertices < 0.0] += 360.0


triangles = np.zeros((ncells, nv, 2))

Expand Down Expand Up @@ -212,8 +242,8 @@ def parallel_wrapper(grid, params, reader, writer):
# client = Client(threads_per_worker=1, n_workers=8)

# lazy_results = []

for c_idx in range(n_cells):
# for c_idx in range(n_cells)[180:190]:
for c_idx in range(n_cells)[2048:2050]:
pw_run(c_idx)
# lazy_result = dask.delayed(pw_run)(c_idx)
# lazy_results.append(lazy_result)
Expand Down
109 changes: 92 additions & 17 deletions src/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(self, cell, params, verbose=False, is_parallel=False):
90.0,
120.0,
150.0,
180.0,
180.0
]
)
self.fn_lat = np.array([90.0, 60.0, 30.0, 0.0, -30.0, -60.0, -90.0])
Expand All @@ -169,24 +169,42 @@ def __init__(self, cell, params, verbose=False, is_parallel=False):
self.lon_verts = np.array(params.lon_extent)

self.merit_cg = params.merit_cg
self.split_EW = False

if not is_parallel:
self.get_topo(cell)

self.is_parallel = is_parallel

def get_topo(self, cell):

# if lat_verts


lat_min_idx = self.__compute_idx(self.lat_verts.min(), "min", "lat")
lat_max_idx = self.__compute_idx(self.lat_verts.max(), "max", "lat")

lon_min_idx = self.__compute_idx(self.lon_verts.min(), "min", "lon")
lon_max_idx = self.__compute_idx(self.lon_verts.max(), "max", "lon")

if ( (self.lon_verts.max() - self.lon_verts.min()) > 180.0 ):
# lon_max_idx, lon_min_idx = lon_min_idx, lon_max_idx
self.split_EW = True

lon_idx_rng = list(range(lon_max_idx, len(self.fn_lon) - 1 )) + list(range(0,lon_min_idx + 1))

else:
if lon_min_idx == lon_max_idx:
lon_max_idx += 1
lon_idx_rng = list(range(lon_min_idx, lon_max_idx))

lat_idx_rng = list(range(lat_max_idx, lat_min_idx))

fns, dirs, lon_cnt, lat_cnt = self.__get_fns(
lat_min_idx, lat_max_idx, lon_min_idx, lon_max_idx
lat_idx_rng, lon_idx_rng
)

self.__load_topo(cell, fns, dirs, lon_cnt, lat_cnt)
self.__load_topo(cell, fns, dirs, lon_cnt, lat_cnt, lat_idx_rng, lon_idx_rng)

def __compute_idx(self, vert, typ, direction):
"""Given a point ``vert``, look up which MERIT NetCDF file contains this point."""
Expand All @@ -213,6 +231,9 @@ def __compute_idx(self, vert, typ, direction):
else:
where_idx -= 1

if where_idx == (len(fn_int) - 1):
where_idx -= 1

where_idx = int(where_idx)

if self.verbose:
Expand All @@ -222,12 +243,12 @@ def __compute_idx(self, vert, typ, direction):

return where_idx

def __get_fns(self, lat_min_idx, lat_max_idx, lon_min_idx, lon_max_idx):
def __get_fns(self, lat_idx_rng, lon_idx_rng):
"""Construct the full filenames required for the loading of the topographic data from the indices identified in :func:`src.io.ncdata.read_merit_topo.__compute_idx`"""
fns = []
dirs = []

for lat_cnt, lat_idx in enumerate(range(lat_max_idx, lat_min_idx)):
for lat_cnt, lat_idx in enumerate(lat_idx_rng):
l_lat_bound, r_lat_bound = (
self.fn_lat[lat_idx],
self.fn_lat[lat_idx + 1],
Expand All @@ -245,7 +266,7 @@ def __get_fns(self, lat_min_idx, lat_max_idx, lon_min_idx, lon_max_idx):
self.rema = False
self.dir = self.dir.replace("REMA", "MERIT")

for lon_cnt, lon_idx in enumerate(range(lon_min_idx, lon_max_idx)):
for lon_cnt, lon_idx in enumerate(lon_idx_rng):
l_lon_bound, r_lon_bound = (
self.fn_lon[lon_idx],
self.fn_lon[lon_idx + 1],
Expand All @@ -271,7 +292,7 @@ def __get_fns(self, lat_min_idx, lat_max_idx, lon_min_idx, lon_max_idx):

return fns, dirs, lon_cnt, lat_cnt

def __load_topo(self, cell, fns, dirs, lon_cnt, lat_cnt, init=True, populate=True):
def __load_topo(self, cell, fns, dirs, lon_cnt, lat_cnt, lat_idx_rng, lon_idx_rng, init=True, populate=True):
"""
This method assembles a contiguous array in ``cell.topo`` containing the regional topography to be loaded.
Expand All @@ -282,7 +303,7 @@ def __load_topo(self, cell, fns, dirs, lon_cnt, lat_cnt, init=True, populate=Tru
2. The second run populates the empty array with the information of the block arrays obtained in the first run.
"""
if (cell.topo is None) and (init):
self.__load_topo(cell, fns, dirs, lon_cnt, lat_cnt, init=False, populate=False)
self.__load_topo(cell, fns, dirs, lon_cnt, lat_cnt, lat_idx_rng, lon_idx_rng, init=False, populate=False)

if not populate:
n_col = 0
Expand Down Expand Up @@ -320,19 +341,68 @@ def __load_topo(self, cell, fns, dirs, lon_cnt, lat_cnt, init=True, populate=Tru
lat_high = np.max((lat_min_idx, lat_max_idx))
lat_low = np.min((lat_min_idx, lat_max_idx))

lon = test["lon"]
lon_min_idx = np.argmin(np.abs((lon - np.sign(lon) * 1e-4) - (self.lon_verts.min())))
lon_max_idx = np.argmin(np.abs((lon + np.sign(lon) * 1e-4) - (self.lon_verts.max())))
# lon = test["lon"]
# lon_min_idx = np.argmin(np.abs((lon - np.sign(lon) * 1e-4) - (self.lon_verts.min())))
# lon_max_idx = np.argmin(np.abs((lon + np.sign(lon) * 1e-4) - (self.lon_verts.max())))

lon_high = np.max((lon_min_idx, lon_max_idx))
lon_low = np.min((lon_min_idx, lon_max_idx))
# lon_high = np.max((lon_min_idx, lon_max_idx))
# lon_low = np.min((lon_min_idx, lon_max_idx))

### Only add lat and lon elements if there are changes to the low and high indices identified:
if (lon_low not in lon_low_old) and (lon_high not in lon_high_old):
lon_nc_change = True
# if (lon_low not in lon_low_old) and (lon_high not in lon_high_old):
# lon_nc_change = True

# if (lat_low not in lat_low_old) and (lat_high not in lat_high_old):
# lat_nc_change = True

############################################
lat = test["lat"]
lon = test["lon"]

l_lat_bound, r_lat_bound = (
self.fn_lat[lat_idx_rng[n_row]],
self.fn_lat[lat_idx_rng[n_row] + 1],
)

l_lon_bound, r_lon_bound = (
self.fn_lon[lon_idx_rng[n_col]],
self.fn_lon[lon_idx_rng[n_col] + 1],
)

lon_rng = r_lon_bound - l_lon_bound

lon_in_file = self.lon_verts[( (self.lon_verts - l_lon_bound) > 0 ) & ( (self.lon_verts - l_lon_bound) <= lon_rng )]

if len(lon_in_file) == 0:
lon_high = np.argmin(np.abs(lon - r_lon_bound))
lon_low = np.argmin(np.abs(lon - l_lon_bound))

else:
if not self.split_EW:
if lon_in_file.max() == self.lon_verts.max():
lon_high = np.argmin(np.abs(lon - lon_in_file.max()))
else:
lon_high = np.argmin(np.abs(lon - r_lon_bound))

if lon_in_file.min() == self.lon_verts.min():
lon_low = np.argmin(np.abs(lon - lon_in_file.min()))
else:
lon_low = np.argmin(np.abs(lon - l_lon_bound))

else:
if lon_in_file.max() == self.lon_verts.max():
lon_high = np.argmin(np.abs(lon - r_lon_bound))
lon_low = np.argmin(np.abs(lon - lon_in_file.min()))

if lon_in_file.min() == self.lon_verts.min():
lon_high = np.argmin(np.abs(lon - lon_in_file.max()))
lon_low = np.argmin(np.abs(lon - l_lon_bound))
# if r_lon_bound > lon_in_file.max():
# lon_high = np.argmin(np.abs(lon - lon_in_file.max()))

# if lon_in_file.min() > l_lon_bound:
# lon_low = np.argmin(np.abs(lon - lon_in_file.min()))

if (lat_low not in lat_low_old) and (lat_high not in lat_high_old):
lat_nc_change = True

lon_low_old[cnt] = lon_low
lon_high_old[cnt] = lon_high
Expand Down Expand Up @@ -399,6 +469,11 @@ def __load_topo(self, cell, fns, dirs, lon_cnt, lat_cnt, init=True, populate=Tru
if not populate:
cell.topo = np.zeros((nc_lat, nc_lon))
else:

if self.split_EW:
cell.lon = np.array(cell.lon)
cell.lon[cell.lon < 0.0] += 360.0

iint = self.merit_cg

cell.lat = utils.sliding_window_view(
Expand Down
23 changes: 22 additions & 1 deletion src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,4 +830,25 @@ def is_land(cell, simplex_lat, simplex_lon, topo, height_tol=0.5, percent_tol=0.
if not (((cell.topo <= height_tol).sum() / cell.topo.size) > percent_tol):
return True
else:
return False
return False


def handle_latlon_expansion(clat_vertices, clon_vertices, lat_expand = 1.0, lon_expand = 1.0):
clon_vertices = np.array(clon_vertices)
clat_vertices = np.array(clat_vertices)

clon_vertices[np.where(np.abs(np.abs(clon_vertices) - 180.0) < 1e-5)] = 180.0

clat_vertices[np.argmax(clat_vertices)] += lat_expand
clon_vertices[np.argmax(clon_vertices)] += lon_expand

clat_vertices[np.argmin(clat_vertices)] -= lat_expand
clon_vertices[np.argmin(clon_vertices)] -= lon_expand

clon_vertices[np.where(clon_vertices < -180.0)] += 360.0
clon_vertices[np.where(clon_vertices > 180.0)] -= 360.0

clat_vertices = np.where(clat_vertices < -90.0, clat_vertices + 1.0, clat_vertices)
clat_vertices = np.where(clat_vertices > 90.0, clat_vertices - 1.0, clat_vertices)

return clat_vertices, clon_vertices

0 comments on commit dbf7915

Please sign in to comment.