-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathutils_spatial.py
97 lines (74 loc) · 2.83 KB
/
utils_spatial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""Utilities for spatial selection and aggregation."""
def subset_lat(ds, lat_bnds):
"""Select grid points that fall within latitude bounds.
Parameters
----------
ds : Union[xarray.DataArray, xarray.Dataset]
Input data
lat_bnds : list
Latitude bounds: [south bound, north bound]
Returns
-------
Union[xarray.DataArray, xarray.Dataset]
Subsetted xarray.DataArray or xarray.Dataset
"""
if 'latitude' in ds.dims:
ds = ds.rename({'latitude': 'lat'})
south_bound, north_bound = lat_bnds
assert -90 <= south_bound <= 90, "Valid latitude range is [-90, 90]"
assert -90 <= north_bound <= 90, "Valid latitude range is [-90, 90]"
lat_axis = ds['lat'].values
if lat_axis[-1] > lat_axis[0]:
# monotonic increasing lat axis (e.g. -90 to 90)
ds = ds.sel({'lat': slice(south_bound, north_bound)})
else:
# monotonic decreasing lat axis (e.g. 90 to -90)
ds = ds.sel({'lat': slice(north_bound, south_bound)})
return ds
def avoid_cyclic(ds, west_bound, east_bound):
"""Alter longitude axis if requested bounds straddle cyclic point"""
west_bound_360 = (west_bound + 360) % 360
east_bound_360 = (east_bound + 360) % 360
west_bound_180 = ((west_bound + 180) % 360) - 180
east_bound_180 = ((east_bound + 180) % 360) - 180
if east_bound_360 < west_bound_360:
ds = ds.assign_coords({'lon': ((ds['lon'] + 180) % 360) - 180})
ds = ds.sortby(ds['lon'])
elif east_bound_180 < west_bound_180:
ds = ds.assign_coords({'lon': (ds['lon'] + 360) % 360})
ds = ds.sortby(ds['lon'])
return ds
def subset_lon(ds, lon_bnds):
"""Select grid points that fall within longitude bounds.
Parameters
----------
ds : Union[xarray.DataArray, xarray.Dataset]
Input data
lon_bnds : list
Longitude bounds: [west bound, east bound]
Returns
-------
Union[xarray.DataArray, xarray.Dataset]
Subsetted xarray.DataArray or xarray.Dataset
"""
if 'longitude' in ds.dims:
ds = ds.rename({'longitude': 'lon'})
assert ds['lon'].values.max() > ds['lon'].values.min()
west_bound, east_bound = lon_bnds
ds = avoid_cyclic(ds, west_bound, east_bound)
lon_axis_max = ds['lon'].values.max()
lon_axis_min = ds['lon'].values.min()
if west_bound > lon_axis_max:
west_bound = west_bound - 360
assert west_bound <= lon_axis_max
if east_bound > lon_axis_max:
east_bound = east_bound - 360
assert east_bound <= lon_axis_max
if west_bound < lon_axis_min:
west_bound = west_bound + 360
assert west_bound >= lon_axis_min
if east_bound < lon_axis_min:
east_bound = east_bound + 360
assert east_bound >= lon_axis_min
ds = ds.sel({'lon': slice(west_bound, east_bound)})
return ds