Skip to content

Commit

Permalink
compression with zst and significant digits quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Mather committed Nov 7, 2023
1 parent 3e1f013 commit f85ba52
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions gplately/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ def find_label(keys, labels):
cdf_lon = cdf[key_lon][:]
cdf_lat = cdf[key_lat][:]

fill_value = cdf[key_z].missing_value

cdf_grid[np.isclose(cdf_grid, fill_value, rtol=0.1)] = np.nan

if realign:
# realign longitudes to -180/180 dateline
cdf_grid_z, cdf_lon, cdf_lat = realign_grid(cdf_grid, cdf_lon, cdf_lat)
Expand Down Expand Up @@ -243,7 +247,7 @@ def find_label(keys, labels):
else:
return cdf_grid_z

def write_netcdf_grid(filename, grid, extent=[-180,180,-90,90]):
def write_netcdf_grid(filename, grid, extent=[-180,180,-90,90], significant_digits=None, fill_value=np.nan):
""" Write geological data contained in a `grid` to a netCDF4 grid with a specified `filename`.
Notes
Expand Down Expand Up @@ -288,8 +292,8 @@ def write_netcdf_grid(filename, grid, extent=[-180,180,-90,90]):
cdf.title = "Grid produced by gplately"
cdf.createDimension('lon', lon_grid.size)
cdf.createDimension('lat', lat_grid.size)
cdf_lon = cdf.createVariable('lon', lon_grid.dtype, ('lon',), zlib=True)
cdf_lat = cdf.createVariable('lat', lat_grid.dtype, ('lat',), zlib=True)
cdf_lon = cdf.createVariable('lon', lon_grid.dtype, ('lon',), compression='zstd', complevel=9)
cdf_lat = cdf.createVariable('lat', lat_grid.dtype, ('lat',), compression='zstd', complevel=9)
cdf_lon[:] = lon_grid
cdf_lat[:] = lat_grid

Expand All @@ -301,10 +305,18 @@ def write_netcdf_grid(filename, grid, extent=[-180,180,-90,90]):
cdf_lat.standard_name = 'lat'
cdf_lat.actual_range = [lat_grid[0], lat_grid[-1]]

cdf_data = cdf.createVariable('z', grid.dtype, ('lat','lon'), zlib=True)
if significant_digits:
cdf_data = cdf.createVariable('z', grid.dtype, ('lat','lon'), compression='zstd', complevel=9,
significant_digits=int(significant_digits),
quantize_mode='GranularBitRound')

else:
cdf_data = cdf.createVariable('z', grid.dtype, ('lat','lon'), compression='zstd', complevel=9)

# netCDF4 uses the missing_value attribute as the default _FillValue
# without this, _FillValue defaults to 9.969209968386869e+36
cdf_data.missing_value = np.nan
cdf_data.missing_value = fill_value

cdf_data.standard_name = 'z'
#Ensure pygmt registers min and max z values properly
cdf_data.actual_range = [np.nanmin(grid), np.nanmax(grid)]
Expand Down

0 comments on commit f85ba52

Please sign in to comment.