import os
import numpy as np
from ._exactextract import RasterSource as _RasterSource
[docs]
class RasterSource(_RasterSource):
"""
Source from which raster data can be read.
A RasterSource provides the ability to read subsets of a single band of
raster data. Several implementations are included in exactextract:
- :py:class:`GDALRasterSource`
- :py:class:`NumPyRasterSource`
- :py:class:`RasterioRasterSource`
- :py:class:`XArrayRasterSource`
"""
def __init__(self):
super().__init__()
[docs]
class GDALRasterSource(RasterSource):
"""
RasterSource backed by GDAL
"""
[docs]
def __init__(self, ds, band_idx: int = 1, *, name=None):
"""
Args:
ds: A ``gdal.Dataset`` or path from which one can be opened
band_idx: 1-based numerical index of band to read
name: source name, to be used in generating field names for results
"""
super().__init__()
from osgeo import gdal
if isinstance(ds, (str, os.PathLike)):
ds = gdal.Open(ds)
self.ds = ds
# Sanity check inputs
if band_idx is not None and band_idx <= 0:
raise ValueError("Raster band index starts from 1!")
# Check for axis-aligned grid
gt = self.ds.GetGeoTransform()
if gt[2] != 0 or gt[4] != 0:
raise ValueError("Rotated rasters are not supported.")
self.band = self.ds.GetRasterBand(band_idx)
self.isfloat = self.band.DataType in {gdal.GDT_Float32, gdal.GDT_Float64}
self.scaled = self.band.GetScale() not in (
None,
1.0,
) or self.band.GetOffset() not in (None, 0.0)
self.use_mask_band = self._calc_use_mask_band()
if name:
self.set_name(name)
[docs]
def res(self):
gt = self.ds.GetGeoTransform()
return gt[1], abs(gt[5])
[docs]
def extent(self):
gt = self.ds.GetGeoTransform()
dx, dy = self.res()
left = gt[0]
right = left + dx * self.ds.RasterXSize
top = gt[3]
bottom = gt[3] - dy * self.ds.RasterYSize
return (left, bottom, right, top)
[docs]
def nodata_value(self):
if self.scaled:
# for scaled rasters we rely on the NODATA mask rather than inverting the scaling
return None
val = self.band.GetNoDataValue()
if val is not None and not self.scaled:
return val if self.isfloat else int(val)
def _calc_use_mask_band(self):
from osgeo import gdal
flags = self.band.GetMaskFlags()
if flags == gdal.GMF_ALL_VALID:
return False
if flags == gdal.GMF_NODATA:
return self.scaled
return True
[docs]
def read_window(self, x0, y0, nx, ny):
arr = self.band.ReadAsArray(xoff=x0, yoff=y0, win_xsize=nx, win_ysize=ny)
mask = None
if self.use_mask_band:
mask = ~(
self.band.GetMaskBand()
.ReadAsArray(xoff=x0, yoff=y0, win_xsize=nx, win_ysize=ny)
.astype(bool)
)
if self.band.GetScale() not in (None, 1.0):
if issubclass(arr.dtype.type, np.integer):
arr = arr.astype(np.float64)
arr *= self.band.GetScale()
if self.band.GetOffset() not in (None, 0.0):
if issubclass(arr.dtype.type, np.integer):
arr = arr.astype(np.float64)
arr += self.band.GetOffset()
if mask is not None:
return np.ma.masked_array(arr, mask)
else:
return arr
[docs]
def srs_wkt(self):
crs = self.ds.GetSpatialRef()
if crs:
return crs.ExportToWkt()
[docs]
class NumPyRasterSource(RasterSource):
"""
RasterSource backed by a NumPy array
"""
[docs]
def __init__(
self,
mat,
xmin=None,
ymin=None,
xmax=None,
ymax=None,
*,
nodata=None,
name=None,
srs_wkt=None
):
"""
Create a RasterSource that references a NumPy array.
If spatial extent arguments are not provided, the extent will be assumed to be
from (0,0) to (nx,ny).
Args:
mat: a two-dimensional NumPy array. Masked arrays are supported.
xmin: x coordinate of left edge
ymin: y coordinate of bottom edge
xmax: x coordinate of right edge
ymax: y coordinate of top edge
nodata: Optional value used to indicate missing data.
name: source name, to be used in generating field names for results
srs_wkt: WKT string indicating the spatial reference system.
"""
super().__init__()
self.mat = mat
self.nodata = nodata
assert (xmin is None) == (ymin is None) == (xmax is None) == (ymax is None)
if xmin is None:
self.ext = (0, 0, self.mat.shape[1], self.mat.shape[0])
else:
self.ext = (xmin, ymin, xmax, ymax)
if name:
self.set_name(name)
self.srs_wkt_str = srs_wkt
[docs]
def res(self):
ny, nx = self.mat.shape
dy = (self.ext[3] - self.ext[1]) / ny
dx = (self.ext[2] - self.ext[0]) / nx
return (dx, dy)
[docs]
def srs_wkt(self):
return self.srs_wkt_str
[docs]
def extent(self):
return self.ext
[docs]
def nodata_value(self):
return self.nodata
[docs]
def read_window(self, x0, y0, nx, ny):
return self.mat[y0 : y0 + ny, x0 : x0 + nx]
[docs]
class RasterioRasterSource(RasterSource):
"""
RasterSource backed by rasterio
"""
[docs]
def __init__(self, ds, band_idx=1, *, name=None):
"""
Args:
ds: A ``rasterio.DatasetReader`` or path from which one can be opened
band_idx: 1-based numerical index of band to read
name: source name, to be used in generating field names for results
"""
super().__init__()
if isinstance(ds, (str, os.PathLike)):
import rasterio
ds = rasterio.open(ds)
self.ds = ds
self.band_idx = band_idx
self.isfloat = self.ds.dtypes[band_idx - 1].startswith("float")
gt = self.ds.get_transform()
if gt[2] != 0 or gt[4] != 0:
raise ValueError("Rotated rasters are not supported.")
if name:
self.set_name(name)
self.scale = self.ds.scales[self.band_idx - 1]
self.offset = self.ds.offsets[self.band_idx - 1]
self.scaled = self.scale != 1.0 or self.offset != 0.0
[docs]
def res(self):
dx = (self.ds.bounds.right - self.ds.bounds.left) / self.ds.width
dy = (self.ds.bounds.top - self.ds.bounds.bottom) / self.ds.height
return (dx, dy)
[docs]
def srs_wkt(self):
crs = self.ds.crs
if crs:
return crs.wkt
[docs]
def extent(self):
return (
self.ds.bounds.left,
self.ds.bounds.bottom,
self.ds.bounds.right,
self.ds.bounds.top,
)
[docs]
def nodata_value(self):
if self.ds.nodata is not None and not self.scaled:
return self.ds.nodata if self.isfloat else int(self.ds.nodata)
[docs]
def read_window(self, x0, y0, nx, ny):
from rasterio.windows import Window
arr = self.ds.read(self.band_idx, window=Window(x0, y0, nx, ny), masked=True)
if self.scaled:
if issubclass(arr.dtype.type, np.integer):
arr = arr.astype(np.float64)
arr = arr * self.scale + self.offset
if len(arr.mask.shape):
return arr
else:
return arr.data
[docs]
class XArrayRasterSource(RasterSource):
"""
RasterSource backed by xarray
The rio-xarray extension is used to retrieve metadata such as the
array extent, resolution, and spatial reference system.
"""
[docs]
def __init__(self, ds, band_idx=1, *, name=None):
"""
Args:
ds: An xarray ``DataArray`` or a path from which one can be read.
band_idx: 1-based numerical index of band to read
name: source name, to be used in generating field names for results
"""
super().__init__()
if isinstance(ds, (str, os.PathLike)):
import rioxarray # noqa: F401
import xarray
ds = xarray.open_dataset(ds)
ds = ds[next(iter(ds.keys()))] # get first variable, for now
self.ds = ds
if self.ds.rio.crs is None:
# Set a default CRS to prevent clip_box from
# complaining that we don't have one
self.ds.rio.set_crs("EPSG:4326", inplace=True)
self.band_idx = band_idx
self.band_dim = self._band_dim(self.ds)
self.bounds = self.ds.rio.bounds()
if name:
self.set_name(name)
@staticmethod
def _band_dim(ds):
dims = list(ds.dims)
dims.remove(ds.rio.x_dim)
dims.remove(ds.rio.y_dim)
if len(dims) == 0:
return None
elif len(dims) == 1:
return dims[0]
else:
raise Exception("Cannot handle >1 non-spatial dimension")
[docs]
def srs_wkt(self):
crs = self.ds.rio.crs
if crs:
return crs.wkt
[docs]
def res(self):
return tuple(abs(x) for x in self.ds.rio.resolution())
[docs]
def extent(self):
return self.bounds
[docs]
def nodata_value(self):
return self.ds.rio.nodata
[docs]
def read_window(self, x0, y0, nx, ny):
if nx == 0 or ny == 0:
return np.array([[]], dtype=self.ds.dtype)
lats = self.ds[self.ds.rio.y_dim]
flipped = bool(len(lats) > 1 and lats[1] > lats[0])
if flipped:
y0 = self.ds.rio.height - y0 - ny
selection = {}
if self.band_dim is not None:
selection[self.band_dim] = self.ds[self.band_dim][self.band_idx - 1]
selection[self.ds.rio.x_dim] = self.ds[self.ds.rio.x_dim][x0 : x0 + nx]
selection[self.ds.rio.y_dim] = self.ds[self.ds.rio.y_dim][y0 : y0 + ny]
ret = self.ds.sel(**selection).to_numpy()
if flipped:
ret = np.flipud(ret)
return ret