Skip to content

Commit

Permalink
the I/O for MERIT REMA seems to work for tricky grid cells now
Browse files Browse the repository at this point in the history
I got to comment up what I am doing though, since I am dealing with the corner cases separately.
  • Loading branch information
ray-chew committed May 27, 2024
1 parent dbf7915 commit e04e117
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 27 deletions.
21 changes: 12 additions & 9 deletions runs/icon_merit_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,15 @@ def do_cell(c_idx,
# lon_expand = 0.0
# lon_extent = [lon_verts.min() - lon_expand,lon_verts.min() - lon_expand,lon_verts.max() + lon_expand]

lat_extent = lat_verts
lon_extent = lon_verts
# lat_extent = lat_verts
# lon_extent = lon_verts
# we only keep the topography that is inside this lat-lon extent.

lat_extent, lon_extent = utils.handle_latlon_expansion(lat_extent, lon_extent)
lat_extent, lon_extent = utils.handle_latlon_expansion(lat_verts, lon_verts)

lat_verts = np.array(lat_extent)
lon_verts = np.array(lon_extent)
# lat_verts = np.array(lat_verts)
# lon_verts = np.array(lon_verts)
lat_verts, lon_verts = utils.handle_latlon_expansion(lat_verts, lon_verts, lat_expand = 0.0, lon_expand = 0.0)

params.lat_extent = lat_extent
params.lon_extent = lon_extent
Expand Down Expand Up @@ -106,7 +107,7 @@ def do_cell(c_idx,
triangles[i, :, 1] = np.array(clat_vertices[i, :])

if params.plot:
cart_plot.lat_lon_icon(topo, triangles, ncells=ncells, clon=clon, clat=clat)
cart_plot.lat_lon_icon(topo, triangles, ncells=ncells, clon=clon, clat=clat, title=c_idx)

# %%
tri_idx = 0
Expand Down Expand Up @@ -170,7 +171,7 @@ def do_cell(c_idx,
output_fig=False
)
else:
dplot.show(tri_idx, sols, v_extent=v_extent, output_fig=False)
dplot.show(c_idx, sols, v_extent=v_extent, output_fig=False)

if params.recompute_rhs:
sols, _ = sa.do(tri_idx, ampls_fa)
Expand All @@ -190,7 +191,7 @@ def do_cell(c_idx,
output_fig=False
)
else:
dplot.show(tri_idx, sols, v_extent=v_extent, output_fig=False)
dplot.show(c_idx, sols, v_extent=v_extent, output_fig=False)

print("--> analysis done")

Expand Down Expand Up @@ -242,8 +243,10 @@ def parallel_wrapper(grid, params, reader, writer):
# client = Client(threads_per_worker=1, n_workers=8)

# lazy_results = []

# for c_idx in range(n_cells)[:20]:
# for c_idx in range(n_cells)[180:190]:
for c_idx in range(n_cells)[2048:2050]:
for c_idx in range(n_cells)[2046:2060]:
pw_run(c_idx)
# lazy_result = dask.delayed(pw_run)(c_idx)
# lazy_results.append(lazy_result)
Expand Down
36 changes: 24 additions & 12 deletions src/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,17 +180,23 @@ def get_topo(self, cell):

# if lat_verts

if ( (self.lon_verts.max() - self.lon_verts.min()) > 180.0 ):
self.split_EW = True

if self.split_EW:
min_lon = max(np.where(self.lon_verts < 0.0, self.lon_verts + 360.0, self.lon_verts)) - 360.0
max_lon = min(np.where(self.lon_verts < 0.0, self.lon_verts + 360.0, self.lon_verts))
else:
min_lon = self.lon_verts.min()
max_lon = self.lon_verts.max()

lat_min_idx = self.__compute_idx(self.lat_verts.min(), "min", "lat")
lat_max_idx = self.__compute_idx(self.lat_verts.max(), "max", "lat")

lon_min_idx = self.__compute_idx(self.lon_verts.min(), "min", "lon")
lon_max_idx = self.__compute_idx(self.lon_verts.max(), "max", "lon")
lon_min_idx = self.__compute_idx(min_lon, "min", "lon")
lon_max_idx = self.__compute_idx(max_lon, "max", "lon")

if ( (self.lon_verts.max() - self.lon_verts.min()) > 180.0 ):
# lon_max_idx, lon_min_idx = lon_min_idx, lon_max_idx
self.split_EW = True

lon_idx_rng = list(range(lon_max_idx, len(self.fn_lon) - 1 )) + list(range(0,lon_min_idx + 1))

else:
Expand Down Expand Up @@ -219,15 +225,17 @@ def __compute_idx(self, vert, typ, direction):
print(fn_int, where_idx)

if typ == "min":
if (vert - fn_int[where_idx]) < 0.0:
if ((vert - fn_int[where_idx]) < 0.0):
if direction == "lon":
where_idx -= 1
if not self.split_EW:
where_idx -= 1
else:
where_idx += 1
elif typ == "max":
if (vert - fn_int[where_idx]) > 0.0:
if ((vert - fn_int[where_idx]) > 0.0):
if direction == "lon":
where_idx += 1
if not self.split_EW:
where_idx += 1
else:
where_idx -= 1

Expand Down Expand Up @@ -390,13 +398,17 @@ def __load_topo(self, cell, fns, dirs, lon_cnt, lat_cnt, lat_idx_rng, lon_idx_r
lon_low = np.argmin(np.abs(lon - l_lon_bound))

else:
if lon_in_file.max() == self.lon_verts.max():
if lon_in_file.max() == min(np.where(self.lon_verts < 0.0, self.lon_verts + 360.0, self.lon_verts)):
lon_high = np.argmin(np.abs(lon - r_lon_bound))
lon_low = np.argmin(np.abs(lon - lon_in_file.min()))
else:
lon_high = np.argmin(np.abs(lon - r_lon_bound))

if lon_in_file.min() == self.lon_verts.min():
if lon_in_file.min() == (max(self.lon_verts[self.lon_verts < 0.0] + 360.0) - 360.0):
lon_high = np.argmin(np.abs(lon - lon_in_file.max()))
lon_low = np.argmin(np.abs(lon - l_lon_bound))
else:
lon_low = np.argmin(np.abs(lon - l_lon_bound))
# if r_lon_bound > lon_in_file.max():
# lon_high = np.argmin(np.abs(lon - lon_in_file.max()))

Expand Down Expand Up @@ -452,7 +464,7 @@ def __load_topo(self, cell, fns, dirs, lon_cnt, lat_cnt, lat_idx_rng, lon_idx_r
] = topo

n_col += 1
lon_sz_old = np.copy(lon_sz)
lon_sz_old += np.copy(lon_sz)

if n_col == (lon_cnt+1):
n_col = 0
Expand Down
12 changes: 7 additions & 5 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,10 +834,12 @@ def is_land(cell, simplex_lat, simplex_lon, topo, height_tol=0.5, percent_tol=0.


def handle_latlon_expansion(clat_vertices, clon_vertices, lat_expand = 1.0, lon_expand = 1.0):
clon_vertices = np.array(clon_vertices)
clat_vertices = np.array(clat_vertices)
clon_vertices = np.around(clon_vertices,5)
clat_vertices = np.around(clat_vertices,5)

clon_vertices[np.where(np.abs(np.abs(clon_vertices) - 180.0) < 1e-5)] = 180.0
# clon_vertices[np.where(np.abs(np.abs(clon_vertices) - 180.0) < 1e-5)] = 180.0
clon_vertices[np.where(clon_vertices == 180.0)] = np.sign(clon_vertices.min()) * 180.0
clon_vertices[np.where(clon_vertices == -180.0)] = np.sign(clon_vertices.max()) * 180.0

clat_vertices[np.argmax(clat_vertices)] += lat_expand
clon_vertices[np.argmax(clon_vertices)] += lon_expand
Expand All @@ -848,7 +850,7 @@ def handle_latlon_expansion(clat_vertices, clon_vertices, lat_expand = 1.0, lon_
clon_vertices[np.where(clon_vertices < -180.0)] += 360.0
clon_vertices[np.where(clon_vertices > 180.0)] -= 360.0

clat_vertices = np.where(clat_vertices < -90.0, clat_vertices + 1.0, clat_vertices)
clat_vertices = np.where(clat_vertices > 90.0, clat_vertices - 1.0, clat_vertices)
clat_vertices = np.where(clat_vertices < -90.0, clat_vertices + lat_expand, clat_vertices)
clat_vertices = np.where(clat_vertices > 90.0, clat_vertices - lat_expand, clat_vertices)

return clat_vertices, clon_vertices
2 changes: 1 addition & 1 deletion vis/cart_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def lat_lon_icon(
fc="r",
alpha=0.2,
linewidth=1,
transform=ccrs.Geodetic(),
transform=ccrs.PlateCarree(),
zorder=3,
)
ax.add_collection(coll)
Expand Down

0 comments on commit e04e117

Please sign in to comment.