Skip to content

Commit

Permalink
fixed ugrid so it works with Py 3.10, 3.11
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisBarker-NOAA committed Oct 9, 2024
1 parent f123650 commit 4d02b7a
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 72 deletions.
7 changes: 3 additions & 4 deletions tests/test_visualization/test_mpl_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,14 @@ def test_plot_ugrid_start_index_1():
# SGRID tests
#############

# def test_plot_sgrid_only_grid():
# import cftime
# def test_plot_sgrid_and_nodes():
# ds = xr.open_dataset(EXAMPLE_DATA / "wcofs_small_subset.nc", decode_times=False)

# fig, axis = plt.subplots()

# plot_sgrid(axis, ds)
# plot_sgrid(axis, ds, nodes=True)

# fig.savefig(OUTPUT_DIR / "sgrid_just_plot")
# fig.savefig(OUTPUT_DIR / "sgrid_nodes")



6 changes: 4 additions & 2 deletions xarray_subset_grid/grids/ugrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,11 @@ def assign_ugrid_topology(
mesh.__dict__.update(mesh_attrs)

# Add in the ones passed in:
mesh.__dict__.update({att: vars()[att]
variables = vars()
mesh.__dict__.update({att: variables[att]
for att in ALL_MESH_VARS
if vars()[att] is not None})
if variables[att] is not None})
mesh.start_index = start_index

if mesh.face_node_connectivity is None:
raise ValueError(
Expand Down
132 changes: 66 additions & 66 deletions xarray_subset_grid/visualization/mpl_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,73 +113,73 @@ def plot_sgrid(axes, ds, nodes=False, rho_points=False, edge_points=False):

raise NotImplementedError("have to port ugrid code to Sgrid")

mesh_defs = ds[ds.cf.cf_roles["grid_topology"][0]].attrs
lon_var, lat_var = mesh_defs["node_coordinates"].split()
nodes_lon, nodes_lat = (ds[n] for n in mesh_defs["node_coordinates"].split())


faces = ds[mesh_defs["face_node_connectivity"]]

if faces.shape[0] == 3:
# swap order for mpl triangulation
faces = faces.T
start_index = faces.attrs.get("start_index")
start_index = 0 if start_index is None else start_index
faces = faces - start_index

mpl_tri = Triangulation(nodes_lon, nodes_lat, faces)

axes.triplot(mpl_tri)
if face_numbers:
try:
face_lon, face_lat = (ds[n] for n in mesh_defs["face_coordinates"].split())
except KeyError:
raise ValueError('"face_coordinates" must be defined to plot the face numbers')
for i, point in enumerate(zip(face_lon, face_lat)):
axes.annotate(
f"{i}",
point,
xytext=(0, 0),
textcoords="offset points",
horizontalalignment="center",
verticalalignment="center",
bbox={
"facecolor": "white",
"alpha": 1.0,
"boxstyle": "round,pad=0.0",
"ec": "white",
},
)
grid_defs = ds[ds.cf.cf_roles["grid_topology"][0]].attrs
lon_var, lat_var = grid_defs["node_coordinates"].split()
nodes_lon, nodes_lat = (ds[n] for n in grid_defs["node_coordinates"].split())


# faces = ds[mesh_defs["face_node_connectivity"]]

# if faces.shape[0] == 3:
# # swap order for mpl triangulation
# faces = faces.T
# start_index = faces.attrs.get("start_index")
# start_index = 0 if start_index is None else start_index
# faces = faces - start_index

# mpl_tri = Triangulation(nodes_lon, nodes_lat, faces)

# axes.triplot(mpl_tri)
# if face_numbers:
# try:
# face_lon, face_lat = (ds[n] for n in mesh_defs["face_coordinates"].split())
# except KeyError:
# raise ValueError('"face_coordinates" must be defined to plot the face numbers')
# for i, point in enumerate(zip(face_lon, face_lat)):
# axes.annotate(
# f"{i}",
# point,
# xytext=(0, 0),
# textcoords="offset points",
# horizontalalignment="center",
# verticalalignment="center",
# bbox={
# "facecolor": "white",
# "alpha": 1.0,
# "boxstyle": "round,pad=0.0",
# "ec": "white",
# },
# )

# plot nodes
if nodes:
axes.plot(nodes_lon, nodes_lat, "o")
# plot node numbers
if node_numbers:
for i, point in enumerate(zip(nodes_lon, nodes_lat)):
axes.annotate(
f"{i}",
point,
xytext=(2, 2),
textcoords="offset points",
bbox={
"facecolor": "white",
"alpha": 1.0,
"boxstyle": "round,pad=0.0",
"ec": "white",
},
)

# boundaries -- if they are there.
if "boundary_node_connectivity" in mesh_defs:
bounds = ds[mesh_defs["boundary_node_connectivity"]]

lines = []
for bound in bounds.data:
line = (
(nodes_lon[bound[0]], nodes_lat[bound[0]]),
(nodes_lon[bound[1]], nodes_lat[bound[1]]),
)
lines.append(line)
lc = LineCollection(lines, linewidths=2, colors=(1, 0, 0, 1))
axes.add_collection(lc)
# # plot node numbers
# if node_numbers:
# for i, point in enumerate(zip(nodes_lon, nodes_lat)):
# axes.annotate(
# f"{i}",
# point,
# xytext=(2, 2),
# textcoords="offset points",
# bbox={
# "facecolor": "white",
# "alpha": 1.0,
# "boxstyle": "round,pad=0.0",
# "ec": "white",
# },
# )

# # boundaries -- if they are there.
# if "boundary_node_connectivity" in mesh_defs:
# bounds = ds[mesh_defs["boundary_node_connectivity"]]

# lines = []
# for bound in bounds.data:
# line = (
# (nodes_lon[bound[0]], nodes_lat[bound[0]]),
# (nodes_lon[bound[1]], nodes_lat[bound[1]]),
# )
# lines.append(line)
# lc = LineCollection(lines, linewidths=2, colors=(1, 0, 0, 1))
# axes.add_collection(lc)

0 comments on commit 4d02b7a

Please sign in to comment.