Skip to content

Commit

Permalink
chunked the grid cells and writing output for each chunk
Browse files Browse the repository at this point in the history
supports restarting in this way
  • Loading branch information
ray-chew committed Jun 12, 2024
1 parent 405d3aa commit 53b4831
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 41 deletions.
29 changes: 20 additions & 9 deletions inputs/icon_global_run.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,43 @@
import numpy as np
from src import var
from ..src import var, utils
from ..inputs import local_paths

params = var.params()

params.output_path = "/home/ray/git-projects/spec_appx/outputs/"
params.output_fn = "icon_merit_reg"
params.fn_grid = "../data/icon_compact.nc"
params.fn_topo = "../data/topo_compact.nc"
params.fn_output = "icon_merit_global"
utils.transfer_attributes(params, local_paths.paths, prefix="path")

### alaska
params.lat_extent = [48.0, 64.0, 64.0]
params.lon_extent = [-148.0, -148.0, -112.0]

### Tierra del Fuego
params.lat_extent = [-38.0, -56.0, -56.0]
params.lon_extent = [-76.0, -76.0, -53.0]

### South Pole
params.lat_extent = None
params.lon_extent = None
params.lat_extent = [-75.0, -61.0, -61.0]
params.lon_extent = [-77.0, -50.0, -50.0]

params.tri_set = [13, 104, 105, 106]

params.merit_cg = 100

# Setup the Fourier parameters and object.
params.nhi = 24
params.nhj = 48

params.n_modes = 50
params.padding = 10

params.U, params.V = 10.0, 0.0

params.rect = True

params.debug = False
params.dfft_first_guess = True
params.dfft_first_guess = False
params.refine = False
params.verbose = False

params.plot = True
params.plot = False
params.plot_output = False
4 changes: 2 additions & 2 deletions inputs/icon_regional_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@

params.tri_set = [13, 104, 105, 106]

params.merit_cg = 50
params.merit_cg = 100

# Setup the Fourier parameters and object.
params.nhi = 24
params.nhj = 48

params.n_modes = 10
params.n_modes = 50
params.padding = 10

params.U, params.V = 10.0, 0.0
Expand Down
47 changes: 26 additions & 21 deletions runs/icon_merit_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,12 @@ def parallel_wrapper(grid, params, reader, writer):
# %%

# autoreload()
from pycsam.inputs.icon_regional_run import params
from pycsam.inputs.icon_global_run import params

from dask.distributed import Client
import dask.bag as db
# import dask
# from dask.diagnostics import ProgressBar
# import dask.bag as db
import dask

# dask.config.set(scheduler='synchronous')

Expand All @@ -228,36 +229,40 @@ def parallel_wrapper(grid, params, reader, writer):
# reader.read_dat(params.path_compact_grid, grid)
reader.read_dat(params.path_icon_grid, grid)

# writer object
writer = io.nc_writer(params)

clat_rad = np.copy(grid.clat)
clon_rad = np.copy(grid.clon)

grid.apply_f(utils.rad2deg)

n_cells = grid.clat.size

print(n_cells)

pw_run = parallel_wrapper(grid, params, reader, writer)

# NetCDF-4 reader does not work well with multithreading
# Use only 1 thread per worker! (At least on my laptop)
client = Client(threads_per_worker=1, n_workers=2)

lazy_results = []
print(n_cells)

chunk_sz = 50
for chunk in range(0, n_cells, chunk_sz):
# writer object
sfx = "_" + str(chunk+chunk_sz)
writer = io.nc_writer(params, sfx)

pw_run = parallel_wrapper(grid, params, reader, writer)

lazy_results = []

b = db.from_sequence(range(n_cells), npartitions=10)
results = b.map(pw_run)
results = results.compute()
# with ProgressBar():
# b = db.from_sequence(range(chunk), npartitions=100)
# results = b.map(pw_run)
# results = results.compute()

# for c_idx in range(n_cells):
# # pw_run(c_idx)
# lazy_result = dask.delayed(pw_run)(c_idx)
# lazy_results.append(lazy_result)
for c_idx in range(chunk, chunk+chunk_sz):
# pw_run(c_idx)
lazy_result = dask.delayed(pw_run)(c_idx)
lazy_results.append(lazy_result)

# results = dask.compute(*lazy_results)
results = dask.compute(*lazy_results)

for item in results:
writer.duplicate(item.c_idx, item)
for item in results:
writer.duplicate(item.c_idx, item)
16 changes: 8 additions & 8 deletions src/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ def read_dat(self, fn, obj):

df.close()

def open(self, fn):
self.df = nc.Dataset(fn, "r")
self.is_open = True
# def open(self, fn):
# self.df = nc.Dataset(fn, "r")
# self.is_open = True

def close(self):
if self.is_open and hasattr(self, "df"):
self.df.close()
# def close(self):
# if self.is_open and hasattr(self, "df"):
# self.df.close()

def __get_truths(self, arr, vert_pts, d_pts):
"""Assembles Boolean array selecting for data points within a given lat-lon range, including padded boundary."""
Expand Down Expand Up @@ -770,9 +770,9 @@ def populate(self, idx, name, data):

class nc_writer(object):

def __init__(self, params):
def __init__(self, params, sfx=""):

self.fn = params.fn_output
self.fn = params.fn_output + str(sfx)

if self.fn[-3:] != ".nc":
self.fn += '.nc'
Expand Down
2 changes: 1 addition & 1 deletion vis/cart_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def lat_lon_icon(
)
ax.add_collection(coll)

print("--> polygon collection done")
# print("--> polygon collection done")

if annotate_idxs:
ncells = kwargs["ncells"]
Expand Down

0 comments on commit 53b4831

Please sign in to comment.