Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial set of raster extent functions #85

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 215 additions & 0 deletions rivgraph/geo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
functionality here, and some of these functions are simply unused.

"""
import os
try:
from osgeo import gdal
except ImportError:
Expand Down Expand Up @@ -358,3 +359,217 @@ def downsample_binary_geotiff(input_file, ds_factor, output_name, thresh=None):
output_name, dtype=gdal.GDT_Byte)

return output_name


def clip_rasters_to_common_extents(raster_list, outdir, extent='intersection'):
"""
Given a list of rasters, clip each to the common extent of all rasters.

Identifies the common extents for a set of rasters and clips them as new
rasters with the prefix 'clipped_' in the specified output directory. The
extents used for clipping can be 'union' or 'intersection' for the
minimal or maximal rectangle extents, respectively.

Parameters
----------
raster_list : list
List of paths to rasters to clip.

outdir : str
Path to directory where clipped rasters will be saved.

extent : str, optional
If 'union', the minimal rectangle containing all inputs is returned.
If 'intersection', the maximum rectangle is returned.
Gets passed to the common_extent() function.
"""
# Get common extents
ce = common_extents(raster_list, extent=extent)

# Clip each raster to the common extent
for i in raster_list:
_clipped, geo_info = clip_raster_to_extent(i, ce)
# save clipped geotiff to outdir
io.write_geotiff(_clipped, geo_info['gt'], geo_info['wkt'],
os.path.join(outdir, 'clipped_' + i),
dtype=geo_info['dtype'], nbands=geo_info['nbands'],
nodata=geo_info['nodata'])


def common_extents(raster_list, extent='intersection'):
"""
Given a list of raster files, return the common extents of all rasters.

Parameters
----------
raster_list : list
List of raster files (strings) to be used in the analysis.

extent : str, optional
If 'union', the rectangle containing all input data is returned.
If 'intersection', the rectangle containing all intersecting input
data is returned (i.e. the maximum rectangle).

Returns
----------
common_extent : tuple
Tuple of the form (xmin, xmax, ymin, ymax) defining the common extent
of all rasters.

"""
# Get the extents of all rasters
extents = []
for raster in raster_list:
extents.append(get_raster_extent(raster))

# Find the common extent
if extent == 'union':
warnings.warn(
'Expansion of small rasters for union not yet supported.')
common_extent = (np.min([ext[0] for ext in extents]),
np.max([ext[1] for ext in extents]),
np.min([ext[2] for ext in extents]),
np.max([ext[3] for ext in extents]))
elif extent == 'intersection':
common_extent = (np.max([ext[0] for ext in extents]),
np.min([ext[1] for ext in extents]),
np.max([ext[2] for ext in extents]),
np.min([ext[3] for ext in extents]))
else:
raise ValueError(
'Invalid extent type. Must be "union" or "intersection".')

# logic check on the extents (min < max)
if common_extent[0] > common_extent[1]:
raise ValueError(
f'X extents, {common_extent[0], common_extent[1]} , are invalid.')
if common_extent[2] > common_extent[3]:
raise ValueError(
f'Y extents, {common_extent[2], common_extent[3]}, are invalid.')

return common_extent


def get_raster_extent(raster):
"""
Given a raster file, return the extent of the raster.

Parameters
----------
raster : str
Path to raster file.

Returns
----------
extent : tuple
Tuple of the form (xmin, xmax, ymin, ymax) defining the extent of the
raster.

"""
# Get the raster extent
ds = gdal.Open(raster)
gt = ds.GetGeoTransform()
cols = ds.RasterXSize
rows = ds.RasterYSize
xmin = gt[0]
xmax = gt[0] + cols * gt[1]
ymin = gt[3] + rows * gt[5]
ymax = gt[3]

extent = (xmin, xmax, ymin, ymax)

return extent


def clip_raster_to_extent(raster, extent):
"""
Given a raster file and an extent, clip the raster to the extent.

.. warning::

Ability to *expand* small rasters to fill larger extents in for the
"union" functionality is not yet supported!

Parameters
----------
raster : str
Path to raster file.

extent : tuple
Tuple of the form (xmin, xmax, ymin, ymax) defining the extent of the
raster.

Returns
----------
clipped : array
Array containing the clipped raster.

geo_info : dict
Dictionary containing the georeferencing information for the clipped
raster.

"""
# Get the raster extent
ds = gdal.Open(raster)
gt = ds.GetGeoTransform()
cols = ds.RasterXSize
rows = ds.RasterYSize
xmin = gt[0]
xmax = gt[0] + cols * gt[1]
ymin = gt[3] + rows * gt[5]
ymax = gt[3]

# Get the pixel coordinates of the target extent
x1, y1 = world2Pixel(gt, extent[0], extent[3])
x2, y2 = world2Pixel(gt, extent[1], extent[2])

# Clip the raster
# print(f'Clipping to {x1, y2, x2 - x1, y2 - y1}')
clipped = ds.ReadAsArray(x1, y1, x2 - x1, y2 - y1)

# Get the georeferencing information
geo_info = {}
geo_info['gt'] = (extent[0], gt[1], gt[2], extent[3], gt[4], gt[5])
geo_info['wkt'] = ds.GetProjection()
geo_info['dtype'] = ds.GetRasterBand(1).DataType
geo_info['nbands'] = ds.RasterCount
geo_info['nodata'] = ds.GetRasterBand(1).GetNoDataValue()

return clipped, geo_info


def world2Pixel(geoMatrix, x, y):
"""
Function to convert geospatial coordinates to pixel locations.

This function uses a gdal geomatrix (gdal.GetGeoTransform()) to calculate
the pixel location of a geospatial coordinate.

Parameters
----------
geoMatrix : list
A list of 6 numbers representing the geotransform matrix. Obtained
from calling gdal.GetGeoTransform() on some geotif.

x : float
The x coordinate of the point to be converted.

y : float
The y coordinate of the point to be converted.

Returns
-------
x_pixel : int
The x pixel location of the point.

y_pixel : int
The y pixel location of the point.

"""
ulX = geoMatrix[0]
ulY = geoMatrix[3]
xDist = geoMatrix[1]
yDist = geoMatrix[5]
x_pixel = int((x - ulX) / xDist)
y_pixel = int((y - ulY) / yDist)
return (x_pixel, y_pixel)