Skip to content

Commit

Permalink
clean up for raster functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
kaklise committed Oct 1, 2024
1 parent d4d5271 commit 447a104
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 56 deletions.
45 changes: 23 additions & 22 deletions wntr/gis/geospatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def snap(A, B, tolerance):
if not has_shapely or not has_geopandas:
raise ModuleNotFoundError('shapley and geopandas are required')

isinstance(A, gpd.GeoDataFrame)
assert isinstance(A, gpd.GeoDataFrame)
assert(A['geometry'].geom_type).isin(['Point']).all()
isinstance(B, gpd.GeoDataFrame)
assert isinstance(B, gpd.GeoDataFrame)
assert (B['geometry'].geom_type).isin(['Point', 'LineString', 'MultiLineString']).all()
assert A.crs == B.crs

Expand Down Expand Up @@ -205,18 +205,18 @@ def intersect(A, B, B_value=None, include_background=False, background_value=0):
if not has_shapely or not has_geopandas:
raise ModuleNotFoundError('shapley and geopandas are required')

isinstance(A, gpd.GeoDataFrame)
assert isinstance(A, gpd.GeoDataFrame)
assert (A['geometry'].geom_type).isin(['Point', 'LineString',
'MultiLineString', 'Polygon',
'MultiPolygon']).all()
isinstance(B, gpd.GeoDataFrame)
assert isinstance(B, gpd.GeoDataFrame)
assert (B['geometry'].geom_type).isin(['Point', 'LineString',
'MultiLineString', 'Polygon',
'MultiPolygon']).all()
if isinstance(B_value, str):
assert B_value in B.columns
isinstance(include_background, bool)
isinstance(background_value, (int, float))
assert isinstance(include_background, bool)
assert isinstance(background_value, (int, float))
assert A.crs == B.crs, "A and B must have the same crs."

if include_background:
Expand Down Expand Up @@ -291,46 +291,47 @@ def intersect(A, B, B_value=None, include_background=False, background_value=0):
return stats


def sample_raster(A, filepath, indexes):
"""Sample a raster (e.g., GeoTIFF file) at the point locations given
by the geometry of GeoDataFrame A.
def sample_raster(A, filepath, bands=1):
"""Sample a raster (e.g., GeoTIFF file) using Points in GeoDataFrame A.
This function can take either a filepath to a raster or a virtual raster (VRT),
which combines multiple raster tiles into a single object, opens the raster, and
samples it at the coordinates of the point geometries in A. This function
assigns nan to values that match the raster's `nodata` attribute. These sampled
values are returned as a Series which has an index matching A.
This function can take either a filepath to a raster or a virtual raster
(VRT), which combines multiple raster tiles into a single object. The
function opens the raster(s) and samples it at the coordinates of the point
geometries in A. This function assigns nan to values that match the
raster's `nodata` attribute. These sampled values are returned as a Series
which has an index matching A.
Parameters
----------
A : GeoDataFrame
Geodataframe containing point geometries (lines and polygons not yet implemented)
GeoDataFrame containing Point geometries
filepath : str
Path to raster or alternatively a VRT
band : int or list[int]
Index or indices of bands to sample
Path to raster or alternatively a virtual raster (VRT)
bands : int or list[int] (optional, default = 1)
Index or indices of the bands to sample (using 1-based indexing)
Returns
-------
Series
Pandas Series containing the sampled values for each geometry in gdf
"""
# further functionality could include the implementation for other geometries (line, polygon),
# further functionality could include other geometries (Line, Polygon),
# and use of multiprocessing to speed up querying.
if not has_rasterio:
raise ModuleNotFoundError('rasterio is required')

assert (A['geometry'].geom_type == "Point").all()
assert isinstance(filepath, str)
assert isinstance(bands, (int, list))

with rio.open(filepath) as raster:
xys = zip(A.geometry.x, A.geometry.y)

values = np.array(
tuple(raster.sample(xys, indexes)), dtype=float # force to float to allow for conversion of nodata to nan
tuple(raster.sample(xys, bands)), dtype=float # force to float to allow for conversion of nodata to nan
).squeeze()

values[values == raster.nodata] = np.nan
values = pd.Series(values, index=A.index)

return values


79 changes: 45 additions & 34 deletions wntr/tests/test_gis.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,37 +76,7 @@ def setUpClass(self):

df = pd.DataFrame(point_data)
self.points = gpd.GeoDataFrame(df, crs=None)

# raster testing
points = [
(-120.5, 38.5),
(-120.6, 38.6),
(-120.55, 38.65),
(-120.65, 38.55),
(-120.7, 38.7)
]
point_geometries = [Point(xy) for xy in points]
raster_points = gpd.GeoDataFrame(geometry=point_geometries, crs="EPSG:4326")
raster_points.index = ["A", "B", "C", "D", "E"]
self.raster_points = raster_points

# create example raster
minx, miny, maxx, maxy = raster_points.total_bounds
raster_width = 100
raster_height = 100

x = np.linspace(0, 1, raster_width)
y = np.linspace(0, 1, raster_height)
raster_data = np.cos(y)[:, np.newaxis] * np.sin(x) # arbitrary values

transform = rio.transform.from_bounds(minx, miny, maxx, maxy, raster_width, raster_height)
self.transform = transform

with rio.open(
"test_raster.tif", "w", driver="GTiff", height=raster_height, width=raster_width,
count=1, dtype=raster_data.dtype, crs="EPSG:4326", transform=transform) as dst:
dst.write(raster_data, 1)

@classmethod
def tearDownClass(self):
pass
Expand Down Expand Up @@ -348,13 +318,54 @@ def test_snap_points_to_lines(self):

assert_frame_equal(pd.DataFrame(snapped_points), expected, check_dtype=False)

@unittest.skipIf(not has_rasterio,
"Cannot test raster capabilities: rasterio is missing")
class TestRaster(unittest.TestCase):
@classmethod
def setUpClass(self):

# raster testing
points = [
(-120.5, 38.5),
(-120.6, 38.6),
(-120.55, 38.65),
(-120.65, 38.55),
(-120.7, 38.7)
]
point_geometries = [Point(xy) for xy in points]
points = gpd.GeoDataFrame(geometry=point_geometries, crs="EPSG:4326")
points.index = ["A", "B", "C", "D", "E"]
self.points = points

# create example raster
minx, miny, maxx, maxy = points.total_bounds
raster_width = 100
raster_height = 100

x = np.linspace(0, 1, raster_width)
y = np.linspace(0, 1, raster_height)
raster_data = np.cos(y)[:, np.newaxis] * np.sin(x) # arbitrary values

transform = rio.transform.from_bounds(minx, miny, maxx, maxy, raster_width, raster_height)
self.transform = transform

with rio.open(
"test_raster.tif", "w", driver="GTiff", height=raster_height, width=raster_width,
count=1, dtype=raster_data.dtype, crs="EPSG:4326", transform=transform) as dst:
dst.write(raster_data, 1)

@classmethod
def tearDownClass(self):
pass

def test_sample_raster(self):
raster_values = wntr.gis.sample_raster(self.raster_points, "test_raster.tif", 1)
raster_values = wntr.gis.sample_raster(self.points, "test_raster.tif")
assert (raster_values.index == self.points.index).all()

assert (raster_values.index == self.raster_points.index).all()
# self.raster_points.plot(column=values, legend=True)
# self.points.plot(column=values, legend=True)
expected_values = np.array([0.000000, 0.423443, 0.665369, 0.174402, 0.000000])
assert np.isclose(raster_values.values, expected_values, atol=1e-5).all()



if __name__ == "__main__":
unittest.main()

0 comments on commit 447a104

Please sign in to comment.